# Test the training procedure of CatBot
Based on: 
https://github.com/RussTedrake/manipulation/blob/master/rl/box_flipup.ipynb

Some insight into tensorboard plots:
https://medium.com/aureliantactics/understanding-ppo-plots-in-tensorboard-cbc3199b9ba2

Use tensorboard --logdir \(LOGDIR\)

In [1]:
import os
import gym
import numpy as np
import torch
from pydrake.all import StartMeshcat

from catbot.RL.catbot_rl_env import CatBotEnv

from psutil import cpu_count

num_cpu = int(cpu_count() / 2)

# Optional imports (these are heavy dependencies for just this one notebook)
sb3_available = False
try:
    from stable_baselines3 import PPO
    from stable_baselines3.common.vec_env import SubprocVecEnv
    from stable_baselines3.common.env_util import make_vec_env
    from stable_baselines3.common.monitor import Monitor

    sb3_available = True
except ImportError:
    print("stable_baselines3 not found")
    print("Consider 'pip3 install stable_baselines3'.")


In [2]:
meshcat = StartMeshcat()

gym.envs.register(
    id="CatBot-v0", entry_point="catbot.RL.catbot_rl_env:CatBotEnv"
)

INFO:drake:Meshcat listening for connections at http://localhost:7000


In [5]:
observations = "state"
time_limit = 2

env = make_vec_env(
    CatBotEnv,
    n_envs=8,  # Was num_cpu
    seed=0,
    vec_env_cls=SubprocVecEnv,
    env_kwargs={
        "observations": observations,
        "time_limit": time_limit,
    },
)

# wrapped_env = Monitor(env)

use_pretrained_model = False
if use_pretrained_model:
    pass
else:
    model = PPO(
        "MlpPolicy",
        # wrapped_env,
        env,
        verbose=0,
        n_steps=4,
        n_epochs=2,
        batch_size=32,
        tensorboard_log="./ppo_cat_bot_logs")
    model.learn(total_timesteps=50000, progress_bar=True)

env = gym.make("CatBot-v0", meshcat=meshcat, observations=observations, time_limit=time_limit)

env.simulator.set_target_realtime_rate(1.0)

Output()

In [4]:
obs = env.reset()
for i in range(500):
    action, _state = model.predict(obs, deterministic=True)
    obs, reward, done, info = env.step(action)
    env.render()
    if done:
        obs = env.reset()

KeyboardInterrupt: 

In [10]:
save_dir = './models'
saved_model_fnames = os.listdir(save_dir)

i = 0
while True:
    model_fname = f'PPO_test_{i}.zip'
    if model_fname not in saved_model_fnames:
        break
    i += 1

save_fname = f'./models/{model_fname}'
model.save(save_fname)
print(f'Saved model to: {model_fname}')

Saved model to: PPO_test_3.zip
