In [None]:
import re
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

CFG = {
    "data_path": "/kaggle/input/deep-past-initiative-machine-translation/test.csv",
    "models": [
        "/kaggle/input/byt5-base-big-data2",
        "/kaggle/input/byt5-akkadian-model",
        "/kaggle/input/d/assiaben/final-byt5/byt5-akkadian-optimized-34x"
    ],
    "weights": [0.995, 0.985, 0.99], 
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    "max_len": 512,
    "batch_size": 8,
    "gen_params": {
        "num_beams": 8,
        "max_new_tokens": 512,
        "length_penalty": 1.10,
        "early_stopping": True
    }
}

In [None]:
def clean_input(text):
    if pd.isna(text): return ""
    t = str(text)
    t = re.sub(r'(\.{3,}|…+|……)', '<big_gap>', t)
    t = re.sub(r'(xx+|\s+x\s+)', '<gap>', t)
    return t

def hnr(text):
    if not isinstance(text, str) or not text.strip(): 
        return ""
    
    t = text
    t = t.replace('ḫ', 'h').replace('Ḫ', 'H')
    sub_map = str.maketrans("₀₁₂₃₄₅₆₇₈₉", "0123456789")
    t = t.translate(sub_map)
    t = re.sub(r'(\[x\]|\(x\)|\bx\b)', '<gap>', t, flags=re.I)
    t = re.sub(r'(\.{3,}|…|\[\.+\])', '<big_gap>', t)
    t = re.sub(r'<gap>\s*<gap>', ' <big_gap> ', t)
    t = re.sub(r'<big_gap>\s*<big_gap>', ' <big_gap> ', t)
    t = re.sub(
        r'\((fem|plur|pl|sing|singular|plural|\?|!)\.?\s*\w*\)', 
        '', 
        t, 
        flags=re.I
    )
    t = t.replace('<gap>', '\x00GAP\x00').replace('<big_gap>', '\x00BIG\x00')
    bad_chars = '!?()"—–<>⌈⌋⌊[]+ʾ/;'
    t = t.translate(str.maketrans('', '', bad_chars))
    t = t.replace('\x00GAP\x00', ' <gap> ').replace('\x00BIG\x00', ' <big_gap> ')
    t = t.replace("ד", "")
    frac_map = {
        r'\.5\b': ' ½', r'\.25\b': ' ¼', r'\.75\b': ' ¾',
        r'\.33+\d*\b': ' ⅓', r'\.66+\d*\b': ' ⅔'
    }
    for pat, rep in frac_map.items():
        t = re.sub(r'(\d+)' + pat, r'\1' + rep, t)
        t = re.sub(r'\b0' + pat, rep.strip(), t)
    t = re.sub(r'\b(\w+)(?:\s+\1\b)+', r'\1', t)
    for n in range(4, 1, -1):
        pat = r'\b((?:\w+\s+){' + str(n-1) + r'}\w+)(?:\s+\1\b)+'
        t = re.sub(pat, r'\1', t)
    t = re.sub(r'\s+([.,:])', r'\1', t)
    t = re.sub(r'([.,])\1+', r'\1', t)
    return re.sub(r'\s+', ' ', t).strip().strip('-').strip()

def load_blended_model():
    total_score = sum(CFG['weights'])
    W = [w / total_score for w in CFG['weights']]
    
    base_model = AutoModelForSeq2SeqLM.from_pretrained(CFG['models'][1])
    final_sd = base_model.state_dict()
    
    sd_m1 = AutoModelForSeq2SeqLM.from_pretrained(CFG['models'][0]).state_dict()
    sd_m3 = AutoModelForSeq2SeqLM.from_pretrained(CFG['models'][2]).state_dict()
    
    for k in final_sd:
        val = W[1] * final_sd[k]
        norm = W[1]
        
        if k in sd_m1: 
            val += W[0] * sd_m1[k]
            norm += W[0]
        if k in sd_m3: 
            val += W[2] * sd_m3[k]
            norm += W[2]
            
        final_sd[k] = val / norm
        
    base_model.load_state_dict(final_sd)
    return base_model.to(CFG['device']).eval().float()

class AkkadDataset(Dataset):
    def __init__(self, df):
        self.ids = df['id'].tolist()
        self.texts = ["translate Akkadian to English: " + str(t) for t in df['transliteration']]
    def __len__(self): return len(self.ids)
    def __getitem__(self, i): return self.ids[i], self.texts[i]

In [None]:
if __name__ == "__main__":
    df = pd.read_csv(CFG['data_path'])
    df['transliteration'] = df['transliteration'].apply(clean_input)
    
    model = load_blended_model()
    tokenizer = AutoTokenizer.from_pretrained(CFG['models'][1])
    
    loader = DataLoader(
        AkkadDataset(df), 
        batch_size=CFG['batch_size'], 
        shuffle=False,
        num_workers=2,
        collate_fn=lambda b: (
            [x[0] for x in b], 
            tokenizer([x[1] for x in b], max_length=CFG['max_len'], padding=True, truncation=True, return_tensors="pt")
        )
    )
    
    results = []
    
    with torch.inference_mode():
        for ids, inputs in loader:
            outputs = model.generate(
                input_ids=inputs.input_ids.to(CFG['device']),
                attention_mask=inputs.attention_mask.to(CFG['device']),
                **CFG['gen_params']
            )
            
            decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            cleaned = [hnr(txt) for txt in decoded]
            
            results.extend(zip(ids, cleaned))

    sub = pd.DataFrame(results, columns=['id', 'translation'])
    sub.to_csv("submission.csv", index=False)

In [None]:
for index, row in sub.iterrows():
    print(f"ID: {row['id']}")
    print(f"Translation: {row['translation']}")
    print("-" * 20)