# Eval results

In [1]:
from IPython.display import display, Markdown
import pickle
import plotly.express as px
import numpy as np

In [2]:
EVAL_RESULTS_PATHS = {
    "pre_sft": "/workspace/exploration-hacking/artifacts/data/dfalck/science_locking/pre_sft_eval_train.pkl",
    "post_sft": "/workspace/exploration-hacking/artifacts/data/dfalck/science_locking/post_sft_eval_train.pkl",
}

In [3]:
all_results = {}

for eval_name, eval_path in EVAL_RESULTS_PATHS.items():
    with open(eval_path, "rb") as f:
        results = pickle.load(f)
    all_results[eval_name] = results

In [4]:
metric_names = list(set().union(*[set(results.metrics.keys()) for results in all_results.values()]))

## Global stats

In [5]:
for k in metric_names:
    for eval_name, results in all_results.items():
        display(Markdown(f"#### {k} - {eval_name}"))
        print("count", len(results.metrics[k]))
        print("mean", np.mean(results.metrics[k]))
        print("std", np.std(results.metrics[k]))
        print("max", np.max(results.metrics[k]))
        print("min", np.min(results.metrics[k]))
        px.histogram(results.metrics[k], height=300).show()

#### follow_prompt - pre_sft

count 510
mean 0.5757411764705883
std 0.2723277071272847
max 0.98
min 0.0


#### follow_prompt - post_sft

count 510
mean 0.5734313725490197
std 0.256844711984862
max 0.98
min 0.0


#### accuracy - pre_sft

count 510
mean 0.8372549019607843
std 0.369132945025802
max 1.0
min 0.0


#### accuracy - post_sft

count 510
mean 0.8490196078431372
std 0.3580297659986145
max 1.0
min 0.0


#### search_simple_calls - pre_sft

count 510
mean 0.3843137254901961
std 0.4864326118693547
max 1.0
min 0.0


#### search_simple_calls - post_sft

count 510
mean 0.36666666666666664
std 0.4859462839138429
max 2.0
min 0.0


#### completion_under_length_penalty - pre_sft

count 510
mean 0.0
std 0.0
max 0.0
min 0.0


#### completion_under_length_penalty - post_sft

count 510
mean 0.0
std 0.0
max 0.0
min 0.0


#### format_penalty_func - pre_sft

count 510
mean 0.0
std 0.0
max 0.0
min 0.0


#### format_penalty_func - post_sft

count 510
mean 0.0
std 0.0
max 0.0
min 0.0


#### completion_over_length_penalty - pre_sft

count 510
mean 0.0
std 0.0
max 0.0
min 0.0


#### completion_over_length_penalty - post_sft

count 510
mean 0.0
std 0.0
max 0.0
min 0.0


#### total_tool_calls - pre_sft

count 510
mean 0.3843137254901961
std 0.4864326118693547
max 1.0
min 0.0


#### total_tool_calls - post_sft

count 510
mean 0.36470588235294116
std 0.4813475893053831
max 1.0
min 0.0


## Per-segment stats

In [None]:
segments = {eval_name: [None for info in results.info] for eval_name, results in all_results.items()}
unique_segments = list(set().union(*[set(v) for v in segments.values()]))


In [None]:
for segment in unique_segments:
    display(Markdown(f"### Segment: {segment}"))
    for k in metric_names:
        for eval_name, results in all_results.items():
            if "_segment_" in k and (segment or "main").replace(" ", "_") not in k:
                continue
            segment_results = [
                results.metrics[k][i] for i, s in enumerate(segments[eval_name]) if s == segment
            ]
            display(Markdown(f"#### {segment} - {k} - {eval_name}"))
            print("count", len(segment_results))
            print("mean", np.mean(segment_results))
            print("std", np.std(segment_results))
            print("max", np.max(segment_results))
            print("min", np.min(segment_results))
            px.histogram(segment_results, height=300).show()

### Segment: None

#### completion_over_length_penalty - pre_sft

count 335
mean 0.0
std 0.0
max 0.0
min 0.0


#### completion_over_length_penalty - post_sft

count 335
mean 0.0
std 0.0
max 0.0
min 0.0


#### format_penalty_func - pre_sft

count 335
mean 0.0
std 0.0
max 0.0
min 0.0


#### format_penalty_func - post_sft

count 335
mean 0.0
std 0.0
max 0.0
min 0.0


#### accuracy - pre_sft

count 335
mean 0.8298507462686567
std 0.375763868920377
max 1.0
min 0.0


#### accuracy - post_sft

count 335
mean 0.7761194029850746
std 0.41684298638116246
max 1.0
min 0.0


#### search_simple_calls - pre_sft

count 335
mean 0.0955223880597015
std 0.30392105996920427
max 2.0
min 0.0


#### search_simple_calls - post_sft

count 335
mean 0.10149253731343283
std 0.3479116963422917
max 3.0
min 0.0


#### completion_under_length_penalty - pre_sft

count 335
mean 0.0
std 0.0
max 0.0
min 0.0


#### completion_under_length_penalty - post_sft

count 335
mean 0.0
std 0.0
max 0.0
min 0.0


#### total_tool_calls - pre_sft

count 335
mean 0.09253731343283582
std 0.2897829516301277
max 1.0
min 0.0


#### total_tool_calls - post_sft

count 335
mean 0.08955223880597014
std 0.28553920104043273
max 1.0
min 0.0
