In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TrainingArguments, Trainer, DataCollatorForSeq2Seq
from datasets import load_dataset
import torch
from evaluate import load
import numpy as np
from peft import LoraConfig, get_peft_model, TaskType

full_dataset = load_dataset("coastalcph/tydi_xor_rc")

model_checkpoint = "google/mt5-base"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

bleu_metric = load("bleu")
rouge_metric = load("rouge")

train_dataset = full_dataset["train"]
val_dataset = full_dataset["validation"]

te_train = train_dataset.filter(lambda ex: ex["lang"] == "te" and ex["answer_inlang"] != "")
te_val = val_dataset.filter(lambda ex: ex["lang"] == "te" and ex["answer_inlang"] != "")

print(f"Telugu train samples with answer_inlang: {len(te_train)}")
print(f"Telugu val samples with answer_inlang: {len(te_val)}")

def prepare_data_model1(examples):
    inputs = []
    targets = []
    
    for q, c, a in zip(examples["question"], examples["context"], examples["answer_inlang"]):
        q_str = str(q) if q is not None else ""
        c_str = str(c) if c is not None else ""
        a_str = str(a) if a is not None else ""
        
        input_text = f"Question: {q_str} Context: {c_str}"
        inputs.append(input_text)
        targets.append(a_str)
    
    model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding="max_length")
    labels = tokenizer(targets, max_length=128, truncation=True, padding="max_length")
    
    labels_list = []
    for label_ids in labels["input_ids"]:
        label_ids = [l if l != tokenizer.pad_token_id else -100 for l in label_ids]
        labels_list.append(label_ids)
    
    model_inputs["labels"] = labels_list
    return model_inputs

def prepare_data_model2(examples):
    inputs = []
    targets = []
    
    for q, a in zip(examples["question"], examples["answer_inlang"]):
        q_str = str(q) if q is not None else ""
        a_str = str(a) if a is not None else ""
        
        input_text = f"Question: {q_str}"
        inputs.append(input_text)
        targets.append(a_str)
    
    model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding="max_length")
    labels = tokenizer(targets, max_length=128, truncation=True, padding="max_length")
    
    labels_list = []
    for label_ids in labels["input_ids"]:
        label_ids = [l if l != tokenizer.pad_token_id else -100 for l in label_ids]
        labels_list.append(label_ids)
    
    model_inputs["labels"] = labels_list
    return model_inputs

def prepare_data_model3(examples):
    inputs = []
    targets = []
    
    for q, a in zip(examples["question"], examples["answer"]):
        q_str = str(q) if q is not None else ""
        a_str = str(a) if a is not None else ""
        
        input_text = f"Translate to English. Question: {q_str}"
        inputs.append(input_text)
        targets.append(a_str)
    
    model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding="max_length")
    labels = tokenizer(targets, max_length=128, truncation=True, padding="max_length")
    
    labels_list = []
    for label_ids in labels["input_ids"]:
        label_ids = [l if l != tokenizer.pad_token_id else -100 for l in label_ids]
        labels_list.append(label_ids)
    
    model_inputs["labels"] = labels_list
    return model_inputs

def compute_metrics_model1(model):
    predictions = []
    references = []
    
    for i in range(len(te_val)):
        q = str(te_val[i]["question"]) if te_val[i]["question"] is not None else ""
        c = str(te_val[i]["context"]) if te_val[i]["context"] is not None else ""
        input_text = f"Question: {q} Context: {c}"
        inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True).to(model.device)
        
        with torch.no_grad():
            outputs = model.generate(**inputs, max_new_tokens=128)
        
        pred = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
        
        ref = te_val[i]["answer_inlang"]
        ref = str(ref) if ref is not None else ""
        
        if i < 3:
            print(f"\nExample {i+1}:")
            print(f"Question: {q[:100]}...")
            print(f"Prediction: {pred[:100]}...")
            print(f"Reference: {ref[:100]}...")
        
        predictions.append(pred)
        references.append(ref)
    
    bleu = bleu_metric.compute(predictions=predictions, references=[[r] for r in references])
    rouge = rouge_metric.compute(predictions=predictions, references=references)
    
    return {
        "bleu": bleu["bleu"],
        "rouge1": rouge["rouge1"],
        "rouge2": rouge["rouge2"],
        "rougeL": rouge["rougeL"]
    }

def compute_metrics_model2(model):
    predictions = []
    references = []
    
    for i in range(len(te_val)):
        q = str(te_val[i]["question"]) if te_val[i]["question"] is not None else ""
        input_text = f"Question: {q}"
        inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True).to(model.device)
        
        with torch.no_grad():
            outputs = model.generate(**inputs, max_new_tokens=128)
        
        pred = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
        
        ref = te_val[i]["answer_inlang"]
        ref = str(ref) if ref is not None else ""
        
        if i < 3:
            print(f"\nExample {i+1}:")
            print(f"Question: {q[:100]}...")
            print(f"Prediction: {pred[:100]}...")
            print(f"Reference: {ref[:100]}...")
        
        predictions.append(pred)
        references.append(ref)
    
    bleu = bleu_metric.compute(predictions=predictions, references=[[r] for r in references])
    rouge = rouge_metric.compute(predictions=predictions, references=references)
    
    return {
        "bleu": bleu["bleu"],
        "rouge1": rouge["rouge1"],
        "rouge2": rouge["rouge2"],
        "rougeL": rouge["rougeL"]
    }

def compute_metrics_model3(model):
    predictions = []
    references = []
    
    for i in range(len(te_val)):
        q = str(te_val[i]["question"]) if te_val[i]["question"] is not None else ""
        input_text = f"Translate to English. Question: {q}"
        inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True).to(model.device)
        
        with torch.no_grad():
            outputs = model.generate(**inputs, max_new_tokens=128)
        
        pred = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
        
        ref = te_val[i]["answer"]
        ref = str(ref) if ref is not None else ""
        
        if i < 3:
            print(f"\nExample {i+1}:")
            print(f"Question: {q[:100]}...")
            print(f"Prediction: {pred[:100]}...")
            print(f"Reference: {ref[:100]}...")
        
        predictions.append(pred)
        references.append(ref)
    
    bleu = bleu_metric.compute(predictions=predictions, references=[[r] for r in references])
    rouge = rouge_metric.compute(predictions=predictions, references=references)
    
    return {
        "bleu": bleu["bleu"],
        "rouge1": rouge["rouge1"],
        "rouge2": rouge["rouge2"],
        "rougeL": rouge["rougeL"]
    }

tokenized_train_m1 = te_train.map(prepare_data_model1, batched=True, remove_columns=te_train.column_names)
tokenized_val_m1 = te_val.map(prepare_data_model1, batched=True, remove_columns=te_val.column_names)

tokenized_train_m2 = te_train.map(prepare_data_model2, batched=True, remove_columns=te_train.column_names)
tokenized_val_m2 = te_val.map(prepare_data_model2, batched=True, remove_columns=te_val.column_names)

tokenized_train_m3 = te_train.map(prepare_data_model3, batched=True, remove_columns=te_train.column_names)
tokenized_val_m3 = te_val.map(prepare_data_model3, batched=True, remove_columns=te_val.column_names)

results = {}

print("\n" + "="*60)
print("MODEL 1: Telugu Question + English Context -> Telugu Answer")
print("="*60)

model_m1 = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint, device_map="auto")

lora_config_m1 = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM
)

model_m1 = get_peft_model(model_m1, lora_config_m1)
model_m1.print_trainable_parameters()

args_m1 = TrainingArguments(
    output_dir="mt5-base-te-qa-context",
    eval_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    num_train_epochs=50,
    weight_decay=0.01,
    save_strategy="epoch",
    save_total_limit=1,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    report_to=[],
    push_to_hub=False,
    gradient_accumulation_steps=4,
    fp16=True,
)

data_collator_m1 = DataCollatorForSeq2Seq(tokenizer, model=model_m1)

trainer_m1 = Trainer(
    model=model_m1,
    args=args_m1,
    train_dataset=tokenized_train_m1,
    eval_dataset=tokenized_val_m1,
    data_collator=data_collator_m1,
)

trainer_m1.train()
eval_results = trainer_m1.evaluate()

print(f"Computing BLEU and ROUGE...")
metrics = compute_metrics_model1(model_m1)

results["model1"] = {
    "eval_loss": eval_results["eval_loss"],
    **metrics
}

print(f"Model 1 - Loss: {results['model1']['eval_loss']:.4f}, BLEU: {results['model1']['bleu']:.4f}, ROUGE-L: {results['model1']['rougeL']:.4f}")

del model_m1, trainer_m1
torch.cuda.empty_cache()

print("\n" + "="*60)
print("MODEL 2: Telugu Question -> Telugu Answer")
print("="*60)

model_m2 = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint, device_map="auto")

lora_config_m2 = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM
)

model_m2 = get_peft_model(model_m2, lora_config_m2)
model_m2.print_trainable_parameters()

args_m2 = TrainingArguments(
    output_dir="mt5-base-te-qa-no-context",
    eval_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    num_train_epochs=50,
    weight_decay=0.01,
    save_strategy="epoch",
    save_total_limit=1,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    report_to=[],
    push_to_hub=False,
    gradient_accumulation_steps=4,
    fp16=True,
)

data_collator_m2 = DataCollatorForSeq2Seq(tokenizer, model=model_m2)

trainer_m2 = Trainer(
    model=model_m2,
    args=args_m2,
    train_dataset=tokenized_train_m2,
    eval_dataset=tokenized_val_m2,
    data_collator=data_collator_m2,
)

trainer_m2.train()
eval_results = trainer_m2.evaluate()

print(f"Computing BLEU and ROUGE...")
metrics = compute_metrics_model2(model_m2)

results["model2"] = {
    "eval_loss": eval_results["eval_loss"],
    **metrics
}

print(f"Model 2 - Loss: {results['model2']['eval_loss']:.4f}, BLEU: {results['model2']['bleu']:.4f}, ROUGE-L: {results['model2']['rougeL']:.4f}")

del model_m2, trainer_m2
torch.cuda.empty_cache()

print("\n" + "="*60)
print("MODEL 3: Telugu Question -> English Answer")
print("="*60)

model_m3 = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint, device_map="auto")

lora_config_m3 = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM
)

model_m3 = get_peft_model(model_m3, lora_config_m3)
model_m3.print_trainable_parameters()

args_m3 = TrainingArguments(
    output_dir="mt5-base-te-qa-en-answer",
    eval_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    num_train_epochs=50,
    weight_decay=0.01,
    save_strategy="epoch",
    save_total_limit=1,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    report_to=[],
    push_to_hub=False,
    gradient_accumulation_steps=4,
    fp16=True,
)

data_collator_m3 = DataCollatorForSeq2Seq(tokenizer, model=model_m3)

trainer_m3 = Trainer(
    model=model_m3,
    args=args_m3,
    train_dataset=tokenized_train_m3,
    eval_dataset=tokenized_val_m3,
    data_collator=data_collator_m3,
)

trainer_m3.train()
eval_results = trainer_m3.evaluate()

print(f"Computing BLEU and ROUGE...")
metrics = compute_metrics_model3(model_m3)

results["model3"] = {
    "eval_loss": eval_results["eval_loss"],
    **metrics
}

print(f"Model 3 - Loss: {results['model3']['eval_loss']:.4f}, BLEU: {results['model3']['bleu']:.4f}, ROUGE-L: {results['model3']['rougeL']:.4f}")

print("\n" + "="*80)
print("FINAL RESULTS")
print("="*80)
print(f"{'Model':<30} {'Loss':<10} {'BLEU':<10} {'ROUGE-1':<10} {'ROUGE-2':<10} {'ROUGE-L':<10}")
print("-"*80)
print(f"{'M1: Q+C->A_te':<30} {results['model1']['eval_loss']:<10.4f} {results['model1']['bleu']:<10.4f} {results['model1']['rouge1']:<10.4f} {results['model1']['rouge2']:<10.4f} {results['model1']['rougeL']:<10.4f}")
print(f"{'M2: Q->A_te':<30} {results['model2']['eval_loss']:<10.4f} {results['model2']['bleu']:<10.4f} {results['model2']['rouge1']:<10.4f} {results['model2']['rouge2']:<10.4f} {results['model2']['rougeL']:<10.4f}")
print(f"{'M3: Q->A_en':<30} {results['model3']['eval_loss']:<10.4f} {results['model3']['bleu']:<10.4f} {results['model3']['rouge1']:<10.4f} {results['model3']['rouge2']:<10.4f} {results['model3']['rougeL']:<10.4f}")
print("="*80)

results