# Gemma2-2B Fine-Tuning

## Imports

In [1]:
!pip install transformers datasets torch peft
!pip install -q kagglehub
!pip install wandb



In [2]:
import wandb, os

import huggingface_hub
huggingface_hub.notebook_login()

try:
  from kaggle_secrets import UserSecretsClient
  user_secrets = UserSecretsClient()
  os.environ['KAGGLE_KEY'] = user_secrets.get_secret("KAGGLE_KEY")
  os.environ['KAGGLE_USERNAME'] = user_secrets.get_secret("KAGGLE_USERNAME")
  os.environ['WANDB_API_KEY'] = user_secrets.get_secret("WANDB_API_KEY")
except:
  try:
    from google.colab import userdata
    os.environ['KAGGLE_KEY'] = userdata.get("KAGGLE_KEY")
    os.environ['KAGGLE_USERNAME'] = userdata.get("KAGGLE_USERNAME")
    os.environ['WANDB_API_KEY'] = userdata.get("WANDB_API_KEY")
  except:
    print("Could't get secrets")

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [3]:
import os
import shutil
import torch

device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define the path to the Hugging Face cache directory
hf_cache_home = os.path.expanduser("~/.cache/huggingface/hub")

# Check if the cache directory exists and remove it
if os.path.exists(hf_cache_home):
    print(f"Clearing Hugging Face cache at: {hf_cache_home}")
    shutil.rmtree(hf_cache_home)
    print("Hugging Face cache cleared.")
else:
    print("Hugging Face cache directory not found, nothing to clear.")

# # Uninstall existing torch, torchvision, and transformers installations
# !pip uninstall -y torch torchvision transformers

# # Install the latest compatible versions of torch and torchvision
# # Explicitly specifying version for torch and torchvision can help prevent conflicts.
# # Given the original `torch` version was 2.6.0, I will try to reinstall with a compatible `torchvision`.
# !pip install torch torchvision --upgrade

# # Reinstall transformers to ensure compatibility with the new torch/torchvision
!pip install transformers --upgrade

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import huggingface_hub

# Log in to Hugging Face (if not already logged in, this will prompt)
huggingface_hub.notebook_login()

model_id = "google/gemma-3-270m-it"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="cuda",
    trust_remote_code=True,
    attn_implementation="eager" # Explicitly use eager attention to avoid XLA issues
)

model.to(device)

print("Model and tokenizer loaded successfully.")

Hugging Face cache directory not found, nothing to clear.


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

chat_template.jinja:   0%|          | 0.00/1.53k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.35k [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/536M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/173 [00:00<?, ?B/s]

Model and tokenizer loaded successfully.


## Baseline sample tests

First, making sure it generates something

In [4]:
prompt = "If you have 5 apples but someone takes 2, how many do you have left? Show your work, then write your final answer on the last line after ####"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=100, do_sample=False)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.


If you have 5 apples but someone takes 2, how many do you have left? Show your work, then write your final answer on the last line after ####

**Answer:** 7



## MATH dataset

In [5]:
from datasets import load_dataset

ds = load_dataset("qwedsacf/competition_math")

README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001-7320a6f3aba8eb(…):   0%|          | 0.00/4.85M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/12500 [00:00<?, ? examples/s]

In [6]:
ds_split = ds['train'].train_test_split(test_size=0.10, seed=2025)

train_val_split = ds_split['train'].train_test_split(test_size=0.1111, seed=2025)

math_train = train_val_split['train']
math_val   = train_val_split['test']
math_test  = ds_split['test']

print("Train size:", len(math_train))
print("Val size:",   len(math_val))
print("Test size:",  len(math_test))

Train size: 10000
Val size: 1250
Test size: 1250


In [7]:
print(math_train[0])

{'problem': 'What is the distance between the two intersections of $y=x^2$ and $x+y=1$?', 'level': 'Level 5', 'type': 'Algebra', 'solution': 'To find the $x$-coordinates of the intersections, substitute $x^2$ for $y$ in $x+y=1$ and solve for $x$, resulting in  \\begin{align*}\nx+x^2&=1 \\\\\n\\Rightarrow \\qquad x^2+x-1&=0 \\\\\n\\Rightarrow \\qquad x&=\\frac{-1\\pm\\sqrt{1+4}}2=\\frac{-1\\pm\\sqrt5}2\\\\\n\\end{align*}Using each of these coordinates to solve for $y$ gives us the intersections at $\\left(\\frac{-1+\\sqrt5}2,\\frac{3-\\sqrt5}2\\right)$ and $\\left(\\frac{-1-\\sqrt5}2,\\frac{3+\\sqrt5}2\\right)$.  Using the distance formula, we have \\begin{align*}\n&\\sqrt{ \\left(\\frac{-1+\\sqrt5}{2}-\\frac{-1-\\sqrt5}{2}\\right)^2 + \\left(\\frac{3-\\sqrt5}2-\\frac{3+\\sqrt5}2\\right)^2 }\\\\\n&\\qquad=\\sqrt{\\left(\\frac{2\\sqrt5}2\\right)^2 + \\left(-\\frac{2\\sqrt5}2\\right)^2}\\\\\n&\\qquad=\\sqrt{ 2\\sqrt5^2 }\\\\\n&\\qquad=\\boxed{\\sqrt{10}}.\n\\end{align*}'}


In [8]:
print(math_train[1])
print(math_train[2])
print(math_train[3])

{'problem': 'For how many integer values of $a$ does the equation $x^2 + ax + 5a = 0$ have integer solutions for $x$?', 'level': 'Level 5', 'type': 'Algebra', 'solution': "Suppose the roots of the quadratic are given by $m$ and $n$. Note that $$(x-m)(x-n) = x^2 - (m+n)x + mn = x^2 + ax + 5a,$$ and setting coefficients equal, it follows that  \\begin{align*}\nm + n &= -a \\\\\nmn &= 5a\n\\end{align*} (This also follows directly from Vieta's formulas.) Notice that the $a$ can be canceled by either dividing or noting that $$0 = 5a + 5 \\cdot (-a) = mn + 5(m+n).$$\n\nSimon's Favorite Factoring Trick can now be applied: $$mn + 5m + 5n + 25 = (m+5)(n+5) = 25.$$ It follows that $m+5$ and $n+5$ are divisors of $25$, whose pairs of divisors are given by $\\pm \\{(1,25),(5,5),(25,1)\\}$. Solving, we see that $(m,n)$ is in the set $$\\{(-4,20),(0,0),(20,-4),(-6,-30),(-10,-10),(-30,-6)\\}.$$ However, the two pairs of symmetric solutions yield redundant values for $a$, so it follows that the answer

The inspection of `math_train` entries confirmed the presence of numerous LaTeX commands and special characters such as `$`, `\`, `\begin{align*}` (and `\end{align*}`), `\\`, `\cdot`, `\ldots`, `\frac`, `\ge`, `\ge`, `\boxed`, `\binom`, `\times`, `\%`, and others. These need to be removed or replaced to make the text more readable and suitable for natural language processing tasks.

Now, let's create a cleaning function to handle these patterns. We'll start with a basic function and refine it as needed.

**Reasoning**:
Based on the observed patterns of LaTeX commands and special characters, I will now create a Python function to clean the text data. This function will use regular expressions to remove or replace these unwanted elements, making the text suitable for further processing.



## TODO: Cleaning MATH dataset

In [9]:
import re

# --- Utility: remove surrounding spaces ---
def _strip(s):
    return s.strip()

# --- Step 1: Convert LaTeX math syntax to plain math BEFORE tagging ---
def _normalize_latex(expr):

    # Fractions
    expr = re.sub(r'\\frac\s*\{([^}]+)\}\s*\{([^}]+)\}', r'(\1)/(\2)', expr)

    # Binomials
    expr = re.sub(r'\\binom\s*\{([^}]+)\}\s*\{([^}]+)\}', r'C(\1, \2)', expr)

    # Square roots
    expr = re.sub(r'\\sqrt\s*\{([^}]+)\}', r'sqrt(\1)', expr)

    # Operators
    replacements = {
        r'\\cdot': '*',
        r'\\times': '*',
        r'\\ge': '>=',
        r'\\le': '<=',
        r'\\neq': '!=',
        r'\\approx': '~',
        r'\\equiv': '==',
        r'\\infty': 'infinity',
        r'\\ldots': '...',
        r'\\sum': 'sum',
        r'\\prod': 'product',
    }
    for key, val in replacements.items():
        expr = re.sub(key, val, expr)

    # Greek letters
    greek = [
        'alpha', 'beta', 'gamma', 'delta', 'epsilon', 'theta', 'lambda',
        'mu', 'nu', 'pi', 'rho', 'sigma', 'tau', 'phi', 'chi', 'psi', 'omega'
    ]
    for g in greek:
        expr = re.sub(rf'\\{g}', g, expr)

    # Formatting removal
    expr = re.sub(r'\\text\{([^}]+)\}', r'\1', expr)
    expr = re.sub(r'\\boxed\{([^}]+)\}', r'\1', expr)

    # Remove \left, \right
    expr = re.sub(r'\\left', '', expr)
    expr = re.sub(r'\\right', '', expr)

    # Remove stray command backslashes
    expr = re.sub(r'\\([a-zA-Z]+)', '', expr)
    expr = expr.replace('\\', '')

    return expr.strip()


# --- Step 2: Wrap math expressions in [MATH ...] spans ---
def _wrap_math(text):

    # Inline $...$
    def inline_repl(m):
        expr = _normalize_latex(m.group(1))
        return f"[MATH {expr}]"

    text = re.sub(r'\$(.+?)\$', inline_repl, text)

    # Display \[...\]
    def display_repl(m):
        expr = _normalize_latex(m.group(1))
        return f"[MATH {expr}]"

    text = re.sub(r'\\\[(.+?)\\\]', display_repl, text, flags=re.DOTALL)

    # Parentheses math \(...\)
    text = re.sub(r'\\\((.+?)\\\)', inline_repl, text)

    return text


# --- Step 3: Handle align, gather, eqnarray blocks ---
def _replace_align_envs(text):

    pattern = r'\\begin\{(align\*?|gather\*?|eqnarray\*?)\}(.+?)\\end\{\1\}'

    def repl(m):
        block = m.group(2)

        # Split on line breaks or & alignment markers
        lines = re.split(r'\\\\', block)
        cleaned = []
        for line in lines:
            line = line.replace('&', '')
            line = _normalize_latex(line)
            if line.strip():
                cleaned.append(f"[MATH {line.strip()}]")
        return " ".join(cleaned)

    return re.sub(pattern, repl, text, flags=re.DOTALL)


# --- FINAL CLEAN FUNCTION ---
def clean_text(text):

    text = _replace_align_envs(text)
    text = _wrap_math(text)

    # Remove duplicate spaces
    text = re.sub(r'\s+', ' ', text).strip()

    # Fix malformed MATH blocks like "[MATH x]" or "[MATHx]" -> "[MATH x ]"
    text = re.sub(r'\[MATH([^A-Za-z0-9])', r'[MATH \1', text)

    # Ensure closing brackets always have at least 1 space before ]
    text = re.sub(r'\[MATH\s+([^\]]*?)\]', lambda m: f"[MATH {m.group(1).strip()} ]", text)

    return text.strip()

In [10]:
clean_text("Let\n\\[k = \\frac{a^3 + 6}{a} = \\frac{b^3 + 6}{b} = \\frac{c^3 + 6}{c}.\\]Then $a,$ $b,$ and $c$ are all roots of\n\\[k = \\frac{x^3 + 6}{x},\\]or $x^3 - kx + 6 = 0.$  By Vieta's formulas, $a + b + c = 0.$\n\nAlso,\n\\begin{align*}\na^3 - ka + 6 &= 0, \\\\\nb^3 - kb + 6 &= 0, \\\\\nc^3 - kc + 6 &= 0.\n\\end{align*}Adding these, we get $a^3 + b^3 + c^3 - k(a + b + c) + 18 = 0,$ so $a^3 + b^3 + c^3 = k(a + b + c) - 18 = \\boxed{-18}.$")

"Let [MATH k = (a^3 + 6)/(a) = (b^3 + 6)/(b) = (c^3 + 6)/(c). ]Then [MATH a, ] [MATH b, ] and [MATH c ] are all roots of [MATH k = (x^3 + 6)/(x), ]or [MATH x^3 - kx + 6 = 0. ] By Vieta's formulas, [MATH a + b + c = 0. ] Also, [MATH a^3 - ka + 6 = 0, ] [MATH b^3 - kb + 6 = 0, ] [MATH c^3 - kc + 6 = 0. ]Adding these, we get [MATH a^3 + b^3 + c^3 - k(a + b + c) + 18 = 0, ] so [MATH a^3 + b^3 + c^3 = k(a + b + c) - 18 = -18. ]"

In [11]:
#applying to training/testing splits
math_train = math_train.map(lambda example: {'cleaned_problem': clean_text(example['problem'])})
math_train = math_train.map(lambda example: {'cleaned_solution': clean_text(example['solution'])})

math_test = math_test.map(lambda example: {'cleaned_problem': clean_text(example['problem'])})
math_test = math_test.map(lambda example: {'cleaned_solution': clean_text(example['solution'])})

math_val = math_val.map(lambda example: {'cleaned_problem': clean_text(example['problem'])})
math_val = math_val.map(lambda example: {'cleaned_solution': clean_text(example['solution'])})


print("Cleaning applied to math_train, math_test and math_val datasets.")
print("Sample of math_train after cleaning:")
print(math_train[0]['problem'])
print(math_train[0]['cleaned_problem'])
print(math_train[0]['solution'])
print(math_train[0]['cleaned_solution'])

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1250 [00:00<?, ? examples/s]

Map:   0%|          | 0/1250 [00:00<?, ? examples/s]

Map:   0%|          | 0/1250 [00:00<?, ? examples/s]

Map:   0%|          | 0/1250 [00:00<?, ? examples/s]

Cleaning applied to math_train, math_test and math_val datasets.
Sample of math_train after cleaning:
What is the distance between the two intersections of $y=x^2$ and $x+y=1$?
What is the distance between the two intersections of [MATH y=x^2 ] and [MATH x+y=1 ]?
To find the $x$-coordinates of the intersections, substitute $x^2$ for $y$ in $x+y=1$ and solve for $x$, resulting in  \begin{align*}
x+x^2&=1 \\
\Rightarrow \qquad x^2+x-1&=0 \\
\Rightarrow \qquad x&=\frac{-1\pm\sqrt{1+4}}2=\frac{-1\pm\sqrt5}2\\
\end{align*}Using each of these coordinates to solve for $y$ gives us the intersections at $\left(\frac{-1+\sqrt5}2,\frac{3-\sqrt5}2\right)$ and $\left(\frac{-1-\sqrt5}2,\frac{3+\sqrt5}2\right)$.  Using the distance formula, we have \begin{align*}
&\sqrt{ \left(\frac{-1+\sqrt5}{2}-\frac{-1-\sqrt5}{2}\right)^2 + \left(\frac{3-\sqrt5}2-\frac{3+\sqrt5}2\right)^2 }\\
&\qquad=\sqrt{\left(\frac{2\sqrt5}2\right)^2 + \left(-\frac{2\sqrt5}2\right)^2}\\
&\qquad=\sqrt{ 2\sqrt5^2 }\\
&\qquad=\box

In [12]:
math_train[0]

{'problem': 'What is the distance between the two intersections of $y=x^2$ and $x+y=1$?',
 'level': 'Level 5',
 'type': 'Algebra',
 'solution': 'To find the $x$-coordinates of the intersections, substitute $x^2$ for $y$ in $x+y=1$ and solve for $x$, resulting in  \\begin{align*}\nx+x^2&=1 \\\\\n\\Rightarrow \\qquad x^2+x-1&=0 \\\\\n\\Rightarrow \\qquad x&=\\frac{-1\\pm\\sqrt{1+4}}2=\\frac{-1\\pm\\sqrt5}2\\\\\n\\end{align*}Using each of these coordinates to solve for $y$ gives us the intersections at $\\left(\\frac{-1+\\sqrt5}2,\\frac{3-\\sqrt5}2\\right)$ and $\\left(\\frac{-1-\\sqrt5}2,\\frac{3+\\sqrt5}2\\right)$.  Using the distance formula, we have \\begin{align*}\n&\\sqrt{ \\left(\\frac{-1+\\sqrt5}{2}-\\frac{-1-\\sqrt5}{2}\\right)^2 + \\left(\\frac{3-\\sqrt5}2-\\frac{3+\\sqrt5}2\\right)^2 }\\\\\n&\\qquad=\\sqrt{\\left(\\frac{2\\sqrt5}2\\right)^2 + \\left(-\\frac{2\\sqrt5}2\\right)^2}\\\\\n&\\qquad=\\sqrt{ 2\\sqrt5^2 }\\\\\n&\\qquad=\\boxed{\\sqrt{10}}.\n\\end{align*}',
 'cleaned_pro

In [20]:
import re
import torch
from tqdm import tqdm

def extract_answer(text):
    """
    Extract the final numeric answer from model output.
    Priority:
        1. Last non-empty [MATH ...] block
        2. Last \boxed{...}
        3. Last number ONLY IF NO math/boxed blocks found
    """

    # 1. Extract full MATH blocks: [MATH 42], [MATH -7], etc.
    blocks = re.findall(r'\[MATH\s+([^\]]+)\]', text)
    if blocks:  # only non-empty matches
        last = blocks[-1].strip()
        candidates = re.findall(r'-?\d+(?:_\d+)?(?:\.\d+)?', last)
        if candidates:
            return candidates[-1]

    # 2. Boxed answers
    boxed = re.findall(r'\\boxed\{([^}]*)\}', text)
    if boxed:
        candidates = re.findall(r'-?\d+(?:_\d+)?(?:\.\d+)?', boxed[-1])
        if candidates:
            return candidates[-1]

    # 3. LAST RESORT — only if no MATH or boxed answer found
    #    This prevents greedy extraction of numbers from the explanation
    return None



def eval_math(model, tokenizer, data, batch_size=16, max_new_tokens=128, verbose=False, num_eval=None):
    """Evaluate model accuracy on your cleaned math dataset."""
    model.eval()
    torch.cuda.empty_cache()

    prompt_template = (
        "Solve the math problem; show your work in a 'Reasoning:' section.\n"
        "You MUST finish with:\n\n"
        "Final Answer:\n"
        "<answer>\n\n"
        "Question:\n{question}\n\n"
        "Reasoning:\n"
    )

    if num_eval is None:
        num_eval = len(data)

    correct = 0
    formatted = 0

    items = list(range(num_eval))

    for start in tqdm(range(0, num_eval, batch_size)):
        batch_idx = items[start:start + batch_size]
        batch = [data[i] for i in batch_idx]

        prompts = [
            prompt_template.format(question=item["cleaned_problem"])
            for item in batch
        ]

        inputs = tokenizer(
            prompts,
            padding=True,
            truncation=True,
            max_length=1024,
            return_tensors="pt"
        ).to(model.device)

        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False,
            )

        for j, item in enumerate(batch):
            input_len = inputs["input_ids"][j].shape[0]
            gen_ids = outputs[j][input_len:]
            gen_text = tokenizer.decode(gen_ids, skip_special_tokens=True)

            pred = extract_answer(gen_text)
            gt = extract_answer(item["cleaned_solution"])

            if pred is not None:
                formatted += 1
                if gt is not None and pred == gt:
                    correct += 1

    total = num_eval
    format_rate = formatted / total
    accuracy = correct / formatted if formatted else 0

    print(f"\nFormat rate: {format_rate:.2%} ({formatted}/{total})")
    print(f"Accuracy:    {accuracy:.2%} ({correct}/{formatted})")

    return format_rate, accuracy



In [14]:
#testing eval function
eval_math(model, tokenizer, math_train, verbose=True, num_eval=3)

Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.


Problem 0
What is the distance between the two intersections of [MATH y=x^2 ] and [MATH x+y=1 ]?

Real solution: To find the [MATH x ]-coordinates of the intersections, substitute [MATH x^2 ] for [MATH y ] in [MATH x+y=1 ] and solve for [MATH x ], resulting in [MATH x+x^2=1 ] [MATH x^2+x-1=0 ] [MATH x={-1(1+4)}2={-15}2 ]Using each of these coordinates to solve for [MATH y ] gives us the intersections at [MATH <=ft({-1+5}2,{3-5}2) ] and [MATH <=ft({-1-5}2,{3+5}2) ]. Using the distance formula, we have [MATH sqrt( <=ft((-1+5)/(2)-(-1-5)/(2))^2 + <=ft({3-5)2-{3+5}2)^2 } ] [MATH =sqrt(<=ft({25)2)^2 + <=ft(-{25}2)^2} ] [MATH =sqrt( 25^2 ) ] [MATH =sqrt(10). ]
Ground Truth: 10
Generated output:  2

Final Answer: 2

Pred Answer: None

--------------------------------------------------



Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.


Problem 1
For how many integer values of [MATH a ] does the equation [MATH x^2 + ax + 5a = 0 ] have integer solutions for [MATH x ]?

Real solution: Suppose the roots of the quadratic are given by [MATH m ] and [MATH n ]. Note that [MATH $(x-m)(x-n) = x^2 - (m+n)x + mn = x^2 + ax + 5a, ][MATH and setting coefficients equal, it follows that [MATH  m + n = -a ] [MATH mn = 5a ] (This also follows directly from Vieta's formulas.) Notice that the]a[MATH can be canceled by either dividing or noting that ][MATH 0 = 5a + 5 * (-a) = mn + 5(m+n). ]$ Simon's Favorite Factoring Trick can now be applied: [MATH $mn + 5m + 5n + 25 = (m+5)(n+5) = 25. ][MATH It follows that ]m+5[MATH and ]n+5[MATH are divisors of ]25[MATH , whose pairs of divisors are given by ]\pm \{(1,25),(5,5),(25,1)\}[MATH . Solving, we see that ](m,n)[MATH is in the set ][MATH {(-4,20),(0,0),(20,-4),(-6,-30),(-10,-10),(-30,-6)}. ][MATH However, the two pairs of symmetric solutions yield redundant values for ]a[MATH , so it follows

(0.0, 0.0)

# SFT training

## Format dataset questions


In [15]:
def format_math(example, tokenizer):
    """Format example for training with proper label masking.

    Args:
        example: Dictionary with 'cleaned_problem' and 'cleaned_solution' keys
        tokenizer: HuggingFace tokenizer

    Returns:
        Dictionary with 'input_ids', 'attention_mask', and 'labels'
        Labels are -100 for question tokens (ignored in loss) and
        actual token IDs for answer tokens.
    """
    prompt_template = """Solve the math problem; show your work in a "Reasoning:" section (2-4 sentences only).
You MUST finish with a single line that starts exactly with:

Final Answer:
<single token with final answer, e.g. -18, 42_7, 3/2, 1.25>

Rules:
- The Final Answer line must contain *only* the final answer (no text, no units, no brackets, no LaTeX).
- Do NOT include any additional text after the Final Answer line.

Question:
{question}

Reasoning:
"""
    question_text = prompt_template.format(question=example['cleaned_problem'])
    answer_text = f"{example['cleaned_solution']}"

    # Tokenize question and answer separately
    question_tokens = tokenizer(question_text, add_special_tokens=True, truncation=True,max_length=1024)
    answer_tokens = tokenizer(answer_text, add_special_tokens=False, truncation=True,max_length=1024)

    # Combine
    input_ids = question_tokens['input_ids'] + answer_tokens['input_ids']
    attention_mask = question_tokens['attention_mask'] + answer_tokens['attention_mask']

    # Create labels: -100 for question (ignored), actual tokens for answer
    labels = [-100] * len(question_tokens['input_ids']) + answer_tokens['input_ids']

    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': labels
    }

# Example:
print(f"First Question Tokenized: {tokenizer(math_train[0]['cleaned_problem'], add_special_tokens=True)}")
print()
print(f"First Answer Tokenized: {tokenizer(math_train[0]['cleaned_solution'], add_special_tokens=True)}")
print()
example_formatted = format_math(math_train[0], tokenizer)
print(f"Combined tokenized input: {example_formatted['input_ids']}")
print()
print(f"Prediction targets/labels: {example_formatted['labels']}")

First Question Tokenized: {'input_ids': [2, 3689, 563, 506, 5149, 1534, 506, 1156, 69811, 529, 870, 202896, 570, 236784, 236781, 236884, 236778, 4422, 532, 870, 202896, 1123, 236862, 236762, 236784, 236770, 4422, 236881], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

First Answer Tokenized: {'input_ids': [2, 2021, 1586, 506, 870, 202896, 1123, 4422, 236772, 56220, 529, 506, 69811, 236764, 22784, 870, 202896, 1123, 236884, 236778, 4422, 573, 870, 202896, 570, 4422, 528, 870, 202896, 1123, 236862, 236762, 236784, 236770, 4422, 532, 8974, 573, 870, 202896, 1123, 7975, 9113, 528, 870, 202896, 1123, 236862, 236781, 236884, 236778, 236784, 236770, 4422, 870, 202896, 1123, 236884, 236778, 236862, 236781, 236772, 236770, 236784, 236771, 4422, 870, 202896, 1123, 2638, 236772, 236770, 236769, 236770, 236862, 236812, 4567, 236778, 2638, 236772, 236770, 236810, 236783, 236778, 4422, 17123, 1546, 529, 1239, 15375, 531, 8974, 573, 870, 20289

## Train model

In [16]:
# You do not need to modify this code
from transformers import TrainingArguments, Trainer, EarlyStoppingCallback, TrainerCallback, DataCollatorForSeq2Seq
from peft import LoraConfig, get_peft_model, TaskType
import torch
from transformers import DataCollatorForLanguageModeling

# Define a simple callback to print the train loss periodically
class PrintLossCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None and 'loss' in logs:
            print(f"Step {state.global_step}: Train loss={logs['loss']:.4f}")

def train_math(model, tokenizer, train_data, val_data, test_data,
                batch_size, gradient_accumulation_steps,
                learning_rate, max_epochs, patience,
                lora_r, lora_drop):
    """Fine-tune model with LoRA on MATH."""

    # Set padding token
    tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = tokenizer.eos_token_id
    model.config.eos_token_id = tokenizer.eos_token_id

    # Apply LoRA
    lora_config = LoraConfig(r=lora_r, lora_alpha=lora_r * 2, #EDIT: acc to gemma notation, lora config
                             target_modules=[
                                 "q_proj", "k_proj", "v_proj", "o_proj",
                                  # "gate_proj", "up_proj", "down_proj" --> bec 3x MLPs
                                  ],
                             lora_dropout=lora_drop, bias="none",
                             task_type=TaskType.CAUSAL_LM)
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()


    # --- REQUIRED FIX FOR GEMMA3 + PEFT + GRAD CHECKPOINTING ---
    model.gradient_checkpointing_enable()     # must be AFTER LoRA
    model.enable_input_require_grads()        # ensures tensors require grad
    model.config.use_cache = False            # avoids incompatible caches
    model.train()                             # ensure training mode

    model.to(device)
    # Format and tokenize with proper masking
    train_tokenized = train_data.map(
        lambda x: format_math(x, tokenizer),
        remove_columns=train_data.column_names
    )
    val_tokenized = val_data.map(
        lambda x: format_math(x, tokenizer),
        remove_columns=val_data.column_names
    )

    # Data collator for padding with label padding
    #data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, label_pad_token_id=-100, padding="longest")
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False
    )

    # Training arguments
    args = TrainingArguments(
        output_dir="./math_checkpoints",
        num_train_epochs=max_epochs,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=1, #EDIT: VALIDATIOIN IS SMALLER FOR MEMORY ISSUES
        gradient_accumulation_steps=gradient_accumulation_steps,
        learning_rate=learning_rate,
        logging_steps=20,
        eval_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        fp16=False,
        bf16=True,  #edit: true for TPU usage, smaller GPU usage
        report_to="none",
        gradient_checkpointing=True,
        eval_accumulation_steps=1, #EDIT: again, mem saving for val set
    )

    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=train_tokenized,
        eval_dataset=val_tokenized,
        data_collator=data_collator,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=patience), PrintLossCallback()],
    )

    # Train
    trainer.train()

    #save training
    trainer.save_model("./sft_model_full")
    tokenizer.save_pretrained("./sft_model_full")

    # Final evaluation
    model.config.use_cache = True
    print("\nFinal test results:")
    eval_math(model, tokenizer, math_test)

    return model

## Setting hyperparameters

In [17]:
# TODO: Set your hyperparameters here

max_length = 768
batch_size = 1  # Try 2, 4, or 8
gradient_accumulation_steps = 32  # Try 4 or 8
learning_rate = 2e-5 # Try 1e-4, 5e-5, or 1e-5
max_epochs = 1  # Try 5-10
patience = 2  # Try 2 or 3
lora_r = 4  # Try 8, 16, or 32
lora_drop =0.05  # Try 0.05 or 0.1
lora_alpha=16

# Free GPU memory and reload a fresh model for fine-tuning
# This is to ensure that if you run this cell multiple times
# with different hyperparameters, you always start from the
# clean original pretrained model (not a partially fine-tuned one)
import gc
del model
gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

#model_id = google/gemma-3-270m-it
#tpu_device = xm.xla_device()
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="cuda"
)
model.to(device)
model.config.use_flash_attn = False

if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

model.config.pad_token_id = tokenizer.pad_token_id

print("Fresh model loaded and ready for fine-tuning")

# Run training
model = train_math(
    model=model,
    tokenizer=tokenizer,
    train_data=math_train,
    val_data=math_val, #turn off for now (storage issues)
    test_data=math_test,
    batch_size=batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    learning_rate=learning_rate,
    max_epochs=max_epochs,
    patience=patience,
    lora_r=lora_r,
    lora_drop=lora_drop
)

Fresh model loaded and ready for fine-tuning
trainable params: 368,640 || all params: 268,466,816 || trainable%: 0.1373


Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1250 [00:00<?, ? examples/s]

Epoch,Training Loss,Validation Loss
1,1.6984,1.713297


Step 20: Train loss=2.6670
Step 40: Train loss=2.5377
Step 60: Train loss=2.3607
Step 80: Train loss=2.2679
Step 100: Train loss=2.1327
Step 120: Train loss=2.0798
Step 140: Train loss=2.0021
Step 160: Train loss=1.9379
Step 180: Train loss=1.9292
Step 200: Train loss=1.8384
Step 220: Train loss=1.7836
Step 240: Train loss=1.7770
Step 260: Train loss=1.7109
Step 280: Train loss=1.6735
Step 300: Train loss=1.6984

Final test results:


  0%|          | 0/1250 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.
  0%|          | 1/1250 [00:06<2:12:13,  6.35s/it]Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.
  0%|          | 2/1250 [00:12<2:11:12,  6.31s/it]Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.
  0%|          | 3/1250 [00:18<2:10:30,  6.28s/it]Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.
  0%|          | 4/1250 [00:25<2:10:09,  6.27s/it]Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.
  0%|          | 5/1250 [00:31<2:10:05,  6.27s/it]Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.
  0%|          | 6/1250 [00:37<2:09:55,  6.27s/it]Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.
  1%|          | 7/1250 [00:43<2:09:33,  6.25s/it]Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.
  1%|          | 8/1250 [00:50<2:09:21,  6.25s/it]Setting `pad_tok

KeyboardInterrupt: 

In [21]:
format_rate, accuracy = eval_math(
    model,
    tokenizer,
    math_test,   # or math_val
    max_new_tokens=128
)


  0%|          | 0/79 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.
  1%|▏         | 1/79 [00:07<10:04,  7.75s/it]Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.
  3%|▎         | 2/79 [00:15<09:55,  7.73s/it]Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.
  4%|▍         | 3/79 [00:23<09:48,  7.74s/it]Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.
  5%|▌         | 4/79 [00:30<09:37,  7.70s/it]Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.
  6%|▋         | 5/79 [00:38<09:29,  7.70s/it]Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.
  8%|▊         | 6/79 [00:46<09:20,  7.68s/it]Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.
  9%|▉         | 7/79 [00:53<09:11,  7.67s/it]Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.
 10%|█         | 8/79 [01:01<09:02,  7.64s/it]Setting `pad_token_id` to `eos_token_id`:1 for ope


Format rate: 47.36% (592/1250)
Accuracy:    9.46% (56/592)





In [None]:
from matplotlib import pyplot as plt

# Helper function to visualize performance during training
def plot_training_curves(train_losses, val_accuracies):
    """Plot training loss and validation accuracy curves.

    Parameters
    ----------
    train_losses : list of float
        Training loss values for each epoch. Should have one value per epoch.
    val_accuracies : list of float
        Validation accuracy values for each epoch. Should have same length as
        train_losses. Accuracy values should be between 0 and 1 (or 0 and 100
        if using percentages).

    Returns
    -------
    None
        Displays matplotlib figure with two subplots showing training curves.

    Examples
    --------
    >>> train_losses = [0.8, 0.6, 0.4, 0.3, 0.2]
    >>> val_accuracies = [0.75, 0.80, 0.85, 0.87, 0.88]
    >>> plot_training_curves(train_losses, val_accuracies)
    """
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

    ax1.plot(train_losses)
    ax1.set_title('Training Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.grid(True)

    ax2.plot(val_accuracies)
    ax2.set_title('Validation Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.grid(True)

    plt.tight_layout()
    plt.show()

In [None]:
# Finally, plot training curves
plot_training_curves(train_losses, val_accuracies)

# Ablation Study

diff loRA params

## Deploying with Streamlit

In [23]:
!pip install streamlit



In [32]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

base_model_id = "google/gemma-3-270m-it"  # or your gemma3 model_id

# 1. Load base pretrained model
base = AutoModelForCausalLM.from_pretrained(
    base_model_id,
    torch_dtype="auto",
    device_map="cpu"
)

# 2. Load your adapter weights
model = PeftModel.from_pretrained(base, "./math_checkpoints/checkpoint-313")

# 3. Merge LoRA → full model
model = model.merge_and_unload()

# 4. Save full merged model
model.save_pretrained("final_sft_model")
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
tokenizer.save_pretrained("final_sft_model")

print("Merged model saved to final_sft_model/")


Merged model saved to final_sft_model/


In [38]:
%%writefile app.py
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

MODEL_PATH = "final_sft_model"

@st.cache_resource
def load_model():
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_PATH,
        trust_remote_code=True,
        fix_mistral_regex=True
    )
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_PATH,
        torch_dtype=torch.bfloat16,
        device_map="auto"
    )
    return tokenizer, model

tokenizer, model = load_model()


Overwriting app.py


In [39]:
!streamlit run app.py --server.enableCORS false --server.enableXsrfProtection false



Collecting usage statistics. To deactivate, set browser.gatherUsageStats to false.
[0m
[0m
[34m[1m  You can now view your Streamlit app in your browser.[0m
[0m
[34m  Local URL: [0m[1mhttp://localhost:8501[0m
[34m  Network URL: [0m[1mhttp://172.28.0.12:8501[0m
[34m  External URL: [0m[1mhttp://34.32.157.225:8501[0m
[0m
[34m  Stopping...[0m
[34m  Stopping...[0m
