# Optimized Fine-tuning of Gemma-2B-IT on Colab (A100 GPU)

This notebook provides a complete, optimized workflow for fine-tuning the `google/gemma-2b-it` model on a Google Colab A100 GPU. It leverages Flash Attention 2 and `bfloat16` for maximum performance and memory efficiency.

## Step 1: Setup and Environment

We install all necessary libraries, including `flash-attn` for optimization on A100 GPUs.

In [None]:
!pip install -q -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install -q -U accelerate bitsandbytes peft transformers trl datasets huggingface_hub flash-attn

Log in to the Hugging Face Hub to download the model and push your fine-tuned adapter.

In [None]:
from huggingface_hub import login

login()

## Step 2: Load and Prepare Data

We will load the same two datasets as before. The key change will be in how we format the data to match the expected input for an instruction-tuned model like `gemma-2b-it`.

In [None]:
from datasets import load_dataset, concatenate_datasets

# Load the datasets
translation_dataset = load_dataset("michsethowusu/english-tooro_sentence-pairs_mt560", split='train')
multitask_dataset = load_dataset("cle-13/rutooro_multitask", split='train')

## Step 3: Load Model and Tokenizer (Optimized for A100)

We now load the instruction-tuned model, `google/gemma-2b-it`. We enable Flash Attention 2 for significant speedup and memory savings on A100 GPUs.

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

model_id = "google/gemma-2b-it"

# Configure quantization for A100
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16 # Use bfloat16 for A100
)

# Load the model with Flash Attention 2
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto",
    attn_implementation="flash_attention_2",
)

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.padding_side = 'right'
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

### Data Formatting with Chat Templates

Instead of manually creating an instruction string, we will now use the tokenizer's `apply_chat_template` method. This is the correct way to format data for instruction-tuned or chat models. It ensures the input matches the format the model was trained on. We create a function to transform our examples into the required list of dictionaries (`[{"role": "user", ...}, {"role": "assistant", ...}]`).

In [None]:
def format_for_chat_template(sample):
    # Handle the multitask dataset
    if 'instruction' in sample:
        user_content = sample['instruction']
        if sample.get('input'):
            user_content += "\n" + sample['input']
        assistant_content = sample['output']
    # Handle the translation dataset
    elif 'en' in sample and 'tt' in sample:
        user_content = f"Translate this to Rutooro: {sample['en']}"
        assistant_content = sample['tt']
    else:
        raise ValueError(f"Unexpected sample structure: {sample.keys()}")

    messages = [
        {"role": "user", "content": user_content},
        {"role": "assistant", "content": assistant_content}
    ]
    
    return {"text": tokenizer.apply_chat_template(messages, tokenize=False)}

# Apply the formatting
formatted_translation_dataset = translation_dataset.map(format_for_chat_template)
formatted_multitask_dataset = multitask_dataset.map(format_for_chat_template)

# Merge and split
merged_dataset = concatenate_datasets([formatted_translation_dataset, formatted_multitask_dataset])
merged_dataset = merged_dataset.shuffle(seed=42)
dataset_splits = merged_dataset.train_test_split(test_size=0.1)
train_dataset = dataset_splits['train']
test_dataset = dataset_splits['test']

print(f"Training set size: {len(train_dataset)}")
print(f"Testing set size: {len(test_dataset)}")
print("\nExample entry:\n", train_dataset[0]['text'])

## Step 4: Configure QLoRA

In [None]:
from peft import LoraConfig

lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

## Step 5: Define Training Arguments (Optimized for A100)

We now use `bf16=True` to leverage the A100's Tensor Cores for faster and more stable training.

In [None]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="gemma-2b-it-rutooro-A100",
    num_train_epochs=3,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    logging_steps=25,
    bf16=True, # A100 optimization
    push_to_hub=True,
    report_to="tensorboard",
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    save_strategy="epoch",
    evaluation_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False
)

## Step 6: Fine-tuning with SFTTrainer

In [None]:
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    peft_config=lora_config,
    dataset_text_field="text",
    args=training_args,
    max_seq_length=1024,
    packing=True,
)

trainer.train()
trainer.save_model(f"{training_args.output_dir}/final_adapter")

## Step 7: Save Final Model to Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
source_dir = training_args.output_dir
destination_dir = f"/content/drive/MyDrive/{source_dir}"
os.makedirs(destination_dir, exist_ok=True)
!cp -r {source_dir}/* {destination_dir}/
print(f"Model saved to: {destination_dir}")

## Step 8: Inference

To run inference, we must format our prompt using the same chat template.

In [None]:
model.eval()

prompt_text = "Translate this to Rutooro: I am going to the market."

# Format the prompt using the chat template
messages = [{"role": "user", "content": prompt_text}]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

with torch.no_grad():
    outputs = model.generate(**inputs, max_new_tokens=50, do_sample=True, top_k=50, top_p=0.95)

response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response_text)

## Step 9: Push to Hub and Conclude

In [None]:
trainer.push_to_hub()
print("Notebook complete!")