In [1]:
!pip install bert_score

[0m

In [2]:
# --- EN→ES fine-tuning on ~8GB GPU -------------------

import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"

import numpy as np
import torch
from datasets import load_dataset
from transformers import (
    MBartForConditionalGeneration,
    MBart50TokenizerFast,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq,
    EarlyStoppingCallback,
    set_seed,
)

set_seed(42)
MAX_LEN = 128

# BERTScore import 
try:
    from bert_score import score as bertscore_fn
except Exception:
    import sys, subprocess
    subprocess.check_call([sys.executable, "-m", "pip", "install", "bert_score"])
    from bert_score import score as bertscore_fn

# print device info
if torch.cuda.is_available():
    print("✅ CUDA available:", torch.cuda.get_device_name(0))
else:
    print("⚠️ CUDA not available; running on CPU.")

# ---------------- Data: OPUS100 en-es, sample 5k/1k ----------------
ds_en_es = load_dataset("opus100", "en-es")

def small_split(ds, train_n=5000, val_n=1000, seed=42):
    train_full = ds["train"].shuffle(seed=seed)
    n = min(train_n + val_n, len(train_full))
    selected = train_full.select(range(n))
    train = selected.select(range(min(train_n, len(selected))))
    val_start = min(train_n, len(selected))
    val_end = min(train_n + val_n, len(selected))
    val = selected.select(range(val_start, val_end))
    return train, val

es_train, es_val = small_split(ds_en_es)
print(f"EN-ES: train={len(es_train)} val={len(es_val)}")

# ---------------- Model & tokenizer (mBART-50) ----------------
MODEL_NAME = "facebook/mbart-large-50"

tokenizer_es = MBart50TokenizerFast.from_pretrained(MODEL_NAME)
model_es = MBartForConditionalGeneration.from_pretrained(MODEL_NAME)
tokenizer_es.src_lang = "en_XX"
tokenizer_es.tgt_lang = "es_XX"
model_es.config.forced_bos_token_id = tokenizer_es.lang_code_to_id["es_XX"]


model_es.config.use_cache = False
# ✅ Same behavior as your call, but explicit variant to silence the warning
model_es.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})

def preprocess_pair(dataset, src_key, tgt_key, tokenizer):
    def fn(batch):
        srcs = [t[src_key] for t in batch["translation"]]
        tgts = [t[tgt_key] for t in batch["translation"]]

        inputs = tokenizer(srcs, max_length=MAX_LEN, truncation=True)
        try:
            labels = tokenizer(text_target=tgts, max_length=MAX_LEN, truncation=True)
        except TypeError:
            with tokenizer.as_target_tokenizer():
                labels = tokenizer(tgts, max_length=MAX_LEN, truncation=True)

        inputs["labels"] = labels["input_ids"]   # collator pads & masks to -100
        return inputs

    return dataset.map(fn, batched=True, remove_columns=["translation"])

tokenized_es_train = preprocess_pair(es_train, "en", "es", tokenizer_es)
tokenized_es_val   = preprocess_pair(es_val,   "en", "es", tokenizer_es)

data_collator_es = DataCollatorForSeq2Seq(tokenizer_es, model=model_es, label_pad_token_id=-100)

def make_bertscore_metrics(tokenizer, lang_code_for_bertscore):
    device = "cpu"  # keep metric on CPU to avoid eval VRAM spikes
    def compute_metrics(eval_pred):
        preds, labels = eval_pred
        if isinstance(preds, tuple):
            preds = preds[0]
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
        P, R, F1 = bertscore_fn(decoded_preds, decoded_labels,
                                lang=lang_code_for_bertscore, device=device, verbose=False)
        return {
            "bertscore_f1": float(F1.mean()),
            "bertscore_precision": float(P.mean()),
            "bertscore_recall": float(R.mean()),
        }
    return compute_metrics

# SETTING PRECISION
use_bf16 = torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8

training_args_es = Seq2SeqTrainingArguments(
    optim='adafactor',
    output_dir="mbart50_opus_en_es",
    num_train_epochs=3,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=8,
    eval_accumulation_steps=8,
    learning_rate=3e-5,
    evaluation_strategy="steps",
    eval_steps=300,
    save_strategy="steps",
    save_steps=300,
    load_best_model_at_end=True,
    metric_for_best_model="bertscore_f1",
    greater_is_better=True,
    logging_steps=300,
    predict_with_generate=True,
    fp16=not use_bf16,
    bf16=use_bf16,
    report_to="none",
)

trainer_es = Seq2SeqTrainer(
    model=model_es,
    args=training_args_es,
    train_dataset=tokenized_es_train,
    eval_dataset=tokenized_es_val,
    tokenizer=tokenizer_es,
    data_collator=data_collator_es,
    compute_metrics=make_bertscore_metrics(tokenizer_es, "es"),
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
)

print("🚀 Training EN→ES ...")
es_out = trainer_es.train()
print(es_out)
best_es = trainer_es.state.best_model_checkpoint
print("✅ Best EN-ES checkpoint:", best_es)

2025-08-14 19:10:29.220512: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-08-14 19:10:29.220580: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-08-14 19:10:29.222104: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-08-14 19:10:29.230529: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


✅ CUDA available: Quadro RTX 4000


Downloading readme: 0.00B [00:00, ?B/s]

Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/237k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/99.6M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/238k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating test split:   0%|          | 0/2000 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/1000000 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/2000 [00:00<?, ? examples/s]

EN-ES: train=5000 val=1000


tokenizer_config.json:   0%|          | 0.00/531 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/649 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/2.44G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/261 [00:00<?, ?B/s]

Map:   0%|          | 0/5000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

🚀 Training EN→ES ...


You're using a MBart50TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss,Validation Loss,Bertscore F1,Bertscore Precision,Bertscore Recall
300,3.4065,1.835719,0.833806,0.842518,0.826207
600,1.8525,1.668589,0.842682,0.849811,0.83648
900,1.2597,1.638796,0.846448,0.85105,0.842609
1200,1.2424,1.646887,0.846144,0.852111,0.841018
1500,0.8452,1.757743,0.845749,0.848992,0.843239




tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/625 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/996k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.96M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/714M [00:00<?, ?B/s]

There were missing keys in the checkpoint model loaded: ['model.encoder.embed_tokens.weight', 'model.decoder.embed_tokens.weight', 'lm_head.weight'].


TrainOutput(global_step=1500, training_loss=1.7212730000813803, metrics={'train_runtime': 6780.0517, 'train_samples_per_second': 2.212, 'train_steps_per_second': 0.277, 'total_flos': 437700955668480.0, 'train_loss': 1.7212730000813803, 'epoch': 2.4})
✅ Best EN-ES checkpoint: mbart50_opus_en_es/checkpoint-900


In [3]:
# ---- Inference: EN → ES using the saved fine-tuned model --------------------
import torch
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast

# Find best model in training
model_dir = best_es if (isinstance(best_es, str) and len(best_es) > 0) else training_args_es.output_dir
print("🔎 Loading model from:", model_dir)

# Load tokenizer & model
tokenizer_inf = MBart50TokenizerFast.from_pretrained(model_dir)
model_inf = MBartForConditionalGeneration.from_pretrained(model_dir)

# Language settings for mBART-50 (EN -> ES)
tokenizer_inf.src_lang = "en_XX"
tokenizer_inf.tgt_lang = "es_XX"
model_inf.config.forced_bos_token_id = tokenizer_inf.lang_code_to_id["es_XX"]

# Device
device = "cuda" if torch.cuda.is_available() else "cpu"
model_inf.to(device)
model_inf.eval()

def translate_en_to_es(texts, max_new_tokens=64, max_source_len=128):
    """
    texts: str or List[str] in English
    returns: List[str] in Spanish
    """
    if isinstance(texts, str):
        texts = [texts]

    enc = tokenizer_inf(
        texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=max_source_len,
    )
    enc = {k: v.to(device) for k, v in enc.items()}

    with torch.no_grad():
        gen = model_inf.generate(
            **enc,
            max_new_tokens=max_new_tokens,   # keep modest to avoid VRAM spikes
            # num_beams=1  # default is 1; uncomment if you want to be explicit
        )

    return tokenizer_inf.batch_decode(gen, skip_special_tokens=True)

# Quick sanity check
examples = [
    "The committee will meet next Tuesday to review the proposal.",
    "Please submit the signed contract by the end of the week.",
    "This device requires regular maintenance to function properly."
]
es_out = translate_en_to_es(examples, max_new_tokens=64)
for src, tgt in zip(examples, es_out):
    print(f"\nEN: {src}\nES: {tgt}")

🔎 Loading model from: mbart50_opus_en_es/checkpoint-900

EN: The committee will meet next Tuesday to review the proposal.
ES: El comité se reúne el martes para examinar la propuesta.

EN: Please submit the signed contract by the end of the week.
ES: Por favor, envuélvete el contrato firmado antes del fin de semana.

EN: This device requires regular maintenance to function properly.
ES: Este dispositivo necesita mantenimiento regular para funcionar correctamente.
