### Fine-tuning Gemma-3 Model on Jeopardy Data with MLX

This notebook demonstrates a LoRA fine-tuning pipeline for the `gemma-3` model
using the `mlx_lm` library. Our custom training data is located in
`test_data/extra_matches.tsv` and contains Jeopardy questions & answers.

Columns in the TSV:
[ round, clue_value, daily_double_value, category, comments, answer, question, air_date, notes ]

In [1]:
%pip install -Uqq mlx mlx_lm transformers datasets 

Note: you may need to restart the kernel to use updated packages.


In [2]:
import pandas as pd
from pathlib import Path

In [3]:
# 1. Load the data
data_path = Path("/Users/sid/Projects/code/JeopardyLLM/data")
df = pd.read_csv(f"{data_path}/extra_matches.tsv", sep="\t")

# 2. Format the data for Gemma's chat template
def generate_prompt(row: pd.Series) -> str:
    return f"""<bos><start_of_turn>user
# Instructions
# You are Alex Trebek hosting the current season of Jeopardy! You will provide a clue (the 'answer' in Jeopardy terms), and a contestant will respond with the correct 'question'. 

Jeopardy Round : {row['round']}
Category : {row['category']}
Value : {row['clue_value']}
Air Date : {row['air_date']}
Comments : {row['comments']}
Notes : {row['notes']}
Question : {row['question']}

<end_of_turn> <start_of_turn>model Here is your Jeopardy clue: "{row['answer']}" <end_of_turn><eos>"""

df["text"] = df.apply(generate_prompt, axis=1)

In [4]:
# 3. Split into train/valid sets & save to JSONL
split_index = int(len(df) * 0.9)
df_shuf = df.sample(frac=1, random_state=42)
train, valid = df_shuf[:split_index], df_shuf[split_index:]

Path(f"{data_path}/training_data").mkdir(exist_ok=True)
train[["text"]].to_json(f"{data_path}/training_data/train.jsonl", orient="records", lines=True)
valid[["text"]].to_json(f"{data_path}/training_data/valid.jsonl", orient="records", lines=True)

In [None]:
# 4. LoRA Fine-tuning on gemma-3 using mlx_lm
# NOTE: Adjust --iters, --model, and hyperparameters as needed
!python -m mlx_lm.lora \
    --model google/gemma-3-1b-it \
    --train \
    --iters 600 \
    --adapter-path /Users/sid/Projects/code/JeopardyLLM/models \
    --data /Users/sid/Projects/code/JeopardyLLM/data/training_data 

Loading pretrained model
Fetching 8 files: 100%|████████████████████████| 8/8 [00:00<00:00, 23865.17it/s]
Loading datasets
Training
Trainable parameters: 0.035% (0.459M/1301.876M)
Starting training..., iters: 600
Iter 1: Val loss 5.892, Val took 12.856s
mx.metal.get_peak_memory is deprecated and will be removed in a future version. Use mx.get_peak_memory instead.
Iter 10: Train loss 5.390, Learning Rate 1.000e-05, It/sec 0.904, Tokens/sec 548.091, Trained Tokens 6066, Peak mem 7.353 GB
Iter 20: Train loss 4.186, Learning Rate 1.000e-05, It/sec 1.119, Tokens/sec 686.702, Trained Tokens 12201, Peak mem 7.353 GB
Iter 30: Train loss 3.540, Learning Rate 1.000e-05, It/sec 1.125, Tokens/sec 690.792, Trained Tokens 18344, Peak mem 7.353 GB
Iter 40: Train loss 3.086, Learning Rate 1.000e-05, It/sec 1.089, Tokens/sec 686.406, Trained Tokens 24648, Peak mem 7.353 GB
Iter 50: Train loss 2.792, Learning Rate 1.000e-05, It/sec 1.018, Tokens/sec 635.116, Trained Tokens 30888, Peak mem 7.353 GB
Iter 

In [None]:
# 5. Merge LoRA weights to produce a fused model
"""
!python -m mlx_lm.fuse \
    --model google/gemma-3-1b-it \
    --adapter-file jeopardy_adapters.npz \
    --out-dir gemma3_jeopardy_fused
"""

In [None]:
# 6. (Optional) Test generation
from mlx_lm import generate, load

model, tokenizer = load("google/gemma-3-1b-it", adapter_path="/Users/sid/Projects/code/JeopardyLLM/models/")
prompt = "Hey! Are you Alex Trebek? Test me on Jeopardy!"
resp = generate(
    model,
    tokenizer,
    prompt=prompt,
    max_tokens=6000,
    verbose=True
)
print(resp)

In [None]:
def evaluate_jeopardy(model, tokenizer, n_samples=10):
    """
    Generate a few sample Jeopardy clues and see how consistent they are.
    Optionally, retrieve relevant references to reduce hallucinations.
    """
    sample_indices = np.random.choice(len(val_df), size=n_samples, replace=False)
    for idx in sample_indices:
        row = val_df.iloc[idx]
        # Retrieve relevant references for the clue from the vector store
        # (Just a demonstration).
        references = retrieve_similar_clues(row["answer"], k=2)
        
        # Prepare a prompt.
        prompt = f"""You are Alex Trebek:
Clue: {row['answer']}
(References: {references['answer'].tolist()})

What is the best way to deliver this clue to the contestant?
"""
        # Hypothetical generation method
        # response = model.generate(prompt, tokenizer, max_tokens=80)
        # Evaluate correctness, style, or factual consistency as needed.
        # ...
        pass
