# Fine-tuning gemma 3 for Tool Use

Code authored by: Shaw Talebi

### imports

In [1]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from peft import get_peft_model, LoraConfig, TaskType
from trl import SFTConfig, SFTTrainer
import numpy as np

from dotenv import load_dotenv
load_dotenv()

True

### load data

In [2]:
# load dataset
ds = load_dataset("shawhin/tool-use-finetuning")
ds

DatasetDict({
    train: Dataset({
        features: ['query', 'query_type', 'trace', 'num_tools_available', 'tool_needed', 'tool_name'],
        num_rows: 477
    })
    validation: Dataset({
        features: ['query', 'query_type', 'trace', 'num_tools_available', 'tool_needed', 'tool_name'],
        num_rows: 60
    })
    test: Dataset({
        features: ['query', 'query_type', 'trace', 'num_tools_available', 'tool_needed', 'tool_name'],
        num_rows: 60
    })
})

In [3]:
# # filter out "hard" queries
# ds = ds.filter(lambda x: x['query_type'] != "hard")

In [4]:
# # only use "easy" queries
# ds = ds.filter(lambda x: x['query_type'] == "easy")

In [5]:
import numpy as np

# Set random seed for reproducibility
np.random.seed(42)

def filter_dataset(example):
    # Keep all "easy" queries that need tools
    if example['query_type'] == 'easy' and example['tool_needed'] == True:
        return True
    
    # Keep 20% of "no_tool" queries
    if example['query_type'] == 'no_tool':
        return np.random.random() < 0.2
    
    # Exclude everything else
    return False

# Apply the filtering
ds = ds.filter(filter_dataset)

Filter:   0%|          | 0/477 [00:00<?, ? examples/s]

Filter:   0%|          | 0/60 [00:00<?, ? examples/s]

Filter:   0%|          | 0/60 [00:00<?, ? examples/s]

### load model

In [6]:
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    print(f"MPS is available: {mps_device}")
else:
    print("MPS is not available. Please check your macOS version, PyTorch installation, and hardware.")

MPS is available: mps


In [7]:
# load model
model_name = "google/gemma-3-1b-it"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="mps",
    attn_implementation='eager'
)

# load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

### preprocess data

In [8]:
def preprocess(row):
    # replace first user message role to system
    messages = row['trace']
    messages[0]['role'] = 'system'

    # add tokenized text to dataset
    return {
        "text": tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False, return_tensors="pt")
    }

In [9]:
ds = ds.map(preprocess)

Map:   0%|          | 0/196 [00:00<?, ? examples/s]

Map:   0%|          | 0/27 [00:00<?, ? examples/s]

Map:   0%|          | 0/22 [00:00<?, ? examples/s]

### define LoRA hyperparameters

In [10]:
r = 16
lora_alpha = 32
lora_dropout = 0.05
target_modules = "all-linear"

peft_config = LoraConfig(r=r,
                         lora_alpha=lora_alpha,
                         lora_dropout=lora_dropout,
                         target_modules=target_modules,
                         bias="none",
                         task_type=TaskType.CAUSAL_LM)

In [11]:
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

trainable params: 13,045,760 || all params: 1,012,931,712 || trainable%: 1.2879


### custom eval

In [12]:
# from transformers import EvalPrediction

# import numpy as np
# from transformers import EvalPrediction

# def compute_tool_calling_metric(eval_pred: EvalPrediction):
#     """Fixed metric for SFTTrainer"""
#     predictions = eval_pred.predictions
#     labels = eval_pred.label_ids
    
#     print("ðŸ”¥ METRIC FUNCTION CALLED! ðŸ”¥")
#     print(f"Predictions shape: {predictions.shape}")
#     print(f"Labels shape: {labels.shape}")
    
#     # Convert logits to token IDs if needed
#     if len(predictions.shape) == 3:  # [batch, seq_len, vocab_size] - these are logits
#         predictions = np.argmax(predictions, axis=-1)
    
#     # Handle different prediction formats
#     if isinstance(predictions, tuple):
#         predictions = predictions[0]
    
#     # Replace -100 with pad_token_id for proper decoding
#     labels_for_decode = np.where(labels != -100, labels, tokenizer.pad_token_id)
    
#     # Ensure predictions and labels are the same length
#     min_len = min(predictions.shape[1], labels_for_decode.shape[1])
#     predictions = predictions[:, :min_len]
#     labels_for_decode = labels_for_decode[:, :min_len]
    
#     try:
#         # Decode predictions and labels
#         decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
#         decoded_labels = tokenizer.batch_decode(labels_for_decode, skip_special_tokens=True)
        
#         print(f"Example pred: {decoded_preds[0][:100]}...")
#         print(f"Example label: {decoded_labels[0][:100]}...")
        
#         scores = []
#         for pred, label in zip(decoded_preds, decoded_labels):
#             # Check if tool was expected
#             tool_expected = "<tool_call>" in label
#             # Check if model called a tool
#             model_called_tool = "<tool_call>" in pred
            
#             # Score: 1 if model correctly called/didn't call tool, 0 otherwise
#             if tool_expected == model_called_tool:
#                 scores.append(1.0)
#             else:
#                 scores.append(0.0)
        
#         result = {"tool_called_when_needed": np.mean(scores) * 100}
#         print(f"Metric result: {result}")
#         return result
        
#     except Exception as e:
#         print(f"Error in decoding: {e}")
#         print(f"Predictions type: {type(predictions)}")
#         print(f"Labels type: {type(labels_for_decode)}")
#         return {"tool_called_when_needed": 0.0}

### define training hyperparameters

In [13]:
# hyperparameters
lr = 2e-4
num_epochs = 3
batch_size = 1
finetuned_model_name = "gemma-3-1b-tool-use"

# define training arguments
training_args = SFTConfig(
    output_dir=f"models/{finetuned_model_name}",
    num_train_epochs=num_epochs,
    learning_rate=lr,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    gradient_accumulation_steps=8,
    warmup_ratio = 0.03,
    max_grad_norm = 0.3,
    eval_strategy="steps",
    save_strategy="steps",
    logging_steps=20,
    eval_steps=20,
    save_steps=20,
    load_best_model_at_end=True,
    bf16=False,
    fp16=False,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
)

### fine-tune model

In [14]:
%%time
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=ds["train"],#.select(range(100)),
    eval_dataset=ds["validation"],#.select(range(6)),
    processing_class=tokenizer,
    peft_config=peft_config,
)
trainer.train()

Adding EOS to train dataset:   0%|          | 0/196 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/196 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/196 [00:00<?, ? examples/s]

Adding EOS to eval dataset:   0%|          | 0/27 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/27 [00:00<?, ? examples/s]

Truncating eval dataset:   0%|          | 0/27 [00:00<?, ? examples/s]

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


Step,Training Loss,Validation Loss
20,0.7132,0.223053
40,0.1506,0.135271
60,0.0973,0.122439




CPU times: user 3min 24s, sys: 2min 32s, total: 5min 57s
Wall time: 9min 57s


TrainOutput(global_step=75, training_loss=0.26859559933344523, metrics={'train_runtime': 596.8328, 'train_samples_per_second': 0.985, 'train_steps_per_second': 0.126, 'total_flos': 2295453414240000.0, 'train_loss': 0.26859559933344523})

### push to hub

In [15]:
# push to hub
username = "shawhin"
trainer.push_to_hub(f"{username}/{finetuned_model_name}")

training_args.bin:   0%|          | 0.00/6.10k [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/52.2M [00:00<?, ?B/s]

Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]

CommitInfo(commit_url='https://huggingface.co/shawhin/gemma-3-1b-tool-use/commit/c05fb405b2a0985f86f9467f09d2725a91a7d109', commit_message='shawhin/gemma-3-1b-tool-use', commit_description='', oid='c05fb405b2a0985f86f9467f09d2725a91a7d109', pr_url=None, repo_url=RepoUrl('https://huggingface.co/shawhin/gemma-3-1b-tool-use', endpoint='https://huggingface.co', repo_type='model', repo_id='shawhin/gemma-3-1b-tool-use'), pr_revision=None, pr_num=None)