This notebook provides examples to go along with the [textbook](http://manipulation.csail.mit.edu/rl.html).  I recommend having both windows open, side-by-side!

In [5]:
import gymnasium as gym
import numpy as np
import torch
from psutil import cpu_count
from pydrake.all import StartMeshcat

import box_flipup  # no-member
from manipulation.meshcat_utils import plot_surface
#from manipulation.utils import FindDataResource, RenderDiagram, running_as_notebook
from utils import FindResource, running_as_notebook
num_cpu = int(cpu_count() / 2) if running_as_notebook else 2

from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv

In [6]:
meshcat = StartMeshcat()

INFO:drake:Meshcat listening for connections at http://localhost:7004


# RL for box flip-up

## State-feedback policy via PPO (with stiffness control)

In [7]:
observations = "state"
time_limit = 10 if running_as_notebook else 0.5

# Note: Models saved in stable baselines are version specific.  This one
# requires python3.8 (and cloudpickle==1.6.0).
zip = f"box_flipup_ppo_{observations}.zip"


# Use a callback so that the forked process imports the environment.
def make_boxflipup():
    import manipulation.envs.box_flipup

    return gym.make("BoxFlipUp-v0", observations=observations, time_limit=time_limit)


env = make_vec_env(
    make_boxflipup,
    n_envs=num_cpu,
    seed=0,
    vec_env_cls=SubprocVecEnv if running_as_notebook else DummyVecEnv,
)

use_pretrained_model = True
if use_pretrained_model:
    # TODO(russt): Save a trained model that works on Deepnote.
    model = PPO.load(FindResource(zip), env)
elif running_as_notebook:
    # This is a relatively small amount of training.  See rl/train_boxflipup.py
    # for a version that runs the heavyweight version with multiprocessing.
    model = PPO("MlpPolicy", env, verbose=1)
    model.learn(total_timesteps=100000)
else:
    # For testing this notebook, we simply want to make sure that the code runs.
    model = PPO("MlpPolicy", env, n_steps=4, n_epochs=2, batch_size=8)
    model.learn(total_timesteps=4)

Process ForkServerProcess-14:
Traceback (most recent call last):
  File "/usr/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/lib/python3.12/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/franka/Documents/franka-pipeline/.venv/lib/python3.12/site-packages/stable_baselines3/common/vec_env/subproc_vec_env.py", line 29, in _worker
    env = _patch_env(env_fn_wrapper.var())
                     ^^^^^^^^^^^^^^^^^^^^
  File "/home/franka/Documents/franka-pipeline/.venv/lib/python3.12/site-packages/stable_baselines3/common/env_util.py", line 98, in _init
    env = env_id(**env_kwargs)
          ^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_491386/2212558267.py", line 13, in make_boxflipup
  File "/home/franka/Documents/franka-pipeline/.venv/lib/python3.12/site-packages/gymnasium/envs/registration.py", line 741, in make
    env_spec = _find_spec(id)
               ^^^^^^^^^^^^^^
  File "/ho

ConnectionResetError: [Errno 104] Connection reset by peer

_worker
    env = _patch_env(env_fn_wrapper.var())
                     ^^^^^^^^^^^^^^^^^^^^
  File "/home/franka/Documents/franka-pipeline/.venv/lib/python3.12/site-packages/stable_baselines3/common/env_util.py", line 98, in _init
    env = env_id(**env_kwargs)
          ^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_491386/2212558267.py", line 13, in make_boxflipup
  File "/home/franka/Documents/franka-pipeline/.venv/lib/python3.12/site-packages/gymnasium/envs/registration.py", line 741, in make
    env_spec = _find_spec(id)
               ^^^^^^^^^^^^^^
  File "/home/franka/Documents/franka-pipeline/.venv/lib/python3.12/site-packages/gymnasium/envs/registration.py", line 527, in _find_spec
    _check_version_exists(ns, name, version)
  File "/home/franka/Documents/franka-pipeline/.venv/lib/python3.12/site-packages/gymnasium/envs/registration.py", line 393, in _check_version_exists
    _check_name_exists(ns, name)
  File "/home/franka/Documents/franka-pipeline/.venv/lib/python3.12/site-

Process ForkServerProcess-6:
Traceback (most recent call last):
  File "/usr/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/lib/python3.12/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/franka/Documents/franka-pipeline/.venv/lib/python3.12/site-packages/stable_baselines3/common/vec_env/subproc_vec_env.py", line 29, in _worker
    env = _patch_env(env_fn_wrapper.var())
                     ^^^^^^^^^^^^^^^^^^^^
  File "/home/franka/Documents/franka-pipeline/.venv/lib/python3.12/site-packages/stable_baselines3/common/env_util.py", line 98, in _init
    env = env_id(**env_kwargs)
          ^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_491386/2212558267.py", line 13, in make_boxflipup
  File "/home/franka/Documents/franka-pipeline/.venv/lib/python3.12/site-packages/gymnasium/envs/registration.py", line 741, in make
    env_spec = _find_spec(id)
               ^^^^^^^^^^^^^^
  File "/hom

In [None]:
# Make a version of the env with meshcat.
env = gym.make("BoxFlipUp-v0", meshcat=meshcat, observations=observations)

obs, _ = env.reset()
meshcat.StartRecording()
for i in range(500 if running_as_notebook else 5):
    action, _state = model.predict(obs, deterministic=True)
    obs, reward, terminated, truncated, info = env.step(action)
    env.render()
    if terminated:
        obs, _ = env.reset()
meshcat.PublishRecording()

In [None]:
obs, _ = env.reset()
Q, Qdot = np.meshgrid(np.arange(0, np.pi, 0.05), np.arange(-2, 2, 0.05))
# TODO(russt): tensorize this...
V = 0 * Q
for i in range(Q.shape[0]):
    for j in range(Q.shape[1]):
        obs[2] = Q[i, j]
        obs[7] = Qdot[i, j]
        with torch.no_grad():
            V[i, j] = (
                model.policy.predict_values(model.policy.obs_to_tensor(obs)[0])[0]
                .cpu()
                .numpy()[0]
            )
V = V - np.min(np.min(V))
V = V / np.max(np.max(V))

meshcat.Delete()
meshcat.ResetRenderMode()
plot_surface(meshcat, "Critic", Q, Qdot, V, wireframe=True)

In [None]:
env = gym.make("BoxFlipUp-v0")

#RenderDiagram(env.simulator.get_system(), max_depth=1)

<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=8ac7c900-70d2-4af1-83c6-341a64fb0e14' target="_blank">
 </img>
Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>