In [218]:
from pathlib import Path

import pandas as pd

In [219]:
LOG_PATH = Path("..", "log").with_suffix(".jsonl")

In [220]:
def get_all_files_matching(**kwargs):
    """
    Extract all rows where the kwargs are matched.
    """
    log_files = pd.read_json(LOG_PATH, lines=True)
    for key, value in kwargs.items():
        log_files = log_files[log_files[key] == value]
    return log_files

In [221]:
files = get_all_files_matching(
    model_name="/public/hf/models/meta-llama/Meta-Llama-3.1-8B-Instruct/",
    split="answerable",
)

In [222]:
Path(files["filename"][0]).name

'2025-05-08T19:25:53.jsonl'

In [223]:
p = Path("..", "runs", f"{Path(files['filename'][0]).name}")
with p.open("r") as f:
    df = pd.read_json(f, lines=True)

In [224]:
def get_model_final_answer(row):
    # Get the final answer from the row
    messages = row["messages"].to_dict()[row.index[0]]
    # Get the messages where 'role' = 'tool' and 'name' = 'final_answer'
    final_answer = [
        msg
        for msg in messages
        if msg["role"] == "tool" and msg["name"] == "final_answer"
    ]
    if final_answer:
        # Get the content of the final answer
        final_answer_content = final_answer[0]["content"]
        # Convert the content to a string
        final_answer_str = str(final_answer_content)
        return final_answer_str
    else:
        # If no final answer is found, return None
        return None


def get_gold_final_answer(row):
    # Get the final answer from the actions sequence
    actions = row["actions"].to_dict()[row.index[0]]
    # Get the actions where 'name' = 'final_answer'
    final_answer = [action for action in actions if action["name"] == "final_answer"]
    return final_answer["result"]


def get_model_tool_calls(
    row,
    clean=False,
):
    # Get the tool calls from the row
    messages = row["messages"].to_dict()[row.index[0]]

    # Get the messages where 'role' = 'tool'
    tool_calls = [
        msg["tool_calls"][0]["function"]
        for msg in messages
        if msg["role"] == "assistant" and "tool_calls" in msg
    ]

    # Arguments is a string representation of a dictionary, convert it to a dictionary
    for i in range(len(tool_calls)):
        tool_calls[i]["arguments"] = eval(tool_calls[i]["arguments"])

    # If clean, remove 'think' tool calls
    if clean:
        tool_calls = [call for call in tool_calls if call["name"] != "think"]

    return tool_calls


def get_gold_tool_calls(
    row,
    clean=False,
):
    # Get the tool calls from the actions sequence
    actions = row["actions"].to_dict()[row.index[0]]

    # Each tool call contains a 'result' key, drop it from the actions
    for action in actions:
        if "result" in action:
            del action["result"]

    return actions

In [225]:
row = df.sample(1)
gold = get_gold_tool_calls(row)
model = get_model_tool_calls(row)

In [226]:
gold

[{'name': 'get_country_codes_in_region',
  'arguments': {'region_name': 'Western Asia'}},
 {'name': 'get_indicator_code_from_name',
  'arguments': {'indicator_name': 'Immunization, DPT (% of children ages 12-23 months)'}},
 {'name': 'retrieve_value',
  'arguments': {'country_code': 'ARM',
   'indicator_code': 'SH.IMM.IDPT',
   'year': '2019'}},
 {'name': 'retrieve_value',
  'arguments': {'country_code': 'AZE',
   'indicator_code': 'SH.IMM.IDPT',
   'year': '2019'}},
 {'name': 'retrieve_value',
  'arguments': {'country_code': 'BHR',
   'indicator_code': 'SH.IMM.IDPT',
   'year': '2019'}},
 {'name': 'retrieve_value',
  'arguments': {'country_code': 'CYP',
   'indicator_code': 'SH.IMM.IDPT',
   'year': '2019'}},
 {'name': 'retrieve_value',
  'arguments': {'country_code': 'GEO',
   'indicator_code': 'SH.IMM.IDPT',
   'year': '2019'}},
 {'name': 'retrieve_value',
  'arguments': {'country_code': 'IRQ',
   'indicator_code': 'SH.IMM.IDPT',
   'year': '2019'}},
 {'name': 'retrieve_value',
  'ar

In [227]:
def make_hashable(value):
    """
    Recursively convert a value into a hashable type.
    - Dictionaries are converted to tuples of sorted key-value pairs.
    - Lists are converted to tuples.
    - Other types are returned as-is.
    """
    if isinstance(value, dict):
        return tuple(sorted((k, make_hashable(v)) for k, v in value.items()))
    elif isinstance(value, list):
        return tuple(make_hashable(v) for v in value)
    else:
        return value

In [228]:
def calculate_metrics_pure_python(gold, model):
    """
    Calculate precision, recall, and accuracy for tool calls in pure Python.
    """
    # Convert lists of dicts to sets of hashable tuples for comparison
    gold_set = {make_hashable(d) for d in gold}
    model_set = {make_hashable(d) for d in model}

    # True positives, false positives, and false negatives
    true_positives = gold_set & model_set
    false_positives = model_set - gold_set
    false_negatives = gold_set - model_set

    # Calculate metrics
    precision = (
        len(true_positives) / (len(true_positives) + len(false_positives))
        if model_set
        else 0
    )
    recall = (
        len(true_positives) / (len(true_positives) + len(false_negatives))
        if gold_set
        else 0
    )
    accuracy = (
        len(true_positives) / (len(gold_set | model_set))
        if (gold_set | model_set)
        else 0
    )

    return {"precision": precision, "recall": recall, "accuracy": accuracy}

In [229]:
gold, model

([{'name': 'get_country_codes_in_region',
   'arguments': {'region_name': 'Western Asia'}},
  {'name': 'get_indicator_code_from_name',
   'arguments': {'indicator_name': 'Immunization, DPT (% of children ages 12-23 months)'}},
  {'name': 'retrieve_value',
   'arguments': {'country_code': 'ARM',
    'indicator_code': 'SH.IMM.IDPT',
    'year': '2019'}},
  {'name': 'retrieve_value',
   'arguments': {'country_code': 'AZE',
    'indicator_code': 'SH.IMM.IDPT',
    'year': '2019'}},
  {'name': 'retrieve_value',
   'arguments': {'country_code': 'BHR',
    'indicator_code': 'SH.IMM.IDPT',
    'year': '2019'}},
  {'name': 'retrieve_value',
   'arguments': {'country_code': 'CYP',
    'indicator_code': 'SH.IMM.IDPT',
    'year': '2019'}},
  {'name': 'retrieve_value',
   'arguments': {'country_code': 'GEO',
    'indicator_code': 'SH.IMM.IDPT',
    'year': '2019'}},
  {'name': 'retrieve_value',
   'arguments': {'country_code': 'IRQ',
    'indicator_code': 'SH.IMM.IDPT',
    'year': '2019'}},
  {'n

In [230]:
calculate_metrics_pure_python(gold, model)

{'precision': 0.5128205128205128,
 'recall': 0.9090909090909091,
 'accuracy': 0.4878048780487805}