In [1]:
%load_ext autoreload
%autoreload 2
import numpy as np
import enum
import copy
import connect4.connect4 as game
from pympler import asizeof
import deeplearning.buffer as buf
import torch 
import torch.nn as nn
import torch.optim as optim
import deeplearning.mlp as mlp
import torch.nn.functional as F

In [2]:
env = game.Connect4()
randomPlayer1 = game.RandomPlayer()
greedyPlayer2 = game.GreedyRandomPlayer()
buffer = buf.ReplayBuffer(20000)

In [3]:
gm = game.GameManager([randomPlayer1, greedyPlayer2])
gm.play(1000, game.Connect4, buffer)
gm.info()

gm = game.GameManager([randomPlayer1, randomPlayer1])
gm.play(500, game.Connect4, buffer)
gm.info()

gm = game.GameManager([greedyPlayer2, greedyPlayer2])
gm.play(500, game.Connect4, buffer)
gm.info()

p1:  0.337 p2:  0.662 draw:  0.001
p1:  0.492 p2:  0.506 draw:  0.002
p1:  0.552 p2:  0.446 draw:  0.002


In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Define Decision Transformer

In [5]:
import transformers
from transformers import DecisionTransformerModel, DecisionTransformerConfig

# Config
state_dim = env.nrows * env.ncols
act_dim = env.ncols
vocab_size = 1  # Vocab size is for encoder, not for input tokens
state_dim, act_dim, vocab_size
print(f"state_dim: {state_dim}, act_dim: {act_dim}, vocab_size: {vocab_size}")

config = DecisionTransformerConfig(state_dim=state_dim, act_dim=act_dim, vocab_size=vocab_size)

  from .autonotebook import tqdm as notebook_tqdm


state_dim: 42, act_dim: 7, vocab_size: 1


In [6]:
model = DecisionTransformerModel(config)

In [7]:
# Load dataset from ReplayBuffer
BATCH_SIZE = 32
X, y = buffer.create_training_examples(BATCH_SIZE)
print(np.shape(X), np.shape(y))
X = torch.stack(X)  # Create tensor from list of tensors

(32, 44) (32,)


## Let's just design our own Dataset and DataLoader from the Buffer

In [8]:
database = buffer.buffer
np.shape(database), database[0]

((2000,),
 {'states': [array([[0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0.]]),
   array([[0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 1.]]),
   array([[ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0., -1.,  0.,  0.,  1.]]),
   array([[ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.

In [9]:
# Buffer is 2000 episodes
# Let's look at the first episode
episode1 = database[0]
print(np.shape(episode1["states"]))  # (16, 6, 7) corresponds to 16 timesteps of 6x7 board
print(np.shape(episode1["actions"]))  # (16,) corresponds to 16 timesteps of 1 action
print(episode1["reward"])  # 1 if player 1 won, -1 if player 2 won, 0 if draw
print(episode1["elo"])  # Elo ratings of [player1, player2]

(16, 6, 7)
(16,)
-1
[940, 1060]


In [10]:
# Let's look at the random episode
episode = np.random.choice(database)
print(np.shape(episode["states"]))  # (seq_len, n_rows, n_cols)
print(np.shape(episode["actions"]))  # (seq_len,)
print(episode["reward"])  # 1 if player 1 won, -1 if player 2 won, 0 if draw
print(episode["elo"])  # Elo ratings of [player1, player2]

(11, 6, 7)
(11,)
1
[940.0, 940.0]


In [11]:
# Define a func to sample a batch of episodes
def sample_batch_episodes(buffer, batch_size=32):
    """
    Sample a batch of episodes from the buffer.
    Episodes should be sampled with probability weighted by their length.
    """
    # Get the lengths of all episodes in the buffer
    episode_lengths = [len(episode["states"]) for episode in buffer.buffer]
    
    # Calculate the probabilities for each episode based on their length
    probabilities = [length / sum(episode_lengths) for length in episode_lengths]
    
    # Sample episodes based on the calculated probabilities
    sampled_episodes = np.random.choice(buffer.buffer, size=batch_size, replace=False, p=probabilities)
    
    return sampled_episodes

In [12]:
# Create Hugging Face Dataset from pd.DataFrame
from datasets import Dataset
import pandas as pd

# database_dict = pd.DataFrame(database).to_dict(orient="list")
# dataset = Dataset.from_dict(database_dict)
df = pd.DataFrame(buffer.buffer)
print(np.shape(df.iloc[0]["states"]))
df["states"] = df["states"].apply(lambda x: np.reshape(x, (-1, env.nrows * env.ncols)))  # Flatten board state but keep sequence

print(np.shape(df.iloc[0]["states"]))
df.head()

(16, 6, 7)
(16, 42)


Unnamed: 0,states,actions,reward,elo
0,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[6, 3, 1, 0, 2, 3, 6, 0, 0, 6, 6, 0, 1, 3, 4, 3]",-1,"[940, 1060]"
1,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[0, 5, 5, 5, 5, 1, 5, 2, 5, 2, 0, 0, 6, 1, 2, ...",-1,"[940, 1060]"
2,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[6, 2, 0, 3, 3, 3, 6, 3, 3, 3, 0, 6, 0, 0, 0, ...",1,"[940, 1060]"
3,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[0, 3, 1, 3, 1, 3, 4, 3]",-1,"[940, 1060]"
4,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[2, 3, 4, 3, 2, 3, 0, 6, 2, 3]",-1,"[940, 1060]"


In [95]:
act_enc = np.eye(env.ncols, env.ncols, dtype=int)  # One-hot encoding of actions
df["actions"] = df["actions"].apply(lambda actions: np.stack([act_enc[action] for action in actions]))
df

Unnamed: 0,states,actions,reward,elo
0,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 1, 0, 0, 0],...",-1,"[940, 1060]"
1,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[1, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 1, 0],...",-1,"[940, 1060]"
2,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[0, 0, 0, 0, 0, 0, 1], [0, 0, 1, 0, 0, 0, 0],...",1,"[940, 1060]"
3,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[1, 0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0],...",-1,"[940, 1060]"
4,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0],...",-1,"[940, 1060]"
...,...,...,...,...
1995,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0],...",1,"[1060.0, 1060.0]"
1996,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 1, 0, 0],...",1,"[1060.0, 1060.0]"
1997,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[0, 1, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0],...",1,"[1060.0, 1060.0]"
1998,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 1, 0, 0, 0],...",1,"[1060.0, 1060.0]"


In [116]:
# Generate trajectories from each player's perspective
# The even elements of the list are player 1's perspective, the odd elements are player 2's perspective
# So take even elements for player 1, leave state and reward as is
# Take odd elements for player 2, flip reward and state

# Player 1's perspective
df_p1 = df.copy()
df_p1["states"] = df_p1["states"].apply(lambda x: x[::2])
df_p1["actions"] = df_p1["actions"].apply(lambda x: x[::2])
df_p1["elo"] = df_p1["elo"].apply(lambda x: x[0])

# Player 2's perspective
df_p2 = df.copy()
df_p2["states"] = -1 * df_p2["states"].apply(lambda x: x[1::2])  # Reversed board state
df_p2["actions"] = df_p2["actions"].apply(lambda x: x[1::2])
df_p2["reward"] = -1 * df_p2["reward"]  # Reversed reward
df_p2["elo"] = df_p2["elo"].apply(lambda x: x[1])

In [122]:
df = pd.concat([df_p1, df_p2])
df

Unnamed: 0,states,actions,reward,elo
0,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[0, 0, 0, 0, 0, 0, 1], [0, 1, 0, 0, 0, 0, 0],...",-1,940.0
1,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[1, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 1, 0],...",-1,940.0
2,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[0, 0, 0, 0, 0, 0, 1], [1, 0, 0, 0, 0, 0, 0],...",1,940.0
3,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[1, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0],...",-1,940.0
4,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[[0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 1, 0, 0],...",-1,940.0
...,...,...,...,...
1995,"[[-0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0...","[[0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0],...",-1,1060.0
1996,"[[-0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0...","[[0, 0, 0, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 0],...",-1,1060.0
1997,"[[-0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0...","[[0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 0, 1, 0],...",-1,1060.0
1998,"[[-0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0...","[[0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 1, 0, 0],...",-1,1060.0


In [123]:
dataset = Dataset.from_dict(df.to_dict(orient="list"))
print(dataset.features)
print(dataset)
print(dataset[0])
# del df to free up memory

{'states': Sequence(feature=Sequence(feature=Value(dtype='float64', id=None), length=-1, id=None), length=-1, id=None), 'actions': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None), 'reward': Value(dtype='int64', id=None), 'elo': Value(dtype='float64', id=None)}
Dataset({
    features: ['states', 'actions', 'reward', 'elo'],
    num_rows: 4000
})
{'states': [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.

In [125]:
np.shape(dataset[0]["states"])

(8, 42)

In [126]:
df["states"].apply(len).max()

21

In [159]:
import random
from torch.nn.utils.rnn import pad_sequence

# Create Hugging Face DataCollator
class DecisionTransformerDataCollator:
    """DataCollator for Decision Transformer.
    When __call__ is called, should sample a random batch of episodes from
    the dataset, whilst padding the states and actions to the max length.
    Should also create an attention mask for the states and actions so that
    padding, the current action and future actions are masked.
    """
    max_len: int = 20  # subsets of the episode padded to this length
    state_dim: int = 42  # 6x7 board
    act_dim: int = 7  # 7 columns
    max_ep_len: int = 50  # Approx max episode length in dataset
    scale: float = max_ep_len  # Scale rewards per timestep?
    p_sample: np.array = None  # Distribution to sample episodes by length
    n_traj: int = 0  # to store the number of trajectories in the dataset

    def __init__(self, dataset) -> None:
        self.act_dim = env.ncols
        self.state_dim = env.nrows * env.ncols
        self.dataset = dataset

        traj_lens = []
        for act in dataset["actions"]:
            traj_lens.append(len(act))
        
        self.n_traj = len(traj_lens)

        traj_lens = np.array(traj_lens)
        self.p_sample = traj_lens / np.sum(traj_lens)  # weight sampling prob by traj len

        self.max_ep_len = int(max(traj_lens))
    
    def _discount_cumsum(self):
        pass  # Single terminal rewards, so no need to discount

    def __call__(self, features):  # No idea what is calling this, and what features is.
        """
        Args:
            """

        batch_size = len(features)  # this doesn't really make sense but try
        batch_inds = np.random.choice(
            np.arange(self.n_traj),
            size=batch_size,
            replace=True,
            p=self.p_sample
        )

        # A batch of dataset features
        s, a, r, d, rtg, timesteps, mask = [], [], [], [], [], [], []

        for ind in batch_inds:
            # for feature in features:
            feature = self.dataset[int(ind)]
            ep_len = len(feature["actions"])
            si = random.randint(0, ep_len - 1)  # Random timestep

            # get sequences from dataset
            # s.append(np.array(feature["states"]).reshape(1, -1, state_dim))
            s.append(torch.tensor(feature["states"]))

            # print(s)
            # a.append(np.array(feature["actions"]).reshape(1, -1, act_dim))
            a.append(torch.tensor(feature["actions"]))
            # r.append(np.array(feature["reward"]))
            reward = torch.tensor(feature["reward"], dtype=torch.int64)

            # d.append(np.array(feature["dones"][si : si + max_len]).reshape(1, -1))
            timesteps.append(torch.arange(ep_len))
            rtg.append(reward.expand(ep_len, 1))
            
            # Attend up to current timestep
            mask.append(torch.concatenate([torch.ones((si)), torch.zeros((self.max_ep_len - si))], axis=0))

            print(s[-1].shape, a[-1].shape, reward.shape, timesteps[-1].shape, mask[-1].shape)

        print(s)
        print("\n\n#####################\n\n")
        print(a)
        print("\n\n#####################\n\n")
        print(rtg)
        print("\n\n#####################\n\n")
        print(timesteps)
        print("\n\n#####################\n\n")
        print(mask)

        # padding sequences
        s = pad_sequence(s, batch_first=True)
        a = pad_sequence(a, batch_first=True)
        # r = pad_sequence(r, batch_first=True)
        # d = pad_sequence(d, batch_first=True)
        rtg = pad_sequence(rtg, batch_first=True)
        timesteps = pad_sequence(timesteps, batch_first=True)
        mask = pad_sequence(mask, batch_first=True)

        # # convert to tensors
        # s = torch.from_numpy(np.concatenate(s, axis=0)).float()
        # a = torch.from_numpy(np.concatenate(a, axis=0)).float()
        # # r = torch.from_numpy(np.concatenate(r, axis=0)).float()
        # # d = torch.from_numpy(np.concatenate(d, axis=0))
        # rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).float()
        # timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).long()
        # mask = torch.from_numpy(np.concatenate(mask, axis=0)).float()

        print(s.shape, a.shape, rtg.shape, timesteps.shape, mask.shape)
            
        return {
            "states": s.float(),
            "actions": a.float(),
            # "reward": r,
            # "dones": d,
            "returns_to_go": rtg.float(),
            "timesteps": timesteps,
            "attention_mask": mask
        }

In [None]:
class DecisionTransformerGymDataCollator:
    return_tensors: str = "pt"
    max_len: int = 20 #subsets of the episode we use for training
    state_dim: int = 17  # size of state space
    act_dim: int = 6  # size of action space
    max_ep_len: int = 1000 # max episode length in the dataset
    scale: float = 1000.0  # normalization of rewards/returns
    state_mean: np.array = None  # to store state means
    state_std: np.array = None  # to store state stds
    p_sample: np.array = None  # a distribution to take account trajectory lengths
    n_traj: int = 0 # to store the number of trajectories in the dataset

    def __init__(self, dataset) -> None:
        self.act_dim = len(dataset[0]["actions"][0])
        self.state_dim = len(dataset[0]["observations"][0])
        self.dataset = dataset
        # calculate dataset stats for normalization of states
        states = []
        traj_lens = []
        for obs in dataset["observations"]:
            states.extend(obs)
            traj_lens.append(len(obs))
        self.n_traj = len(traj_lens)
        states = np.vstack(states)
        self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6
        
        traj_lens = np.array(traj_lens)
        self.p_sample = traj_lens / sum(traj_lens)

    def _discount_cumsum(self, x, gamma):
        discount_cumsum = np.zeros_like(x)
        discount_cumsum[-1] = x[-1]
        for t in reversed(range(x.shape[0] - 1)):
            discount_cumsum[t] = x[t] + gamma * discount_cumsum[t + 1]
        return discount_cumsum

    def __call__(self, features):
        batch_size = len(features)
        # this is a bit of a hack to be able to sample of a non-uniform distribution
        batch_inds = np.random.choice(
            np.arange(self.n_traj),
            size=batch_size,
            replace=True,
            p=self.p_sample,  # reweights so we sample according to timesteps
        )
        # a batch of dataset features
        s, a, r, d, rtg, timesteps, mask = [], [], [], [], [], [], []
        
        for ind in batch_inds:
            # for feature in features:
            feature = self.dataset[int(ind)]
            si = random.randint(0, len(feature["rewards"]) - 1)

            # get sequences from dataset
            s.append(np.array(feature["observations"][si : si + self.max_len]).reshape(1, -1, self.state_dim))
            a.append(np.array(feature["actions"][si : si + self.max_len]).reshape(1, -1, self.act_dim))
            r.append(np.array(feature["rewards"][si : si + self.max_len]).reshape(1, -1, 1))

            d.append(np.array(feature["dones"][si : si + self.max_len]).reshape(1, -1))
            timesteps.append(np.arange(si, si + s[-1].shape[1]).reshape(1, -1))
            timesteps[-1][timesteps[-1] >= self.max_ep_len] = self.max_ep_len - 1  # padding cutoff
            rtg.append(
                self._discount_cumsum(np.array(feature["rewards"][si:]), gamma=1.0)[
                    : s[-1].shape[1]   # TODO check the +1 removed here
                ].reshape(1, -1, 1)
            )
            if rtg[-1].shape[1] < s[-1].shape[1]:
                print("if true")
                rtg[-1] = np.concatenate([rtg[-1], np.zeros((1, 1, 1))], axis=1)

            # padding and state + reward normalization
            tlen = s[-1].shape[1]
            s[-1] = np.concatenate([np.zeros((1, self.max_len - tlen, self.state_dim)), s[-1]], axis=1)
            s[-1] = (s[-1] - self.state_mean) / self.state_std
            a[-1] = np.concatenate(
                [np.ones((1, self.max_len - tlen, self.act_dim)) * -10.0, a[-1]],
                axis=1,
            )
            r[-1] = np.concatenate([np.zeros((1, self.max_len - tlen, 1)), r[-1]], axis=1)
            d[-1] = np.concatenate([np.ones((1, self.max_len - tlen)) * 2, d[-1]], axis=1)
            rtg[-1] = np.concatenate([np.zeros((1, self.max_len - tlen, 1)), rtg[-1]], axis=1) / self.scale
            timesteps[-1] = np.concatenate([np.zeros((1, self.max_len - tlen)), timesteps[-1]], axis=1)
            mask.append(np.concatenate([np.zeros((1, self.max_len - tlen)), np.ones((1, tlen))], axis=1))

        s = torch.from_numpy(np.concatenate(s, axis=0)).float()
        a = torch.from_numpy(np.concatenate(a, axis=0)).float()
        r = torch.from_numpy(np.concatenate(r, axis=0)).float()
        d = torch.from_numpy(np.concatenate(d, axis=0))
        rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).float()
        timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).long()
        mask = torch.from_numpy(np.concatenate(mask, axis=0)).float()

        return {
            "states": s,
            "actions": a,
            "rewards": r,
            "returns_to_go": rtg,
            "timesteps": timesteps,
            "attention_mask": mask,
        }


In [147]:
from torch.nn.utils.rnn import pad_sequence
batch_inds = [22, 31, 43, 56, 5]
s, a, r, d, rtg, timesteps, mask = [], [], [], [], [], [], []
max_len = 20
max_ep_len = 42
state_dim = 42
act_dim = 7

import random

for ind in batch_inds:
    # for feature in features:
    feature = dataset[int(ind)]
    ep_len = len(feature["actions"])
    si = random.randint(0, ep_len - 1)  # Random timestep

    # get sequences from dataset
    # s.append(np.array(feature["states"]).reshape(1, -1, state_dim))
    s.append(torch.tensor(feature["states"]))

    # print(s)
    # a.append(np.array(feature["actions"]).reshape(1, -1, act_dim))
    a.append(torch.tensor(feature["actions"]))
    # r.append(np.array(feature["reward"]))
    reward = torch.tensor(feature["reward"], dtype=torch.int64)

    # d.append(np.array(feature["dones"][si : si + max_len]).reshape(1, -1))
    timesteps.append(torch.arange(ep_len))
    rtg.append(reward.expand(ep_len))
    
    # Attend up to current timestep
    mask.append(torch.concatenate([torch.ones((si)), torch.zeros((max_ep_len - si))], axis=0))

    print(s[-1].shape, a[-1].shape, reward.shape, timesteps[-1].shape, mask[-1].shape)

print(s)
print("\n\n#####################\n\n")
print(a)
print("\n\n#####################\n\n")
print(rtg)
print("\n\n#####################\n\n")
print(timesteps)
print("\n\n#####################\n\n")
print(mask)

# padding sequences
s = pad_sequence(s, batch_first=True)
a = pad_sequence(a, batch_first=True)
# r = pad_sequence(r, batch_first=True)
# d = pad_sequence(d, batch_first=True)
rtg = pad_sequence(rtg, batch_first=True)
timesteps = pad_sequence(timesteps, batch_first=True)
mask = pad_sequence(mask, batch_first=True)

# # convert to tensors
# s = torch.from_numpy(np.concatenate(s, axis=0)).float()
# a = torch.from_numpy(np.concatenate(a, axis=0)).float()
# # r = torch.from_numpy(np.concatenate(r, axis=0)).float()
# # d = torch.from_numpy(np.concatenate(d, axis=0))
# rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).float()
# timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).long()
# mask = torch.from_numpy(np.concatenate(mask, axis=0)).float()

print(s.shape, a.shape, rtg.shape, timesteps.shape, mask.shape)

return_dict = {
    "states": s,
    "actions": a,
    # "rewards": r,
    "returns_to_go": rtg,
    "timesteps": timesteps,
    "attention_mask": mask,
}    

torch.Size([6, 42]) torch.Size([6, 7]) torch.Size([]) torch.Size([6]) torch.Size([42])
torch.Size([7, 42]) torch.Size([7, 7]) torch.Size([]) torch.Size([7]) torch.Size([42])
torch.Size([5, 42]) torch.Size([5, 7]) torch.Size([]) torch.Size([5]) torch.Size([42])
torch.Size([6, 42]) torch.Size([6, 7]) torch.Size([]) torch.Size([6]) torch.Size([42])
torch.Size([5, 42]) torch.Size([5, 7]) torch.Size([]) torch.Size([5]) torch.Size([42])
[tensor([[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., -1.,  0.,  1.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         

In [87]:
# # Create Hugging Face DataCollator
# from datasets import Features
# class DecisionTransformerDataCollator:
#     """
#     Data collator for Decision Transformer.
#     It pads the inputs and masks the padding.
#     """
    
#     def __call__(self, features: Features):
#         # Get the max sequence length in the batch
#         max_length = max([len(feature["states"]) for feature in features])
        
#         # Pad all sequences to the max length
#         padded_states = [F.pad(feature["states"], (0, 0, 0, max_length - len(feature["states"]))) for feature in features]
#         padded_actions = [F.pad(feature["actions"], (0, max_length - len(feature["actions"]))) for feature in features]
#         padded_rewards = [F.pad(feature["reward"], (0, max_length - len(feature["reward"]))) for feature in features]
#         padded_elos = [F.pad(feature["elo"], (0, max_length - len(feature["elo"]))) for feature in features]
        
#         # Stack the tensors
#         batch_states = torch.stack(padded_states)
#         batch_actions = torch.stack(padded_actions)
#         batch_rewards = torch.stack(padded_rewards)
#         batch_elos = torch.stack(padded_elos)
        
#         # Create attention mask
#         batch_attention_mask = torch.ones_like(batch_states)
#         batch_attention_mask[batch_states == 0] = 0
        
#         # Create inputs dict
#         inputs = {
#             "states": batch_states,
#             "actions": batch_actions,
#             "rewards": batch_rewards,
#             "elos": batch_elos,
#             "attention_mask": batch_attention_mask
#         }
        
#         return inputs

NameError: name 'dataclass' is not defined

In [160]:
# Subclass DecisionTransformerModel
class TrainableDT(DecisionTransformerModel):
    def __init__(self, config):
        super().__init__(config)

    def forward(self, **kwargs):
        output = super().forward(**kwargs)
        # add the DT loss
        action_preds = output[1]
        action_targets = kwargs["actions"]
        attention_mask = kwargs["attention_mask"]
        act_dim = action_preds.shape[2]
        action_preds = action_preds.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
        action_targets = action_targets.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
        
        loss = torch.mean((action_preds - action_targets) ** 2)

        return {"loss": loss}

    def original_forward(self, **kwargs):
        return super().forward(**kwargs)


In [161]:
from transformers import Trainer, TrainingArguments
training_args = TrainingArguments(
    output_dir="output/",
    remove_unused_columns=False,
    num_train_epochs=120,
    per_device_train_batch_size=64,
    learning_rate=1e-4,
    weight_decay=1e-4,
    warmup_ratio=0.1,
    optim="adamw_torch",
    max_grad_norm=0.25,
)

collator = DecisionTransformerDataCollator(dataset)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=collator,
)

trainer.train()


  0%|          | 0/7560 [02:47<?, ?it/s]
  0%|          | 0/7560 [00:00<?, ?it/s]

torch.Size([20, 42]) torch.Size([20, 7]) torch.Size([]) torch.Size([20]) torch.Size([21])
torch.Size([6, 42]) torch.Size([6, 7]) torch.Size([]) torch.Size([6]) torch.Size([21])
torch.Size([9, 42]) torch.Size([9, 7]) torch.Size([]) torch.Size([9]) torch.Size([21])
torch.Size([6, 42]) torch.Size([6, 7]) torch.Size([]) torch.Size([6]) torch.Size([21])
torch.Size([18, 42]) torch.Size([18, 7]) torch.Size([]) torch.Size([18]) torch.Size([21])
torch.Size([18, 42]) torch.Size([18, 7]) torch.Size([]) torch.Size([18]) torch.Size([21])
torch.Size([6, 42]) torch.Size([6, 7]) torch.Size([]) torch.Size([6]) torch.Size([21])
torch.Size([5, 42]) torch.Size([5, 7]) torch.Size([]) torch.Size([5]) torch.Size([21])
torch.Size([11, 42]) torch.Size([11, 7]) torch.Size([]) torch.Size([11]) torch.Size([21])
torch.Size([15, 42]) torch.Size([15, 7]) torch.Size([]) torch.Size([15]) torch.Size([21])
torch.Size([4, 42]) torch.Size([4, 7]) torch.Size([]) torch.Size([4]) torch.Size([21])
torch.Size([12, 42]) torch.S

RuntimeError: shape '[64, 60]' is invalid for input of size 4032