In [1]:
import sys
#sys.path.append("/Users/evanracah/Dropbox/projects/atari-representation-learning/")
import wandb
import pandas as pd
import numpy as np
from copy import deepcopy
import json
from matplotlib import pyplot as plt
import tabulate
wandb_proj = "eracah/coors-production"


In [6]:
def grab_relevant_wandb_runs():
    """
    grabs relevant runs for a given method (and have computed the metric that is desired)
    """
    api = wandb.Api()
    runs = list(api.runs(wandb_proj,filters={"state":"finished", "tags": "eval"}))
    return runs


def make_dataframe_from_runs(runs_list, metric_name):
    rd = [run.summary_metrics for run in runs_list]
    df = pd.DataFrame(rd)
    env_names = [run.config['env_name'].split("NoFrameskip")[0].lower().capitalize() for run in runs_list]
    methods = [run.config['method'] + ("_" + "_".join(run.config["ablations"]) if "ablations" in run.config and len(run.config["ablations"]) > 0 else "" )  for run in runs_list]
    df['env_name'] = env_names
    df['method'] = methods
    df = df[[metric_name, "method", "env_name"]]
    return df


def make_mean_df(df, metric_name):
    # compute mean of multiple runs
    mdf = df.groupby(["env_name","method"]).mean().reset_index() 
    mdf = mdf.pivot(index="env_name", columns="method", values=metric_name) # pivot table so each method has own column
    # remove unnecessary column set names
    mdf.columns.name = None 
    mdf.index.name = None

    mdf = mdf.round(2)
    return mdf

def make_std_df(df, metric_name):
    sdf = df.groupby(["env_name","method"]).std().reset_index() # compute std for any duplicate runs (runs that share same env_name and method)
    sdf = sdf.pivot(index="env_name", columns="method", values=metric_name) # pivot so method names are columns 
    sdf = sdf.round(2)
    sdf.columns.name = None
    sdf.index.name = None
    return sdf
  
def add_error_bars(mean_df,std_df):
    for method in mean_df.columns:
        plus_minus_strings = []
        mean_std = list(zip(mean_df[method], std_df[method]))
        for mn, std in mean_std:
            if np.isnan(std):
                plus_minus_strings.append(str(mn))
            else:   
                plus_minus_strings.append(str(mn) + "+-" + str(std))
        std_df[method] = plus_minus_strings
    return std_df

def get_df(metric_name):
    runs = grab_relevant_wandb_runs()
    df = make_dataframe_from_runs(runs, metric_name=metric_name)
    mean_df = make_mean_df(df, metric_name)
    std_df = make_std_df(df, metric_name)
    final_df = add_error_bars(mean_df, std_df)
    return final_df

In [7]:
final_df = get_df(metric_name="concat_overall_localization_avg_r2")

fdf = final_df.loc[["Asteroids","Freeway","Breakout","Mspacman"]][["stdim","slot-stdim"]]
print(tabulate.tabulate(fdf,headers=["env_name",*fdf.columns], tablefmt="pipe")) 

| env_name   |   stdim |   slot-stdim |
|:-----------|--------:|-------------:|
| Asteroids  |    0.23 |         0.18 |
| Freeway    |    0.77 |         0.74 |
| Breakout   |    0.58 |         0.48 |
| Mspacman   |    0.21 |         0.14 |
