# Gemma 3 fine-tuning using Hugging Face

In this notebook, we're going to fine tune Gemma 3 [google/gemma-3-1b-it](https://huggingface.co/google/gemma-3-1b-it) on the [Natural Language to Regex dataset](https://huggingface.co/datasets/inclinedadarsh/nl-to-regex)

In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

## Installing the required libraries

In [2]:
%pip install "torch>2.3.0" wandb

%pip install git+https://github.com/huggingface/transformers@v4.49.0-Gemma-3

%pip install -U datasets accelerate evaluate bitsandbytes peft trl

Note: you may need to restart the kernel to use updated packages.
Collecting git+https://github.com/huggingface/transformers@v4.49.0-Gemma-3
  Cloning https://github.com/huggingface/transformers (to revision v4.49.0-Gemma-3) to /tmp/pip-req-build-m7_2f53s
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers /tmp/pip-req-build-m7_2f53s
  Running command git checkout -q 367bab469b0ef32017e2a0a0a5dbac5d36002f03
  Resolved https://github.com/huggingface/transformers to commit 367bab469b0ef32017e2a0a0a5dbac5d36002f03
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: transformers
  Building wheel for transformers (pyproject.toml) ... [?25l[?25hdone
  Created wheel for transformers: filename=transformers-4.50.0.dev0-py3-none-any.whl size=10936468 sha256=54103165a8b1c12ea09241ac79cf7d20c9

## Making necessary configs

In [3]:
# Loading the secrets from kaggle

from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
HF_TOKEN = user_secrets.get_secret("HF_TOKEN") # Make sure the HF_TOKEN has read and write access (write access to push the model to hub)
WANDB_KEY = user_secrets.get_secret("WANDB_KEY")

In [4]:
from huggingface_hub import login
login(HF_TOKEN)

In [6]:
import wandb
wandb.login(key=WANDB_KEY)
wandb.init(project="gemma-3-finetune", name="second-run")
# Initialize the wandb project here

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


## Getting the dataset read

Let's get the dataset ready, with the help of `load_dataset` function from the `datasets` library. We'll be loading the `inclinedadarsh/nl-to-regex` dataset.

In [7]:
from datasets import load_dataset

In [8]:
dataset = load_dataset('inclinedadarsh/nl-to-regex')

README.md:   0%|          | 0.00/1.28k [00:00<?, ?B/s]

dataset.csv:   0%|          | 0.00/53.2k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/824 [00:00<?, ? examples/s]

In [13]:
system_message = "You are a helpful natural language to regex converter. The user will provide some prompt, and you have to create a regex according to it."

def format_example(example):
    return {
        "messages": [
            {"role": "system", "content": system_message},
            {"role": "user", "content": example['user']},
            {"role": "assistant", "content": example['assistant']}
        ]
    }

dataset = dataset.map(format_example, remove_columns=dataset['train'].features, batched=False)

Map:   0%|          | 0/824 [00:00<?, ? examples/s]

In [15]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

In [16]:
model_name = 'google/gemma-3-1b-it'

In [17]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_storage=torch.float16
)

In [18]:
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    attn_implementation='eager',
    quantization_config=bnb_config,
    torch_dtype=torch.float16,
    device_map='auto'
)

config.json:   0%|          | 0.00/899 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.00G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/215 [00:00<?, ?B/s]

In [19]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

## Setting up LoRA parameteres for parameters efficient fine tuning

In [21]:
from peft import LoraConfig

In [22]:
peft_config = LoraConfig(
    lora_alpha=8,
    lora_dropout=0.05,
    r=8,
    bias="none",
    target_modules='all-linear',
    task_type="CAUSAL_LM",
)

## Setting up the training arguments

In [23]:
from trl import SFTConfig

In [24]:
training_args = SFTConfig(
    output_dir="./gemma-finetune",
    max_seq_length=512,
    packing=True,
    num_train_epochs=3,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    optim='adamw_torch_fused',
    logging_steps=2,
    save_strategy='epoch',
    learning_rate=2e-4,
    fp16=True,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type='constant',
    push_to_hub=False,
    report_to='wandb',
    dataset_kwargs={
        "add_special_tokens": False,
        "append_concat_token": True
    }
)

In [25]:
from trl import SFTTrainer

In [27]:
trainer = SFTTrainer(
    model=model,
    args=training_args,
    peft_config=peft_config,
    processing_class=tokenizer,
    train_dataset=dataset['train']
)



Converting train dataset to ChatML:   0%|          | 0/824 [00:00<?, ? examples/s]

Applying chat template to train dataset:   0%|          | 0/824 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/824 [00:00<?, ? examples/s]

Packing train dataset:   0%|          | 0/824 [00:00<?, ? examples/s]

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


## Let's train!

In [28]:
trainer.train()

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss
2,4.0608
4,3.0744
6,2.6516
8,2.4062
10,2.2734
12,2.0792
14,1.7952
16,1.6342
18,1.4476
20,1.3281


TrainOutput(global_step=75, training_loss=1.0530573407808939, metrics={'train_runtime': 172.5795, 'train_samples_per_second': 1.79, 'train_steps_per_second': 0.435, 'total_flos': 642734385911808.0, 'train_loss': 1.0530573407808939})

## Pushing the model to hugging face hub

In [29]:
trainer.push_to_hub("inclinedadarsh/gemma-3-1b-it-nl-to-regex")

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

Upload 4 LFS files:   0%|          | 0/4 [00:00<?, ?it/s]

adapter_model.safetensors:   0%|          | 0.00/26.1M [00:00<?, ?B/s]

training_args.bin:   0%|          | 0.00/5.69k [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/inclinedadarsh/gemma-finetune/commit/10cc112fb85693e16c7cd956a29fd46283c3811e', commit_message='inclinedadarsh/gemma-3-1b-it-nl-to-regex', commit_description='', oid='10cc112fb85693e16c7cd956a29fd46283c3811e', pr_url=None, repo_url=RepoUrl('https://huggingface.co/inclinedadarsh/gemma-finetune', endpoint='https://huggingface.co', repo_type='model', repo_id='inclinedadarsh/gemma-finetune'), pr_revision=None, pr_num=None)

## Inference from the model

### Before inferencing, let's free up the memory

In [31]:
del model
del trainer
torch.cuda.empty_cache()

### Inference pipeline

In [33]:
import torch
from transformers import pipeline

In [34]:
tuned_model_name = 'inclinedadarsh/gemma-3-1b-nl-to-regex'

In [35]:
model = AutoModelForCausalLM.from_pretrained(
    tuned_model_name,
    device_map='auto',
    torch_dtype=torch.float16,
    attn_implementation='eager'
)

adapter_config.json:   0%|          | 0.00/851 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/26.1M [00:00<?, ?B/s]

In [36]:
tokenizer = AutoTokenizer.from_pretrained(tuned_model_name)

tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

In [37]:
from random import randint
import re

pipe = pipeline('text-generation', model=model, tokenizer=tokenizer)

Device set to use cuda:0


In [41]:
rand_idx = randint(0, len(dataset['train']))
test_sample = dataset['train'][rand_idx]

In [42]:
stop_token_ids = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<end_of_turn>")]
prompt = pipe.tokenizer.apply_chat_template(test_sample['messages'][:2], tokenize=False, add_generation_prompt=True)

In [43]:
outputs = pipe(prompt, max_new_tokens=256, do_sample=False, temperature=0.1, top_k=50, top_p=0.1, eos_token_id=stop_token_ids, disable_compile=True)



In [44]:
outputs

[{'generated_text': '<bos><start_of_turn>user\nYou are a helpful natural language to regex converter. The user will provide some prompt, and you have to create a regex according to it.\n\nlines containing at least 2 words<end_of_turn>\n<start_of_turn>model\n(.*\\b[A-Za-z]+\\b.*){2}'}]