### Get results from Wandb

In [1]:
import pandas as pd 
import wandb
api = wandb.Api()

# Project is specified by <entity/project-name>
runs = api.runs("draftrec/atari_pretrain")

summary_list, config_list, id_list = [], [], []
for run in runs: 
    # .summary contains the output keys/values for metrics like accuracy.
    #  We call ._json_dict to omit large files 
    summary_list.append(run.summary._json_dict)

    # .config contains the hyperparameters.
    #  We remove special values that start with _.
    config_list.append(
        {k: v for k,v in run.config.items()
          if not k.startswith('_')})

    # .name is the human-readable name of the run.
    id_list.append(run.id)

runs_df = pd.DataFrame({
    "summary": summary_list,
    "config": config_list,
    "id": id_list,
    })

runs_df.to_csv("project.csv")

### Read runs

In [2]:
data_ = pd.read_csv('project.csv', index_col=0)
data_

Unnamed: 0,summary,config,id
0,{},"{'env': {'game': 'Boxing', 'type': 'atari', 'f...",kw6hqbf6
1,{},"{'env': {'game': 'Boxing', 'type': 'atari', 'f...",2ivz19nz
2,{},"{'env': {'game': 'Boxing', 'type': 'atari', 'f...",9wqpicwo
3,{'_wandb': {'runtime': 28}},"{'env': {'game': 'Boxing', 'type': 'atari', 'f...",f0i66cyl
4,{},"{'env': {'game': 'Boxing', 'type': 'atari', 'f...",3dst0211
...,...,...,...
1275,"{'pos_neg_diff': 0.87976934150327, 'positive_s...","{'env': {'game': 'pong', 'type': 'atari', 'hor...",5zlpym7v
1276,"{'loss': -0.9878902941942216, '_step': 117001,...","{'env': {'game': 'pong', 'type': 'atari', 'hor...",3oh4f4sw
1277,"{'loss': -0.9783781695365906, '_step': 117001,...","{'env': {'game': 'pong', 'type': 'atari', 'hor...",385p3sw7
1278,"{'negative_sim': 0.17696081340312958, 'pos_neg...","{'env': {'game': 'pong', 'type': 'atari', 'hor...",njf5gwab


### Filter based on conditions

In [3]:
group_name = 'baseline'
exp_name = 'clt_traj_cons_no_proj'
model_path = '0/10/model.pth'

In [4]:
configs = data_['config']
indexs = []
for idx, config in enumerate(configs):
    cfg = eval(config)
    if len(cfg) == 0:
        continue
        
    run_exp_name = cfg['exp_name']
    run_group_name = cfg['group_name']

    # condition
    if run_exp_name == exp_name and run_group_name == group_name:
        if 'env' in cfg:
            indexs.append(idx)

In [5]:
data = data_.iloc[indexs]
envs = []
for config in data['config']:
    envs.append(eval(config)['env']['game'])
data['env'] = envs
print(len(data))
data

12


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  data['env'] = envs


Unnamed: 0,summary,config,id,env
21,"{'_step': 200000, 'act_acc': 0.696096105813980...","{'env': {'game': 'CrazyClimber', 'type': 'atar...",3b8bcyof,CrazyClimber
22,"{'pos_neg_diff': 0.941808152616024, 'max_grad_...","{'env': {'game': 'Freeway', 'type': 'atari', '...",bkmyy6jj,Freeway
23,"{'rtg_loss': 1.111785558372736, 'reward_ratio'...","{'env': {'game': 'ChopperCommand', 'type': 'at...",f6rlmwdk,ChopperCommand
24,"{'_step': 202000, '_timestamp': 1669436278, 'm...","{'env': {'game': 'DemonAttack', 'type': 'atari...",jxr9ejhs,DemonAttack
72,"{'mean_grad_norm': 0.1223612709985516, 'best_m...","{'env': {'game': 'Boxing', 'type': 'atari', 'f...",32nzfqdn,Boxing
73,"{'rtg_loss': 1.2629908736805282, 'epoch': 10, ...","{'env': {'game': 'Breakout', 'type': 'atari', ...",2batx68n,Breakout
74,"{'best_metric_val': 0.17895647602832812, '_ste...","{'env': {'game': 'BankHeist', 'type': 'atari',...",4pgopiu3,BankHeist
75,"{'mean_grad_norm': 0.10300133017908912, 'idm_a...","{'env': {'game': 'BattleZone', 'type': 'atari'...",2s6qu43q,BattleZone
80,"{'mean_grad_norm': 0.1353523378893935, 'best_m...","{'env': {'game': 'Alien', 'type': 'atari', 'fr...",1xi3ji6w,Alien
81,"{'epoch': 10, 'idm_loss': 2.323826084938153, '...","{'env': {'game': 'Amidar', 'type': 'atari', 'f...",2v6wvz3p,Amidar


### 1. Restore Saved Models

In [6]:
import pathlib
base_path = str(pathlib.Path().resolve())

artifact_dict = {}
for run_id, env in zip(data['id'], data['env']):
    print(env, run_id)
    try:
        name = env + '/' + model_path 
        path = base_path + '/' + name
        wandb.restore(name, run_path="draftrec/atari_pretrain/" + run_id)
        artifact_dict[path] = name 
    except:
        continue

CrazyClimber 3b8bcyof
Freeway bkmyy6jj
ChopperCommand f6rlmwdk
DemonAttack jxr9ejhs
Boxing 32nzfqdn
Breakout 2batx68n
BankHeist 4pgopiu3
BattleZone 2s6qu43q
Alien 1xi3ji6w
Amidar 2v6wvz3p
Asterix oz7rk7ss
Assault 3t5ggjqr


### Save to artifact

In [7]:
wandb.init(project='atari_pretrain',
           entity='draftrec',
           group=exp_name,
           settings=wandb.Settings(start_method="thread"))  
artifact = wandb.Artifact(name=exp_name, type='model')

# save models
for path, name in artifact_dict.items():
    artifact.add_file(path, name=name)

wandb.run.finish_artifact(artifact)

[34m[1mwandb[0m: Currently logged in as: [33mjoonleesky[0m ([33mdraftrec[0m). Use [1m`wandb login --relogin`[0m to force relogin
  warn("The `IPython.html` package has been deprecated since IPython 4.0. "


<wandb.sdk.wandb_artifacts.Artifact at 0x7f444c2b8d60>

### Remove artifact

In [8]:
import shutil
for env in data['env']:
    shutil.rmtree('./'+env, ignore_errors=True)

### 2. Generate json file

In [None]:
def get_results_dict(data):
    results = []
    for idx in range(len(data)):
        row = data.iloc[idx]
        summary = eval(row['summary'])
        config = eval(row['config'])

        if 'env' not in config:
            continue

        game = config['env']['game']
        try:
            import pdb
            pdb.set_trace()
            score = summary['eval_mean_traj_game_scores']
        except:
            continue
        results.append([0, game, score, 0])

    return results
    
results = get_results_dict(data)

In [None]:
data

In [18]:
import json

def generate_json_file(file_name, results):
    data = {}
    data['data'] = results
    path = file_name + '.json'
    with open(path, 'w') as json_file:
        json.dump(data, json_file)

In [19]:
generate_json_file('byol_impala', results)