In [1]:
import logging
import os
import sys
from pathlib import Path

import pandas as pd

# In Jupyter, __file__ is not defined, so use the current working directory
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '..')))
# from frankenstein.graph import FrankensteinGraph
import matcher

# Suppress all logging from FrankensteinGraph and its dependencies
logging.getLogger().setLevel(logging.ERROR)

run_dir = Path('runs')
dfs = {f.stem: pd.read_json(f, orient='records', lines=True) for f in run_dir.iterdir()}

m = matcher.Matcher()



In [2]:
df = dfs['gpt-4o-mini_answerable-full']

In [3]:
df.iloc[0]['messages']

[{'role': 'system',
  'content': "You are a helpful assistant tasked with answering questions that require multiple intermediate steps of reasoning to arrive at a final answer.\n\nThe questions involve using World Bank data for various countries and indicators.\n\nCreate a step-by-step plan to answer the question, and then execute each step of that plan to arrive at the final answer.\n\nThe conversation will only end after you call the `final_answer` tool with your final answer.\n\nYou have access to a set of tools to help you answer the question:\n\nPay attention to the tool names, arguments, descriptions, and the types of outputs they return, and think carefully about how to use them to solve the problem.\n\nIf there is a tool available that can help you with the next step, you must use it rather than trying to solve the problem without it.\n\nI will execute tool calls that you provide. You can use multiple tools in one step, but make sure you follow the correct format.\n\nUse the re

In [17]:
def get_gold_tool_calls(row: pd.Series) -> pd.DataFrame:
    """Get the gold tool calls from the DataFrame."""
    tool_calls = []
    for action in row['actions']:
        tool_calls.append({'name': action['name'], 'arguments': action['arguments']})

    return tool_calls


def get_pred_tool_calls(row: pd.Series):
    """Get the predicted tool calls from the DataFrame."""
    tool_calls = []
    for msg in row['messages']:
        if 'tool_calls' in msg:
            for call in msg['tool_calls']:
                function = call['function']
                tool_calls.append({'name': function['name'], 'arguments': function['arguments']})
    return tool_calls


In [18]:
get_gold_tool_calls(df.iloc[0])

[{'name': 'get_country_code_from_name',
  'arguments': {'country_name': 'Viet Nam'}},
 {'name': 'get_indicator_code_from_name',
  'arguments': {'indicator_name': 'Container port traffic (TEU: 20 foot equivalent units)'}},
 {'name': 'retrieve_value',
  'arguments': {'country_code': 'VNM',
   'indicator_code': 'IS.SHP.GOOD.TU',
   'year': '2010'}},
 {'name': 'retrieve_value',
  'arguments': {'country_code': 'VNM',
   'indicator_code': 'IS.SHP.GOOD.TU',
   'year': '2011'}},
 {'name': 'retrieve_value',
  'arguments': {'country_code': 'VNM',
   'indicator_code': 'IS.SHP.GOOD.TU',
   'year': '2012'}},
 {'name': 'retrieve_value',
  'arguments': {'country_code': 'VNM',
   'indicator_code': 'IS.SHP.GOOD.TU',
   'year': '2013'}},
 {'name': 'retrieve_value',
  'arguments': {'country_code': 'VNM',
   'indicator_code': 'IS.SHP.GOOD.TU',
   'year': '2014'}},
 {'name': 'retrieve_value',
  'arguments': {'country_code': 'VNM',
   'indicator_code': 'IS.SHP.GOOD.TU',
   'year': '2015'}},
 {'name': 'subtr

In [15]:
get_pred_tool_calls(df.iloc[0])

[{'name': 'get_country_code_from_name',
  'arguments': {'country_name': 'Viet Nam'}},
 {'name': 'get_indicator_code_from_name',
  'arguments': {'indicator_name': 'Container throughput at ports'}},
 {'name': 'search_for_indicator_codes',
  'arguments': {'keywords': ['Container', 'throughput', 'ports']}},
 {'name': 'retrieve_value',
  'arguments': {'country_code': 'VNM',
   'indicator_code': 'IS.SHP.GOOD.TU',
   'year': '2010'}},
 {'name': 'retrieve_value',
  'arguments': {'country_code': 'VNM',
   'indicator_code': 'IS.SHP.GOOD.TU',
   'year': '2011'}},
 {'name': 'retrieve_value',
  'arguments': {'country_code': 'VNM',
   'indicator_code': 'IS.SHP.GOOD.TU',
   'year': '2012'}},
 {'name': 'retrieve_value',
  'arguments': {'country_code': 'VNM',
   'indicator_code': 'IS.SHP.GOOD.TU',
   'year': '2013'}},
 {'name': 'retrieve_value',
  'arguments': {'country_code': 'VNM',
   'indicator_code': 'IS.SHP.GOOD.TU',
   'year': '2014'}},
 {'name': 'retrieve_value',
  'arguments': {'country_code': 