In [1]:
import sys

#sys.path.append("/Users/evanracah/Dropbox/projects/atari-representation-learning/")

import wandb
import pandas as pd
import numpy as np
from copy import deepcopy

In [2]:
def add_bold_max(latex_str, ignore_last_column=True, ignore_first_column=False, margin=0.01):
    strs = []
    for r in latex_str.split("\\\\"):
        if "." in r:
            s = r.split("&")
            nums = [float(n) for n in s[1:] if "." in n or "NaN" in  n]
            #print(nums)
            am_nums = deepcopy(nums)
            if ignore_last_column:
                am_nums = am_nums[:-1]
                
            if ignore_first_column:
                am_nums = am_nums[1:]
  
                
                
            filt_nums = [n if not np.isnan(n) else -1 for n in am_nums ]
            amax, maxx = [np.argmax(filt_nums)], np.max(filt_nums)
            for i,n in enumerate(filt_nums):
                if i == amax[0]:
                    continue
                if n+margin >= maxx:
                    amax.append(i)
            for am in amax:
                if ignore_first_column:
                    am = am+1
                nums[am] = "\\textbf{%5.2f} "%(nums[am])
            nums = ["%5.2f"%n if not isinstance(n,str) else n for n in nums ]
            nr = [s[0]] + nums
            nr = "  &  ".join(nr)
        else:
            nr = r
        strs.append(nr)

    final_latex = "\\\\".join(strs)
    final_latex = final_latex.replace("nan","NaN")
    return final_latex

def print_latex(ldf,ignore_last_column=True, ignore_first_column=False):
    latex_ldf = ldf.to_latex()
    print(add_bold_max(latex_ldf, ignore_last_column=ignore_last_column, ignore_first_column=ignore_first_column))

In [3]:
wandb_proj = "eracah/coors-production"

In [4]:
def get_table_for(metric_name):
    api = wandb.Api()
    runs = list(api.runs(wandb_proj))
    #                      {"state": "crashed", 
    #                                                 }))

    runs = [run for run in runs if  metric_name in run.summary_metrics ]

    rd = [run.summary_metrics for run in runs if  metric_name in run.summary_metrics]
    df = pd.DataFrame(rd)
    df['env_name'] = [run.config['env_name'].split("NoFrameskip")[0].lower().capitalize() for run in runs]
    df['method'] = [ run.config['method'] for run in runs]

    df = df[[metric_name, "method", "env_name"]]
    methods = ["random-cnn", "infonce", "supervised"] #methods = list(set(df.method))
    mdfs = [df[df.method==method].set_index("env_name").rename(columns={metric_name:method})[method] for method in methods]

    final_df = pd.concat(mdfs,axis=1,sort=True)

    final_df = final_df.round(2)


    latex_ldf = final_df.to_latex()
    bold_latex = add_bold_max(latex_ldf, ignore_last_column=True, ignore_first_column=False)
    print(latex_ldf)
    return latex_ldf

In [5]:
metric_name = "assigned_slot_explicitness_linear_across_categories_avg_f1"
latex_str = get_table_for(metric_name)

for line in latex_str.split("\n"):
    if "&" in line:
        line = line.replace("\\\\", "&  & \\\\")
    print(line)

\begin{tabular}{lrrr}
\toprule
{} &  random-cnn &  infonce &  supervised \\
\midrule
Asteroids     &        0.70 &     0.77 &         NaN \\
Berzerk       &        0.55 &     0.73 &        0.54 \\
Bowling       &         NaN &     0.71 &        0.96 \\
Boxing        &        0.01 &     0.17 &        0.82 \\
Breakout      &        0.19 &     0.52 &        0.91 \\
Demonattack   &        0.16 &     0.30 &        0.81 \\
Freeway       &         NaN &     0.62 &        0.96 \\
Frostbite     &        0.23 &     0.69 &        0.81 \\
Hero          &        0.28 &     0.66 &        0.98 \\
Mspacman      &         NaN &     0.46 &        0.79 \\
Pitfall       &         NaN &     0.42 &        0.88 \\
Pong          &        0.08 &     0.36 &        0.86 \\
Privateeye    &         NaN &     0.80 &        0.99 \\
Qbert         &        0.36 &     0.46 &        0.75 \\
Riverraid     &         NaN &     0.13 &        0.59 \\
Seaquest      &         NaN &     0.63 &        0.79 \\
Spaceinvaders &    

In [6]:
metric_name = "assigned_slot_explicitness_mlp_across_categories_avg_f1"

latex_str = get_table_for(metric_name)

\begin{tabular}{lrrr}
\toprule
{} &  random-cnn &  infonce &  supervised \\
\midrule
Asteroids     &         NaN &     0.77 &         NaN \\
Berzerk       &        0.78 &     0.77 &        0.58 \\
Bowling       &         NaN &     0.90 &        0.98 \\
Boxing        &         NaN &     0.47 &        0.83 \\
Breakout      &        0.54 &     0.72 &        0.93 \\
Demonattack   &         NaN &     0.43 &        0.82 \\
Freeway       &         NaN &     0.99 &        0.97 \\
Frostbite     &        0.83 &     0.88 &        0.83 \\
Hero          &        0.76 &     0.82 &        0.98 \\
Mspacman      &         NaN &     0.58 &        0.81 \\
Pitfall       &         NaN &     0.65 &        0.90 \\
Pong          &        0.39 &     0.70 &        0.89 \\
Privateeye    &         NaN &     0.92 &        0.99 \\
Qbert         &        0.53 &     0.57 &        0.77 \\
Riverraid     &         NaN &     0.25 &        0.59 \\
Seaquest      &         NaN &     0.77 &        0.80 \\
Spaceinvaders &    

In [7]:
for line in latex_str.split("\n"):
    if "&" in line:
        line = line.replace("\\\\", "&  & \\\\")
    print(line)

\begin{tabular}{lrrr}
\toprule
{} &  random-cnn &  infonce &  supervised &  & \\
\midrule
Asteroids     &         NaN &     0.77 &         NaN &  & \\
Berzerk       &        0.78 &     0.77 &        0.58 &  & \\
Bowling       &         NaN &     0.90 &        0.98 &  & \\
Boxing        &         NaN &     0.47 &        0.83 &  & \\
Breakout      &        0.54 &     0.72 &        0.93 &  & \\
Demonattack   &         NaN &     0.43 &        0.82 &  & \\
Freeway       &         NaN &     0.99 &        0.97 &  & \\
Frostbite     &        0.83 &     0.88 &        0.83 &  & \\
Hero          &        0.76 &     0.82 &        0.98 &  & \\
Mspacman      &         NaN &     0.58 &        0.81 &  & \\
Pitfall       &         NaN &     0.65 &        0.90 &  & \\
Pong          &        0.39 &     0.70 &        0.89 &  & \\
Privateeye    &         NaN &     0.92 &        0.99 &  & \\
Qbert         &        0.53 &     0.57 &        0.77 &  & \\
Riverraid     &         NaN &     0.25 &        0.59 &  

In [None]:
def get_main_df(wandb_proj="curl-atari/curl-atari-post-neurips-2", collect_mode="random_agent"):
#     if metric == "f1":
#         metric_name = "f1score"
#     elif metric == "acc":
#         metric_name = "test_acc"
    api = wandb.Api()
    
    runs = list(api.runs(wandb_proj, 
                         {"state": "finished", 
                          "config.collect_mode":collect_mode,
                                                    }))

    #df = pd.DataFrame()
    rd = [run.summary_metrics for run in runs]
    df = pd.DataFrame(rd)
    df['env_name'] = [run.config['env_name'].split("NoFrameskip")[0].lower() for run in runs]
    ms = []
    for run in runs:
        method = run.config['method']
        method = translate_method_name(method)
        ms.append(method)
    df['method'] = ms
    for metric_name in ["f1_score", "test_acc"]:
        metrics = []
        for run in runs:
            if "mean_mean_" + metric_name in  run.summary_metrics:
                metrics.append(run.summary_metrics["mean_mean_" + metric_name] )
            elif "mean_" + metric_name in run.summary_metrics:
                metrics.append(run.summary_metrics["mean_" + metric_name])
            else:
                metrics.append(np.nan)

        df[metric_name] = metrics
    df["timestamp"] = [run.summary_metrics['_timestamp'] if "_timestamp" in run.summary_metrics else 0  for run in runs  ]

    if collect_mode == "random_agent":
        for method in ["pixel-pred"]:
            cf = df[df["method"]==method]
            cf = keep_most_recent(cf)
            df = df[df["method"]!=method]
            df = pd.concat([df,cf],axis=0,sort=True)
    return df

In [1]:
from atariari.benchmark.ram_annotations import atari_dict
from atariari.benchmark.categorization import summary_key_dict as skd,  unused_keys, detailed_key_dict, all_keys

In [2]:
def get_game(df, game, methods=["cpc","st-dim"], metric_name="f1score"):
    #game = "seaquest"

    gdf = df[df["env_name"]==game]
    mdfs = [gdf[(gdf["method"]==method)] for method in methods] 

    mdf = pd.concat(mdfs,axis=0)
    cols = [c for c in mdf.columns if metric_name in c and c.replace("_"+metric_name, "") in atari_dict[game].keys()]
    col_change_dic = {col:col.split("_"+metric_name)[0] for col in cols}
    cols = ["method"] +cols 

    gmdf = mdf[cols]
    gmdf = gmdf.rename(columns=col_change_dic)

    gmdf = gmdf.round(2)
    gmdf = gmdf.set_index("method")
    gmdf = gmdf.T
    return gmdf



def add_bold_max(latex_str, ignore_last_column=True, ignore_first_column=False, margin=0.01):
    strs = []
    for r in latex_str.split("\\\\"):
        if "." in r:
            s = r.split("&")
            nums = [float(n) for n in s[1:] if "." in n or "NaN" in  n]
            #print(nums)
            am_nums = deepcopy(nums)
            if ignore_last_column:
                am_nums = am_nums[:-1]
                
            if ignore_first_column:
                am_nums = am_nums[1:]
  
                
                
            filt_nums = [n if not np.isnan(n) else -1 for n in am_nums ]
            amax, maxx = [np.argmax(filt_nums)], np.max(filt_nums)
            for i,n in enumerate(filt_nums):
                if i == amax[0]:
                    continue
                if n+margin >= maxx:
                    amax.append(i)
            for am in amax:
                if ignore_first_column:
                    am = am+1
                nums[am] = "\\textbf{%5.2f} "%(nums[am])
            nums = ["%5.2f"%n if not isinstance(n,str) else n for n in nums ]
            nr = [s[0]] + nums
            nr = "  &  ".join(nr)
        else:
            nr = r
        strs.append(nr)

    final_latex = "\\\\".join(strs)
    final_latex = final_latex.replace("nan","NaN")
    return final_latex

def keep_most_recent(df):
    cf =df.reset_index()

    cfs = [cf[cf["env_name"]==game] for game in atari_dict.keys()]

    cfs =  [cdf["timestamp"].idxmax() for cdf in cfs]

    cf = cf.iloc[cfs]
    #print(cf["env_name"])
    return cf

def translate_method_name(method):
        if method == "spatial-appo":
            method = "jsd-st-dim"
        elif method == "infonce-stdim":
            method = "st-dim"
        elif method == "global-infonce-stdim":
            method = "global-t-dim"
        elif method == "naff":
            method = "pixel-pred"
        elif method == "majority":
            method = "maj-clf"
        elif method == "global-local-infonce-stdim":
            method = "gl-st-dim"
        return method
    

def get_main_df(wandb_proj="curl-atari/curl-atari-post-neurips-2", collect_mode="random_agent"):
#     if metric == "f1":
#         metric_name = "f1score"
#     elif metric == "acc":
#         metric_name = "test_acc"
    api = wandb.Api()
    
    runs = list(api.runs(wandb_proj, 
                         {"state": "finished", 
                          "config.collect_mode":collect_mode,
                                                    }))

    #df = pd.DataFrame()
    rd = [run.summary_metrics for run in runs]
    df = pd.DataFrame(rd)
    df['env_name'] = [run.config['env_name'].split("NoFrameskip")[0].lower() for run in runs]
    ms = []
    for run in runs:
        method = run.config['method']
        method = translate_method_name(method)
        ms.append(method)
    df['method'] = ms
    for metric_name in ["f1_score", "test_acc"]:
        metrics = []
        for run in runs:
            if "mean_mean_" + metric_name in  run.summary_metrics:
                metrics.append(run.summary_metrics["mean_mean_" + metric_name] )
            elif "mean_" + metric_name in run.summary_metrics:
                metrics.append(run.summary_metrics["mean_" + metric_name])
            else:
                metrics.append(np.nan)

        df[metric_name] = metrics
    df["timestamp"] = [run.summary_metrics['_timestamp'] if "_timestamp" in run.summary_metrics else 0  for run in runs  ]

    if collect_mode == "random_agent":
        for method in ["pixel-pred"]:
            cf = df[df["method"]==method]
            cf = keep_most_recent(cf)
            df = df[df["method"]!=method]
            df = pd.concat([df,cf],axis=0,sort=True)
    return df

def compute_cat_df(df, metric_name="f1score"):
    for cat,cat_keys in skd.items():
        cols = [c for c in df.columns if c.split("_" + metric_name)[0] in cat_keys]
        df[cat] = df[cols].mean(axis=1)
    

    
    
    cat_df = df.loc[:,["env_name","method"] + list(skd.keys())]
    
    cat_df["overall"] = cat_df.loc[:,list(skd.keys())].mean(axis=1)

    return cat_df

def get_method_df(cat_df,method):
    #method = translate_method_name(method)
    cat_df = cat_df.loc[:,["method","env_name", "overall"]]
    final_cols = ["env_name", "overall"]
    mdf = cat_df.loc[cat_df.method==method].loc[:,final_cols].rename(columns={"overall":method})
    mdf = mdf.set_index("env_name")
    return mdf
    

def cat_avg(df, methods, metric="f1"):
    if metric == "f1":
        metric_name = "f1score"
    elif metric == "acc":
        metric_name = "test_acc"
    cat_df = compute_cat_df(df, metric_name=metric_name)
    cdfs = []
    for method in methods:
        mdf = get_method_df(cat_df,method)
#         print(method,len(mdf))
        cdfs.append(mdf)

    
    ldf = pd.concat(cdfs, axis=1,sort=True)
    ldf.loc['mean'] = ldf.mean()
    ldf=ldf.round(2)
    #latex_ldf = ldf.to_latex()
    #print(add_bold_max(latex_ldf, ignore_last_column=ignore_last_column))
    return ldf

def manual_add(raw_df,ldf,method, metric="f1"):
    if metric == "f1":
        metric_name = "f1score"
    elif metric == "acc":
        metric_name = "test_acc"
    #method = translate_method_name(method)
    cat_df = compute_cat_df(raw_df,metric_name=metric_name)
    mdf = get_method_df(cat_df, method)
    mdf = mdf.sort_values(by="env_name")
    mdf.loc['mean'] = mdf.mean()
    mdf = mdf.round(2)
    ldf[method] = mdf[method].values
    return ldf

def print_latex(ldf,ignore_last_column=True, ignore_first_column=False):
    latex_ldf = ldf.to_latex()
    print(add_bold_max(latex_ldf, ignore_last_column=ignore_last_column, ignore_first_column=ignore_first_column))

def all_cat(df,methods,ignore_last_column=True, metric="f1"):
    if metric == "f1":
        metric_name = "f1score"
    elif metric == "acc":
        metric_name = "test_acc"
    cat_df = compute_cat_df(df, metric_name=metric_name)
    cw_df = pd.DataFrame()

    for method in methods:
        mdf = cat_df[cat_df["method"]==method]
        mdf = mdf[list(skd.keys())]
        r = mdf.mean(axis=0)
        cw_df[method]  =r

    cw_df = cw_df.T
    cw_df = cw_df.rename(columns={"small_object_localization": "Small Loc.","agent_localization":"Agent Loc.","other_localization": "Other Loc.", "score_clock_lives_display": "Score/Clock/Lives/Display", "misc_keys": "Misc."})
    cw_df = cw_df.T
    cw_df = cw_df.round(2)
    

 
    return cw_df   

### Table 1

In [3]:
# skdd = deepcopy(skd)
# skdd["overall"] = all_keys
# keys = ['agent_localization',
#  'small_object_localization',
#  'other_localization',
#  'score_clock_lives_display',
#  'misc_keys',
#  'overall']#list(summary_key_dict.keys())
# summary_stats = {env:{sk:0 for sk in keys} for env in atari_dict.keys()}

# for env in atari_dict.keys():
#     for k in atari_dict[env].keys():
#         for sk,v in skdd.items():
#             if k in v:
#                 summary_stats[env][sk] +=1

# print(" & ".join(["game"] + list(keys)))

# for env,v in summary_stats.items():
#     print( " & ".join([env] + list(str(v[k]) for k in keys)), "\\\\")
# print(" & ".join(["total"] + [str(len(skdd[k])) for k in keys]), "\\\\")

### Table 2 
(Average over Categories instead of State Variables)

In [4]:
ra_df = get_main_df(wandb_proj="curl-atari/curl-atari-post-neurips-2", collect_mode="random_agent")
methods = [ "maj-clf","random-cnn","vae","pixel-pred","cpc","st-dim", "supervised"]  #"pixel-pred"
methods_filt = deepcopy(methods)
ldf = cat_avg(df=ra_df,methods=methods_filt, metric="f1")
ldf= ldf[methods]
ldf["vae"]["pitfall"] = ra_df[(ra_df.method == "vae") & (ra_df.env_name == "pitfall")].across_categories_avg_f1

print_latex(ldf)

\begin{tabular}{lrrrrrrr}
\toprule
{} &  maj-clf &  random-cnn &       vae &  pixel-pred &   cpc &  st-dim &  supervised \\
\midrule
asteroids          &   0.28  &   0.34  &   0.36  &   0.34  &   0.42  &  \textbf{ 0.49}   &   0.52\\
berzerk            &   0.18  &   0.43  &   0.45  &  \textbf{ 0.55}   &  \textbf{ 0.56}   &   0.53  &   0.68\\
bowling            &   0.33  &   0.48  &   0.50  &   0.81  &   0.90  &  \textbf{ 0.96}   &   0.95\\
boxing             &   0.01  &   0.19  &   0.20  &   0.44  &   0.29  &  \textbf{ 0.58}   &   0.83\\
breakout           &   0.17  &   0.51  &   0.57  &   0.70  &   0.74  &  \textbf{ 0.88}   &   0.94\\
demonattack        &   0.16  &   0.26  &   0.26  &   0.32  &   0.57  &  \textbf{ 0.69}   &   0.83\\
freeway            &   0.01  &   0.50  &   0.01  &  \textbf{ 0.81}   &   0.47  &  \textbf{ 0.81}   &   0.98\\
frostbite          &   0.08  &   0.57  &   0.51  &   0.72  &  \textbf{ 0.76}   &  \textbf{ 0.75}   &   0.85\\
hero               &   0.22  &   0.75

### Table 3
Categorical Summaries

In [5]:
methods = [ "maj-clf","random-cnn","vae", "pixel-pred","cpc","st-dim", "supervised"] 
cw_df = all_cat(ra_df,methods, metric="f1")
print_latex(cw_df)

\begin{tabular}{lrrrrrrr}
\toprule
{} &  maj-clf &  random-cnn &   vae &  pixel-pred &   cpc &  st-dim &  supervised \\
\midrule
Small Loc.                  &   0.14  &   0.19  &   0.18  &   0.31  &   0.42  &  \textbf{ 0.51}   &   0.66\\
Agent Loc.                  &   0.12  &   0.31  &   0.32  &   0.48  &   0.43  &  \textbf{ 0.58}   &   0.81\\
Other Loc.                  &   0.14  &   0.50  &   0.39  &   0.61  &   0.66  &  \textbf{ 0.69}   &   0.80\\
Score/Clock/Lives/Display   &   0.13  &   0.58  &   0.54  &   0.76  &   0.83  &  \textbf{ 0.87}   &   0.91\\
Misc.                       &   0.26  &   0.59  &   0.63  &   0.70  &   0.71  &  \textbf{ 0.75}   &   0.83\\
\bottomrule
\end{tabular}



### Table 4
Incompleteness Table

In [6]:
boxing_df = get_game(ra_df, "boxing",methods=["vae","pixel-pred","cpc","global-t-dim", "st-dim"])
print_latex(boxing_df,ignore_last_column=False)

\begin{tabular}{lrrrrr}
\toprule
method &   vae &  pixel-pred &   cpc &  global-t-dim &  st-dim \\
\midrule
clock          &   0.03  &   0.27  &   0.79  &   0.81  &  \textbf{ 0.92} \\
enemy\_score    &   0.19  &   0.58  &   0.59  &  \textbf{ 0.74}   &   0.70\\
enemy\_x        &   0.32  &   0.49  &   0.15  &   0.17  &  \textbf{ 0.51} \\
enemy\_y        &   0.22  &  \textbf{ 0.42}   &   0.04  &   0.16  &   0.38\\
player\_score   &   0.08  &   0.32  &   0.56  &   0.45  &  \textbf{ 0.88} \\
player\_x       &   0.33  &   0.54  &   0.19  &   0.13  &  \textbf{ 0.56} \\
player\_y       &   0.16  &  \textbf{ 0.43}   &   0.04  &   0.14  &   0.37\\
\bottomrule
\end{tabular}



### Table 5
No computation needed (text tables)

### Table 6

In [7]:
ppo_df = get_main_df(collect_mode="pretrained_ppo")

In [8]:
ppo_df[(ppo_df.method == "supervised") & (ppo_df.env_name == "tennis")].mean_mean_f1score

34    0.55911
Name: mean_mean_f1score, dtype: float64

In [9]:
methods = [ "maj-clf","random-cnn","vae", "pixel-pred","cpc","st-dim", "supervised"] 
methods_filt = deepcopy(methods)
ldf = cat_avg(df=ppo_df,methods=methods_filt)
ldf= ldf[methods]
ldf["supervised"]["asteroids"] = ppo_df[(ppo_df.method == "supervised") & (ppo_df.env_name == "asteroids")].across_categories_avg_f1
ldf["supervised"]["berzerk"] = ppo_df[(ppo_df.method == "supervised") & (ppo_df.env_name == "berzerk")].across_categories_avg_f1
ldf = ldf.drop("mean")

Add PPO Mean Reward Column on the Left

In [10]:
api = wandb.Api()
wandb_proj = "curl-atari/pretrained-rl-agents-2"
runs = list(api.runs(wandb_proj, 
                     {"state": "finished", 
                                                }))

#df = pd.DataFrame()
rd = [run.summary_metrics for run in runs]
df = pd.DataFrame(rd)

df['env_name'] = [run.config['env_name'].split("NoFrameskip")[0].lower() for run in runs]

df.columns

reward_col = df.set_index("env_name")

r_df = reward_col.loc[:,["Mean Rewards"]]

cdf = pd.concat((r_df,ldf),axis=1,sort=True)

cdf.loc["mean"] = cdf.mean().round(2)
cdf.iloc[22,0] = np.nan

In [11]:
cdf= cdf.rename(columns={"Mean Rewards": "mean agent rewards"})

In [12]:
print_latex(cdf, ignore_first_column=True)

\begin{tabular}{lrrrrrrrr}
\toprule
{} &  mean agent rewards &  maj-clf &  random-cnn &   vae &  pixel-pred &   cpc &  st-dim &  supervised \\
\midrule
asteroids          &  489862.00  &   0.23  &   0.31  &   0.35  &   0.31  &   0.38  &  \textbf{ 0.40}   &   0.56\\
berzerk            &  1913.00  &   0.13  &   0.33  &   0.35  &   0.39  &   0.38  &  \textbf{ 0.43}   &   0.61\\
bowling            &  29.80  &   0.23  &   0.61  &   0.51  &   0.81  &   0.90  &  \textbf{ 0.98}   &   0.98\\
boxing             &  93.30  &   0.05  &   0.30  &   0.32  &   0.57  &   0.32  &  \textbf{ 0.66}   &   0.87\\
breakout           &  580.40  &   0.09  &   0.34  &   0.59  &   0.47  &   0.55  &  \textbf{ 0.66}   &   0.87\\
demonattack        &  428165.00  &   0.03  &   0.19  &   0.18  &   0.26  &   0.43  &  \textbf{ 0.58}   &   0.76\\
freeway            &  33.50  &   0.01  &   0.36  &   0.02  &  \textbf{ 0.60}   &   0.38  &  \textbf{ 0.60}   &   0.76\\
frostbite          &  3561.00  &   0.13  &   0.57  &   0.

### Table 7

In [13]:
cw_df = all_cat(ppo_df,methods)
print_latex(cw_df)

\begin{tabular}{lrrrrrrr}
\toprule
{} &  maj-clf &  random-cnn &   vae &  pixel-pred &   cpc &  st-dim &  supervised \\
\midrule
Small Loc.                  &   0.10  &   0.13  &   0.14  &   0.27  &   0.31  &  \textbf{ 0.41}   &   0.65\\
Agent Loc.                  &   0.11  &   0.34  &   0.34  &   0.48  &   0.45  &  \textbf{ 0.54}   &   0.83\\
Other Loc.                  &   0.14  &   0.47  &   0.38  &   0.56  &   0.58  &  \textbf{ 0.61}   &   0.74\\
Score/Clock/Lives/Display   &   0.05  &   0.44  &   0.50  &   0.71  &   0.74  &  \textbf{ 0.80}   &   0.90\\
Misc.                       &   0.19  &   0.53  &   0.57  &   0.62  &   0.65  &  \textbf{ 0.67}   &   0.83\\
\bottomrule
\end{tabular}



### Table 8
Ablation Version of Table 2

In [14]:
methods = ["static-dim-2", "jsd-st-dim", "global-t-dim", "st-dim"] 
adf = cat_avg(ra_df,methods)
adf=adf.rename(columns={"static-dim-2":"static-dim"})
print_latex(adf, ignore_last_column=False)

\begin{tabular}{lrrrr}
\toprule
{} &  static-dim &  jsd-st-dim &  global-t-dim &  st-dim \\
\midrule
asteroids          &   0.37  &   0.44  &   0.38  &  \textbf{ 0.49} \\
berzerk            &   0.41  &   0.49  &   0.49  &  \textbf{ 0.53} \\
bowling            &   0.34  &   0.91  &   0.77  &  \textbf{ 0.96} \\
boxing             &   0.09  &  \textbf{ 0.61}   &   0.32  &   0.58\\
breakout           &   0.19  &   0.85  &   0.71  &  \textbf{ 0.88} \\
demonattack        &   0.30  &   0.44  &   0.43  &  \textbf{ 0.69} \\
freeway            &   0.02  &   0.70  &   0.76  &  \textbf{ 0.81} \\
frostbite          &   0.27  &   0.52  &   0.68  &  \textbf{ 0.75} \\
hero               &   0.59  &   0.85  &   0.87  &  \textbf{ 0.93} \\
montezumarevenge   &   0.17  &   0.55  &   0.67  &  \textbf{ 0.78} \\
mspacman           &   0.17  &   0.70  &   0.53  &  \textbf{ 0.72} \\
pitfall            &   0.22  &   0.47  &   0.44  &  \textbf{ 0.60} \\
pong               &   0.13  &  \textbf{ 0.80}   &   0.65  

### Figure 3 Data

In [15]:
jsd, globalt, stdim = [list(adf.values[:,i]) for i in range(3)]
ablation_arrays =dict(jsd=jsd, globalt=globalt, stdim=stdim, keys=list(adf.index.values))
print(ablation_arrays)

{'jsd': [0.37, 0.41, 0.34, 0.09, 0.19, 0.3, 0.02, 0.27, 0.59, 0.17, 0.17, 0.22, 0.13, 0.25, 0.41, 0.16, 0.41, 0.4, 0.17, 0.25, 0.21, 0.12, 0.26], 'globalt': [0.44, 0.49, 0.91, 0.61, 0.85, 0.44, 0.7, 0.52, 0.85, 0.55, 0.7, 0.47, 0.8, 0.79, 0.59, 0.28, 0.55, 0.44, 0.57, 0.4, 0.54, 0.32, 0.58], 'stdim': [0.38, 0.49, 0.77, 0.32, 0.71, 0.43, 0.76, 0.68, 0.87, 0.67, 0.53, 0.44, 0.65, 0.81, 0.57, 0.33, 0.59, 0.44, 0.52, 0.47, 0.53, 0.18, 0.55], 'keys': ['asteroids', 'berzerk', 'bowling', 'boxing', 'breakout', 'demonattack', 'freeway', 'frostbite', 'hero', 'montezumarevenge', 'mspacman', 'pitfall', 'pong', 'privateeye', 'qbert', 'riverraid', 'seaquest', 'spaceinvaders', 'tennis', 'venture', 'videopinball', 'yarsrevenge', 'mean']}


### Table 9
Ablation Version of Table 3

In [16]:
methods = ["static-dim-2","jsd-st-dim", "global-t-dim", "st-dim"] 
acdf = all_cat(ra_df, methods)
acdf=acdf.rename(columns={"static-dim-2":"static-dim"})
print_latex(acdf,ignore_last_column=False)

\begin{tabular}{lrrrr}
\toprule
{} &  static-dim &  jsd-st-dim &  global-t-dim &  st-dim \\
\midrule
Small Loc.                  &   0.18  &   0.44  &   0.37  &  \textbf{ 0.51} \\
Agent Loc.                  &   0.19  &   0.47  &   0.43  &  \textbf{ 0.58} \\
Other Loc.                  &   0.27  &   0.64  &   0.53  &  \textbf{ 0.69} \\
Score/Clock/Lives/Display   &   0.33  &   0.69  &   0.76  &  \textbf{ 0.87} \\
Misc.                       &   0.41  &   0.64  &   0.66  &  \textbf{ 0.75} \\
\bottomrule
\end{tabular}



###  Table 10 
RL Probe With Data from PPO

In [17]:
ppo_df = get_main_df(collect_mode="pretrained_ppo")
methods = [ "maj-clf","random-cnn","pretrained-rl-agent"] 
ldf = cat_avg(df=ppo_df,methods=methods)
print_latex(ldf,ignore_last_column=False)

\begin{tabular}{lrrr}
\toprule
{} &  maj-clf &  random-cnn &  pretrained-rl-agent \\
\midrule
asteroids          &   0.23  &  \textbf{ 0.31}   &  \textbf{ 0.31} \\
berzerk            &   0.13  &  \textbf{ 0.33}   &   0.30\\
bowling            &   0.23  &  \textbf{ 0.61}   &   0.48\\
boxing             &   0.05  &  \textbf{ 0.30}   &   0.12\\
breakout           &   0.09  &  \textbf{ 0.34}   &   0.23\\
demonattack        &   0.03  &  \textbf{ 0.19}   &   0.16\\
freeway            &   0.01  &  \textbf{ 0.36}   &   0.26\\
frostbite          &   0.13  &  \textbf{ 0.57}   &   0.43\\
hero               &   0.12  &  \textbf{ 0.54}   &   0.42\\
montezumarevenge   &   0.08  &  \textbf{ 0.68}   &   0.07\\
mspacman           &   0.07  &  \textbf{ 0.34}   &   0.26\\
pitfall            &   0.16  &  \textbf{ 0.39}   &   0.23\\
pong               &   0.02  &  \textbf{ 0.10}   &   0.09\\
privateeye         &   0.24  &  \textbf{ 0.71}   &   0.31\\
qbert              &   0.06  &  \textbf{ 0.36}   &   0.3

### Table 11
Accuracy Version of Table 2

In [18]:
ra_df = get_main_df(wandb_proj="curl-atari/curl-atari-post-neurips-2", collect_mode="random_agent")
methods = [ "maj-clf","random-cnn","vae","pixel-pred","cpc","st-dim", "supervised"]  #"pixel-pred"
methods_filt = deepcopy(methods)
ldf = cat_avg(df=ra_df,methods=methods_filt, metric="acc")
ldf= ldf[methods]
ldf["vae"]["pitfall"] = ra_df[(ra_df.method == "vae") & (ra_df.env_name == "pitfall")].across_categories_avg_f1

print_latex(ldf)

\begin{tabular}{lrrrrrrr}
\toprule
{} &  maj-clf &  random-cnn &       vae &  pixel-pred &   cpc &  st-dim &  supervised \\
\midrule
asteroids          &   0.37  &   0.42  &   0.41  &   0.43  &   0.48  &  \textbf{ 0.52}   &   0.53\\
berzerk            &   0.30  &   0.48  &   0.46  &  \textbf{ 0.56}   &  \textbf{ 0.57}   &   0.54  &   0.69\\
bowling            &   0.43  &   0.54  &   0.56  &   0.83  &   0.90  &  \textbf{ 0.96}   &   0.95\\
boxing             &   0.05  &   0.22  &   0.23  &   0.45  &   0.32  &  \textbf{ 0.59}   &   0.83\\
breakout           &   0.28  &   0.55  &   0.61  &   0.71  &   0.75  &  \textbf{ 0.89}   &   0.94\\
demonattack        &   0.26  &   0.30  &   0.31  &   0.35  &   0.58  &  \textbf{ 0.70}   &   0.83\\
freeway            &   0.06  &   0.53  &   0.07  &  \textbf{ 0.85}   &   0.49  &   0.82  &   0.99\\
frostbite          &   0.19  &   0.59  &   0.54  &   0.72  &  \textbf{ 0.76}   &  \textbf{ 0.75}   &   0.85\\
hero               &   0.34  &   0.78  &   0.72

### Table 12
Accuracy Version of Table 3

In [19]:
methods = [ "maj-clf","random-cnn","vae", "pixel-pred","cpc","st-dim", "supervised"] 
cw_df = all_cat(ra_df,methods, metric="acc")
print_latex(cw_df)

\begin{tabular}{lrrrrrrr}
\toprule
{} &  maj-clf &  random-cnn &   vae &  pixel-pred &   cpc &  st-dim &  supervised \\
\midrule
Small Loc.                  &   0.23  &   0.29  &   0.26  &   0.36  &   0.46  &  \textbf{ 0.53}   &   0.67\\
Agent Loc.                  &   0.21  &   0.37  &   0.37  &   0.51  &   0.46  &  \textbf{ 0.59}   &   0.81\\
Other Loc.                  &   0.22  &   0.54  &   0.42  &   0.63  &   0.67  &  \textbf{ 0.70}   &   0.80\\
Score/Clock/Lives/Display   &   0.24  &   0.61  &   0.56  &   0.77  &   0.84  &  \textbf{ 0.87}   &   0.91\\
Misc.                       &   0.38  &   0.61  &   0.65  &   0.71  &   0.72  &  \textbf{ 0.75}   &   0.83\\
\bottomrule
\end{tabular}



### Tables 13-34
Fine-Grained Table for Every Game for Random Agent

In [21]:
for game in atari_dict.keys():
    metric_name="f1score"
    if game == "pitfall":
        methods=[ "maj-clf","random-cnn", "pixel-pred","cpc","st-dim", "supervised"]
    else:
        methods=[ "maj-clf","random-cnn","vae", "pixel-pred","cpc","st-dim", "supervised"]
    gmdf = get_game(ra_df,game,
                    methods=methods,
                   metric_name=metric_name)
    gmdf = gmdf.dropna()
    if game == "pitfall":
        vgmdf = get_game(ra_df,game="pitfall",
                    methods=["vae"],
                   metric_name="f1")
        gmdf = pd.concat((vgmdf,gmdf),axis=1)
        methods=[ "maj-clf","random-cnn","vae", "pixel-pred","cpc","st-dim", "supervised"]
        gmdf = gmdf[methods]
        

    if "supervised" in gmdf:
        ignore_last_column=True
        s = "|"
    else:
        s=""
        ignore_last_column=False
    print("\\begin{table*}[ht] \\caption{%s fine-grained results. Breakdown of F1 Scores for every state variable in %s  for every method for probes where data was collected by random agent} \\label{%s-inc} \\vskip 0.15in \\begin{center} \\begin{small} \\begin{sc}"%(game.capitalize(),game.capitalize(),game.capitalize()))   
    print(" \\scalebox{0.9}{ \\begin{adjustbox}{center}")
    print(add_bold_max(gmdf.to_latex(),ignore_last_column=ignore_last_column).replace("rrr}", "rr%sr}"%s))
    print("\\end{adjustbox}}")
    print("""\\end{sc}
\\end{small}
\\end{center}
\\vskip -0.1in
\\end{table*}""")

\begin{table*}[ht] \caption{Asteroids fine-grained results. Breakdown of F1 Scores for every state variable in Asteroids  for every method for probes where data was collected by random agent} \label{Asteroids-inc} \vskip 0.15in \begin{center} \begin{small} \begin{sc}
 \scalebox{0.9}{ \begin{adjustbox}{center}
\begin{tabular}{lrrrrrr|r}
\toprule
method &  maj-clf &  random-cnn &   vae &  pixel-pred &   cpc &  st-dim &  supervised \\
\midrule
enemy\_asteroids\_x\_0         &   0.02  &   0.33  &  \textbf{ 0.37}   &   0.20  &  \textbf{ 0.37}   &   0.34  &   0.48\\
enemy\_asteroids\_x\_10        &   0.00  &   0.17  &   0.11  &   0.16  &   0.22  &  \textbf{ 0.25}   &   0.28\\
enemy\_asteroids\_x\_11        &   0.03  &   0.11  &   0.07  &   0.14  &   0.20  &  \textbf{ 0.26}   &   0.22\\
enemy\_asteroids\_x\_12        &  \textbf{ 0.58}   &   0.32  &   0.36  &   0.25  &   0.34  &   0.46  &   0.42\\
enemy\_asteroids\_x\_1         &   0.07  &  \textbf{ 0.44}   &   0.27  &   0.24  &   0.24  &   0.

### Tables 35-56
Fine-Grained Table for Every Game for PPO Agent

In [22]:
for game in atari_dict.keys():
    metric_name="f1score"
    
    # must manually add supervised for these ones
    if game == "asteroids" or game == "berzerk":
        methods=[ "maj-clf","random-cnn","vae", "pixel-pred","cpc","st-dim"]
    else:
        methods=[ "maj-clf","random-cnn","vae", "pixel-pred","cpc","st-dim", "supervised"]
    gmdf = get_game(ppo_df,game,
                    methods=methods,
                   metric_name=metric_name)
    gmdf = gmdf.dropna()
    if "supervised" in gmdf:
        ignore_last_column=True
        s = "|"
    else:
        s=""
        ignore_last_column=False
    print("\\begin{table*}[ht] \\caption{%s fine-grained results. Breakdown of F1 Scores for every state variable in %s  for every method for probes where data was collected by a pretrained PPO agent that was trained for 50M frames} \\label{%s-inc-ppo} \\vskip 0.15in \\begin{center} \\begin{small} \\begin{sc}"%(game.capitalize(),game.capitalize(),game.capitalize()))   
    print(" \\scalebox{0.9}{ \\begin{adjustbox}{center}")
    print(add_bold_max(gmdf.to_latex(),ignore_last_column=ignore_last_column).replace("rrr}", "rr%sr}"%s))
    print("\\end{adjustbox}}")
    print("""\\end{sc}
\\end{small}
\\end{center}
\\vskip -0.1in
\\end{table*}""")

\begin{table*}[ht] \caption{Asteroids fine-grained results. Breakdown of F1 Scores for every state variable in Asteroids  for every method for probes where data was collected by a pretrained PPO agent that was trained for 50M frames} \label{Asteroids-inc-ppo} \vskip 0.15in \begin{center} \begin{small} \begin{sc}
 \scalebox{0.9}{ \begin{adjustbox}{center}
\begin{tabular}{lrrrrrr}
\toprule
method &  maj-clf &  random-cnn &   vae &  pixel-pred &   cpc &  st-dim \\
\midrule
player\_x                    &   0.03  &   0.09  &   0.14  &   0.13  &   0.17  &  \textbf{ 0.21} \\
player\_y                    &   0.20  &   0.30  &   0.27  &   0.32  &   0.32  &  \textbf{ 0.41} \\
player\_score\_low            &   0.07  &   0.22  &   0.34  &   0.19  &   0.50  &  \textbf{ 0.73} \\
player\_missile\_x1           &   0.02  &   0.07  &   0.11  &   0.10  &   0.12  &  \textbf{ 0.14} \\
player\_missile\_x2           &   0.02  &   0.06  &   0.10  &   0.08  &   0.10  &  \textbf{ 0.12} \\
player\_missile\_y1   