# Stretch Reinforcement Learning with DM_Control and PPO

This notebook is a reference for a RL exercise with the Hello Robot Stretch.

Two learning tasks are defined in this notebook:
- [StretchPushCubeTraining](#training-task-definition-stretchpushcubetraining)
- [StretchPushCubeTrainingArmOnly](#training-task-definition-2-stretchpushcubetrainingarmonly)

Both of them are trained to push a cube on a table.

The [PPO](https://en.wikipedia.org/wiki/Proximal_policy_optimization) algorithm is implemented using PyTorch in the [PPO Definition](#ppo-definition) section.

References:
- Google Deepmind [DM_Control Colab](https://colab.research.google.com/github/google-deepmind/dm_control/blob/main/tutorial.ipynb#scrollTo=JHSvxHiaopDb)
- CleanRL single-file [PPO algorithm implementation](https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_continuous_actionpy)


<!-- Internal installation instructions. -->

## Install the dependencies using UV

- Get UV by following the instruction in the [README](../../README.md) 
- Run `uv pip install -e ".[rlearning]"` to install the RL dependencies.

In [None]:
use_gpu = True # if available

try:
  import google.colab
  RUNNING_IN_COLAB = True
except:
  RUNNING_IN_COLAB = False

In [None]:
"""Install dependencies when running in Google Colab"""
if RUNNING_IN_COLAB:
  %pip install -q dm_control matplotlib

if RUNNING_IN_COLAB and use_gpu:
  import os
  import subprocess
  if subprocess.run('nvidia-smi').returncode:
    raise RuntimeError(
        'Cannot communicate with GPU. '
        'Make sure you are using a GPU Colab runtime. '
        'Go to the Runtime menu and select Choose runtime type.')

  # Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
  # This is usually installed as part of an Nvidia driver package, but the Colab
  # kernel doesn't install its driver via APT, and as a result the ICD is missing.
  # (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
  NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
  if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
    with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
      f.write("""{
      "file_format_version" : "1.0.0",
      "ICD" : {
          "library_path" : "libEGL_nvidia.so.0"
      }
  }
  """)

  # Configure dm_control to use the EGL rendering backend (requires GPU)
  %env MUJOCO_GL=egl

  print('Checking that the dm_control installation succeeded...')
  try:
    from dm_control import suite
    stretchPushCubeTrainingArmOnly = suite.load('cartpole', 'swingup')
    pixels = stretchPushCubeTrainingArmOnly.physics.render()
  except Exception as e:
    raise e from RuntimeError(
        'Something went wrong during installation. Check the shell output above '
        'for more information.\n'
        'If using a hosted Colab runtime, make sure you enable GPU acceleration '
        'by going to the Runtime menu and selecting "Choose runtime type".')
  else:
    del pixels, suite

### GPU Device (If available)

In [None]:
# Init a device with cuda or mps so that it can train faster
import platform
from typing import Literal
import torch


device: Literal['cuda'] | Literal['mps'] | Literal['cpu'] = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

device = device if use_gpu else "cpu"

if use_gpu and platform.system() != "Darwin":
  # Configure dm_control to use the EGL rendering backend (requires GPU)
  %env MUJOCO_GL=egl

print(f"Using {device} device")

### Utils

In [None]:
# From the Google Deepmind dm_control colab notebook:
import numpy as np
import matplotlib
import matplotlib.animation as animation
import matplotlib.pyplot as plt
from IPython.display import HTML, display
import PIL.Image

# The basic mujoco wrapper.
from dm_control import mujoco

# Access to enums and MuJoCo library functions.
from dm_control.mujoco.wrapper.mjbindings import enums
from dm_control.mujoco.wrapper.mjbindings import mjlib

# Use svg backend for figure rendering
%config InlineBackend.figure_format = 'svg'

# Font sizes
SMALL_SIZE = 8
MEDIUM_SIZE = 10
BIGGER_SIZE = 12
plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title


def display_video(frames, framerate=30):
  height, width, _ = frames[0].shape
  dpi = 70
  orig_backend = matplotlib.get_backend()
  matplotlib.use('Agg')  # Switch to headless 'Agg' to inhibit figure rendering.
  fig, ax = plt.subplots(1, 1, figsize=(width / dpi, height / dpi), dpi=dpi)
  matplotlib.use(orig_backend)  # Switch back to the original backend.
  ax.set_axis_off()
  ax.set_aspect('equal')
  ax.set_position([0, 0, 1, 1])
  im = ax.imshow(frames[0])
  def update(frame):
    im.set_data(frame)
    return [im]
  interval = 1000/framerate
  anim = animation.FuncAnimation(fig=fig, func=update, frames=frames,
                                  interval=interval, blit=True, repeat=False)
  return HTML(anim.to_html5_video())

# PPO and Task Definitions

### Training Task Definition: `StretchPushCubeTraining`

In [None]:
import time
from stretch_mujoco.enums.actuators import Actuators

scene_option = mujoco.wrapper.core.MjvOption()
scene_option.flags[enums.mjtVisFlag.mjVIS_JOINT] = True

In [None]:
"""
This is a training task that tries to push a cube to a target location using all the joints and actuators on the robot.
"""
from functools import cache


class StretchPushCubeTraining:
    def __init__(self, physics: mujoco.Physics, push_cube_by:tuple[float,float,float]):
        self.physics = physics
        self.target_position =  self._get_cube_pos() + push_cube_by

        print("Using all joints:", self._get_joints())

        # Define state size: joint positions, joint velocities, 3 object1 position
        self.state_size = len(self._get_joints()) * 2 + 3
        
        # Define action size: continuous joint actions
        self.action_size = len(self._get_joints())  # num joints to control

        self.frames = []
        self.render_rate = 1/30 #1/Hz
        self.time_last_render = time.perf_counter()
        self.last_step_time = time.perf_counter()

        self.current_distance_to_target = float('inf')

    @cache
    def _get_joints(self):
        """Gets joint names in MJCF"""
        return [name for j in self._get_actuators() for name in j.get_joint_names_in_mjcf()]
    
    @cache
    def _get_actuator_names(self):
        return [j.name for j in self._get_actuators()]
    
    @cache
    def _get_actuators(self):
        return Actuators.get_actuated_joints()

    @cache
    def _get_cube_id(self):
        return self.physics.model.name2id("object1", "body")
    def _get_cube_pos(self):
        return self.physics.data.xpos[self._get_cube_id()]
    
    @cache
    def _get_cube_original_pos(self):
        return self.physics.model.body("object1").pos
    
    def arm_joint_pos(self):
        return self.physics.named.data.qpos[self._get_joints()]
    def arm_joint_vel(self):
        return self.physics.named.data.qvel[self._get_joints()]
    
    def reset(self, use_home_pose = True):
        # Reset the simulation
        self.frames = []
        
        self.physics.reset(0 if use_home_pose else None)

        if use_home_pose:
            #Reset isn't working, so we're gonna go there manually:
            self.physics.data.ctrl = self.physics.model.keyframe("home").ctrl
            for x in range(400):
                self.physics.step()
                self.render()
        
        self.current_distance_to_target = float('inf')

        return np.concatenate([self.arm_joint_pos(), self.arm_joint_vel(), self._get_cube_original_pos()])
        

    def reward(self):
        # Calculate the reward (negative distance to target position of object1)
        object_pos = self._get_cube_pos()
        self.current_distance_to_target = np.linalg.norm(object_pos - self.target_position).astype(np.float32)
        return -self.current_distance_to_target  # Negative because we want to minimize the distance

    def check_is_done(self):
        return self.current_distance_to_target < 0.05  # Done if the object is close enough to the target

    def step(self, action):

        time_until_next_step = self.physics.model.opt.timestep - (time.perf_counter() - self.last_step_time)
        if time_until_next_step > 0:
            # Sleep to match the timestep.
            time.sleep(time_until_next_step)

        # Apply the action to the joints
        for index, name in enumerate(self._get_actuator_names()):
            self.physics.data.actuator(name).ctrl = action[index]
        
        # Step the simulation forward
        self.physics.step()

        self.last_step_time = time.perf_counter()

        # Get the current state (qpos, qvel, object1 position)
        state = np.concatenate([self.arm_joint_pos(), self.arm_joint_vel(), self._get_cube_pos()])

        return state

    def render(self):
        
        elapsed = time.perf_counter() - self.time_last_render
        if elapsed > self.render_rate:
            self.time_last_render = time.perf_counter()

            pixels = self.physics.render(scene_option=scene_option)

            self.frames.append(pixels)



### Training Task Definition #2: `StretchPushCubeTrainingArmOnly`

In [None]:
"""
This is a training task that tries to push a cube to a target location using the arm joints only.
"""
class StretchPushCubeTrainingArmOnly(StretchPushCubeTraining):
    def __init__(self, physics: mujoco.Physics, push_cube_by:tuple[float,float,float]):
        self.physics = physics
        self.target_position =  self._get_cube_pos() + push_cube_by

        print("Using arm joints only:", self._get_joints())

        # Define state size: joint positions, joint velocities, 3 object1 position
        self.state_size = len(self._get_joints()) * 2 + 3
        
        # Define action size: 7 continuous joint actions
        self.action_size = len(self._get_joints())  # num joints to control

        self.frames = []
        self.render_rate = 1/30 #1/Hz
        self.time_last_render = time.perf_counter()
        self.last_step_time = time.perf_counter()

        self.current_distance_to_target = float('inf')


    @cache
    def _get_actuators(self):
        return Actuators.get_arm_joints()
    


### PPO Definition

In [None]:
# References https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_continuous_actionpy
import random
import time
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.normal import Normal


@dataclass
class PpoTrainingArgs:
    total_timesteps: int = 1000000
    """total timesteps of the experiments"""
    learning_rate: float = 3e-4
    """the learning rate of the optimizer"""
    num_envs: int = 1
    """the number of parallel game environments"""
    num_steps: int = 2048
    """the number of steps to run in each environment per policy rollout"""
    anneal_lr: bool = True
    """Toggle learning rate annealing for policy and value networks"""
    gamma: float = 0.99
    """the discount factor gamma"""
    gae_lambda: float = 0.95
    """the lambda for the general advantage estimation"""
    num_minibatches: int = 32
    """the number of mini-batches"""
    update_epochs: int = 10
    """the K epochs to update the policy"""
    norm_adv: bool = True
    """Toggles advantages normalization"""
    clip_coef: float = 0.2
    """the surrogate clipping coefficient"""
    clip_vloss: bool = True
    """Toggles whether or not to use a clipped loss for the value function, as per the paper."""
    ent_coef: float = 0.0
    """coefficient of the entropy"""
    vf_coef: float = 0.5
    """coefficient of the value function"""
    max_grad_norm: float = 0.5
    """the maximum norm for the gradient clipping"""

    # to be filled in runtime
    batch_size: int = 0
    """the batch size (computed in runtime)"""
    minibatch_size: int = 0
    """the mini-batch size (computed in runtime)"""
    num_iterations: int = 0
    """the number of iterations (computed in runtime)"""
    seed:int = 1
    """PRNG seed"""
    save_model_to_path:str|None = None
    """the path to save the agent params at the end of training"""

def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

class Agent(nn.Module):
    def __init__(self, state_size, action_size):
        super().__init__()
        self.critic = nn.Sequential(
            layer_init(nn.Linear(state_size, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 1), std=1.0),
        )
        self.actor_mean = nn.Sequential(
            layer_init(nn.Linear(state_size, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, action_size), std=0.01),
        )
        self.actor_logstd = nn.Parameter(torch.zeros(1, action_size))

    def get_value(self, x):
        return self.critic(x)

    def get_action_and_value(self, x, action=None):
        action_mean = self.actor_mean(x)
        action_logstd = self.actor_logstd.expand_as(action_mean)
        action_std = torch.exp(action_logstd)
        probs = Normal(action_mean, action_std)
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x)


def train(env:StretchPushCubeTraining, training_args:PpoTrainingArgs):
    """
    Call this to train the task using PPO.

    References https://raw.githubusercontent.com/vwxyzjn/cleanrl/refs/heads/master/cleanrl/ppo_continuous_action.py
    """
    args = training_args

    args.batch_size = int(args.num_envs * args.num_steps)
    args.minibatch_size = int(args.batch_size // args.num_minibatches)
    args.num_iterations = args.total_timesteps // args.batch_size

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True
    
    agent = Agent(env.state_size,env.action_size).to(device)
    optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)

    # ALGO Logic: Storage setup
    for iteration in range(1, args.num_iterations + 1):
        start_time = time.time()

        obs = torch.zeros((args.num_steps+env.state_size, args.num_envs+env.state_size - 1)).to(device)
        actions = torch.zeros((args.num_steps + env.action_size, args.num_envs+ env.action_size - 1)).to(device)
        logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
        rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
        dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
        values = torch.zeros((args.num_steps, args.num_envs)).to(device)

        next_obs = env.reset()
        next_obs = torch.Tensor(next_obs).to(device)
        next_obs = next_obs.reshape(1, env.state_size)
        next_done = torch.zeros(args.num_envs).to(device)

        # Annealing the rate if instructed to do so.
        if args.anneal_lr:
            frac = 1.0 - (iteration - 1.0) / args.num_iterations
            lrnow = frac * args.learning_rate
            optimizer.param_groups[0]["lr"] = lrnow

        for step in range(0, args.num_steps):
            obs[step] = next_obs
            dones[step] = next_done

            # ALGO LOGIC: action logic
            with torch.no_grad():
                action, logprob, _, value = agent.get_action_and_value(next_obs)
                values[step] = value.flatten()
            actions[step] = action
            logprobs[step] = logprob

            # STEP PHYSICS
            next_obs = env.step(action.cpu()[0])
            reward = env.reward()
            reward =  torch.tensor(reward).to(device).view(-1)

            env.render()

            next_done = env.check_is_done()

            rewards[step] = reward
            next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor([1] if next_done else [0]).to(device)

            next_obs = next_obs.reshape(1, env.state_size)

        # bootstrap value if not done
        with torch.no_grad():
            next_value = agent.get_value(next_obs).reshape(1, -1)
            advantages = torch.zeros_like(rewards).to(device)
            lastgaelam = 0
            for t in reversed(range(args.num_steps)):
                if t == args.num_steps - 1:
                    nextnonterminal = 1.0 - next_done
                    nextvalues = next_value
                else:
                    nextnonterminal = 1.0 - dones[t + 1]
                    nextvalues = values[t + 1]
                delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
                advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
            returns = advantages + values

        # flatten the batch
        b_obs = obs.reshape((-1,env.state_size) )
        b_logprobs = logprobs.reshape(-1)
        b_actions = actions.reshape((-1,env.action_size))
        b_advantages = advantages.reshape(-1)
        b_returns = returns.reshape(-1)
        b_values = values.reshape(-1)

        # Optimizing the policy and value network
        b_inds = np.arange(args.batch_size)
        clipfracs = []
        for epoch in range(args.update_epochs):
            np.random.shuffle(b_inds)
            for start in range(0, args.batch_size, args.minibatch_size):
                end = start + args.minibatch_size
                mb_inds = b_inds[start:end]

                _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds])
                logratio = newlogprob - b_logprobs[mb_inds]
                ratio = logratio.exp()

                with torch.no_grad():
                    # calculate approx_kl http://joschu.net/blog/kl-approx.html
                    old_approx_kl = (-logratio).mean()
                    approx_kl = ((ratio - 1) - logratio).mean()
                    clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()]

                mb_advantages = b_advantages[mb_inds]
                if args.norm_adv:
                    mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)

                # Policy loss
                pg_loss1 = -mb_advantages * ratio
                pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
                pg_loss = torch.max(pg_loss1, pg_loss2).mean()

                # Value loss
                newvalue = newvalue.view(-1)
                if args.clip_vloss:
                    v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
                    v_clipped = b_values[mb_inds] + torch.clamp(
                        newvalue - b_values[mb_inds],
                        -args.clip_coef,
                        args.clip_coef,
                    )
                    v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
                    v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                    v_loss = 0.5 * v_loss_max.mean()
                else:
                    v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()

                entropy_loss = entropy.mean()
                loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
                optimizer.step()

        print(f"Episode {iteration}: Training complete. Avg Reward: {np.average(rewards.cpu()):.3f}. Steps: {step}. Distance of object: {env.current_distance_to_target:.3f}. Took: {(time.time() - start_time):.3f}s")
        
        display(display_video(env.frames))

    if args.save_model_to_path is not None:
        torch.save(agent.state_dict(), args.save_model_to_path)
        print(f"model saved to {args.save_model_to_path}")

    return agent

# Do Training

In [None]:
import importlib.resources

models_path = str(importlib.resources.files("stretch_mujoco") / "models")
xml_path = models_path + "/scene.xml"
physics = mujoco.Physics.from_xml_path(xml_path)

In [None]:
print('timestep', physics.model.opt.timestep)
print('gravity', physics.model.opt.gravity)

In [None]:
push_cube_by=(0,-0.2,0)

stretchPushCubeTrainingArmOnly = StretchPushCubeTrainingArmOnly(physics, push_cube_by=push_cube_by)

In [None]:
"""
Display starting pose and expected trained pose analytically.
"""

# This block just reads the mjspec to move the cube to its target_position.
# To do this, we have to edit the MjSpec and recompile it before loading it into physics. Mujoco does not allow us to edit body positions at runtime:
spec = mujoco.MjSpec.from_file(xml_path)
spec = mujoco.MjSpec.from_file(xml_path)
spec.find_body("object1").pos = stretchPushCubeTrainingArmOnly.target_position
spec.meshdir = "../../stretch_mujoco/models/assets/"
"""meshdir is relative here, it should be the same as spec.modelfiledir, but mujoco expects meshdir to be a relative dir? If you get a Not Found error, your relative path may be wrong."""
spec.texturedir = spec.meshdir
spec.compile()
expected_model = spec.to_xml()
expected_physics = mujoco.Physics.from_xml_string(expected_model)



# Go to home pose
physics.data.ctrl = physics.model.keyframe("home").ctrl
expected_physics.data.ctrl = expected_physics.model.keyframe("home").ctrl
for x in range(400):
    physics.step()
    expected_physics.step()



# Display images:
pixels = physics.render()
display(PIL.Image.fromarray(pixels))

print("Expecting final result:")

pixels = expected_physics.render()
display(PIL.Image.fromarray(pixels))
expected_physics.free()


#### Training Parameters

In [None]:
seconds_of_sim_per_epoch = 10 # Number of seconds in simulation per epoch.
max_steps_per_episode=int(seconds_of_sim_per_epoch * (1/ physics.model.opt.timestep) / 2)

training_args = PpoTrainingArgs(
    num_steps=max_steps_per_episode, 
    update_epochs=100
)


#### Train `StretchPushCubeTrainingArmOnly`

In [None]:
training_args.save_model_to_path="./stretchPushCubeTrainingArmOnly.model"


train(
    env=stretchPushCubeTrainingArmOnly, 
    training_args=training_args
)

#### Train `StretchPushCubeTraining` with all the joints on Stretch

In [None]:
training_args.save_model_to_path="./stretchPushCubeTrainingAllJoints.model"

stretchPushCubeTrainingAllJoints = StretchPushCubeTraining(physics, push_cube_by=push_cube_by)

train(
    env=stretchPushCubeTrainingAllJoints,
    training_args=training_args
)