In [2]:
import numpy as np
from transformers import DecisionTransformerModel, DecisionTransformerConfig
from time import sleep
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
max_ep_length = 11264 # maximum number that can exists in timesteps (frame number we skip 10 so this gets high)
n_positions = 2**10 # The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
batch_size = 64 # The batch size to use for training.
state_dim = (10,10,10)
state_dim_flatten = state_dim[0]*state_dim[1]*state_dim[2]
action_dim = 6
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = DecisionTransformerConfig(state_dim=state_dim_flatten, act_dim=action_dim, max_ep_len=max_ep_length, n_positions=n_positions)
model = DecisionTransformerModel(config).to(device=device)

param_size = 0
for param in model.parameters():
    param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

size_model_mb = (param_size + buffer_size) / 1024**2
size_model_gb = size_model_mb / 1024
print(f"Size of model: {size_model_mb:.2f} MB ({size_model_gb:.2f} GB)")

Size of model: 12.26 MB (0.01 GB)


In [11]:
# fake data

def fake_data(sequence_length):
    target_return = torch.rand((sequence_length, 1), device=device, dtype=torch.float32) # expected future return
    states = torch.rand((sequence_length, state_dim_flatten), device=device, dtype=torch.float32)#.reshape(1, 1, state_dim_flatten) #
    actions = torch.rand((sequence_length, action_dim), device=device, dtype=torch.float32)
    timesteps = torch.tensor([i for i in range(1, sequence_length+1)], device=device, dtype=torch.long) # integer what timestep we on

    return states, actions, target_return, timesteps


states, actions, target_return, timesteps = fake_data(2)

print(f"states: {states.shape}, actions: {actions.shape}, target_return: {target_return.shape}, timesteps: {timesteps.shape}")



states: torch.Size([2, 1000]), actions: torch.Size([2, 6]), target_return: torch.Size([2, 1]), timesteps: torch.Size([2])
