In [1]:
import inspect
import os
import sys
from pathlib import Path

import analysis
import pandas as pd

# In Jupyter, __file__ is not defined, so use the current working directory
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '..')))

import matcher

from frankenstein.tools import arithmetic, data_retrieval

ARITHMETIC_TOOL_NAMES = [name for name, _ in inspect.getmembers(arithmetic, predicate=inspect.isfunction)]
DATA_TOOL_NAMES = [name for name, _ in inspect.getmembers(data_retrieval, predicate=inspect.isfunction)]

run_dir = Path('runs')
dfs = {f.stem: pd.read_json(f, orient='records', lines=True, precise_float=True) for f in run_dir.iterdir()}
print(f'Found {len(dfs)} runs in {run_dir}:')
for name in dfs:
    print(f'  {name}')

m = matcher.Matcher()

Found 49 runs in runs:
  Mistral-Small-3.1-24B_answerable-full_all-tools_0-shot
  Llama-3.1-8B-Instruct_answerable-full_all-tools_1-shot
  Qwen3-14B_answerable-full_data-tools_0-shot
  Qwen3-14B_answerable-full_all-tools_0-shot
  Llama-3.2-3B-Instruct_answerable-full_all-tools_3-shot
  Qwen3-4B_answerable-full_data-tools_0-shot
  Qwen3-32B_answerable-full_all-tools_1-shot
  Mistral-Small-3.1-24B_answerable-partial_all-tools_0-shot
  gpt-4.1-mini_answerable-full_all-tools_1-shot
  Llama-3.1-70B-Instruct_answerable-full_all-tools_3-shot
  gpt-4o-mini_answerable-full_all-tools_1-shot
  Llama-3.3-70B-Instruct_answerable-partial_all-tools_0-shot
  Llama-3.1-8B-Instruct_answerable-full_data-tools_0-shot
  Llama-3.2-3B-Instruct_answerable-full_data-tools_0-shot
  Llama-3.1-8B-Instruct_answerable-full_all-tools_3-shot
  Qwen3-32B_answerable-full_data-tools_0-shot
  Qwen3-4B_answerable-full_all-tools_0-shot
  Llama-3.3-70B-Instruct_answerable-full_all-tools_0-shot
  Qwen3-14B_answerable-partial

In [2]:
summary_rows = []
for run_name, run_df in dfs.items():
    print(f'Processing run: {run_name}')
    # Compute tool call columns if not already present
    # Keep only rows where df['answer_format'] is not bool
    run_df = run_df[run_df['answer_format'] != bool]
    if 'gold_tool_calls' not in run_df.columns:
        run_df['gold_tool_calls'] = run_df.apply(analysis.get_gold_tool_calls, axis=1)
    if 'pred_tool_calls' not in run_df.columns:
        run_df['pred_tool_calls'] = run_df.apply(analysis.get_pred_tool_calls, axis=1)
    if 'true_positives' not in run_df.columns:
        run_df['true_positives'] = run_df.apply(analysis.get_true_positives, axis=1)
    if 'false_positives' not in run_df.columns:
        run_df['false_positives'] = run_df.apply(analysis.get_false_positives, axis=1)
    if 'false_negatives' not in run_df.columns:
        run_df['false_negatives'] = run_df.apply(analysis.get_false_negatives, axis=1)
    if 'precision' not in run_df.columns:
        run_df['precision'] = run_df.apply(analysis.get_precision, axis=1)
    if 'coverage' not in run_df.columns:
        run_df['coverage'] = run_df.apply(analysis.get_coverage, axis=1)
    if 'recall' not in run_df.columns:
        run_df['recall'] = run_df.apply(analysis.get_recall, axis=1)
    if 'error_made' not in run_df.columns:
        run_df['error_made'] = run_df.apply(analysis.get_error_made, axis=1)
    if 'no_search_for_indicator_names' not in run_df.columns:
        run_df['no_search_for_indicator_names'] = run_df.apply(analysis.get_no_search_for_indicator_names, axis=1)

    summary = {
        'run': run_name,
        'n': len(run_df),
        'accuracy': run_df['correct'].mean() if 'correct' in run_df.columns else None,
        'precision_mean': run_df['precision'].mean(),
        'precision_std': run_df['precision'].std(),
        'coverage_mean': run_df['coverage'].mean(),
        'coverage_std': run_df['coverage'].std(),
        'recall_mean': run_df['recall'].mean(),
        'recall_std': run_df['recall'].std(),
        'error_rate': run_df['error_made'].mean(),
        'no_search_rate': run_df['no_search_for_indicator_names'].mean(),
    }
    summary_rows.append(summary)

summary_df = pd.DataFrame(summary_rows)
summary_df = summary_df.sort_values(by='run')
summary_df = summary_df.round(3)  # Round to 3 decimal places
summary_df

Processing run: Mistral-Small-3.1-24B_answerable-full_all-tools_0-shot
Processing run: Llama-3.1-8B-Instruct_answerable-full_all-tools_1-shot
Processing run: Qwen3-14B_answerable-full_data-tools_0-shot
Processing run: Qwen3-14B_answerable-full_all-tools_0-shot
Processing run: Llama-3.2-3B-Instruct_answerable-full_all-tools_3-shot
Processing run: Qwen3-4B_answerable-full_data-tools_0-shot
Processing run: Qwen3-32B_answerable-full_all-tools_1-shot
Processing run: Mistral-Small-3.1-24B_answerable-partial_all-tools_0-shot
Processing run: gpt-4.1-mini_answerable-full_all-tools_1-shot
Processing run: Llama-3.1-70B-Instruct_answerable-full_all-tools_3-shot
Processing run: gpt-4o-mini_answerable-full_all-tools_1-shot
Processing run: Llama-3.3-70B-Instruct_answerable-partial_all-tools_0-shot
Processing run: Llama-3.1-8B-Instruct_answerable-full_data-tools_0-shot
Processing run: Llama-3.2-3B-Instruct_answerable-full_data-tools_0-shot
Processing run: Llama-3.1-8B-Instruct_answerable-full_all-tool

Unnamed: 0,run,n,accuracy,precision_mean,precision_std,coverage_mean,coverage_std,recall_mean,recall_std,error_rate,no_search_rate
39,Llama-3.1-70B-Instruct_answerable-full_all-too...,400,0.265,0.414,0.357,0.35,0.346,0.35,0.346,0.385,0.672
21,Llama-3.1-70B-Instruct_answerable-full_all-too...,400,0.382,0.496,0.339,0.429,0.346,0.429,0.346,0.325,0.775
9,Llama-3.1-70B-Instruct_answerable-full_all-too...,400,0.315,0.434,0.371,0.336,0.341,0.336,0.341,0.23,0.828
32,Llama-3.1-70B-Instruct_answerable-full_data-to...,400,0.3,0.746,0.28,0.548,0.27,0.548,0.27,0.535,0.498
38,Llama-3.1-70B-Instruct_answerable-partial_all-...,260,0.177,0.37,0.306,0.312,0.335,0.312,0.335,0.627,0.431
33,Llama-3.1-8B-Instruct_answerable-full_all-tool...,400,0.035,0.076,0.205,0.045,0.136,0.045,0.136,0.202,0.912
1,Llama-3.1-8B-Instruct_answerable-full_all-tool...,400,0.032,0.083,0.196,0.038,0.102,0.038,0.102,0.248,0.915
14,Llama-3.1-8B-Instruct_answerable-full_all-tool...,400,0.07,0.115,0.233,0.059,0.133,0.059,0.133,0.225,0.845
12,Llama-3.1-8B-Instruct_answerable-full_data-too...,400,0.085,0.382,0.33,0.2,0.231,0.2,0.231,0.68,0.68
47,Llama-3.2-3B-Instruct_answerable-full_all-tool...,400,0.06,0.037,0.104,0.028,0.097,0.028,0.097,0.372,0.712
