In [1]:
import wandb
import pandas as pd
import numpy as np

In [2]:
from src.atari_ram_annotations import summary_key_dict, atari_dict, unused_keys

In [4]:
def get_dicts_for(method):
    rem_games =  ["VideoPinball", "Asteroids"]
    api = wandb.Api()
    if method in ["cpc","vae","spatial-appo"]:
        runs = list(api.runs("curl-atari/curl-atari-2", {"config.method": method,
                                                         "config.collect_mode":"random_agent",
                                                         "config.train_encoder":True,
                                                         "config.feature_size":256,
                                                         "state": "finished"}))
    else:
        runs = list(api.runs("curl-atari/curl-atari-2", {"config.method": method,
                                                 "config.collect_mode":"random_agent",
                                                 "state": "finished"}))


    run_info_dict = {run.config["env_name"].replace("NoFrameskip-v4","")  :run.config for run in runs}
    run_dict = {run.config["env_name"].replace("NoFrameskip-v4","")  :run.summary_metrics for run in runs }
    run_dict = {k:v for k,v in run_dict.items() if k not in rem_games}

    run_summary_dict = {env_meth:{} for env_meth in run_dict.keys()}

    overall_summary_dict = {k:[] for k in summary_key_dict.keys()}

    for run_name, run in run_dict.items():

        
        #print(run_name)
        for summary_name, key_list in summary_key_dict.items():
            run_mean_values = {k.lower():v for k,v in run.items() if "acc" in k and "stderr" not in k  and  any(sum_key.lower() in k.lower() for sum_key in key_list)}
            #print(run_mean_values)
            #run_stderr_values = {k:v for k,v in run.items() if "stderr" in k and "mean" not in k and any(sum_key in k.lower() for sum_key in key_list)}
            if len(run_mean_values) > 0:
                run_summary = np.mean(list(run_mean_values.values()))
                overall_summary_dict[summary_name].append(run_summary)
                run_summary_dict[run_name][summary_name] = run_summary
                #print("\t",summary_name,run_summary_dict[run_name][summary_name])
            #overall_summary_dict["overall"] = 
    overall_summary_dict = {k:np.mean(v) for k,v in overall_summary_dict.items()}  
    return overall_summary_dict, run_summary_dict , runs #, run_info_dict
        

In [11]:
methods = ["random-cnn", "vae", "cpc", "spatial-appo", "supervised"]
gbg_dict = {env:{} for env in atari_dict.keys()}

overall_dict = {method:{} for method in methods}

for method in methods:
    overall_summary_dict, run_summary_dict, _ = get_dicts_for(method)
    for env, dic in run_summary_dict.items():
        if len(dic) > 0:
            gbg_dict[env.lower()][method] = dic["overall"]
        else:
            gbg_dict[env.lower()][method] = "n/a"
    overall_dict[method] = overall_summary_dict["overall"]

print(" & ".join(["game"] + methods), "\\\\")

for game, dic in gbg_dict.items():
    if game in ["videopinball", "asteroids"]:
        continue
    print(" & ".join([game] + ["%8.2f"%(100*dic[k]) if k in dic else "n/a" for k in methods ]), "\\\\")


print(" & ".join(["overall"] + ["%8.2f"%(100*overall_dict[k]) if k in overall_dict else "n/a" for k in methods ]), "\\\\") 

In [7]:
overall_keys = ["overall"]
loc_keys = ['localization',
 'small_object_localization',
 'agent_localization',
 'enemy_localization']
misc_keys=['relative_position',
 'direction',
 'score',
 'level_room',
 'count_display',
 'existence',
 'speed']

In [10]:

print(" & ".join(loc_keys))
for method in methods:
    overall_summary_dict, run_summary_dict, _ = get_dicts_for(method)


    loc_nums = ["%8.2f"%(100*overall_summary_dict[k]) for k in loc_keys]
    print(method + " & "," & ".join(loc_nums), "\\\\")

print(" & ".join(misc_keys))    
for method in methods:
    overall_summary_dict, run_summary_dict, _ = get_dicts_for(method)



    misc_nums = ["%8.2f"%(100*overall_summary_dict[k]) for k in misc_keys]
    print(method + " & "," & ".join(misc_nums), "\\\\")

In [10]:
#rl agent parsing

In [9]:
loc_keys = ['overall',
 'localization',
 'small_object_localization',
 'agent_localization',
 'enemy_localization']

keys = ['overall',
 'small_object_localization',
 'agent_localization',
 'enemy_localization',"score_lives", 
 'count_display','level_room','direction', 'existence']

summary_stats = {env:{sk:0 for sk in summary_key_dict.keys()} for env in atari_dict.keys()}

for env in atari_dict.keys():
    for k in atari_dict[env].keys():
        for sk,v in summary_key_dict.items():
            if k in v:
                summary_stats[env][sk] +=1

print(" & ".join(["game"] + list(keys)))

for env,v in summary_stats.items():
    print( " & ".join([env] + list(str(v[k]) for k in keys)), "\\\\")



def get_pretrained_rl_dicts(algo):
    api = wandb.Api()
    runs = list(api.runs("curl-atari/curl-atari-2", {"config.method": "pretrained-rl-agent",
                                                     "config.zoo_algos":[algo],
                                                     "config.probe_collect_mode":"atari_zoo",
                                                     "state": "finished"}))

    runs = [run for run in runs if "test_mean_reward_per_episode" in run.summary_metrics.keys()]

    run_info_dict = {run.config["env_name"].replace("NoFrameskip-v4","")  :run.config for run in runs}
    run_dict = {run.config["env_name"].replace("NoFrameskip-v4","")  :run.summary_metrics for run in runs}

    run_summary_dict = {env_meth:{} for env_meth in run_dict.keys()}

    overall_summary_dict = {k:[] for k in list(summary_key_dict.keys()) + ["reward_per_episode"]}

    for run_name, run in run_dict.items():

        #print(run_name)
        for summary_name, key_list in summary_key_dict.items():
            run_mean_values = {k.lower():v for k,v in run.items() if "acc" in k and "stderr" not in k\
                               and  any(sum_key.lower() in k.lower() for sum_key in key_list)}
            #run_stderr_values = {k:v for k,v in run.items() if "stderr" in k and "mean" not in k and any(sum_key in k.lower() for sum_key in key_list)}
            if len(run_mean_values) > 0:
                run_summary = np.mean(list(run_mean_values.values()))
                overall_summary_dict[summary_name].append(run_summary)
                run_summary_dict[run_name][summary_name] = run_summary
        overall_summary_dict["reward_per_episode"].append(run["test_mean_reward_per_episode"])
        run_summary_dict[run_name]["reward_per_episode"] = run["test_mean_reward_per_episode"]
                #print("\t",summary_name,run_summary_dict[run_name][summary_name])
    overall_summary_dict = {k:np.mean(v) for k,v in overall_summary_dict.items()} 
    return overall_summary_dict, run_summary_dict #, run_dict, run_info_dict

rlag = {k:{} for k in atari_dict.keys()}

for algo in ["a2c", "apex"]:
    _, run_summary_dict = get_pretrained_rl_dicts(algo)
    for k, v in run_summary_dict.items():
        rlag[k.lower()][algo] = (v["overall"]*100,v['reward_per_episode']) 
    

print(" & ".join(["game", "a2c probe", "a2c returns","apex probe", "a2c returns"]), "\\\\")
for game, dic in rlag.items():
    #print(dic)
    if len(dic) > 1:
        all_res = list(dic["a2c"]) + list(dic["apex"])
        all_res = ["%.2f"%(float(n)) for n in all_res]
        print(game, " & " ," & ".join(all_res), "\\\\")
    


rlov = {}

for algo in ["a2c", "apex"]:
    osd, _ = get_pretrained_rl_dicts(algo)
    rlov[algo + "probe avg"] = osd["overall"]*100
    rlov[algo + "avg return per ep per game"] = osd['reward_per_episode']


a2c0,a2c1 = ["%8.2f"%rlov[k] for k in rlov.keys() if "a2c" in k]

apex0,apex1 = ["%8.2f"%rlov[k] for k in rlov.keys() if "apex" in k]

alla2c_probes = [v["a2c"][0] for v in rlag.values() if len(v) > 0]

alla2c_returns = [v["a2c"][1] for v in rlag.values() if len(v) > 0]

a2c_corr = np.corrcoef(alla2c_probes,alla2c_returns)[0][1]

allapex_probes =  [v["apex"][0] for v in rlag.values() if len(v) > 0]
allapex_returns =  [v["apex"][1] for v in rlag.values() if len(v) > 0]

apex_corr = np.corrcoef(allapex_probes,allapex_returns)[0][1]

print(" & ".join(["method", "probe_score", "avg_returns all games","correlation"]))
print(" & ".join(["a2c",a2c0,a2c1,str(a2c_corr)]))
print(" & ".join(["apex",apex0,apex1,str(apex_corr)]))