In [1]:
import os
import gc
import numpy as np
import pandas as pd
import torch
import evaluate
from datasets import load_dataset, Dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, Trainer, TrainingArguments, DataCollatorForSeq2Seq

os.environ["HF_HOME"] = "lfs/hyperturing1/0/aarushs/hf_cache"
os.environ["HF_MODELS_CACHE"] = "lfs/hyperturing1/0/aarushs/hf_cache"
os.environ["HF_DATASETS_CACHE"] = "lfs/hyperturing1/0/aarushs/hf_cache"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

dataset = load_dataset("coastalcph/tydi_xor_rc")

def filter_te_with_answers(example):
    has_te_lang = example["lang"] == "te"
    has_answer = example["answer"] is not None and len(example["answer"]) > 0
    return has_te_lang and has_answer

train_dataset = dataset["train"].filter(filter_te_with_answers)
val_dataset = dataset["validation"].filter(filter_te_with_answers)

print(f"Telugu training examples: {len(train_dataset)}")
print(f"Telugu validation examples: {len(val_dataset)}")

model_checkpoint = "facebook/mbart-large-50-many-to-many-mmt"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

max_length = 576

# ============================================================================
# IMPROVED PROMPT FORMATTING FUNCTIONS
# ============================================================================

def format_prompt_qc(question, context):
    """
    Improved prompt with clear structure and task instruction.
    Uses delimiters and explicit task framing.
    """
    prompt = (
        f"Answer the following question based on the given context. "
        f"Provide a concise and accurate answer in Telugu.\n\n"
        f"Question: {question}\n\n"
        f"Context: {context}\n\n"
        f"Answer:"
    )
    return prompt

def format_prompt_q(question, lang="Telugu"):
    """
    Improved question-only prompt with task instruction.
    Explicitly states the expected output language.
    """
    if lang == "Telugu":
        prompt = (
            f"Answer the following question accurately and concisely in Telugu. "
            f"Use your knowledge to provide the best answer.\n\n"
            f"Question: {question}\n\n"
            f"Answer:"
        )
    else:  # English
        prompt = (
            f"Answer the following question accurately and concisely in English. "
            f"Use your knowledge to provide the best answer.\n\n"
            f"Question: {question}\n\n"
            f"Answer:"
        )
    return prompt

def format_prompt_qc_structured(question, context):
    """
    Alternative: More structured format with XML-like tags.
    Can help models better parse input components.
    """
    prompt = (
        f"<task>Answer the question in Telugu based on the context.</task>\n"
        f"<question>{question}</question>\n"
        f"<context>{context}</context>\n"
        f"<answer>"
    )
    return prompt

def format_prompt_q_structured(question, lang="Telugu"):
    """
    Alternative: Structured format for question-only prompts.
    """
    prompt = (
        f"<task>Answer the question in {lang}.</task>\n"
        f"<question>{question}</question>\n"
        f"<answer>"
    )
    return prompt

def format_prompt_qc_instructional(question, context):
    """
    Alternative: Instruction-following format.
    Explicitly guides the model's behavior.
    """
    prompt = (
        f"### Instruction:\n"
        f"Read the context below and answer the question in Telugu. "
        f"If the answer is not in the context, provide your best answer based on knowledge.\n\n"
        f"### Context:\n{context}\n\n"
        f"### Question:\n{question}\n\n"
        f"### Response:\n"
    )
    return prompt

# ============================================================================
# PREPROCESSING FUNCTIONS (with improved prompts)
# ============================================================================

def preprocess_function(examples):
    inputs = []
    targets = []
    for q, c, a_inlang, a_en in zip(examples["question"], examples["context"], examples["answer_inlang"], examples["answer"]):
        input_text = format_prompt_qc(q, c)
        target_text = a_inlang if (a_inlang is not None and len(a_inlang) > 0) else (a_en if a_en is not None else "")
        inputs.append(input_text)
        targets.append(target_text)
    model_inputs = tokenizer(inputs, max_length=max_length, truncation=True, padding=False)
    labels = tokenizer(targets, max_length=max_length, truncation=True, padding=False)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

def preprocess_function_q_only(examples):
    inputs = []
    targets = []
    for q, a_inlang, a_en in zip(examples["question"], examples["answer_inlang"], examples["answer"]):
        input_text = format_prompt_q(q, "Telugu")
        target_text = a_inlang if (a_inlang is not None and len(a_inlang) > 0) else (a_en if a_en is not None else "")
        inputs.append(input_text)
        targets.append(target_text)
    model_inputs = tokenizer(inputs, max_length=max_length, truncation=True, padding=False)
    labels = tokenizer(targets, max_length=max_length, truncation=True, padding=False)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

def preprocess_function_q_only_en_ans(examples):
    inputs = []
    targets = []
    for q, a in zip(examples["question"], examples["answer"]):
        input_text = format_prompt_q(q, "English")
        target_text = a if a is not None else ""
        inputs.append(input_text)
        targets.append(target_text)
    model_inputs = tokenizer(inputs, max_length=max_length, truncation=True, padding=False)
    labels = tokenizer(targets, max_length=max_length, truncation=True, padding=False)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

print("\nTokenizing datasets...")
tokenized_train_dataset = train_dataset.map(preprocess_function, batched=True, remove_columns=train_dataset.column_names)
tokenized_val_dataset = val_dataset.map(preprocess_function, batched=True, remove_columns=val_dataset.column_names)
tokenized_train_dataset_q_only = train_dataset.map(preprocess_function_q_only, batched=True, remove_columns=train_dataset.column_names)
tokenized_val_dataset_q_only = val_dataset.map(preprocess_function_q_only, batched=True, remove_columns=val_dataset.column_names)
tokenized_train_dataset_q_only_en_ans = train_dataset.map(preprocess_function_q_only_en_ans, batched=True, remove_columns=train_dataset.column_names)
tokenized_val_dataset_q_only_en_ans = val_dataset.map(preprocess_function_q_only_en_ans, batched=True, remove_columns=val_dataset.column_names)

print(f"Sample tokenized lengths:")
print(f"  Input IDs: {len(tokenized_train_dataset[0]['input_ids'])}")
print(f"  Labels: {len(tokenized_train_dataset[0]['labels'])}")
print(f"  Non -100 labels: {sum(1 for x in tokenized_train_dataset[0]['labels'] if x != -100)}")

rouge = evaluate.load("rouge")
bleu = evaluate.load("bleu")

def train_and_eval(tokenized_train, tokenized_val, val_raw_dataset, output_dir, prompt_format_fn, answer_key="answer_inlang"):
    torch.cuda.empty_cache()
    gc.collect()
    model = AutoModelForSeq2SeqLM.from_pretrained(
        model_checkpoint, 
        device_map="auto",
        torch_dtype=torch.bfloat16,
    )

    if hasattr(model.config, "dropout"):
        model.config.dropout = 0.28628261092381746
    if hasattr(model.config, "attention_dropout"):
        model.config.attention_dropout = 0.17909783477703217

    data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

    args = TrainingArguments(
        output_dir=output_dir,
        eval_strategy="epoch",
        save_strategy="epoch",
        learning_rate=0.0001400032301305189,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        gradient_accumulation_steps=6,
        weight_decay=0.2880382450100906,
        save_total_limit=1,
        num_train_epochs=4,
        bf16=True,
        report_to=[],
        optim="adamw_torch_fused",
        dataloader_pin_memory=False,
        max_grad_norm=1.8717849416806842,
        warmup_ratio=0.2632858110781834,
        warmup_steps=100,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        lr_scheduler_type="cosine_with_restarts",
        logging_first_step=True,
        label_smoothing_factor=0.11694641172038693,
        adam_beta1=0.8806332791282017,
        adam_beta2=0.9662516563582316,
        adam_epsilon=1.972007726100591e-08,
    )

    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_val,
        tokenizer=tokenizer,
        data_collator=data_collator,
    )

    train_result = trainer.train()
    model.eval()

    gen_kwargs = dict(
        max_new_tokens=100,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
        do_sample=False,
        num_beams=8,
        repetition_penalty=1.6283771357674486,
        length_penalty=1.5955228963302186,
    )

    answerable_predictions = []
    answerable_references = []
    unanswerable_predictions = []
    unanswerable_references = []
    with torch.no_grad():
        for i in range(min(200, len(val_raw_dataset))):
            example = val_raw_dataset[i]
            prompt = prompt_format_fn(example)
            input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_length).input_ids.to(model.device)
            outputs = model.generate(input_ids, **gen_kwargs)
            pred_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
            if example["answerable"]:
                if example["answer_inlang"] and len(example["answer_inlang"]) > 0:
                    ref_text = example["answer_inlang"]
                else:
                    ref_text = example["answer"]
                answerable_predictions.append(pred_text.strip())
                answerable_references.append(ref_text.strip())
            else:
                unanswerable_predictions.append(pred_text.strip())
                unanswerable_references.append("")

    answerable_rouge = rouge.compute(predictions=answerable_predictions, references=answerable_references, use_stemmer=True) if answerable_predictions else {"rouge1": 0, "rouge2": 0, "rougeL": 0}
    try:
        answerable_bleu = bleu.compute(predictions=answerable_predictions, references=[[r] for r in answerable_references]) if answerable_predictions else {"bleu": 0}
    except ZeroDivisionError:
        answerable_bleu = {"bleu": 0}
    unanswerable_rouge = rouge.compute(predictions=unanswerable_predictions, references=unanswerable_references, use_stemmer=True) if unanswerable_predictions else {"rouge1": 0, "rouge2": 0, "rougeL": 0}
    try:
        unanswerable_bleu = bleu.compute(predictions=unanswerable_predictions, references=[[r] for r in unanswerable_references]) if unanswerable_predictions else {"bleu": 0}
    except ZeroDivisionError:
        unanswerable_bleu = {"bleu": 0}
    all_predictions = answerable_predictions + unanswerable_predictions
    all_references = answerable_references + unanswerable_references
    overall_rouge = rouge.compute(predictions=all_predictions, references=all_references, use_stemmer=True) if all_predictions else {"rouge1": 0, "rouge2": 0, "rougeL": 0}
    try:
        overall_bleu = bleu.compute(predictions=all_predictions, references=[[r] for r in all_references]) if all_predictions else {"bleu": 0}
    except ZeroDivisionError:
        overall_bleu = {"bleu": 0}
    results = {
        "answerable_rouge1": round(answerable_rouge.get("rouge1", 0), 4),
        "answerable_rouge2": round(answerable_rouge.get("rouge2", 0), 4),
        "answerable_rougeL": round(answerable_rouge.get("rougeL", 0), 4),
        "answerable_bleu": round(answerable_bleu.get("bleu", 0), 4),
        "unanswerable_rouge1": round(unanswerable_rouge.get("rouge1", 0), 4),
        "unanswerable_rouge2": round(unanswerable_rouge.get("rouge2", 0), 4),
        "unanswerable_rougeL": round(unanswerable_rouge.get("rougeL", 0), 4),
        "unanswerable_bleu": round(unanswerable_bleu.get("bleu", 0), 4),
        "overall_rouge1": round(overall_rouge.get("rouge1", 0), 4),
        "overall_rouge2": round(overall_rouge.get("rouge2", 0), 4),
        "overall_rougeL": round(overall_rouge.get("rougeL", 0), 4),
        "overall_bleu": round(overall_bleu.get("bleu", 0), 4),
        "train_loss": round(train_result.training_loss, 4),
        "answerable_count": len(answerable_predictions),
        "unanswerable_count": len(unanswerable_predictions)
    }
    del model
    del trainer
    torch.cuda.empty_cache()
    gc.collect()
    return results

def prompt_fn_qc(example):
    return format_prompt_qc(example["question"], example["context"])

def prompt_fn_q_te(example):
    return format_prompt_q(example["question"], "Telugu")

def prompt_fn_q_en(example):
    return format_prompt_q(example["question"], "English")

print("\n" + "="*80)
print("Training Model 1: Telugu Question + English Context → Telugu Answer")
print("="*80)
res1 = train_and_eval(tokenized_train_dataset, tokenized_val_dataset, val_dataset, "mbart-te-qc", prompt_fn_qc, "answer_inlang")

print("\n" + "="*80)
print("Training Model 2: Telugu Question → Telugu Answer")
print("="*80)
res2 = train_and_eval(tokenized_train_dataset_q_only, tokenized_val_dataset_q_only, val_dataset, "mbart-te-q", prompt_fn_q_te, "answer_inlang")

print("\n" + "="*80)
print("Training Model 3: Telugu Question → English Answer")
print("="*80)
res3 = train_and_eval(tokenized_train_dataset_q_only_en_ans, tokenized_val_dataset_q_only_en_ans, val_dataset, "mbart-te-q-en", prompt_fn_q_en, "answer")

print("\n" + "="*100)
print("DETAILED RESULTS COMPARISON")
print("="*100)

results_data = []
for i, (name, results) in enumerate([("Q+C → A_te", res1), ("Q → A_te", res2), ("Q → A_en", res3)], 1):
    print(f"\n{'='*20} MODEL {i}: {name} {'='*20}")
    print(f"Training Loss: {results['train_loss']}")
    print(f"Answerable Examples: {results['answerable_count']}")
    print(f"Unanswerable Examples: {results['unanswerable_count']}")
    print("\nAnswerable Performance:")
    print(f"  ROUGE-1: {results['answerable_rouge1']}")
    print(f"  ROUGE-2: {results['answerable_rouge2']}")
    print(f"  ROUGE-L: {results['answerable_rougeL']}")
    print(f"  BLEU: {results['answerable_bleu']}")
    print("\nUnanswerable Performance:")
    print(f"  ROUGE-1: {results['unanswerable_rouge1']}")
    print(f"  ROUGE-2: {results['unanswerable_rouge2']}")
    print(f"  ROUGE-L: {results['unanswerable_rougeL']}")
    print(f"  BLEU: {results['unanswerable_bleu']}")
    print("\nOverall Performance:")
    print(f"  ROUGE-1: {results['overall_rouge1']}")
    print(f"  ROUGE-2: {results['overall_rouge2']}")
    print(f"  ROUGE-L: {results['overall_rougeL']}")
    print(f"  BLEU: {results['overall_bleu']}")
    results_data.append({
        'Model': name,
        'Train Loss': results['train_loss'],
        'Answerable Count': results['answerable_count'],
        'Unanswerable Count': results['unanswerable_count'],
        'Answerable ROUGE-1': results['answerable_rouge1'],
        'Unanswerable ROUGE-1': results['unanswerable_rouge1'],
        'Overall ROUGE-1': results['overall_rouge1'],
        'Answerable BLEU': results['answerable_bleu'],
        'Unanswerable BLEU': results['unanswerable_bleu'],
        'Overall BLEU': results['overall_bleu']
    })

summary_df = pd.DataFrame(results_data)
print(f"\n{'='*100}")
print("SUMMARY TABLE")
print("="*100)
print(summary_df.to_string(index=False))



Telugu training examples: 1355
Telugu validation examples: 384

Tokenizing datasets...


Map:   0%|          | 0/384 [00:00<?, ? examples/s]

Sample tokenized lengths:
  Input IDs: 223
  Labels: 3
  Non -100 labels: 3


`torch_dtype` is deprecated! Use `dtype` instead!



Training Model 1: Telugu Question + English Context → Telugu Answer


  trainer = Trainer(
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
/lfs/hyperturing1/0/aarushs/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::runtime_error::~runtime_error()@GLIBCXX_3.4'
/lfs/hyperturing1/0/aarushs/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__gxx_personality_v0@CXXABI_1.3'
/lfs/hyperturing1/0/aarushs/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream::tellp()@GLIBCXX_3.4'
/lfs/hyperturing1/0/aarushs/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::chrono::_V2::steady_clock::now()@GLIBCXX_3.4.19'
/lfs/hyperturing1/0/aarushs/miniconda3/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::_M_replac

Epoch,Training Loss,Validation Loss
1,6.7902,4.339398
2,6.7902,3.372892
3,6.7902,3.087343
4,6.7902,2.918082


There were missing keys in the checkpoint model loaded: ['model.encoder.embed_tokens.weight', 'model.decoder.embed_tokens.weight', 'lm_head.weight'].



Training Model 2: Telugu Question → Telugu Answer


  trainer = Trainer(


Epoch,Training Loss,Validation Loss
1,8.639,5.740492
2,8.639,4.923953
3,8.639,4.592388
4,8.639,4.510159


There were missing keys in the checkpoint model loaded: ['model.encoder.embed_tokens.weight', 'model.decoder.embed_tokens.weight', 'lm_head.weight'].



Training Model 3: Telugu Question → English Answer


  trainer = Trainer(


Epoch,Training Loss,Validation Loss
1,8.6579,4.89909
2,8.6579,4.41686
3,8.6579,4.291759
4,8.6579,4.227438


There were missing keys in the checkpoint model loaded: ['model.encoder.embed_tokens.weight', 'model.decoder.embed_tokens.weight', 'lm_head.weight'].



DETAILED RESULTS COMPARISON

Training Loss: 3.2922
Answerable Examples: 200
Unanswerable Examples: 0

Answerable Performance:
  ROUGE-1: 0.5092
  ROUGE-2: 0.2658
  ROUGE-L: 0.5106
  BLEU: 0.2541

Unanswerable Performance:
  ROUGE-1: 0
  ROUGE-2: 0
  ROUGE-L: 0
  BLEU: 0

Overall Performance:
  ROUGE-1: 0.5092
  ROUGE-2: 0.2658
  ROUGE-L: 0.5106
  BLEU: 0.2541

Training Loss: 5.2182
Answerable Examples: 200
Unanswerable Examples: 0

Answerable Performance:
  ROUGE-1: 0.0662
  ROUGE-2: 0.0225
  ROUGE-L: 0.0656
  BLEU: 0.0

Unanswerable Performance:
  ROUGE-1: 0
  ROUGE-2: 0
  ROUGE-L: 0
  BLEU: 0

Overall Performance:
  ROUGE-1: 0.0662
  ROUGE-2: 0.0225
  ROUGE-L: 0.0656
  BLEU: 0.0

Training Loss: 5.127
Answerable Examples: 200
Unanswerable Examples: 0

Answerable Performance:
  ROUGE-1: 0.0976
  ROUGE-2: 0.0301
  ROUGE-L: 0.0978
  BLEU: 0.0236

Unanswerable Performance:
  ROUGE-1: 0
  ROUGE-2: 0
  ROUGE-L: 0
  BLEU: 0

Overall Performance:
  ROUGE-1: 0.0976
  ROUGE-2: 0.0301
  ROUGE-L