## Imports and Environment Setup


In [1]:
#%%capture

# @title Install MuJoCo
!pip install mujoco
#@title Check if installation was successful

# from google.colab import files

import distutils.util
import os
import subprocess
# if subprocess.run('nvidia-smi').returncode:
#   raise RuntimeError(
#       'Cannot communicate with GPU. '
#       'Make sure you are using a GPU Colab runtime. '
#       'Go to the Runtime menu and select Choose runtime type.')

# # Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# # This is usually installed as part of an Nvidia driver package, but the Colab
# # kernel doesn't install its driver via APT, and as a result the ICD is missing.
# # (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
# NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
# if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
#   with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
#     f.write("""{
#     "file_format_version" : "1.0.0",
#     "ICD" : {
#         "library_path" : "libEGL_nvidia.so.0"
#     }
# }
# """)

# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl

try:
  print('Checking that the installation succeeded:')
  import mujoco
  mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
  raise e from RuntimeError(
      'Something went wrong during installation. Check the shell output above '
      'for more information.\n'
      'If using a hosted Colab runtime, make sure you enable GPU acceleration '
      'by going to the Runtime menu and selecting "Choose runtime type".')

print('Installation successful.')

#@title Import packages for plotting and creating graphics
import time
import itertools
import numpy as np
from typing import Callable, NamedTuple, Optional, Union, List

# Graphics and plotting.
print('Installing mediapy:')
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
!pip install -q mediapy
import mediapy as media
import matplotlib.pyplot as plt

# Torch setup
import torch
from torch import nn, zeros
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from collections import deque
import random
import copy

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# More legible printing from numpy.
np.set_printoptions(precision=3, suppress=True, linewidth=100)

Setting environment variable to use GPU rendering:
env: MUJOCO_GL=egl
Checking that the installation succeeded:
Installation successful.
Installing mediapy:


  import distutils.util
2025-04-16 14:00:41.380786: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-04-16 14:00:41.380823: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-04-16 14:00:41.381752: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-16 14:00:41.386526: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Using device: cuda


In [2]:
# @title Load simple connector MJCF as `xml`
xml = """
<mujoco>
  <default>
    <geom density="1" solimp="0.0 0.1 0.1 0.5 2"/>
    <!-- <geom solimp="0.0 0.1 0.1 0.5 2" /> -->
  </default>

  <visual>
    <map force="0.1"/>
    <headlight ambient="0.7 0.7 0.7"/>
    <rgba contactforce="0.7 1.0 1.0 .6"/>
  </visual>

  <visual>
  <global offwidth="1024" offheight="1024"/>
  <rgba haze="0.15 0.25 0.35 1"/>
</visual>


  <option>
    <flag gravity="disable"/>
  </option>


<asset>
    <texture type="skybox" builtin="gradient" rgb1=".3 .5 .7" rgb2="0 0 0" width="32" height="512"/>
    <texture name="body" type="cube" builtin="flat" mark="cross" width="128" height="128" rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" markrgb="1 1 1" random="0.01"/>
    <material name="body" texture="body" texuniform="true" rgba="0.8 0.6 .4 1"/>
    <texture name="grid" type="2d" builtin="checker" width="512" height="512" rgb1=".1 .2 .3" rgb2=".2 .3 .4"/>
    <material name="grid" texture="grid" texrepeat="1 1" texuniform="true" reflectance="0."/>
  </asset>

  <worldbody>
    <light diffuse=".5 .5 .5" pos="0 0 3" dir="0 0 -1"/>
    <geom name="floor" pos="0 0 -3" size="0 0 0.05" type="plane" material="grid"/>
    <light name="spotlight" mode="targetbodycom" target="m" diffuse=".8 .8 .8" specular="0.3 0.3 0.3" pos="0 -6 4" cutoff="30"/>
    <body name="f" >
      <body name = "fb1" >
        <geom name="fb1" size=".075" pos=".5 .45 0" rgba="0 0 1 1" priority="1"  friction=".6 0.005 0.0001"/>
      </body>
      <body name = "fb2" >
        <geom name="fb2" size=".075" pos="-.5 .45 0" rgba="0 0 1 1" priority="1"  friction=".6 0.005 0.0001"/>
      </body>
      <body name = "fb3" >
        <geom name="fb3" size=".075" pos=".5 -.45 0" rgba="0 0 1 1" priority="1"  friction=".6 0.005 0.0001"/>
      </body>
      <body name = "fb4" >
        <geom name="fb4" size=".075" pos="-.5 -.45 0" rgba="0 0 1 1" priority="1"  friction=".6 0.005 0.0001"/>
      </body>
      <geom name="left" type="box" pos="-0.65 0 -0.85" size=".15 .45 .85" rgba=".4 .4 .4 .2"/>
      <geom name="right" type="box" pos="0.65 0 -0.85" size=".15 .45 .85" rgba=".4 .4 .4 .2"/>
      <geom name="bottom" type="box" pos="0 0 -1.85" size=".8 .45 .15" rgba=".4 .4 .4 .2"/>
      <geom name="side1" type="box" pos="0 0.6 -1" size=".8 .15 1" rgba=".4 .4 .4 .2"/>
      <geom name="side2" type="box" pos="0 -0.6 -1" size=".8 .15 1" rgba=".4 .4 .4 .2"/>
      <site name="FT"/>
    </body>
    

    <body name="m" pos="0 0 1">
      <joint type="slide" name="pos_z" axis="0 0 1"/>
      <joint pos="0 0 1" type="hinge" name="rot_y" axis="0 1 0"/>
      <joint pos="0 0 1" type="hinge" name="rot_x" axis="1 0 0"/>
    <geom name="male" type="box" size = ".45 .4 1.0" rgba=".4 .2 .1 .5" />
      <body name = "mb1" >
        <geom name="mb1" size=".085" pos="0.45 -.4 -1.0" rgba="1 0 0 0"/>
      </body>
      <body name = "mb2" >
        <geom name="mb2" size=".085" pos="-0.45 -.4 -1.0" rgba="1 0 0 0"/>
      </body>
      <body name = "mb3" >
        <geom name="mb3" size=".085" pos=".45 .4 -1.0" rgba="1 0 0 0"/>
      </body>
      <body name = "mb4" >
        <geom name="mb4" size=".085" pos="-.45 .4 -1.0" rgba="1 0 0 0"/>
      </body>
      <site name="plug" pos="0 0 -1"/>
      <camera name="track" pos="0 -6 0" xyaxes="1 0 0 0 .2 1" mode="track"/>
    </body>
  </worldbody>

  <default>
    <general ctrlrange="-3 3" ctrllimited="true" biastype="affine"/>
  </default>

  <actuator>
    <general name="pos_z" joint="pos_z" gainprm="10" biasprm="0 0 -10"/>
    <general name="rot_x" joint="rot_x" gainprm="10" biasprm="0 0 -10"/>
    <general name="rot_y" joint="rot_y" gainprm="10" biasprm="0 0 -10"/>
  </actuator>

  <sensor>
    <force name="force" site="FT"/>
    <torque name="torque" site="FT"/>
    <framepos name="plugpos" objtype="site" objname="plug" reftype="geom" refname="bottom"/>
  </sensor>

  <contact>
    <exclude body1="f" body2="fb1"/>
    <exclude body1="f" body2="fb2"/>
    <exclude body1="f" body2="fb3"/>
    <exclude body1="f" body2="fb4"/>
    <exclude body1="m" body2="mb1"/>
    <exclude body1="m" body2="mb2"/>
    <exclude body1="m" body2="mb3"/>
    <exclude body1="m" body2="mb4"/>
    <exclude body1="fb1" body2="mb1"/>
    <exclude body1="fb1" body2="mb2"/>
    <exclude body1="fb1" body2="mb3"/>
    <exclude body1="fb1" body2="mb4"/>
    <exclude body1="fb2" body2="mb1"/>
    <exclude body1="fb2" body2="mb2"/>
    <exclude body1="fb2" body2="mb3"/>
    <exclude body1="fb2" body2="mb4"/>
    <exclude body1="fb3" body2="mb1"/>
    <exclude body1="fb3" body2="mb2"/>
    <exclude body1="fb3" body2="mb3"/>
    <exclude body1="fb3" body2="mb4"/>
    <exclude body1="fb4" body2="mb1"/>
    <exclude body1="fb4" body2="mb2"/>
    <exclude body1="fb4" body2="mb3"/>
    <exclude body1="fb4" body2="mb4"/>
  </contact>



</mujoco>
"""

In [3]:
# Environment Setup as a custom gym environment
from typing import Optional
import numpy as np
import gymnasium as gym
from scipy.spatial.transform import Rotation

class PluggingEnv(gym.Env):
    def __init__(self, xml, reset_noise_scale, record_video=False):
        # make model and data
        self.model = mujoco.MjModel.from_xml_string(xml)
        self.data = mujoco.MjData(self.model)
        self.n_steps_per_call = 10
        dt = 1/600
        self.model.opt.timestep = dt
        self.observation_space = gym.spaces.Box(low=-40, high=40, shape=(6+6+3,)) # 3 pos, 3 vel, 6 dof force, 3 rotations of plug
        self.action_space = gym.spaces.Box(low=-1, high=1, shape=(3,)) # 3 joint actions
        self.max_steps = 1000
        self.step_count = 0
        self.reset_noise_scale = reset_noise_scale
        self.record_video = record_video
        self.framerate = 1/(dt*self.n_steps_per_call)
        if(self.record_video):
            self.frames = []

    def reset(self, seed: Optional[int] = None, options: Optional[dict] = None):
        # We need the following line to seed self.np_random
        super().reset(seed=seed)

        self.model = mujoco.MjModel.from_xml_string(xml)
        self.data = mujoco.MjData(self.model)
        self.step_count = 0
        mujoco.mj_resetData(self.model, self.data)
        self.data.qpos = np.array([0, 0, 0])
        self.data.qvel = np.zeros_like(self.data.qpos)
        if(options is not None):
            self.rpy = options['rpy']
        else:
            self.rpy = 2*self.reset_noise_scale*(np.random.rand(3)-0.5)
        self.rpy[2] = 0
        rotation = Rotation.from_euler('yxz', self.rpy)
        self.model.body("f").quat = rotation.as_quat(scalar_first=True)          
        if(self.record_video):
            # make renderer, render and show the pixels
            self.renderer = mujoco.Renderer(self.model, width=1000, height=1000)
            self.options = mujoco.MjvOption()
            
            # turn on contact force visualizer
            mujoco.mjv_defaultOption(self.options)
            self.options.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = True
        
        obs, _, _, _, info = self.step(np.zeros_like(self.data.ctrl))
        
        return obs, info

    def _check_truncate(self, state):
        state_check = state[0] > 3
        angle_check = np.linalg.norm(state[1:3]) > 0.2
        speed_check = np.linalg.norm(state[3]) > 0.1
        return state_check or angle_check or speed_check

    def _get_obs(self):
        state = np.concatenate((
        self.data.qpos.copy(),
        self.data.qvel.copy()
        ))
        force = np.concatenate((
            self.data.sensor('force').data,
            self.data.sensor('torque').data,
        ))
        rpy = self.rpy
        return state, force, rpy

    def _get_reward(self, state, action, force, rpy):
        pt_distance_scale = 1
        force_scale = 1
        position_reward = -np.linalg.norm(self.data.sensor('plugpos').data)
        bonus_reward = 100 if np.linalg.norm(self.data.sensor('plugpos').data) < 0.25 else 0
        force_reward = 0.1*np.exp(-np.linalg.norm(force)/force_scale)
        angle_reward = -np.linalg.norm(self.data.qpos[1:3] - rpy[0:2])
        # print("rpy: ", rpy, "qpos: ", self.data.qpos[1:3], "angle reward: ", angle_reward)
        return position_reward + force_reward + bonus_reward + angle_reward

    def step(self, action):
        self.data.ctrl = np.clip(action, -3, 3)
        mujoco.mj_step(self.model, self.data, nstep = self.n_steps_per_call)
        state, force, rpy = self._get_obs()
        reward = self._get_reward(state, action, force, rpy)
        observation = np.concatenate((state, force, rpy))
        self.step_count += 1
        terminated = self.step_count >= self.max_steps #or np.linalg.norm(self.data.sensor('plugpos').data) < 0.25

        truncated = False

        info = dict()
        if(self.record_video):
            self.renderer.update_scene(self.data, "track", self.options)
            pixels = self.renderer.render()
            self.frames.append(pixels)
        return observation, reward, terminated, truncated, info

    def get_frames(self):
        return self.frames


gym.register(
    id="DRL/PluggingEnv",
    entry_point=PluggingEnv,
)



In [4]:
class ReplayBuffer:
    def __init__(self):
        self.buffer = deque(maxlen=1_000_000)
        self.batch_size = 32

    def store(self, state, action, reward, next_state, done):
        transitions = list(zip(state, action, reward, next_state, 1 - torch.Tensor(done)))
        self.buffer.extend(transitions)

    def sample(self):
        batch = random.sample(self.buffer, self.batch_size)
        return [torch.stack(e).to(device) for e in zip(*batch)]  # states, actions, rewards, next_states, not_dones

In [5]:
import pickle

with open("replay_buffer.pkl", "rb") as f:
    replay_buffer = pickle.load(f)

In [22]:
def obs_to_training_data(obs):
    obs_new = obs[:, :-3, ...]
    ground_truth = obs[:, -3:-1, ...]
    return obs_new, ground_truth

In [27]:
obs, action, _, next_obs, _ = replay_buffer.sample()
n_obs = obs.size()[1]
n_actions = action.size()[1]

print(n_obs, n_actions)

osi_net = nn.Sequential(
            nn.Linear(n_obs+n_actions-3, 128), # Remove the roll, pitch, and yaw of the block from the observation data
            nn.ReLU(),
            nn.Linear(128,128),
            nn.ReLU(),
            nn.Linear(128,128),
            nn.ReLU(),
            nn.Linear(128, 2) # roll and pitch
        ).to(device)

optimizer = Adam(osi_net.parameters(), 3e-4)
loss_fcn = torch.nn.MSELoss()

15 3


In [28]:
# DDPG training loop

# tensorboard label can be changed with e.g. f'runs/unique_hyperparam_test'
writer = SummaryWriter(log_dir=f'runs/OSI')


# takes ~5-10 minutes on colab gpus
for i in range(30000):
    obs, action, _, next_obs, _ = replay_buffer.sample() #state, action, reward, next_state, dones
    obs_new, ground_truth = obs_to_training_data(obs)
    estimate = osi_net(torch.cat((obs_new, action), 1))
    
    loss = loss_fcn(ground_truth, estimate)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # end student code
    writer.add_scalar("stats/nll_loss", loss.item(), i)
   

print("Done Training!")

Done Training!


In [20]:
# Launch TensorBoard
%load_ext tensorboard
%tensorboard --logdir runs

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 306080), started 1 day, 18:00:49 ago. (Use '!kill 306080' to kill it.)