In [124]:
import os
import pickle
import sys
from copy import deepcopy

import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np
import torch
from gymnasium.wrappers import RescaleAction
from torch import nn
from torch.distributions import Normal
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset, random_split

sys.path.append(os.path.abspath(".."))

from rlib.algorithms.sac import sac
from rlib.common.buffer import ReplayBuffer, RolloutBuffer
from rlib.common.evaluation import get_trajectory, save_frames_as_gif, validation
from rlib.common.logger import TensorBoardLogger
from rlib.common.policies import DeterministicMlpPolicy, MlpQCritic, StochasticMlpPolicy

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Expert data

In [41]:
env = gym.make("Pendulum-v1", render_mode="rgb_array")

min_action, max_action = -1, 1
env = RescaleAction(env, min_action, max_action)

In [111]:
obs_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
trajectory_len = 200

print(obs_dim, action_dim)

3 1


In [42]:
with open("./models/pendulum_stoc_expert", "rb") as file:
    expert_actor = pickle.load(file)

In [43]:
rb = RolloutBuffer()

In [44]:
rb.collect_rollouts(env, expert_actor, trajectories_n=50)

In [45]:
expert_trajectories = rb.get_data()

  return torch.tensor(value, dtype=dtype)


### Decision Transformer

In [169]:
class DecisionTransformer(nn.Module):
    def __init__(
        self,
        obs_dim: int,
        action_dim: int,
        trajectory_len: int,
        embedding_dim: int = 32,
        nhead: int = 4,
        num_layers: int = 1,
    ):
        super().__init__()

        self.R_embedding = nn.Linear(1, embedding_dim)
        self.s_embedding = nn.Linear(obs_dim, embedding_dim)
        self.a_embedding = nn.Linear(action_dim, embedding_dim)

        self.t_embedding = nn.Embedding(trajectory_len, embedding_dim)

        decoder_layer = nn.TransformerDecoderLayer(
            embedding_dim,
            nhead,
            batch_first=True,
        )
        self.transformer = nn.TransformerDecoder(decoder_layer, num_layers)
        self.head = nn.Linear(embedding_dim, action_dim)

    def forward(self, R, s, a, t):
        """
        Args:
            R (torch.Tensor): (B, T, 1)
            s (torch.Tensor): (B, T, obs_dim)
            a (torch.Tensor): (B, T, action_dim)
            t (torch.Tensor): (B, T, 1)

        Returns:
            a_pred: (torch.Tensor): (B, T, action_dim)
        """
        t_emb = self.t_embedding(t).squeeze()

        print(
            self.t_embedding(t).shape,
            self.R_embedding(R).shape,
        )

        R_emb = self.R_embedding(R) + t_emb
        s_emb = self.s_embedding(s) + t_emb
        a_emb = self.a_embedding(a) + t_emb

        token_emb = torch.cat((R_emb, s_emb, a_emb), dim=2)
        B, T, C = token_emb.shape

        print(token_emb.shape)

        memory = torch.zeros_like(token_emb)
        mask = nn.Transformer.generate_square_subsequent_mask(T)

        hidden_states = self.transformer.forward(token_emb, memory, mask)
        a_hidden = hidden_states[:, :, -1]
        a_pred = self.head(a_hidden)

        return a_pred

In [170]:
dt = DecisionTransformer(obs_dim, action_dim, trajectory_len)

In [171]:
B, T = 8, 10

R = torch.rand((B, T, 1))
s = torch.rand((B, T, 3))
a = torch.rand((B, T, 1))
t = torch.randint(0, 10, (B, T, 1))

dt.forward(R, s, a, t)

torch.Size([8, 10, 1, 32]) torch.Size([8, 10, 32])
torch.Size([8, 10, 96])


AssertionError: was expecting embedding dimension of 32, but got 96

In [152]:
class TrajectoryDataset(Dataset):
    def __init__(self, data):
        self.data = data
        self._add_timesteps()

    def __getitem__(self, i):
        return (
            self.data["q_estimations"][i],
            self.data["observations"][i],
            self.data["actions"][i],
            self.data["timesteps"][i],
        )

    def __len__(self):
        return self.data["observations"].shape[0]

    def _add_timesteps(self):
        rollout_size = self.data["observations"].shape[0]
        self.data["timesteps"] = torch.zeros((rollout_size, 1), dtype=torch.int64)

        count = 0
        for i in range(rollout_size):
            self.data["timesteps"][i] = count
            count += 1

            if self.data["terminated"][i] or self.data["truncated"][i]:
                count = 0


In [153]:
def train(
    decision_transformer: DecisionTransformer,
    optimizer: Adam,
    train_dataloader: DataLoader,
    test_dataloader: DataLoader,
    total_epochs: int = 10,
):
    logger = TensorBoardLogger(log_dir="./tb_logs/dt_")

    for epoch_n in range(total_epochs):
        loss = {"train": 0, "test": 0}

        for R, s, a, t in train_dataloader:
            a_preds = decision_transformer(R, s, a, t)
            batch_loss = ((a_preds - a) ** 2).mean()
            loss["train"] += batch_loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        with torch.no_grad():
            for R, s, a, t in test_dataloader:
                a_preds = decision_transformer(R, s, a, t)
                batch_loss = ((a_preds - a) ** 2).mean()
                loss["test"] += batch_loss.item()

        logger.log_scalars(loss, epoch_n)


In [154]:
dataset = TrajectoryDataset(expert_trajectories)

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size

train_dataset, test_dataset = random_split(dataset, (train_size, test_size))
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=True)

In [155]:
decision_transformer = DecisionTransformer(obs_dim, action_dim, trajectory_len)
optimizer = Adam(decision_transformer.parameters(), lr=1e-3)

In [156]:
train(
    decision_transformer,
    optimizer,
    train_dataloader,
    test_dataloader,
)

AssertionError: was expecting embedding dimension of 32, but got 96

In [157]:
loss = {"train": 0, "test": 0}

for R, s, a, t in train_dataloader:
    print(
        R.shape,
        s.shape,
        a.shape,
        t.shape
    )

    print(
        R.dtype,
        s.dtype,
        a.dtype,
        t.dtype
    )

    a_preds = decision_transformer(R, s, a, t)
    batch_loss = ((a_preds - a) ** 2).mean()
    loss["train"] += batch_loss.item()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    break

torch.Size([128, 1]) torch.Size([128, 3]) torch.Size([128, 1]) torch.Size([128, 1])
torch.float32 torch.float32 torch.float32 torch.int64


AssertionError: was expecting embedding dimension of 32, but got 96

In [None]:
def eval():
    pass

## db

In [None]:
num_embeddings = 10
embedding_dim = 32

emb_table = nn.Embedding(num_embeddings, embedding_dim)

In [None]:
embedding_dim = 32

decoder_layer = nn.TransformerDecoderLayer(
    d_model=embedding_dim,
    nhead=4,
    batch_first=True,
)

transformer = nn.TransformerDecoder(decoder_layer, num_layers=3)

B = 4
T = 10
C = embedding_dim

tgt = torch.zeros((B, T, C))
tgt_mask = nn.Transformer.generate_square_subsequent_mask(T)

memory = torch.zeros((B, 1, C))

transformer.forward(
    tgt=tgt,
    memory=memory,
    tgt_mask=tgt_mask,
).shape

torch.Size([4, 10, 32])