# Performance comparison table

This notebook evaluates various rollout datasets to re-produce Table 1 of the main paper.
For these calculations, the rollout dataset containing response sequences from the corresponding models is read from MongoDB. 
This dataset can be downloaded as an archive here and then imported using `mongorestore` as outlined in [RolloutDatasets.ipynb](RolloutDatasets.ipynb).

In [1]:
import sys; sys.path.append("..")

import logging
import numpy as np
import pandas as pd

from searchformer.rollout import RolloutDataStore
from searchformer.utils import mongodb_client


logging.basicConfig(
    level=logging.DEBUG,
    format="%(levelname)s - %(asctime)s - %(name)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)

Selecting datasets used for evaluation.

In [2]:
datastore = RolloutDataStore()

rollout_data = datastore.list_all()
rollout_data = rollout_data[rollout_data["rollout_repeats"] == 64]
rollout_data = rollout_data[rollout_data["checkpoint_id"].isin(
    [
        "sokoban-7722-m-plan-only-100k-0",
        "sokoban-7722-m-plan-only-100k-1",
        "sokoban-7722-m-plan-only-100k-2",
        "sokoban-7722-m-trace-plan-100k-0",
        "sokoban-7722-m-trace-plan-100k-1",
        "sokoban-7722-m-trace-plan-100k-2",
        "sokoban-7722-m-trace-plan-100k-0-step-1",
        "sokoban-7722-m-trace-plan-100k-1-step-1",
        "sokoban-7722-m-trace-plan-100k-2-step-1",
        "sokoban-7722-m-trace-plan-100k-0-step-2",
        "sokoban-7722-m-trace-plan-100k-1-step-2",
        "sokoban-7722-m-trace-plan-100k-2-step-2",
        "sokoban-7722-m-trace-plan-100k-0-step-3",
        "sokoban-7722-m-trace-plan-100k-1-step-3",
        "sokoban-7722-m-trace-plan-100k-2-step-3",
        "sokoban-7722-l-plan-only-100k-0",
        "sokoban-7722-l-plan-only-100k-1",
        "sokoban-7722-l-plan-only-100k-2",
        "sokoban-7722-l-trace-plan-100k-0",
        "sokoban-7722-l-trace-plan-100k-1",
        "sokoban-7722-l-trace-plan-100k-2",
        "sokoban-7722-xl-plan-only-100k-0",
        "sokoban-7722-xl-plan-only-100k-1",
        "sokoban-7722-xl-plan-only-100k-2",
    ]
)]
rollout_data["id"] = [str(i) for i in rollout_data["_id"]]
rollout_data.drop(columns=["_id"], inplace=True)
rollout_data.head()

INFO - 2024-04-26 14:02:09 - root - Connecting to mongodb://localhost:27017/mongo


Unnamed: 0,checkpoint_id,dataset_name,sampler_name,rollout_len,rollout_repeats,prefix_len,min_reasoning_len,max_reasoning_len,id
341,sokoban-7722-m-trace-plan-100k-0,sokoban.7-by-7-walls-2-boxes-2.with-box-40k,probability,11000,64,0,0,10000,65b6bbb143977a604499c572
342,sokoban-7722-m-plan-only-100k-0,sokoban.7-by-7-walls-2-boxes-2.with-box-40k,probability,2000,64,0,0,10000,65b6bbc479cc93df45d17a5c
343,sokoban-7722-m-plan-only-100k-1,sokoban.7-by-7-walls-2-boxes-2.with-box-40k,probability,2000,64,0,0,10000,65b6bbc50fe338eeca6cea91
344,sokoban-7722-m-plan-only-100k-2,sokoban.7-by-7-walls-2-boxes-2.with-box-40k,probability,2000,64,0,0,10000,65b6bbc6b4bf8d5e7493570e
345,sokoban-7722-l-plan-only-100k-0,sokoban.7-by-7-walls-2-boxes-2.with-box-40k,probability,2000,64,0,0,10000,65b6bbfde67a96d8e48e008b


The following code iterates over the rollout test datasets and evaluates how often an optimal or correct plan is contained in the recorded responses.

In [3]:
import multiprocessing as mp
from typing import Optional
from searchformer.sokoban import evaluate_rollout

def eval_rollout_dataset(dataset_id: str) -> Optional[pd.DataFrame]:
    logging.debug(f"Loading dataset with id {dataset_id}.")
    datastore = RolloutDataStore()
    dataset = datastore.load_by_id(dataset_id)
    rollout_df_list = []
    for rollout in dataset.rollout_test_it():
        assert len(rollout.rollouts) == 64, "Incorrect number of generated sequences for prompt."
        eval_df = evaluate_rollout(rollout).to_dataframe()
        rollout_df_list.append(eval_df)

    if len(rollout_df_list) == 0:
        return None

    rollout_df = pd.concat(rollout_df_list)
    rollout_df.rename(columns={"_id": "sequence_id"}, inplace=True)
    rollout_df["id"] = dataset_id
    return rollout_df


with mp.Pool() as p:
    evals_df_list = p.map(eval_rollout_dataset, rollout_data["id"].tolist())
evals_df = pd.concat(filter(lambda d: d is not None, evals_df_list))  # type: ignore
evals_df = pd.merge(evals_df, rollout_data, on="id", how="left")
evals_df["plan_optimal"] = evals_df["plan_correct"] & (evals_df["plan_length"] == evals_df["optimal_plan_length"])
evals_df = evals_df[["checkpoint_id", "sequence_id", "plan_optimal", "plan_correct", "trace_tokens", "reasoning_length", "plan_length", "optimal_plan_length"]].copy()
evals_df["tokens_improved"] = evals_df["reasoning_length"] - evals_df["trace_tokens"]
evals_df["tokens_improved"] = evals_df["tokens_improved"].values * evals_df["plan_optimal"].values.astype(np.int32)

ckpt_2_set = pd.DataFrame(
    data=[
        ("sokoban-7722-m-plan-only-100k-0", "45M Solution only"),
        ("sokoban-7722-m-plan-only-100k-1", "45M Solution only"),
        ("sokoban-7722-m-plan-only-100k-2", "45M Solution only"),
        ("sokoban-7722-m-trace-plan-100k-0", "45M Search augmented"),
        ("sokoban-7722-m-trace-plan-100k-1", "45M Search augmented"),
        ("sokoban-7722-m-trace-plan-100k-2", "45M Search augmented"),
        ("sokoban-7722-m-trace-plan-100k-0-step-1", "45M Searchformer, step 1"),
        ("sokoban-7722-m-trace-plan-100k-1-step-1", "45M Searchformer, step 1"),
        ("sokoban-7722-m-trace-plan-100k-2-step-1", "45M Searchformer, step 1"),
        ("sokoban-7722-m-trace-plan-100k-0-step-2", "45M Searchformer, step 2"),
        ("sokoban-7722-m-trace-plan-100k-1-step-2", "45M Searchformer, step 2"),
        ("sokoban-7722-m-trace-plan-100k-2-step-2", "45M Searchformer, step 2"),
        ("sokoban-7722-m-trace-plan-100k-0-step-3", "45M Searchformer, step 3"),
        ("sokoban-7722-m-trace-plan-100k-1-step-3", "45M Searchformer, step 3"),
        ("sokoban-7722-m-trace-plan-100k-2-step-3", "45M Searchformer, step 3"),
        ("sokoban-7722-l-plan-only-100k-0", "175M Solution only"),
        ("sokoban-7722-l-plan-only-100k-1", "175M Solution only"),
        ("sokoban-7722-l-plan-only-100k-2", "175M Solution only"),
        ("sokoban-7722-l-trace-plan-100k-0", "175M Search augmented"),
        ("sokoban-7722-l-trace-plan-100k-1", "175M Search augmented"),
        ("sokoban-7722-l-trace-plan-100k-2", "175M Search augmented"),
        ("sokoban-7722-xl-plan-only-100k-0", "757M Solution only"),
        ("sokoban-7722-xl-plan-only-100k-1", "757M Solution only"),
        ("sokoban-7722-xl-plan-only-100k-2", "757M Solution only"),
    ],
    columns=["checkpoint_id", "set"]
)
evals_df = pd.merge(evals_df, ckpt_2_set, on=["checkpoint_id"], how="left")
evals_df.rename(
    columns={
        "reasoning_length": "trace_len",
        "trace_tokens": "rollout_len"
    }, 
    inplace=True
)

pygame 2.5.2 (SDL 2.28.2, Python 3.10.13)
Hello from the pygame community. https://www.pygame.org/contribute.html


DEBUG - 2024-04-26 14:02:10 - root - Loading dataset with id 65b6bbb143977a604499c572.
DEBUG - 2024-04-26 14:02:10 - root - Loading dataset with id 65b6bbfde67a96d8e48e008b.
DEBUG - 2024-04-26 14:02:10 - root - Loading dataset with id 65b6bbc6b4bf8d5e7493570e.
DEBUG - 2024-04-26 14:02:10 - root - Loading dataset with id 65b6bbc50fe338eeca6cea91.
DEBUG - 2024-04-26 14:02:10 - root - Loading dataset with id 65b6bbc479cc93df45d17a5c.
DEBUG - 2024-04-26 14:02:10 - root - Loading dataset with id 65b6bc118f67678f03ec99db.
DEBUG - 2024-04-26 14:02:10 - root - Loading dataset with id 65b6bc24ff07df4ca9191ab8.
DEBUG - 2024-04-26 14:02:10 - root - Loading dataset with id 65b6bc2552b35c10a841e311.
DEBUG - 2024-04-26 14:02:10 - root - Loading dataset with id 65b6bc26297b5484efad216f.
DEBUG - 2024-04-26 14:02:10 - root - Loading dataset with id 65b6bc2a36dfc75fae7af3d8.
INFO - 2024-04-26 14:02:10 - root - Connecting to mongodb://localhost:27017/mongo
DEBUG - 2024-04-26 14:02:10 - root - Loading dat

Evaluation results for each generated sequence are stored in this dataframe.

In [4]:
evals_df.head()

Unnamed: 0,checkpoint_id,sequence_id,plan_optimal,plan_correct,rollout_len,trace_len,plan_length,optimal_plan_length,tokens_improved,set
0,sokoban-7722-m-trace-plan-100k-0,-9223366637686761758,True,True,2511,2337,10,10,-174,45M Search augmented
1,sokoban-7722-m-trace-plan-100k-0,-9223366637686761758,True,True,3198,2337,10,10,-861,45M Search augmented
2,sokoban-7722-m-trace-plan-100k-0,-9223366637686761758,True,True,2514,2337,10,10,-177,45M Search augmented
3,sokoban-7722-m-trace-plan-100k-0,-9223366637686761758,True,True,3069,2337,10,10,-732,45M Search augmented
4,sokoban-7722-m-trace-plan-100k-0,-9223366637686761758,True,True,2943,2337,10,10,-606,45M Search augmented


### Optimally solved tasks

In [5]:
optimally_answered = evals_df[["checkpoint_id", "set", "sequence_id", "plan_optimal"]].groupby(["checkpoint_id", "set", "sequence_id"]).any().reset_index()
optimally_answered = optimally_answered[["checkpoint_id", "set", "plan_optimal"]].groupby(["checkpoint_id", "set"]).mean().reset_index()
optimally_answered["plan_optimal"] *= 100
optimally_answered.rename(columns={"plan_optimal": "optimally answered [in %]"}, inplace=True)
optimally_answered.sort_values(by=["set", "checkpoint_id"])

Unnamed: 0,checkpoint_id,set,optimally answered [in %]
3,sokoban-7722-l-trace-plan-100k-0,175M Search augmented,93.0
4,sokoban-7722-l-trace-plan-100k-1,175M Search augmented,91.5
5,sokoban-7722-l-trace-plan-100k-2,175M Search augmented,95.0
0,sokoban-7722-l-plan-only-100k-0,175M Solution only,89.5
1,sokoban-7722-l-plan-only-100k-1,175M Solution only,91.0
2,sokoban-7722-l-plan-only-100k-2,175M Solution only,88.5
9,sokoban-7722-m-trace-plan-100k-0,45M Search augmented,91.0
13,sokoban-7722-m-trace-plan-100k-1,45M Search augmented,93.5
17,sokoban-7722-m-trace-plan-100k-2,45M Search augmented,88.0
10,sokoban-7722-m-trace-plan-100k-0-step-1,"45M Searchformer, step 1",93.0


In [6]:
optimally_answered_avg = optimally_answered[["set", "optimally answered [in %]"]].groupby(["set"]).mean()
optimally_answered_sem = optimally_answered[["set", "optimally answered [in %]"]].groupby(["set"]).sem().rename(columns={"optimally answered [in %]": "optimally answered [in %] [err]"})
pd.merge(optimally_answered_avg, optimally_answered_sem, on="set").reset_index()

Unnamed: 0,set,optimally answered [in %],optimally answered [in %] [err]
0,175M Search augmented,93.166667,1.013794
1,175M Solution only,89.666667,0.726483
2,45M Search augmented,90.833333,1.589899
3,"45M Searchformer, step 1",93.5,1.040833
4,"45M Searchformer, step 2",93.333333,0.600925
5,"45M Searchformer, step 3",93.666667,1.589899
6,45M Solution only,86.833333,0.333333
7,757M Solution only,92.166667,1.20185


### Correctly solved tasks

In [7]:
correctly_answered = evals_df[["checkpoint_id", "set", "sequence_id", "plan_correct"]].groupby(["checkpoint_id", "set", "sequence_id"]).any().reset_index()
correctly_answered = correctly_answered[["checkpoint_id", "set", "plan_correct"]].groupby(["checkpoint_id", "set"]).mean().reset_index()
correctly_answered["plan_correct"] *= 100
correctly_answered.rename(columns={"plan_correct": "correctly answered [in %]"}, inplace=True)
correctly_answered.sort_values(by=["set", "checkpoint_id"])

Unnamed: 0,checkpoint_id,set,correctly answered [in %]
3,sokoban-7722-l-trace-plan-100k-0,175M Search augmented,94.5
4,sokoban-7722-l-trace-plan-100k-1,175M Search augmented,94.0
5,sokoban-7722-l-trace-plan-100k-2,175M Search augmented,97.0
0,sokoban-7722-l-plan-only-100k-0,175M Solution only,95.0
1,sokoban-7722-l-plan-only-100k-1,175M Solution only,96.0
2,sokoban-7722-l-plan-only-100k-2,175M Solution only,95.0
9,sokoban-7722-m-trace-plan-100k-0,45M Search augmented,92.0
13,sokoban-7722-m-trace-plan-100k-1,45M Search augmented,94.5
17,sokoban-7722-m-trace-plan-100k-2,45M Search augmented,91.0
10,sokoban-7722-m-trace-plan-100k-0-step-1,"45M Searchformer, step 1",96.5


In [8]:
correctly_answered_avg = correctly_answered[["set", "correctly answered [in %]"]].groupby(["set"]).mean()
correctly_answered_sem = correctly_answered[["set", "correctly answered [in %]"]].groupby(["set"]).sem().rename(columns={"correctly answered [in %]": "correctly answered [in %] [err]"})
pd.merge(correctly_answered_avg, correctly_answered_sem, on="set").reset_index()

Unnamed: 0,set,correctly answered [in %],correctly answered [in %] [err]
0,175M Search augmented,95.166667,0.927961
1,175M Solution only,95.333333,0.333333
2,45M Search augmented,92.5,1.040833
3,"45M Searchformer, step 1",95.5,1.0
4,"45M Searchformer, step 2",96.0,0.5
5,"45M Searchformer, step 3",95.5,0.763763
6,45M Solution only,90.333333,1.013794
7,757M Solution only,96.5,0.0


### SWC scores

This computes the SWC scores for each Sokoban run.
The score is defined as 

$$\text{SWC} := \frac{1}{n} \sum_{i=1}^n c_i \frac{ l_i^* }{ \max \{ l_i,l_i^* \} }.$$

First, the A* reference dataset is loaded.
This dataset was generated by repeatedly running A* search on each test task and mapping each execution trace to a token sequence.
The following cell loads the resulting dataset from MongoDB.

In [9]:
def trace_2_stats(trace_doc):
    return {
        "sequence_id": trace_doc["sequence_id"],
        "prompt_len": len(trace_doc["prompt"]),
        "reasoning_len": len(trace_doc["reasoning"]),
        "plan_len": len(trace_doc["plan"]),
    }


client = mongodb_client()
db = client["sokobanAStarRefDataDB"]
astar_trace_len = pd.DataFrame(map(trace_2_stats, db.trace_variance.find()))
astar_trace_len = astar_trace_len[["sequence_id", "reasoning_len"]].groupby(["sequence_id"]).mean().reset_index().rename(columns={"reasoning_len": "trace_len A*"})

INFO - 2024-04-26 14:02:39 - root - Connecting to mongodb://localhost:27017/mongo


## SWC calculations

The SWC score is calculated for every Sokoban experiment.
Specifically, the score

$$\text{SWC} := \frac{1}{n} \sum_{i=1}^n c_i \frac{ l_i^* }{ \max \{ l_i,l_i^* \} }.$$

In [10]:
opt_plan_len = evals_df[["sequence_id", "optimal_plan_length"]].drop_duplicates()

group_cols = ["set", "checkpoint_id", "sequence_id"]

pred_plan_len_seq = evals_df[["set", "checkpoint_id", "sequence_id", "plan_correct", "plan_length"]].copy()
pred_plan_records = []
for cols, df in pred_plan_len_seq.groupby(group_cols):
    seq_dict = {k: v for k, v in zip(group_cols, cols)}
    
    correct_predictions = df[df["plan_correct"]]
    if len(correct_predictions) > 0:
        seq_dict["plan_length"] = correct_predictions["plan_length"].values.min()
        seq_dict["plan_correct"] = True
    else:
        seq_dict["plan_length"] = 0
        seq_dict["plan_correct"] = False

    pred_plan_records.append(seq_dict)
pred_plan_len = pd.DataFrame(pred_plan_records)
pred_plan_len = pd.merge(pred_plan_len, opt_plan_len, on=["sequence_id"])
pred_plan_len

Unnamed: 0,set,checkpoint_id,sequence_id,plan_length,plan_correct,optimal_plan_length
0,175M Search augmented,sokoban-7722-l-trace-plan-100k-0,-9223366637686761758,10,True,10
1,175M Search augmented,sokoban-7722-l-trace-plan-100k-1,-9223366637686761758,10,True,10
2,175M Search augmented,sokoban-7722-l-trace-plan-100k-2,-9223366637686761758,10,True,10
3,175M Solution only,sokoban-7722-l-plan-only-100k-0,-9223366637686761758,10,True,10
4,175M Solution only,sokoban-7722-l-plan-only-100k-1,-9223366637686761758,10,True,10
...,...,...,...,...,...,...
4795,45M Solution only,sokoban-7722-m-plan-only-100k-1,-9189986881111192192,8,True,8
4796,45M Solution only,sokoban-7722-m-plan-only-100k-2,-9189986881111192192,8,True,8
4797,757M Solution only,sokoban-7722-xl-plan-only-100k-0,-9189986881111192192,8,True,8
4798,757M Solution only,sokoban-7722-xl-plan-only-100k-1,-9189986881111192192,8,True,8


In [11]:
plan_corr = pred_plan_len["plan_correct"].values.astype(np.float32)
pred_len = pred_plan_len["plan_length"].values.astype(np.float32)
opt_len = pred_plan_len["optimal_plan_length"].values.astype(np.float32)
max_len = np.max(np.stack((pred_len, opt_len)), axis=0)

pred_plan_len["spl"] = plan_corr * opt_len / max_len
spl_df = pred_plan_len[["set", "checkpoint_id", "spl"]].groupby(["set", "checkpoint_id"]).mean().reset_index()

spl_df_avg = spl_df[["set", "spl"]].groupby(["set"]).mean().reset_index()
spl_df_sem = spl_df[["set", "spl"]].groupby(["set"]).sem().reset_index().rename(columns={"spl": "spl [sem]"})
pd.merge(spl_df_avg, spl_df_sem, on=["set"])

Unnamed: 0,set,spl,spl [sem]
0,175M Search augmented,0.948871,0.009564
1,175M Solution only,0.946101,0.004461
2,45M Search augmented,0.923504,0.010904
3,"45M Searchformer, step 1",0.953079,0.009672
4,"45M Searchformer, step 2",0.957214,0.004789
5,"45M Searchformer, step 3",0.952837,0.008897
6,45M Solution only,0.898819,0.008539
7,757M Solution only,0.958283,0.001732


## ILR calculations

The ILR score is defined as

$$\text{ILR} := \frac{1}{n} \sum_{i=1}^n c_i \frac{ t_i^* }{ t_i }.$$

This calculation is only done for the search-augmented and Searchformer models with the following ids.

In [12]:
run_ids = [
        "sokoban-7722-m-trace-plan-100k-0"       ,
        "sokoban-7722-m-trace-plan-100k-1"       ,
        "sokoban-7722-m-trace-plan-100k-2"       ,
        "sokoban-7722-m-trace-plan-100k-0-step-1",
        "sokoban-7722-m-trace-plan-100k-1-step-1",
        "sokoban-7722-m-trace-plan-100k-2-step-1",
        "sokoban-7722-m-trace-plan-100k-0-step-2",
        "sokoban-7722-m-trace-plan-100k-1-step-2",
        "sokoban-7722-m-trace-plan-100k-2-step-2",
        "sokoban-7722-m-trace-plan-100k-0-step-3",
        "sokoban-7722-m-trace-plan-100k-1-step-3",
        "sokoban-7722-m-trace-plan-100k-2-step-3",
        "sokoban-7722-l-trace-plan-100k-0"       ,
        "sokoban-7722-l-trace-plan-100k-1"       ,
        "sokoban-7722-l-trace-plan-100k-2"       ,
]

### ILR-on-solved

ILR calculation on sequences ending in a correct plan.

First the average rollout length across all sequences ending in an optimal plan are computed.
If no sequence ended in a correct plan, then the task is flagged with `plan_correct` set to false.

In [13]:
group_cols = ["set", "checkpoint_id", "sequence_id"]

trace_len_records = []
for cols, df in evals_df.groupby(group_cols):
    seq_dict = {k: v for k, v in zip(group_cols, cols)}
    
    correct_predictions = df[df["plan_correct"]]
    if len(correct_predictions) > 0:
        seq_dict["rollout_len"] = correct_predictions["rollout_len"].values.mean()
        seq_dict["plan_correct"] = True
    else:
        seq_dict["rollout_len"] = 1
        seq_dict["plan_correct"] = False

    trace_len_records.append(seq_dict)
corr_trace_len = pd.DataFrame(trace_len_records)
corr_trace_len = pd.merge(corr_trace_len, astar_trace_len, on=["sequence_id"])
corr_trace_len = corr_trace_len[corr_trace_len["checkpoint_id"].isin(run_ids)]
corr_trace_len

Unnamed: 0,set,checkpoint_id,sequence_id,rollout_len,plan_correct,trace_len A*
0,175M Search augmented,sokoban-7722-l-trace-plan-100k-0,-9223366637686761758,2688.984127,True,2722.968750
1,175M Search augmented,sokoban-7722-l-trace-plan-100k-1,-9223366637686761758,2625.796875,True,2722.968750
2,175M Search augmented,sokoban-7722-l-trace-plan-100k-2,-9223366637686761758,2733.532258,True,2722.968750
6,45M Search augmented,sokoban-7722-m-trace-plan-100k-0,-9223366637686761758,2644.809524,True,2722.968750
7,45M Search augmented,sokoban-7722-m-trace-plan-100k-1,-9223366637686761758,2651.774194,True,2722.968750
...,...,...,...,...,...,...
4789,"45M Searchformer, step 2",sokoban-7722-m-trace-plan-100k-1-step-2,-9189986881111192192,621.920635,True,793.640625
4790,"45M Searchformer, step 2",sokoban-7722-m-trace-plan-100k-2-step-2,-9189986881111192192,775.583333,True,793.640625
4791,"45M Searchformer, step 3",sokoban-7722-m-trace-plan-100k-0-step-3,-9189986881111192192,530.516129,True,793.640625
4792,"45M Searchformer, step 3",sokoban-7722-m-trace-plan-100k-1-step-3,-9189986881111192192,546.929825,True,793.640625


In [14]:
plan_corr = corr_trace_len["plan_correct"].values.astype(np.float32)
astar_trace_len_arr = corr_trace_len["trace_len A*"].values.astype(np.float32)
model_trace_len_arr = corr_trace_len["rollout_len"].values.astype(np.float32)

corr_trace_len["irl-on-solved"] = plan_corr * astar_trace_len_arr / model_trace_len_arr
corr_trace_len_ilr = corr_trace_len[["set", "checkpoint_id", "irl-on-solved"]].groupby(["set", "checkpoint_id"]).mean().reset_index()

corr_ilr_df_avg = corr_trace_len_ilr[["set", "irl-on-solved"]].groupby(["set"]).mean().reset_index()
corr_ilr_df_sem = corr_trace_len_ilr[["set", "irl-on-solved"]].groupby(["set"]).sem().reset_index().rename(columns={"irl-on-solved": "irl-on-solved [sem]"})
pd.merge(corr_ilr_df_avg, corr_ilr_df_sem, on=["set"])

Unnamed: 0,set,irl-on-solved,irl-on-solved [sem]
0,175M Search augmented,0.925034,0.010261
1,45M Search augmented,0.908225,0.020143
2,"45M Searchformer, step 1",1.054411,0.019714
3,"45M Searchformer, step 2",1.157814,0.024698
4,"45M Searchformer, step 3",1.292427,0.044076


### ILR-on-optimal

ILR calculation on sequences ending in an optimal plan.

First the average rollout length across all sequences ending in an optimal plan are computed.
If no sequence ended in an optimal plan, then the task is flagged with `plan_optimal` set to false.

In [15]:
group_cols = ["set", "checkpoint_id", "sequence_id"]

trace_len_records = []
for cols, df in evals_df.groupby(group_cols):
    seq_dict = {k: v for k, v in zip(group_cols, cols)}
    
    correct_predictions = df[df["plan_optimal"]]
    if len(correct_predictions) > 0:
        seq_dict["rollout_len"] = correct_predictions["rollout_len"].values.mean()
        seq_dict["plan_optimal"] = True
    else:
        seq_dict["rollout_len"] = 1
        seq_dict["plan_optimal"] = False

    trace_len_records.append(seq_dict)
opt_trace_len = pd.DataFrame(trace_len_records)
opt_trace_len = pd.merge(opt_trace_len, astar_trace_len, on=["sequence_id"])
opt_trace_len = opt_trace_len[opt_trace_len["checkpoint_id"].isin(run_ids)]
opt_trace_len

Unnamed: 0,set,checkpoint_id,sequence_id,rollout_len,plan_optimal,trace_len A*
0,175M Search augmented,sokoban-7722-l-trace-plan-100k-0,-9223366637686761758,2639.209677,True,2722.968750
1,175M Search augmented,sokoban-7722-l-trace-plan-100k-1,-9223366637686761758,2590.952381,True,2722.968750
2,175M Search augmented,sokoban-7722-l-trace-plan-100k-2,-9223366637686761758,2733.532258,True,2722.968750
6,45M Search augmented,sokoban-7722-m-trace-plan-100k-0,-9223366637686761758,2644.809524,True,2722.968750
7,45M Search augmented,sokoban-7722-m-trace-plan-100k-1,-9223366637686761758,2651.774194,True,2722.968750
...,...,...,...,...,...,...
4789,"45M Searchformer, step 2",sokoban-7722-m-trace-plan-100k-1-step-2,-9189986881111192192,621.920635,True,793.640625
4790,"45M Searchformer, step 2",sokoban-7722-m-trace-plan-100k-2-step-2,-9189986881111192192,636.448276,True,793.640625
4791,"45M Searchformer, step 3",sokoban-7722-m-trace-plan-100k-0-step-3,-9189986881111192192,530.516129,True,793.640625
4792,"45M Searchformer, step 3",sokoban-7722-m-trace-plan-100k-1-step-3,-9189986881111192192,546.929825,True,793.640625


In [16]:
plan_corr = opt_trace_len["plan_optimal"].values.astype(np.float32)
astar_trace_len_arr = opt_trace_len["trace_len A*"].values.astype(np.float32)
model_trace_len_arr = opt_trace_len["rollout_len"].values.astype(np.float32)

opt_trace_len["irl-on-optimal"] = plan_corr * astar_trace_len_arr / model_trace_len_arr
opt_trace_len_ilr = opt_trace_len[["set", "checkpoint_id", "irl-on-optimal"]].groupby(["set", "checkpoint_id"]).mean().reset_index()

opt_ilr_df_avg = opt_trace_len_ilr[["set", "irl-on-optimal"]].groupby(["set"]).mean().reset_index()
opt_ilr_df_sem = opt_trace_len_ilr[["set", "irl-on-optimal"]].groupby(["set"]).sem().reset_index().rename(columns={"irl-on-optimal": "irl-on-optimal [sem]"})
pd.merge(opt_ilr_df_avg, opt_ilr_df_sem, on=["set"])

Unnamed: 0,set,irl-on-optimal,irl-on-optimal [sem]
0,175M Search augmented,0.932697,0.010753
1,45M Search augmented,0.9192,0.018816
2,"45M Searchformer, step 1",1.06222,0.015116
3,"45M Searchformer, step 2",1.180596,0.011529
4,"45M Searchformer, step 3",1.343152,0.067
