In [2]:
import numpy as np
from transformers import DecisionTransformerModel, DecisionTransformerConfig
from time import sleep
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
max_ep_length = 11264 # maximum number that can exists in timesteps (frame number we skip 10 so this gets high)
n_positions = 2**10 # The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
batch_size = 64 # The batch size to use for training.
state_dim = (10,10,10)
state_dim_flatten = state_dim[0]*state_dim[1]*state_dim[2]
action_dim = 6
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = DecisionTransformerConfig(state_dim=state_dim_flatten, act_dim=action_dim, max_ep_len=max_ep_length, n_positions=n_positions, action_tanh=True)
model = DecisionTransformerModel(config).to(device=device)

param_size = 0
for param in model.parameters():
    param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

size_model_mb = (param_size + buffer_size) / 1024**2
size_model_gb = size_model_mb / 1024
print(f"Size of model: {size_model_mb:.2f} MB ({size_model_gb:.2f} GB)")

Size of model: 12.26 MB (0.01 GB)


In [38]:
# fake data

def fake_data(batch, sequence_length):
    target_return = torch.rand((batch, sequence_length, 1), device=device, dtype=torch.float32) # expected future return
    states = torch.rand((batch, sequence_length, state_dim_flatten), device=device, dtype=torch.float32)#.reshape(1, 1, state_dim_flatten) #
    actions = torch.rand((batch, sequence_length, action_dim), device=device, dtype=torch.float32)
    timesteps = torch.tensor([[i for i in range(1, sequence_length+1)] for _ in range(batch)], device=device, dtype=torch.long) # integer what timestep we on

    return states, actions, target_return, timesteps


seq = 2
b = 1
states, actions, target_return, timesteps = fake_data(b, seq)

print(f"states: {states.shape}, actions: {actions.shape}, target_return: {target_return.shape}, timesteps: {timesteps.shape}")


# foward pass
model.eval()
attention_mask = torch.ones((b, seq), device=device, dtype=torch.float32)

state_preds, action_preds, return_preds = model(states=states,
    actions=actions,
    rewards=None, #not used in foward pass https://github.com/huggingface/transformers/blob/v4.27.2/src/transformers/models/decision_transformer/modeling_decision_transformer.py#L831
    returns_to_go=target_return, # return to go is the sum of all future rewards
    timesteps=timesteps,
    attention_mask=attention_mask,
    return_dict=False)

# remove batch dim  
state_preds, action_preds, return_preds = torch.squeeze(state_preds, 0), torch.squeeze(action_preds, 0), torch.squeeze(return_preds, 0)
print(f"state_preds: {state_preds.shape}, action_preds: {action_preds.shape}, return_preds: {return_preds.shape}")

actionIdx = torch.argmax(action_preds[-1]).item()
pred_arr = torch.squeeze(action_preds[-1]).detach().cpu().numpy()
print(f"pred_arr: {pred_arr},  actionIdx: {actionIdx}")


states: torch.Size([1, 2, 1000]), actions: torch.Size([1, 2, 6]), target_return: torch.Size([1, 2, 1]), timesteps: torch.Size([1, 2])
state_preds: torch.Size([2, 1000]), action_preds: torch.Size([2, 6]), return_preds: torch.Size([2, 1])
pred_arr: [ 0.14693639 -0.23803829 -0.1962808  -0.10674385  0.01430107 -0.14090946],  actionIdx: 0


In [43]:
s, a, r, t = fake_data(64, 6)

a.mean().item()

0.5065848231315613

In [45]:

#x[:, 2] # get t's for actions


TypeError: list indices must be integers or slices, not tuple