In [118]:
import inspect
import os
import sys
from pathlib import Path

import pandas as pd

# In Jupyter, __file__ is not defined, so use the current working directory
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '..')))

import matcher

from frankenstein.tools import arithmetic, data_retrieval

ARITHMETIC_TOOL_NAMES = [name for name, _ in inspect.getmembers(arithmetic, predicate=inspect.isfunction)]
DATA_TOOL_NAMES = [name for name, _ in inspect.getmembers(data_retrieval, predicate=inspect.isfunction)]

run_dir = Path('runs')
dfs = {f.stem: pd.read_json(f, orient='records', lines=True, precise_float=True) for f in run_dir.iterdir()}
print(f'Found {len(dfs)} runs in {run_dir}:')
for name in dfs:
    print(f'  {name}')

m = matcher.Matcher()

Found 6 runs in runs:
  gpt-4o-mini_answerable-full
  Llama-3.2-3B-Instruct_answerable-full
  gpt-4.1-mini_answerable-full
  gpt-4o-mini_answerable-fulls
  Qwen3-4B_answerable-full
  Qwen3-14B_answerable-full


In [119]:
df = dfs['gpt-4o-mini_answerable-full']

In [120]:
df.head(1)

Unnamed: 0,id,question_template,question,actions,answer,slot_values,answerable,data_availability,answer_format,messages,tokens,pred,correct,error
0,796d6084-9389-4eee-ad9b-66a711a0b5e2,SubjectPropertyRank,"Among countries in Central America, what was t...","[{'name': 'get_country_code_from_name', 'argum...",7,"{'subject': 'HND', 'property': 'SH.IMM.MEAS', ...",True,full,int,"[{'role': 'system', 'content': 'You are a help...",2888,7,True,0


In [121]:
def get_gold_tool_calls(row: pd.Series) -> list[dict]:
    """Get the gold tool calls from a row (Series).

    Parameters
    ----------
    row : pd.Series
        A row from the DataFrame containing the actions.

    Returns
    -------
    list[dict]
        A list of dictionaries representing the tool calls, where each dictionary contains the 'name' and
        'arguments' of the tool call.

    """
    tool_calls = []
    for action in row['actions']:
        tool_calls.append({'name': action['name'], 'arguments': action['arguments']})

    # For functions with a 'values' argument (which takes a list of values), we should sort the values to perform a fair comparison.
    for call in tool_calls:
        if 'values' in call['arguments']:
            # Sort the values for comparison
            call['arguments']['values'] = sorted(call['arguments']['values'])

    return tool_calls

In [122]:
def get_pred_tool_calls(row: pd.Series) -> list[dict]:
    """Post-process predicted tool calls from messages.

    Post-processing involves applying a number of transformations to the tool calls extracted from the messages. These include:

    1. Normalising the arguments of the `less_than` tool call to `greater_than`.
    2. Checking if the `search_for_indicator_codes` tool call resulted in the correct indicator codes, and rewriting it to match the gold call if successful.


    Parameters
    ----------
    row : pd.Series
        A single row from the DataFrame containing the messages and tool calls.

    Returns
    -------
    list[dict]
        A list of dictionaries representing the post-processed tool calls.

    """
    tool_calls = []

    # Extract tool calls from messages
    for msg in row['messages']:
        if 'tool_calls' in msg:
            for call in msg['tool_calls']:
                function = call['function']
                pred_call = {'name': function['name'], 'arguments': function['arguments'], 'id': call.get('id')}
                pred_call['result'] = next(
                    m['content'] for m in row['messages'] if m.get('tool_call_id') == call.get('id') and m['role'] == 'tool'
                )
                tool_calls.append(pred_call)

    # Normalise the arguments of the `less_than` tool call to `greater_than`
    for call in tool_calls:
        if call['name'] == 'less_than':
            call['name'] = 'greater_than'
            value_a = call['arguments']['value_a']
            value_b = call['arguments']['value_b']
            call['arguments'] = {'value_a': value_b, 'value_b': value_a}

    # Remove any final_answer tool calls
    tool_calls = [call for call in tool_calls if call['name'] != 'final_answer']

    # search_for_indicator_names: check if any of these calls resulted in the correct indicator names.
    for call in tool_calls:
        if call['name'] == 'search_for_indicator_names':
            # Check if any of the returned indicator names match the 'property' slot value
            for d in call['result']:
                if d['indicator_name'] == row['slot_values']['property_original']:
                    # This counts as a successful search.
                    # Now, because it's successful, we rewrite this to match the gold call to aid analysis.
                    call['arguments'] = {
                        'keywords': row['slot_values']['property_original']
                    }  # Use the full name of the indicator.

    # For functions with a 'values' argument (which takes a list of values), we should sort the values to perform a fair comparison.
    for call in tool_calls:
        if 'values' in call['arguments']:
            # Sort the values for comparison
            call['arguments']['values'] = sorted(call['arguments']['values'])

    # Finally, drop 'id' and 'result' fields from each call
    for call in tool_calls:
        call.pop('id', None)
        call.pop('result', None)

    return tool_calls

In [123]:
df['gold_tool_calls'] = df.apply(get_gold_tool_calls, axis=1)
df['pred_tool_calls'] = df.apply(get_pred_tool_calls, axis=1)

In [124]:
def get_true_false_positives(row: pd.Series) -> float:
    """Extract the number of true and false positives from a row.

    Not quite as simple as standard true/false positives because we do not include repeated calls as true positives.

    Parameters
    ----------
    row : pd.Series
        A single row from the DataFrame containing the predicted and gold tool calls.

    Returns
    -------
    tuple[list[dict], list[dict]]
        A tuple containing two lists: the first is the list of true positives, and the second is the list of false positives.

    """
    pred = row['pred_tool_calls']
    gold = row['gold_tool_calls']
    tp = []
    fp = []
    for p in pred:
        if p in gold:
            tp.append(p)
            gold.remove(p)
        else:
            fp.append(p)
    return tp, fp


def get_precision(row: pd.Series) -> float:
    """Calculate the precision of the predicted tool calls.

    Precision is defined as the number of true positives divided by the total number of predicted tool calls.

    Parameters
    ----------
    row : pd.Series
        A single row from the DataFrame containing the predicted and gold tool calls.

    Returns
    -------
    float
        The precision of the predicted tool calls.

    """
    # Create true/false positives if they don't already exist
    if 'true_positives' not in row or 'false_positives' not in row:
        tp, fp = get_true_false_positives(row)
        row['true_positives'] = tp
        row['false_positives'] = fp

    tp = row['true_positives']
    fp = row['false_positives']

    return len(tp) / (len(tp) + len(fp)) if (len(tp) + len(fp)) > 0 else 0.0

In [125]:
df['precision'] = df.apply(get_precision, axis=1)

In [126]:
df['precision'][df['correct']].describe()

count    16.000000
mean      0.783679
std       0.343360
min       0.052632
25%       0.774457
50%       1.000000
75%       1.000000
max       1.000000
Name: precision, dtype: float64

In [127]:
df['precision'][~df['correct']].describe()

count    4.000000
mean     0.481677
std      0.302030
min      0.051724
25%      0.416777
50%      0.558704
75%      0.623604
max      0.757576
Name: precision, dtype: float64

In [128]:
def get_error_made(row: pd.Series) -> list[dict]:
    """Return true if the model made a tool call which returned an error.

    Parameters
    ----------
    row : pd.Series
        A single row from the DataFrame containing the predicted tool calls.

    Returns
    -------
    bool
        True if the model made a tool call which returned an error, False otherwise.

    """
    tool_call_error = False

    # Extract tool calls from messages
    for msg in row['messages']:
        if 'tool_calls' in msg:
            for call in msg['tool_calls']:
                function = call['function']
                pred_call = {'name': function['name'], 'arguments': function['arguments'], 'id': call.get('id')}

                # Resolve the result of the tool call from the messages
                pred_call['result'] = next(
                    m['content'] for m in row['messages'] if m.get('tool_call_id') == call.get('id') and m['role'] == 'tool'
                )

                # Check if the result is an error
                if isinstance(pred_call['result'], str) and pred_call['result'].startswith('Error:'):
                    # If the result starts with 'Error:', we consider it an error call
                    pred_call['result'] = {'error': pred_call['result']}
                    tool_call_error = True
                    break

    return tool_call_error


def get_no_search_for_indicator_names(row: pd.Series) -> bool:
    """Check if the model did not make any search_for_indicator_names tool calls.

    Parameters
    ----------
    row : pd.Series
        A single row from the DataFrame containing the predicted tool calls.

    Returns
    -------
    bool
        True if no search_for_indicator_names tool calls were made, False otherwise.

    """
    return all(call['name'] != 'search_for_indicator_names' for call in row['pred_tool_calls'])

In [129]:
df['precision'] = df.apply(get_precision, axis=1)
df['error_made'] = df.apply(get_error_made, axis=1)
df['no_search_for_indicator_names'] = df.apply(get_no_search_for_indicator_names, axis=1)

In [130]:
pd.crosstab(
    df['correct'],
    df['error_made'],
    rownames=['Correct Answer Given'],
    colnames=['Error Made'],
)

Error Made,False,True
Correct Answer Given,Unnamed: 1_level_1,Unnamed: 2_level_1
False,2,2
True,12,4


In [131]:
pd.crosstab(
    df['correct'],
    df['no_search_for_indicator_names'],
    rownames=['Correct Answer Given'],
    colnames=['No Search for Indicator Names'],
)

No Search for Indicator Names,False,True
Correct Answer Given,Unnamed: 1_level_1,Unnamed: 2_level_1
False,2,2
True,4,12
