In [None]:
# Title: Fine-tuning Gemma-3 for Function Calling Tasks
# Environment: Python 3.8+, Ubuntu 20.04, CUDA 11.3+
# Libraries: transformers, datasets, trl, peft, huggingface_hub, dotenv
# Description:
# This script fine-tunes Google's Gemma-3 language model for function calling tasks using LoRA and HuggingFace.
# It loads a function-calling dataset, preprocesses it into a chat format, and trains the model with special tokens.
# The script also demonstrates how to push the trained model and tokenizer to the HuggingFace Hub,
# and how to load and use the fine-tuned model for inference.
#

In [None]:
!pip install "torch>=2.4.0" tensorboard flash-attn
!pip install git+https://github.com/huggingface/transformers@v4.49.0
!pip install --upgrade datasets==3.3.2 accelerate==1.4.0 evaluate==0.4.3  bitsandbytes==0.45.3 trl==0.15.2 peft==0.14.0 protobuf==3.20.3  sentencepiece

In [None]:
!pip install git+https://github.com/huggingface/transformers@v4.49.0-Gemma-3

In [None]:
# https://medium.com/@akriti.upadhyay/enhancing-gemma-3s-capabilities-with-fine-tuning-for-function-calling-f1bc74051abe

import torch, json, gc, os
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration, BitsAndBytesConfig, set_seed
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig, PeftModel, PeftConfig
from enum import Enum
from huggingface_hub import login
from google.colab import userdata

hf_token = os.getenv("hf-api")
login(token=hf_token)

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

seed = 42
set_seed(seed)

torch_dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
model_name = "google/gemma-3-4b-it"
dataset_name = "Salesforce/xlam-function-calling-60k"

model_kwargs = dict(
    attn_implementation="flash_attention_2", # "eager", "sdpa", "flash_attention", "flash_attention_2"
    torch_dtype=torch_dtype,
    device_map="auto",
    quantization_config=BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type='nf4',
        bnb_4bit_compute_dtype=torch_dtype,
        bnb_4bit_quant_storage=torch_dtype,
        llm_int8_enable_fp32_cpu_offload=True
    )
)

model = Gemma3ForConditionalGeneration.from_pretrained(model_name, **model_kwargs)

class ToolCallSpacialTokens(str, Enum):
    tools = "<tools>"
    eotools = "</tools>"
    think = "<think>"
    eothink = "</think>"
    tool_call="<tool_call>"
    eotool_call="</tool_call>"
    tool_response="<tool_response>"
    eotool_response="</tool_response>"
    pad_token = "<pad>"
    eos_token = "<eos>"

    @classmethod
    def list(cls):
        return [c.value for c in cls]

tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    pad_token=ToolCallSpacialTokens.pad_token.value,
    additional_special_tokens=ToolCallSpacialTokens.list()
)

tokenizer.chat_template = """{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{{ '<start_of_turn>' + message['role'] + '\n' + message['content'] | trim + '<end_of_turn><eos>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}"""

model.resize_token_embeddings(len(tokenizer))
model.to(device)


config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/90.6k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.64G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/215 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

Gemma3ForConditionalGeneration(
  (vision_tower): SiglipVisionModel(
    (vision_model): SiglipVisionTransformer(
      (embeddings): SiglipVisionEmbeddings(
        (patch_embedding): Conv2d(3, 1152, kernel_size=(14, 14), stride=(14, 14), padding=valid)
        (position_embedding): Embedding(4096, 1152)
      )
      (encoder): SiglipEncoder(
        (layers): ModuleList(
          (0-26): 27 x SiglipEncoderLayer(
            (self_attn): SiglipAttention(
              (k_proj): Linear4bit(in_features=1152, out_features=1152, bias=True)
              (v_proj): Linear4bit(in_features=1152, out_features=1152, bias=True)
              (q_proj): Linear4bit(in_features=1152, out_features=1152, bias=True)
              (out_proj): Linear4bit(in_features=1152, out_features=1152, bias=True)
            )
            (layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
            (mlp): SiglipMLP(
              (activation_fn): PytorchGELUTanh()
              (fc1): Linear4b

In [None]:
def query_gemma(prompt, model, tokenizer, max_new_tokens=256, temperature=0.01, top_p=0.95, repetition_penalty=1.1):
    """
    Queries the fine-tuned Gemma model with a given prompt.

    Args:
        prompt (str): The input prompt for the model.
        model: The loaded and fine-tuned Gemma model.
        tokenizer: The tokenizer for the model.
        max_new_tokens (int): Maximum number of new tokens to generate.
        temperature (float): Controls the randomness of the generation.
        top_p (float): The cumulative probability for top-p sampling.
        repetition_penalty (float): Penalty for repeating tokens.

    Returns:
        str: The decoded response from the model.
    """
    # Keep inputs on CPU initially
    inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)

    # Move inputs to model's device right before generation
    outputs = model.generate(
        **inputs.to(model.device), # Move inputs to GPU here
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=temperature,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        eos_token_id=tokenizer.eos_token_id
    )

    response = tokenizer.decode(outputs[0], skip_special_tokens=False)
    return response

# Example usage of the function
prompt = """<bos><start_of_turn>user
You have access to the following tools:

- numerical_derivative: Estimate the derivative of a mathematical function

User query:
I need to estimate the derivative of the function y = sin(x) at x = π/4 and x = π. Can you help with that?<end_of_turn><eos>
<start_of_turn>assistant
"""

response = query_gemma(prompt, model, tokenizer)
print(response)

<bos><start_of_turn>user
You have access to the following tools:

- numerical_derivative: Estimate the derivative of a mathematical function

User query:
I need to estimate the derivative of the function y = sin(x) at x = π/4 and x = π. Can you help with that?<end_of_turn><eos>
<start_of_turn>assistant
```tool_code
print(numerical_derivative(func='sin(x)', x=pi/4, h=0.001))
print(numerical_derivative(func='sin(x)', x=pi, h=0.001))
```<end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><en

In [None]:

def preprocess(sample):
    try:
        tools = json.loads(sample["tools"])
        answers = json.loads(sample["answers"])
        user_query = sample["query"]
    except Exception as e:
        print("Error decoding JSON:", sample)
        raise e

    messages = [
        {
            "role": "user",
            "content": (
                "You have access to the following tools:\n\n"
                + "\n\n".join(f"- {tool['name']}: {tool['description']}" for tool in tools)
                + "\n\nUser query:\n" + user_query
            )
        },
        {
            "role": "assistant",
            "content": "\n".join(
                f"<function_call>\n{json.dumps(answer)}\n</function_call>"
                for answer in answers
            )
        }
    ]

    return {
        "text": tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
    }

dataset = load_dataset(dataset_name)
dataset = dataset["train"].map(preprocess, remove_columns=["id", "query", "answers", "tools"])
dataset = dataset.train_test_split(0.1)
print(dataset)

print(dataset["train"][19]["text"])

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=16,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=["lm_head", "embed_tokens"] # make sure to save the lm_head and embed_tokens as you train the special tokens
)

training_arguments = SFTConfig(
    output_dir="gemma-3-4b-it-thinking-function_calling-V0",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=32,
    save_strategy="epoch",
    eval_strategy="epoch",
    logging_steps=50,
    learning_rate=3e-4,
    max_grad_norm=0.3,
    weight_decay=0.1,
    warmup_ratio=0.03,
    lr_scheduler_type="constant",
    report_to=None,
    bf16=True,
    optim="paged_adamw_8bit",
    torch_compile=False,
    push_to_hub=False,
    num_train_epochs=3,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    packing=False,
    max_seq_length=512,
    dataset_kwargs={
        "add_special_tokens": False,
        "append_concat_token": True,
    }
)

torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()

trainer = SFTTrainer(
    model=model,
    args=training_arguments,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    # tokenizer=tokenizer,
    peft_config=peft_config,
) # https://huggingface.co/docs/trl/en/sft_trainer

trainer.train()
trainer.save_model()

trainer.push_to_hub(f'mac999/gemma-3-4b-it-thinking-function_calling-V0-{seed}', commit_message="Pushing fine-tuned model with function calling capabilities")

tokenizer.eos_token = "<eos>"
tokenizer.push_to_hub(f"mac999/", token=True)

peft_model_id = f"mac999/gemma-3-4b-it-thinking-function_calling-V0-{seed}"
device = "auto"
config = PeftConfig.from_pretrained(peft_model_id)
model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it",
                                             device_map="auto",
                                             )
tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
model.resize_token_embeddings(len(tokenizer))
model = PeftModel.from_pretrained(model, peft_model_id)
model.to(torch.bfloat16)
model.eval()


response = query_gemma(prompt, model, tokenizer)
print(response)

README.md:   0%|          | 0.00/10.7k [00:00<?, ?B/s]

xlam_function_calling_60k.json:   0%|          | 0.00/96.1M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/60000 [00:00<?, ? examples/s]

Map:   0%|          | 0/60000 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 54000
    })
    test: Dataset({
        features: ['text'],
        num_rows: 6000
    })
})
<bos><start_of_turn>user
You have access to the following tools:

- max_points_on_line: Finds the maximum number of points that lie on the same straight line.

- solve_quadratic: Computes the roots of a quadratic equation given its coefficients.

- greatest_common_divisor: Computes the greatest common divisor (GCD) of two non-negative integers.

- chi_square_independence_test: Performs a Chi-Square test for independence on a 2x2 contingency table.

- can_attend_all_meetings: Determines if a person can attend all meetings given a list of meeting time intervals.

- binomial_probability: Calculates the probability of getting exactly k successes in n independent trials,

User query:
What are the roots of the quadratic equation 2x^2 - 3x + 1 = 0?<end_of_turn><eos>
<start_of_turn>assistant
<function_call>
{"name": "solv

Converting train dataset to ChatML:   0%|          | 0/54000 [00:00<?, ? examples/s]

Applying chat template to train dataset:   0%|          | 0/54000 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/54000 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/54000 [00:00<?, ? examples/s]

Converting eval dataset to ChatML:   0%|          | 0/6000 [00:00<?, ? examples/s]

Applying chat template to eval dataset:   0%|          | 0/6000 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/6000 [00:00<?, ? examples/s]

Truncating eval dataset:   0%|          | 0/6000 [00:00<?, ? examples/s]

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mmac999[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Epoch,Training Loss,Validation Loss
