In [1]:
! pip -q install pandas plotly
! pip -q install nbformat

In [18]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import glob
import json
from plotly.subplots import make_subplots
import math
import os

In [19]:
result_files = glob.glob("eval_results/preliminary/*/*_iter_*/*.json")
baseline_hptl2_dedup = glob.glob("/scratch/project_462000353/pyysalos/second-hplt-eval/eval_results/results/hplt_v2_dedup_iter_00*/*.json")
baseline_fineweb = glob.glob("/scratch/project_462000353/pyysalos/second-hplt-eval/eval_results/results/fineweb_iter_00*/*.json")

tasks = {"leaderboard|hellaswag|0":"hellaswag",
         "lighteval|arc:easy|0":"arc_easy",
         "lighteval|openbookqa|0": "openbookqa",
         "lighteval|piqa|0": "piqa"}

In [32]:
def extract_results(result_files, recalc_mean=False):
    data = []
    for f in result_files:
        splitted = f.split("/")
        #print(splitted)
        register = splitted[-2].split("_iter_")[0]
        step = splitted[-2].replace(f'{register}_iter_', "")
        step=int(step)
        #print(register)
        #print(step)
        with open(f, 'r') as result:
            j = json.load(result)
            results = {"step":int(step)}
            if step > 15000:
                continue
            results_all = j["results"]["all"]
            results = {**results, **results_all}
            for task, task_name in tasks.items():
                r = {k+f"_{task_name}":v for k,v in j["results"][task].items()}
                results = {**results, **r}
            #print(results)
            df = pd.DataFrame.from_dict({register:{"step":int(step), **results}})
            data.append(df.T)
            #print(register, j["results"]["all"])
            #data[register] = {step: j["results"]["all"]}

    df = pd.concat(data)
    df = df.reset_index()
    df = df.rename(columns={"index":"register"})
    # problem: step does not go to correct place on the plot
    # sort it as numerical, then make it categorical so plotting works
    # ordered=True does nothing here because plotly does not respect it; only sorting works
    df=df.sort_values("step")
    df["tokens"] = df["step"]*2.1e6
    #df["step"] = pd.Categorical(df['step'], categories=sorted(df['step'].astype(int).unique()), ordered=True)
    df["step"]= df['step'].apply(lambda x: str(round(x)))
    #df["tokens"]= df['tokens'].apply(lambda x: millify(x))
    if recalc_mean:
        df["acc_norm_original"] = df["acc_norm"]
        #df["acc_norm"] = df[["acc_norm_hellaswag", "acc_norm_openbookqa", "acc_norm_piqa", "acc_norm_arc_easy"]].mean(axis=1)
        df["acc_norm"] = df[[f"acc_norm_{taskname}" for taskname in tasks.values()]].mean(axis=1)
    return df

df1 = extract_results(result_files)
df2 = extract_results(baseline_hptl2_dedup, recalc_mean=True)
df3 = extract_results(baseline_fineweb, recalc_mean=True)

In [33]:
df = pd.concat([df1, df2, df3])
display(df)

Unnamed: 0,register,acc,acc_arc_easy,acc_hellaswag,acc_norm,acc_norm_arc_easy,acc_norm_hellaswag,acc_norm_openbookqa,acc_norm_piqa,acc_norm_stderr,...,acc_openbookqa,acc_piqa,acc_stderr,acc_stderr_arc_easy,acc_stderr_hellaswag,acc_stderr_openbookqa,acc_stderr_piqa,step,tokens,acc_norm_original
58,LY,0.311213,0.305976,0.260805,0.330152,0.292508,0.265087,0.232,0.531012,0.011070,...,0.134,0.544070,0.010177,0.009456,0.004382,0.015250,0.011620,1000,2.100000e+09,
27,ne,0.325739,0.338805,0.266481,0.343494,0.317761,0.261601,0.250,0.544614,0.011236,...,0.140,0.557671,0.010311,0.009712,0.004412,0.015533,0.011588,1000,2.100000e+09,
156,IP,0.322605,0.308923,0.272257,0.346753,0.306397,0.278431,0.232,0.570185,0.011095,...,0.126,0.583243,0.010070,0.009481,0.004442,0.014856,0.011503,1000,2.100000e+09,
25,SP,0.330287,0.335017,0.261402,0.349751,0.326599,0.265983,0.258,0.548422,0.011308,...,0.154,0.570729,0.010444,0.009685,0.004385,0.016158,0.011549,1000,2.100000e+09,
50,HI,0.355569,0.322391,0.283609,0.389250,0.316919,0.301434,0.264,0.674646,0.011198,...,0.140,0.676279,0.010135,0.009591,0.004498,0.015533,0.010917,1000,2.100000e+09,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
27,fineweb,0.432052,0.562710,0.409480,0.516159,0.490741,0.521211,0.322,0.730686,0.011880,...,0.208,0.731774,0.011243,0.010179,0.004907,0.018170,0.010337,13000,2.730000e+10,0.465828
15,fineweb,0.429527,0.552609,0.412567,0.513739,0.480640,0.524895,0.322,0.727421,0.011897,...,0.212,0.736670,0.011211,0.010203,0.004913,0.018297,0.010276,13500,2.835000e+10,0.464575
7,fineweb,0.429570,0.555135,0.412169,0.512926,0.482323,0.530870,0.310,0.728509,0.011841,...,0.224,0.723613,0.011312,0.010197,0.004912,0.018664,0.010434,14000,2.940000e+10,0.463242
22,fineweb,0.430369,0.558502,0.413165,0.515199,0.485269,0.530472,0.316,0.729053,0.011861,...,0.214,0.730686,0.011243,0.010189,0.004914,0.018360,0.010350,14500,3.045000e+10,0.465060


In [35]:
fig = px.line(df, y="acc_norm", x="tokens", color="register", color_discrete_sequence=px.colors.qualitative.Alphabet)
fig.update_layout(
    title="Average accuracy by Register",
    height=600,  # Adjust the height of the plot
    width=850,
    showlegend=True,  # Show legend
)
fig.update_layout(
    template='simple_white',
    #plot_bgcolor='rgba(0, 0, 0, 0)',
    #paper_bgcolor='rgba(0, 0, 0, 0)',
)
fig.show()