In [None]:
import warnings
warnings.filterwarnings("ignore")

import os
os.environ["OMP_NUM_THREADS"] = "4"
os.environ["MKL_NUM_THREADS"] = "4"
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "true"

import re
import random
import math
from pathlib import Path
from typing import List
from collections import Counter
from dataclasses import dataclass

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

banner_palette = [
    "#2c1810",  # dark brown
    "#5c4a2a",  # medium brown
    "#8b6914",  # golden brown
    "#d4a843",  # gold
    "#f0d68a"   # light gold
]
sns.set_palette(banner_palette)

import torch
from torch.utils.data import Dataset, DataLoader, Sampler
from torch.cuda.amp import autocast
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from tqdm.auto import tqdm

print('Done')

# **Metrics for competition

In [None]:
from IPython.display import clear_output

var="/kaggle/input/datasets/canicule/translate-akkadian-texts-wheel"
!pip install \
  "$var"/optuna-4.7.0-py3-none-any.whl \
  "$var"/sacrebleu-2.6.0-py3-none-any.whl \
  "$var"/portalocker-3.2.0-py3-none-any.whl \
  --no-index \
  --find-links "$var"

#!pip install /kaggle/input/datasets/canicule/translate-akkadian-texts-wheel/optuna-4.7.0-py3-none-any.whl

#clear_output()

In [None]:
#!pip install sacrebleu

In [None]:
# ============================================================
# METRICS: sacrebleu with pure-Python fallback
# ============================================================
'''
USE_SACREBLEU = False
try:
    import sacrebleu
    USE_SACREBLEU = True
    print("sacrebleu loaded")
except ImportError:
    try:
        import subprocess, sys
        subprocess.check_call([sys.executable, "-m", "pip", "install", "sacrebleu", "-q"])
        import sacrebleu
        USE_SACREBLEU = True
        print("sacrebleu installed and loaded")
    except Exception:
        print("sacrebleu unavailable — using built-in BLEU/chrF++ implementation")

try:
    import optuna
    print("optuna loaded")
except ImportError:
    try:
        import subprocess, sys
        subprocess.check_call([sys.executable, "-m", "pip", "install", "optuna", "-q"])
        import optuna
        print("optuna installed and loaded")
    except Exception:
        raise ImportError("optuna is required but could not be installed. Enable internet or pre-install optuna.")

'''

def _ngrams(tokens, n):
    return [tuple(tokens[i:i+n]) for i in range(len(tokens)-n+1)]


def _corpus_bleu_fallback(hypotheses, references, max_n=4):
    """Simplified corpus BLEU (no smoothing, brevity penalty included)."""
    clip_counts = [0] * max_n
    total_counts = [0] * max_n
    hyp_len = 0
    ref_len = 0
    for hyp, ref in zip(hypotheses, references):
        hyp_tok = hyp.split()
        ref_tok = ref.split()
        hyp_len += len(hyp_tok)
        ref_len += len(ref_tok)
        for n in range(1, max_n + 1):
            hyp_ng = Counter(_ngrams(hyp_tok, n))
            ref_ng = Counter(_ngrams(ref_tok, n))
            clipped = {ng: min(c, ref_ng.get(ng, 0)) for ng, c in hyp_ng.items()}
            clip_counts[n-1] += sum(clipped.values())
            total_counts[n-1] += max(len(hyp_tok) - n + 1, 0)
    precisions = []
    for n in range(max_n):
        if total_counts[n] == 0:
            precisions.append(0)
        else:
            precisions.append(clip_counts[n] / total_counts[n])
    if any(p == 0 for p in precisions):
        return 0.0
    log_avg = sum(math.log(p) for p in precisions) / max_n
    bp = 1.0 if hyp_len >= ref_len else math.exp(1 - ref_len / max(hyp_len, 1))
    return bp * math.exp(log_avg) * 100


def _chrf_pp_fallback(hypotheses, references, n_char=6, n_word=2, beta=2):
    """Simplified chrF++ (character n-gram F-score + word n-grams)."""
    total_hyp_ngrams = 0
    total_ref_ngrams = 0
    total_matches = 0
    for hyp, ref in zip(hypotheses, references):
        for n in range(1, n_char + 1):
            hyp_ng = Counter(_ngrams(list(hyp), n))
            ref_ng = Counter(_ngrams(list(ref), n))
            matches = sum(min(hyp_ng[ng], ref_ng[ng]) for ng in hyp_ng if ng in ref_ng)
            total_matches += matches
            total_hyp_ngrams += sum(hyp_ng.values())
            total_ref_ngrams += sum(ref_ng.values())
        for n in range(1, n_word + 1):
            hyp_ng = Counter(_ngrams(hyp.split(), n))
            ref_ng = Counter(_ngrams(ref.split(), n))
            matches = sum(min(hyp_ng[ng], ref_ng[ng]) for ng in hyp_ng if ng in ref_ng)
            total_matches += matches
            total_hyp_ngrams += sum(hyp_ng.values())
            total_ref_ngrams += sum(ref_ng.values())
    precision = total_matches / max(total_hyp_ngrams, 1)
    recall = total_matches / max(total_ref_ngrams, 1)
    if precision + recall == 0:
        return 0.0
    beta_sq = beta ** 2
    f_score = (1 + beta_sq) * precision * recall / (beta_sq * precision + recall)
    return f_score * 100


def _sentence_bleu_fallback(hypothesis, reference, max_n=4):
    """Sentence-level BLEU with add-1 smoothing."""
    hyp_tok = hypothesis.split()
    ref_tok = reference.split()
    precisions = []
    for n in range(1, max_n + 1):
        hyp_ng = Counter(_ngrams(hyp_tok, n))
        ref_ng = Counter(_ngrams(ref_tok, n))
        clipped = sum(min(c, ref_ng.get(ng, 0)) for ng, c in hyp_ng.items())
        total = max(len(hyp_tok) - n + 1, 0)
        precisions.append((clipped + 1) / (total + 1))  # add-1 smoothing
    log_avg = sum(math.log(p) for p in precisions) / max_n
    bp = 1.0 if len(hyp_tok) >= len(ref_tok) else math.exp(1 - len(ref_tok) / max(len(hyp_tok), 1))
    return bp * math.exp(log_avg) * 100


# Optuna

In [None]:
print(f"\nPyTorch: {torch.__version__} | CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)} | {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB")

# Loading and analyzing data

In [None]:
df_train = pd.read_csv("/kaggle/input/deep-past-initiative-machine-translation/train.csv")
df_test = pd.read_csv("/kaggle/input/deep-past-initiative-machine-translation/test.csv")
print("- The train set's shape is", df_train.shape[0], "rows and", df_train.shape[1], "columns.")
print("- The test set's shape is", df_test.shape[0], "rows and", df_test.shape[1], "columns.")
df_train.head()

In [None]:
df_train.info()

In [None]:
df_train.describe()

In [None]:
# Compute text lengths
df_train['src_word_count'] = df_train['transliteration'].fillna('').apply(lambda x: len(x.split()))
df_train['tgt_word_count'] = df_train['translation'].fillna('').apply(lambda x: len(x.split()))
df_train['src_char_count'] = df_train['transliteration'].fillna('').str.len()
df_train['tgt_char_count'] = df_train['translation'].fillna('').str.len()
df_test['src_word_count'] = df_test['transliteration'].fillna('').apply(lambda x: len(x.split()))
df_test['src_char_count'] = df_test['transliteration'].fillna('').str.len()

print("Source (transliteration) word count stats:")
print(df_train['src_word_count'].describe())
print("\nTarget (translation) word count stats:")
print(df_train['tgt_word_count'].describe())

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Source word count
axes[0, 0].hist(df_train['src_word_count'], bins=50, color='#8b6914', alpha=0.7, edgecolor='#2c1810')
axes[0, 0].axvline(df_train['src_word_count'].mean(), color='#d4a843', linestyle='--', linewidth=2, label=f"Mean: {df_train['src_word_count'].mean():.1f}")
axes[0, 0].axvline(df_train['src_word_count'].median(), color='#f0d68a', linestyle='-.', linewidth=2, label=f"Median: {df_train['src_word_count'].median():.1f}")
axes[0, 0].set_title('Source (Transliteration) Word Count', fontweight='bold')
axes[0, 0].legend()

# Target word count
axes[0, 1].hist(df_train['tgt_word_count'], bins=50, color='#5c4a2a', alpha=0.7, edgecolor='#2c1810')
axes[0, 1].axvline(df_train['tgt_word_count'].mean(), color='#d4a843', linestyle='--', linewidth=2, label=f"Mean: {df_train['tgt_word_count'].mean():.1f}")
axes[0, 1].axvline(df_train['tgt_word_count'].median(), color='#f0d68a', linestyle='-.', linewidth=2, label=f"Median: {df_train['tgt_word_count'].median():.1f}")
axes[0, 1].set_title('Target (Translation) Word Count', fontweight='bold')
axes[0, 1].legend()

# Source character count
axes[1, 0].hist(df_train['src_char_count'], bins=50, color='#8b6914', alpha=0.7, edgecolor='#2c1810')
axes[1, 0].set_title('Source Character Count', fontweight='bold')

# Target character count
axes[1, 1].hist(df_train['tgt_char_count'], bins=50, color='#5c4a2a', alpha=0.7, edgecolor='#2c1810')
axes[1, 1].set_title('Target Character Count', fontweight='bold')

plt.tight_layout()
plt.show()

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Word count comparison
sns.kdeplot(df_train['src_word_count'], ax=axes[0], label='Train', fill=True, alpha=0.5, color='#8b6914')
sns.kdeplot(df_test['src_word_count'], ax=axes[0], label='Test', fill=True, alpha=0.3, color='#d4a843')
axes[0].set_title('Source Word Count: Train vs Test')
axes[0].legend()

# Character count comparison
sns.kdeplot(df_train['src_char_count'], ax=axes[1], label='Train', fill=True, alpha=0.5, color='#8b6914')
sns.kdeplot(df_test['src_char_count'], ax=axes[1], label='Test', fill=True, alpha=0.3, color='#d4a843')
axes[1].set_title('Source Char Count: Train vs Test')
axes[1].legend()

plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(10, 6))
plt.scatter(df_train['src_word_count'], df_train['tgt_word_count'], alpha=0.3, color='#8b6914', s=10)
plt.xlabel('Source Word Count (Transliteration)')
plt.ylabel('Target Word Count (Translation)')
plt.title('Source vs Target Length Relationship')

# Fit and plot trend line
z = np.polyfit(df_train['src_word_count'], df_train['tgt_word_count'], 1)
p = np.poly1d(z)
x_line = np.linspace(0, df_train['src_word_count'].max(), 100)
plt.plot(x_line, p(x_line), color='#d4a843', linewidth=2, linestyle='--', label=f'Trend: y={z[0]:.2f}x + {z[1]:.2f}')
plt.legend()
plt.tight_layout()
plt.show()

print(f"Correlation between source and target word counts: {df_train['src_word_count'].corr(df_train['tgt_word_count']):.3f}")

In [None]:
# Analyze gap markers in source texts
df_train['has_gap'] = df_train['transliteration'].fillna('').str.contains(r'\bx\b|xx|\.\.\.|…', regex=True)
df_test['has_gap'] = df_test['transliteration'].fillna('').str.contains(r'\bx\b|xx|\.\.\.|…', regex=True)

print(f"Train texts with gaps: {df_train['has_gap'].sum()} ({df_train['has_gap'].mean()*100:.1f}%)")
print(f"Test texts with gaps:  {df_test['has_gap'].sum()} ({df_test['has_gap'].mean()*100:.1f}%)")

# Count gap markers per text
df_train['gap_count'] = df_train['transliteration'].fillna('').apply(
    lambda x: len(re.findall(r'\bx\b|xx+|\.\.\.|…', x))
)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Gap presence pie chart
counts = df_train['has_gap'].value_counts()
labels = ['No Gaps', 'Has Gaps']
colors = ['#8b6914', '#d4a843']
axes[0].pie(counts, labels=labels, colors=colors, autopct='%1.1f%%',
            textprops={'fontsize': 12, 'fontweight': 'bold'})
axes[0].set_title('Gap Marker Presence in Train', fontweight='bold')

# Gap count distribution
axes[1].hist(df_train[df_train['gap_count'] > 0]['gap_count'], bins=30, 
             color='#8b6914', alpha=0.7, edgecolor='#2c1810')
axes[1].set_title('Gap Count Distribution (texts with gaps)', fontweight='bold')
axes[1].set_xlabel('Number of Gap Markers')

plt.tight_layout()
plt.show()

In [None]:
# Most common words in target translations
all_target_words = ' '.join(df_train['translation'].fillna('')).lower().split()
word_counts = Counter(all_target_words)
top_30 = word_counts.most_common(30)

plt.figure(figsize=(14, 6))
words, counts_list = zip(*top_30)
plt.bar(words, counts_list, color='#8b6914', edgecolor='#2c1810')
plt.xticks(rotation=45, ha='right')
plt.title('Top 30 Most Common Words in English Translations', fontweight='bold')
plt.ylabel('Frequency')
plt.tight_layout()
plt.show()

In [None]:
for col in ['src_word_count', 'tgt_word_count']:
    Q1 = df_train[col].quantile(0.25)
    Q3 = df_train[col].quantile(0.75)
    IQR = Q3 - Q1
    lower_bound = Q1 - 1.5 * IQR
    upper_bound = Q3 + 1.5 * IQR
    
    outliers = df_train[(df_train[col] < lower_bound) | (df_train[col] > upper_bound)]
    print(f"{col}: Lower={lower_bound:.0f}, Upper={upper_bound:.0f}, Outliers={outliers.shape[0]}")

# Length ratio analysis
df_train['length_ratio'] = df_train['tgt_word_count'] / df_train['src_word_count'].clip(lower=1)
print(f"\nLength ratio (target/source) stats:")
print(df_train['length_ratio'].describe())

# Data Processing

In [None]:
# ============================================================
# PREPROCESSOR (exact match: chunky_v1_5_0)
# ============================================================

class OptimizedPreprocessor:
    def __init__(self):
        self.patterns = {
            'big_gap': re.compile(r'(\.{3,}|…+|……)'),
            'small_gap': re.compile(r'(xx+|\s+x\s+)'),
        }
    
    def preprocess_input_text(self, text: str) -> str:
        if pd.isna(text):
            return ""
        text = str(text)
        text = self.patterns['big_gap'].sub('<big_gap>', text)
        text = self.patterns['small_gap'].sub('<gap>', text)
        return text
    
    def preprocess_batch(self, texts: List[str]) -> List[str]:
        s = pd.Series(texts).fillna('').astype(str)
        s = s.str.replace(self.patterns['big_gap'], '<big_gap>', regex=True)
        s = s.str.replace(self.patterns['small_gap'], '<gap>', regex=True)
        return s.tolist()


# ============================================================
# AKKADIAN CLAUSE-BOUNDARY CHUNKING (exact match: chunky_v1_5_0)
# ============================================================

CHUNK_MIN_WORDS = 15
CHUNK_MAX_WORDS = 30
CHUNK_THRESHOLD = 50

CLAUSE_MARKERS = [
    r'KIŠIB\s+',
    r'IGI\s+',
    r'um-ma\s+',
    r'a-na\s+\S+\s+qí-bi',
    r'šu-ma\s+',
    r'\.\s+',
    r'\[\.\.\.\]\s*',
]
CLAUSE_PATTERN = re.compile('|'.join(CLAUSE_MARKERS), re.IGNORECASE)


def split_akkadian(text: str, max_words: int = CHUNK_MAX_WORDS, min_words: int = CHUNK_MIN_WORDS) -> List[str]:
    words = text.split()
    if len(words) <= CHUNK_THRESHOLD:
        return [text]
    
    chunks, current_chunk = [], []
    for word in words:
        current_chunk.append(word)
        chunk_text = ' '.join(current_chunk)
        chunk_len = len(current_chunk)
        is_break = bool(CLAUSE_PATTERN.search(chunk_text + ' '))
        
        if chunk_len >= min_words and is_break:
            chunks.append(chunk_text.strip())
            current_chunk = []
        elif chunk_len >= max_words:
            chunks.append(chunk_text.strip())
            current_chunk = []
    
    if current_chunk:
        last_chunk = ' '.join(current_chunk).strip()
        if last_chunk:
            chunks.append(last_chunk)
    
    return chunks if chunks else [text]


# ============================================================
# POSTPROCESSOR (exact match: chunky_v1_5_0)
# ============================================================

def remove_phrase_repeats(text: str) -> str:
    """Remove repeated phrases of 3-8 words using sliding window."""
    if not text:
        return text
    words = text.split()
    if len(words) < 6:
        return text
    for phrase_len in range(8, 2, -1):
        i = 0
        result_words = []
        while i < len(words):
            if i + phrase_len * 2 <= len(words):
                phrase = words[i:i + phrase_len]
                next_phrase = words[i + phrase_len:i + phrase_len * 2]
                if phrase == next_phrase:
                    result_words.extend(phrase)
                    j = i + phrase_len
                    while j + phrase_len <= len(words) and words[j:j + phrase_len] == phrase:
                        j += phrase_len
                    i = j
                    continue
            result_words.append(words[i])
            i += 1
        words = result_words
    return ' '.join(words)


def trim_trailing_fragment(text: str) -> str:
    """Trim trailing incomplete word or sentence fragment."""
    if not text:
        return text
    text = text.rstrip()
    if not text:
        return text
    if len(text) > 100 and text[-1].isalpha():
        for i in range(len(text) - 1, -1, -1):
            if text[i] in '.?!':
                return text[:i + 1]
            if text[i] in "'" and i > 0 and text[i - 1] in '.?!':
                return text[:i + 1]
    return text


class VectorizedPostprocessor:
    def __init__(self, aggressive: bool = True):
        self.aggressive = aggressive
        self.patterns = {
            'gap': re.compile(r'(\[x\]|\(x\)|\bx\b)', re.I),
            'big_gap': re.compile(r'(\.{3,}|…|\[\.+\])'),
            'annotations': re.compile(r'\((fem|plur|pl|sing|singular|plural|\?|!)\..\s*\w*\)', re.I),
            'repeated_words': re.compile(r'\b(\w+)(?:\s+\1\b)+'),
            'whitespace': re.compile(r'\s+'),
            'punct_space': re.compile(r'\s+([.,:])'),
            'repeated_punct': re.compile(r'([.,])\1+'),
        }
        self.subscript_trans = str.maketrans('₀₁₂₃₄₅₆₇₈₉', '0123456789')
        self.special_chars_trans = str.maketrans('ḫḪ', 'hH')
        self.forbidden_chars = '!?()"——<>⌈⌋⌊[]+ʾ/;'
        self.forbidden_trans = str.maketrans('', '', self.forbidden_chars)
    
    def postprocess_batch(self, translations: List[str]) -> List[str]:
        s = pd.Series(translations)
        valid_mask = s.apply(lambda x: isinstance(x, str) and x.strip())
        if not valid_mask.all():
            s[~valid_mask] = ''
        
        s = s.str.translate(self.special_chars_trans)
        s = s.str.translate(self.subscript_trans)
        s = s.str.replace(self.patterns['whitespace'], ' ', regex=True)
        s = s.str.strip()
        
        if self.aggressive:
            s = s.str.replace(self.patterns['gap'], '<gap>', regex=True)
            s = s.str.replace(self.patterns['big_gap'], '<big_gap>', regex=True)
            s = s.str.replace('<gap> <gap>', '<big_gap>', regex=False)
            s = s.str.replace('<big_gap> <big_gap>', '<big_gap>', regex=False)
            s = s.str.replace(self.patterns['annotations'], '', regex=True)
            
            s = s.str.replace('<gap>', '\x00GAP\x00', regex=False)
            s = s.str.replace('<big_gap>', '\x00BIG\x00', regex=False)
            s = s.str.translate(self.forbidden_trans)
            s = s.str.replace('\x00GAP\x00', ' <gap> ', regex=False)
            s = s.str.replace('\x00BIG\x00', ' <big_gap> ', regex=False)
            
            # Fractions (exact match: chunky_v1_5_0 — only ½, ¼, ¾)
            s = s.str.replace(r'(\d+)\.5\b', r'\1½', regex=True)
            s = s.str.replace(r'\b0\.5\b', '½', regex=True)
            s = s.str.replace(r'(\d+)\.25\b', r'\1¼', regex=True)
            s = s.str.replace(r'\b0\.25\b', '¼', regex=True)
            s = s.str.replace(r'(\d+)\.75\b', r'\1¾', regex=True)
            s = s.str.replace(r'\b0\.75\b', '¾', regex=True)
            
            # Remove repeated words/n-grams
            s = s.str.replace(self.patterns['repeated_words'], r'\1', regex=True)
            for n in range(4, 1, -1):
                pattern = r'\b((?:\w+\s+){' + str(n - 1) + r'}\w+)(?:\s+\1\b)+'
                s = s.str.replace(pattern, r'\1', regex=True)
            
            # Sliding-window phrase dedup
            s = s.apply(remove_phrase_repeats)
            
            s = s.str.replace(self.patterns['punct_space'], r'\1', regex=True)
            s = s.str.replace(self.patterns['repeated_punct'], r'\1', regex=True)
            s = s.str.replace(self.patterns['whitespace'], ' ', regex=True)
            s = s.str.strip().str.strip('-').str.strip()
            
            # Trim trailing incomplete fragments
            s = s.apply(trim_trailing_fragment)
        
        return s.tolist()


preprocessor = OptimizedPreprocessor()
postprocessor = VectorizedPostprocessor(aggressive=True)
print("Preprocessor and Postprocessor initialized (chunky_v1_5_0 exact match).")


# Model Loading


In [None]:
MODEL_PATH = "/kaggle/input/datasets/assiaben/final-byt5/byt5-akkadian-optimized-34x"

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
print("Loading model...")
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_PATH)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device).eval()

num_params = sum(p.numel() for p in model.parameters())
print(f"Model loaded: {num_params:,} parameters on {device}")

# Apply BetterTransformer if available
try:
    from optimum.bettertransformer import BetterTransformer
    model = BetterTransformer.transform(model)
    print("BetterTransformer applied")
except Exception as e:
    print(f"BetterTransformer skipped: {e}")

In [None]:
# ============================================================
# DATASET & SAMPLER
# ============================================================

class AkkadianDataset(Dataset):
    def __init__(self, dataframe: pd.DataFrame, preprocessor):
        if 'id' in dataframe.columns:
            self.sample_ids = dataframe['id'].tolist()
        else:
            self.sample_ids = list(range(len(dataframe)))
        raw_texts = dataframe['transliteration'].tolist()
        preprocessed = preprocessor.preprocess_batch(raw_texts)
        self.input_texts = ['translate Akkadian to English: ' + t for t in preprocessed]
        print(f"Dataset created: {len(self.sample_ids)} samples")
    
    def __len__(self):
        return len(self.sample_ids)
    
    def __getitem__(self, index):
        return self.sample_ids[index], self.input_texts[index]


class BucketBatchSampler(Sampler):
    def __init__(self, dataset, batch_size: int, num_buckets: int = 4):
        lengths = [len(text.split()) for _, text in dataset]
        sorted_indices = sorted(range(len(lengths)), key=lambda i: lengths[i])
        bucket_size = max(1, len(sorted_indices) // num_buckets)
        self.buckets = []
        for i in range(num_buckets):
            start = i * bucket_size
            end = None if i == num_buckets - 1 else (i + 1) * bucket_size
            self.buckets.append(sorted_indices[start:end])
        self.batch_size = batch_size
    
    def __iter__(self):
        for bucket in self.buckets:
            for i in range(0, len(bucket), self.batch_size):
                yield bucket[i:i + self.batch_size]
    
    def __len__(self):
        return sum((len(b) + self.batch_size - 1) // self.batch_size for b in self.buckets)

print("Dataset and Sampler classes ready.")


# MOdel training

In [None]:
# Use a small sample for fast tuning
VAL_SIZE = 100
np.random.seed(42)
val_indices = np.random.choice(len(df_train), size=min(VAL_SIZE, len(df_train)), replace=False)
df_val = df_train.iloc[val_indices].reset_index(drop=True)
print(f"Validation set: {len(df_val)} samples for Optuna tuning")


def translate_batch_with_params(texts, length_penalty, num_beams, max_new_tokens=512):
    """Translate a list of texts with specific generation parameters."""
    preprocessed = preprocessor.preprocess_batch(texts)
    prefixed = ['translate Akkadian to English: ' + t for t in preprocessed]
    
    translations = []
    batch_size = 4
    
    with torch.inference_mode():
        for i in range(0, len(prefixed), batch_size):
            batch = prefixed[i:i + batch_size]
            inputs = tokenizer(batch, max_length=512, padding=True, truncation=True, return_tensors='pt')
            input_ids = inputs.input_ids.to(device)
            attention_mask = inputs.attention_mask.to(device)
            
            with autocast():
                outputs = model.generate(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    num_beams=num_beams,
                    max_new_tokens=max_new_tokens,
                    length_penalty=length_penalty,
                    early_stopping=True,
                    use_cache=True,
                )
            
            decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            translations.extend(decoded)
    
    cleaned = postprocessor.postprocess_batch(translations)
    return cleaned



In [None]:

def compute_bleu(predictions, references):
    """Corpus BLEU score."""
    #if USE_SACREBLEU:
    return sacrebleu.corpus_bleu(predictions, [references]).score
    #return _corpus_bleu_fallback(predictions, references)


def compute_chrf(predictions, references):
    """Corpus chrF++ score."""
    #if USE_SACREBLEU:
    return sacrebleu.corpus_chrf(predictions, [references], word_order=2).score
    #return _chrf_pp_fallback(predictions, references)


def compute_sentence_bleu(hypothesis, reference):
    """Sentence-level BLEU."""
    #if USE_SACREBLEU:
    return sacrebleu.sentence_bleu(hypothesis, [reference]).score
    #return _sentence_bleu_fallback(hypothesis, reference)


def compute_competition_score(predictions, references):
    """Compute geometric mean of BLEU and chrF++."""
    bleu_score = compute_bleu(predictions, references)
    chrf_score = compute_chrf(predictions, references)
    
    if bleu_score <= 0 or chrf_score <= 0:
        return 0.0
    
    return math.sqrt(bleu_score * chrf_score)


print("Translation and scoring functions ready.")
print(f"Metrics backend: {'sacrebleu'}")


In [None]:
import optuna, sacrebleu, portalocker

In [None]:
# Known good parameters from top public notebooks
PROVEN_PARAMS = [
    {'length_penalty': 1.5, 'num_beams': 8},   # chunky_v1_5_0 → 35.1
    {'length_penalty': 1.3, 'num_beams': 8},   # adaptive-beams → 35.1
]

def objective(trial):
    length_penalty = trial.suggest_float('length_penalty', 0.8, 2.0)
    num_beams = trial.suggest_int('num_beams', 4, 12)
    
    source_texts = df_val['transliteration'].tolist()
    reference_texts = df_val['translation'].tolist()
    
    predictions = translate_batch_with_params(
        source_texts,
        length_penalty=length_penalty,
        num_beams=num_beams,
    )
    
    score = compute_competition_score(predictions, reference_texts)
    return score


study = optuna.create_study(direction='maximize')

# Enqueue proven baselines so they are always evaluated first
for params in PROVEN_PARAMS:
    study.enqueue_trial(params)

study.optimize(objective, n_trials=50, timeout=3600 * 2)

# Compare Optuna best vs proven baselines
print("\n" + "=" * 60)
print("OPTUNA RESULTS")
print("=" * 60)
print(f"Best Score (geometric mean): {study.best_value:.2f}")
print(f"Best params: {study.best_params}")

# Show all trial results sorted by score
print("\nAll trials (sorted by score):")
trials_sorted = sorted(study.trials, key=lambda t: t.value if t.value is not None else 0, reverse=True)
for t in trials_sorted[:10]:
    tag = ""
    if t.params in PROVEN_PARAMS:
        tag = " [PROVEN BASELINE]"
    #print(f"  Trial {t.number}: score={t.value:.2f}, lp={t.params['length_penalty']:.3f}, beams={t.params['num_beams']}{tag}")
#print("=" * 60)


In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Trial scores
trial_numbers = [t.number for t in study.trials]
trial_values = [t.value for t in study.trials]
axes[0].plot(trial_numbers, trial_values, 'o-', color='#8b6914', markersize=8)
axes[0].axhline(study.best_value, color='#d4a843', linestyle='--', label=f'Best: {study.best_value:.2f}')
axes[0].set_xlabel('Trial Number')
axes[0].set_ylabel('Score (Geometric Mean)')
axes[0].set_title('Optuna Trial Scores', fontweight='bold')
axes[0].legend()

# Parameter importance - length_penalty vs score
lp_values = [t.params['length_penalty'] for t in study.trials]
axes[1].scatter(lp_values, trial_values, c='#8b6914', s=60, edgecolors='#2c1810')
axes[1].set_xlabel('length_penalty')
axes[1].set_ylabel('Score')
axes[1].set_title('length_penalty vs Score', fontweight='bold')

plt.tight_layout()
plt.show()

In [None]:

FIXED_LENGTH_PENALTY = 1.5
FIXED_NUM_BEAMS = 8

# Show Optuna comparison
print("Optuna best vs proven baseline:")
print(f"  Optuna:  lp={study.best_params['length_penalty']:.3f}, beams={study.best_params['num_beams']}, score={study.best_value:.2f}")
print(f"  Proven:  lp={FIXED_LENGTH_PENALTY}, beams={FIXED_NUM_BEAMS}")

# Evaluate proven params on validation
best_length_penalty = FIXED_LENGTH_PENALTY
best_num_beams = FIXED_NUM_BEAMS

val_predictions = translate_batch_with_params(
    df_val['transliteration'].tolist(),
    length_penalty=best_length_penalty,
    num_beams=best_num_beams,
)
val_references = df_val['translation'].tolist()

bleu_score = compute_bleu(val_predictions, val_references)
chrf_score = compute_chrf(val_predictions, val_references)
geo_mean = math.sqrt(bleu_score * chrf_score) if bleu_score > 0 and chrf_score > 0 else 0.0

print(f"\nValidation Results (proven params: lp={best_length_penalty}, beams={best_num_beams}):")
print(f"  BLEU:  {bleu_score:.2f}")
print(f"  chrF++: {chrf_score:.2f}")
print(f"  Geometric Mean: {geo_mean:.2f}")


In [None]:

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Bar chart of metrics
metrics = ['BLEU', 'chrF++', 'Geometric Mean']
values = [bleu_score, chrf_score, geo_mean]
colors = ['#2c1810', '#8b6914', '#d4a843']
bars = axes[0].bar(metrics, values, color=colors, edgecolor='#2c1810', linewidth=1.5)
for bar, val in zip(bars, values):
    axes[0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
                f'{val:.2f}', ha='center', fontweight='bold', fontsize=12)
axes[0].set_ylabel('Score')
axes[0].set_title('Validation Metrics', fontweight='bold')
axes[0].set_ylim(0, max(values) * 1.2)

# Per-sample BLEU distribution
per_sample_bleu = []
for pred, ref in zip(val_predictions, val_references):
    s = compute_sentence_bleu(pred, ref)
    per_sample_bleu.append(s)

axes[1].hist(per_sample_bleu, bins=30, color='#8b6914', alpha=0.7, edgecolor='#2c1810')
axes[1].axvline(np.mean(per_sample_bleu), color='#d4a843', linestyle='--', linewidth=2,
                label=f'Mean: {np.mean(per_sample_bleu):.1f}')
axes[1].set_xlabel('Sentence BLEU')
axes[1].set_ylabel('Count')
axes[1].set_title('Per-Sample BLEU Distribution', fontweight='bold')
axes[1].legend()

plt.tight_layout()
plt.show()

In [None]:
print("Sample Predictions vs References:")
for i in range(min(5, len(val_predictions))):
    s_bleu = compute_sentence_bleu(val_predictions[i], val_references[i])
    print("=" * 70)
    print(f"[{i}] BLEU: {s_bleu:.1f}")
    print(f"SRC:  {df_val.iloc[i]['transliteration'][:150]}")
    print(f"REF:  {val_references[i][:150]}")
    print(f"PRED: {val_predictions[i][:150]}")
print("=" * 70)

In [None]:

BATCH_SIZE = 8
MAX_LENGTH = 512
NUM_WORKERS = 4
NUM_BUCKETS = 4

print(f"Inference config: length_penalty={best_length_penalty}, num_beams={best_num_beams}")
print(f"Test samples: {len(df_test)}")

test_dataset = AkkadianDataset(df_test, preprocessor)

def collate_fn(batch):
    ids = [s[0] for s in batch]
    texts = [s[1] for s in batch]
    tokenized = tokenizer(texts, max_length=MAX_LENGTH, padding=True, truncation=True, return_tensors='pt')
    return ids, tokenized

print(f"Dataset ready: {len(test_dataset)} samples")


In [None]:

results = []
chunked_ids = set()

gen_config_chunk = {
    'num_beams': best_num_beams,
    'max_new_tokens': 512,
    'length_penalty': best_length_penalty,
    'early_stopping': True,
    'use_cache': True,
}

print("Phase 1: Translating long texts with clause-boundary chunking...")

with torch.inference_mode():
    for idx in range(len(test_dataset)):
        sample_id, input_text = test_dataset[idx]
        raw_text = input_text.replace('translate Akkadian to English: ', '')
        
        if len(raw_text.split()) > CHUNK_THRESHOLD:
            chunks = split_akkadian(raw_text)
            prefix = 'translate Akkadian to English: '
            chunk_translations = []
            
            for chunk in chunks:
                inputs = tokenizer(prefix + chunk, return_tensors='pt',
                                  max_length=MAX_LENGTH, truncation=True).to(device)
                if torch.cuda.is_available():
                    with autocast():
                        outputs = model.generate(
                            input_ids=inputs.input_ids,
                            attention_mask=inputs.attention_mask,
                            **gen_config_chunk
                        )
                else:
                    outputs = model.generate(
                        input_ids=inputs.input_ids,
                        attention_mask=inputs.attention_mask,
                        **gen_config_chunk
                    )
                translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
                chunk_translations.append(translation.strip())
            
            full_translation = ' '.join(chunk_translations)
            cleaned = postprocessor.postprocess_batch([full_translation])[0]
            results.append((sample_id, cleaned))
            chunked_ids.add(idx)

print(f"Chunked {len(chunked_ids)} long texts")


In [None]:

print("Phase 2: Batch translating remaining texts...")

if chunked_ids:
    short_indices = [i for i in range(len(test_dataset)) if i not in chunked_ids]
    short_dataset = torch.utils.data.Subset(test_dataset, short_indices)
else:
    short_dataset = test_dataset

if len(short_dataset) > 0:
    if len(short_dataset) >= NUM_BUCKETS:
        batch_sampler_short = BucketBatchSampler(short_dataset, BATCH_SIZE, NUM_BUCKETS)
        dataloader_short = DataLoader(
            short_dataset, batch_sampler=batch_sampler_short,
            collate_fn=collate_fn, num_workers=NUM_WORKERS, pin_memory=True,
            prefetch_factor=2, persistent_workers=True if NUM_WORKERS > 0 else False
        )
    else:
        dataloader_short = DataLoader(
            short_dataset, batch_size=BATCH_SIZE, shuffle=False,
            collate_fn=collate_fn, num_workers=NUM_WORKERS, pin_memory=True,
            prefetch_factor=2, persistent_workers=True if NUM_WORKERS > 0 else False
        )
    
    base_gen_config = {
        'max_new_tokens': 512,
        'length_penalty': best_length_penalty,
        'early_stopping': True,
        'use_cache': True,
    }
    
    with torch.inference_mode():
        for batch_idx, (batch_ids, tokenized) in enumerate(tqdm(dataloader_short, desc="Translating")):
            input_ids = tokenized.input_ids.to(device)
            attention_mask = tokenized.attention_mask.to(device)
            
            # Adaptive beams (exact match: chunky_v1_5_0)
            lengths = attention_mask.sum(dim=1)
            beam_sizes = torch.where(
                lengths < 100,
                torch.tensor(max(4, best_num_beams // 2)),
                torch.tensor(best_num_beams),
            )
            adaptive_beams = int(beam_sizes[0].item())
            
            gen_config = {**base_gen_config, 'num_beams': adaptive_beams}
            
            if torch.cuda.is_available():
                with autocast():
                    outputs = model.generate(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        **gen_config
                    )
            else:
                outputs = model.generate(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    **gen_config
                )
            
            translations = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            cleaned = postprocessor.postprocess_batch(translations)
            results.extend(zip(batch_ids, cleaned))
            
            if torch.cuda.is_available() and batch_idx % 10 == 0:
                torch.cuda.empty_cache()

print(f"\nTotal translations: {len(results)}")
if torch.cuda.is_available():
    torch.cuda.empty_cache()


In [None]:
# Build submission
submission = pd.DataFrame(results, columns=['id', 'translation'])
submission = submission.sort_values('id').reset_index(drop=True)
submission = submission.drop_duplicates()

print(submission)

# Validation checks
assert len(submission) == len(df_test), f"Expected {len(df_test)} rows, got {len(submission)}"
empty_count = submission['translation'].str.strip().eq('').sum()
print(f"Submission shape: {submission.shape}")
print(f"Empty translations: {empty_count}")
print(f"Translation length range: [{submission['translation'].str.len().min()}, {submission['translation'].str.len().max()}]")

submission.to_csv('submission.csv', index=False)
print(f"\nSaved submission.csv")
submission.head(10)