# Dependencies/setup
To begin, prepare the colab environment by clicking the play button below and make sure you are using a GPU runtime. This will install all dependencies for the future code. This can take up to 1.5 minutes.

In [None]:
!mkdir -p /usr/share/vulkan/icd.d
!wget -q https://raw.githubusercontent.com/haosulab/ManiSkill2/main/docker/nvidia_icd.json
!wget -q https://raw.githubusercontent.com/haosulab/ManiSkill2/main/docker/10_nvidia.json
!mv nvidia_icd.json /usr/share/vulkan/icd.d
!mv 10_nvidia.json /usr/share/glvnd/egl_vendor.d/10_nvidia.json
# dependencies
!apt-get install -y --no-install-recommends libvulkan-dev
!pip install git+https://github.com/arnavg115/ManiSkill2.git
!pip install --upgrade --no-cache-dir gdown
!pip install transformers wandb

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
  libvulkan1
Recommended packages:
  mesa-vulkan-drivers | vulkan-icd
The following NEW packages will be installed:
  libvulkan-dev libvulkan1
0 upgraded, 2 newly installed, 0 to remove and 23 not upgraded.
Need to get 1,020 kB of archives.
After this operation, 17.2 MB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/main amd64 libvulkan1 amd64 1.3.204.1-2 [128 kB]
Get:2 http://archive.ubuntu.com/ubuntu jammy/main amd64 libvulkan-dev amd64 1.3.204.1-2 [892 kB]
Fetched 1,020 kB in 1s (1,278 kB/s)
Selecting previously unselected package libvulkan1:amd64.
(Reading database ... 120903 files and directories currently installed.)
Preparing to unpack .../libvulkan1_1.3.204.1-2_amd64.deb ...
Unpacking libvulkan1:amd64 (1.3.204.1-2) ...
Selecting previously unselected package libvulkan-dev:amd64.
Preparing to un

In [None]:
try:
    import google.colab

    IN_COLAB = True
except:
    IN_COLAB = False

if IN_COLAB:
    import site

    site.main()

# Applying Decision Transformers to ManiSkill2


This notebook implements the [decision transformer](https://sites.google.com/berkeley.edu/decision-transformer) model and applies it to some of the environments from [maniskill2](https://maniskill2.github.io/).

In [None]:
# make all the important imports
import gymnasium as gym
import numpy as np
import h5py
import torch
import torch.nn as nn
from tqdm.notebook import tqdm
import mani_skill2.envs
from mani_skill2.utils.wrappers import RecordEpisode
from torch.utils.data import Dataset
import wandb
import random
import os



# 0. Initialize environment and download dataset.
The origiinal decision transformers paper uses offline reinforcement learning. This form of reinforcement learning uses recorded trajectories of other policies in order to learn. This is contrasted with online reinforcement learning since the model doesn't interact with the environment during training.

In our case we download these recorded trajectories. However, these trajectories don't contain reward data. So we then use the replay trajectory feature of Maniskill2 to also save these rewards.

In [None]:
# @markdown Specifies which environment is used.
env_id = "LiftCube-v0"  # @param ["PickCube-v0", "LiftCube-v0", "StackCube-v0"]

In [None]:
# init env
env = gym.make(env_id, obs_mode="state", control_mode="pd_ee_delta_pose")

In [None]:
# download the trajectories
!python -m mani_skill2.utils.download_demo {env_id} -o demos

Downloading v0 demonstrations: 1/1, LiftCube-v0
3.08Mit [00:01, 2.07Mit/s]               
32.8kit [00:00, 135kit/s]       


In [None]:
# replay trajectories and save the reward data.
!python -m mani_skill2.trajectory.replay_trajectory --traj-path \
    demos/v0/rigid_body/{env_id}/trajectory.h5 --save-traj \
    --obs-mode state --target-control-mode pd_ee_delta_pose --num-procs 2 --record-reward true

0step [00:00, ?step/s][2023-12-12 22:33:28.901] [svulkan2] [[31m[1merror[m] GLFW error: X11: The DISPLAY environment variable is missing. You may suppress this message by setting environment variable SAPIEN_NO_DISPLAY=1
[2023-12-12 22:33:29.135] [svulkan2] [[31m[1merror[m] GLFW error: The GLFW library is not initialized. You may suppress this message by setting environment variable SAPIEN_NO_DISPLAY=1

0step [00:00, ?step/s][A[2023-12-12 22:33:29.672] [svulkan2] [[31m[1merror[m] GLFW error: X11: The DISPLAY environment variable is missing. You may suppress this message by setting environment variable SAPIEN_NO_DISPLAY=1
[2023-12-12 22:33:29.826] [svulkan2] [[31m[1merror[m] GLFW error: The GLFW library is not initialized. You may suppress this message by setting environment variable SAPIEN_NO_DISPLAY=1
  logger.warn(
  logger.warn(
  logger.warn(
Replaying traj_0:  16% 14/89 [00:00<00:01, 52.09step/s, control_mode=pd_ee_delta_pose, obs_mode=state]
  logger.warn(
  logger.wa

# 2. Dataset
Here we are initializing the dataset and reading from the saved trajectories we created in the previous step.

In [None]:
def load_h5_data(data):
    out = dict()
    for k in data.keys():
        if isinstance(data[k], h5py.Dataset):
            out[k] = data[k][:]
        else:
            out[k] = load_h5_data(data[k])
    return out


class ManiSkill2Dataset(Dataset):
    def __init__(self, dataset_file: str, load_count=-1) -> None:
        self.dataset_file = dataset_file
        # for details on how the code below works, see the
        # quick start tutorial
        import h5py
        from mani_skill2.utils.io_utils import load_json

        self.data = h5py.File(dataset_file, "r")
        json_path = dataset_file.replace(".h5", ".json")
        self.json_data = load_json(json_path)
        self.episodes = self.json_data["episodes"]

        self.env_info = self.json_data["env_info"]
        self.env_id = self.env_info["env_id"]
        self.env_kwargs = self.env_info["env_kwargs"]

        self.observations = []
        self.actions = []
        self.rewards = []
        self.dones = []
        self.total_frames = 0
        if load_count == -1:
            load_count = len(self.episodes)
        for eps_id in tqdm(range(load_count)):
            eps = self.episodes[eps_id]
            trajectory = self.data[f"traj_{eps['episode_id']}"]
            trajectory = load_h5_data(trajectory)

            # we use :-1 here to ignore the last observation as that
            # is the terminal observation which has no actions
            self.observations.append(trajectory["obs"][:-1])
            self.actions.append(trajectory["actions"])
            self.rewards.append(trajectory["rewards"][1:])
            self.dones.append(trajectory["success"])
            # print(trajectory.keys())

        # self.rewards = np.vstack(self.rewards)

    def get_state_stats(self):
        arr = np.vstack(self.observations)
        return np.mean(arr, axis=0), np.std(arr, axis=0) + 1e-6

    def __len__(self):
        return len(self.observations)

    def __getitem__(self, idx):
        action = torch.from_numpy(self.actions[idx]).float()
        obs = torch.from_numpy(self.observations[idx]).float()
        rew = torch.from_numpy(self.rewards[idx]).float()
        done = torch.from_numpy(self.dones[idx]).float()
        return obs, action, rew, done

In [None]:
dataset = ManiSkill2Dataset(
    f"demos/v0/rigid_body/{env_id}/trajectory.state.pd_ee_delta_pose.h5"
)
# quick check that the env is configured correctly
assert env.action_space.shape[0] == dataset[0][1].shape[-1]

  0%|          | 0/98 [00:00<?, ?it/s]

In [None]:
len(dataset)

98

# 2. Define load batch
Now that we loaded our dataset into the `dataset`variable we now define a custom `load_batch` function that splits the data into batches as well as performs such preprocessing steps. Decision transformers uses 3 inputs, namely state, action, and returns-to-go.

While the other two are already defined in the dataset, the last one needs to be computed from the rewards. Returns-to-go represents the expected future rewards and it is computed by taking the sum of the rewards from a certain timestep forwards. For example, if we are at timestep 1 we can compute the returns-to-go by summing up the rewards from timestep 2 to the end of the batch. This is implemented in `discounted_cumsum`.

In [None]:
def discount_cumsum(x, gamma):
    discount_cumsum = np.zeros_like(x)
    discount_cumsum[-1] = x[-1]
    for t in reversed(range(x.shape[0] - 1)):
        discount_cumsum[t] = x[t] + gamma * discount_cumsum[t + 1]
    return discount_cumsum

In [None]:
def load_batch(
    batch_size=16,
    max_len=20,
    num_trajectories=100,
    state_dim=42,
    act_dim=7,
    max_ep_len=90,
    scale=50,
):
    batch_inds = np.random.choice(
        np.arange(num_trajectories),
        size=batch_size,
        replace=True,
    )
    state_mean, state_std = dataset.get_state_stats()
    s, a, r, d, rtg, timesteps, mask = [], [], [], [], [], [], []
    for i in range(batch_size):
        obs, action, rew, done = dataset[batch_inds[i]]
        si = random.randint(0, rew.shape[0] - 1 - max_len)
        s.append(obs[si : si + max_len].reshape(1, -1, state_dim))
        a.append(action[si : si + max_len].reshape(1, -1, act_dim))
        r.append(rew[si : si + max_len].reshape(1, -1, 1))
        d.append(done[si : si + max_len].reshape(1, -1))
        timesteps.append(np.arange(si, si + s[-1].shape[1]).reshape(1, -1))
        timesteps[-1][timesteps[-1] >= max_ep_len] = max_ep_len - 1
        rtg.append(
            discount_cumsum(rew[si:], gamma=1.0)[: s[-1].shape[1] + 1].reshape(1, -1, 1)
        )
        tlen = s[-1].shape[1]
        s[-1] = np.concatenate(
            [np.zeros((1, max_len - tlen, state_dim)), s[-1]], axis=1
        )
        s[-1] = (s[-1] - state_mean) / state_std
        a[-1] = np.concatenate(
            [np.ones((1, max_len - tlen, act_dim)) * -10.0, a[-1]], axis=1
        )
        r[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), r[-1]], axis=1)
        d[-1] = np.concatenate([np.ones((1, max_len - tlen)) * 2, d[-1]], axis=1)
        rtg[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), rtg[-1]], axis=1)
        timesteps[-1] = np.concatenate(
            [np.zeros((1, max_len - tlen)), timesteps[-1]], axis=1
        )
        mask.append(
            np.concatenate([np.zeros((1, max_len - tlen)), np.ones((1, tlen))], axis=1)
        )

    device = "cuda:0"
    s = torch.from_numpy(np.concatenate(s, axis=0)).to(
        dtype=torch.float32, device=device
    )
    a = torch.from_numpy(np.concatenate(a, axis=0)).to(
        dtype=torch.float32, device=device
    )
    r = torch.from_numpy(np.concatenate(r, axis=0)).to(
        dtype=torch.float32, device=device
    )
    d = torch.from_numpy(np.concatenate(d, axis=0)).to(dtype=torch.long, device=device)
    rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).to(
        dtype=torch.float32, device=device
    )
    timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).to(
        dtype=torch.long, device=device
    )
    mask = torch.from_numpy(np.concatenate(mask, axis=0)).to(device=device)
    return s, a, r, d, rtg, timesteps, mask

#3. Define model
The actual reinforcement model uses the basic decoder only architecture seen in models like `GPT-2`. In fact this implementation uses a modified version of the `GPT-2` model with a few extra additions.

Decision transformer works by taking in the three inputs and then running them through a linear layer that is analogous to the embedding that text tokens go through in a language transformer. We also have a position embedding based on the the timestep. The time step is then added to each of these embeddigs. These three embeddings are then interleaved together such that one timestep is made up of 3 tokens. This input is then fed into `gpt 2`. The hidden states from `gpt 2` are then fed through two different linear layers to obtain state (not used) and action predictions.

There is additional code that then takes this action predictions to obtain an action.

In [None]:
from transformers import GPT2Model, GPT2Config


class DecisionTransformer(nn.Module):
    def __init__(
        self,
        state_dim,
        act_dim,
        hidden_size,
        max_length=None,
        max_ep_len=4096,
        **kwargs
    ):
        super().__init__()
        config = GPT2Config(vocab_size=1, n_embd=hidden_size, **kwargs)
        self.transformer = GPT2Model(config)
        self.state_dim = state_dim
        self.act_dim = act_dim
        self.max_length = max_length
        self.embed_timestep = nn.Embedding(max_ep_len, hidden_size)
        self.embed_return = torch.nn.Linear(1, hidden_size)
        self.embed_state = torch.nn.Linear(self.state_dim, hidden_size)
        self.embed_action = torch.nn.Linear(self.act_dim, hidden_size)
        self.hidden_size = hidden_size
        self.embed_ln = nn.LayerNorm(hidden_size)
        self.predict_state = torch.nn.Linear(hidden_size, self.state_dim)
        self.predict_action = nn.Sequential(
            *([nn.Linear(hidden_size, self.act_dim), nn.Tanh()])
        )
        self.predict_return = torch.nn.Linear(hidden_size, 1)

    def forward(self, states, actions, rewards, rtg, timesteps, attention_mask):

        batch_size, seq_length = states.shape[0], states.shape[1]

        if attention_mask is None:
            # attention mask for GPT: 1 if can be attended to, 0 if not
            attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)

        # embed each modality with a different head
        state_embeddings = self.embed_state(states)
        action_embeddings = self.embed_action(actions)
        returns_embeddings = self.embed_return(rtg)
        time_embeddings = self.embed_timestep(timesteps)

        # time embeddings are treated similar to positional embeddings
        state_embeddings = state_embeddings + time_embeddings
        action_embeddings = action_embeddings + time_embeddings
        returns_embeddings = returns_embeddings + time_embeddings

        # this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...)
        # which works nice in an autoregressive sense since states predict actions
        stacked_inputs = (
            torch.stack(
                (returns_embeddings, state_embeddings, action_embeddings), dim=1
            )
            .permute(0, 2, 1, 3)
            .reshape(batch_size, 3 * seq_length, self.hidden_size)
        )
        stacked_inputs = self.embed_ln(stacked_inputs)

        # to make the attention mask fit the stacked inputs, have to stack it as well
        stacked_attention_mask = (
            torch.stack((attention_mask, attention_mask, attention_mask), dim=1)
            .permute(0, 2, 1)
            .reshape(batch_size, 3 * seq_length)
        )

        # we feed in the input embeddings (not word indices as in NLP) to the model
        transformer_outputs = self.transformer(
            inputs_embeds=stacked_inputs,
            attention_mask=stacked_attention_mask,
        )
        x = transformer_outputs["last_hidden_state"]

        # reshape x so that the second dimension corresponds to the original
        # returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t
        x = x.reshape(batch_size, seq_length, 3, self.hidden_size).permute(0, 2, 1, 3)

        # get predictions
        action_preds = self.predict_action(x[:, 1])  # predict next action given state
        state_preds = self.predict_state(x[:, 2])
        return_preds = self.predict_return(x[:, 2])

        return state_preds, action_preds, return_preds

    def get_action(self, states, actions, rewards, returns_to_go, timesteps, **kwargs):
        # we don't care about the past rewards in this model

        states = states.reshape(1, -1, self.state_dim)
        actions = actions.reshape(1, -1, self.act_dim)
        returns_to_go = returns_to_go.reshape(1, -1, 1)
        timesteps = timesteps.reshape(1, -1)

        if self.max_length is not None:
            states = states[:, -self.max_length :]
            actions = actions[:, -self.max_length :]
            returns_to_go = returns_to_go[:, -self.max_length :]
            timesteps = timesteps[:, -self.max_length :]

            # pad all tokens to sequence length
            attention_mask = torch.cat(
                [
                    torch.zeros(self.max_length - states.shape[1]),
                    torch.ones(states.shape[1]),
                ]
            )
            attention_mask = attention_mask.to(
                dtype=torch.long, device=states.device
            ).reshape(1, -1)
            states = torch.cat(
                [
                    torch.zeros(
                        (
                            states.shape[0],
                            self.max_length - states.shape[1],
                            self.state_dim,
                        ),
                        device=states.device,
                    ),
                    states,
                ],
                dim=1,
            ).to(dtype=torch.float32)
            actions = torch.cat(
                [
                    torch.zeros(
                        (
                            actions.shape[0],
                            self.max_length - actions.shape[1],
                            self.act_dim,
                        ),
                        device=actions.device,
                    ),
                    actions,
                ],
                dim=1,
            ).to(dtype=torch.float32)
            returns_to_go = torch.cat(
                [
                    torch.zeros(
                        (
                            returns_to_go.shape[0],
                            self.max_length - returns_to_go.shape[1],
                            1,
                        ),
                        device=returns_to_go.device,
                    ),
                    returns_to_go,
                ],
                dim=1,
            ).to(dtype=torch.float32)
            timesteps = torch.cat(
                [
                    torch.zeros(
                        (timesteps.shape[0], self.max_length - timesteps.shape[1]),
                        device=timesteps.device,
                    ),
                    timesteps,
                ],
                dim=1,
            ).to(dtype=torch.long)
        else:
            attention_mask = None

        _, action_preds, return_preds = self.forward(
            states,
            actions,
            None,
            returns_to_go,
            timesteps,
            attention_mask=attention_mask,
            **kwargs
        )

        return action_preds[0, -1]

#4. Define trainer
Here we define a trainer parent class and a sequence trainer child class to help train the model. This trainer contains details such as the optimizer, loss, and learning rate scheduler.

In [None]:
import time


class Trainer:
    def __init__(
        self,
        model,
        optimizer,
        batch_size,
        get_batch,
        loss_fn,
        scheduler=None,
        eval_fns=None,
        state_dim=None,
    ):
        self.model = model
        self.optimizer = optimizer
        self.batch_size = batch_size
        self.get_batch = get_batch
        self.loss_fn = loss_fn
        self.scheduler = scheduler
        self.eval_fns = [] if eval_fns is None else eval_fns
        self.diagnostics = dict()
        self.state_dim = state_dim
        self.start_time = time.time()

    def train_iteration(self, num_steps, iter_num=0, print_logs=False):

        train_losses = []
        logs = dict()

        train_start = time.time()

        self.model.train()
        for _ in range(num_steps):
            train_loss = self.train_step()
            train_losses.append(train_loss)
            if self.scheduler is not None:
                self.scheduler.step()

        logs["time/training"] = time.time() - train_start

        eval_start = time.time()

        self.model.eval()
        for eval_fn in self.eval_fns:
            outputs = eval_fn(self.model)
            for k, v in outputs.items():
                logs[f"evaluation/{k}"] = v

        logs["time/total"] = time.time() - self.start_time
        logs["time/evaluation"] = time.time() - eval_start
        logs["training/train_loss_mean"] = np.mean(train_losses)
        logs["training/train_loss_std"] = np.std(train_losses)

        for k in self.diagnostics:
            logs[k] = self.diagnostics[k]

        if print_logs:
            print("=" * 80)
            print(f"Iteration {iter_num}")
            for k, v in logs.items():
                print(f"{k}: {v}")

        return logs

    def train_step(self):
        states, actions, rewards, dones, attention_mask, returns = self.get_batch(
            self.batch_size, state_dim=self.state_dim
        )
        state_target, action_target, reward_target = (
            torch.clone(states),
            torch.clone(actions),
            torch.clone(rewards),
        )

        state_preds, action_preds, reward_preds = self.model.forward(
            states,
            actions,
            rewards,
            masks=None,
            attention_mask=attention_mask,
            target_return=returns,
        )

        # note: currently indexing & masking is not fully correct
        loss = self.loss_fn(
            state_preds,
            action_preds,
            reward_preds,
            state_target[:, 1:],
            action_target,
            reward_target[:, 1:],
        )
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.detach().cpu().item()


class SequenceTrainer(Trainer):
    def train_step(self):
        (
            states,
            actions,
            rewards,
            dones,
            rtg,
            timesteps,
            attention_mask,
        ) = self.get_batch(self.batch_size, state_dim=self.state_dim)
        action_target = torch.clone(actions)

        state_preds, action_preds, reward_preds = self.model.forward(
            states,
            actions,
            rewards,
            rtg[:, :-1],
            timesteps,
            attention_mask=attention_mask,
        )

        act_dim = action_preds.shape[2]
        action_preds = action_preds.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
        action_target = action_target.reshape(-1, act_dim)[
            attention_mask.reshape(-1) > 0
        ]

        loss = self.loss_fn(
            None,
            action_preds,
            None,
            None,
            action_target,
            None,
        )

        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.25)
        self.optimizer.step()

        with torch.no_grad():
            self.diagnostics["training/action_error"] = (
                torch.mean((action_preds - action_target) ** 2).detach().cpu().item()
            )

        return loss.detach().cpu().item()

# 5. Define variables and initialize model
Here we define the hyperparameters and other parameters needed for the model and training. Many of these specify variables for the transformer such as the embed dimension and the number of layers, while others specify variables for the loss or learning rate scheduler.

In [None]:
variant = {
    "embed_dim": 256,
    "n_layer": 4,
    "n_head": 1,
    "dropout": 0.1,
    "activation_function": "relu",
    "warmup_steps": 10_000,
    "learning_rate": 1e-4,
    "weight_decay": 1e-4,
    "max_iters": 15,
    "K": 30,
    "max_ep_len": 100,
    "num_eval_episodes": 100,
}
K = variant["K"]
state_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]
max_ep_len = variant["max_ep_len"]
num_eval_episodes = variant["num_eval_episodes"]
state_mean, state_std = dataset.get_state_stats()

In [None]:
model = DecisionTransformer(
    state_dim=state_dim,
    act_dim=act_dim,
    max_length=K,
    max_ep_len=100,
    hidden_size=variant["embed_dim"],
    n_layer=variant["n_layer"],
    n_head=variant["n_head"],
    n_inner=4 * variant["embed_dim"],
    activation_function=variant["activation_function"],
    n_positions=1024,
    resid_pdrop=variant["dropout"],
    attn_pdrop=variant["dropout"],
)
model = model.to(device="cuda:0")

#6. Create eval function
Here we define the evaluation functions to test how the model runs. This method comes in two fold with the `evaluate_episode_rtg` performing the evaluation for individual episodes. This method is then called in `eval_episdoes` which runs `evaluate_episode_rtg` a number of times. This latter function is also used to log some key statistics about this evaluation like the success rate and mean reward.



In [None]:
def evaluate_episode_rtg(
    env,
    state_dim,
    act_dim,
    model,
    max_ep_len=100,
    scale=10.0,
    state_mean=0.0,
    state_std=1.0,
    device="cuda:0",
    target_return=None,
    mode="normal",
):

    model.eval()
    model.to(device=device)

    state_mean = torch.from_numpy(state_mean).to(device=device)
    state_std = torch.from_numpy(state_std).to(device=device)

    state = env.reset()[0]
    if mode == "noise":
        state = state + np.random.normal(0, 0.1, size=state.shape)

    # we keep all the histories on the device
    # note that the latest action and reward will be "padding"
    states = (
        torch.from_numpy(state)
        .reshape(1, state_dim)
        .to(device=device, dtype=torch.float32)
    )
    actions = torch.zeros((0, act_dim), device=device, dtype=torch.float32)
    rewards = torch.zeros(0, device=device, dtype=torch.float32)

    ep_return = target_return
    target_return = torch.tensor(ep_return, device=device, dtype=torch.float32).reshape(
        1, 1
    )
    timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1)

    sim_states = []

    episode_return, episode_length = 0, 0
    for t in range(max_ep_len):

        # add padding
        actions = torch.cat([actions, torch.zeros((1, act_dim), device=device)], dim=0)
        rewards = torch.cat([rewards, torch.zeros(1, device=device)])

        action = model.get_action(
            (states.to(dtype=torch.float32) - state_mean) / state_std,
            actions.to(dtype=torch.float32),
            rewards.to(dtype=torch.float32),
            target_return.to(dtype=torch.float32),
            timesteps.to(dtype=torch.long),
        )
        actions[-1] = action
        action = action.detach().cpu().numpy()

        state, reward, terminated, truncated, info = env.step(action)
        done = terminated or truncated

        cur_state = torch.from_numpy(state).to(device=device).reshape(1, state_dim)
        states = torch.cat([states, cur_state], dim=0)
        rewards[-1] = reward

        if mode != "delayed":
            pred_return = target_return[0, -1] - (reward / scale)
        else:
            pred_return = target_return[0, -1]
        target_return = torch.cat([target_return, pred_return.reshape(1, 1)], dim=1)
        timesteps = torch.cat(
            [timesteps, torch.ones((1, 1), device=device, dtype=torch.long) * (t + 1)],
            dim=1,
        )

        episode_return += reward
        episode_length += 1

        # if done:
        #     break
    success = info["success"]

    return episode_return, episode_length, success

In [None]:
def eval_episodes(target_rew):
    def fn(model):
        returns, lengths = [], []
        success = 0
        for _ in range(num_eval_episodes):
            with torch.no_grad():
                ret, length, s = evaluate_episode_rtg(
                    env,
                    state_dim,
                    act_dim,
                    model,
                    max_ep_len=max_ep_len,
                    scale=10,
                    target_return=target_rew / 10,
                    mode="normal",
                    state_mean=state_mean,
                    state_std=state_std,
                    device="cuda:0",
                )
            returns.append(ret)
            lengths.append(length)
            success += s
        return {
            f"target_{target_rew}_return_mean": np.mean(returns),
            f"target_{target_rew}_return_std": np.std(returns),
            f"target_{target_rew}_length_mean": np.mean(lengths),
            f"target_{target_rew}_length_std": np.std(lengths),
            f"target_{target_rew}_success_rate": success / len(returns),
        }

    return fn

#7. Instantiate optimizer, scheduler, wandb and trainer
Here we instantiate the optimizer, scheduler and trainer objects to be used in the next step.
Using Wandb for logging can also be toggled on, but it requires an api key from Wandb.

In [None]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=variant["learning_rate"],
    weight_decay=variant["weight_decay"],
)
scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer, lambda steps: min((steps + 1) / variant["warmup_steps"], 1)
)

trainer = SequenceTrainer(
    model=model,
    optimizer=optimizer,
    batch_size=32,
    get_batch=load_batch,
    scheduler=scheduler,
    loss_fn=lambda s_hat, a_hat, r_hat, s, a, r: torch.mean((a_hat - a) ** 2),
    eval_fns=[eval_episodes(tar) for tar in [50, 25]],
    state_dim=state_dim,
)

In [None]:
use_wandb = False  # @param {type:"boolean"}
if use_wandb:
    wandb.init(
        name=f"maniskill2-{env_id}-{random.randint(0, 100000)}",
        group=f"maniskill2-{env_id}",
        project="decision-transformer",
        config=variant,
    )

# 8. Train the Model

In [None]:
for iter in range(variant["max_iters"]):
    outputs = trainer.train_iteration(1000, iter_num=iter, print_logs=True)
    if use_wandb:
        wandb.log(outputs)

Iteration 0
time/training: 129.9936866760254
evaluation/target_50_return_mean: 8.628591140859395
evaluation/target_50_return_std: 2.3996948189285146
evaluation/target_50_length_mean: 100.0
evaluation/target_50_length_std: 0.0
evaluation/target_50_success_rate: 0.0
evaluation/target_25_return_mean: 8.964036443337456
evaluation/target_25_return_std: 2.9256796689606714
evaluation/target_25_length_mean: 100.0
evaluation/target_25_length_std: 0.0
evaluation/target_25_success_rate: 0.0
time/total: 399.8114821910858
time/evaluation: 220.05823349952698
training/train_loss_mean: 0.13538279953598975
training/train_loss_std: 0.1235495622643161
training/action_error: 0.01846698857843876
Iteration 1
time/training: 126.23758864402771
evaluation/target_50_return_mean: 12.202699902284937
evaluation/target_50_return_std: 4.24248125644988
evaluation/target_50_length_mean: 100.0
evaluation/target_50_length_std: 0.0
evaluation/target_50_success_rate: 0.0
evaluation/target_25_return_mean: 11.71880450951496

# 9. Saving model weights and a video
Here we use Mani-Skill2's video functionality to record a video of the model in action.
There is also code to save the model weights fro later use.

In [None]:
from mani_skill2.utils.wrappers import RecordEpisode

video_env = RecordEpisode(
    gym.make(
        env_id,
        render_mode="cameras",
        enable_shadow=True,
        obs_mode="state",
        control_mode="pd_ee_delta_pose",
    ),
    "./videos",
    info_on_video=True,
)
out = evaluate_episode_rtg(
    video_env,
    state_dim,
    act_dim,
    model,
    target_return=50,
    state_mean=dataset.get_state_stats()[0],
    state_std=dataset.get_state_stats()[1],
)
video_env.flush_video()
if use_wandb:
    wandb.log({f"video_{env_id}": wandb.Video("./videos/0.mp4")})
from IPython.display import Video

Video("./videos/0.mp4", embed=True)

In [None]:
torch.save(model.state_dict(), os.path.join(wandb.run.dir, f"weights_{env_id}.pt"))
if use_wandb:
    wandb.save(f"weights_{env_id}.pt")