In [1]:
from typing import Any, List, cast

import numpy as np
from inspect_ai.analysis.beta import (
    EvalInfo,
    EvalModel,
    EvalTask,
    SampleColumn,
    SampleSummary,
    evals_df,
    samples_df,
)
from inspect_ai.analysis.beta._dataframe.extract import score_values
from inspect_ai.scorer import value_to_float
from inspect_viz import Data
from pydantic import JsonValue

from evals import evals_bar_plot, evals_table

to_float = value_to_float()


def scores(x: JsonValue) -> JsonValue:
    """Extract the headline score"""
    scores = cast(dict[str, Any], x)
    processed_scores = [to_float(v["value"]) for k, v in scores.items()]
    return processed_scores[0] if processed_scores else 0.0


simple_score = SampleColumn("score", path="scores", value=scores)

df = samples_df(
    "logs/aime/",
    columns=SampleSummary + EvalInfo + EvalTask + EvalModel + [simple_score],
)

In [2]:
from evals import evals_heatmap_plot

evals_heatmap_plot(Data.from_dataframe(df))

Component(spec='{"vconcat":[{"hconcat":[{"input":"select","multiple":false,"value":"all","label":"id: ","from"…

### first calculate pass@k in pandas

In [3]:
def pass_at_k(k: int, values: List[float]) -> float:
    total = len(values)
    correct = sum(v for v in values)

    if len(values) < k:
        raise ValueError(
            f"Cannot compute pass@{k} for {total} values. "
            "Please ensure that the number of values is at least k."
        )
    if total - correct < k:
        return 1.0
    else:
        numerators = np.arange(total - correct + 1, total + 1)
        probs_miss = np.prod(1.0 - k / numerators)
        return 1.0 - cast(float, probs_miss.item())


df["model"] = df["model"].apply(lambda x: x.split("/")[-1] if isinstance(x, str) else x)
ks = [1, 2, 3, 4, 5]

grouped = (
    df.groupby(["model", "task_name", "id"])["score"]
    .apply(list)
    .reset_index(name="scores")
)


for k in ks:
    grouped[k] = grouped["scores"].apply(lambda vals: pass_at_k(k, vals))

In [4]:
from evals import evals_pass_at_k_heatmap_plot

evals_pass_at_k_heatmap_plot(Data.from_dataframe(grouped), min_k=1, max_k=5)

  table = pa.Table.from_pandas(self, schema=requested_schema)


Component(spec='{"vconcat":[{"hconcat":[{"input":"select","multiple":false,"value":"all","label":"id: ","from"…