# Evaluate Fine-tuned Model

Code authored by: Shaw Talebi

### imports

In [1]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from peft import PeftModel
from utils.tool_calling import parse_tool_call, call_tool
import pandas as pd

### load data

In [2]:
# load dataset
ds = load_dataset("shawhin/tool-use-finetuning")
ds_test = ds['test']

### load models

In [3]:
# load base model
model_name = "google/gemma-3-1b-it"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="mps",
)

# load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
# load fine-tuned model
finetuned_model_name = "shawhin/gemma-3-1b-tool-use"

finetuned_model = AutoModelForCausalLM.from_pretrained(
    finetuned_model_name,
    device_map="mps",
)

### generate responses using models

In [5]:
def evaluate_model_tool_calling(generator, row):
    """
    Evaluates whether the model correctly identifies the need for a tool,
    calls the correct tool, and executes it successfully.
    """
    # Extract values from row
    messages = row['trace']
    expected_tool_name = row.get('tool_name', None)

    # Auto-infer if a tool is needed
    tool_needed = expected_tool_name is not None and str(expected_tool_name).lower() != 'nan'

    # Set system role on first message
    messages = messages.copy()
    messages[0]['role'] = 'system'

    # Generate assistant output
    output = generator(messages[:2], return_full_text=True)[0]
    response = output['generated_text'][-1]

    # Initialize tracking flags
    model_called_tool = False
    model_tool_name = None
    model_called_correct_tool = False
    tool_call_success = False

    # Check if model issued a tool call
    if "<tool_call>" in response['content']:
        model_called_tool = True
        parsed_result = parse_tool_call(response['content'])

        if parsed_result is not None:
            tool_name, tool_args = parsed_result
            model_tool_name = tool_name

            model_called_correct_tool = (tool_name == expected_tool_name)

            try:
                result = call_tool(tool_name, tool_args)
                tool_call_success = True
            except:
                tool_call_success = False
        else:
            # tool_call marker present but malformed
            model_tool_name = None
            model_called_correct_tool = False
            tool_call_success = False

    # Final evaluation: did the model call a tool when it was needed?
    if tool_needed:
        model_called_tool_when_needed = model_called_tool
    else:
        model_called_tool_when_needed = not model_called_tool

    return {
        'response': response,
        'model_called_tool': model_called_tool,
        'model_tool_name': model_tool_name,
        'model_called_tool_when_needed': model_called_tool_when_needed,
        'model_called_correct_tool': model_called_correct_tool,
        'tool_call_success': tool_call_success,
    }


In [6]:
# create pipelines
base_generator = pipeline("text-generation", model=model, tokenizer=tokenizer, temperature=0.1)
finetuned_generator = pipeline("text-generation", model=finetuned_model, tokenizer=tokenizer, temperature=0.1)

Device set to use mps
Device set to use mps


In [7]:
%%time
results_data = []

for i, row in enumerate(ds_test):
    print("Evaluating row:", i)
    # Generate base model results
    base_results = evaluate_model_tool_calling(base_generator, row)
    result_row = {
        'model_name': model_name,
        'query': row['query'],
        'query_type': row['query_type'],
        'num_tools_available': row['num_tools_available'],
        'expected_tool_name': row['tool_name'],
        **base_results
    }
    results_data.append(result_row)

    # Generate fine-tuned model results
    finetuned_results = evaluate_model_tool_calling(finetuned_generator, row)
    finetuned_result_row = {
        'model_name': finetuned_model_name,
        'query': row['query'],
        'query_type': row['query_type'],
        'num_tools_available': row['num_tools_available'],
        'expected_tool_name': row['tool_name'],
        **finetuned_results
    }
    results_data.append(finetuned_result_row)

Evaluating row: 0
Evaluating row: 1
Evaluating row: 2
Evaluating row: 3
Evaluating row: 4
Evaluating row: 5
Evaluating row: 6
Evaluating row: 7
Evaluating row: 8
Evaluating row: 9
Evaluating row: 10
Evaluating row: 11
Evaluating row: 12
Evaluating row: 13
Evaluating row: 14
Evaluating row: 15
Evaluating row: 16
Evaluating row: 17
Evaluating row: 18
Evaluating row: 19
Evaluating row: 20
Evaluating row: 21
Evaluating row: 22
Evaluating row: 23
Evaluating row: 24
Evaluating row: 25
Evaluating row: 26
Evaluating row: 27
Evaluating row: 28
Evaluating row: 29
Evaluating row: 30
Evaluating row: 31
Evaluating row: 32
Evaluating row: 33
Evaluating row: 34
Evaluating row: 35
Evaluating row: 36
Evaluating row: 37
Evaluating row: 38
Evaluating row: 39
Evaluating row: 40
Evaluating row: 41
Evaluating row: 42
Evaluating row: 43
Evaluating row: 44
Evaluating row: 45
Evaluating row: 46
Evaluating row: 47
Evaluating row: 48
Evaluating row: 49
Evaluating row: 50
Evaluating row: 51
Evaluating row: 52
Eva

In [8]:
results_df = pd.DataFrame(results_data)

In [9]:
# write results to file
results_df.to_csv('data/eval_results.csv', index=False)

### evaluate models

In [10]:
def compare_model_performance(results_df):
    """
    Compare model performance based on tool calling metrics.
    
    Parameters:
    results_df (pd.DataFrame): DataFrame containing evaluation results with columns:
        - model_name
        - model_called_tool_when_needed
        - model_called_correct_tool
        - tool_call_success
        - expected_tool_name
    
    Returns:
    pd.DataFrame: Performance comparison with models as rows and metrics as columns
    """
    
    # Filter for rows where a tool call was needed (expected_tool_name is not None/NaN)
    tool_needed_rows = results_df[results_df['expected_tool_name'].notna() & 
                                 (results_df['expected_tool_name'] != 'nan')]
    
    # Group by model name and calculate pass rates
    performance_metrics = results_df.groupby('model_name').agg({
        'model_called_tool_when_needed': 'mean',
    }).round(4)
    
    # Calculate metrics only for rows where tool was needed
    if len(tool_needed_rows) > 0:
        tool_metrics = tool_needed_rows.groupby('model_name').agg({
            'model_called_correct_tool': 'mean',
            'tool_call_success': 'mean'
        }).round(4)
        
        # Combine the metrics
        performance_metrics = pd.concat([performance_metrics, tool_metrics], axis=1)
    else:
        # If no rows need tools, set these metrics to None
        performance_metrics['model_called_correct_tool'] = None
        performance_metrics['tool_call_success'] = None
    
    # Rename columns for clarity
    performance_metrics.columns = [
        'Tool Called When Needed (%)',
        'Correct Tool Called (%)', 
        'Tool Call Success (%)'
    ]
    
    # Convert to percentages
    performance_metrics = performance_metrics * 100
    
    return performance_metrics

In [11]:
def compare_model_performance_by_column(results_df, split_column):
    """
    Compare model performance based on tool calling metrics, split by specified column.
    
    Parameters:
    results_df (pd.DataFrame): DataFrame containing evaluation results with columns:
        - model_name
        - model_called_tool_when_needed
        - model_called_correct_tool
        - tool_call_success
        - expected_tool_name
        - split_column (the column to split analysis by)
    split_column (str): Column name to split the analysis by (e.g., 'query_type', 'num_tools_available')
    
    Returns:
    pd.DataFrame: Performance comparison with models and split_column values as rows and metrics as columns
    """
    
    # Filter for rows where a tool call was needed (expected_tool_name is not None/NaN)
    tool_needed_rows = results_df[results_df['expected_tool_name'].notna() & 
                                 (results_df['expected_tool_name'] != 'nan')]
    
    # Group by model name and split column, calculate pass rates for all rows
    performance_metrics = results_df.groupby(['model_name', split_column]).agg({
        'model_called_tool_when_needed': 'mean',
    }).round(4)
    
    # Calculate metrics only for rows where tool was needed
    if len(tool_needed_rows) > 0:
        tool_metrics = tool_needed_rows.groupby(['model_name', split_column]).agg({
            'model_called_correct_tool': 'mean',
            'tool_call_success': 'mean'
        }).round(4)
        
        # Combine the metrics
        performance_metrics = pd.concat([performance_metrics, tool_metrics], axis=1)
    else:
        # If no rows need tools, set these metrics to None
        performance_metrics['model_called_correct_tool'] = None
        performance_metrics['tool_call_success'] = None
    
    # Rename columns for clarity
    performance_metrics.columns = [
        'Tool Called When Needed (%)',
        'Correct Tool Called (%)', 
        'Tool Call Success (%)'
    ]
    
    # Convert to percentages
    performance_metrics = performance_metrics * 100
    
    return performance_metrics

#### global results

In [12]:
eval_summary = compare_model_performance(results_df)

In [13]:
eval_summary

Unnamed: 0_level_0,Tool Called When Needed (%),Correct Tool Called (%),Tool Call Success (%)
model_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
google/gemma-3-1b-it,90.0,35.0,50.0
shawhin/gemma-3-1b-tool-use,93.33,52.5,67.5


In [14]:
# write results to file
eval_summary.to_csv('data/eval_summary.csv', index=False)

#### results by query type

In [15]:
eval_summary_by_type = compare_model_performance_by_column(results_df, 'query_type')

In [16]:
eval_summary_by_type

Unnamed: 0_level_0,Unnamed: 1_level_0,Tool Called When Needed (%),Correct Tool Called (%),Tool Call Success (%)
model_name,query_type,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
google/gemma-3-1b-it,easy,100.0,50.0,65.0
google/gemma-3-1b-it,hard,100.0,20.0,35.0
google/gemma-3-1b-it,no_tool,70.0,,
shawhin/gemma-3-1b-tool-use,easy,95.0,60.0,80.0
shawhin/gemma-3-1b-tool-use,hard,85.0,45.0,55.0
shawhin/gemma-3-1b-tool-use,no_tool,100.0,,
