# Test the training procedure of CatBot
Based on: 
https://github.com/RussTedrake/manipulation/blob/master/rl/box_flipup.ipynb

Some insight into tensorboard plots:
https://medium.com/aureliantactics/understanding-ppo-plots-in-tensorboard-cbc3199b9ba2

Use tensorboard --logdir \(LOGDIR\)

In [3]:
import os
import gym
import numpy as np
import datetime
import torch
from pydrake.all import StartMeshcat

from catbot.RL.catbot_rl_env import CatBotEnv

from psutil import cpu_count

num_cpu = int(cpu_count() / 2)

# Optional imports (these are heavy dependencies for just this one notebook)
sb3_available = False
try:
    from stable_baselines3 import PPO, SAC
    from stable_baselines3.common.vec_env import SubprocVecEnv
    from stable_baselines3.common.env_util import make_vec_env
    from stable_baselines3.common.monitor import Monitor
    from stable_baselines3.common.callbacks import CheckpointCallback

    sb3_available = True
except ImportError:
    print("stable_baselines3 not found")
    print("Consider 'pip3 install stable_baselines3'.")


In [4]:
meshcat = StartMeshcat()

gym.envs.register(
    id="CatBot-v0", entry_point="catbot.RL.catbot_rl_env:CatBotEnv"
)

INFO:drake:Meshcat listening for connections at http://localhost:7000


In [None]:
observations = "state"
time_limit = 4
total_timesteps = 10000

env = make_vec_env(
    CatBotEnv,
    n_envs=8,  # Was num_cpu
    seed=0,
    vec_env_cls=SubprocVecEnv,
    env_kwargs={
        "observations": observations,
        "time_limit": time_limit,
    },
)

checkpoint_callback = CheckpointCallback(
    save_freq=1000,
    save_path='./checkpoints/',
    name_prefix='PPO_CHECKPOINT_TEST',
    save_replay_buffer=True,
    save_vecnormalize=True,
)

use_pretrained_model = False 
model_zip_fname = "./models/PPO_C6_0254.zip"
sac = False
print('starting training')
if sac:
    model = SAC(
        "MlpPolicy",
        # wrapped_env,
        env,
        verbose=0,
        batch_size=32,
        tensorboard_log="./ppo_cat_bot_logs",
        seed=0)
    print('SAC')
    model.learn(total_timesteps=total_timesteps, progress_bar=True, callback=checkpoint_callback)
else:
    print('PPO')
    if use_pretrained_model:
        model = PPO.load(model_zip_fname, env)
        model.learn(total_timesteps=total_timesteps, progress_bar=True, callback=checkpoint_callback)
    else:
        model = PPO(
            "MlpPolicy",
            env,
            # env,
            verbose=0,
            n_steps=4,
            n_epochs=2,
            batch_size=32,
            tensorboard_log="./ppo_cat_bot_logs")
        model.learn(total_timesteps=total_timesteps, progress_bar=True, callback=checkpoint_callback)


## Save Model

In [None]:
save_dir = './models'
saved_model_fnames = os.listdir(save_dir)

datetime_val = datetime.datetime.now().strftime("%y%m%d_%H%M%S")
model_fname = f'{datetime_val}_PPO.zip'

save_fname = f'./models/{model_fname}'
model.save(save_fname)
print(f'Saved model to: {model_fname}')

Saved model to: 230510_030745_PPO.zip


### Show just trained model

In [None]:
env = gym.make("CatBot-v0", meshcat=meshcat, observations=observations, time_limit=time_limit)
env.simulator.set_target_realtime_rate(1.0)

obs = env.reset()
for i in range(500):
    action, _state = model.predict(obs, deterministic=True)
    obs, reward, done, info = env.step(action)
    env.render()
    if done:
        obs = env.reset()

# Show pretrained model

In [16]:
# model_zip_fname = "./checkpoints/230510_125444_PPO_P0_C10_390000_steps.zip"
# model_zip_fname = "./checkpoints/230510_131356_PPO_P0_C10_400000_steps.zip"
# model_zip_fname = "./checkpoints/230510_132402_PPO_P0_C11_300000_steps.zip"
# model_zip_fname = "./checkpoints/230510_134537_PPO_P0_C12_30000_steps.zip"
# model_zip_fname = "./checkpoints/230510_135133_PPO_P0_C12_400000_steps.zip"
# model_zip_fname = "./checkpoints/230510_140638_PPO_P0_C12_20000_steps.zip"
# model_zip_fname = "./checkpoints/230510_142001_PPO_P0_C12_sde_800000_steps.zip"
# model_zip_fname = "./checkpoints/230510_160054_PPO_P0_C13_sde_20000_steps.zip"
# model_zip_fname = "./checkpoints/230510_160600_PPO_P0_C13_sde_800000_steps.zip"
# model_zip_fname = "./checkpoints/230510_171803_PPO_P0_C13_SAC_60000_steps.zip"
# model_zip_fname = "./checkpoints/230510_171803_PPO_P0_C13_SAC_800000_steps.zip"
# model_zip_fname = "./checkpoints/230510_170053_PPO_P1_C13_sde_800000_steps.zip"
# model_zip_fname = "./checkpoints/230510_170053_PPO_P1_C13_sde_800000_steps.zip"
# model_zip_fname = "./checkpoints/230510_181232_SAC_P0_C13_1600000_steps.zip"
# model_zip_fname = "./checkpoints/230510_230410_SAC_P1_C13_Half_init_1700000_steps.zip"
model_zip_fname = "./models/PPO_800000_good.zip"
# model_zip_fname = "./models/SAC_800000_good.zip"
observations = "state"
time_limit = 8

env = gym.make("CatBot-v0", meshcat=meshcat, observations=observations, time_limit=time_limit)
env.simulator.set_target_realtime_rate(2)

# sac = True
sac = False
if sac:
    model = SAC.load(model_zip_fname, env)
else:
    model = PPO.load(model_zip_fname, env)

obs = env.reset()
cum_reward = 0
reward_list = []
for i in range(4000):
    action, _state = model.predict(obs, deterministic=True)
    obs, reward, done, info = env.step(action)
    cum_reward += reward
    if done:
        print('Reward: ', cum_reward)
        reward_list.append(cum_reward)  
        cum_reward = 0
    env.render()
    if done:
        obs = env.reset()
print('Mean reward: ', np.mean(np.array(reward_list)))

Reward:  0.0
Reward:  0.0
Reward:  0.0
Reward:  0.0
Reward:  0.0
Reward:  0.0
Reward:  4500.0
Reward:  0.0
Reward:  0.0
Reward:  0.0
Reward:  0.0
Reward:  0.0
Reward:  0.0
Reward:  0.0
Reward:  0.0
Reward:  0.0
Reward:  1800.0
Reward:  0.0
Reward:  3900.0
Reward:  0.0
Reward:  0.0
Reward:  0.0
Reward:  3200.0
Reward:  0.0
Reward:  4000.0
Reward:  0.0
Reward:  0.0
Reward:  2000.0
Reward:  0.0
Reward:  200.0
Reward:  0.0
Reward:  0.0
Reward:  2300.0
Reward:  0.0
Reward:  300.0
Reward:  0.0
Reward:  0.0
Reward:  0.0
Reward:  0.0
Reward:  0.0
Reward:  0.0
Reward:  0.0
Reward:  200.0
Reward:  3300.0
Reward:  1700.0
Reward:  0.0
Reward:  4500.0
Reward:  0.0
Reward:  900.0
Mean reward:  669.3877551020408
