### Introduction
This notebook is designed to run inference on the [Diffuser](https://arxiv.org/abs/2205.09991) planning model for model-based RL. The notebook is modified from the authors' [original](https://colab.research.google.com/drive/1YajKhu-CUIGBJeQPehjVPJcK_b38a8Nc?usp=sharing#scrollTo=57hSzI4mCgat). For those new to reinforcement learning, consider checking out the HuggingFace [Reinforcement Learning Course](https://huggingface.co/blog/deep-rl-intro) for a primer.

> Colab made by [natolambert](https://twitter.com/natolambert) and is still WIP!

![diffusers_library](https://github.com/huggingface/diffusers/raw/main/docs/source/imgs/diffusers_library.jpg)


### Setup

#### apt install requirements

These requirements primarily pertain to install mujoco and run it in the colab.
Source was inspired by this (fairly recent) [demo](https://colab.research.google.com/drive/1KGMZdRq6AemfcNscKjgpRzXqfhUtCf-V?usp=sharing).

In [1]:
# installations primiarly needed for Mujoco
!apt-get install -y \
    libgl1-mesa-dev \
    libgl1-mesa-glx \
    libglew-dev \
    libosmesa6-dev \
    software-properties-common

!apt-get install -y patchelf

Reading package lists... Done
Building dependency tree       
Reading state information... Done
libgl1-mesa-dev is already the newest version (20.0.8-0ubuntu1~18.04.1).
libgl1-mesa-dev set to manually installed.
software-properties-common is already the newest version (0.96.24.32.18).
The following package was automatically installed and is no longer required:
  libnvidia-common-460
Use 'apt autoremove' to remove it.
Suggested packages:
  glew-utils
The following NEW packages will be installed:
  libgl1-mesa-glx libglew-dev libglew2.0 libosmesa6 libosmesa6-dev
0 upgraded, 5 newly installed, 0 to remove and 49 not upgraded.
Need to get 2,916 kB of archives.
After this operation, 12.6 MB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu bionic-updates/main amd64 libgl1-mesa-glx amd64 20.0.8-0ubuntu1~18.04.1 [5,532 B]
Get:2 http://archive.ubuntu.com/ubuntu bionic/universe amd64 libglew2.0 amd64 2.0.0-5 [140 kB]
Get:3 http://archive.ubuntu.com/ubuntu bionic/univ

#### Clone repos and install dependencies

In [2]:
%cd /content

# install latest HF diffusers
!git clone https://github.com/huggingface/diffusers #git@github.com:huggingface/diffusers
!pip install -q /content/diffusers #!pip install -e .
!pip install -q accelerate datasets transformers unidecode

# additional RL-sepcific requirements
%pip install -f https://download.pytorch.org/whl/torch_stable.html \
                typed-argument-parser \
                scikit-image==0.17.2 \
                scikit-video==1.1.11 \
                gitpython \
                einops \
                pillow \
                free-mujoco-py \
                gym \
                protobuf==3.20.1 \
                git+https://github.com/rail-berkeley/d4rl.git \
                mediapy \
                Pillow==9.0.0 \


/content
Cloning into 'diffusers'...
remote: Enumerating objects: 2723, done.[K
remote: Counting objects: 100% (578/578), done.[K
remote: Compressing objects: 100% (175/175), done.[K
remote: Total 2723 (delta 363), reused 558 (delta 348), pack-reused 2145[K
Receiving objects: 100% (2723/2723), 824.68 KiB | 13.52 MiB/s, done.
Resolving deltas: 100% (1751/1751), done.
[33m  DEPRECATION: A future pip version will change local packages to be built in-place without first copying to a temporary directory. We recommend you use --use-feature=in-tree-build to test your packages with this new behavior before it becomes the default.
   pip 21.3 will remove support for this functionality. You can find discussion regarding this at https://github.com/pypa/pip/issues/7555.[0m
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
[K     |████████████████████████████████| 101 kB 4.9 MB/s 


In [3]:
## import d4rl for handling mujoco environments
## cythonize mujoco-py at first import
import d4rl

Compiling /usr/local/lib/python3.7/dist-packages/mujoco_py/cymj.pyx because it changed.
[1/1] Cythonizing /usr/local/lib/python3.7/dist-packages/mujoco_py/cymj.pyx
running build_ext
building 'mujoco_py.cymj' extension
creating /usr/local/lib/python3.7/dist-packages/mujoco_py/generated/_pyxbld_2.0.2.13_37_linuxcpuextensionbuilder
creating /usr/local/lib/python3.7/dist-packages/mujoco_py/generated/_pyxbld_2.0.2.13_37_linuxcpuextensionbuilder/temp.linux-x86_64-3.7
creating /usr/local/lib/python3.7/dist-packages/mujoco_py/generated/_pyxbld_2.0.2.13_37_linuxcpuextensionbuilder/temp.linux-x86_64-3.7/usr
creating /usr/local/lib/python3.7/dist-packages/mujoco_py/generated/_pyxbld_2.0.2.13_37_linuxcpuextensionbuilder/temp.linux-x86_64-3.7/usr/local
creating /usr/local/lib/python3.7/dist-packages/mujoco_py/generated/_pyxbld_2.0.2.13_37_linuxcpuextensionbuilder/temp.linux-x86_64-3.7/usr/local/lib
creating /usr/local/lib/python3.7/dist-packages/mujoco_py/generated/_pyxbld_2.0.2.13_37_linuxcpuexten

No module named 'flow'
No module named 'carla'


### Conditional sampling & planning
In this section, we will create the environment, handle the data, and run the diffusion model.

#### Constants & imports
This colab is designed to run with pretrained models from the hopper environment. As more models are trained, this can be extended.

In [4]:
import torch
import tqdm
import einops
import numpy as np
from diffusers import DDPMScheduler, TemporalUNet

import gym 
env_name = "hopper-medium-expert-v2"
env = gym.make(env_name)
data = env.get_dataset()

torch.cuda.get_device_name(0)
DEVICE = 'cuda:0'
DTYPE = torch.float
n_samples = 4
horizon = 128
state_dim = env.observation_space.shape[0] 
action_dim = env.action_space.shape[0]
num_inference_steps = 100 



Downloading dataset: http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco_v2/hopper_medium_expert-v2.hdf5 to /root/.d4rl/datasets/hopper_medium_expert-v2.hdf5


load datafile: 100%|██████████| 9/9 [00:03<00:00,  2.72it/s]


#### Helper functions

In [5]:
# todo: maybe rename as scale
def normalize(x_in, data, key):
  upper = np.max(data[key], axis=0)
  lower = np.min(data[key], axis=0)
  x_out = 2*(x_in - lower)/(upper-lower) - 1
  return x_out

def de_normalize(x_in, data, key):
	upper = np.max(data[key], axis=0)
	lower = np.min(data[key], axis=0)
	x_out = lower + (upper - lower)*(1 + x_in) /2
	return x_out

def to_np(x_in):
	if torch.is_tensor(x_in):
		x_in = x_in.detach().cpu().numpy()
	return x_in
	
def to_torch(x_in, dtype=None, device=None):
	dtype = dtype or DTYPE
	device = device or DEVICE
	if type(x_in) is dict:
		return {k: to_torch(v, dtype, device) for k, v in x_in.items()}
	elif torch.is_tensor(x_in):
		return x_in.to(device).type(dtype)
		# import pdb; pdb.set_trace()
	return torch.tensor(x_in, dtype=dtype, device=device)

def to_device(x_in, device=DEVICE):
	if torch.is_tensor(x_in):
		return x_in.to(device)
	elif type(x_in) is dict:
		return {k: to_device(v, device) for k, v in x_in.items()}
	else:
		print(f'Unrecognized type in `to_device`: {type(xx_in)}')
		pdb.set_trace()
  
def reset_x0(x_in, cond, act_dim):
	for key, val in cond.items():
		x_in[:, key, act_dim:] = val.clone()
	return x_in

#### Sample Initial State

In [6]:
# from diffusers import DDPMScheduler
from diffusers import TemporalUNet

generator = torch.Generator(device='cuda')
generator_cpu = torch.Generator(device='cpu')

## Can set environment seed for debugging
# torch.manual_seed(0)
# np.random.seed(0)
# env.seed(1996)

obs = env.reset()
obs_raw = obs



Downloading:   0%|          | 0.00/229 [00:00<?, ?B/s]

[(14, 32), (32, 128), (128, 256)]


Downloading:   0%|          | 0.00/14.8M [00:00<?, ?B/s]

#### Initialize Diffusion model
In this section, we create a scheduler and load a pretrained model from the Hub. An important detail in the RL application space is to save `conditions` which will allow the model to optimize trajectories only from the current state (which is cruical to making decisions!). 

In [None]:
scheduler = DDPMScheduler(timesteps=100,beta_schedule="squaredcos_cap_v2")

# 3 different pretrained models are available for this task. 
# The horizion represents the length of trajectories used in training.
network = TemporalUNet.from_pretrained("fusing/ddpm-unet-rl-hopper-hor128").to(device=DEVICE)
# network = TemporalUNet.from_pretrained("fusing/ddpm-unet-rl-hopper-hor256").to(device=DEVICE)
# network = TemporalUNet.from_pretrained("fusing/ddpm-unet-rl-hopper-hor512").to(device=DEVICE)

clip_denoised = network.clip_denoised
predict_epsilon = network.predict_epsilon

# normalize observations for forward passes
obs = normalize(obs, data, 'observations')

## add a batch dimension and repeat for multiple samples
## [ observation_dim ] --> [ n_samples x observation_dim ]
obs = obs[None].repeat(n_samples, axis=0)
conditions = {
    0: to_torch(obs, device=DEVICE)
  }

# constants for inference
batch_size = len(conditions[0])
shape = (batch_size, horizon, state_dim+action_dim)

# sample random initial noise vector
x1 = torch.randn(shape, device=DEVICE, generator=generator)
x = reset_x0(x1, conditions, action_dim)

#### Run the Diffusion Process /Generate trajectories

In [7]:
eta = 1.0 # noise factor for sampling reconstructed state
for i in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
    timesteps = torch.full((batch_size,), i, device=DEVICE, dtype=torch.long)
    with torch.no_grad():
      residual = network(x, conditions, timesteps)
    
    obs_reconstruct = scheduler.step(residual.cpu(), x.cpu(), i, predict_epsilon=False)

    if eta > 0:
      noise = torch.randn(obs_reconstruct.shape, generator=generator_cpu).to(obs_reconstruct.device)
      posterior_variance = scheduler.get_variance(i) # * noise
      # no noise when t == 0
      # NOTE: original implementation missing sqrt on posterior_variance
      obs_reconstruct = obs_reconstruct + int(i>0) * (0.5 * posterior_variance) * eta* noise  # MJ had as log var, exponentiated

    obs_reconstruct_postcond = reset_x0(obs_reconstruct, conditions, action_dim)
    x = to_torch(obs_reconstruct_postcond)


100%|██████████| 100/100 [00:02<00:00, 48.84it/s]


### Render the samples

#### Renderer Class
Rendering from Mujoco is historically not easy. Here is a modified version from the original paper. Additionally, a TODO is to investigate this web-based [viewer](https://github.com/kevinzakka/mjc_viewer).

In [8]:
# Code adapted from Michael Janner
# source: https://github.com/jannerm/diffuser/blob/main/diffuser/utils/rendering.py
import mujoco_py as mjc

def env_map(env_name):
    '''
        map D4RL dataset names to custom fully-observed
        variants for rendering
    '''
    if 'halfcheetah' in env_name:
        return 'HalfCheetahFullObs-v2'
    elif 'hopper' in env_name:
        return 'HopperFullObs-v2'
    elif 'walker2d' in env_name:
        return 'Walker2dFullObs-v2'
    else:
        return env_name

def get_image_mask(img):
    background = (img == 255).all(axis=-1, keepdims=True)
    mask = ~background.repeat(3, axis=-1)
    return mask

def atmost_2d(x):
    while x.ndim > 2:
        x = x.squeeze(0)
    return x

def set_state(env, state):
    qpos_dim = env.sim.data.qpos.size
    qvel_dim = env.sim.data.qvel.size
    if not state.size == qpos_dim + qvel_dim:
        warnings.warn(
            f'[ utils/rendering ] Expected state of size {qpos_dim + qvel_dim}, '
            f'but got state of size {state.size}')
        state = state[:qpos_dim + qvel_dim]

    env.set_state(state[:qpos_dim], state[qpos_dim:])


class MuJoCoRenderer:
    '''
        default mujoco renderer
    '''

    def __init__(self, env):
        if type(env) is str:
            env = env_map(env)
            self.env = gym.make(env)
        else:
            self.env = env
        ## - 1 because the envs in renderer are fully-observed
        ## @TODO : clean up
        self.observation_dim = np.prod(self.env.observation_space.shape) - 1
        self.action_dim = np.prod(self.env.action_space.shape)
        try:
            self.viewer = mjc.MjRenderContextOffscreen(self.env.sim)
        except:
            print('[ utils/rendering ] Warning: could not initialize offscreen renderer')
            self.viewer = None

    def pad_observation(self, observation):
        state = np.concatenate([
            np.zeros(1),
            observation,
        ])
        return state

    def pad_observations(self, observations):
        qpos_dim = self.env.sim.data.qpos.size
        ## xpos is hidden
        xvel_dim = qpos_dim - 1
        xvel = observations[:, xvel_dim]
        xpos = np.cumsum(xvel) * self.env.dt
        states = np.concatenate([
            xpos[:,None],
            observations,
        ], axis=-1)
        return states

    def render(self, observation, dim=256, partial=False, qvel=True, render_kwargs=None, conditions=None):

        if type(dim) == int:
            dim = (dim, dim)

        if self.viewer is None:
            return np.zeros((*dim, 3), np.uint8)

        if render_kwargs is None:
            xpos = observation[0] if not partial else 0
            render_kwargs = {
                'trackbodyid': 2,
                'distance': 3,
                'lookat': [xpos, -0.5, 1],
                'elevation': -20
            }

        for key, val in render_kwargs.items():
            if key == 'lookat':
                self.viewer.cam.lookat[:] = val[:]
            else:
                setattr(self.viewer.cam, key, val)

        if partial:
            state = self.pad_observation(observation)
        else:
            state = observation

        qpos_dim = self.env.sim.data.qpos.size
        if not qvel or state.shape[-1] == qpos_dim:
            qvel_dim = self.env.sim.data.qvel.size
            state = np.concatenate([state, np.zeros(qvel_dim)])

        set_state(self.env, state)

        self.viewer.render(*dim)
        data = self.viewer.read_pixels(*dim, depth=False)
        data = data[::-1, :, :]
        return data

    def _renders(self, observations, **kwargs):
        images = []
        for observation in observations:
            img = self.render(observation, **kwargs)
            images.append(img)
        return np.stack(images, axis=0)

    def renders(self, samples, partial=False, **kwargs):
        if partial:
            samples = self.pad_observations(samples)
            partial = False

        sample_images = self._renders(samples, partial=partial, **kwargs)

        composite = np.ones_like(sample_images[0]) * 255

        for img in sample_images:
            mask = get_image_mask(img)
            composite[mask] = img[mask]

        return composite

    def composite(self, savepath, paths, dim=(1024, 256), **kwargs):

        render_kwargs = {
            'trackbodyid': 2,
            'distance': 10,
            'lookat': [5, 2, 0.5],
            'elevation': 0
        }
        images = []
        for path in paths:
            ## [ H x obs_dim ]
            path = atmost_2d(path)
            img = self.renders(to_np(path), dim=dim, partial=True, qvel=True, render_kwargs=render_kwargs, **kwargs)
            images.append(img)
        images = np.concatenate(images, axis=0)

        if savepath is not None:
            imageio.imsave(savepath, images)
            print(f'Saved {len(paths)} samples to: {savepath}')

        return images

    def render_rollout(self, savepath, states, **video_kwargs):
        if type(states) is list: states = np.array(states)
        images = self._renders(states, partial=True)
        save_video(savepath, images, **video_kwargs)

    def render_plan(self, savepath, actions, observations_pred, state, fps=30):
        ## [ batch_size x horizon x observation_dim ]
        observations_real = rollouts_from_state(self.env, state, actions)

        ## there will be one more state in `observations_real`
        ## than in `observations_pred` because the last action
        ## does not have an associated next_state in the sampled trajectory
        observations_real = observations_real[:,:-1]

        images_pred = np.stack([
            self._renders(obs_pred, partial=True)
            for obs_pred in observations_pred
        ])

        images_real = np.stack([
            self._renders(obs_real, partial=False)
            for obs_real in observations_real
        ])

        ## [ batch_size x horizon x H x W x C ]
        images = np.concatenate([images_pred, images_real], axis=-2)
        save_videos(savepath, *images)

    def render_diffusion(self, savepath, diffusion_path, **video_kwargs):
        '''
            diffusion_path : [ n_diffusion_steps x batch_size x 1 x horizon x joined_dim ]
        '''
        render_kwargs = {
            'trackbodyid': 2,
            'distance': 10,
            'lookat': [10, 2, 0.5],
            'elevation': 0,
        }

        diffusion_path = to_np(diffusion_path)

        n_diffusion_steps, batch_size, _, horizon, joined_dim = diffusion_path.shape

        frames = []
        for t in reversed(range(n_diffusion_steps)):
            print(f'[ utils/renderer ] Diffusion: {t} / {n_diffusion_steps}')

            ## [ batch_size x horizon x observation_dim ]
            states_l = diffusion_path[t].reshape(batch_size, horizon, joined_dim)[:, :, :self.observation_dim]

            frame = []
            for states in states_l:
                img = self.composite(None, states, dim=(1024, 256), partial=True, qvel=True, render_kwargs=render_kwargs)
                frame.append(img)
            frame = np.concatenate(frame, axis=0)

            frames.append(frame)

        save_video(savepath, frames, **video_kwargs)

    def __call__(self, *args, **kwargs):
        return self.renders(*args, **kwargs)


#### Helpers

In [9]:
import os
import io
import base64
import skvideo.io
from IPython.display import HTML
from IPython import display as ipythondisplay

# from MJ's Diffuser code 
# https://github.com/jannerm/diffuser/blob/76ae49ae85ba1c833bf78438faffdc63b8b4d55d/diffuser/utils/colab.py#L79
def mkdir(savepath):
    """
        returns `True` iff `savepath` is created
    """
    if not os.path.exists(savepath):
        os.makedirs(savepath)
        return True
    else:
        return False


def show_sample(renderer, observations, filename='sample.mp4', savebase='/content/videos'):
    '''
    observations : [ batch_size x horizon x observation_dim ]
    '''

    mkdir(savebase)
    savepath = os.path.join(savebase, filename)

    images = []
    for rollout in observations:
        ## [ horizon x height x width x channels ]
        img = renderer._renders(rollout, partial=True)
        images.append(img)

    ## [ horizon x height x (batch_size * width) x channels ]
    images = np.concatenate(images, axis=2)

    import mediapy as media
    media.show_video(images, codec='h264', fps=60)
    # save_video(savepath, images)
    # show_video(savepath, height=200)

def show_video(path, height=400):
  video = io.open(path, 'r+b').read()
  encoded = base64.b64encode(video)
  ipythondisplay.display(HTML(data='''<video alt="test" autoplay 
              loop controls style="height: {0}px;">
              <source src="data:video/mp4;base64,{1}" type="video/mp4" />
           </video>'''.format(height, encoded.decode('ascii'))))


def save_video(filename, video_frames, fps=60, video_format='mp4'):
    assert fps == int(fps), fps
    mkdir(filename)

    skvideo.io.vwrite(
        filename,
        video_frames,
        inputdict={
            '-r': str(int(fps)),
        },
        outputdict={
            '-f': video_format,
            '-pix_fmt': 'yuv420p', # '-pix_fmt=yuv420p' needed for osx https://github.com/scikit-video/scikit-video/issues/74
        }
    )




#### Show Plans
This section renders 4 trajectories chosen from the same initial state in the environment.

In [11]:
render = MuJoCoRenderer(env)

In [12]:
de_normalized = de_normalize(to_np(x[:,:,action_dim:]), data, 'observations')
show_sample(render, de_normalized)
# investigating broken pipes, hmm. I think it's memory issues in a colab (doesn't like some errors)

0
This browser does not support the video tag.


### TEMPORARY

#### Save model to disk tool

In [13]:
# from google.colab import files
# torch.save(network.state_dict(), 'diffusion_model_hopper.pt')

# # download checkpoint file
# files.download('diffusion_model_hopper.pt')