In [1]:
import sys
#sys.path.append("/Users/evanracah/Dropbox/projects/atari-representation-learning/")
import wandb
import pandas as pd
import numpy as np
from copy import deepcopy
import json
from matplotlib import pyplot as plt
import tabulate
wandb_proj = "eracah/coors-production"


In [2]:
api = wandb.Api()

runs = api.runs(wandb_proj,filters={"state":"finished", "tags":"train", "config.method":"stdim", "config.ablations":["normalize", "structure-loss"]})

ids = [run.id for run in runs]
ids.reverse()

In [3]:
def get_table_for(metric_name, methods=["stdim","scn"]):
    api = wandb.Api()
    runs = list(api.runs(wandb_proj,filters={"state":"finished"}))
    runs = [run for run in runs if  metric_name in run.summary_metrics ]
    rd = [run.summary_metrics for run in runs if  metric_name in run.summary_metrics]
    df = pd.DataFrame(rd)
    df['env_name'] = [run.config['env_name'].split("NoFrameskip")[0].lower().capitalize() for run in runs]
    df['method'] = [ run.config['method'] + ("_" + "_".join(run.config["ablations"]) if "ablations" in run.config and len(run.config["ablations"]) > 0 else "" )  for run in runs]

    df = df[[metric_name, "method", "env_name"]]
  #methods = list(set(df.method))
    mdfs = [df[df.method==method].set_index("env_name").rename(columns={metric_name:method})[method] for method in methods]
    adfs =  [avg_runs(m,m.name).set_index("env_name") for m in mdfs]
    final_df = pd.concat(adfs,axis=1,sort=True)

    
    fdft = final_df.T
    fdft["Overall"] = final_df.mean().round(2)
    fdft = fdft.round(2).T



    #latex_ldf = final_df.to_latex()
    #bold_latex = add_bold_max(latex_ldf, ignore_last_column=False, ignore_first_column=False)
    #print(latex_ldf)
    return fdft

In [4]:
def avg_runs(df,method):
    gd = []
    games = set(list(df.keys()))
    for game in games:
        m,v = df[game].mean(), df[game].var()
        if len(df[game].shape) > 0:
            count = df[game].shape[0]
        else:
            count = 1
        stderr = v / np.sqrt(count)
        dic = {"env_name":game,method:m}
        #dic = dict(env_name=game,mean=m) #,stderr=stderr,nruns=count)
        gd.append(dic)
    return pd.DataFrame(gd)

    

In [14]:
methods=["stdim_structure-loss", "stdim_structure-loss_normalize"]

In [15]:
df = get_table_for("concat_overall_avg_r2", methods=methods)
print(tabulate.tabulate(df,headers=["env_name",*methods], tablefmt="pipe"))

| env_name   |   stdim_structure-loss |   stdim_structure-loss_normalize |
|:-----------|-----------------------:|---------------------------------:|
| Breakout   |                   0.75 |                             0.75 |
| Freeway    |                 nan    |                             0.89 |
| Mspacman   |                   0.39 |                             0.37 |
| Overall    |                   0.57 |                             0.67 |


In [11]:
ldf = get_table_for("concat_overall_localization_avg_r2", methods=methods)
print(tabulate.tabulate(ldf,headers=["env_name",*methods], tablefmt="pipe"))

| env_name   |   stdim_structure-loss |   stdim_structure-loss_normalize |
|:-----------|-----------------------:|---------------------------------:|
| Breakout   |                   0.6  |                             0.62 |
| Freeway    |                 nan    |                             0.89 |
| Mspacman   |                   0.24 |                             0.23 |
| Overall    |                   0.42 |                             0.58 |


In [14]:
ldf = get_table_for("concat_overall_localization_avg_f1", methods=methods)
print(tabulate.tabulate(ldf,headers=["env_name",*methods], tablefmt="pipe"))

| env_name      |   stdim_structure-loss |   stdim_structure-loss_normalize |
|:--------------|-----------------------:|---------------------------------:|
| Boxing        |                   0.58 |                           nan    |
| Breakout      |                   0.79 |                             0.79 |
| Freeway       |                 nan    |                             0.98 |
| Mspacman      |                   0.6  |                             0.66 |
| Pong          |                   0.73 |                           nan    |
| Seaquest      |                   0.65 |                           nan    |
| Spaceinvaders |                   0.72 |                           nan    |
| Overall       |                   0.68 |                             0.81 |


In [None]:
| env_name      |   stdim |   scn |   scn_loss1-only |   scn_loss1-only_hinge-loss |
|:--------------|--------:|------:|-----------------:|----------------------------:|
| Boxing        |    0.73 |  0.22 |             0.21 |                        0.12 |
| Breakout      |    0.92 |  0.74 |             0.78 |                        0.5  |
| Freeway       |    0.92 |  0.52 |             0.65 |                        0.15 |
| Mspacman      |    0.71 |  0.48 |             0.53 |                        0.48 |
| Pong          |    0.84 |  0.67 |             0.6  |                        0.33 |
| Seaquest      |    0.68 |  0.6  |             0.65 |                        0.62 |
| Spaceinvaders |    0.65 |  0.42 |             0.43 |                        0.4  |
| Overall       |    0.71 |  0.54 |             0.55 |                        0.37 |

In [34]:
stdf = pd.read_csv("~/Dropbox/atariari.csv")



games = [s.capitalize() for s in list(stdf["Game"])]

stdf["Game"] = games

stdf = stdf.set_index("Game")
stdf = stdf[["random-cnn", "vae", "pixel-pred", "cpc", "st-dim"]]

In [35]:
both_df = pd.concat([ac_df,stdf],axis=1)

In [36]:
print(tabulate.tabulate(both_df,headers=["env_name",*both_df.columns], tablefmt="pipe"))

| env_name         |   supervised |   stdim |   scn |   random-cnn |   vae |   pixel-pred |   cpc |   st-dim |
|:-----------------|-------------:|--------:|------:|-------------:|------:|-------------:|------:|---------:|
| Asteroids        |         0.51 |    0.49 |  0.49 |         0.34 |  0.36 |         0.34 |  0.42 |     0.49 |
| Berzerk          |         0.64 |    0.59 |  0.56 |         0.43 |  0.45 |         0.55 |  0.56 |     0.53 |
| Bowling          |         0.92 |    0.84 |  0.76 |         0.48 |  0.5  |         0.81 |  0.9  |     0.96 |
| Boxing           |         0.79 |    0.73 |  0.22 |         0.19 |  0.2  |         0.44 |  0.29 |     0.58 |
| Breakout         |         0.93 |    0.92 |  0.74 |         0.51 |  0.57 |         0.7  |  0.74 |     0.88 |
| Demonattack      |         0.81 |    0.73 |  0.46 |         0.26 |  0.26 |         0.32 |  0.57 |     0.69 |
| Freeway          |         0.97 |    0.92 |  0.59 |         0.5  |  0.01 |         0.81 |  0.47 |     0.81 |
|

In [11]:
| env_name         |   supervised |   stdim |   scn |   random-cnn |   vae |   pixel-pred |   cpc |   st-dim |
|:-----------------|-------------:|--------:|------:|-------------:|------:|-------------:|------:|---------:|
| Asteroids        |         0.51 |    0.49 |  0.49 |         0.34 |  0.36 |         0.34 |  0.42 |     0.49 |
| Berzerk          |         0.64 |    0.59 |  0.56 |         0.43 |  0.45 |         0.55 |  0.56 |     0.53 |
| Bowling          |         0.92 |    0.84 |  0.76 |         0.48 |  0.5  |         0.81 |  0.9  |     0.96 |
| Boxing           |         0.79 |    0.73 |  0.22 |         0.19 |  0.2  |         0.44 |  0.29 |     0.58 |
| Breakout         |         0.93 |    0.92 |  0.74 |         0.51 |  0.57 |         0.7  |  0.74 |     0.88 |
| Demonattack      |         0.81 |    0.73 |  0.46 |         0.26 |  0.26 |         0.32 |  0.57 |     0.69 |
| Freeway          |         0.97 |    0.92 |  0.59 |         0.5  |  0.01 |         0.81 |  0.47 |     0.81 |
| Frostbite        |         0.88 |    0.71 |  0.65 |         0.57 |  0.51 |         0.72 |  0.76 |     0.75 |
| Hero             |         0.98 |    0.95 |  0.86 |         0.75 |  0.69 |         0.74 |  0.9  |     0.93 |
| Montezumarevenge |         0.93 |    0.83 |  0.72 |         0.68 |  0.38 |         0.74 |  0.75 |     0.78 |
| Mspacman         |         0.84 |    0.71 |  0.51 |         0.49 |  0.56 |         0.74 |  0.65 |     0.72 |
| Pitfall          |         0.93 |    0.81 |  0.59 |         0.34 |  0.35 |         0.44 |  0.46 |     0.6  |
| Pong             |         0.88 |    0.84 |  0.64 |         0.17 |  0.09 |         0.7  |  0.71 |     0.81 |
| Privateeye       |         0.99 |    0.87 |  0.75 |         0.7  |  0.71 |         0.83 |  0.81 |     0.91 |
| Qbert            |         0.83 |    0.64 |  0.52 |         0.49 |  0.49 |         0.52 |  0.65 |     0.73 |
| Riverraid        |         0.55 |    0.45 |  0.3  |         0.34 |  0.26 |         0.41 |  0.4  |     0.36 |
| Seaquest         |         0.8  |    0.68 |  0.6  |         0.57 |  0.56 |         0.62 |  0.66 |     0.67 |
| Spaceinvaders    |         0.79 |    0.65 |  0.42 |         0.41 |  0.52 |         0.57 |  0.54 |     0.57 |
| Tennis           |         0.79 |    0.69 |  0.55 |         0.41 |  0.29 |         0.57 |  0.6  |     0.6  |
| Venture          |         0.58 |    0.55 |  0.46 |         0.36 |  0.38 |         0.46 |  0.51 |     0.58 |
| Videopinball     |         0.78 |    0.59 |  0.45 |         0.37 |  0.45 |         0.57 |  0.58 |     0.61 |
| Yarsrevenge      |         0.45 |    0.46 |  0.13 |         0.22 |  0.08 |         0.19 |  0.39 |     0.42 |
| Overall          |         0.8  |    0.71 |  0.54 |         0.44 |  0.4  |         0.58 |  0.61 |     0.68 |

SyntaxError: invalid syntax (<ipython-input-11-ce1b7d381fe0>, line 1)

In [96]:
api = wandb.Api()

sup_runs = api.runs(wandb_proj,filters={"config.method": "supervised"})

" ".join([s.id for s in sup_runs])

