# Dependencies/setup
Installs all the necessary dependencies as well as setups files.

In [1]:
!mkdir -p /usr/share/vulkan/icd.d
!wget -q https://raw.githubusercontent.com/haosulab/ManiSkill2/main/docker/nvidia_icd.json
!wget -q https://raw.githubusercontent.com/haosulab/ManiSkill2/main/docker/10_nvidia.json
!mv nvidia_icd.json /usr/share/vulkan/icd.d
!mv 10_nvidia.json /usr/share/glvnd/egl_vendor.d/10_nvidia.json
# dependencies
!apt-get install -y --no-install-recommends libvulkan-dev
!pip install git+https://github.com/arnavg115/ManiSkill2.git
!pip install --upgrade --no-cache-dir gdown
!pip install diffusers wandb

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
  libvulkan1
Recommended packages:
  mesa-vulkan-drivers | vulkan-icd
The following NEW packages will be installed:
  libvulkan-dev libvulkan1
0 upgraded, 2 newly installed, 0 to remove and 30 not upgraded.
Need to get 1,020 kB of archives.
After this operation, 17.2 MB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/main amd64 libvulkan1 amd64 1.3.204.1-2 [128 kB]
Get:2 http://archive.ubuntu.com/ubuntu jammy/main amd64 libvulkan-dev amd64 1.3.204.1-2 [892 kB]
Fetched 1,020 kB in 2s (414 kB/s)
Selecting previously unselected package libvulkan1:amd64.
(Reading database ... 121658 files and directories currently installed.)
Preparing to unpack .../libvulkan1_1.3.204.1-2_amd64.deb ...
Unpacking libvulkan1:amd64 (1.3.204.1-2) ...
Selecting previously unselected package libvulkan-dev:amd64.
Preparing to unpa

# Diffusion Policy Using Maniskill-2
This notebook implements [diffusion policy](https://diffusion-policy.cs.columbia.edu/) using the environments in ManiSkill2. This notebook uses code from the example jupyter notebook found [here](https://colab.research.google.com/drive/1gxdkgRVfM55zihY9TFLja97cSVZOZq2B?usp=sharing).

In [2]:
import gymnasium as gym
from tqdm.notebook import tqdm
import numpy as np
import mani_skill2.envs
import torch
from torch.utils.data import Dataset
from mani_skill2.utils.io_utils import load_json
import h5py
import torch.nn as nn
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.training_utils import EMAModel
import math
from mani_skill2.utils.wrappers import RecordEpisode
from diffusers.optimization import get_scheduler

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

# 0. Initialize environment, download dataset, and define hyper parameters.


In [10]:
env = "LiftCube-v0"  # @param ["LiftCube-v0", "StackCube-v0"]
# @markdown If loss or evaluation metrics should be published to wandb.
wandb = True  # @param {type:"boolean"}
# @markdown How many observations are fed as part of the global condition
obs_horizon = 2  # @param {type:"integer"}
# @markdown How many predictions the model makes
pred_horizon = 16  # @param {type:"integer"}
# @markdown How many predictions are acted upon.
action_horizon = 8  # @param {type:"integer"}
batch_size = 256  # @param {type:"integer"}
# @markdown How many epochs the model is trained on
n_epochs = 1000  # @param {type:"integer"}


config = {
    "env": env,
    "dataset": f"demos/v0/rigid_body/{env}/trajectory.state.pd_ee_delta_pose.h5",
    "pred_horizon": pred_horizon,
    "obs_horizon": obs_horizon,
    "action_horizon": action_horizon,
    "num_eval_eps": 10,
    "eval_ep_len": 100,
    "batch_size": batch_size,
    "wandb": wandb,
    "n_epochs": n_epochs,
    "dataset": f"demos/v0/rigid_body/{env}/trajectory.state.pd_ee_delta_pose.h5",
}

In [5]:
env = gym.make(
    config["env"],
    obs_mode="state",
    control_mode="pd_ee_delta_pose",
    render_mode="cameras",
    enable_shadow=True,
)

In [6]:
# @markdown Download preprocessed mani-skill2 trajectories
!wget https://huggingface.co/datasets/a11g/maniskill2-replayed-trajectories/resolve/main/demos.zip?download=true -O demos.zip

--2024-01-24 17:46:51--  https://huggingface.co/datasets/a11g/maniskill2-replayed-trajectories/resolve/main/demos.zip?download=true
Resolving huggingface.co (huggingface.co)... 13.33.33.55, 13.33.33.102, 13.33.33.110, ...
Connecting to huggingface.co (huggingface.co)|13.33.33.55|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.huggingface.co/repos/fc/c9/fcc95cd9e7677aafe247b0d349c4e09c4db622091f64f08a66db3846092956ea/7af10271357a249edec6f14683f52e4cd6c5eb60261ed0b70696391444a0bf40?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27demos.zip%3B+filename%3D%22demos.zip%22%3B&response-content-type=application%2Fzip&Expires=1706377611&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcwNjM3NzYxMX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2ZjL2M5L2ZjYzk1Y2Q5ZTc2NzdhYWZlMjQ3YjBkMzQ5YzRlMDljNGRiNjIyMDkxZjY0ZjA4YTY2ZGIzODQ2MDkyOTU2ZWEvN2FmMTAyNzEzNTdhMjQ5ZWRl

In [9]:
!unzip demos

Archive:  demos.zip
   creating: demos/
   creating: demos/v0/
   creating: demos/v0/rigid_body/
   creating: demos/v0/rigid_body/LiftCube-v0/
  inflating: demos/v0/rigid_body/LiftCube-v0/trajectory.json  
  inflating: demos/v0/rigid_body/LiftCube-v0/trajectory.h5  
  inflating: demos/v0/rigid_body/LiftCube-v0/trajectory.state.pd_ee_delta_pose.h5  
  inflating: demos/v0/rigid_body/LiftCube-v0/trajectory.state.pd_ee_delta_pose.json  
   creating: demos/v0/rigid_body/StackCube-v0/
  inflating: demos/v0/rigid_body/StackCube-v0/trajectory.json  
  inflating: demos/v0/rigid_body/StackCube-v0/trajectory.h5  
  inflating: demos/v0/rigid_body/StackCube-v0/trajectory.state.pd_ee_delta_pose.h5  
  inflating: demos/v0/rigid_body/StackCube-v0/trajectory.state.pd_ee_delta_pose.json  
   creating: demos/v0/rigid_body/PickCube-v0/
  inflating: demos/v0/rigid_body/PickCube-v0/trajectory.json  
  inflating: demos/v0/rigid_body/PickCube-v0/trajectory.h5  
  inflating: demos/v0/rigid_body/PickCube-v0/tra

#1. Define Dataset.
The datset is defined here. Also defined are helper methods that prepare the data to be fed into the model. We feed the model actions of shape (B,obs_horizon, obs_space) and actions of shape (B, pred_horizon, pred_shape). B stands for the batch size. We then define a dataloader which is used during training.

In [8]:
def load_h5_data(data):
    out = dict()
    for k in data.keys():
        if isinstance(data[k], h5py.Dataset):
            out[k] = data[k][:]
        else:
            out[k] = load_h5_data(data[k])
    return out


def create_sample_indices(
    episode_ends: np.ndarray,
    sequence_length: int,
    pad_before: int = 0,
    pad_after: int = 0,
):
    indices = list()
    for i in range(len(episode_ends)):
        start_idx = 0
        # if i > 0:
        #     start_idx = episode_ends[i-1]
        end_idx = episode_ends[i]
        episode_length = end_idx - start_idx

        min_start = -pad_before
        max_start = episode_length - sequence_length + pad_after

        # range stops one idx before end
        for idx in range(min_start, max_start + 1):
            buffer_start_idx = max(idx, 0) + start_idx
            buffer_end_idx = min(idx + sequence_length, episode_length) + start_idx
            start_offset = buffer_start_idx - (idx + start_idx)
            end_offset = (idx + sequence_length + start_idx) - buffer_end_idx
            sample_start_idx = 0 + start_offset
            sample_end_idx = sequence_length - end_offset
            indices.append(
                [buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx, i]
            )
    indices = np.array(indices)
    return indices


def sample_sequence(
    data, seq_len, buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx, i
):
    sample = data[i][buffer_start_idx:buffer_end_idx]
    if sample_start_idx > 0:
        sample = np.insert(sample, 0, np.tile(sample[0], (sample_start_idx, 1)), axis=0)
    if sample_end_idx < seq_len:
        sample = np.insert(
            sample, -1, np.tile(sample[-1], (seq_len - sample_end_idx, 1)), axis=0
        )
    return sample


class ManiSkill2Dataset(Dataset):
    def __init__(self, config, load_count=-1) -> None:
        self.dataset_file = config["dataset"]
        # for details on how the code below works, see the
        # quick start tutorial
        self.data = h5py.File(config["dataset"], "r")
        json_path = config["dataset"].replace(".h5", ".json")
        self.json_data = load_json(json_path)
        self.episodes = self.json_data["episodes"]

        self.env_info = self.json_data["env_info"]
        self.env_id = self.env_info["env_id"]
        self.env_kwargs = self.env_info["env_kwargs"]

        self.observations = []
        self.actions = []
        self.rewards = []
        self.dones = []
        self.total_frames = 0
        if load_count == -1:
            load_count = len(self.episodes)
        for eps_id in tqdm(range(load_count)):
            eps = self.episodes[eps_id]
            trajectory = self.data[f"traj_{eps['episode_id']}"]
            trajectory = load_h5_data(trajectory)

            # we use :-1 here to ignore the last observation as that
            # is the terminal observation which has no actions
            self.observations.append(trajectory["obs"][:-1])
            self.actions.append(trajectory["actions"])
            # print(trajectory.keys())
        ends = [action.shape[0] for action in self.actions]
        self.action_space = self.actions[0].shape[-1]
        self.obs_space = self.observations[0].shape[-1]

        self.episode_ends = np.array(ends)
        self.inds = create_sample_indices(
            self.episode_ends,
            config["pred_horizon"],
            config["obs_horizon"] - 1,
            config["action_horizon"] - 1,
        )
        self.pred_horizon = config["pred_horizon"]
        self.obs_horizon = config["obs_horizon"]

        # self.rewards = np.vstack(self.rewards)

    def get_state_stats(self):
        arr = np.vstack(self.observations)
        return np.mean(arr, axis=0), np.std(arr, axis=0) + 1e-6

    def __len__(self):
        return len(self.observations)

    def __getitem__(self, idx):

        action = sample_sequence(self.actions, self.pred_horizon, *self.inds[idx])
        obs = sample_sequence(self.observations, self.pred_horizon, *self.inds[idx])

        return torch.from_numpy(action).to(torch.float32), torch.from_numpy(
            obs[: self.obs_horizon, :]
        ).to(torch.float32)

In [11]:
dataset = ManiSkill2Dataset(config=config)

  0%|          | 0/100 [00:00<?, ?it/s]

In [None]:
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=config["batch_size"],
    num_workers=1,
    shuffle=True,
    # accelerate cpu-gpu transfer
    pin_memory=True,
    # don't kill worker process afte each epoch
    persistent_workers=True,
)

# 2. Defining the model
This is taken from the original implementation and it implements a conditional u-net. The following descriptions come from the original notebook. We then instantiate the model and noise scheduler. We also define an EMA of the model parameters as well as a learning rate scheduler.

The model works taking noise and running it through the diffusion process in order to come up with an optimal action. The observations are used as the global condition for the model similar to how text would be used in a diffusion model meant for Text to Image.

In [None]:
# @markdown ### **Network**
# @markdown
# @markdown Defines a 1D UNet architecture `ConditionalUnet1D`
# @markdown as the noise prediction network
# @markdown
# @markdown Components
# @markdown - `SinusoidalPosEmb` Positional encoding for the diffusion iteration k
# @markdown - `Downsample1d` Strided convolution to reduce temporal resolution
# @markdown - `Upsample1d` Transposed convolution to increase temporal resolution
# @markdown - `Conv1dBlock` Conv1d --> GroupNorm --> Mish
# @markdown - `ConditionalResidualBlock1D` Takes two inputs `x` and `cond`. \
# @markdown `x` is passed through 2 `Conv1dBlock` stacked together with residual connection.
# @markdown `cond` is applied to `x` with [FiLM](https://arxiv.org/abs/1709.07871) conditioning.


class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb


class Downsample1d(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.Conv1d(dim, dim, 3, 2, 1)

    def forward(self, x):
        return self.conv(x)


class Upsample1d(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)

    def forward(self, x):
        return self.conv(x)


class Conv1dBlock(nn.Module):
    """
    Conv1d --> GroupNorm --> Mish
    """

    def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
        super().__init__()

        self.block = nn.Sequential(
            nn.Conv1d(
                inp_channels, out_channels, kernel_size, padding=kernel_size // 2
            ),
            nn.GroupNorm(n_groups, out_channels),
            nn.Mish(),
        )

    def forward(self, x):
        return self.block(x)


class ConditionalResidualBlock1D(nn.Module):
    def __init__(self, in_channels, out_channels, cond_dim, kernel_size=3, n_groups=8):
        super().__init__()

        self.blocks = nn.ModuleList(
            [
                Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
                Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
            ]
        )

        # FiLM modulation https://arxiv.org/abs/1709.07871
        # predicts per-channel scale and bias
        cond_channels = out_channels * 2
        self.out_channels = out_channels
        self.cond_encoder = nn.Sequential(
            nn.Mish(), nn.Linear(cond_dim, cond_channels), nn.Unflatten(-1, (-1, 1))
        )

        # make sure dimensions compatible
        self.residual_conv = (
            nn.Conv1d(in_channels, out_channels, 1)
            if in_channels != out_channels
            else nn.Identity()
        )

    def forward(self, x, cond):
        """
        x : [ batch_size x in_channels x horizon ]
        cond : [ batch_size x cond_dim]

        returns:
        out : [ batch_size x out_channels x horizon ]
        """
        out = self.blocks[0](x)
        embed = self.cond_encoder(cond)

        embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1)
        scale = embed[:, 0, ...]
        bias = embed[:, 1, ...]
        out = scale * out + bias

        out = self.blocks[1](out)
        out = out + self.residual_conv(x)
        return out


class ConditionalUnet1D(nn.Module):
    def __init__(
        self,
        input_dim,
        global_cond_dim,
        diffusion_step_embed_dim=256,
        down_dims=[256, 512, 1024],
        kernel_size=5,
        n_groups=8,
    ):
        """
        input_dim: Dim of actions.
        global_cond_dim: Dim of global conditioning applied with FiLM
          in addition to diffusion step embedding. This is usually obs_horizon * obs_dim
        diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k
        down_dims: Channel size for each UNet level.
          The length of this array determines numebr of levels.
        kernel_size: Conv kernel size
        n_groups: Number of groups for GroupNorm
        """

        super().__init__()
        all_dims = [input_dim] + list(down_dims)
        start_dim = down_dims[0]

        dsed = diffusion_step_embed_dim
        diffusion_step_encoder = nn.Sequential(
            SinusoidalPosEmb(dsed),
            nn.Linear(dsed, dsed * 4),
            nn.Mish(),
            nn.Linear(dsed * 4, dsed),
        )
        cond_dim = dsed + global_cond_dim

        in_out = list(zip(all_dims[:-1], all_dims[1:]))
        mid_dim = all_dims[-1]
        self.mid_modules = nn.ModuleList(
            [
                ConditionalResidualBlock1D(
                    mid_dim,
                    mid_dim,
                    cond_dim=cond_dim,
                    kernel_size=kernel_size,
                    n_groups=n_groups,
                ),
                ConditionalResidualBlock1D(
                    mid_dim,
                    mid_dim,
                    cond_dim=cond_dim,
                    kernel_size=kernel_size,
                    n_groups=n_groups,
                ),
            ]
        )

        down_modules = nn.ModuleList([])
        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (len(in_out) - 1)
            down_modules.append(
                nn.ModuleList(
                    [
                        ConditionalResidualBlock1D(
                            dim_in,
                            dim_out,
                            cond_dim=cond_dim,
                            kernel_size=kernel_size,
                            n_groups=n_groups,
                        ),
                        ConditionalResidualBlock1D(
                            dim_out,
                            dim_out,
                            cond_dim=cond_dim,
                            kernel_size=kernel_size,
                            n_groups=n_groups,
                        ),
                        Downsample1d(dim_out) if not is_last else nn.Identity(),
                    ]
                )
            )

        up_modules = nn.ModuleList([])
        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (len(in_out) - 1)
            up_modules.append(
                nn.ModuleList(
                    [
                        ConditionalResidualBlock1D(
                            dim_out * 2,
                            dim_in,
                            cond_dim=cond_dim,
                            kernel_size=kernel_size,
                            n_groups=n_groups,
                        ),
                        ConditionalResidualBlock1D(
                            dim_in,
                            dim_in,
                            cond_dim=cond_dim,
                            kernel_size=kernel_size,
                            n_groups=n_groups,
                        ),
                        Upsample1d(dim_in) if not is_last else nn.Identity(),
                    ]
                )
            )

        final_conv = nn.Sequential(
            Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
            nn.Conv1d(start_dim, input_dim, 1),
        )

        self.diffusion_step_encoder = diffusion_step_encoder
        self.up_modules = up_modules
        self.down_modules = down_modules
        self.final_conv = final_conv

        print(
            "number of parameters: {:e}".format(
                sum(p.numel() for p in self.parameters())
            )
        )

    def forward(
        self,
        sample: torch.Tensor,
        timestep,  # TODO: Union[torch.Tensor, float, int]
        global_cond=None,
    ):
        """
        x: (B,T,input_dim)
        timestep: (B,) or int, diffusion step
        global_cond: (B,global_cond_dim)
        output: (B,T,input_dim)
        """
        # (B,T,C)
        sample = sample.moveaxis(-1, -2)
        # (B,C,T)

        # 1. time
        timesteps = timestep
        if not torch.is_tensor(timesteps):
            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
            timesteps = torch.tensor(
                [timesteps], dtype=torch.long, device=sample.device
            )
        elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)
        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        timesteps = timesteps.expand(sample.shape[0])

        global_feature = self.diffusion_step_encoder(timesteps)

        if global_cond is not None:
            global_feature = torch.cat([global_feature, global_cond], axis=-1)

        x = sample
        h = []
        for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
            x = resnet(x, global_feature)
            x = resnet2(x, global_feature)
            h.append(x)
            x = downsample(x)

        for mid_module in self.mid_modules:
            x = mid_module(x, global_feature)

        for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
            x = torch.cat((x, h.pop()), dim=1)
            x = resnet(x, global_feature)
            x = resnet2(x, global_feature)
            x = upsample(x)

        x = self.final_conv(x)

        # (B,C,T)
        x = x.moveaxis(-1, -2)
        # (B,T,C)
        return x

In [None]:
obs_horizon = config["obs_horizon"]
pred_horizon = config["pred_horizon"]
action_horizon = config["action_horizon"]

# observation and action dimensions corrsponding to
# the output of PushTEnv
obs_dim = dataset.obs_space
action_dim = dataset.action_space

# create network object
noise_pred_net = ConditionalUnet1D(
    input_dim=action_dim, global_cond_dim=obs_dim * obs_horizon
)

# for this demo, we use DDPMScheduler with 100 diffusion iterations
num_diffusion_iters = config["num_diffusion_iters"]
noise_scheduler = DDPMScheduler(
    num_train_timesteps=num_diffusion_iters,
    # the choise of beta schedule has big impact on performance
    # we found squared cosine works the best
    beta_schedule="squaredcos_cap_v2",
    # clip output to [-1,1] to improve stability
    clip_sample=True,
    # our network predicts noise (instead of denoised action)
    prediction_type="epsilon",
)

# device transfer
device = torch.device("cuda")
_ = noise_pred_net.to(device)

num_epochs = config["n_epochs"]

# Exponential Moving Average
# accelerates training and improves stability
# holds a copy of the model weights
ema = EMAModel(parameters=noise_pred_net.parameters(), power=0.75)

# Standard ADAM optimizer
# Note that EMA parametesr are not optimized
optimizer = torch.optim.AdamW(
    params=noise_pred_net.parameters(), lr=1e-4, weight_decay=1e-6
)

# Cosine LR schedule with linear warmup
lr_scheduler = get_scheduler(
    name="cosine",
    optimizer=optimizer,
    num_warmup_steps=500,
    num_training_steps=len(dataloader) * num_epochs,
)

number of parameters: 6.642305e+07


# 3. Defining Evaluation Function
Here we define code to be used to evaluate our model and guage it's success rate.

In [None]:
def evaluate(model: ConditionalUnet1D, env, noise_scheduler, config, device):
    r = []
    s = []
    model.eval()

    for _ in range(config["num_eval_eps"]):
        obs, info = env.reset()
        rewards = list()
        obs_deque = collections.deque(
            [obs] * config["obs_horizon"], maxlen=config["obs_horizon"]
        )

        steps = 0
        while steps < config["eval_ep_len"]:
            obs_seq = np.stack(obs_deque)

            with torch.no_grad():
                noisy_action = torch.randn(
                    (1, config["pred_horizon"], config["action_dim"]),
                    device=config["device"],
                )
                obs = torch.from_numpy(obs_seq).to(device=device, dtype=torch.float32)
                obs = obs.unsqueeze(0).flatten(1)
                noise_scheduler.set_timesteps(config["num_diffusion_iters"])
                for k in noise_scheduler.timesteps:
                    # predict noise
                    noise_pred = model(sample=noisy_action, timestep=k, global_cond=obs)

                    # inverse diffusion step (remove noise)
                    noisy_action = noise_scheduler.step(
                        model_output=noise_pred, timestep=k, sample=noisy_action
                    ).prev_sample
                actions = noisy_action[0].detach().to(device="cpu").numpy()
                start = config["obs_horizon"] - 1
                end = start + config["action_horizon"]

                for action in actions[start:end]:
                    observation, reward, _, _, info = env.step(action)
                    obs_deque.append(observation)
                    # and reward/vis
                    rewards.append(reward)
                    steps += 1
        s.append(info["success"])
        r.append(sum(rewards) / len(rewards))
    model.train()
    return s, r

# 4. Train function
Here we define a training function for our model. This code also evaluates the model every `eval_interval`.

In [None]:
def train(
    noise_pred_net: ConditionalUnet1D,
    optimizer,
    noise_scheduler,
    dataloader,
    ema,
    lr_scheduler,
    env,
    config: dict,
    device: str,
):
    if config["wandb"]:
        import wandb

        def log_fn(output: dict):
            wandb.log(output)

    log = {}
    noise_pred_net.train()
    for epoch in tqdm(range(config["n_epochs"])):
        epoch_loss = []
        for batch in dataloader:
            obs = batch[1].to(device=device)
            action = batch[0].to(device=device)
            B = obs.shape[0]
            obs_cond = obs[:, : config["obs_horizon"], :]
            # (B, obs_horizon * obs_dim)
            obs_cond = obs_cond.flatten(start_dim=1)
            noise = torch.randn(action.shape, device=device)
            timesteps = torch.randint(
                0, noise_scheduler.config.num_train_timesteps, (B,), device=device
            ).long()

            # add noise to the clean images according to the noise magnitude at each diffusion iteration
            # (this is the forward diffusion process)
            noisy_actions = noise_scheduler.add_noise(action, noise, timesteps)

            # predict the noise residual
            noise_pred = noise_pred_net(noisy_actions, timesteps, global_cond=obs_cond)

            # L2 loss
            loss = nn.functional.mse_loss(noise_pred, noise)

            # optimize
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            # step lr scheduler every batch
            # this is different from standard pytorch behavior
            lr_scheduler.step()
            ema.step(noise_pred_net.parameters())

            # logging
            loss_cpu = loss.item()
            epoch_loss.append(loss_cpu)

        if epoch % config["eval_interval"] == 0 and epoch > 0:
            ema_noise_pred_net = noise_pred_net
            ema.copy_to(ema_noise_pred_net.parameters())
            log["training/epoch_loss_mean"] = np.mean(np.array(epoch_loss))
            log["training/epoch_loss_std"] = np.std(np.array(epoch_loss))
            s, r = evaluate(
                ema_noise_pred_net, env, noise_scheduler, config, config["device"]
            )
            log["eval/success_rate"] = sum(np.array(s)) / len(s)
            log["eval/reward_avg"] = np.mean(np.array(r))
            log["eval/reward_std"] = np.std(np.array(r))
            if config["wandb"]:
                log_fn(log)
    ema_noise_pred_net = noise_pred_net
    ema.copy_to(ema_noise_pred_net.parameters())

# 5. Training the model.
This may take some time to train the model. If you selected wandb logging you will need to input your wandb api key here.

In [None]:
if config["wandb"]:
    wandb.init(
        name=f"maniskill2-{env_id}-{random.randint(0, 100000)}",
        group=f"maniskill2-{env_id}",
        project="diffusion-policy",
        config={
            "n_epochs": 10000,
            "obs_horizon": 2,
            "eval_interval": 1000,
            "num_eval_eps": 10,
            "eval_ep_len": 100,
            "pred_horizon": 16,
        },
    )

train(
    noise_pred_net,
    optimizer,
    noise_scheduler,
    dataloader,
    ema,
    lr_scheduler,
    env,
    config,
    device,
)

# 6. Render a Video
Here we use ManiSkill2's record episode functionality to record a video of our model in action.

In [None]:
video_env = RecordEpisode(env, "./videos", info_on_video=True)
evaluate(ema_noise_pred_net, video_env, noise_scheduler, config, config["device"])
video_env.flush_video()