In [18]:
import torch
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
import pickle, json, re, gc
from datasets import load_dataset, Dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig

In [2]:
torch_dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch_dtype, device

(torch.bfloat16, device(type='cuda'))

In [4]:
model_name = "../../models/gemma-3-4b-it/"
model_kwargs = dict(
    attn_implementation="eager",# eager, flash_attention_2
    torch_dtype=torch_dtype, #torch_dtype, auto
    device_map="auto"
)
model = Gemma3ForConditionalGeneration.from_pretrained(model_name, **model_kwargs)

tokenizer = AutoTokenizer.from_pretrained(model_name)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
len(tokenizer)

262145

#### Data Loading

In [6]:
with open("../data/xlam-function-calling-60k-train_data.pkl", "rb") as f:
    train_data = pickle.load(f)

with open("../data/xlam-function-calling-60k-validation_data.pkl", "rb") as f:
    validation_data = pickle.load(f)


train_dataset = Dataset.from_list(train_data)
validation_dataset = Dataset.from_list(validation_data)


train_dataset = train_dataset.select(range(2000, 2500))
valid_dataset = validation_dataset.select(range(100, 250))
train_dataset, valid_dataset

(Dataset({
     features: ['id', 'query', 'answers', 'tools'],
     num_rows: 500
 }),
 Dataset({
     features: ['id', 'query', 'answers', 'tools'],
     num_rows: 150
 }))

### updated prompt(05-06-2025) 

In [7]:
def preprocess_structured(sample):
    try:
        tools = json.loads(sample["tools"])
        answers = json.loads(sample["answers"])
        user_query = sample["query"]
    except Exception as e:
        print("Error decoding JSON:", sample)
        raise e
    
    system_prompt = "You are a helpful assistant that can call functions to help answer user queries. When you need to use a tool, format your response with <function_call> tags containing valid JSON."
    
    tools_formatted = []
    for tool in tools:
        tool_info = {
            "name": tool["name"],
            "description": tool["description"],
            "parameters": tool.get("parameters", {})
        }
        tools_formatted.append(json.dumps(tool_info, indent=2))
    
    tools_text = "Available tools:\n" + "\n\n".join(tools_formatted)
    
    messages = [
        {
            "role": "system",
            "content": system_prompt
        },
        {
            "role": "user",
            "content": f"{tools_text}\n\nUser query: {user_query}"
        },
        {
            "role": "assistant", 
            "content": "\n".join(
                f"<function_call>\n{json.dumps(answer, ensure_ascii=False, separators=(',', ':'))}\n</function_call>"  # FIXED: added ensure_ascii=False
                for answer in answers
            )
        }
    ]
    
    return {
        "text": tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)
    }

In [8]:
dataset_train = train_dataset.map(preprocess_structured, remove_columns=["id", "query", "answers", "tools"])

dataset_validation = valid_dataset.map(preprocess_structured, remove_columns=["id", "query", "answers", "tools"])
dataset_train, dataset_validation

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/150 [00:00<?, ? examples/s]

(Dataset({
     features: ['text'],
     num_rows: 500
 }),
 Dataset({
     features: ['text'],
     num_rows: 150
 }))

In [9]:
print(dataset_train["text"][0])

<bos><start_of_turn>user
You are a helpful assistant that can call functions to help answer user queries. When you need to use a tool, format your response with <function_call> tags containing valid JSON.

Available tools:
{
  "name": "get_joke_of_the_day_by_category",
  "description": "Fetches the joke of the day from a specified category using the World of Jokes API.",
  "parameters": {
    "category": {
      "description": "The category of joke to be fetched.",
      "type": "str",
      "default": "Money"
    }
  }
}

User query: Fetch the joke of the day from the 'nerdy' category.<end_of_turn>
<start_of_turn>model
<function_call>
{"name":"get_joke_of_the_day_by_category","arguments":{"category":"nerdy"}}
</function_call><end_of_turn>



In [12]:
peft_config = LoraConfig(
    lora_alpha=32, #16
    lora_dropout=0.1, # 0.05
    r=32, #16
    bias="none",
    # target_modules="all-linear",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],  # More specific targeting
    task_type="CAUSAL_LM",
    # modules_to_save=["lm_head", "embed_tokens"] # make sure to save the lm_head and embed_tokens as you train the special tokens Not needed without new tokens
)

In [16]:
training_arguments = SFTConfig(
    output_dir="/mnt/data1/mani/finetuned_models/gemma-3-4b-it-function-calling-V1",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=16,
    save_strategy="steps",
    eval_strategy="steps",
    logging_steps=10, # 10 it will print the loss 
    learning_rate=2e-4,
    max_grad_norm=0.3,
    weight_decay=0.1,
    warmup_ratio=0.03,
    lr_scheduler_type="cosine", #constant
    report_to="tensorboard",
    bf16=True,
    save_total_limit = 2, # number of check points to store , in that 2 check points are saved , last and best 
    save_steps = 100, # for saving the checkpoints per steps
    load_best_model_at_end = True,
    optim="paged_adamw_8bit",
    torch_compile=False,
    push_to_hub=False,
    num_train_epochs=3,
    gradient_checkpointing=True, # True
    gradient_checkpointing_kwargs={"use_reentrant": False},
    packing=False,
    max_seq_length=1024,
    logging_dir="/mnt/data1/mani/finetuned_models/gemma-3-4b-it-function-calling-V1/logs",
    dataset_kwargs={
        "add_special_tokens": False,
        "append_concat_token": False,
    }
)

In [19]:
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()

537

In [20]:
trainer = SFTTrainer(
    model=model,
    args=training_arguments,
    train_dataset=dataset_train,
    eval_dataset=dataset_validation,
    # tokenizer=tokenizer,
    processing_class = tokenizer,
    peft_config=peft_config,
)

Converting train dataset to ChatML:   0%|          | 0/500 [00:00<?, ? examples/s]

Applying chat template to train dataset:   0%|          | 0/500 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/500 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/500 [00:00<?, ? examples/s]

Converting eval dataset to ChatML:   0%|          | 0/150 [00:00<?, ? examples/s]

Applying chat template to eval dataset:   0%|          | 0/150 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/150 [00:00<?, ? examples/s]

Truncating eval dataset:   0%|          | 0/150 [00:00<?, ? examples/s]

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [15]:
trainer.train()

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss,Validation Loss
10,21.9256,0.654367
20,9.9157,0.548147
30,8.2939,0.504273
40,6.7757,0.485462
50,7.0654,0.472899
60,6.5302,0.463355
70,6.2407,0.459858
80,5.9112,0.459946
90,5.6436,0.45909


TrainOutput(global_step=93, training_loss=8.610176024898406, metrics={'train_runtime': 726.0552, 'train_samples_per_second': 2.066, 'train_steps_per_second': 0.128, 'total_flos': 1.794163375706256e+16, 'train_loss': 8.610176024898406})

In [16]:
trainer.save_model()

### Saving....

In [1]:
import torch
import json
from peft import PeftModel, PeftConfig
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration

In [2]:
torch_dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch_dtype, device

(torch.bfloat16, device(type='cuda'))

In [3]:
peft_model_id = "/mnt/data1/mani/finetuned_models/gemma-3-4b-it-function-calling-V1/"
model_name = "../../models/gemma-3-4b-it"

save_folder = "/mnt/data1/mani/finetuned_models/gemma-3-4b-it-function-calling-V1-merged"

model = Gemma3ForConditionalGeneration.from_pretrained(model_name,
                                                       device_map="auto",
                                                       torch_dtype=torch_dtype, # torch_dtype, auto
                                             )
tokenizer = AutoTokenizer.from_pretrained(peft_model_id)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
model = PeftModel.from_pretrained(model, peft_model_id)
merged_model = model.merge_and_unload()

In [5]:
merged_model.save_pretrained(save_folder)
tokenizer.save_pretrained(save_folder)

('/mnt/data1/mani/finetuned_models/gemma-3-4b-it-function-calling-V1-merged/tokenizer_config.json',
 '/mnt/data1/mani/finetuned_models/gemma-3-4b-it-function-calling-V1-merged/special_tokens_map.json',
 '/mnt/data1/mani/finetuned_models/gemma-3-4b-it-function-calling-V1-merged/tokenizer.model',
 '/mnt/data1/mani/finetuned_models/gemma-3-4b-it-function-calling-V1-merged/added_tokens.json',
 '/mnt/data1/mani/finetuned_models/gemma-3-4b-it-function-calling-V1-merged/tokenizer.json')

In [10]:
merged_model.config

Gemma3Config {
  "architectures": [
    "Gemma3ForConditionalGeneration"
  ],
  "boi_token_index": 255999,
  "eoi_token_index": 256000,
  "eos_token_id": [
    1,
    106
  ],
  "image_token_index": 262144,
  "initializer_range": 0.02,
  "mm_tokens_per_image": 256,
  "model_type": "gemma3",
  "text_config": {
    "attention_bias": false,
    "attention_dropout": 0.0,
    "attn_logit_softcapping": null,
    "cache_implementation": "hybrid",
    "final_logit_softcapping": null,
    "head_dim": 256,
    "hidden_activation": "gelu_pytorch_tanh",
    "hidden_size": 2560,
    "initializer_range": 0.02,
    "intermediate_size": 10240,
    "max_position_embeddings": 131072,
    "model_type": "gemma3_text",
    "num_attention_heads": 8,
    "num_hidden_layers": 34,
    "num_key_value_heads": 4,
    "query_pre_attn_scalar": 256,
    "rms_norm_eps": 1e-06,
    "rope_local_base_freq": 10000.0,
    "rope_scaling": {
      "factor": 8.0,
      "rope_type": "linear"
    },
    "rope_theta": 1000000.0

In [7]:
merged_model.config.to_json_file(f"{save_folder}/config.json")

In [8]:
peft_config = PeftConfig.from_pretrained(peft_model_id)
peft_config

LoraConfig(task_type='CAUSAL_LM', peft_type=<PeftType.LORA: 'LORA'>, auto_mapping=None, base_model_name_or_path='../../models/gemma-3-4b-it/', revision=None, inference_mode=True, r=32, target_modules={'gate_proj', 'k_proj', 'up_proj', 'v_proj', 'q_proj', 'o_proj', 'down_proj'}, exclude_modules=None, lora_alpha=32, lora_dropout=0.1, fan_in_fan_out=False, bias='none', use_rslora=False, modules_to_save=None, init_lora_weights=True, layers_to_transform=None, layers_pattern=None, rank_pattern={}, alpha_pattern={}, megatron_config=None, megatron_core='megatron.core', trainable_token_indices=None, loftq_config={}, eva_config=None, corda_config=None, use_dora=False, layer_replication=None, runtime_config=LoraRuntimeConfig(ephemeral_gpu_offload=False), lora_bias=False)

In [9]:
from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained(model_name)
processor.save_pretrained(save_folder)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


['/mnt/data1/mani/finetuned_models/gemma-3-4b-it-function-calling-V1-merged/processor_config.json']