In [95]:
import sys
#sys.path.append("/Users/evanracah/Dropbox/projects/atari-representation-learning/")
import wandb
import pandas as pd
import numpy as np
from copy import deepcopy
import json
from matplotlib import pyplot as plt

wandb_proj = "eracah/coors-production"

In [23]:
import tabulate

In [24]:
def add_bold_max(latex_str, ignore_last_column=True, ignore_first_column=False, margin=0.01):
    strs = []
    for r in latex_str.split("\\\\"):
        if "." in r:
            s = r.split("&")
            nums = [float(n) for n in s[1:] if "." in n or "NaN" in  n]
            #print(nums)
            am_nums = deepcopy(nums)
            if ignore_last_column:
                am_nums = am_nums[:-1]
                
            if ignore_first_column:
                am_nums = am_nums[1:]
  
                
                
            filt_nums = [n if not np.isnan(n) else -1 for n in am_nums ]
            amax, maxx = [np.argmax(filt_nums)], np.max(filt_nums)
            for i,n in enumerate(filt_nums):
                if i == amax[0]:
                    continue
                if n+margin >= maxx:
                    amax.append(i)
            for am in amax:
                if ignore_first_column:
                    am = am+1
                nums[am] = "\\textbf{%5.2f} "%(nums[am])
            nums = ["%5.2f"%n if not isinstance(n,str) else n for n in nums ]
            nr = [s[0]] + nums
            nr = "  &  ".join(nr)
        else:
            nr = r
        strs.append(nr)

    final_latex = "\\\\".join(strs)
    final_latex = final_latex.replace("nan","NaN")
    return final_latex

def print_latex(ldf,ignore_last_column=True, ignore_first_column=False):
    latex_ldf = ldf.to_latex()
    print(add_bold_max(latex_ldf, ignore_last_column=ignore_last_column, ignore_first_column=ignore_first_column))

In [25]:
def get_table_for(metric_name, methods=["stdim","scn"]):
    api = wandb.Api()
    runs = list(api.runs(wandb_proj))
    runs = [run for run in runs if  metric_name in run.summary_metrics ]
    rd = [run.summary_metrics for run in runs if  metric_name in run.summary_metrics]
    df = pd.DataFrame(rd)
    df['env_name'] = [run.config['env_name'].split("NoFrameskip")[0].lower().capitalize() for run in runs]
    df['method'] = [ run.config['method'] for run in runs]

    df = df[[metric_name, "method", "env_name"]]
  #methods = list(set(df.method))
    mdfs = [df[df.method==method].set_index("env_name").rename(columns={metric_name:method})[method] for method in methods]
    adfs =  [avg_runs(m,m.name).set_index("env_name") for m in mdfs]
    final_df = pd.concat(adfs,axis=1,sort=True)

    
    fdft = final_df.T
    fdft["Overall"] = final_df.mean().round(2)
    fdft = fdft.round(2).T



    #latex_ldf = final_df.to_latex()
    #bold_latex = add_bold_max(latex_ldf, ignore_last_column=False, ignore_first_column=False)
    #print(latex_ldf)
    return fdft

In [89]:
def avg_runs(df,method):
    gd = []
    games = set(list(df.keys()))
    for game in games:
        m,v = df[game].mean(), df[game].var()
        if len(df[game].shape) > 0:
            count = df[game].shape[0]
        else:
            count = 1
        stderr = v / np.sqrt(count)
        dic = {"env_name":game,method:m}
        #dic = dict(env_name=game,mean=m) #,stderr=stderr,nruns=count)
        gd.append(dic)
    return pd.DataFrame(gd)

    

In [90]:
df = get_table_for("concat_overall_avg_f1")

print(tabulate.tabulate(df,headers=["env_name","stdim", "scn"], tablefmt="pipe"))

| env_name         |   stdim |   scn |
|:-----------------|--------:|------:|
| Asteroids        |    0.48 |  0.48 |
| Berzerk          |    0.62 |  0.58 |
| Bowling          |    0.84 |  0.77 |
| Boxing           |    0.77 |  0.24 |
| Breakout         |    0.9  |  0.71 |
| Demonattack      |    0.44 |  0.28 |
| Freeway          |    0.98 |  0.68 |
| Frostbite        |    0.78 |  0.68 |
| Hero             |    0.95 |  0.87 |
| Montezumarevenge |    0.76 |  0.65 |
| Mspacman         |    0.67 |  0.45 |
| Pitfall          |    0.8  |  0.6  |
| Pong             |    0.87 |  0.74 |
| Privateeye       |    0.88 |  0.77 |
| Qbert            |    0.57 |  0.41 |
| Riverraid        |    0.43 |  0.29 |
| Seaquest         |    0.66 |  0.56 |
| Spaceinvaders    |    0.69 |  0.46 |
| Tennis           |    0.65 |  0.48 |
| Venture          |    0.44 |  0.32 |
| Videopinball     |    0.59 |  0.45 |
| Yarsrevenge      |    0.32 |  0.13 |
| Overall          |    0.68 |  0.53 |


In [91]:
ac_df = get_table_for("concat_across_categories_avg_f1")
print(tabulate.tabulate(ac_df,headers=["env_name","stdim", "scn"], tablefmt="pipe"))

| env_name         |   stdim |   scn |
|:-----------------|--------:|------:|
| Asteroids        |    0.49 |  0.49 |
| Berzerk          |    0.59 |  0.56 |
| Bowling          |    0.84 |  0.76 |
| Boxing           |    0.73 |  0.22 |
| Breakout         |    0.92 |  0.74 |
| Demonattack      |    0.73 |  0.46 |
| Freeway          |    0.92 |  0.52 |
| Frostbite        |    0.71 |  0.65 |
| Hero             |    0.95 |  0.86 |
| Montezumarevenge |    0.83 |  0.72 |
| Mspacman         |    0.71 |  0.48 |
| Pitfall          |    0.81 |  0.59 |
| Pong             |    0.84 |  0.67 |
| Privateeye       |    0.87 |  0.75 |
| Qbert            |    0.64 |  0.52 |
| Riverraid        |    0.45 |  0.3  |
| Seaquest         |    0.68 |  0.6  |
| Spaceinvaders    |    0.65 |  0.42 |
| Tennis           |    0.69 |  0.55 |
| Venture          |    0.55 |  0.46 |
| Videopinball     |    0.59 |  0.45 |
| Yarsrevenge      |    0.46 |  0.13 |
| Overall          |    0.71 |  0.54 |


In [92]:
stdf = pd.read_csv("~/Dropbox/atariari.csv")



games = [s.capitalize() for s in list(stdf["Game"])]

stdf["Game"] = games

stdf = stdf.set_index("Game")
stdf = stdf[["random-cnn", "vae", "pixel-pred", "cpc", "st-dim"]]

In [93]:
both_df = pd.concat([ac_df,stdf],axis=1)

In [94]:
print(tabulate.tabulate(both_df,headers=["env_name",*both_df.columns], tablefmt="pipe"))

| env_name         |   stdim |   scn |   random-cnn |   vae |   pixel-pred |   cpc |   st-dim |
|:-----------------|--------:|------:|-------------:|------:|-------------:|------:|---------:|
| Asteroids        |    0.49 |  0.49 |         0.34 |  0.36 |         0.34 |  0.42 |     0.49 |
| Berzerk          |    0.59 |  0.56 |         0.43 |  0.45 |         0.55 |  0.56 |     0.53 |
| Bowling          |    0.84 |  0.76 |         0.48 |  0.5  |         0.81 |  0.9  |     0.96 |
| Boxing           |    0.73 |  0.22 |         0.19 |  0.2  |         0.44 |  0.29 |     0.58 |
| Breakout         |    0.92 |  0.74 |         0.51 |  0.57 |         0.7  |  0.74 |     0.88 |
| Demonattack      |    0.73 |  0.46 |         0.26 |  0.26 |         0.32 |  0.57 |     0.69 |
| Freeway          |    0.92 |  0.52 |         0.5  |  0.01 |         0.81 |  0.47 |     0.81 |
| Frostbite        |    0.71 |  0.65 |         0.57 |  0.51 |         0.72 |  0.76 |     0.75 |
| Hero             |    0.95 |  0.86 |  

In [96]:
api = wandb.Api()

sup_runs = api.runs(wandb_proj,filters={"config.method": "supervised"})

" ".join([s.id for s in sup_runs])

