# Fine-tuning mT5-base on Cebuano GSM8K with LoRA

This notebook fine-tunes the `google/mt5-base` model on the translated Cebuano GSM8K dataset (`gsm8k_cebuano.jsonl`).
It uses **LoRA (Low-Rank Adaptation)** to efficiently train the model on a consumer GPU (like in Google Colab).

## 1. Install Dependencies

In [None]:
!pip install transformers datasets peft accelerate bitsandbytes sentencepiece

## 2. Imports and Setup

In [None]:
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
import os

# Check for GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

## 3. Load Dataset

In [None]:
DATA_FILE = "gsm8k_cebuano.jsonl"

if not os.path.exists(DATA_FILE):
    print(f"WARNING: {DATA_FILE} not found. Please upload your translated file.")
else:
    dataset = load_dataset("json", data_files=DATA_FILE, split="train")
    print(f"Loaded {len(dataset)} examples.")
    
    # Optional: Filter by similarity score if available
    if "similarity_score" in dataset.column_names:
        original_len = len(dataset)
        dataset = dataset.filter(lambda x: x["similarity_score"] >= 0.75)
        print(f"Filtered to {len(dataset)} examples (score >= 0.75).")
    
    # Split into train/val
    dataset = dataset.train_test_split(test_size=0.1)
    print(dataset)

## 4. Preprocessing

In [None]:
MODEL_ID = "google/mt5-base"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

def preprocess_function(examples):
    # Format: "Question: <question>"
    inputs = [f"Question: {q}" for q in examples["cebuano_question"]]
    targets = examples["cebuano_answer"]
    
    model_inputs = tokenizer(inputs, max_length=512, truncation=True)
    labels = tokenizer(targets, max_length=512, truncation=True)
    
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_dataset = dataset.map(preprocess_function, batched=True)

## 5. Model Setup (LoRA)

In [None]:
# Load model in 8-bit to save memory
model = AutoModelForSeq2SeqLM.from_pretrained(
    MODEL_ID,
    load_in_8bit=True,
    device_map="auto"
)

# Prepare for k-bit training
model = prepare_model_for_kbit_training(model)

# LoRA Configuration
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

## 6. Training

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

training_args = Seq2SeqTrainingArguments(
    output_dir="mt5-gsm8k-cebuano",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=1e-3,
    num_train_epochs=3,
    logging_steps=10,
    save_strategy="epoch",
    evaluation_strategy="epoch",
    fp16=True, # Use mixed precision
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
    data_collator=data_collator,
)

# Train
trainer.train()

## 7. Save and Inference

In [None]:
model.save_pretrained("mt5-gsm8k-cebuano-final")

def solve_problem(question):
    inputs = tokenizer(f"Question: {question}", return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model.generate(**inputs, max_length=256)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Test
sample_q = dataset["test"][0]["cebuano_question"]
print(f"Question: {sample_q}")
print(f"Answer: {solve_problem(sample_q)}")