# SAC-SVG(H): POPLIN Sweep and Experiments

This notebook reproduces Table 2 of our [SAC-SVG(H) paper](https://arxiv.org/abs/2008.12775)
that considers the OpenAI gym locomotion tasks used
in POPLIN (Wang and Ba, 2019).
Once setting up the code and a multitask environment with
hydra and uncommenting the `hydra/sweeper` line in `config/train.yaml`,
you can launch our hyper-parameter search with:

```bash
/train.py -m experiment=full_poplin_sweep
```

Then you can use this notebook as a starting point to analyze the
progress and results of these experiments.

In [1]:
%load_ext autoreload

from collections import defaultdict
from pprint import pprint

import matplotlib.pyplot as plt
plt.style.use('bmh')

%autoreload 2
from svg.analysis import sweep_summary, plot_agg, plot_ablation

%matplotlib inline

In [2]:
# Load all of the experiments into dataframes for the experiments.
d = '../exp/2021.03.16/1542_test/'
all_summary, groups, agg, configs = sweep_summary(d)

In [3]:
# Load all of the experiments into dataframes for the experiments.
d = '../exp/2021.04.17/2048_mbpo/'
all_summary, groups, agg, configs = sweep_summary(d)

display(all_summary.head())
display(agg)

Unnamed: 0,env,seed,best_eval_rew,last_eval_rew,d,env_name
0,mbpo_cheetah,1,16490.721131,16211.209097,../exp/2021.04.17/2048_mbpo/0/,mbpo_cheetah
1,mbpo_cheetah,3,15814.255879,15037.872012,../exp/2021.04.17/2048_mbpo/2/,mbpo_cheetah
2,mbpo_cheetah,10,16241.069622,15590.25518,../exp/2021.04.17/2048_mbpo/9/,mbpo_cheetah
3,mbpo_hopper,2,4086.168831,2401.692851,../exp/2021.04.17/2048_mbpo/11/,mbpo_hopper
4,mbpo_cheetah,7,14560.622332,14172.839787,../exp/2021.04.17/2048_mbpo/6/,mbpo_cheetah


Unnamed: 0_level_0,Unnamed: 1_level_0,best_eval_rew,best_eval_rew,last_eval_rew,last_eval_rew
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,std,mean,std
env_name,env,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2
mbpo_ant,mbpo_ant,8288.696077,1183.103427,8163.55002,1312.316999
mbpo_cheetah,mbpo_cheetah,14571.224437,2871.214498,13221.85443,3860.682656
mbpo_hopper,mbpo_hopper,3917.269708,169.12088,3642.448877,531.034875
mbpo_humanoid,mbpo_humanoid,9626.282362,774.896734,8062.06507,2809.256164
mbpo_walker2d,mbpo_walker2d,6028.896575,307.840072,5753.337798,647.904448


In [4]:
# Start by finding the best hyper-parameters and printing
# out the LaTeX-formatted results table.

env_raws = ['gym_ant', 'gym_hopper', 'gym_fswimmer', 'gym_cheetah', 'gym_walker2d', 'pets_cheetah']
agg_flat = agg.reset_index()
horizons = sorted(agg_flat['agent.horizon'].unique(), key=int)

horizon_table = defaultdict(list)
best_hypers = {}
best_horizons = {}
for env in env_raws:
    best_val = best_s = None
    for horizon in horizons:
        agg_env = agg_flat[agg_flat.env_name == env]
        agg_env = agg_env[agg_flat['agent.horizon'] == horizon]
        best_i = agg_env.last_eval_rew['mean'].idxmax()
        best_rew = agg_env.last_eval_rew.loc[best_i]
        s = f"{best_rew['mean']:.2f} $\\pm$ {best_rew['std']:.2f}"
        best_df = agg_env.loc[best_i]
        horizon_table[horizon].append(s)
        if int(horizon) > 0:
            if best_val is None or best_rew['mean'] > best_val:
                best_s = s
                best_val = best_rew['mean']
                best_horizons[env] = horizon
                
        best_hypers[(env,horizon)] = {
            'learn_temp.init_targ_entr': best_df['learn_temp.init_targ_entr'].values[0],
            'learn_temp.final_targ_entr': best_df['learn_temp.final_targ_entr'].values[0],
            'learn_temp.entr_decay_factor': best_df['learn_temp.entr_decay_factor'].values[0]
        }
    horizon_table['best'].append(best_s)
    
s = 'SAC-SVG & ' + ' & '.join(horizon_table['best']) + r' \\'
print(s)
for horizon in horizons:
    s = f'SAC-SVG({int(horizon)-1}) & '
    s += ' & '.join(horizon_table[horizon])
    s += r' \\'
    print(s)

KeyError: 'agent.horizon'

In [None]:
# Produce the plots of the experiments

env_raws = ['gym_ant', 'gym_hopper', 'gym_fswimmer', 'gym_cheetah', 'gym_walker2d', 'pets_cheetah']
env_pretties = ['Ant', 'Hopper', 'Swimmer', 'Cheetah', 'Walker2d', 'PETS Cheetah']
poplin_lims = [2330, 2055, 334, 4235, 597, 12227.9]

table_data = []
for env, env_pretty, poplin_lim in zip(env_raws, env_pretties, poplin_lims):
    groups = []
    print(f'\n=== {env}')
    for horizon in [best_horizons[env], '0']:
        print(f'--- H={horizon}')
        I = (all_summary['env_name'] == env) & (all_summary['agent.horizon'] == horizon)
        for k, val in best_hypers[(env, horizon)].items():
            I = I & (all_summary[k] == val)
        pprint(best_hypers[(env, horizon)])
        t = all_summary[I]
        assert len(t) == 10
        
        groups.append({
            'roots': t.d.values,
            'tag': str(horizon),
        })
        
    plot_ablation(
        groups = groups,
        save=f'../data/fig/poplin_{env}.pdf',
        title=env_pretty,
#         xmax=5e5,
        legend=False,
        axhline=poplin_lim,
    )