# Loading, evaluating, and visualizing runs

## Plotting imports

In [1]:
%matplotlib inline
%config InlineBackend.figure_format='retina'

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

## Helper functions

In [2]:
import gym
import torch.nn as nn
from gym import ObservationWrapper
from gym.spaces import flatten_space
from gym.wrappers import FilterObservation
from rrc.env import initializers, cube_env
from rrc.env.reward_fns import *
from rrc.env.wrappers import MonitorPyBulletWrapper
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.torch_layers import CombinedExtractor
from stable_baselines3.common.preprocessing import get_flattened_obs_dim, is_image_space
from stable_baselines3 import HerReplayBuffer, SAC, TD3
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize

class FlattenGoalObs(ObservationWrapper):
    def __init__(self, env, observation_keys):
        super().__init__(env)
        obs_space = self.env.observation_space
        obs_dict = {k: flatten_space(obs_space[k]) for k in observation_keys}
        self.observation_space = gym.spaces.Dict(obs_dict)

    def observation(self, obs):
        n_obs = {}
        for k in self.observation_space.spaces:
            if isinstance(obs[k], dict):
                obs_list = [obs[k][k2] for k2 in self.env.observation_space[k]]
                n_obs[k] = np.concatenate(obs_list)
            else:
                n_obs[k] = obs[k]
        return n_obs


class HERCombinedExtractor(CombinedExtractor):
    """
    HERCombinedExtractor is a combined extractor which only extracts pre-specified observation_keys to include in
    the observation, while retaining them at the environment level so that they may still be stored in the replay buffer
    """

    def __init__(self, observation_space: gym.spaces.Dict, cnn_output_dim: int = 256, observation_keys: list = []):
        # TODO we do not know features-dim here before going over all the items, so put something there. This is dirty!
        super(CombinedExtractor, self).__init__(observation_space, features_dim=1)

        extractors = {}

        total_concat_size = 0
        for key in observation_keys:
            subspace = observation_space.spaces[key]
            # The observation key is a vector, flatten it if needed
            extractors[key] = nn.Flatten()
            total_concat_size += get_flattened_obs_dim(subspace)

        self.extractors = nn.ModuleDict(extractors)

        # Update the features dim manually
        self._features_dim = total_concat_size


def make_model(ep_len, lr, exp_dir=None, env=None, use_goal=True,
               use_sde=False):
    if use_goal:
        obs_keys = ['desired_goal', 'observation']
    else:
        obs_keys = ['observation']

    policy_kwargs = dict(
                    log_std_init=-3,
                    features_extractor_class=HERCombinedExtractor,
                    features_extractor_kwargs=dict(observation_keys=obs_keys))
    if use_sde:
        sde_kwargs = dict(
                use_sde=True,
                use_sde_at_warmup=True,
                sde_sample_freq=64)
    else:
        sde_kwargs = {}

    rb_kwargs = dict(
                    n_sampled_goal=4,
                    goal_selection_strategy='future',
                    online_sampling=False,
                    max_episode_length=ep_len)

    model = SAC('MultiInputPolicy', env,
                # tensorboard_log=exp_dir,
                replay_buffer_class=HerReplayBuffer,
                # Parameters for HER
                replay_buffer_kwargs=rb_kwargs,
                policy_kwargs=policy_kwargs,
                verbose=1, buffer_size=int(1e6),
                learning_starts=1500,
                learning_rate=lr,
                gamma=0.99, batch_size=256, **sde_kwargs)
    return model


def env_fn_generator(diff=3, initializer=initializers.training_init,
                     episode_length=500, relative_goal=True, reward_fn=None,
                     save_mp4=False, save_dir='', save_freq=10, **env_kwargs):
    if reward_fn is None:
        reward_fn = training_reward3
    else:
        if reward_fn == 'train1':
            reward_fn = training_reward1
        elif reward_fn == 'train2':
            reward_fn = training_reward2
        elif reward_fn == 'train3':
            reward_fn = training_reward3
        elif reward_fn == 'competition':
            reward_fn = competition_reward

    def env_fn():
        env = cube_env.CubeEnv(None, diff,
                initializer=initializer,
                episode_length=episode_length,
                relative_goal=relative_goal,
                reward_fn=reward_fn,
                torque_factor=.1,
                **env_kwargs)
        if save_mp4:
            env = MonitorPyBulletWrapper(env, save_dir, save_freq)
        env = FlattenGoalObs(env, ['desired_goal', 'achieved_goal', 'observation'])
        return Monitor(env, info_keywords=('ori_err', 'pos_err'))
    return env_fn


wandb_root = '/scr-ssd/ksrini/spinningup/notebooks'
get_save_path = lambda run: '/'.join([wandb_root] + run.config['exp_dir'].split('/')[1:])

def display_video(path=None, run=None):
    if run:
        path = get_save_path(run.config['exp_dir'])
    return Video(path, embed=True, width=640)

## Visualizing a trajectory (mp4 video)

In [3]:
from IPython.display import Video

In [4]:
Video('/scr-ssd/ksrini/spinningup/notebooks/videos/sim-22.mp4', embed=True, width=640)

## Loading a wandb run

In [4]:
import wandb
import pandas as pd
import numpy as np
import os.path as osp
from trifinger_simulation.tasks.move_cube import Pose
import json

api = wandb.Api()

In [5]:
def run_history(run_id, proj='cvxrl', keys=None, samples=1000):
    if len(run_id.split('/')) == 2:
        proj, run_id = run_id.split('/')
    if keys:
        data = api.run('krshna/{}/{}'.format(proj, run_id)).history(samples=samples, keys=keys, pandas=False)
        while len(data) == 1:
            data = data[0]
        return pd.DataFrame(data)
    history = api.run('krshna/{}/{}'.format(proj, run_id)).history(samples=samples, keys=keys)
    keep_cols = ([c for c in history.columns if 'train' in c] 
     + [c for c in history.columns if 'rollout' in c]
     + [c for c in history.columns if 'episodes' in c])

    new_df = {}
    for c in keep_cols:
        new_df[c] = history[c][~history[c].isna()].values
    new_df['global_step'] = history['global_step'][~history[c].isna()].values
    return new_df

In [7]:
history = run_history('xfwj0484')

In [8]:
history['global_step']

array([ 16000,  38000,  84000,  94000,  96000, 128000, 152000, 168000,
       190000, 194000, 196000, 206000, 222000, 224000, 248000, 268000,
       304000, 316000, 336000, 364000, 390000, 458000, 462000, 498000,
       506000, 512000, 524000, 530000, 538000, 560000, 568000, 580000,
       608000, 632000, 678000, 708000, 730000, 736000, 740000, 746000,
       780000, 782000, 790000, 796000, 820000, 822000, 842000, 844000,
       852000, 860000, 878000, 886000, 914000, 942000, 946000, 948000,
       950000, 976000, 982000, 990000, 996000])

In [13]:
run_id = 'v8j2y9w8'
run = api.run('krshna/cvxrl/{}'.format(run_id))

In [14]:
initializer = initializers.fixed_g_init(4, Pose.from_json(json.load(open('/scr-ssd/ksrini/spinningup/notebooks/goal.json', 'r'))).to_dict())

# Create DummyVecEnv
# env_fn = env_fn_generator(save_mp4=True, visualization=True, save_dir=osp.join(get_save_path(run), 'videos'), save_freq=1,
#                           initializer=initializer)
# env = DummyVecEnv([env_fn])

# Create model
use_goal = not(run.config.get('no_goal'))
use_sde = not(run.config.get('no_sde'))
model = make_model(500, run.config['lr'], use_goal=use_goal, use_sde=use_sde, env=env)
model.load(osp.join(get_save_path(run), 'best_model.zip'), env)
# save_path = './data/HER-SAC_rrc-diff3/2021-06-17_10-19-29/best_model.npz'

Using cuda device


<stable_baselines3.sac.sac.SAC at 0x7f67801f9048>

## Evaluating a policy

In [15]:
from stable_baselines3.common.evaluation import evaluate_policy

In [16]:
evaluate_policy(model, env, n_eval_episodes=1)

(0.012224, 0.0)

In [18]:
display_video(env.envs[0].videos[1])