# Loading, evaluating, and visualizing runs

## Visualization imports

In [1]:
%matplotlib inline
%config InlineBackend.figure_format='retina'

import numpy as np
import pandas as pd
import pybullet as p
import matplotlib.pyplot as plt
from IPython.display import Video

## Helper functions

In [2]:
import gym
import torch.nn as nn
from gym import ObservationWrapper
from gym.spaces import flatten_space
from gym.wrappers import FilterObservation
from rrc.env import initializers, cube_env, make_env
from rrc.env.reward_fns import *
from rrc.env.wrappers import MonitorPyBulletWrapper, ResidualPDWrapper
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.torch_layers import CombinedExtractor
from stable_baselines3.common.preprocessing import get_flattened_obs_dim, is_image_space
from stable_baselines3 import HerReplayBuffer, SAC, TD3, PPO
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize


class FlattenGoalObs(ObservationWrapper):
    def __init__(self, env, observation_keys):
        super().__init__(env)
        obs_space = self.env.observation_space
        obs_dict = {k: flatten_space(obs_space[k]) for k in observation_keys}
        self.observation_space = gym.spaces.Dict(obs_dict)

    def observation(self, obs):
        n_obs = {}
        for k in self.observation_space.spaces:
            if isinstance(obs[k], dict):
                obs_list = [obs[k][k2] for k2 in self.env.observation_space[k]]
                n_obs[k] = np.concatenate(obs_list)
            else:
                n_obs[k] = obs[k]
        return n_obs


class HERCombinedExtractor(CombinedExtractor):
    """
    HERCombinedExtractor is a combined extractor which only extracts pre-specified observation_keys to include in
    the observation, while retaining them at the environment level so that they may still be stored in the replay buffer
    """

    def __init__(self, observation_space: gym.spaces.Dict, cnn_output_dim: int = 256, observation_keys: list = []):
        # TODO we do not know features-dim here before going over all the items, so put something there. This is dirty!
        super(CombinedExtractor, self).__init__(observation_space, features_dim=1)

        extractors = {}

        total_concat_size = 0
        for key in observation_keys:
            subspace = observation_space.spaces[key]
            # The observation key is a vector, flatten it if needed
            extractors[key] = nn.Flatten()
            total_concat_size += get_flattened_obs_dim(subspace)

        self.extractors = nn.ModuleDict(extractors)

        # Update the features dim manually
        self._features_dim = total_concat_size



def make_model(ep_len, lr, exp_dir=None, env=None, use_goal=True,
               use_sde=False, log_std_init=-3, load_path=None,
               residual=False):
    if use_goal:
        obs_keys = ['desired_goal', 'achieved_goal', 'observation']
    else:
        obs_keys = ['observation']

    policy_kwargs = dict(
                    log_std_init=log_std_init,
                    features_extractor_class=HERCombinedExtractor,
                    features_extractor_kwargs=dict(observation_keys=obs_keys))
    if use_sde:
        sde_kwargs = dict(
                use_sde=True,
                use_sde_at_warmup=False,
                sde_sample_freq=64)
    else:
        sde_kwargs = {}

    rb_kwargs = dict(n_sampled_goal=4,
                     goal_selection_strategy='future',
                     online_sampling=False,
                     max_episode_length=ep_len)

    model = SAC('MultiInputPolicy', env,
                # tensorboard_log=exp_dir,
                replay_buffer_class=HerReplayBuffer,
                # Parameters for HER
                replay_buffer_kwargs=rb_kwargs,
                policy_kwargs=policy_kwargs,
                verbose=1, buffer_size=int(1e6),
                learning_starts=20000,
                learning_rate=lr,
                gamma=0.99, batch_size=256, residual=residual, **sde_kwargs)
    if load_path is not None:
        if osp.isdir(load_path):
            load_path = osp.join(load_path, 'best_model.zip')
        model.load(load_path, env)
    return model



def make_ppo_model(ep_len, lr, exp_dir=None, env=None, use_goal=True,
                   use_sde=True, dry_run=False):
    if use_goal:
        obs_keys = ['desired_goal', 'observation']
    else:
        obs_keys = ['observation']

    policy_kwargs = dict(
                    log_std_init=-3,
                    features_extractor_class=HERCombinedExtractor,
                    features_extractor_kwargs=dict(observation_keys=obs_keys))
    if use_sde:
        sde_kwargs = dict(
                use_sde=True,
                sde_sample_freq=4)
    else:
        sde_kwargs = {}
    tensorboard_log = exp_dir if dry_run else None
    model = PPO('MlpPolicy', env,
                tensorboard_log=tensorboard_log,
                # Parameters for HER
                policy_kwargs=policy_kwargs,
                verbose=1,
                learning_rate=lr,
                n_steps=1000,
                gamma=0.99, batch_size=250, **sde_kwargs)
    return model



wandb_root = '/scr-ssd/ksrini/spinningup/notebooks'
get_save_path = lambda run: '/'.join([wandb_root] + run.config['exp_dir'].split('/')[1:])

def display_video(path=None, run=None):
    if run:
        path = get_save_path(run.config['exp_dir'])
    return Video(path, embed=True, width=640)

## Visualizing a trajectory (mp4 video)

In [4]:
Video('/scr-ssd/ksrini/spinningup/notebooks/videos/sim-22.mp4', embed=True, width=640)

## Loading a wandb run

In [3]:
import os
import wandb
import pandas as pd
import numpy as np
import os.path as osp
from trifinger_simulation.tasks.move_cube import Pose
import json

api = wandb.Api()

In [9]:
def run_history(run_id, proj='cvxrl', keys=None, samples=1000):
    if len(run_id.split('/')) == 2:
        proj, run_id = run_id.split('/')
    if keys:
        data = api.run('krshna/{}/{}'.format(proj, run_id)).history(samples=samples, keys=keys, pandas=False)
        while len(data) == 1:
            data = data[0]
        return pd.DataFrame(data)
    history = api.run('krshna/{}/{}'.format(proj, run_id)).history(samples=samples, keys=keys)
    keep_cols = ([c for c in history.columns if 'train' in c] 
     + [c for c in history.columns if 'episodes' in c]
     + [c for c in history.columns if 'rollout' in c])

    new_df = {}
    for c in keep_cols:
        new_df[c] = history[c][~history[c].isna()].values
    new_df['global_step'] = history['global_step'][~history[c].isna()].values
    return new_df

In [13]:
hist.keys()

dict_keys(['train/actor_loss', 'train/critic_loss', 'train/ent_coef', 'train/ent_coef_loss', 'train/learning_rate', 'train/n_updates', 'time/episodes', 'rollout/ep_len_mean', 'rollout/ep_ori_err_mean', 'rollout/ep_pos_err_mean', 'rollout/ep_rew_mean', 'global_step'])

## PPO Run

In [4]:
run_id = '1tq7w30i'
run = api.run('krshna/cvxrl/{}'.format(run_id))

In [None]:
env_cls = make_env_cls(diff=run.config['diff'], episode_length=500, reward_fn='train2',
                       initializer='center', torque_factor=1., force_factor=1.)

wrapper = lambda env: FlattenGoalObs(ResidualPDWrapper(env, force_factor=.1, torque_factor=.25), 
                                     observation_keys=['desired_goal', 'achieved_goal', 'observation'])

env = make_vec_env(env_cls, n_envs=10, wrapper_class=wrapper,
        monitor_kwargs=dict(info_keywords=('ori_err', 'pos_err')))

model = make_ppo_model(500, 3e-4, None, env, True)

## SAC Run

In [5]:
run_id = 'xcpyxowe'
run = api.run('krshna/cvxrl/{}'.format(run_id))

In [10]:
hist = run_history(run_id)

In [11]:
df = pd.DataFrame({k: hist[k] for k in hist if len(hist[k]) == len(hist['global_step'])})
df.columns

In [None]:
df = pd.DataFrame({k: hist[k] for k in hist if len(hist[k]) == max([len(hist[k]) for k in hist])})
df[df['rollout/ep_pos_err_mean'] < 1.].plot(x='global_step', y=['rollout/ep_pos_err_mean', 'rollout/ep_ori_err_mean'])
plt.suptitle('Pos/Ori error - Residual')
# df[df['rollout/ep_pos_err_mean'] < 1.].plot(x='global_step', y='rollout/ep_ori_err_mean')

In [7]:
!ls {get_save_path(run)}

2e+05-steps.zip  best_model.zip  evaluations.npz


In [15]:
# initializer = initializers.fixed_g_init(4, Pose.from_json(json.load(open('/scr-ssd/ksrini/spinningup/notebooks/goal.json', 'r'))).to_dict())

# Create DummyVecEnv
if run.config.get('contact'):
    env_cls = cube_env.ContactForceCubeEnv
else:
    env_cls = cube_env.CubeEnv
    
env_fn = make_env.env_fn_generator(diff=run.config['diff'], visualization=True, save_freq=1, initializer=run.config['init'],
                                   reward_fn=run.config['rew_fn'], residual=run.config.get('residual', False), env_cls=env_cls)
env = DummyVecEnv([env_fn])

# Create model
use_goal = not(run.config.get('no_goal'))
use_sde = not(run.config.get('no_sde'))
load_path = None # osp.join(get_save_path(run), '2e+05-steps.zip')
model = make_model(run.config.get('ep_len'), run.config['lr'], use_goal=use_goal, use_sde=use_sde, env=env,
                   load_path=load_path)



Using cuda device


In [16]:
!ls {get_save_path(run)}

2e+05-steps.zip  best_model.zip  evaluations.npz


In [17]:
load_path = osp.join(get_save_path(run), '2e+05-steps.zip')
model = model.load(load_path, env)

## Evaluating a policy

In [19]:
from stable_baselines3.common.evaluation import evaluate_policy

In [20]:
evaluate_policy(model, env, n_eval_episodes=10)

(819.3640668, 364.7574312471647)

In [None]:
env.close()

In [30]:
gym_env = env.envs[0]

o = gym_env.reset()
d = False
rs = []
acs = []
while not d:
    ac, _ = model.predict(o)
    acs.append(ac)
#     ac = gym_env.action_space.sample()
    o,r,d,i = gym_env.step(ac)
    rs.append(r)

In [71]:
model.predict(o, deterministic=True)

(array([-0.06126785,  0.00481224, -0.0344879 ,  0.05529511, -0.01432532,
         0.04956543,  0.01054287, -0.05547428,  0.06105804], dtype=float32),
 None)

In [46]:
!ls {get_save_path(run)}

2e+05-steps.zip  best_model.zip  evaluations.npz


In [50]:
display_video(osp.join(get_save_path(run), 'videos', 'sim-20.mp4'))