In [1]:
import torch
import torch.nn as nn
import numpy as np
import gym
from decision_transformer import DecisionTransformer
from dataset.trajectory_dataset import TrajectoryDataset

env = gym.make('CartPole-v1')
config = {
        "learning_rate": 2e-4,
        "epochs": 100,
        "batch_size": 32,
        "hidden_size": 64,
        "c_len": 50,
        "device": "auto",
        "weight_decay": 1e-4,
        "betas": (0.9, 0.999),
        "activation_function": "relu",
        'dropout': 0.1,
        "warmup_steps": 10000,
        "num_workers": 0
    }
    
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
c_len = config["c_len"]
model = DecisionTransformer(state_dim, action_dim, config["hidden_size"], c_len, 200, True, n_head=1, n_layer=3, n_inner=4*config['hidden_size'],
        activation_function=config['activation_function'],
        n_positions=1024,
        resid_pdrop=config['dropout'],
        attn_pdrop=config['dropout'], device=config["device"]).cuda()

train_dataset = TrajectoryDataset(c_len, state_dim, action_dim)

64 1
64 1
64 1
number of parameters: 0.15M
done


In [2]:
import sys
import numpy
numpy.set_printoptions(threshold=sys.maxsize)

In [5]:
states, actions, returns, dones, timesteps, attn_mask = train_dataset[3]
states = states.to(dtype=torch.float).unsqueeze(0).cuda()
actions = torch.from_numpy(actions).to(dtype=torch.float).unsqueeze(0).cuda()
returns = torch.from_numpy(returns).to(dtype=torch.float).unsqueeze(0).cuda()
timesteps = torch.from_numpy(timesteps).to(dtype=torch.long).unsqueeze(0).cuda()
attn_mask = torch.from_numpy(attn_mask).unsqueeze(0).cuda()
stacked_attn_mask = torch.stack(
            (attn_mask, attn_mask, attn_mask), dim=1
        ).permute(0, 2, 1).reshape(attn_mask.shape[0], 1, 3*attn_mask.shape[1])
attention_mask = stacked_attn_mask.transpose(-1, -2) @ stacked_attn_mask
stuff = -1e9 * attention_mask
print(attn_mask)
action_preds = model(
                    states, actions, returns, timesteps=timesteps, attn_mask=attn_mask
                )

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]],
       device='cuda:0', dtype=torch.float64)
tensor([[[[-1.0000e+09, -1.0000e+09, -1.0000e+09,  ..., -1.0000e+09,
           -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ..., -1.0000e+09,
           -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ..., -1.0000e+09,
           -1.0000e+09, -1.0000e+09],
          ...,
          [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ...,  5.5217e-03,
           -1.3782e-02,  3.0311e-04],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ...,  1.5389e-02,
            2.9173e-02,  2.3165e-02],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ...,  1.9076e-03,
            5.4983e-03,  1.5087e-02]]]], device='cuda:0',
       grad_fn=<AddBackward0>)
====
tensor([[[[1.