In [1]:
import re
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# =========================
# Config
# =========================
TEST_DATA_PATH = "/kaggle/input/deep-past-initiative-machine-translation/test.csv"

MODEL1_PATH = "/kaggle/input/byt5-base-big-data2"
MODEL2_PATH = "/kaggle/input/byt5-akkadian-model"
MODEL3_PATH = "/kaggle/input/train-gap-all-2/byt5-base-akkadian_gap_setence2"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

PREFIX = "translate Akkadian to English: "
MAX_SOURCE_LEN = 512

BATCH_SIZE = 8

# Decoding defaults close to your baseline; tune later
NUM_BEAMS = 6
MAX_NEW_TOKENS = 384          # IMPORTANT: avoid truncation
LENGTH_PENALTY = 1.05
EARLY_STOPPING = True

# IMPORTANT for score stability (ByT5 often dislikes fp16 decoding)
USE_AMP = False

# =========================
# Use YOUR original replace_gaps (recommended)
# =========================
def replace_gaps(text):
    if pd.isna(text):
        return text
    text = str(text)
    text = re.sub(r'\.3(?:\s+\.3)+\.{3}(?:\s+\.{3})+\s+\.{3}(?:\s+\.{3})+', '<big_gap>', text)
    text = re.sub(r'\.3(?:\s+\.3)+\.{3}(?:\s+\.{3})+', '<big_gap>', text)
    text = re.sub(r'\.{3}(?:\s+\.{3})+', '<big_gap>', text)
    text = re.sub(r'xx', '<gap>', text)
    text = re.sub(r' x ', ' <gap> ', text)
    text = re.sub(r'……', '<big_gap>', text)
    text = re.sub(r'\.\.\.\.\.\.', '<big_gap>', text)
    text = re.sub(r'…', '<big_gap>', text)
    text = re.sub(r'\.\.\.', '<big_gap>', text)
    return text

# =========================
# Weighted checkpoint averaging (your approach)
# =========================
perf1, perf2, perf3 = 0.98, 1.00, 0.40
total = perf1 + perf2 + perf3
w1, w2, w3 = perf1/total, perf2/total, perf3/total
print(f"Weights: w1={w1:.4f}, w2={w2:.4f}, w3={w3:.4f}")

print("Loading models...")
m1 = AutoModelForSeq2SeqLM.from_pretrained(MODEL1_PATH)
m2 = AutoModelForSeq2SeqLM.from_pretrained(MODEL2_PATH)
m3 = AutoModelForSeq2SeqLM.from_pretrained(MODEL3_PATH)

sd1, sd2, sd3 = m1.state_dict(), m2.state_dict(), m3.state_dict()

print("Averaging weights...")
final_sd = sd2.copy()
for k in final_sd:
    if k in sd1 and k in sd3:
        final_sd[k] = w1 * sd1[k] + w2 * sd2[k] + w3 * sd3[k]
    elif k in sd1:
        final_sd[k] = w1 * sd1[k] + (w2 + w3) * sd2[k]
    elif k in sd3:
        final_sd[k] = w3 * sd3[k] + (w1 + w2) * sd2[k]

model = AutoModelForSeq2SeqLM.from_pretrained(MODEL2_PATH)
model.load_state_dict(final_sd)
model.to(DEVICE).eval()
model.float()   # keep fp32

tokenizer = AutoTokenizer.from_pretrained(MODEL2_PATH)

# =========================
# Dataset + dynamic padding
# =========================
test_df = pd.read_csv(TEST_DATA_PATH)
test_df["transliteration"] = test_df["transliteration"].apply(replace_gaps)

class InferenceDataset(Dataset):
    def __init__(self, df):
        self.ids = df["id"].tolist()
        self.texts = [PREFIX + t for t in df["transliteration"].astype(str).tolist()]

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return self.ids[idx], self.texts[idx]

def collate_fn(batch):
    ids, texts = zip(*batch)
    enc = tokenizer(
        list(texts),
        max_length=MAX_SOURCE_LEN,
        truncation=True,
        padding=True,
        return_tensors="pt"
    )
    return list(ids), enc["input_ids"], enc["attention_mask"]

loader = DataLoader(
    InferenceDataset(test_df),
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=(DEVICE.type == "cuda"),
    collate_fn=collate_fn
)

# =========================
# Inference
# =========================
all_ids, all_pred = [], []
torch.set_grad_enabled(False)

with torch.inference_mode():
    for ids, input_ids, attention_mask in loader:
        input_ids = input_ids.to(DEVICE)
        attention_mask = attention_mask.to(DEVICE)

        # Keep decoding in fp32 for score stability
        if USE_AMP:
            ctx = torch.autocast(device_type="cuda", dtype=torch.float16)
        else:
            from contextlib import nullcontext
            ctx = nullcontext()

        with ctx:
            outputs = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                num_beams=NUM_BEAMS,
                max_new_tokens=MAX_NEW_TOKENS,
                length_penalty=LENGTH_PENALTY,
                early_stopping=EARLY_STOPPING,
            )

        decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        decoded = [d.strip() if d.strip() else "broken text" for d in decoded]

        all_ids.extend(ids)
        all_pred.extend(decoded)

submission = pd.DataFrame({"id": all_ids, "translation": all_pred})
submission.to_csv("submission.csv", index=False)
print("Saved submission.csv")
print(submission.head())

Weights: w1=0.4118, w2=0.4202, w3=0.1681
Loading models...


2026-01-14 21:40:51.759827: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1768426851.958169      24 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1768426852.013152      24 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1768426852.478118      24 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1768426852.478163      24 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1768426852.478166      24 computation_placer.cc:177] computation placer alr

Averaging weights...
Saved submission.csv
   id                                        translation
0   0  From the Kanesh colony to the <big_gap> of our...
1   1  In the tablet from the City, you wrote to me i...
2   2  Just as you hear our letter, he has given eith...
3   3  I sent our tablets to every single day and nig...
