In [None]:
import sys

sys.path.append("/Users/evanracah/Dropbox/projects/rl-representation-learning/")

In [None]:
import wandb
import pandas as pd
import numpy as np
from copy import deepcopy

In [None]:
from src.atari_ram_annotations import summary_key_dict, atari_dict, unused_keys, detailed_key_dict

In [None]:
def remove_dicts_in_dict(dic):
    rems = []
    for k,v in dic.items():
        if isinstance(v,dict) or isinstance(v,list):
            rems.append(k)
    for k in rems:            
        if isinstance(dic[k], list):
            for i,v in enumerate(dic[k]):
                dic[k+"_"+str(i)] = v
        dic.pop(k)
        
    return dic
            

In [None]:
def get_dicts_for(method, verbose=True):
    rem_games = ["defender","enduro"]
    api = wandb.Api()
    if method in ["random_cnn","supervised","majority"]:
        runs = list(api.runs("curl-atari/curl-atari-2", {"config.method": method,
                                                 "config.collect_mode":"random_agent",
                                                "config.probe_steps": 50000,
                                                #"config.entropy_threshold":0.3,
                                                 "state": "finished"}))
    else:
        runs = list(api.runs("curl-atari/curl-atari-2", {"config.method": method,
                                                         "config.collect_mode":"random_agent",
                                                         "config.train_encoder":True,
                                                       "config.feature_size":256,
                                                         "config.probe_steps": 50000,
                                                         "config.entropy_threshold":0.3,
                                                         "state": "finished", 
                                                        #"config.patience":10
                                                        }))



    runs_info = {run.config["env_name"].replace("NoFrameskip-v4","") + "_" + run.name :remove_dicts_in_dict(run.config) for run in runs}
    summary_metrics = {run.config["env_name"].replace("NoFrameskip-v4","") + "_" + run.name :run.summary_metrics for run in runs }
    summary_metrics = {k:v for k,v in summary_metrics.items() if k.split("_")[0].lower() not in rem_games}
    run_summary_dict = {name:{} for name in summary_metrics.keys()}
    for run_name, run in summary_metrics.items():
        for summary_name, key_list in summary_key_dict.items():
            run_mean_values = {k.lower():v for k,v in run.items() if "mean" in k and "var" not in k  and  any(sum_key.lower() in k.lower() for sum_key in key_list)}
            #print(run_mean_values)
            #run_stderr_values = {k:v for k,v in run.items() if "stderr" in k and "mean" not in k and any(sum_key in k.lower() for sum_key in key_list)}
            if len(run_mean_values) > 0:
                run_summary = np.mean(list(run_mean_values.values()))
                #overall_summary_dict[summary_name].append(run_summary)
                run_summary_dict[run_name][summary_name] = run_summary
                #print("\t",summary_name,run_summary_dict[run_name][summary_name])
        run_summary_dict[run_name]["overall"] = np.mean(list(run_summary_dict[run_name].values()))
                
    for game_name in atari_dict.keys():
        duplicates = [run_name for run_name in summary_metrics.keys() if game_name in run_name.lower()]
        num_duplicates = len(duplicates)
        if num_duplicates > 1:
            
            duplicates_info = {run_name: runs_info[run_name] for run_name in duplicates}
            if verbose:
                print("Game {} is in summary {} times as:".format(game_name, num_duplicates))
                print("\t",duplicates)
                print("\tDifferences are: ")
                for rn in list(duplicates_info.keys())[1:]:
                    rn1 = list(duplicates_info.keys())[0]
                    dup1 = duplicates_info[rn1]
                    dupi = duplicates_info[rn]
                    print("\t\t run {} differs with run {}".format(rn1,rn),set(dup1.items()).symmetric_difference(set(dupi.items())))
            duplicates_time_stamps = [summary_metrics[run_name]["_timestamp"] for run_name in duplicates ]
            most_recent_index = np.argmax(duplicates_time_stamps)
            
            for i, run_name in enumerate(duplicates):
                if i != most_recent_index:
                    run_summary_dict.pop(run_name)
    run_keys = list(run_summary_dict.keys()) # make list ahead of time cuz we are changing the dictionary inside the loop
    for run_name in run_keys:
        run_dict = deepcopy(run_summary_dict[run_name])
        run_summary_dict.pop(run_name)
        game_name = run_name.split("_")[0]
        run_summary_dict[game_name] = run_dict
            
                

    

    return run_summary_dict, runs_info, summary_metrics, runs 
        

In [None]:
def compute_overall_summary_dict(run_summary_dict):
    all_dups = []
    for game_name in atari_dict.keys():       
        duplicates = [run_name for run_name in run_summary_dict.keys() if game_name in run_name.lower()]
        all_dups.append(duplicates)
        num_duplicates = len(duplicates)
        if num_duplicates > 1:
            print("You still have duplicates!".format(game_name, num_duplicates))
            print("\t",duplicates)
            print("")
            
    overall_summary_dict = {k:[] for k in list(summary_key_dict.keys()) + ["overall"]}
    for run_name, key_dict in run_summary_dict.items():
        for key,value in key_dict.items():
            overall_summary_dict[key].append(value)
        
    overall_summary_dict = {k:np.mean(v) for k,v in overall_summary_dict.items()}


    return overall_summary_dict

In [None]:
run_summary_dict,runs_info, summary_metrics, runs = get_dicts_for("naff")

In [None]:
compute_overall_summary_dict(run_summary_dict)

In [None]:
methods = ["majority","random-cnn", "naff","vae","appo","cpc","spatial-appo","supervised"]
gbg_dict = {env:{} for env in atari_dict.keys()}
overall_dict = {method:{} for method in methods}

for method in methods:
    run_summary_dict, _, _, _ = get_dicts_for(method, verbose=False)
    overall_summary_dict = compute_overall_summary_dict(run_summary_dict)
    for env, dic in run_summary_dict.items():
        if len(dic) > 0:
            gbg_dict[env.lower()][method] = dic["overall"]
        else:
            gbg_dict[env.lower()][method] = "n/a"
    overall_dict[method] = overall_summary_dict["overall"]

print(" & ".join(["game"] + methods), "\\\\")

for game, dic in gbg_dict.items():
    if game in ["videopinball", "asteroids"]:
        continue
    print(" & ".join([game] + ["%8.2f"%(100*dic[k]) if k in dic else "n/a" for k in methods ]), "\\\\")


print(" & ".join(["overall"] + ["%8.2f"%(100*overall_dict[k]) if k in overall_dict else "n/a" for k in methods ]), "\\\\") 

In [None]:
keys = list(summary_key_dict.keys())
print(" & ".join(keys))
for method in methods:
    run_summary_dict,runs_info, summary_metrics, runs = get_dicts_for(method, verbose=False)
    overall_summary_dict = compute_overall_summary_dict(run_summary_dict)

    nums = ["%8.2f"%(100*overall_summary_dict[k]) if k in overall_summary_dict else "n/a" for k in keys]
    print(method + " & "," & ".join(nums), "\\\\")

# print(" & ".join(misc_keys))    
# for method in methods:
#     overall_summary_dict, run_summary_dict, _ = get_dicts_for(method)



#     misc_nums = ["%8.2f"%(100*overall_summary_dict[k]) for k in misc_keys]
#     print(method + " & "," & ".join(misc_nums), "\\\\")

In [None]:
keys = list(detailed_key_dict.keys())
summary_stats = {env:{sk:0 for sk in keys} for env in atari_dict.keys()}

for env in atari_dict.keys():
    for k in atari_dict[env].keys():
        for sk,v in detailed_key_dict.items():
            if k in v:
                summary_stats[env][sk] +=1

print(" & ".join(["game"] + list(keys)))

for env,v in summary_stats.items():
    print( " & ".join([env] + list(str(v[k]) for k in keys)), "\\\\")

In [None]:
def get_pretrained_rl_dicts(algo):
    api = wandb.Api()
    runs = list(api.runs("curl-atari/curl-atari-2", {"config.method": "pretrained-rl-agent",
                                                     "config.zoo_algos":[algo],
                                                     "config.probe_collect_mode":"atari_zoo",
                                                     "state": "finished"}))

    runs = [run for run in runs if "test_mean_reward_per_episode" in run.summary_metrics.keys()]

    run_info_dict = {run.config["env_name"].replace("NoFrameskip-v4","")  :run.config for run in runs}
    run_dict = {run.config["env_name"].replace("NoFrameskip-v4","")  :run.summary_metrics for run in runs}

    run_summary_dict = {env_meth:{} for env_meth in run_dict.keys()}

    overall_summary_dict = {k:[] for k in list(summary_key_dict.keys()) + ["reward_per_episode"]}

    for run_name, run in run_dict.items():

        #print(run_name)
        for summary_name, key_list in summary_key_dict.items():
            run_mean_values = {k.lower():v for k,v in run.items() if "acc" in k and "stderr" not in k\
                               and  any(sum_key.lower() in k.lower() for sum_key in key_list)}
            #run_stderr_values = {k:v for k,v in run.items() if "stderr" in k and "mean" not in k and any(sum_key in k.lower() for sum_key in key_list)}
            if len(run_mean_values) > 0:
                run_summary = np.mean(list(run_mean_values.values()))
                overall_summary_dict[summary_name].append(run_summary)
                run_summary_dict[run_name][summary_name] = run_summary
        overall_summary_dict["reward_per_episode"].append(run["test_mean_reward_per_episode"])
        run_summary_dict[run_name]["reward_per_episode"] = run["test_mean_reward_per_episode"]
                #print("\t",summary_name,run_summary_dict[run_name][summary_name])
    overall_summary_dict = {k:np.mean(v) for k,v in overall_summary_dict.items()} 
    return overall_summary_dict, run_summary_dict #, run_dict, run_info_dict

rlag = {k:{} for k in atari_dict.keys()}

for algo in ["a2c", "apex"]:
    _, run_summary_dict = get_pretrained_rl_dicts(algo)
    for k, v in run_summary_dict.items():
        rlag[k.lower()][algo] = (v["overall"]*100,v['reward_per_episode']) 
    

print(" & ".join(["game", "a2c probe", "a2c returns","apex probe", "a2c returns"]), "\\\\")
for game, dic in rlag.items():
    #print(dic)
    if len(dic) > 1:
        all_res = list(dic["a2c"]) + list(dic["apex"])
        all_res = ["%.2f"%(float(n)) for n in all_res]
        print(game, " & " ," & ".join(all_res), "\\\\")
    


rlov = {}

for algo in ["a2c", "apex"]:
    osd, _ = get_pretrained_rl_dicts(algo)
    rlov[algo + "probe avg"] = osd["overall"]*100
    rlov[algo + "avg return per ep per game"] = osd['reward_per_episode']


a2c0,a2c1 = ["%8.2f"%rlov[k] for k in rlov.keys() if "a2c" in k]

apex0,apex1 = ["%8.2f"%rlov[k] for k in rlov.keys() if "apex" in k]

alla2c_probes = [v["a2c"][0] for v in rlag.values() if len(v) > 0]

alla2c_returns = [v["a2c"][1] for v in rlag.values() if len(v) > 0]

a2c_corr = np.corrcoef(alla2c_probes,alla2c_returns)[0][1]

allapex_probes =  [v["apex"][0] for v in rlag.values() if len(v) > 0]
allapex_returns =  [v["apex"][1] for v in rlag.values() if len(v) > 0]

apex_corr = np.corrcoef(allapex_probes,allapex_returns)[0][1]

print(" & ".join(["method", "probe_score", "avg_returns all games","correlation"]))
print(" & ".join(["a2c",a2c0,a2c1,str(a2c_corr)]))
print(" & ".join(["apex",apex0,apex1,str(apex_corr)]))