In [1]:
import os
import pickle
import sys
from copy import deepcopy

import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np
import torch
from gymnasium.wrappers import RescaleAction
from torch import nn
from torch.distributions import Normal
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset, random_split

sys.path.append(os.path.abspath(".."))

from rlib.algorithms.sac import sac
from rlib.common.buffer import ReplayBuffer, RolloutBuffer
from rlib.common.evaluation import get_trajectory, save_frames_as_gif, validation
from rlib.common.logger import TensorBoardLogger
from rlib.common.policies import DeterministicMlpPolicy, MlpQCritic, StochasticMlpPolicy

%load_ext autoreload
%autoreload 2

### Expert data

In [59]:
env = gym.make("Pendulum-v1", render_mode="rgb_array")

min_action, max_action = -1, 1
env = RescaleAction(env, min_action, max_action)

In [60]:
obs_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
trajectory_len = 200

print(obs_dim, action_dim)

3 1


In [61]:
with open("./models/pendulum_stoc_expert", "rb") as file:
    expert_actor = pickle.load(file)

In [62]:
rb = RolloutBuffer()

In [69]:
rb.collect_rollouts(env, expert_actor, trajectories_n=50)

In [70]:
data = rb.get_data()
expert_trajectories = rb.get_trajectories(data)

In [73]:
expert_trajectories[0]["rewards"]

tensor([[-9.3213e+00],
        [-9.3062e+00],
        [-9.2314e+00],
        [-9.1002e+00],
        [-8.9172e+00],
        [-8.6883e+00],
        [-8.4207e+00],
        [-8.1240e+00],
        [-7.8083e+00],
        [-7.4833e+00],
        [-7.1585e+00],
        [-6.8543e+00],
        [-6.6680e+00],
        [-6.7582e+00],
        [-7.1258e+00],
        [-7.7598e+00],
        [-8.6369e+00],
        [-9.7198e+00],
        [-1.0865e+01],
        [-9.9924e+00],
        [-9.0551e+00],
        [-8.0828e+00],
        [-7.1102e+00],
        [-6.1688e+00],
        [-5.2086e+00],
        [-4.3655e+00],
        [-3.9019e+00],
        [-3.8279e+00],
        [-4.1469e+00],
        [-4.8658e+00],
        [-5.9872e+00],
        [-7.4943e+00],
        [-9.3196e+00],
        [-1.1418e+01],
        [-1.2574e+01],
        [-1.1125e+01],
        [-9.5973e+00],
        [-8.0586e+00],
        [-6.5907e+00],
        [-5.2656e+00],
        [-4.1284e+00],
        [-3.1889e+00],
        [-2.4376e+00],
        [-1

In [78]:
expert_trajectories[49]["q_estimations"].shape

torch.Size([200, 1])

### Decision Transformer

In [38]:
class DecisionTransformer(nn.Module):
    def __init__(
        self,
        obs_dim: int,
        action_dim: int,
        trajectory_len: int,
        embedding_dim: int = 32,
        nhead: int = 4,
        num_layers: int = 1,
    ):
        super().__init__()

        self.R_embedding = nn.Linear(1, embedding_dim)
        self.s_embedding = nn.Linear(obs_dim, embedding_dim)
        self.a_embedding = nn.Linear(action_dim, embedding_dim)

        self.t_embedding = nn.Embedding(trajectory_len, embedding_dim)

        decoder_layer = nn.TransformerDecoderLayer(
            embedding_dim,
            nhead,
            batch_first=True,
        )
        self.transformer = nn.TransformerDecoder(decoder_layer, num_layers)
        self.head = nn.Linear(embedding_dim, action_dim)

    def forward(self, R, s, a, t):
        """
        Args:
            R (torch.Tensor): (B, T, 1)
            s (torch.Tensor): (B, T, obs_dim)
            a (torch.Tensor): (B, T, action_dim)
            t (torch.Tensor): (B, T, 1)

        Returns:
            a_pred: (torch.Tensor): (B, T, action_dim)
        """
        t_emb = self.t_embedding(t).squeeze()

        R_emb = self.R_embedding(R) + t_emb  # (B, T, C)
        s_emb = self.s_embedding(s) + t_emb  # (B, T, C)
        a_emb = self.a_embedding(a) + t_emb  # (B, T, C)

        B, T, C = R_emb.shape

        token_emb = torch.cat((R_emb, s_emb, a_emb), dim=1)  # (B, 3T, C)

        memory = torch.zeros_like(token_emb)
        mask = nn.Transformer.generate_square_subsequent_mask(3 * T)

        hidden_states = self.transformer.forward(token_emb, memory, mask)  # (B, 3T, C)
        a_hidden = hidden_states[:, 2::3, :]  # (B, T, C)
        a_pred = self.head(a_hidden)  # (B, T, 1)

        return a_pred

In [39]:
dt = DecisionTransformer(obs_dim, action_dim, trajectory_len)

In [58]:
class TrajectoryDataset(Dataset):
    def __init__(self, trajectories, K):
        self.sequences

    def __getitem__(self, i):
        return (
            self.trajectories["q_estimations"][i],
            self.trajectories["observations"][i],
            self.trajectories["actions"][i],
            self.trajectories["timesteps"][i],
        )

    def __len__(self):
        return self.trajectories["observations"].shape[0]

    def _add_timesteps(self):
        rollout_size = self.trajectories["observations"].shape[0]
        self.trajectories["timesteps"] = torch.zeros(
            (rollout_size, 1), dtype=torch.int64
        )

        count = 0
        for i in range(rollout_size):
            self.trajectories["timesteps"][i] = count
            count += 1

            if self.trajectories["terminated"][i] or self.trajectories["truncated"][i]:
                count = 0


In [45]:
def train(
    decision_transformer: DecisionTransformer,
    optimizer: Adam,
    train_dataloader: DataLoader,
    test_dataloader: DataLoader,
    total_epochs: int = 10,
):
    logger = TensorBoardLogger(log_dir="./tb_logs/dt_")

    for epoch_n in range(total_epochs):
        loss = {"train": 0, "test": 0}

        for R, s, a, t in train_dataloader:
            print(
                R.shape,
                s.shape,
            )
            a_preds = decision_transformer(R, s, a, t)
            batch_loss = ((a_preds - a) ** 2).mean()
            loss["train"] += batch_loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        with torch.no_grad():
            for R, s, a, t in test_dataloader:
                a_preds = decision_transformer(R, s, a, t)
                batch_loss = ((a_preds - a) ** 2).mean()
                loss["test"] += batch_loss.item()

        logger.log_scalars(loss, epoch_n)


In [46]:
dataset = TrajectoryDataset(expert_trajectories)

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size

train_dataset, test_dataset = random_split(dataset, (train_size, test_size))
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=True)

In [47]:
decision_transformer = DecisionTransformer(obs_dim, action_dim, trajectory_len)
optimizer = Adam(decision_transformer.parameters(), lr=1e-3)

In [48]:
train(
    decision_transformer,
    optimizer,
    train_dataloader,
    test_dataloader,
)

torch.Size([128, 1]) torch.Size([128, 3])


ValueError: not enough values to unpack (expected 3, got 2)

In [157]:
loss = {"train": 0, "test": 0}

for R, s, a, t in train_dataloader:
    print(R.shape, s.shape, a.shape, t.shape)

    print(R.dtype, s.dtype, a.dtype, t.dtype)

    a_preds = decision_transformer(R, s, a, t)
    batch_loss = ((a_preds - a) ** 2).mean()
    loss["train"] += batch_loss.item()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    break

torch.Size([128, 1]) torch.Size([128, 3]) torch.Size([128, 1]) torch.Size([128, 1])
torch.float32 torch.float32 torch.float32 torch.int64


AssertionError: was expecting embedding dimension of 32, but got 96

In [None]:
def eval():
    pass

## db

In [None]:
num_embeddings = 10
embedding_dim = 32

emb_table = nn.Embedding(num_embeddings, embedding_dim)

In [None]:
embedding_dim = 32

decoder_layer = nn.TransformerDecoderLayer(
    d_model=embedding_dim,
    nhead=4,
    batch_first=True,
)

transformer = nn.TransformerDecoder(decoder_layer, num_layers=3)

B = 4
T = 10
C = embedding_dim

tgt = torch.zeros((B, T, C))
tgt_mask = nn.Transformer.generate_square_subsequent_mask(T)

memory = torch.zeros((B, 1, C))

transformer.forward(
    tgt=tgt,
    memory=memory,
    tgt_mask=tgt_mask,
).shape

torch.Size([4, 10, 32])