In [7]:
import inspect
import os
import sys
from pathlib import Path

import analysis
import pandas as pd

# In Jupyter, __file__ is not defined, so use the current working directory
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '..')))

import matcher

from frankenstein.tools import arithmetic, data_retrieval

ARITHMETIC_TOOL_NAMES = [name for name, _ in inspect.getmembers(arithmetic, predicate=inspect.isfunction)]
DATA_TOOL_NAMES = [name for name, _ in inspect.getmembers(data_retrieval, predicate=inspect.isfunction)]

run_dir = Path('runs')
dfs = {f.stem: pd.read_json(f, orient='records', lines=True, precise_float=True) for f in run_dir.iterdir()}
print(f'Found {len(dfs)} runs in {run_dir}:')
for name in dfs:
    print(f'  {name}')

m = matcher.Matcher()

Found 10 runs in runs:
  gpt-4o-mini_answerable-full_all-tools_3-shot
  gpt-4.1-mini_answerable-full_all-tools_1-shot
  gpt-4o-mini_answerable-full_data-tools_0-shot
  gpt-4.1-mini_answerable-full_data-tools_0-shot
  gpt-4.1-mini_answerable-full_all-tools_0-shot
  gpt-4o-mini_answerable-full_all-tools_1-shot
  gpt-4.1-mini_answerable-full_all-tools_3-shot
  gpt-4.1-mini_answerable-partial_all-tools_0-shot
  gpt-4o-mini_answerable-full_all-tools_0-shot
  gpt-4o-mini_answerable-partial_all-tools_0-shot


In [8]:
summary_rows = []
for run_name, run_df in dfs.items():
    # Compute tool call columns if not already present
    # Keep only rows where df['answer_format'] is not bool
    run_df = run_df[run_df['answer_format'] != bool]
    if 'gold_tool_calls' not in run_df.columns:
        run_df['gold_tool_calls'] = run_df.apply(analysis.get_gold_tool_calls, axis=1)
    if 'pred_tool_calls' not in run_df.columns:
        run_df['pred_tool_calls'] = run_df.apply(analysis.get_pred_tool_calls, axis=1)
    if 'precision' not in run_df.columns:
        run_df['precision'] = run_df.apply(analysis.get_precision, axis=1)
    if 'coverage' not in run_df.columns:
        run_df['coverage'] = run_df.apply(analysis.get_coverage, axis=1)
    if 'recall' not in run_df.columns:
        run_df['recall'] = run_df.apply(analysis.get_recall, axis=1)
    if 'error_made' not in run_df.columns:
        run_df['error_made'] = run_df.apply(analysis.get_error_made, axis=1)
    if 'no_search_for_indicator_names' not in run_df.columns:
        run_df['no_search_for_indicator_names'] = run_df.apply(analysis.get_no_search_for_indicator_names, axis=1)

    summary = {
        'run': run_name,
        'n': len(run_df),
        'accuracy': run_df['correct'].mean() if 'correct' in run_df.columns else None,
        'precision_mean': run_df['precision'].mean(),
        'precision_std': run_df['precision'].std(),
        'coverage_mean': run_df['coverage'].mean(),
        'coverage_std': run_df['coverage'].std(),
        'recall_mean': run_df['recall'].mean(),
        'recall_std': run_df['recall'].std(),
        'error_rate': run_df['error_made'].mean(),
        'no_search_rate': run_df['no_search_for_indicator_names'].mean(),
    }
    summary_rows.append(summary)

summary_df = pd.DataFrame(summary_rows)
summary_df = summary_df.sort_values(by='run')
summary_df = summary_df.round(3)  # Round to 3 decimal places
summary_df

Unnamed: 0,run,n,accuracy,precision_mean,precision_std,coverage_mean,coverage_std,recall_mean,recall_std,error_rate,no_search_rate
4,gpt-4.1-mini_answerable-full_all-tools_0-shot,400,0.705,0.813,0.213,0.346,0.458,0.375,0.491,0.528,0.318
1,gpt-4.1-mini_answerable-full_all-tools_1-shot,400,0.702,0.828,0.198,0.339,0.459,0.368,0.497,0.37,0.368
6,gpt-4.1-mini_answerable-full_all-tools_3-shot,400,0.7,0.833,0.22,0.273,0.428,0.298,0.461,0.358,0.495
3,gpt-4.1-mini_answerable-full_data-tools_0-shot,400,0.495,0.823,0.184,0.067,0.212,0.077,0.225,0.205,0.19
7,gpt-4.1-mini_answerable-partial_all-tools_0-shot,236,0.538,0.764,0.202,0.166,0.372,0.167,0.372,0.644,0.127
8,gpt-4o-mini_answerable-full_all-tools_0-shot,400,0.68,0.665,0.302,0.239,0.41,0.251,0.42,0.52,0.542
5,gpt-4o-mini_answerable-full_all-tools_1-shot,400,0.638,0.631,0.336,0.162,0.351,0.173,0.362,0.402,0.688
0,gpt-4o-mini_answerable-full_all-tools_3-shot,400,0.642,0.639,0.331,0.166,0.358,0.175,0.366,0.398,0.668
2,gpt-4o-mini_answerable-full_data-tools_0-shot,400,0.358,0.675,0.347,0.049,0.195,0.053,0.199,0.518,0.548
9,gpt-4o-mini_answerable-partial_all-tools_0-shot,5,0.4,0.614,0.4,0.0,0.0,0.0,0.0,0.6,0.4
