In [5]:
# mT5 (RU -> EN) Fine-tuning

!pip -q install -U datasets sacrebleu transformers accelerate sentencepiece tqdm

import os, random, re, math, json, time
from dataclasses import dataclass
from typing import List, Dict

import torch
from torch.utils.data import DataLoader

from datasets import load_dataset
import sacrebleu
from tqdm.auto import tqdm

from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    get_linear_schedule_with_warmup
)

def hr(title=None, ch="="):
    if title:
        print(f"\n{ch*10} {title} {ch*10}")
    else:
        print(ch*32)

def print_kv(k, v, pad=22):
    print(f"{k:<{pad}}: {v}")

SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hr("Runtime")
print_kv("torch.cuda.is_available()", torch.cuda.is_available())
if DEVICE.type == "cuda":
    print_kv("GPU", torch.cuda.get_device_name(0))
print_kv("Device", DEVICE)

hr("Mount Drive")
DRIVE_ON = True
DRIVE_DIR = "/content/drive/MyDrive/mt5_ru_en"

if DRIVE_ON:
    try:
        from google.colab import drive
        drive.mount("/content/drive")
        os.makedirs(DRIVE_DIR, exist_ok=True)
        print_kv("Drive save dir", DRIVE_DIR)
    except Exception as e:
        DRIVE_ON = False
        print("Drive mount failed, continuing without Drive:", e)


DATASET_NAME = "Helsinki-NLP/opus-100"
DATASET_CONFIG = "en-ru"      # contains both en and ru in 'translation'
SRC_LANG = "ru"
TGT_LANG = "en"

MODEL_NAME = "google/mt5-small"

N_TRAIN = 50_000     # training subset
N_DEV   = 2_000
N_TEST  = 2_000

MAX_SRC_LEN = 96
MAX_TGT_LEN = 96

EPOCHS = 3

TRAIN_BS = 4
EVAL_BS  = 8
GRAD_ACCUM = 4

LR = 2e-4
WEIGHT_DECAY = 0.01
WARMUP_RATIO = 0.06

GRAD_CLIP = 1.0


USE_AMP = False
EVAL_SENTS_DEV  = 500
EVAL_SENTS_TEST = 500


NUM_BEAMS = 4
GEN_MAX_LEN = 128


OUT_DIR = "/content/mt5_ru_en"
CKPT_DIR = os.path.join(OUT_DIR, "checkpoints")
os.makedirs(CKPT_DIR, exist_ok=True)


hr("Load dataset")
ds = load_dataset(DATASET_NAME, DATASET_CONFIG)
train_data = ds["train"]
dev_data   = ds["validation"]
test_data  = ds["test"]

print_kv("Train size", len(train_data))
print_kv("Dev size", len(dev_data))
print_kv("Test size", len(test_data))
print_kv("Train used", N_TRAIN)


hr("Load model/tokenizer")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
model.to(DEVICE)

n_params = sum(p.numel() for p in model.parameters())
print_kv("Model", MODEL_NAME)
print_kv("Params", f"{n_params:,}")


hr("Tokenize/encode")

def preprocess(batch):
    src_texts = [ex[SRC_LANG] for ex in batch["translation"]]
    tgt_texts = [ex[TGT_LANG] for ex in batch["translation"]]

    model_inputs = tokenizer(
        src_texts,
        max_length=MAX_SRC_LEN,
        truncation=True,
        padding=False
    )

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            tgt_texts,
            max_length=MAX_TGT_LEN,
            truncation=True,
            padding=False
        )

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs
train_small = train_data.select(range(N_TRAIN))
dev_small   = dev_data.select(range(min(N_DEV, len(dev_data))))
test_small  = test_data.select(range(min(N_TEST, len(test_data))))

train_enc = train_small.map(preprocess, batched=True, remove_columns=train_small.column_names)
dev_enc   = dev_small.map(preprocess, batched=True, remove_columns=dev_small.column_names)
test_enc  = test_small.map(preprocess, batched=True, remove_columns=test_small.column_names)

print_kv("Train encoded", len(train_enc))
print_kv("Dev encoded", len(dev_enc))
print_kv("Test encoded", len(test_enc))

@dataclass
class DataCollatorSeq2SeqSimple:
    tokenizer: AutoTokenizer

    def __call__(self, features: List[Dict]) -> Dict[str, torch.Tensor]:
        batch = self.tokenizer.pad(
            [{"input_ids": f["input_ids"], "attention_mask": f["attention_mask"]} for f in features],
            padding=True,
            return_tensors="pt"
        )

        labels = self.tokenizer.pad(
            [{"input_ids": f["labels"]} for f in features],
            padding=True,
            return_tensors="pt"
        )["input_ids"]
        labels[labels == self.tokenizer.pad_token_id] = -100
        batch["labels"] = labels
        return batch

collator = DataCollatorSeq2SeqSimple(tokenizer)

train_loader = DataLoader(train_enc, batch_size=TRAIN_BS, shuffle=True, collate_fn=collator)
dev_loader   = DataLoader(dev_enc, batch_size=EVAL_BS, shuffle=False, collate_fn=collator)
test_loader  = DataLoader(test_enc, batch_size=EVAL_BS, shuffle=False, collate_fn=collator)

hr("Dataloaders")
print_kv("Train batches/epoch", len(train_loader))
print_kv("Dev eval sents", EVAL_SENTS_DEV)
print_kv("Test eval sents", EVAL_SENTS_TEST)

hr("Optimizer/Scheduler")

optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

total_update_steps = math.ceil(len(train_loader) / GRAD_ACCUM) * EPOCHS
warmup_steps = int(WARMUP_RATIO * total_update_steps)

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_update_steps
)

print_kv("Total updates", total_update_steps)
print_kv("Warmup steps", warmup_steps)
print_kv("LR", LR)
print_kv("Grad accum", GRAD_ACCUM)
print_kv("Effective batch", TRAIN_BS * GRAD_ACCUM)
print_kv("AMP", USE_AMP)

scaler = torch.cuda.amp.GradScaler(enabled=(USE_AMP and DEVICE.type == "cuda"))

@torch.no_grad()
def decode_subset_bleu(split_name, raw_split, max_sents=500):
    n = min(max_sents, len(raw_split))
    refs, hyps = [], []

    model.eval()
    for i in tqdm(range(n), desc=f"{split_name} BLEU ({n} sents)"):
        ex = raw_split[i]["translation"]
        src = ex[SRC_LANG]
        ref = ex[TGT_LANG]

        inputs = tokenizer(
            src, return_tensors="pt", truncation=True, max_length=MAX_SRC_LEN
        ).to(DEVICE)

        gen = model.generate(
            **inputs,
            num_beams=NUM_BEAMS,
            max_length=GEN_MAX_LEN
        )
        hyp = tokenizer.decode(gen[0], skip_special_tokens=True)

        hyps.append(hyp)
        refs.append(ref)

    bleu = sacrebleu.corpus_bleu(hyps, [refs])
    return bleu

hr("Save helpers")

def save_checkpoint(epoch, dev_bleu, test_bleu, tag="epoch"):
    fname = f"{tag}_{epoch}.pt"
    local_path = os.path.join(CKPT_DIR, fname)

    ckpt = {
        "epoch": epoch,
        "model_name": MODEL_NAME,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict(),
        "dev_bleu": float(dev_bleu),
        "test_bleu": float(test_bleu),
        "config": {
            "N_TRAIN": N_TRAIN,
            "MAX_SRC_LEN": MAX_SRC_LEN,
            "MAX_TGT_LEN": MAX_TGT_LEN,
            "LR": LR,
            "TRAIN_BS": TRAIN_BS,
            "GRAD_ACCUM": GRAD_ACCUM,
            "EPOCHS": EPOCHS,
            "NUM_BEAMS": NUM_BEAMS
        }
    }

    torch.save(ckpt, local_path)
    print_kv("Saved local", local_path)

    if DRIVE_ON:
        drive_path = os.path.join(DRIVE_DIR, fname)
        torch.save(ckpt, drive_path)
        print_kv("Saved drive", drive_path)

hr("Train")

bleu_curve = []

global_step = 0
for epoch in range(1, EPOCHS + 1):
    model.train()
    optimizer.zero_grad(set_to_none=True)

    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}", leave=True)
    running_loss = 0.0
    steps_in_epoch = 0

    for step, batch in enumerate(pbar, start=1):
        batch = {k: v.to(DEVICE) for k, v in batch.items()}

        with torch.cuda.amp.autocast(enabled=(USE_AMP and DEVICE.type == "cuda")):
            out = model(**batch)
            loss = out.loss / GRAD_ACCUM

        if USE_AMP and DEVICE.type == "cuda":
            scaler.scale(loss).backward()
        else:
            loss.backward()

        running_loss += float(loss.item()) * GRAD_ACCUM
        steps_in_epoch += 1

        if step % GRAD_ACCUM == 0:
            if USE_AMP and DEVICE.type == "cuda":
                scaler.unscale_(optimizer)

            torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)

            if USE_AMP and DEVICE.type == "cuda":
                scaler.step(optimizer)
                scaler.update()
            else:
                optimizer.step()

            scheduler.step()
            optimizer.zero_grad(set_to_none=True)
            global_step += 1

        if step % 50 == 0:
            pbar.set_postfix(loss=f"{running_loss/steps_in_epoch:.4f}", lr=f"{scheduler.get_last_lr()[0]:.2e}")

        if not math.isfinite(running_loss/steps_in_epoch):
            raise RuntimeError("Loss became NaN/Inf. Reduce LR, keep AMP off, or increase grad clipping.")

    avg_train_loss = running_loss / max(1, steps_in_epoch)

    hr(f"Iteration/Epoch {epoch} evaluation", ch="-")
    dev_bleu = decode_subset_bleu("Dev", dev_small, max_sents=EVAL_SENTS_DEV)
    test_bleu = decode_subset_bleu("Test", test_small, max_sents=EVAL_SENTS_TEST)

    print_kv("Avg train loss", f"{avg_train_loss:.4f}")
    print_kv("Dev BLEU", f"{dev_bleu.score:.2f}")
    print_kv("Test BLEU", f"{test_bleu.score:.2f}")

    bleu_curve.append((epoch, float(dev_bleu.score), float(test_bleu.score)))

    save_checkpoint(epoch, dev_bleu.score, test_bleu.score, tag="epoch")

hr("Demo translations")

examples_ru = [
    "я люблю машинное обучение",
    "сегодня погода очень хорошая, но немного холодно",
    "пожалуйста, скажи мне где находится ближайшая станция метро"
]

model.eval()
for i, ru_sent in enumerate(examples_ru, 1):
    inputs = tokenizer(ru_sent, return_tensors="pt", truncation=True, max_length=MAX_SRC_LEN).to(DEVICE)
    gen = model.generate(**inputs, num_beams=NUM_BEAMS, max_length=GEN_MAX_LEN)
    en_hyp = tokenizer.decode(gen[0], skip_special_tokens=True)

    print(f"\nExample {i}")
    print_kv("RU", ru_sent, pad=6)
    print_kv("EN*", en_hyp,  pad=6)

hr("BLEU per epoch")
for (e, d, t) in bleu_curve:
    print(f"Epoch {e}: Dev {d:.2f} | Test {t:.2f}")

hr("Save final model")
FINAL_DIR = os.path.join(OUT_DIR, "final")
os.makedirs(FINAL_DIR, exist_ok=True)

model.save_pretrained(FINAL_DIR)
tokenizer.save_pretrained(FINAL_DIR)
print_kv("Final local", FINAL_DIR)

if DRIVE_ON:
    FINAL_DRIVE = os.path.join(DRIVE_DIR, "final")
    os.makedirs(FINAL_DRIVE, exist_ok=True)
    model.save_pretrained(FINAL_DRIVE)
    tokenizer.save_pretrained(FINAL_DRIVE)
    print_kv("Final drive", FINAL_DRIVE)



torch.cuda.is_available(): True
GPU                   : Tesla T4
Device                : cuda

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Drive save dir        : /content/drive/MyDrive/mt5_ru_en

Train size            : 1000000
Dev size              : 2000
Test size             : 2000
Train used            : 50000

Model                 : google/mt5-small
Params                : 300,176,768



Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

Train encoded         : 50000
Dev encoded           : 2000
Test encoded          : 2000

Train batches/epoch   : 12500
Dev eval sents        : 500
Test eval sents       : 500

Total updates         : 9375
Warmup steps          : 562
LR                    : 0.0002
Grad accum            : 4
Effective batch       : 16
AMP                   : False




  scaler = torch.cuda.amp.GradScaler(enabled=(USE_AMP and DEVICE.type == "cuda"))


Epoch 1/3:   0%|          | 0/12500 [00:00<?, ?it/s]

You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  with torch.cuda.amp.autocast(enabled=(USE_AMP and DEVICE.type == "cuda")):



---------- Iteration/Epoch 1 evaluation ----------


Dev BLEU (500 sents):   0%|          | 0/500 [00:00<?, ?it/s]

Test BLEU (500 sents):   0%|          | 0/500 [00:00<?, ?it/s]

Avg train loss        : 4.6437
Dev BLEU              : 17.61
Test BLEU             : 17.18
Saved local           : /content/mt5_ru_en/checkpoints/epoch_1.pt
Saved drive           : /content/drive/MyDrive/mt5_ru_en/epoch_1.pt


Epoch 2/3:   0%|          | 0/12500 [00:00<?, ?it/s]


---------- Iteration/Epoch 2 evaluation ----------


Dev BLEU (500 sents):   0%|          | 0/500 [00:00<?, ?it/s]

Test BLEU (500 sents):   0%|          | 0/500 [00:00<?, ?it/s]

Avg train loss        : 2.7955
Dev BLEU              : 21.04
Test BLEU             : 19.03
Saved local           : /content/mt5_ru_en/checkpoints/epoch_2.pt
Saved drive           : /content/drive/MyDrive/mt5_ru_en/epoch_2.pt


Epoch 3/3:   0%|          | 0/12500 [00:00<?, ?it/s]


---------- Iteration/Epoch 3 evaluation ----------


Dev BLEU (500 sents):   0%|          | 0/500 [00:00<?, ?it/s]

Test BLEU (500 sents):   0%|          | 0/500 [00:00<?, ?it/s]

Avg train loss        : 2.6145
Dev BLEU              : 20.98
Test BLEU             : 20.32
Saved local           : /content/mt5_ru_en/checkpoints/epoch_3.pt
Saved drive           : /content/drive/MyDrive/mt5_ru_en/epoch_3.pt


Example 1
RU    : я люблю машинное обучение
EN*   : I love a machine training

Example 2
RU    : сегодня погода очень хорошая, но немного холодно
EN*   : It's very good, but it's very cold.

Example 3
RU    : пожалуйста, скажи мне где находится ближайшая станция метро
EN*   : Please tell me where the station is located

Epoch 1: Dev 17.61 | Test 17.18
Epoch 2: Dev 21.04 | Test 19.03
Epoch 3: Dev 20.98 | Test 20.32

Final local           : /content/mt5_ru_en/final
Final drive           : /content/drive/MyDrive/mt5_ru_en/final
