### Get results from Wandb

In [29]:
import pandas as pd 
import wandb
api = wandb.Api()

# Project is specified by <entity/project-name>
runs = api.runs("draftrec/atari_pretrain")

summary_list, config_list, name_list = [], [], []
for run in runs: 
    # .summary contains the output keys/values for metrics like accuracy.
    #  We call ._json_dict to omit large files 
    summary_list.append(run.summary._json_dict)

    # .config contains the hyperparameters.
    #  We remove special values that start with _.
    config_list.append(
        {k: v for k,v in run.config.items()
          if not k.startswith('_')})

    # .name is the human-readable name of the run.
    name_list.append(run.name)

runs_df = pd.DataFrame({
    "summary": summary_list,
    "config": config_list,
    "name": name_list,
    })

runs_df.to_csv("project.csv")

### Read runs

In [30]:
data = pd.read_csv('project.csv', index_col=0)
data

Unnamed: 0,summary,config,name
0,"{'loss': 1.9288443381786349, '_step': 4000, '_...","{'env': {'game': 'DemonAttack', 'type': 'atari...",amber-dust-3284
1,"{'idm_acc': 0.18005816142819822, 'idm_loss': 1...","{'env': {'game': 'DemonAttack', 'type': 'atari...",sandy-smoke-3283
2,"{'idm_loss': 1.8171095966100692, 'obs_loss': 1...","{'env': {'game': 'DemonAttack', 'type': 'atari...",hearty-salad-3282
3,"{'min_grad_norm': 0, 'pos_sim': 0.825347212672...","{'env': {'game': 'DemonAttack', 'type': 'atari...",unique-river-3281
4,"{'act_loss': 0, 'mean_grad_norm': 0.0483209920...","{'env': {'game': 'DemonAttack', 'type': 'atari...",tough-leaf-3280
...,...,...,...
651,"{'pos_neg_diff': 0.87976934150327, 'positive_s...","{'env': {'game': 'pong', 'type': 'atari', 'hor...",usual-dawn-161
652,"{'_timestamp': 1659280381, 'negative_sim': 0.5...","{'env': {'game': 'pong', 'type': 'atari', 'hor...",different-silence-156
653,"{'negative_sim': 0.18186782717704772, 'pos_neg...","{'env': {'game': 'pong', 'type': 'atari', 'hor...",light-oath-154
654,"{'loss': -0.9872578982114792, '_step': 117001,...","{'env': {'game': 'pong', 'type': 'atari', 'hor...",honest-armadillo-153


In [None]:
best_model = wandb.restore('Boxing/0/best/model.pth', run_path="draftrec/atari_pretrain/8qeiqddf")
abc = torch.load('Boxing/0/best/model.pth')

### Filter based on conditions

In [33]:
configs = data['config']
indexs = []
for idx, config in enumerate(configs):
    cfg = eval(config)
    exp_name = cfg['exp_name']
    group_name = cfg['group_name']
    
    # condition
    if exp_name == 'trajformer_vid' and group_name == 'baseline_10':
        indexs.append(idx)

In [34]:
data = data.iloc[indexs]
data

Unnamed: 0,summary,config,name
4,"{'act_loss': 0, 'mean_grad_norm': 0.0483209920...","{'env': {'game': 'DemonAttack', 'type': 'atari...",tough-leaf-3280
61,"{'act_loss': 0, 'min_grad_norm': 0, 'act_acc':...","{'env': {'game': 'Boxing', 'type': 'atari', 'f...",hearty-cosmos-3223
62,"{'_wandb': {'runtime': 49708}, 'act_acc': 0, '...","{'env': {'game': 'Breakout', 'type': 'atari', ...",swift-sound-3221
63,"{'loss': 0.209914782102192, '_step': 234370, '...","{'env': {'game': 'BattleZone', 'type': 'atari'...",royal-paper-3222
64,"{'_step': 234370, 'epoch': 10, 'act_acc': 0, '...","{'env': {'game': 'BankHeist', 'type': 'atari',...",vocal-sun-3220
72,"{'pos_sim': 0.9517223587850244, 'act_loss': 0,...","{'env': {'game': 'Alien', 'type': 'atari', 'fr...",stellar-wind-3212
73,"{'reward_f1': 0.6561181434599156, 'max_grad_no...","{'env': {'game': 'Asterix', 'type': 'atari', '...",honest-oath-3210
74,"{'_wandb': {'runtime': 49209}, 'act_acc': 0, '...","{'env': {'game': 'Amidar', 'type': 'atari', 'f...",drawn-firebrand-3209
75,"{'act_acc': 0, '_runtime': 49266, 'idm_loss': ...","{'env': {'game': 'Assault', 'type': 'atari', '...",chocolate-firebrand-3211


### Generate json file

In [17]:
def get_results_dict(data):
    results = []
    for idx in range(len(data)):
        row = data.iloc[idx]
        summary = eval(row['summary'])
        config = eval(row['config'])

        if 'env' not in config:
            continue

        game = config['env']['game']
        try:
            score = summary['eval_mean_traj_game_scores']
        except:
            continue
        results.append([0, game, score, 0])

    return results
    
results = get_results_dict(data)

In [18]:
import json

def generate_json_file(file_name, results):
    data = {}
    data['data'] = results
    path = file_name + '.json'
    with open(path, 'w') as json_file:
        json.dump(data, json_file)

In [19]:
generate_json_file('byol_impala', results)