In [1]:
import sys

#sys.path.append("/Users/evanracah/Dropbox/projects/atari-representation-learning/")

import wandb
import pandas as pd
import numpy as np
from copy import deepcopy

from aari.ram_annotations import atari_dict
from aari.categorization import summary_key_dict as skd,  unused_keys, detailed_key_dict, all_keys

In [26]:
def get_game(df, game, methods=["cpc","st-dim"]):
    #game = "seaquest"

    gdf = df[df["env_name"]==game]
    mdfs = [gdf[(gdf["method"]==method)] for method in methods] 

    mdf = pd.concat(mdfs,axis=0)
    cols = [c for c in mdf.columns if "f1score" in c and c.split("_f1score")[0] in atari_dict[game].keys()]
    col_change_dic = {col:col.split("_f1score")[0] for col in cols}
    cols = ["method"] +cols 

    gmdf = mdf[cols]
    gmdf = gmdf.rename(columns=col_change_dic)

    gmdf = gmdf.round(2)
    gmdf = gmdf.set_index("method")
    gmdf = gmdf.T
    return gmdf



def add_bold_max(latex_str, ignore_last_column=True, margin=0.01):
    strs = []
    for r in latex_str.split("\\\\"):
        if "." in r:
            s = r.split("&")
            nums = [float(n) for n in s[1:] if "." in n or "NaN" in  n]
            #print(nums)
            if ignore_last_column:
                am_nums = nums[:-1]
            else:
                am_nums = nums
            filt_nums = [n if not np.isnan(n) else -1 for n in am_nums ]
            amax, maxx = [np.argmax(filt_nums)], np.max(filt_nums)
            for i,n in enumerate(filt_nums):
                if i == amax[0]:
                    continue
                if n+margin >= maxx:
                    amax.append(i)
            for am in amax:
                nums[am] = "\\textbf{%5.2f} "%(nums[am])
            nums = ["%5.2f"%n if not isinstance(n,str) else n for n in nums ]
            nr = [s[0]] + nums
            nr = "  &  ".join(nr)
        else:
            nr = r
        strs.append(nr)

    final_latex = "\\\\".join(strs)
    final_latex = final_latex.replace("nan","NaN")
    return final_latex

def keep_most_recent(df):
    cf =df.reset_index()

    cfs = [cf[cf["env_name"]==game] for game in atari_dict.keys()]

    cfs =  [cdf["timestamp"].idxmax() for cdf in cfs]

    cf = cf.iloc[cfs]
    #print(cf["env_name"])
    return cf

def translate_method_name(method):
        if method == "spatial-appo":
            method = "jsd-st-dim"
        elif method == "infonce-stdim":
            method = "st-dim"
        elif method == "global-infonce-stdim":
            method = "global-t-dim"
        elif method == "naff":
            method = "pixel-pred"
        elif method == "majority":
            method = "maj-clf"
        elif method == "global-local-infonce-stdim":
            method = "gl-st-dim"
        return method
    

def get_main_df(wandb_proj="curl-atari/curl-atari-post-neurips-2", collect_mode="random_agent"):
#     if metric == "f1":
#         metric_name = "f1score"
#     elif metric == "acc":
#         metric_name = "test_acc"
    api = wandb.Api()
    
    runs = list(api.runs(wandb_proj, 
                         {"state": "finished", 
                          "config.collect_mode":collect_mode,
                                                    }))

    #df = pd.DataFrame()
    rd = [run.summary_metrics for run in runs]
    df = pd.DataFrame(rd)
    df['env_name'] = [run.config['env_name'].split("NoFrameskip")[0].lower() for run in runs]
    ms = []
    for run in runs:
        method = run.config['method']
        method = translate_method_name(method)
        ms.append(method)
    df['method'] = ms
    for metric_name in ["f1_score", "test_acc"]:
        metrics = []
        for run in runs:
            if "mean_mean_" + metric_name in  run.summary_metrics:
                metrics.append(run.summary_metrics["mean_mean_" + metric_name] )
            elif "mean_" + metric_name in run.summary_metrics:
                metrics.append(run.summary_metrics["mean_" + metric_name])
            else:
                metrics.append(np.nan)

        df[metric_name] = metrics
    df["timestamp"] = [run.summary_metrics['_timestamp'] if "_timestamp" in run.summary_metrics else 0  for run in runs  ]

    if collect_mode == "random_agent":
        for method in ["pixel-pred"]:
            cf = df[df["method"]==method]
            cf = keep_most_recent(cf)
            df = df[df["method"]!=method]
            df = pd.concat([df,cf],axis=0,sort=True)
    return df

def compute_cat_df(df, metric_name="f1score"):
    for cat,cat_keys in skd.items():
        cols = [c for c in df.columns if c.split("_" + metric_name)[0] in cat_keys]
        df[cat] = df[cols].mean(axis=1)
    

    
    
    cat_df = df.loc[:,["env_name","method"] + list(skd.keys())]
    
    cat_df["overall"] = cat_df.loc[:,list(skd.keys())].mean(axis=1)

    return cat_df

def get_method_df(cat_df,method):
    #method = translate_method_name(method)
    cat_df = cat_df.loc[:,["method","env_name", "overall"]]
    final_cols = ["env_name", "overall"]
    mdf = cat_df.loc[cat_df.method==method].loc[:,final_cols].rename(columns={"overall":method})
    mdf = mdf.set_index("env_name")
    return mdf
    

def cat_avg(df, methods, metric="f1"):
    if metric == "f1":
        metric_name = "f1score"
    elif metric == "acc":
        metric_name = "test_acc"
    cat_df = compute_cat_df(df, metric_name=metric_name)
    cdfs = []
    for method in methods:
        mdf = get_method_df(cat_df,method)
#         print(method,len(mdf))
        cdfs.append(mdf)


    ldf = pd.concat(cdfs, axis=1,sort=True)
    ldf.loc['mean'] = ldf.mean()
    ldf=ldf.round(2)
    #latex_ldf = ldf.to_latex()
    #print(add_bold_max(latex_ldf, ignore_last_column=ignore_last_column))
    return ldf

def manual_add(raw_df,ldf,method, metric="f1"):
    if metric == "f1":
        metric_name = "f1score"
    elif metric == "acc":
        metric_name = "test_acc"
    #method = translate_method_name(method)
    cat_df = compute_cat_df(raw_df,metric_name=metric_name)
    mdf = get_method_df(cat_df, method)
    mdf = mdf.sort_values(by="env_name")
    mdf.loc['mean'] = mdf.mean()
    mdf = mdf.round(2)
    ldf[method] = mdf[method].values
    return ldf

def print_latex(ldf,ignore_last_column=True):
    latex_ldf = ldf.to_latex()
    print(add_bold_max(latex_ldf, ignore_last_column=ignore_last_column))

def all_cat(df,methods,ignore_last_column=True, metric="f1"):
    if metric == "f1":
        metric_name = "f1score"
    elif metric == "acc":
        metric_name = "test_acc"
    cat_df = compute_cat_df(df, metric_name=metric_name)
    cw_df = pd.DataFrame()

    for method in methods:
        mdf = cat_df[cat_df["method"]==method]
        mdf = mdf[list(skd.keys())]
        r = mdf.mean(axis=0)
        cw_df[method]  =r

    cw_df = cw_df.T
    cw_df = cw_df.rename(columns={"small_object_localization": "Small Loc.","agent_localization":"Agent Loc.","other_localization": "Other Loc.", "score_clock_lives_display": "Score/Clock/Lives/Display", "misc_keys": "Misc."})
    cw_df = cw_df.T
    cw_df = cw_df.round(2)
    

 
    return cw_df   

### Table 1

In [27]:
# skdd = deepcopy(skd)
# skdd["overall"] = all_keys
# keys = ['agent_localization',
#  'small_object_localization',
#  'other_localization',
#  'score_clock_lives_display',
#  'misc_keys',
#  'overall']#list(summary_key_dict.keys())
# summary_stats = {env:{sk:0 for sk in keys} for env in atari_dict.keys()}

# for env in atari_dict.keys():
#     for k in atari_dict[env].keys():
#         for sk,v in skdd.items():
#             if k in v:
#                 summary_stats[env][sk] +=1

# print(" & ".join(["game"] + list(keys)))

# for env,v in summary_stats.items():
#     print( " & ".join([env] + list(str(v[k]) for k in keys)), "\\\\")
# print(" & ".join(["total"] + [str(len(skdd[k])) for k in keys]), "\\\\")

### Table 2 
(Average over Categories instead of State Variables)

In [37]:
ra_df = get_main_df(collect_mode="random_agent")

In [38]:
methods = [ "maj-clf","random-cnn","static-dim-2", "vae","pixel-pred","cpc","st-dim", "supervised"]  #"pixel-pred"
#methods = ["static-dim-2"]
methods_filt = deepcopy(methods)
methods_filt.remove("vae")
#methods_filt.remove("pixel-pred")

ldf = cat_avg(df=ra_df,methods=methods_filt, metric="f1")
ldf = manual_add(ra_df,ldf,"vae", metric="f1")
ldf= ldf[methods]
#print_latex(ldf)

In [39]:
ldf

Unnamed: 0,maj-clf,random-cnn,static-dim-2,vae,pixel-pred,cpc,st-dim,supervised
asteroids,0.28,0.34,0.37,0.36,0.34,0.42,0.49,0.52
berzerk,0.18,0.43,0.41,0.45,0.55,0.56,0.53,0.68
bowling,0.33,0.48,0.34,0.5,0.81,0.9,0.96,0.95
boxing,0.01,0.19,0.09,0.2,0.44,0.29,0.58,0.83
breakout,0.17,0.51,0.19,0.57,0.7,0.74,0.88,0.94
demonattack,0.16,0.26,0.3,0.25,0.32,0.57,0.69,0.83
freeway,0.01,0.5,0.02,0.26,0.81,0.47,0.81,0.98
frostbite,0.08,0.57,0.27,0.01,0.72,0.76,0.75,0.85
hero,0.22,0.75,0.59,0.51,0.74,0.9,0.93,0.98
montezumarevenge,0.08,0.68,0.17,0.69,0.74,0.75,0.78,0.87


### Table 3
Categorical Summaries

In [48]:
methods = [ "maj-clf","random-cnn","vae", "pixel-pred","cpc","dim","st-dim", "supervised"] 
cw_df = all_cat(ra_df,methods, metric="f1")
print_latex(cw_df)

\begin{tabular}{lrrrrrrrr}
\toprule
{} &  maj-clf &  random-cnn &   vae &  pixel-pred &   cpc &   dim &  st-dim &  supervised \\
\midrule
Small Loc.                  &   0.14  &   0.19  &   0.17  &   0.31  &   0.42  &   0.48  &  \textbf{ 0.51}   &   0.66\\
Agent Loc.                  &   0.12  &   0.31  &   0.30  &   0.47  &   0.43  &  \textbf{ 0.63}   &   0.58  &   0.81\\
Other Loc.                  &   0.14  &   0.50  &   0.36  &   0.60  &   0.66  &  \textbf{ 0.70}   &  \textbf{ 0.69}   &   0.80\\
Score/Clock/Lives/Display   &   0.13  &   0.58  &   0.53  &   0.76  &   0.83  &  \textbf{ 0.86}   &  \textbf{ 0.87}   &   0.91\\
Misc.                       &   0.26  &   0.59  &   0.65  &   0.70  &   0.71  &   0.72  &  \textbf{ 0.75}   &   0.83\\
\bottomrule
\end{tabular}



In [49]:
cw_df

Unnamed: 0,maj-clf,random-cnn,vae,pixel-pred,cpc,dim,st-dim,supervised
Small Loc.,0.14,0.19,0.17,0.31,0.42,0.48,0.51,0.66
Agent Loc.,0.12,0.31,0.3,0.47,0.43,0.63,0.58,0.81
Other Loc.,0.14,0.5,0.36,0.6,0.66,0.7,0.69,0.8
Score/Clock/Lives/Display,0.13,0.58,0.53,0.76,0.83,0.86,0.87,0.91
Misc.,0.26,0.59,0.65,0.7,0.71,0.72,0.75,0.83


### Table 4
Incompleteness Table

In [None]:
boxing_df = get_game(ra_df, "boxing",methods=["vae","pixel-pred","cpc","global-t-dim", "st-dim"])
print_latex(boxing_df,ignore_last_column=False)

### Table 5
No computation needed (text tables)

### Table 6

In [None]:
ppo_df = get_main_df(collect_mode="pretrained_ppo")
methods = [ "maj-clf","random-cnn","vae", "pixel-pred","cpc","st-dim", "supervised"] 
methods_filt = deepcopy(methods)
methods_filt.remove("supervised")
ldf = cat_avg(df=ppo_df,methods=methods_filt)

ldf = manual_add(ppo_df,ldf,"supervised")
ldf= ldf[methods]
print_latex(ldf)

### Table 7

In [None]:
cw_df = all_cat(ppo_df,methods)
print_latex(cw_df)

### Table 8
Ablation Version of Table 2

In [None]:
methods = ["jsd-st-dim", "global-t-dim", "st-dim"] 
adf = cat_avg(ra_df,methods)
print_latex(adf, ignore_last_column=False)

### Figure 3 Data

In [None]:
jsd, globalt, stdim = [list(adf.values[:,i]) for i in range(3)]
ablation_arrays =dict(jsd=jsd, globalt=globalt, stdim=stdim, keys=list(adf.index.values))
print(ablation_arrays)

### Table 9
Ablation Version of Table 3

In [None]:
methods = ["jsd-st-dim", "global-t-dim", "st-dim"] 
acdf = all_cat(ra_df, methods)
print_latex(acdf,ignore_last_column=False)

###  Table 10 
RL Probe With Data from PPO

In [None]:
ppo_df = get_main_df(collect_mode="pretrained_ppo")
methods = [ "maj-clf","random-cnn","pretrained-rl-agent"] 
ldf = cat_avg(df=ppo_df,methods=methods)
print_latex(ldf,ignore_last_column=False)

### Tables 11-33
Fine-Grained Table for Every Game

In [None]:
for game in atari_dict.keys():
    gmdf = get_game(ra_df,game,methods=[ "maj-clf","random-cnn","vae", "pixel-pred","cpc","st-dim", "supervised"] )
    gmdf = gmdf.dropna()
    if "supervised" in gmdf:
        ignore_last_column=True
        s = "|"
    else:
        s=""
        ignore_last_column=False
    print("\\begin{table*}[ht] \\caption{%s fine-grained results. Breakdown of F1 Scores for every state variable in %s  for every method for probes where data was collected by random agent} \\label{%s-inc} \\vskip 0.15in \\begin{center} \\begin{small} \\begin{sc}"%(game.capitalize(),game.capitalize(),game.capitalize()))   
    print(" \\scalebox{0.9}{ \\begin{adjustbox}{center}")
    print(add_bold_max(gmdf.to_latex(),ignore_last_column=ignore_last_column).replace("rrr}", "rr%sr}"%s))
    print("\\end{adjustbox}}")
    print("""\\end{sc}
\\end{small}
\\end{center}
\\vskip -0.1in
\\end{table*}""")