### Get results from Wandb

In [1]:
import pandas as pd 
import wandb
api = wandb.Api()

# Project is specified by <entity/project-name>
runs = api.runs("draftrec/atari_pretrain")

summary_list, config_list, id_list = [], [], []
for run in runs: 
    # .summary contains the output keys/values for metrics like accuracy.
    #  We call ._json_dict to omit large files 
    summary_list.append(run.summary._json_dict)

    # .config contains the hyperparameters.
    #  We remove special values that start with _.
    config_list.append(
        {k: v for k,v in run.config.items()
          if not k.startswith('_')})

    # .name is the human-readable name of the run.
    id_list.append(run.id)

runs_df = pd.DataFrame({
    "summary": summary_list,
    "config": config_list,
    "id": id_list,
    })

runs_df.to_csv("project.csv")

### Read runs

In [2]:
data_ = pd.read_csv('project.csv', index_col=0)
data_

Unnamed: 0,summary,config,id
0,{'_wandb': {'runtime': 6}},{},36godywq
1,{'_wandb': {'runtime': 6}},{},13sayzx8
2,{'_wandb': {'runtime': 6}},{},3fz8wbam
3,{'_wandb': {'runtime': 1}},"{'exp_name': 'gpt_video_cons_npred', 'override...",3b29kha1
4,{'_wandb': {'runtime': 2}},"{'exp_name': 'gpt_cont_video', 'overrides': ['...",3fbr2k1c
...,...,...,...
1571,"{'_runtime': 20190, '_timestamp': 1659282969, ...","{'env': {'game': 'pong', 'type': 'atari', 'hor...",5zlpym7v
1572,"{'_step': 117001, '_wandb': {'runtime': 18252}...","{'env': {'game': 'pong', 'type': 'atari', 'hor...",3oh4f4sw
1573,"{'positive_sim': 0.9868739223480224, 'loss': -...","{'env': {'game': 'pong', 'type': 'atari', 'hor...",385p3sw7
1574,"{'positive_sim': 0.9873969799280168, 'loss': -...","{'env': {'game': 'pong', 'type': 'atari', 'hor...",njf5gwab


### Filter based on conditions

In [3]:
group_name = 'baseline'
exp_name = 'barlow'
model_path = '0/10/model.pth'

In [4]:
configs = data_['config']
indexs = []
for idx, config in enumerate(configs):
    cfg = eval(config)
    if len(cfg) == 0:
        continue
        
    run_exp_name = cfg['exp_name']
    run_group_name = cfg['group_name']

    # condition
    if run_exp_name == exp_name and run_group_name == group_name:
        if 'env' in cfg:
            indexs.append(idx)

In [5]:
data = data_.iloc[indexs]
envs = []
for config in data['config']:
    envs.append(eval(config)['env']['game'])
data['env'] = envs
print(len(data))
data

26


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  data['env'] = envs


Unnamed: 0,summary,config,id,env
257,"{'_wandb': {'runtime': 11135}, 'neg_sim': 0.35...","{'env': {'game': 'UpNDown', 'seed': 42, 'type'...",23bzj0u3,UpNDown
258,"{'_step': 23440, 'epoch': 10, '_timestamp': 16...","{'env': {'game': 'Seaquest', 'seed': 42, 'type...",35tf9avz,Seaquest
271,"{'loss': 28.35670840956948, 'act_f1': 0.097864...","{'env': {'game': 'PrivateEye', 'seed': 42, 'ty...",2qv9hfzc,PrivateEye
272,"{'_step': 23440, 'epoch': 10, 'min_grad_norm':...","{'env': {'game': 'Qbert', 'seed': 42, 'type': ...",b8zg4678,Qbert
273,"{'loss': 29.03515445535833, '_wandb': {'runtim...","{'env': {'game': 'RoadRunner', 'seed': 42, 'ty...",1dvsqbut,RoadRunner
274,"{'_wandb': {'runtime': 10596}, 'act_f1': 0.300...","{'env': {'game': 'Pong', 'seed': 42, 'type': '...",1v07rwbk,Pong
278,"{'act_f1': 0.11976189469097366, 'pos_neg_diff'...","{'env': {'game': 'Krull', 'seed': 42, 'type': ...",3sqp74f9,Krull
279,"{'pos_sim': 0.9595307210629636, '_runtime': 11...","{'env': {'game': 'KungFuMaster', 'seed': 42, '...",1kv7msnc,KungFuMaster
280,"{'act_f1': 0.1725325578512591, '_timestamp': 1...","{'env': {'game': 'MsPacman', 'seed': 42, 'type...",2c1lcc42,MsPacman
281,"{'_runtime': 10681, 'reward_ratio': 0.01780089...","{'env': {'game': 'Kangaroo', 'seed': 42, 'type...",1rjdts2y,Kangaroo


### 1. Restore Saved Models

In [6]:
import pathlib
base_path = str(pathlib.Path().resolve())

artifact_dict = {}
for run_id, env in zip(data['id'], data['env']):
    print(env, run_id)
    try:
        name = env + '/' + model_path 
        path = base_path + '/' + name
        wandb.restore(name, run_path="draftrec/atari_pretrain/" + run_id)
        artifact_dict[path] = name 
    except:
        continue

UpNDown 23bzj0u3
Seaquest 35tf9avz
PrivateEye 2qv9hfzc
Qbert b8zg4678
RoadRunner 1dvsqbut
Pong 1v07rwbk
Krull 3sqp74f9
KungFuMaster 1kv7msnc
MsPacman 2c1lcc42
Kangaroo 1rjdts2y
Jamesbond 1vbc0umc
Frostbite 27v43tpn
Hero 315xv24d
Gopher 32t0h2fw
ChopperCommand 2v562c8j
CrazyClimber 158f3nfd
Freeway 3vtdxqvx
DemonAttack 5kds5mdo
Boxing 2kccraz3
Breakout 2oyutopa
BattleZone t6joigzx
BankHeist 14ha5knr
Asterix bmyrjvkx
Alien e3iqpk8o
Assault 1akvdlvi
Amidar 5ukfpvp8


### Save to artifact

In [7]:
wandb.init(project='atari_pretrain',
           entity='draftrec',
           group=exp_name,
           settings=wandb.Settings(start_method="thread"))  
artifact = wandb.Artifact(name=exp_name, type='model')

# save models
for path, name in artifact_dict.items():
    artifact.add_file(path, name=name)

wandb.run.finish_artifact(artifact)

[34m[1mwandb[0m: Currently logged in as: [33mjoonleesky[0m ([33mdraftrec[0m). Use [1m`wandb login --relogin`[0m to force relogin
  warn("The `IPython.html` package has been deprecated since IPython 4.0. "


<wandb.sdk.wandb_artifacts.Artifact at 0x7fa9214e4610>

### Remove artifact

In [8]:
import shutil
for env in data['env']:
    shutil.rmtree('./'+env, ignore_errors=True)

### 2. Generate json file

In [None]:
def get_results_dict(data):
    results = []
    for idx in range(len(data)):
        row = data.iloc[idx]
        summary = eval(row['summary'])
        config = eval(row['config'])

        if 'env' not in config:
            continue

        game = config['env']['game']
        try:
            import pdb
            pdb.set_trace()
            score = summary['eval_mean_traj_game_scores']
        except:
            continue
        results.append([0, game, score, 0])

    return results
    
results = get_results_dict(data)

In [None]:
data

In [18]:
import json

def generate_json_file(file_name, results):
    data = {}
    data['data'] = results
    path = file_name + '.json'
    with open(path, 'w') as json_file:
        json.dump(data, json_file)

In [19]:
generate_json_file('byol_impala', results)