In [58]:
import os
import pandas as pd
import numpy as np
from scipy.signal import savgol_filter
import matplotlib.pyplot as plt

ENVS = [
    ("CartPole-v1", 5e5),
    ("LunarLander-v2", 5e5),
    ("Swimmer-v4", 5e5),
    ("HalfCheetah-v4", 3e6),
    ("Boxing-v5", 1e8),
    ("SpaceInvaders-v5", 1e8),
    
    ("Acrobot-v1", 5e5),
    ("Pendulum-v1", 5e5),
    ("BipedalWalker-v3", 2e6),
    
    ("Hopper-v4", 1e6),
    ("Walker2d-v4", 2e6),
    ("Ant-v4", 1e7),
    ("Humanoid-v4", 1e7),
    
    ("Atlantis-v5", 2e7),
    ("BeamRider-v5", 2e7),
    ("Pong-v5", 2e7),
    ("CrazyClimber-v5", 2e7),
    ("Enduro-v5",  2e7),
    
    ("Qbert-v5", 2e7),
    ("Seaquest-v5", 2e7),
]


ALL_GYM_ENVS = [
    ("CartPole-v1", 0.1, 4, 475,),
    ("Acrobot-v1", 0.05, 4, -100,),
    ("Pendulum-v1", .1, 32, -100,),
    ("LunarLander-v2", .1, 32, 200,),
    ("BipedalWalker-v3", .1, 48, 300),
    ("Swimmer-v4", .1, 4, 360),
    ("HalfCheetah-v4", 0.05, 17,  4800,),
    ("Hopper-v4", 0.05, 32,  3000,),
    ("Walker2d-v4", 0.05, 51,  3000,),
    ("Ant-v4", 0.05, 108,  5000,),
    ("Humanoid-v4", 0.01, 128, 6000,),
]

def load_data(env_folder):
    stats_list = []
    for strat in os.listdir(env_folder):
        path = os.path.join(env_folder, strat)
        if not os.path.isdir(path): continue
        for i, (run) in enumerate(os.listdir(path)):
            path = os.path.join(env_folder, strat, run)
            if not os.path.isdir(path): continue
            stats = pd.read_csv(os.path.join(path, "stats.csv"), skipinitialspace=True)
            stats['run'] = i
            stats['folder'] = run
            stats['strat'] = strat
            sig, lamb = strat.split("sigma-")[1].split("-lambda-")
            stats['method'] = stats.strat.str.split("-norm-").str[0]
            stats['sigma0'] = float(sig)
            stats['lambda'] = int(lamb)
            stats['train'] = stats[['best', 'current']].max(axis=1)
            stats['expected_test'] = stats[['best_median', 'current_median']].max(axis=1)
            stats_list.append(stats)
    return pd.concat(stats_list, ignore_index=True)


In [62]:
rl_data = pd.read_pickle("data/rl_data2.pkl")
es_data = pd.DataFrame()
for env_name, time in ENVS:
    if not env_name.endswith("v5"):
        env_data = pd.read_pickle(f"data/{env_name}/data_hyp.pkl")
    else:
        env_data = pd.read_pickle(f"data/{env_name}/data.pkl")
        print(env_name)
    env_data = env_data[
        ["method", "run", "n_train_episodes", "n_train_timesteps", "test", "train", "lambda", "sigma0"]
    ]
    env_data['env'] = env_name
    es_data = pd.concat([es_data, env_data])

data = pd.concat([es_data, rl_data]).reset_index(drop=True)

Boxing-v5
SpaceInvaders-v5
Atlantis-v5
BeamRider-v5
Pong-v5
CrazyClimber-v5
Enduro-v5
Qbert-v5
Seaquest-v5


In [95]:
titles = "CSA-ES", "CMA-ES", "sep-CMA-ES", "DQN", "PPO", "SAC", "DQN*", "PPO*", "SAC*",
keys =  'csa','cma-es', 'sep-cma-es', 'dqn_large', 'ppo_large', 'sac_large', 'dqn_small', 'ppo_small', 'sac_small',
    
max_reward_table = ' & Timesteps & ' + ' & '.join(titles) + ' \\\\ \n'
min_timesteps_table = ' & Threshold & ' + ' & '.join(titles) + ' \\\\ \n'

for env_name, *_, threshold in ALL_GYM_ENVS:
    env_data = data[(data.env == env_name)]
    max_time = dict(ENVS).get(env_name)
    max_reward_dict = dict(env_data[(env_data.n_train_timesteps < max_time)].groupby(["method", "run"])['test'].max().groupby("method").mean("test").astype(int))
    min_time_dict = dict.fromkeys(set(env_data.method), 0)
    min_time_dict.update(**dict(env_data[env_data.test >= threshold].groupby(["method", "run"])['n_train_timesteps'].min().groupby("method").min().astype(int)))
    max_reward_table += env_name + ' & ' +  f'${max_time:.0e}$ & '.replace("e+0", "\cdot 10^") + \
        ' & '.join([str(max_reward_dict.get(k) or ' - ') for k in keys]) + ' \\\\ \n'
    
    min_timesteps_table += env_name + ' & ' +  f'{threshold} & ' + \
        ' & '.join([f"${v:.0e}$".replace("e+0", "\cdot 10^").replace("0\cdot 10^0", "\infty") if (v:=min_time_dict.get(k)) is not None else ' - ' for k in keys]) + ' \\\\ \n'

print(min_timesteps_table)

 & Threshold & CSA-ES & CMA-ES & sep-CMA-ES & DQN & PPO & SAC & DQN* & PPO* & SAC* \\ 
CartPole-v1 & 475 & $3\cdot 10^3$ & $2\cdot 10^3$ & $3\cdot 10^3$ & $2\cdot 10^4$ & $6\cdot 10^4$ &  -  & $\infty$ & $\infty$ &  -  \\ 
Acrobot-v1 & -100 & $4\cdot 10^3$ & $5\cdot 10^3$ & $4\cdot 10^3$ & $2\cdot 10^4$ & $7\cdot 10^4$ &  -  & $1\cdot 10^5$ & $\infty$ &  -  \\ 
Pendulum-v1 & -100 & $\infty$ & $\infty$ & $\infty$ &  -  & $\infty$ & $\infty$ &  -  & $\infty$ & $\infty$ \\ 
LunarLander-v2 & 200 & $5\cdot 10^4$ & $7\cdot 10^4$ & $6\cdot 10^4$ & $3\cdot 10^5$ & $4\cdot 10^5$ &  -  & $\infty$ & $\infty$ &  -  \\ 
BipedalWalker-v3 & 300 & $2\cdot 10^6$ & $2\cdot 10^6$ & $5\cdot 10^6$ &  -  & $\infty$ & $2\cdot 10^5$ &  -  & $\infty$ & $\infty$ \\ 
Swimmer-v4 & 360 & $4\cdot 10^5$ & $3\cdot 10^5$ & $7\cdot 10^5$ &  -  & $\infty$ & $\infty$ &  -  & $\infty$ & $\infty$ \\ 
HalfCheetah-v4 & 4800 & $2\cdot 10^6$ & $1\cdot 10^6$ & $2\cdot 10^6$ &  -  & $\infty$ & $5\cdot 10^4$ &  -  & $\infty$ & $\

In [79]:


# .max("test")["test"]

method
cma-es         1642
csa            3041
dqn_large     21326
ppo_large     64820
sep-cma-es     2751
Name: n_train_timesteps, dtype: object