In [1]:
import torch
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
import pickle, json, re
from datasets import load_dataset, Dataset

In [2]:
torch_dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch_dtype, device

In [3]:
model_path = "../finetuned_models/gemma-3-4b-it-function-calling-V1-merged/"

tokenizer = AutoTokenizer.from_pretrained(model_path)

model = Gemma3ForConditionalGeneration.from_pretrained(model_path, device_map="auto", 
                                                       torch_dtype=torch_dtype#torch_dtype, auto
                                                      )
# model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch_dtype, device_map="auto")
model.eval()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Gemma3ForConditionalGeneration(
  (vision_tower): SiglipVisionModel(
    (vision_model): SiglipVisionTransformer(
      (embeddings): SiglipVisionEmbeddings(
        (patch_embedding): Conv2d(3, 1152, kernel_size=(14, 14), stride=(14, 14), padding=valid)
        (position_embedding): Embedding(4096, 1152)
      )
      (encoder): SiglipEncoder(
        (layers): ModuleList(
          (0-26): 27 x SiglipEncoderLayer(
            (layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
            (self_attn): SiglipAttention(
              (k_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (v_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (q_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (out_proj): Linear(in_features=1152, out_features=1152, bias=True)
            )
            (layer_norm2): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
            (mlp): SiglipMLP(
            

In [4]:
def batch_inference_preprocessing(batch):
    results = {"id": [], "formatted_prompt": []}
    
    for tools_json, answers_json, query, sample_id in zip(batch["tools"], batch["answers"], 
                                                          batch["query"], batch["id"]):
        try:
            # Parse JSON strings safely
            tools = json.loads(tools_json) if isinstance(tools_json, str) else tools_json
            expected_answers = json.loads(answers_json) if isinstance(answers_json, str) else answers_json
            
            # Create system prompt (matches training format)
            system_prompt = "You are a helpful assistant that can call functions to help answer user queries. When you need to use a tool, format your response with <function_call> tags containing valid JSON. Always provide the function call in the exact format requested."
            
            # Format available tools (consistent with training)
            tools_formatted = []
            for tool in tools:
                tool_info = {
                    "name": tool["name"],
                    "description": tool["description"],
                    "parameters": tool.get("parameters", {})
                }
                tools_formatted.append(json.dumps(tool_info, indent=2))
            
            tools_text = "Available tools:\n" + "\n\n".join(tools_formatted)
            
            # Create messages (matches training structure)
            messages = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": f"{tools_text}\n\nUser query: {query}"}
            ]
            
            # Apply tokenizer's chat template for generation
            formatted_prompt = tokenizer.apply_chat_template(messages,
                                                             add_generation_prompt=True,
                                                             tokenize=False,
                                                             dtype=torch_dtype)
            
            results["id"].append(sample_id)
            results["formatted_prompt"].append(formatted_prompt)
            
        except (json.JSONDecodeError, KeyError, TypeError) as e:
            print(f"Error processing sample {sample_id}: {e}")
            print(f"Tools: {tools_json}")
            print(f"Query: {query}")
            # Skip this sample or add empty/default values
            continue
    
    return results

In [5]:
with open("../data/xlam-function-calling-60k-updated-test_data.pkl", "rb") as f:
    test_data = pickle.load(f)

test_dataset = Dataset.from_list(test_data)

dataset_test = test_dataset.map(batch_inference_preprocessing, batched = True,)

dataset_test

Map:   0%|          | 0/600 [00:00<?, ? examples/s]

Dataset({
    features: ['id', 'query', 'answers', 'tools', 'formatted_prompt'],
    num_rows: 600
})

In [6]:
test_data = dataset_test[370]

print(test_data["formatted_prompt"])

<bos><start_of_turn>user
You are a helpful assistant that can call functions to help answer user queries. When you need to use a tool, format your response with <function_call> tags containing valid JSON. Always provide the function call in the exact format requested.

Available tools:
{
  "name": "fibonacci",
  "description": "Calculates the nth Fibonacci number.",
  "parameters": {
    "n": {
      "description": "The position of the Fibonacci number.",
      "type": "int"
    }
  }
}

{
  "name": "generate_password",
  "description": "Generates a random password of specified length and character types.",
  "parameters": {
    "length": {
      "description": "The length of the password. Defaults to 12.",
      "type": "int, optional",
      "default": 12
    },
    "include_special": {
      "description": "Whether to include special characters in the password. Defaults to True.",
      "type": "bool, optional",
      "default": true
    }
  }
}

{
  "name": "is_subset",
  "descript

In [7]:
inputs = tokenizer(test_data["formatted_prompt"], return_tensors="pt", add_special_tokens=False).to(model.device)

input_len = inputs["input_ids"].shape[-1]

with torch.inference_mode():
    generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
    generation = generation[0][input_len:]

decoded = tokenizer.decode(generation, skip_special_tokens=True)
print(decoded)



<function_call>
{"name":"fibonacci","arguments":{"n":12}}
</function_call>
<function_call>
{"name":"generate_password","arguments":{"length":10,"include_special":false}}
</function_call>


In [8]:
def extract_tool_calls(text):
    try:
        matches = re.findall(r"<function_call>\s*(\{.*?\})\s*</function_call>", text, re.DOTALL)
        tool_calls = []
        for match in matches:
            # Ensure JSON is stripped of leading/trailing whitespace
            cleaned_json = match.strip()
            data = json.loads(cleaned_json)
            tool_calls.append(data)
        return tool_calls
    except Exception as e:
        print(f"Error: {e}")
        return []

In [9]:
print("QUERY is :", test_data["query"])
print("\n After Fine-tuning Output:", extract_tool_calls(decoded))

QUERY is : Calculate the 12th Fibonacci number and generate a random password of length 10 without special characters

 After Fine-tuning Output: [{'name': 'fibonacci', 'arguments': {'n': 12}}, {'name': 'generate_password', 'arguments': {'length': 10, 'include_special': False}}]


In [10]:
print("Original Output :", test_data["answers"])

Original Output : [{"name": "fibonacci", "arguments": {"n": 12}}, {"name": "generate_password", "arguments": {"length": 10, "include_special": false}}]


### With out finetuning

In [3]:
model_path = "../../models/gemma-3-4b-it/"

wf_tokenizer = AutoTokenizer.from_pretrained(model_path)

wf_model = Gemma3ForConditionalGeneration.from_pretrained(model_path, device_map="auto", 
                                                       torch_dtype=torch_dtype#torch_dtype, auto
                                                      )
# model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch_dtype, device_map="auto")
wf_model.eval()

In [4]:
def batch_inference_preprocessing(batch):
    results = {"id": [], "formatted_prompt": []}
    
    for tools_json, answers_json, query, sample_id in zip(batch["tools"], batch["answers"], 
                                                          batch["query"], batch["id"]):
        try:
            # Parse JSON strings safely
            tools = json.loads(tools_json) if isinstance(tools_json, str) else tools_json
            expected_answers = json.loads(answers_json) if isinstance(answers_json, str) else answers_json
            
            # Create system prompt (matches training format)
            system_prompt = "You are a helpful assistant that can call functions to help answer user queries. When you need to use a tool, format your response with <function_call> tags containing valid JSON. Always provide the function call in the exact format requested."
            
            # Format available tools (consistent with training)
            tools_formatted = []
            for tool in tools:
                tool_info = {
                    "name": tool["name"],
                    "description": tool["description"],
                    "parameters": tool.get("parameters", {})
                }
                tools_formatted.append(json.dumps(tool_info, indent=2))
            
            tools_text = "Available tools:\n" + "\n\n".join(tools_formatted)
            
            # Create messages (matches training structure)
            messages = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": f"{tools_text}\n\nUser query: {query}"}
            ]
            
            # Apply tokenizer's chat template for generation
            formatted_prompt = wf_tokenizer.apply_chat_template(messages,
                                                             add_generation_prompt=True,
                                                             tokenize=False,
                                                             dtype=torch_dtype)
            
            results["id"].append(sample_id)
            results["formatted_prompt"].append(formatted_prompt)
            
        except (json.JSONDecodeError, KeyError, TypeError) as e:
            print(f"Error processing sample {sample_id}: {e}")
            print(f"Tools: {tools_json}")
            print(f"Query: {query}")
            # Skip this sample or add empty/default values
            continue
    
    return results

In [5]:
with open("../data/xlam-function-calling-60k-updated-test_data.pkl", "rb") as f:
    test_data = pickle.load(f)

test_dataset = Dataset.from_list(test_data)

dataset_test = test_dataset.map(batch_inference_preprocessing, batched = True,)

dataset_test

Map:   0%|          | 0/600 [00:00<?, ? examples/s]

Dataset({
    features: ['id', 'query', 'answers', 'tools', 'formatted_prompt'],
    num_rows: 600
})

In [5]:
test_data = dataset_test[370]

# print(test_data["formatted_prompt"])

In [13]:
inputs = wf_tokenizer(test_data["formatted_prompt"], return_tensors="pt", add_special_tokens=False).to(model.device)

input_len = inputs["input_ids"].shape[-1]

with torch.inference_mode():
    generation = wf_model.generate(**inputs, max_new_tokens=100, do_sample=False)
    generation = generation[0][input_len:]

decoded = wf_tokenizer.decode(generation, skip_special_tokens=True)
print(decoded)

Here’s the breakdown of your request and the corresponding function calls:

First, I need to calculate the 12th Fibonacci number. I'll use the `fibonacci` tool for this.

<function_call>
{
  "tool": "fibonacci",
  "parameters": {
    "n": 12
  }
}
</function_call>

Second, I need to generate a random password of length 10 without


In [14]:
print(test_data["answers"])

[{"name": "fibonacci", "arguments": {"n": 12}}, {"name": "generate_password", "arguments": {"length": 10, "include_special": false}}]


In [15]:
print(test_data["query"])

Calculate the 12th Fibonacci number and generate a random password of length 10 without special characters
