In [None]:
import sys

sys.path.append("/Users/evanracah/Dropbox/projects/rl-representation-learning/")

import wandb
import pandas as pd
import numpy as np
from copy import deepcopy

from src.atari_ram_annotations import summary_key_dict as skd, atari_dict, unused_keys, detailed_key_dict, all_keys

In [None]:
def get_game(df, game, methods=["cpc","st-dim"]):
    #game = "seaquest"

    gdf = df[df["env_name"]==game]
    mdfs = [gdf[(gdf["method"]==method)] for method in methods] 

    mdf = pd.concat(mdfs,axis=0)
    cols = [c for c in mdf.columns if "f1score" in c and c.split("_f1score")[0] in atari_dict[game].keys()]
    col_change_dic = {col:col.split("_f1score")[0] for col in cols}
    cols = ["method"] +cols 

    gmdf = mdf[cols]
    gmdf = gmdf.rename(columns=col_change_dic)

    gmdf = gmdf.round(2)
    gmdf = gmdf.set_index("method")
    gmdf = gmdf.T
    return gmdf



def add_bold_max(latex_str, ignore_last_column=True, margin=0.01):
    strs = []
    for r in latex_str.split("\\\\"):
        if "." in r:
            s = r.split("&")
            nums = [float(n) for n in s[1:] if "." in n or "NaN" in  n]
            #print(nums)
            if ignore_last_column:
                am_nums = nums[:-1]
            else:
                am_nums = nums
            filt_nums = [n if not np.isnan(n) else -1 for n in am_nums ]
            amax, maxx = [np.argmax(filt_nums)], np.max(filt_nums)
            for i,n in enumerate(filt_nums):
                if i == amax[0]:
                    continue
                if n+margin >= maxx:
                    amax.append(i)
            for am in amax:
                nums[am] = "\\textbf{%5.2f} "%(nums[am])
            nums = ["%5.2f"%n if not isinstance(n,str) else n for n in nums ]
            nr = [s[0]] + nums
            nr = "  &  ".join(nr)
        else:
            nr = r
        strs.append(nr)

    final_latex = "\\\\".join(strs)
    final_latex = final_latex.replace("nan","NaN")
    return final_latex

def keep_most_recent(df):
    cf =df.reset_index()

    cfs = [cf[cf["env_name"]==game] for game in atari_dict.keys()]

    cfs =  [cdf["timestamp"].idxmax() for cdf in cfs]

    cf = cf.iloc[cfs]
    return cf

def get_main_df(wandb_proj="curl-atari/curl-atari-post-neurips-2", collect_mode="random_agent"):
    api = wandb.Api()
    
    runs = list(api.runs(wandb_proj, 
                         {"state": "finished", 
                          "config.collect_mode":collect_mode,
                                                    }))

    #df = pd.DataFrame()
    rd = [run.summary_metrics for run in runs]
    df = pd.DataFrame(rd)
    df['env_name'] = [run.config['env_name'].split("NoFrameskip")[0].lower() for run in runs]
    ms = []
    for run in runs:
        method = run.config['method']
        if method == "spatial-appo":
            method = "jsd-st-dim"
        elif method == "infonce-stdim":
            method = "st-dim"
        elif method == "global-infonce-stdim":
            method = "global-t-dim"
        elif method == "naff":
            method = "pixel-pred"
        elif method == "majority":
            method = "maj-clf"
        elif method == "global-local-infonce-stdim":
            method = "gl-st-dim"
        ms.append(method)
    df['method'] = ms
    f1s = []
    for run in runs:
        if "mean_mean_f1score" in  run.summary_metrics:
            f1s.append(run.summary_metrics["mean_mean_f1score"] )
        elif "mean_f1score" in run.summary_metrics:
            f1s.append(run.summary_metrics["mean_f1score"])
        else:
            f1s.append(np.nan)

    df["f1"] = f1s
    df["timestamp"] = [run.summary_metrics['_timestamp'] if "_timestamp" in run.summary_metrics else 0  for run in runs  ]

    if collect_mode == "random_agent":
        cf = df[df["method"]=="cpc"]
        cf = keep_most_recent(cf)
        df = df[df["method"]!="cpc"]
        df = pd.concat([df,cf],axis=0,sort=True)
    return df

def compute_cat_df(df):
    for cat,cat_keys in skd.items():
        cols = [c for c in df.columns if c.split("_f1score")[0] in cat_keys]
        df[cat] = df[cols].mean(axis=1)

    
    
    cat_df = df.loc[:,["env_name","method"] + list(skd.keys())]
    
    cat_df["overall"] = cat_df.loc[:,list(skd.keys())].mean(axis=1)

    return cat_df

def get_method_df(cat_df,method):
    cat_df = cat_df.loc[:,["method","env_name", "overall"]]
    final_cols = ["env_name", "overall"]
    mdf = cat_df.loc[cat_df.method==method].loc[:,final_cols].rename(columns={"overall":method})
    mdf = mdf.set_index("env_name")
    return mdf
    

def cat_avg(df, methods):
    cat_df = compute_cat_df(df)
    cdfs = []
    for method in methods:
        mdf = get_method_df(cat_df,method)
        cdfs.append(mdf)


    ldf = pd.concat(cdfs, axis=1,sort=True)
    ldf.loc['mean'] = ldf.mean()
    ldf=ldf.round(2)
    #latex_ldf = ldf.to_latex()
    #print(add_bold_max(latex_ldf, ignore_last_column=ignore_last_column))
    return ldf

def manual_add(raw_df,ldf,method):
    cat_df = compute_cat_df(raw_df)
    mdf = get_method_df(cat_df, method)
    mdf = mdf.sort_values(by="env_name")
    mdf.loc['mean'] = mdf.mean()
    mdf = mdf.round(2)
    ldf[method] = mdf[method].values
    return ldf

def print_latex(ldf,ignore_last_column=True):
    latex_ldf = ldf.to_latex()
    print(add_bold_max(latex_ldf, ignore_last_column=ignore_last_column))

def all_cat(df,methods,ignore_last_column=True):
    cat_df = compute_cat_df(df)
    cw_df = pd.DataFrame()

    for method in methods:
        mdf = cat_df[cat_df["method"]==method]
        mdf = mdf[list(skd.keys())]
        r = mdf.mean(axis=0)
        cw_df[method]  =r

    cw_df = cw_df.T
    cw_df = cw_df.rename(columns={"small_object_localization": "Small Loc.","agent_localization":"Agent Loc.","other_localization": "Other Loc.", "score_clock_lives_display": "Score/Clock/Lives/Display", "misc_keys": "Misc."})
    cw_df = cw_df.T
    cw_df = cw_df.round(2)
    

 
    return cw_df   

### Table 1

In [None]:
# skdd = deepcopy(skd)
# skdd["overall"] = all_keys
# keys = ['agent_localization',
#  'small_object_localization',
#  'other_localization',
#  'score_clock_lives_display',
#  'misc_keys',
#  'overall']#list(summary_key_dict.keys())
# summary_stats = {env:{sk:0 for sk in keys} for env in atari_dict.keys()}

# for env in atari_dict.keys():
#     for k in atari_dict[env].keys():
#         for sk,v in skdd.items():
#             if k in v:
#                 summary_stats[env][sk] +=1

# print(" & ".join(["game"] + list(keys)))

# for env,v in summary_stats.items():
#     print( " & ".join([env] + list(str(v[k]) for k in keys)), "\\\\")
# print(" & ".join(["total"] + [str(len(skdd[k])) for k in keys]), "\\\\")

### Table 2 
(Average over Categories instead of State Variables)

In [None]:
ra_df = get_main_df(collect_mode="random_agent")
methods = [ "maj-clf","random-cnn","vae", "pixel-pred","cpc","st-dim", "supervised"] 
methods_filt = deepcopy(methods)
methods_filt.remove("vae")

ldf = cat_avg(df=ra_df,methods=methods_filt)
ldf = manual_add(ra_df,ldf,"vae")
ldf= ldf[methods]
print_latex(ldf)

### Table 3
Categorical Summaries

In [None]:
methods = [ "maj-clf","random-cnn","vae", "pixel-pred","cpc","st-dim", "supervised"] 
cw_df = all_cat(ra_df,methods)
print_latex(cw_df)

### Table 4
Incompleteness Table

In [None]:
boxing_df = get_game(ra_df, "boxing",methods=["vae","pixel-pred","cpc","global-t-dim", "st-dim"])
print_latex(boxing_df,ignore_last_column=False)

### Table 5
No computation needed (text tables)

### Table 6

In [None]:
ppo_df = get_main_df(collect_mode="pretrained_ppo")
methods = [ "maj-clf","random-cnn","vae", "pixel-pred","cpc","st-dim", "supervised"] 
methods_filt = deepcopy(methods)
methods_filt.remove("supervised")
ldf = cat_avg(df=ppo_df,methods=methods_filt)

ldf = manual_add(ppo_df,ldf,"supervised")
ldf= ldf[methods]
print_latex(ldf)

### Table 7

In [None]:
cw_df = all_cat(ppo_df,methods)
print_latex(cw_df)

### Table 8
Ablation Version of Table 2

In [None]:
methods = ["jsd-st-dim", "global-t-dim", "st-dim"] 
adf = cat_avg(ra_df,methods)
print_latex(adf, ignore_last_column=False)

### Figure 3 Data

In [None]:
jsd, globalt, stdim = [list(adf.values[:,i]) for i in range(3)]
ablation_arrays =dict(jsd=jsd, globalt=globalt, stdim=stdim, keys=list(adf.index.values))
print(ablation_arrays)

### Table 9
Ablation Version of Table 3

In [None]:
methods = ["jsd-st-dim", "global-t-dim", "st-dim"] 
acdf = all_cat(ra_df, methods)
print_latex(acdf,ignore_last_column=False)

###  Table 10 
RL Probe With Data from PPO

In [None]:
ppo_df = get_main_df(collect_mode="pretrained_ppo")
methods = [ "maj-clf","random-cnn","pretrained-rl-agent"] 
ldf = cat_avg(df=ppo_df,methods=methods)
print_latex(ldf,ignore_last_column=False)

### Tables 11-33
Fine-Grained Table for Every Game

In [None]:
for game in atari_dict.keys():
    gmdf = get_game(ra_df,game,methods=[ "maj-clf","random-cnn","vae", "pixel-pred","cpc","st-dim", "supervised"] )
    gmdf = gmdf.dropna()
    if "supervised" in gmdf:
        ignore_last_column=True
        s = "|"
    else:
        s=""
        ignore_last_column=False
    print("\\begin{table*}[ht] \\caption{%s fine-grained results. Breakdown of F1 Scores for every state variable in %s  for every method for probes where data was collected by random agent} \\label{%s-inc} \\vskip 0.15in \\begin{center} \\begin{small} \\begin{sc}"%(game.capitalize(),game.capitalize(),game.capitalize()))   
    print(" \\scalebox{0.9}{ \\begin{adjustbox}{center}")
    print(add_bold_max(gmdf.to_latex(),ignore_last_column=ignore_last_column).replace("rrr}", "rr%sr}"%s))
    print("\\end{adjustbox}}")
    print("""\\end{sc}
\\end{small}
\\end{center}
\\vskip -0.1in
\\end{table*}""")