In [106]:
from collections import defaultdict
import numpy as np
import wandb

from uniMASK.utils import (
    average_dictionaries,
    create_dir_if_not_exists,
    delete_dir_if_exists,
    load_from_json,
    mean_and_std_err,
    save_as_json,
    to_numpy,
)

In [104]:
api = wandb.Api()

# Project is specified by <entity/project-name>
runs = api.runs("fbert/maze_expl_sweeps_reproduce")
summary_list = [] 
config_list = [] 
name_list = [] 
for run in runs: 
    # run.summary are the output key/values like accuracy.
    # We call ._json_dict to omit large files 
    summary_list.append(run.summary._json_dict) 

    # run.config is the input metrics.
    # We remove special values that start with _.
    config = {k:v for k,v in run.config.items() if not k.startswith('_')}
    config_list.append(config) 

    # run.name is the name of the run.
    name_list.append(run.name)       

import pandas as pd 
summary_df = pd.DataFrame.from_records(summary_list) 
config_df = pd.DataFrame.from_records(config_list) 
name_df = pd.DataFrame({'name': name_list}) 
all_df = pd.concat([name_df, config_df,summary_df], axis=1)

In [97]:
def get_dicts(seq_len):
    df = all_df[all_df.seq_len == seq_len]
    
    d = defaultdict(lambda: defaultdict(list))
    for row_idx in range(len(df)):
        row = df.iloc[row_idx]
        run_name = row['name']
        bc_rew = row['best_eval_avg_rew_BCBC']
        rc_rew = row['best_eval_avg_rew_RCRC']
        d[run_name[:-7]]["bc_rew"].append(bc_rew)
        d[run_name[:-7]]["rc_rew"].append(rc_rew)
    
    bc_table_values = {k: mean_and_std_err(v['bc_rew']) for k, v in d.items()}
    rc_table_values = {k: mean_and_std_err(v['rc_rew']) for k, v in d.items()}
    best_bc_table_values = {k: np.argmax(v['bc_rew']) for k, v in d.items()}
    best_rc_table_values = {k: np.argmax(v['rc_rew']) for k, v in d.items()}
    return bc_table_values, rc_table_values, best_bc_table_values, best_rc_table_values

In [102]:
a, b, c, d = get_dicts(10)

In [91]:
a,b

({'900N_10len_DT_BC_rl_sl0_t_enc_DT': (1.581528819093314, 0.05752544208101805),
  '900N_10len_DT_RC_rl_sl0_t_enc_DT': (nan, nan),
  '900N_10len_RC_rl_NN': (nan, nan),
  '900N_10len_BC_rl_NN': (1.8314716384128409, 0.066168293128948),
  '900N_10len_DT_RC_rl_sl0_DT': (nan, nan),
  '900N_10len_DT_BC_rl_sl0_DT': (2.7419418408825975, 0.01129227507346532),
  '900N_10len_RC_rl': (nan, nan),
  '900N_10len_BC_rl': (2.4676030231346324, 0.044116914384106684),
  '900N_10len_rnd_BC_rl': (2.3621858031193916, 0.05770550366157036),
  '900N_10len_BC_RC_rl': (2.3902412642475896, 0.03391566144801734),
  '900N_10len_rnd_rl': (2.292935913451514, 0.0707849456179969)},
 {'900N_10len_DT_BC_rl_sl0_t_enc_DT': (nan, nan),
  '900N_10len_DT_RC_rl_sl0_t_enc_DT': (1.7000489033190438,
   0.07455501826692258),
  '900N_10len_RC_rl_NN': (1.8751192178712928, 0.0647570515695031),
  '900N_10len_BC_rl_NN': (nan, nan),
  '900N_10len_DT_RC_rl_sl0_DT': (2.725090147117382, 0.02092296001031855),
  '900N_10len_DT_BC_rl_sl0_DT': (n

In [103]:
c, d

({'900N_10len_DT_BC_rl_sl0_t_enc_DT': 4,
  '900N_10len_DT_RC_rl_sl0_t_enc_DT': 0,
  '900N_10len_RC_rl_NN': 0,
  '900N_10len_BC_rl_NN': 2,
  '900N_10len_DT_RC_rl_sl0_DT': 0,
  '900N_10len_DT_BC_rl_sl0_DT': 2,
  '900N_10len_RC_rl': 0,
  '900N_10len_BC_rl': 2,
  '900N_10len_rnd_BC_rl': 3,
  '900N_10len_BC_RC_rl': 3,
  '900N_10len_rnd_rl': 0},
 {'900N_10len_DT_BC_rl_sl0_t_enc_DT': 0,
  '900N_10len_DT_RC_rl_sl0_t_enc_DT': 2,
  '900N_10len_RC_rl_NN': 4,
  '900N_10len_BC_rl_NN': 0,
  '900N_10len_DT_RC_rl_sl0_DT': 2,
  '900N_10len_DT_BC_rl_sl0_DT': 0,
  '900N_10len_RC_rl': 4,
  '900N_10len_BC_rl': 0,
  '900N_10len_rnd_BC_rl': 0,
  '900N_10len_BC_RC_rl': 3,
  '900N_10len_rnd_rl': 0})