In [1]:
# %%  ── Runtime flags ───────────────────────────────────────────────────────────
%env CUDA_LAUNCH_BLOCKING=1
%env TQDM_NOTEBOOK=0
%env WANDB_PROJECT=gemma-nlp_lab2_mmmlu_ft_corrected

env: CUDA_LAUNCH_BLOCKING=1
env: TQDM_NOTEBOOK=0
env: WANDB_PROJECT=gemma-nlp_lab2_mmmlu_ft_corrected


In [2]:
# %%  ── One-shot dependency install ───────────────
!pip -q install datasets transformers accelerate bitsandbytes \
               sentencepiece wandb huggingface_hub peft \
               matplotlib seaborn pandas scikit-learn

In [3]:
# %%  ── Imports ────────────────────────────────────────────────────────────────
import os, re, warnings, json, math, random, time
import torch, numpy as np, pandas as pd
from tqdm import tqdm

from datasets            import load_dataset, DatasetDict, concatenate_datasets
from transformers        import (AutoTokenizer, AutoModelForCausalLM, TrainingArguments,
                                 Trainer, BitsAndBytesConfig, default_data_collator)
from peft                import (LoraConfig, get_peft_model,
                                 prepare_model_for_kbit_training)
from sklearn.metrics     import accuracy_score
from huggingface_hub     import login
import wandb

warnings.filterwarnings("ignore")

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# %%  ── Configuration & Weights & Biases ───────────────────────────────────────
MODEL_NAME                     = "google/gemma-2-2b"
DATASET_NAME                   = "openai/MMMLU"
DATASET_CONFIGS                = ["DE_DE", "FR_FR"]

MAX_PROMPT_TOKENS_FOR_FILTER   = 256      # drop questions with huge context
MAX_SEQ_LENGTH                 = 300      # Max sequence length for tokenization (prompt + answer)
NUM_PROC                       = 4        # dataset map/filter workers
EVAL_BATCH_SIZE                = 8

OUTPUT_DIR                     = "./outputs"
RUN_NAME                       = "gemma-2b-mmmlu-de-fr-qlora-corrected-2epochs-lr1e-4"

os.makedirs(OUTPUT_DIR, exist_ok=True)
wandb.init(project=os.environ["WANDB_PROJECT"],
           name=RUN_NAME, dir=OUTPUT_DIR, mode="online")

[34m[1mwandb[0m: Currently logged in as: [33mdenis-katkalo[0m ([33mdenis-katkalo-kpi[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [5]:
# %%  ── HF login (needed for Gemma) ────────────────────────────────────────────
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
    login(token=HF_TOKEN, add_to_git_credential=False)
else:
    print("⚠️  HF_TOKEN not set – make sure you accepted Gemma licence via web UI.")

⚠️  HF_TOKEN not set – make sure you accepted Gemma licence via web UI.


In [6]:
# %%  ── Tokenizer ────────────────
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN or True)
tokenizer.pad_token    = tokenizer.eos_token # Gemma uses eos_token for padding
tokenizer.padding_side = "left"   # For Causal LMs, padding on the left is standard for generation

## 1. Load, filter & split the dataset

In [7]:
# %%  ── Load DE + FR 'test' splits and merge ───────────────────────────────────
raw_ds_parts = [load_dataset(DATASET_NAME, cfg, split="test") for cfg in DATASET_CONFIGS]
raw_ds       = concatenate_datasets(raw_ds_parts)
print(f"Loaded {len(raw_ds)} total examples.")

Loaded 28084 total examples.


In [8]:
# %%  ── Drop very long questions ────────────────
PROMPT_TEMPLATE = """The following is a multiple-choice question. \
Provide the letter of the correct answer.

Question: {question}
Options:
(A) {A}
(B) {B}
(C) {C}
(D) {D}
Correct Answer:"""

def _prompt_len_ok(ex):
    """Keep only items whose prompt (excluding the answer part) fits within the token limit."""
    prompt = PROMPT_TEMPLATE.format(
        question = ex["Question"],   # ← use the **uppercase** field names
        A = ex["A"], B = ex["B"], C = ex["C"], D = ex["D"]
    )
    return len(tokenizer(prompt).input_ids) <= MAX_PROMPT_TOKENS_FOR_FILTER

ds_filtered = raw_ds.filter(_prompt_len_ok, num_proc=NUM_PROC)
print(f"After length filter: {len(ds_filtered)} examples.")

Filter (num_proc=4): 100%|██████████| 28084/28084 [00:01<00:00, 17819.43 examples/s]

After length filter: 23696 examples.





In [9]:
# %%  ── Stratified 80-10-10 split by subject (use built-in label encoding) ─────
# 1) Convert the textual “Subject” column into a ClassLabel in one call
ds_encoded = ds_filtered.class_encode_column("Subject")   # adds .names mapping

# 2) Perform stratified splits on this encoded column
train_val_test = ds_encoded.train_test_split(
    test_size=0.20, seed=42, stratify_by_column="Subject"
)
val_test_split = train_val_test["test"].train_test_split(
    test_size=0.50, seed=42, stratify_by_column="Subject"
)

raw_train = train_val_test["train"]
raw_val   = val_test_split["train"]
raw_test  = val_test_split["test"]

print(f"Train {len(raw_train)} | Val {len(raw_val)} | Test {len(raw_test)}")

Flattening the indices: 100%|██████████| 23696/23696 [00:00<00:00, 35759.92 examples/s]
Casting to class labels: 100%|██████████| 23696/23696 [00:00<00:00, 38400.89 examples/s]

Train 18956 | Val 2370 | Test 2370





In [None]:
# %%  ── Build tokenised training / validation sets ────────────────────────────
def _tokenise(example):
    prompt_text = PROMPT_TEMPLATE.format(**{
        "question": example["Question"],
        "A": example["A"], "B": example["B"],
        "C": example["C"], "D": example["D"]
    })
    # Tokenize the prompt part first
    prompt_tokenized = tokenizer(prompt_text, add_special_tokens=False) # No BOS/EOS for prompt part alone
    target_letter = example["Answer"].strip().upper()
    target_text_with_space = " " + target_letter 
    target_tokenized = tokenizer(target_text_with_space, add_special_tokens=False)

    # Combine prompt and target tokens for input_ids
    # Add BOS at the beginning and EOS at the end of the combined sequence
    input_ids = [tokenizer.bos_token_id] + prompt_tokenized['input_ids'] + target_tokenized['input_ids'] + [tokenizer.eos_token_id]
    attention_mask = [1] * len(input_ids)

    # Create labels: mask prompt tokens, keep answer tokens and the final EOS
    labels = [-100] * (1 + len(prompt_tokenized['input_ids'])) + target_tokenized['input_ids'] + [tokenizer.eos_token_id]

    # Pad sequences to MAX_SEQ_LENGTH
    padding_length = MAX_SEQ_LENGTH - len(input_ids)
    if padding_length > 0:
        # Pad on the left for Causal LMs
        input_ids = [tokenizer.pad_token_id] * padding_length + input_ids
        attention_mask = [0] * padding_length + attention_mask
        labels = [-100] * padding_length + labels
    elif padding_length < 0:
        # Truncate from the right 
        input_ids = input_ids[:MAX_SEQ_LENGTH]
        attention_mask = attention_mask[:MAX_SEQ_LENGTH]
        labels = labels[:MAX_SEQ_LENGTH]
        # Ensure the last token is EOS if truncated, and its label is EOS or -100 if it was pad before
        if input_ids[-1] != tokenizer.eos_token_id:
             input_ids[-1] = tokenizer.eos_token_id
             if labels[-1] != -100 : # Only change if it was a valid label before truncation made it not EOS
                labels[-1] = tokenizer.eos_token_id # Or -100 if we don't want to predict EOS after truncation
    
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels,
        "label_letter": target_letter, # Keep original letter for accuracy calc
        "subject": example["Subject"]
    }

train_ds = raw_train.map(_tokenise, remove_columns=raw_train.column_names, num_proc=NUM_PROC)
val_ds   = raw_val.map(_tokenise, remove_columns=raw_val.column_names, num_proc=NUM_PROC)
test_ds  = raw_test.map(_tokenise, remove_columns=raw_test.column_names, num_proc=NUM_PROC)

Map (num_proc=4): 100%|██████████| 18956/18956 [00:03<00:00, 5009.14 examples/s] 
Map (num_proc=4): 100%|██████████| 2370/2370 [00:00<00:00, 7772.11 examples/s]
Map (num_proc=4): 100%|██████████| 2370/2370 [00:00<00:00, 7976.85 examples/s]


In [11]:
# %%  ── Load Gemma-2-2 B with 4-bit QLoRA (eager attn, cache off) ─────────────
bnb_cfg = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
)

base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_cfg,
    device_map="auto",
    attn_implementation="eager",   # ← recommended for Gemma-2
    token=HF_TOKEN or True,
)

# Disable KV-cache during training (needed when checkpointing is on)
base_model.config.use_cache = False
base_model.gradient_checkpointing_enable()

base_model = prepare_model_for_kbit_training(base_model)

lora_cfg = LoraConfig(
    r=32, lora_alpha=16, lora_dropout=0.05,
    task_type="CAUSAL_LM", bias="none",
    # Gemma's linear layers are typically 'o_proj', 'k_proj', 'q_proj', 'v_proj', 'gate_proj', 'up_proj', 'down_proj'
    # Targeting all linear layers for LoRA adaptation is a common strategy.
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"] 
)
model = get_peft_model(base_model, lora_cfg)
model.print_trainable_parameters()

Loading checkpoint shards: 100%|██████████| 3/3 [00:22<00:00,  7.46s/it]


trainable params: 41,533,440 || all params: 2,655,875,328 || trainable%: 1.5638


## 2. Finetuning

In [None]:
# %%  ── Trainer setup ──

training_args = TrainingArguments(
    output_dir           = OUTPUT_DIR,
    run_name             = RUN_NAME,
    bf16                 = torch.cuda.is_available(), # Use bfloat16 if available
    per_device_train_batch_size = 4,
    per_device_eval_batch_size  = 4,
    gradient_accumulation_steps = 4, # Effective batch size = 4*4 = 16
    num_train_epochs     = 2, # Increased epochs
    learning_rate        = 1e-4, # Reduced learning rate
    lr_scheduler_type    = "cosine", # Added scheduler
    warmup_ratio         = 0.05,    # Added warmup
    logging_steps        = 25,
    eval_strategy        = "epoch",   
    save_strategy        = "epoch",
    save_total_limit     = 1, # Save only the best model
    load_best_model_at_end = True, # Load the best model at the end of training
    report_to            = ["wandb"],
    seed                 = 42,
)

trainer = Trainer(
    model           = model,
    args            = training_args,
    train_dataset   = train_ds,
    eval_dataset    = val_ds,
    data_collator   = default_data_collator, # default_data_collator handles labels correctly if present
)

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


## 3. Baseline accuracy (model **before** finetune)

In [13]:
# %%  ── Helper: extract first answer letter ─────
import re
LETTER_RE = re.compile(r"\b([A-Da-d])\b")

def _first_letter(text: str) -> str:
    """Return the first occurrence of A-D (case-insensitive) or '' if absent."""
    match = LETTER_RE.search(text)
    return match.group(1).upper() if match else ""

In [None]:
# %%  ── Helper: accuracy evaluator ───────────────────
from tqdm.auto import tqdm

def accuracy_on_raw(raw_dataset_eval, desc="Baseline eval"):
    """Generate answers batch-wise, show progress, and compute accuracy."""
    model.eval() # Ensure model is in evaluation mode
    device   = next(model.parameters()).device
    correct  = total = 0
    n_batches = (len(raw_dataset_eval) + EVAL_BATCH_SIZE - 1) // EVAL_BATCH_SIZE

    for idx in tqdm(range(n_batches), desc=desc):
        start = idx * EVAL_BATCH_SIZE
        end   = min(start + EVAL_BATCH_SIZE, len(raw_dataset_eval))
        batch = raw_dataset_eval.select(range(start, end))

        prompts = [
            PROMPT_TEMPLATE.format(
                question = q,
                A = a, B = b, C = c, D = d
            )
            for q, a, b, c, d in zip(
                batch["Question"], batch["A"], batch["B"],
                batch["C"], batch["D"]
            )
        ]
        
        # Tokenize prompts without adding bos/eos here, as generate will handle it
        inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=MAX_PROMPT_TOKENS_FOR_FILTER).to(device)

        with torch.no_grad(), torch.autocast(device.type if device.type != 'mps' else 'cpu', dtype=torch.bfloat16, enabled=(device.type=="cuda")):
            # Ensure generation starts after the prompt by using input_ids as prefix
            gen_ids = model.generate(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                max_new_tokens=5, # Enough for " A" + EOS or similar
                do_sample=False,
                pad_token_id=tokenizer.eos_token_id
            )

        for ans, g_full in zip(batch["Answer"], gen_ids):
            # Decode only the generated part (after the prompt)
            prompt_len = inputs["input_ids"].shape[1]
            generated_tokens = g_full[prompt_len:]
            
            pred_text = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
            pred_letter = _first_letter(pred_text) # Extract first A-D letter
            
            correct += int(pred_letter == ans.strip().upper())
            total   += 1

    return correct / total if total else 0.0

In [15]:
# %%  ── Baseline inference accuracy on TEST set ────────────────
baseline_test_acc = accuracy_on_raw(raw_test, desc="Baseline test eval") # Evaluate on raw_test
print(f"🔹 Baseline (pre-FT) TEST accuracy: {baseline_test_acc:.4f}")
wandb.log({"baseline_test_accuracy": baseline_test_acc})

Baseline test eval: 100%|██████████| 297/297 [07:21<00:00,  1.49s/it]

🔹 Baseline (pre-FT) TEST accuracy: 0.3751





In [16]:
# %%  ── Finetune ───────────────────────────────────────────────────────────────
trainer.train()

Epoch,Training Loss,Validation Loss
1,0.5477,0.522862
2,0.3388,0.53891


TrainOutput(global_step=2370, training_loss=0.4726693630218506, metrics={'train_runtime': 13491.651, 'train_samples_per_second': 2.81, 'train_steps_per_second': 0.176, 'total_flos': 1.409906483048448e+17, 'train_loss': 0.4726693630218506, 'epoch': 2.0})

## 4. Final Evaluation and Inference (Example)

In [17]:
# %%  ── Perplexity on validation / test sets ───────────────────────────────────
val_metrics  = trainer.evaluate(eval_dataset=val_ds) # Uses tokenized val_ds
test_metrics = trainer.evaluate(eval_dataset=test_ds) # Uses tokenized test_ds

val_ppl  = math.exp(val_metrics["eval_loss"])
test_ppl = math.exp(test_metrics["eval_loss"])

print(f"🔹 Post-FT validation ppl: {val_ppl:.2f}")
print(f"🔹 Post-FT test       ppl: {test_ppl:.2f}")

wandb.log({"final_val_loss": val_metrics["eval_loss"],
           "final_val_ppl" : val_ppl,
           "final_test_loss": test_metrics["eval_loss"],
           "final_test_ppl" : test_ppl})

🔹 Post-FT validation ppl: 1.69
🔹 Post-FT test       ppl: 1.69


In [18]:
# %%  ── Accuracy after finetune on TEST set ───────────────────────────────────
ft_test_accuracy = accuracy_on_raw(raw_test, desc="Post-FT test eval") # Evaluate on raw_test
print(f"🔹 Post-FT TEST accuracy: {ft_test_accuracy:.4f}")
wandb.log({"post_ft_test_accuracy": ft_test_accuracy})

Post-FT test eval: 100%|██████████| 297/297 [03:21<00:00,  1.47it/s]

🔹 Post-FT TEST accuracy: 0.5392





In [19]:
# %%  ── Example inference ─────────────────────────────
example_q     = "Was ist die Hauptstadt von Frankreich?"
example_opts  = {"A":"Berlin","B":"Paris","C":"London","D":"Madrid"}

example_prompt = PROMPT_TEMPLATE.format(
    question = example_q,
    A = example_opts["A"],
    B = example_opts["B"],
    C = example_opts["C"],
    D = example_opts["D"]
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device) # Ensure model is on the correct device
model.eval() # Ensure model is in eval mode for inference

inputs = tokenizer(example_prompt, return_tensors="pt", padding=True, truncation=True, max_length=MAX_PROMPT_TOKENS_FOR_FILTER).to(device)

with torch.no_grad(), torch.autocast(device.type if device.type != 'mps' else 'cpu', dtype=torch.bfloat16, enabled=(device.type=="cuda")):
    gen_ids = model.generate(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        max_new_tokens=5, 
        do_sample=False,
        pad_token_id=tokenizer.eos_token_id
    )

prompt_len = inputs["input_ids"].shape[1]
generated_tokens = gen_ids[0][prompt_len:] # Get only generated tokens
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
generated_letter = _first_letter(generated_text)

print(f"\nQuestion: {example_q}")
print("Options:")
for k,v in example_opts.items():
    print(f"  ({k}) {v}")
print(f"Model answer letter: ({generated_letter})")
wandb.log({"example_question": example_q, "example_answer_predicted": generated_letter})


Question: Was ist die Hauptstadt von Frankreich?
Options:
  (A) Berlin
  (B) Paris
  (C) London
  (D) Madrid
Model answer letter: (B)


In [20]:
wandb.finish()

0,1
baseline_test_accuracy,▁
eval/loss,▁█▁▁
eval/runtime,█▅▅▁
eval/samples_per_second,▁▄▄█
eval/steps_per_second,▁▅▄█
final_test_loss,▁
final_test_ppl,▁
final_val_loss,▁
final_val_ppl,▁
post_ft_test_accuracy,▁

0,1
baseline_test_accuracy,0.37511
eval/loss,0.52199
eval/runtime,244.3603
eval/samples_per_second,9.699
eval/steps_per_second,2.427
example_answer_predicted,B
example_question,Was ist die Hauptsta...
final_test_loss,0.52199
final_test_ppl,1.68537
final_val_loss,0.52286
