# Fintuning Gemma 3 QAT

Adapt a pre-trained model (in our case, the gemma-2-2b-it-4bit model from the MLX Community) to handle function-calling by using the mlx-lm package. 

This involves creating a specialized chat template, preprocessing a dataset of function call interactions, and applying LoRA for efficient fine-tuning.

https://medium.com/@levchevajoana/fine-tuning-a-model-for-function-calling-with-mlx-lm-d00d587e2559

In [49]:
import json
import os
from enum import Enum
from typing import Dict, List, Tuple, Union, Any, Dict, List, Optional


import mlx.optimizers as optim
from datasets import load_dataset
from mlx.utils import tree_flatten
from mlx_lm import generate, load
from mlx_lm.tuner import TrainingArgs, datasets, linear_to_lora_layers, train


import argparse
import json
import math
import os
import sys
import time
from pathlib import Path

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
from mlx.utils import tree_flatten

In [2]:
# Load the model
model_path = "mlx-community/gemma-3-4b-it-qat-4bit"
model, tokenizer = load(model_path)

Fetching 12 files:   0%|          | 0/12 [00:00<?, ?it/s]

customize the tokenizer’s chat template to define the structure of our conversational interactions

This template embeds special tokens (like `<bos>`, `<start_of_turn>`, `<think>`, and `<tool_call>`) that mark the different stages of the conversation — from the user’s prompt to the model’s internal reasoning and eventual function call.

In [3]:
tokenizer.chat_template = (
    "{{ bos_token }}"
    "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}"
    "{% for message in messages %}"
    "{{ '<start_of_turn>' + message['role'] + '\n' + message['content'] | trim + '<end_of_turn><eos>\n' }}"
    "{% endfor %}"
    "{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}"
)

## Dataset Preparation and Preprocessing

We use the dataset Jofthomas/hermes-function-calling-thinking-V1 which contains conversations involving function calls.

In [4]:
dataset_path = "Jofthomas/hermes-function-calling-thinking-V1"

dataset = load_dataset(dataset_path)
dataset

DatasetDict({
    train: Dataset({
        features: ['conversations'],
        num_rows: 3570
    })
})

In [5]:
dataset = dataset.rename_column("conversations", "messages")

In [6]:
first = dataset['train']["messages"][0]

In [7]:
first[0]["content"]

"You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags.You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions.Here are the available tools:<tools> [{'type': 'function', 'function': {'name': 'get_stock_price', 'description': 'Get the current stock price of a company', 'parameters': {'type': 'object', 'properties': {'company': {'type': 'string', 'description': 'The name of the company'}}, 'required': ['company']}}}, {'type': 'function', 'function': {'name': 'get_movie_details', 'description': 'Get details about a movie', 'parameters': {'type': 'object', 'properties': {'title': {'type': 'string', 'description': 'The title of the movie'}}, 'required': ['title']}}}] </tools>Use the following pydantic model json schema for each tool call you will make: {'title': 'FunctionCall', 'type': 'object', 'properties': {'arguments': {'title': 'Arguments', 'type': 'object'

## Preprocessing function

In [15]:
def preprocess(sample):
    messages = sample["messages"]
    first_message = messages[0]

    # Instead of adding a system message, we merge the content into the first user message
    if first_message["role"] == "system":
        system_message_content = first_message["content"]
        # Merge system content with the first user message
        messages[1]["content"] = (
            system_message_content
            + "Also, before making a call to a function take the time to plan the function to take. Make that thinking process between <think>{your thoughts}</think>\n\n"
            + messages[1]["content"]
        )
        # Remove the system message from the conversation
        messages.pop(0)

    return {"text": tokenizer.apply_chat_template(messages, tokenize=False), "1": str(len(messages))}

In [16]:
dataset = dataset.map(preprocess, remove_columns="messages")
dataset = dataset["train"].train_test_split(0.1)
dataset

Map:   0%|          | 0/3570 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['text', '1'],
        num_rows: 3213
    })
    test: Dataset({
        features: ['text', '1'],
        num_rows: 357
    })
})

In [17]:
dataset['train']["text"][0]

'<bos><start_of_turn>human\nYou are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags.You may call one or more functions to assist with the user query. Don\'t make assumptions about what values to plug into functions.Here are the available tools:<tools> [{\'type\': \'function\', \'function\': {\'name\': \'get_movie_details\', \'description\': \'Get details of a movie based on its title\', \'parameters\': {\'type\': \'object\', \'properties\': {\'title\': {\'type\': \'string\', \'description\': \'The title of the movie\'}}, \'required\': [\'title\']}}}, {\'type\': \'function\', \'function\': {\'name\': \'calculate_tip\', \'description\': \'Calculate the amount of tip based on a bill\', \'parameters\': {\'type\': \'object\', \'properties\': {\'bill_amount\': {\'type\': \'number\', \'description\': \'The total amount of the bill\'}, \'tip_percentage\': {\'type\': \'number\', \'description\': \'The percentage of tip to be added\'}}, \'re

## Training with LoRA adapters

In [18]:
adapter_path = "adapters_fc"
os.makedirs(adapter_path, exist_ok=True)
adapter_config_path = os.path.join(adapter_path, "adapter_config.json")
adapter_file_path = os.path.join(adapter_path, "adapters.safetensors")

In [19]:
lora_config = {
    "num_layers": 8,
    "lora_parameters": {
        "rank": 16,
        "scale": 64,
        "dropout": 0.05,
    },
}
with open(adapter_config_path, "w") as f:
    json.dump(lora_config, f, indent=4)

In [20]:
training_args = TrainingArgs(
    adapter_file=adapter_file_path,
    iters=1,
    steps_per_eval=50,
)

In [21]:
# Freeze the original model parameters
_ = model.freeze()

In [22]:
# convert selected linear layers to LoRA layers to make only a small subset of parameters trainable.
linear_to_lora_layers(model, lora_config["num_layers"], lora_config["lora_parameters"])

In [23]:
num_train_params = sum(v.size for _, v in tree_flatten(model.trainable_parameters()))
print(f"Number of trainable parameters: {num_train_params}")

Number of trainable parameters: 1048576


In [24]:
# activate training mode while still preserving the frozen state of the main model parameters.
_ = model.train()

In [34]:
# configure a metrics tracker to log both training and validation losses


def loss(model, inputs, targets, lengths):
    # Run model on inputs
    logits, _ = model(inputs)
    logits = logits.astype(mx.float32)

    # Mask padding tokens
    length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]

    # Calculate the loss
    ce = nn.losses.cross_entropy(logits, targets) * length_mask
    ntoks = length_mask.sum()
    ce = ce.sum() / ntoks
    return ce, ntoks

class Metrics:
    def __init__(self) -> None:
        self.train_losses: List[Tuple[int, float]] = []
        self.val_losses: List[Tuple[int, float]] = []

    def on_train_loss_report(self, info: Dict[str, Union[float, int]]) -> None:
        self.train_losses.append((info["iteration"], info["train_loss"]))

    def on_val_loss_report(self, info: Dict[str, Union[float, int]]) -> None:
        self.val_losses.append((info["iteration"], info["val_loss"]))


metrics = Metrics()

In [35]:
# create mlx-lm–suitable datasets by first defining the following configuration about our datasets

configs = {
    "mask_prompt": False,
    "prompt_feature": "prompt",
    "text_feature": "text",
    "completion_feature": "completion",
    "chat_feature": "messages",
}

In [36]:
train_set = datasets.create_dataset(
    dataset["train"],
    tokenizer,
    configs
)

In [37]:
val_set = datasets.create_dataset(
    dataset["test"],
    tokenizer,
    configs
)

In [None]:
# def train_imp(model, train_set, val_set, optimizer, loss, tokenizer, args):
#     # Create value and grad function for loss
#     loss_value_and_grad = nn.value_and_grad(model, loss)

#     losses = []
#     n_tokens = 0

#     # Main training loop
#     start = time.perf_counter()
#     for it, batch in zip(
#         range(args.iters),
#         iterate_batches(train_set, tokenizer, args.batch_size, train=True),
#     ):
#         # Forward and backward pass
#         (lvalue, toks), grad = loss_value_and_grad(model, *batch)

#         # Model update
#         optimizer.update(model, grad)
#         mx.eval(model.parameters(), optimizer.state, lvalue)

#         # Record loss
#         losses.append(lvalue.item())
#         n_tokens += toks.item()

#         # Report training loss if needed
#         if (it + 1) % args.steps_per_report == 0:
#             train_loss = np.mean(losses)

#             stop = time.perf_counter()
#             print(
#                 f"Iter {it + 1}: Train loss {train_loss:.3f}, "
#                 f"It/sec {args.steps_per_report / (stop - start):.3f}, "
#                 f"Tokens/sec {float(n_tokens) / (stop - start):.3f}"
#             )
#             losses = []
#             n_tokens = 0
#             start = time.perf_counter()

#         # Report validation loss if needed
#         if it == 0 or (it + 1) % args.steps_per_eval == 0:
#             stop = time.perf_counter()
#             val_loss = evaluate(
#                 model, val_set, loss, tokenizer, args.batch_size, args.val_batches
#             )
#             print(
#                 f"Iter {it + 1}: "
#                 f"Val loss {val_loss:.3f}, "
#                 f"Val took {(time.perf_counter() - stop):.3f}s"
#             )

#             start = time.perf_counter()

#         # Save adapter weights if needed
#         if (it + 1) % args.save_every == 0:
#             mx.savez(
#                 args.adapter_file, **dict(tree_flatten(model.trainable_parameters()))
#             )
#             print(f"Iter {it + 1}: Saved adapter weights to {args.adapter_file}.")

In [51]:
# def iterate_batches(dset, tokenizer, batch_size, train=False):
#     # Shuffle indices
#     while True:
#         indices = np.arange(len(dset))
#         if train:
#             indices = np.random.permutation(indices)

#         # Collect batches from dataset
#         for i in range(0, len(indices) - batch_size + 1, batch_size):
#             # Encode batch
#             batch = [tokenizer.encode(dset[indices[i + j]]) for j in range(batch_size)]
#             lengths = [len(x) for x in batch]

#             # Check if any sequence is longer than 2048 tokens
#             if max(lengths) > 2048:
#                 print(
#                     "[WARNING] Some sequences are longer than 2048 tokens. "
#                     "Consider pre-splitting your data to save memory."
#                 )

#             # Pad to the max length
#             batch_arr = np.zeros((batch_size, max(lengths)), np.int32)

#             for j in range(batch_size):
#                 batch_arr[j, : lengths[j]] = batch[j]
#             batch = mx.array(batch_arr)
#             yield batch[:, :-1], batch[:, 1:], mx.array(lengths)

#         if not train:
#             break

class CacheDataset:
    def __init__(self, data: Any):
        self._data = data
        self._proc_data = [None] * len(data)

    def itemlen(self, idx: int):
        return len(self._data[idx])

    def __getitem__(self, idx: int):
        if self._proc_data[idx] is None:
            self._proc_data[idx] = self._data.process(self._data[idx])
        return self._proc_data[idx]

    def __len__(self):
        return len(self._data)

In [None]:
# Start the fine-tuning process by calling the train() function

train(
        model=model,
        args=training_args,
        optimizer=optim.Adam(learning_rate=1e-5),
        train_dataset=CacheDataset(train_set),
        val_dataset=CacheDataset(val_set),
        training_callback=metrics,
    )

Starting training..., iters: 1


: 

In [30]:
debug

> [0;32m/Users/margarito/opt/anaconda3/envs/llm/lib/python3.10/site-packages/mlx_lm/tuner/trainer.py[0m(94)[0;36m<lambda>[0;34m()[0m
[0;32m     92 [0;31m        [0mlen_fn[0m [0;34m=[0m [0;32mlambda[0m [0midx[0m[0;34m:[0m [0mdataset[0m[0;34m.[0m[0mitemlen[0m[0;34m([0m[0midx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     93 [0;31m    [0;32melse[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 94 [0;31m        [0mlen_fn[0m [0;34m=[0m [0;32mlambda[0m [0midx[0m[0;34m:[0m [0mdataset[0m[0;34m[[0m[0midx[0m[0;34m][0m[0;34m[[0m[0;36m1[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     95 [0;31m    [0midx[0m [0;34m=[0m [0msorted[0m[0;34m([0m[0mrange[0m[0;34m([0m[0mlen[0m[0;34m([0m[0mdataset[0m[0;34m)[0m[0;34m)[0m[0;34m,[0m [0mkey[0m[0;34m=[0m[0mlen_fn[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     96 [0;31m    [0;32mif[0m [0mlen[0m[0;34m([0m[0mdataset[0m[0;34m)[0m [0;3

In [41]:
val_set[0].keys()

dict_keys(['text'])