<a href="https://colab.research.google.com/github/jnises/llmog/blob/finetuning/training/finetune.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This is based on the [unsloth](https://unsloth.ai/) gemma 3 finetune notebook.

In [None]:
import os
is_colab = any('COLAB_' in k for k in os.environ.keys())

if is_colab:
  from google.colab import userdata
  HF_TOKEN = userdata.get('HF_TOKEN')
else:
  !pip install dotenv
  from dotenv import load_dotenv
  load_dotenv()
  HF_TOKEN = os.environ["HF_TOKEN"]

In [None]:
%%capture
# TODO specify the exact versions of dependencies here. This was written around the end of May 2025, so if unsloth has breaking changes try reverting to the version that was current at that time
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl==0.15.2 triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1" huggingface_hub hf_transfer
    !pip install transformers==4.51.3
    !pip install --no-deps unsloth

In [None]:
from unsloth import FastModel
import torch

model, tokenizer = FastModel.from_pretrained(
    model_name="unsloth/gemma-3-1b-it",
    max_seq_length=2048,  # Choose any for long context!
    load_in_4bit=False,  # we are training a 1B model, so we don't need to quantize much
    load_in_8bit=True,
    full_finetuning=False,
)

In [None]:
# We now add LoRA adapters so we only need to update a small amount of parameters

model = FastModel.get_peft_model(
    model,
    finetune_vision_layers=False,  # Turn off for just text!
    finetune_language_layers=True,  # Should leave on!
    finetune_attention_modules=True,  # Attention good for GRPO
    finetune_mlp_modules=True,  # SHould leave on always!
    r=8,  # Larger = higher accuracy, but might overfit
    lora_alpha=8,  # Recommended alpha == r at least
    lora_dropout=0,
    bias="none",
    random_state=3407,
)

In [None]:
from unsloth.chat_templates import get_chat_template

tokenizer = get_chat_template(
    tokenizer,
    chat_template="gemma-3",
)

In [None]:
# load the dataset from huggingface

import datasets

dataset = datasets.load_dataset("jnises/llmog-conversations", split="train")

In [None]:
# We now use standardize_data_formats to try converting datasets to the correct format for finetuning purposes

from unsloth.chat_templates import standardize_data_formats

dataset = standardize_data_formats(dataset)

In [None]:
# We now have to apply the chat template for Gemma-3 onto the conversations, and save it to text.
# We remove the <bos> token using removeprefix('<bos>') since we're finetuning.
# The Processor will add this token before training and the model expects only one.


def formatting_prompts_func(examples):
    convos = examples["conversations"]
    texts = [
        tokenizer.apply_chat_template(
            convo, tokenize=False, add_generation_prompt=False
        ).removeprefix("<bos>")
        for convo in convos
    ]
    return {
        "text": texts,
    }


dataset = dataset.map(formatting_prompts_func, batched=True)

In [None]:
from trl import SFTTrainer, SFTConfig

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
    eval_dataset=None,  # Can set up evaluation!
    args=SFTConfig(
        dataset_text_field="text",
        per_device_train_batch_size=2,
        gradient_accumulation_steps=4,  # Use GA to mimic batch size!
        warmup_steps=5,
        num_train_epochs=1,  # Set this for 1 full training run.
        # max_steps = 30, # uncomment this and comment the num_train_epochs line to do a quicker run
        learning_rate=2e-4,  # Reduce to 2e-5 for long training runs
        logging_steps=1,
        optim="adamw_8bit",
        weight_decay=0.01,
        lr_scheduler_type="linear",
        seed=3407,
        report_to="none",  # Use this for WandB etc
        dataset_num_proc=2,
    ),
)

In [None]:
# We also use Unsloth's train_on_completions method to only train on the assistant outputs and ignore the loss on the user's inputs.
# This helps increase accuracy of finetunes.

from unsloth.chat_templates import train_on_responses_only

trainer = train_on_responses_only(
    trainer,
    instruction_part="<start_of_turn>user\n",
    response_part="<start_of_turn>model\n",
)

In [None]:
# @title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

In [None]:
# Let's train the model.
# To resume a training run, set trainer.train(resume_from_checkpoint = True)

trainer_stats = trainer.train()

In [None]:
# @title Show final memory and time stats
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory * 100, 3)
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(
    f"{round(trainer_stats.metrics['train_runtime'] / 60, 2)} minutes used for training."
)
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")

In [None]:
# Test the trained model.
# According to the Gemma-3 team, the recommended settings for inference are temperature = 1.0, top_p = 0.95, top_k = 64

if False:
    from unsloth.chat_templates import get_chat_template

    tokenizer = get_chat_template(
        tokenizer,
        chat_template="gemma-3",
    )
    messages = [
        {
            "role": "system",
            "content": [
                {
                    "type": "text",
                    "text": """You are a developer log analyzer.
  Given a sequence of log lines. Rate only the last line. Use the prior lines only for context.
  If a prior line looks unrelated to the last one, disregard it.
  Rate the last line by how interesting you think it is for diagnosing an issue with the system.
  Output EXACTLY in this format:
  ```
  Very brief single-sentence analysis on a single line
  SCORE: 0-100
  ```

  Do NOT include any code examples, snippets, or additional explanations.
  Keep responses strictly limited to the analysis and score.
  Do NOT include any additional framing such as ````.
  Do NOT start the analysis with "The last line" or similar redundant information.

  Score guide:
  Low (0-30): Routine/minor info
  Medium (31-70): Noteworthy/important
  High (71-100): Critical/security issues
  """,
                }
            ],
        },
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": """[2025-02-25 08:07:33] [INFO] Transaction started - xid=a7f392e1
  [2025-02-25 08:07:33] [WARNING] Slow database query detected - duration=1530ms - query="SELECT * FROM orders JOIN order_items WHERE customer_id = ?"
  [2025-02-25 08:07:33] [INFO] Transaction committed - xid=a7f392e1
  """,
                }
            ],
        },
    ]
    text = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,  # Must add for generation
        # this seems to be needed for the 1B model
        tokenize=False,
    )
    outputs = model.generate(
        **tokenizer([text], return_tensors="pt").to("cuda"),
        max_new_tokens=64,  # Increase for longer outputs!
        # Recommended Gemma-3 settings!
        temperature=1.0,
        top_p=0.95,
        top_k=64,
    )
    tokenizer.batch_decode(outputs)

In [None]:
# upload the unquantized model to huggingface

model_name = "gemma-3-1b-llmog"
HF_REPO = f"jnises/{model_name}"

model.push_to_hub_merged(
    HF_REPO,
    tokenizer,
    token=HF_TOKEN,
)