# Fine-tuning gemma 3 for Tool Use

Code authored by: Shaw Talebi

### imports

In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from peft import LoraConfig
from trl import SFTConfig, SFTTrainer

from dotenv import load_dotenv
load_dotenv()

True

### load data

In [None]:
# load dataset
ds = load_dataset("shawhin/tool-use-finetuning")
ds

DatasetDict({
    train: Dataset({
        features: ['query', 'query_type', 'trace', 'num_tools_available', 'tool_needed', 'tool_name'],
        num_rows: 477
    })
    validation: Dataset({
        features: ['query', 'query_type', 'trace', 'num_tools_available', 'tool_needed', 'tool_name'],
        num_rows: 60
    })
    test: Dataset({
        features: ['query', 'query_type', 'trace', 'num_tools_available', 'tool_needed', 'tool_name'],
        num_rows: 60
    })
})

In [None]:
import numpy as np

# Set random seed for reproducibility
np.random.seed(42)

def filter_dataset(example):
    # Keep all "easy" queries that need tools
    if example['query_type'] == 'easy' and example['tool_needed'] == True:
        return True
    
    # Keep 20% of "no_tool" queries
    if example['query_type'] == 'no_tool':
        return np.random.random() < 0.2
    
    # Exclude everything else
    return False

# Apply the filtering
ds = ds.filter(filter_dataset)

Filter:   0%|          | 0/477 [00:00<?, ? examples/s]

Filter:   0%|          | 0/60 [00:00<?, ? examples/s]

Filter:   0%|          | 0/60 [00:00<?, ? examples/s]

### load model

In [None]:
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    print(f"MPS is available: {mps_device}")
else:
    print("MPS is not available. Please check your macOS version, PyTorch installation, and hardware.")

MPS is available: mps


In [None]:
# load model
model_name = "google/gemma-3-1b-it"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="mps",
    attn_implementation='eager'
)

# load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

### preprocess data

In [None]:
def preprocess(row):
    # replace first user message role to system
    messages = row['trace']
    messages[0]['role'] = 'system'

    # add tokenized text to dataset
    return {
        "text": tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False, return_tensors="pt")
    }

In [None]:
ds = ds.map(preprocess)

Map:   0%|          | 0/196 [00:00<?, ? examples/s]

Map:   0%|          | 0/27 [00:00<?, ? examples/s]

Map:   0%|          | 0/22 [00:00<?, ? examples/s]

### define LoRA hyperparameters

In [None]:
r = 16
lora_alpha = 32
lora_dropout = 0.05
target_modules = "all-linear"

peft_config = LoraConfig(r=r,
                         lora_alpha=lora_alpha,
                         lora_dropout=lora_dropout,
                         target_modules=target_modules,
                         bias="none",
                         task_type=TaskType.CAUSAL_LM)

In [None]:
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

trainable params: 13,045,760 || all params: 1,012,931,712 || trainable%: 1.2879


### define training hyperparameters

In [None]:
# hyperparameters
lr = 2e-4
num_epochs = 3
batch_size = 1
finetuned_model_name = "gemma-3-1b-tool-use"

# define training arguments
training_args = SFTConfig(
    output_dir=f"models/{finetuned_model_name}",
    num_train_epochs=num_epochs,
    learning_rate=lr,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    gradient_accumulation_steps=8,
    warmup_ratio = 0.03,
    max_grad_norm = 0.3,
    eval_strategy="steps",
    save_strategy="steps",
    logging_steps=20,
    eval_steps=20,
    save_steps=20,
    load_best_model_at_end=True,
    bf16=False,
    fp16=False,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
)

### fine-tune model

In [None]:
%%time
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=ds["train"],
    eval_dataset=ds["validation"],
    processing_class=tokenizer,
    peft_config=peft_config,
)
trainer.train()

Adding EOS to train dataset:   0%|          | 0/196 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/196 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/196 [00:00<?, ? examples/s]

Adding EOS to eval dataset:   0%|          | 0/27 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/27 [00:00<?, ? examples/s]

Truncating eval dataset:   0%|          | 0/27 [00:00<?, ? examples/s]

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


Step,Training Loss,Validation Loss
20,0.7132,0.223053
40,0.1506,0.135271
60,0.0973,0.122439




CPU times: user 3min 24s, sys: 2min 32s, total: 5min 57s
Wall time: 9min 57s


TrainOutput(global_step=75, training_loss=0.26859559933344523, metrics={'train_runtime': 596.8328, 'train_samples_per_second': 0.985, 'train_steps_per_second': 0.126, 'total_flos': 2295453414240000.0, 'train_loss': 0.26859559933344523})

### push to hub

In [None]:
# push to hub
username = "charbull"
trainer.push_to_hub(f"{username}/{finetuned_model_name}")

training_args.bin:   0%|          | 0.00/6.10k [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/52.2M [00:00<?, ?B/s]

Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]

CommitInfo(commit_url='https://huggingface.co/shawhin/gemma-3-1b-tool-use/commit/c05fb405b2a0985f86f9467f09d2725a91a7d109', commit_message='shawhin/gemma-3-1b-tool-use', commit_description='', oid='c05fb405b2a0985f86f9467f09d2725a91a7d109', pr_url=None, repo_url=RepoUrl('https://huggingface.co/shawhin/gemma-3-1b-tool-use', endpoint='https://huggingface.co', repo_type='model', repo_id='shawhin/gemma-3-1b-tool-use'), pr_revision=None, pr_num=None)