# Preprocessing

In [None]:
!pip -q install -U transformers datasets evaluate sacrebleu accelerate sentencepiece peft bitsandbytes

In [None]:
from google.colab import drive
drive.mount("/content/drive")

In [None]:
import torch, time, evaluate, random
import pandas as pd
from datasets import load_dataset, Dataset
from peft import LoraConfig, get_peft_model
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer
)

In [None]:
BASE_DIR = "/content/drive/MyDrive/dataset_splits_opus100_10k"

def load_pair(pair):
    data_files = {
        "train": f"{BASE_DIR}/{pair}/train.csv",
        "validation": f"{BASE_DIR}/{pair}/val.csv",
        "test": f"{BASE_DIR}/{pair}/test.csv",
    }
    return load_dataset("csv", data_files=data_files)

# Baseline (Before Fine-tuning)

In [None]:
MODEL = "facebook/nllb-200-distilled-600M"
LANG = {
    "en": "eng_Latn",
    "id": "ind_Latn",
    "vi": "vie_Latn",
    "ko": "kor_Hang"
}

tokenizer = AutoTokenizer.from_pretrained(MODEL)

base_model = AutoModelForSeq2SeqLM.from_pretrained(
    MODEL,
    device_map="auto",
    dtype=torch.float16
)

model = AutoModelForSeq2SeqLM.from_pretrained(
    MODEL,
    device_map="auto",
    torch_dtype=torch.float16
)

# Memory saver
# model.gradient_checkpointing_enable()
# model.config.use_cache = False

# LoRA configuration
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.1,
    bias="none",
    task_type="SEQ_2_SEQ_LM"
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

In [None]:
@torch.inference_mode()
def translate_batch(
    model,
    tokenizer,
    texts,
    src_lang: str,
    tgt_lang: str,
    batch_size: int = 16, # 8 if OOM
    max_input_len: int = 256,
    max_new_tokens: int = 256,
    num_beams: int = 4
):

    device = next(model.parameters()).device

    tokenizer.src_lang = src_lang
    forced_bos_token_id = tokenizer.convert_tokens_to_ids(tgt_lang)

    outputs = []
    for start in range(0, len(texts), batch_size):
        batch_texts = texts[start:start + batch_size]

        # Tokenize
        inputs = tokenizer(
            batch_texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_input_len
        ).to(device)

        # Generate translations
        generated_tokens = model.generate(
            **inputs,
            forced_bos_token_id=forced_bos_token_id,
            max_new_tokens=max_new_tokens,
            num_beams=num_beams
        )

        # Decode
        decoded = tokenizer.batch_decode(
            generated_tokens,
            skip_special_tokens=True
        )
        outputs.extend(decoded)

    return outputs

In [None]:
def evaluate_test(
    model,
    tokenizer,
    ds,
    src_lang: str,
    tgt_lang: str,
    batch_size: int = 16, # 8 if OOM
    max_input_len: int = 256,
    max_new_tokens: int = 256,
    num_beams: int = 4,
    reverse: bool = False,
    show_examples: bool = False,
    num_examples: int = 3,
    seed: int = 42
):

    bleu = evaluate.load("sacrebleu")
    chrf = evaluate.load("chrf")

    # Extract source and reference texts
    if not reverse:
        sources = ds["test"]["source"]
        refs    = ds["test"]["target"]
    else:
        sources = ds["test"]["target"]
        refs    = ds["test"]["source"]

    # Translate
    t0 = time.time()
    preds = translate_batch(
        model,
        tokenizer,
        sources,
        src_lang,
        tgt_lang,
        batch_size=batch_size,
        max_input_len=max_input_len,
        max_new_tokens=max_new_tokens,
        num_beams=num_beams
    )
    t1 = time.time()

    infer_time = t1 - t0
    speed = len(sources) / infer_time if infer_time > 0 else 0.0

    # Compute metrics
    bleu_score = bleu.compute(
        predictions=preds,
        references=[[r] for r in refs]
    )["score"]

    chrf_score = chrf.compute(
        predictions=preds,
        references=refs
    )["score"]

    if show_examples:
        print("\nExamples:\n")

        random.seed(seed)
        indices = random.sample(range(len(sources)), min(num_examples, len(sources)))

        for i, idx in enumerate(indices, 1):
            print("SOURCE     :", sources[idx])
            print("PREDICTION :", preds[idx])
            print("REFERENCE  :", refs[idx])

    return {
        "BLEU": float(bleu_score),
        "chrF": float(chrf_score),
        "Speed": float(speed)
    }

In [None]:
pairs = [
    ("en_ko", "en", "ko"),
    ("en_id", "en", "id"),
    ("en_vi", "en", "vi"),
]

In [None]:
baseline_results = {}

print("\nEvaluation Before Fine-tuning")
for pair_name, src, tgt in pairs:

  ds = load_pair(pair_name)

  # Forward direction
  fwd = evaluate_test(
      base_model,
      tokenizer,
      ds,
      src_lang=LANG[src],
      tgt_lang=LANG[tgt]
  )

  baseline_results[f"{src}->{tgt}"] = fwd

  print(f"\nEvaluating {src} -> {tgt} :")
  print(f"BLEU            : {fwd['BLEU']:.2f}")
  print(f"chrF            : {fwd['chrF']:.2f}")
  print(f"Inference Speed : {fwd['Speed']:.2f} sentences/s\n")
  print("-"*40)

  # Reverse direction
  rev = evaluate_test(
      base_model,
      tokenizer,
      ds,
      src_lang=LANG[tgt],
      tgt_lang=LANG[src],
      reverse=True
  )

  baseline_results[f"{tgt}->{src}"] = rev

  print(f"\nEvaluating {tgt} -> {src} :")
  print(f"BLEU            : {rev['BLEU']:.2f}")
  print(f"chrF            : {rev['chrF']:.2f}")
  print(f"Inference Speed : {rev['Speed']:.2f} sentences/s\n")
  print("-"*40)

# Fine Tuning

In [None]:
def preprocess(batch, src_lang, tgt_lang):

    tokenizer.src_lang = LANG[src_lang]

    model_inputs = tokenizer(
        batch["source"],
        truncation=True,
        padding=False,
        max_length=256
    )

    labels = tokenizer(
        text_target=batch["target"],
        truncation=True,
        max_length=256
    )

    model_inputs["labels"] = labels["input_ids"]

    return model_inputs

In [None]:
def finetune_model(pair_name, src, tgt):

    ds = load_pair(pair_name)

    # Preprocess
    tokenized_ds = ds.map(
        lambda x: preprocess(x, src, tgt),
        batched=True,
        remove_columns=ds["train"].column_names
    )

    # Data collator
    data_collator = DataCollatorForSeq2Seq(
        tokenizer,
        model=model
    )

    # Training args
    training_args = Seq2SeqTrainingArguments(
        output_dir=f"./ft_{pair_name}",
        per_device_train_batch_size=2,
        gradient_accumulation_steps=8,
        learning_rate=2e-5,
        num_train_epochs=1,
        fp16=True,
        logging_steps=100,
        save_strategy="no",
        eval_strategy="no",
        report_to="none"
    )

    # Trainer
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_ds["train"],
        data_collator=data_collator
    )

    # Training
    start = time.time()
    train_output = trainer.train()
    end = time.time()

    loss_df = pd.DataFrame(
        [{"Step": log["step"], "Training Loss": log["loss"]}
         for log in trainer.state.log_history
         if "loss" in log and "step" in log]
    )

    train_info = {
        "train_loss": train_output.training_loss,
        "training_time": end - start,
        "loss_table": loss_df
    }

    return trainer.model, train_info

In [None]:
def finetune_and_evaluate(
    pair_name: str,
    src: str,
    tgt: str,
    model,
    tokenizer
):

    ds = load_pair(pair_name)

    # Forward direction
    fwd = evaluate_test(
        model,
        tokenizer,
        ds,
        src_lang=LANG[src],
        tgt_lang=LANG[tgt]
    )

    # Reverse direction
    rev = evaluate_test(
        model,
        tokenizer,
        ds,
        src_lang=LANG[tgt],
        tgt_lang=LANG[src],
        reverse=True
    )

    return {
        f"{src}->{tgt}": fwd,
        f"{tgt}->{src}": rev
    }

In [None]:
finetune_results = {}

print("\nEvaluation After Fine-tuning")
for pair_name, src, tgt in pairs:
    print(f"\nFine-tuning {src} -> {tgt} :")

    ft_model, train_info = finetune_model(pair_name, src, tgt)

    # Evaluate both forward and reverse directions
    eval_results = finetune_and_evaluate(
        pair_name=pair_name,
        src=src,
        tgt=tgt,
        model=ft_model,
        tokenizer=tokenizer
    )

    finetune_results.update(eval_results)

    # Print both forward and reverse directions
    for direction, metrics in eval_results.items():
        print(f"BLEU            : {metrics['BLEU']:.2f}")
        print(f"chrF            : {metrics['chrF']:.2f}")
        print(f"Inference Speed : {metrics['Speed']:.2f} sentences/s\n")
        print("-"*40)

    if train_info is not None:
        if "train_loss" in train_info:
            print(f"Training Loss  : {train_info['train_loss']:.4f}")
        if "training_time" in train_info:
            print(f"Training Time  : {train_info['training_time']:.2f} sec")
        if "loss_table" in train_info:
            train_info["loss_table"]

    torch.cuda.empty_cache()

print("\n6-direction evaluation completed after fine-tuning.")

# Comparison

In [None]:
def compare(baseline_results, finetune_results, pairs):
    directions = []

    for _, src, tgt in pairs:
      directions.append(f"{src}->{tgt}")
      directions.append(f"{tgt}->{src}")

    for d in directions:
        b = baseline_results[d]
        f = finetune_results[d]

        print(
            f"\nDirection:          {d}\n"
            f"BLEU (Baseline):    {b['BLEU']:.2f}\n"
            f"BLEU (Fine-tuned):  {f['BLEU']:.2f}\n"
            f"chrF (Baseline):    {b['chrF']:.2f}\n"
            f"chrF (Fine-tuned):  {f['chrF']:.2f}\n"
            f"Speed (Baseline):   {b['Speed']:.2f} sentences/s\n"
            f"Speed (Fine-tuned): {f['Speed']:.2f} sentences/s\n"
            f"{'-'*40}"
        )

compare(baseline_results, finetune_results, pairs)