In [7]:
%matplotlib inline

In [8]:
from collections import defaultdict
import json

import numpy as np
import wandb
import pandas as pd

from aux_task_discovery.utils.constants import WANDB_PROJECT, WANDB_ENTITY

##### Set Sweep ID and agent hyperparameter to compare performance across

In [9]:
SWEEP_ID = "iegwo7wi"
PARAM = "learning_rate"

##### Query W&B API for sweep runs

In [10]:
api = wandb.Api()
sweep = api.sweep(WANDB_ENTITY + "/" + WANDB_PROJECT + "/" + SWEEP_ID)
sweep_runs = sweep.runs

##### For each hyperparameter value, calculate the mean episode length using the last 20% of episodes from each run with that value


In [11]:
max_episodes = sweep_runs[0].config['max_episodes']
n_episode_comp = round(max_episodes*0.1)
mean_episode_lens = defaultdict(list)
for run in sweep_runs:
    assert run.config['max_episodes'] == max_episodes, "Max episodes must be the same for all runs"
    run_param = run.config['agent_args'][PARAM]
    data = pd.DataFrame([row for row in run.scan_history(keys=['episode_len', 'episode'])])
    mean_episode_lens[run_param].append(data.tail(n_episode_comp)['episode_len'].mean())

n_runs = len(mean_episode_lens[run_param])
for key in mean_episode_lens:
    assert len(mean_episode_lens[key]) == n_runs, "Number of runs must be the same for all param values"
    mean_episode_lens[key] = np.mean(mean_episode_lens[key])

print(f"Mean episode length for last {n_episode_comp} episodes averaged across {n_runs} runs")
for key, val in mean_episode_lens.items():
    print(f"{PARAM} = {key}: {val}")



Mean episode length for last 20 episodes averaged across 10 runs
learning_rate = 0.04: 671.605
learning_rate = 0.01: 296.69500000000005
learning_rate = 0.0025: 77.77000000000001
learning_rate = 0.000625: 273.31999999999994


##### Save results

In [12]:
with open(f"../sweep_results/{SWEEP_ID}_{PARAM}_episode_lens.txt", "w") as f:
    f.write(f"Sweep ID: {SWEEP_ID}\n")
    f.write("Sweep Config:\n")
    json.dump(sweep.config, f, indent=4)
    f.write(f"\n\nMean episode length for last {n_episode_comp} episodes averaged across {n_runs} runs:\n")
    for key, val in mean_episode_lens.items():
        f.write(f"{PARAM} = {key}: {val}\n")