In [1]:
import sys
#sys.path.append("/Users/evanracah/Dropbox/projects/atari-representation-learning/")
import wandb
import pandas as pd
import numpy as np
from copy import deepcopy
import json
from matplotlib import pyplot as plt
import tabulate
wandb_proj = "eracah/coors-production"


def grab_relevant_tr_runs():
    """
    grabs relevant runs for a given method (and have computed the metric that is desired)
    """
    api = wandb.Api()
    runs = list(api.runs(wandb_proj, order='-created_at',filters={"state":"finished", "tags": "train"}))
    print(" ".join([run.id for run in runs]))


    return runs

def grab_relevant_wandb_runs(tags="eval"):
    """
    grabs relevant runs for a given method (and have computed the metric that is desired)
    """
    api = wandb.Api()
    runs = list(api.runs(wandb_proj,filters={"state":"finished", "tags": tags}))
    return runs

def get_ablations(config):
    if config['method'] == "slot-stdim":
        return "_" + "_".join(config["losses"]) if "losses" in config and len(config["losses"]) > 0 else ""
    else:
        return ""

def is_random_cnn(config):
    if "random_cnn" in config and config["random_cnn"]:
        return True
    else:
        return False

def make_dataframe_from_runs(runs_list, metric_name):
    rd = [run.summary_metrics for run in runs_list]
    df = pd.DataFrame(rd)
    env_names = [run.config['env_name'].split("NoFrameskip")[0].lower().capitalize() for run in runs_list]
    methods = [ ("random_" if is_random_cnn(run.config) and "random" not in run.config['method'] else "") + run.config['method'] + get_ablations(run.config)  for run in runs_list]
    
    df['env_name'] = env_names
    df['method'] = methods
    df = df[[metric_name, "method", "env_name"]]
    return df


def make_mean_df(df, metric_name):
    # compute mean of multiple runs
    mdf = df.groupby(["env_name","method"]).mean().reset_index() 
    mdf = mdf.pivot(index="env_name", columns="method", values=metric_name) # pivot table so each method has own column
    # remove unnecessary column set names
    mdf.columns.name = None 
    mdf.index.name = None
    avg_of_means_series = mdf.mean(axis=0)
    avg_of_means_df = pd.DataFrame([avg_of_means_series], index=["Overall"])
    mdf = pd.concat([mdf,avg_of_means_df],sort=False)
    mdf = mdf.round(3)
    return mdf

def make_std_df(df, metric_name):
    sdf = df.groupby(["env_name","method"]).std().reset_index() # compute std for any duplicate runs (runs that share same env_name and method)
    sdf = sdf.pivot(index="env_name", columns="method", values=metric_name) # pivot so method names are columns 
#     sdf = sdf.round(3)
    sdf.columns.name = None
    sdf.index.name = None
    var_df = sdf**2
    avg_var_series = var_df.mean(axis=0)
    avg_std_series = avg_var_series.pow(0.5)
    avg_std_df = pd.DataFrame([avg_std_series], index=["Overall"])
    sdf = pd.concat([sdf,avg_std_df],sort=False)

    
    return sdf

def make_count_df(df, metric_name):
    cdf = df.groupby(["env_name","method"]).count().reset_index() 
    cdf = cdf.pivot(index="env_name", columns="method", values=metric_name) # pivot so method names are columns 
#     cdf = cdf.round(3)
    cdf.columns.name = None
    cdf.index.name = None
    total_count = pd.DataFrame([cdf.sum(axis=0)], index=["Overall"])
    cdf = pd.concat([cdf, total_count],sort=False)
    return cdf

def make_stderr_df(std_df, count_df):
    stderr_df = std_df / count_df**(1/2)
    stderr_df = stderr_df.round(3)
    return stderr_df

    
def add_error_bars(mean_df,stderr_df, count_df):
    main_df = stderr_df.copy()
    for method in mean_df.columns:
        plus_minus_strings = []
        mean_std = list(zip(mean_df[method], stderr_df[method], count_df[method]))
        for mn, stderr, count in mean_std:
            if np.isnan(stderr):
                plus_minus_strings.append(str(mn))
            else:
                plus_minus_strings.append(str(mn) + "+-" + str(stderr) + " (" + str(int(count)) + ")")
        main_df[method] = plus_minus_strings
    return main_df

def get_df(metric_name,methods=None, tags="eval"):
    runs = grab_relevant_wandb_runs(tags=tags)
    df = make_dataframe_from_runs(runs, metric_name=metric_name)
    if methods:
        df = df[df.method.isin(methods)]
    mean_df = make_mean_df(df, metric_name)
    std_df = make_std_df(df, metric_name)
    count_df = make_count_df(df,metric_name=metric_name)
    stderr_df = make_stderr_df(std_df, count_df)
    final_df = add_error_bars(mean_df, stderr_df, count_df)
    return final_df, count_df, mean_df,stderr_df

    

def plot_bars(mean_df, std_df=None, da_title="", map_dict=None):
    x_labels = list(mean_df.index)
    bar_labels = list(mean_df.columns)
    mean_heights = [mean_df[[bar_label]].to_numpy().squeeze() for bar_label in bar_labels]
    
    if std_df:
        std_heights = [std_df[[bar_label]].to_numpy().squeeze() for  bar_label in bar_labels]
    else:
        std_heights = [0. for  bar_label in bar_labels]
    if map_dict:
        bar_labels = [map_dict[xl] if xl in map_dict.keys() else xl for xl in bar_labels]
    x_inds = np.arange(len(x_labels))  # the label locations
    width = 0.10  # the width of the bars
    rect_placements = [x_inds + width*i for i in range(len(bar_labels))]
    fig, ax = plt.subplots()
    rects = [ax.bar(x=rect_placements[i], height=mean_heights[i],label=bar_labels[i], yerr=std_heights[i],width=width) for i in range(len(bar_labels))]
    ax.set_xticks(x_inds)
    ax.set_xticklabels(x_labels)
    ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.07),
          fancybox=True, shadow=True, ncol=5)
    ax.set_title(da_title)
    
def add_bold_max(latex_str, ignore_last_column=True, ignore_first_column=False, margin=0.01):
    strs = []
    for r in latex_str.split("\\\\"):
        if "." in r:
            s = r.split("&")
            nums = [float(n) for n in s[1:] if "." in n or "NaN" in  n]
            #print(nums)
            am_nums = deepcopy(nums)
            if ignore_last_column:
                am_nums = am_nums[:-1]
                
            if ignore_first_column:
                am_nums = am_nums[1:]
  
                
                
            filt_nums = [n if not np.isnan(n) else -1 for n in am_nums ]
            amax, maxx = [np.argmax(filt_nums)], np.max(filt_nums)
            for i,n in enumerate(filt_nums):
                if i == amax[0]:
                    continue
                if n+margin >= maxx:
                    amax.append(i)
            for am in amax:
                if ignore_first_column:
                    am = am+1
                nums[am] = "\\textbf{%5.2f} "%(nums[am])
            nums = ["%5.2f"%n if not isinstance(n,str) else n for n in nums ]
            nr = [s[0]] + nums
            nr = "  &  ".join(nr)
        else:
            nr = r
        strs.append(nr)

    final_latex = "\\\\".join(strs)
    final_latex = final_latex.replace("nan","NaN")
    return final_latex


In [2]:
metric_name =  'concat_overall_localization_avg_r2_lin_reg'

In [3]:
acc_methods = ['random-cnn',

 'slot-stdim_scn_sdl',
                "cswm",
'stdim',
 'supervised']
disent_methods = ['random-cnn',

 'slot-stdim_scn_sdl',
                "cswm",
 'supervised']
ablations = ['slot-stdim_scn' ,'slot-stdim_scn_sdl']

map_dict={"slot-stdim_hcn": "slot-dim_loss1only", "slot-stdim_scn": "scn_loss1only", "slot-stdim_scn_sdl": "scn", "slot-stdim_hcn_smdl":"slot-stdim"}

final_df, count_df, mean_df,stderr_df = get_df(metric_name=metric_name)
#dcid_df, count_df, dcid_mean_df, dcid_std_df = get_df(metric_name="dci_d_gbt", methods=disent_methods)
#dcic_df, count_df, dcic_mean_df, dcic_std_df = get_df(metric_name="dci_c_gbt",methods=disent_methods)


In [4]:
mean_df = mean_df.drop(["Qbert"]).drop(["Hero"])

In [5]:
main_df = mean_df[acc_methods].rename(columns=map_dict)

In [10]:
print(add_bold_max(main_df.to_latex()))

\begin{tabular}{lrrrrr}
\toprule
{} &  random-cnn &    scn &   cswm &  stdim &  supervised \\
\midrule
Asteroids          &   0.07  &   0.06  &    NaN  &  \textbf{ 0.14}   &   0.32\\
Berzerk            &   0.30  &  \textbf{ 0.42}   &   0.40  &  \textbf{ 0.42}   &   0.71\\
Bowling            &   0.44  &   0.86  &   0.95  &  \textbf{ 0.98}   &   1.00\\
Boxing             &   0.60  &   0.81  &   0.41  &  \textbf{ 0.94}   &   1.00\\
Breakout           &   0.30  &   0.47  &  \textbf{ 0.59}   &   0.57  &   0.64\\
Demonattack        &   0.05  &  \textbf{ 0.22}   &   0.06  &   0.18  &   0.67\\
Freeway            &   0.78  &  \textbf{ 0.90}   &   0.69  &   0.82  &   0.99\\
Frostbite          &  \textbf{ 0.79}   &  \textbf{ 0.78}   &   0.74  &   0.78  &   0.97\\
Montezumarevenge   &   0.66  &   0.85  &   0.84  &  \textbf{ 0.86}   &   0.99\\
Mspacman           &   0.11  &   0.07  &  -0.13  &  \textbf{ 0.23}   &   0.80\\
Pitfall            &   0.37  &   0.34  &   0.36  &  \textbf{ 0.50}   &   0.85

In [35]:
abl_df = mean_df[ablations].rename(columns=map_dict)

In [38]:
print(abl_df.to_latex())

\begin{tabular}{lrr}
\toprule
{} &  scn\_loss1only &    scn \\
\midrule
Asteroids        &          0.038 &  0.057 \\
Berzerk          &          0.318 &  0.420 \\
Bowling          &          0.867 &  0.860 \\
Boxing           &          0.763 &  0.810 \\
Breakout         &          0.464 &  0.474 \\
Demonattack      &          0.159 &  0.221 \\
Freeway          &          0.688 &  0.902 \\
Frostbite        &          0.769 &  0.777 \\
Montezumarevenge &          0.838 &  0.848 \\
Mspacman         &          0.110 &  0.071 \\
Pitfall          &          0.415 &  0.335 \\
Pong             &          0.744 &  0.780 \\
Privateeye       &          0.600 &  0.585 \\
Riverraid        &          0.305 &  0.355 \\
Seaquest         &          0.473 &  0.492 \\
Spaceinvaders    &          0.504 &  0.514 \\
Tennis           &          0.491 &  0.515 \\
Venture          &          0.165 &  0.185 \\
Videopinball     &          0.415 &  0.387 \\
Yarsrevenge      &          0.105 &  0.139 \\
Overall 

In [8]:
slot_mod = dcid_mean_df[main_methods]
print(slot_mod.to_latex())

\begin{tabular}{lrrrr}
\toprule
{} &  random-cnn &  slot-stdim\_scn\_sdl &   cswm &  supervised \\
\midrule
Asteroids        &       0.307 &               0.317 &    NaN &         NaN \\
Boxing           &       0.272 &               0.290 &  0.029 &       0.952 \\
Breakout         &       0.474 &                 NaN &    NaN &       0.767 \\
Demonattack      &       0.237 &               0.295 &  0.096 &         NaN \\
Freeway          &         NaN &               0.545 &    NaN &       0.491 \\
Frostbite        &         NaN &                 NaN &  0.260 &         NaN \\
Montezumarevenge &       0.593 &                 NaN &    NaN &         NaN \\
Mspacman         &       0.230 &               0.213 &    NaN &       0.659 \\
Pitfall          &       0.444 &                 NaN &  0.499 &         NaN \\
Pong             &       0.384 &                 NaN &    NaN &         NaN \\
Privateeye       &       0.532 &                 NaN &  0.444 &         NaN \\
Qbert            &     

In [9]:
slot_mod = dcic_mean_df[main_methods]
print(slot_mod.to_latex())

\begin{tabular}{lrrrr}
\toprule
{} &  random-cnn &  slot-stdim\_scn\_sdl &   cswm &  supervised \\
\midrule
Asteroids        &       0.213 &               0.241 &    NaN &         NaN \\
Boxing           &       0.180 &               0.224 &  0.091 &       0.557 \\
Breakout         &       0.150 &                 NaN &    NaN &       0.431 \\
Demonattack      &       0.117 &               0.162 &  0.105 &         NaN \\
Freeway          &         NaN &               0.363 &    NaN &       0.609 \\
Frostbite        &         NaN &                 NaN &  0.341 &         NaN \\
Montezumarevenge &       0.274 &                 NaN &    NaN &         NaN \\
Mspacman         &       0.118 &               0.111 &    NaN &       0.422 \\
Pitfall          &       0.342 &                 NaN &  0.395 &         NaN \\
Pong             &       0.169 &                 NaN &    NaN &         NaN \\
Privateeye       &       0.308 &                 NaN &  0.260 &         NaN \\
Qbert            &     

In [9]:
plot_bars(mean_df.loc[games][main_methods], da_title="Slot Accuracy Scores (R^2 Score)", map_dict=map_dict)
#plot_bars(dcid_mean_df.loc[games][main_methods], da_title="Slot Modularity Scores",map_dict=map_dict)
#plot_bars(dcic_mean_df.loc[games][main_methods], da_title="Slot Compactness Scores",map_dict=map_dict)

plot_bars(mean_df.loc[games][ablations1], da_title="Slot Accuracy Scores (R^2 Score)",map_dict=map_dict)
plot_bars(dcid_mean_df.loc[games][ablations1], da_title="Slot Modularity Scores",map_dict=map_dict)
plot_bars(dcic_mean_df.loc[games][ablations1], da_title="Slot Compactness Scores",map_dict=map_dict)

games.remove("Breakout")

plot_bars(mean_df.loc[games][ablations2], da_title="Slot Accuracy Scores (R^2 Score)",map_dict=map_dict)
plot_bars(dcid_mean_df.loc[games][ablations2], da_title="Slot Modularity Scores",map_dict=map_dict)
plot_bars(dcic_mean_df.loc[games][ablations2], da_title="Slot Compactness Scores",map_dict=map_dict)

