In [205]:
import numpy as np
import torch
import torch.nn as nn
import random

import transformers

from model import TrajectoryModel
from trajectory_gpt2 import GPT2Model

import gymnasium as gym
import argparse
import time

In [206]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [239]:
env = gym.make("Acrobot-v1")
print(env.observation_space)
env.action_space.n

Box([ -1.        -1.        -1.        -1.       -12.566371 -28.274334], [ 1.        1.        1.        1.       12.566371 28.274334], (6,), float32)


3

In [207]:
env = gym.make("Acrobot-v1")

n_traj = 100
traj_states = torch.zeros((n_traj,traj_len,6))
traj_actions = torch.zeros(n_traj,traj_len)
traj_rewards = torch.zeros(n_traj,traj_len)

state_dim = 6
act_dim = 1
traj_len = 500
n = 0
while n<n_traj:
    state, info = env.reset()
    state = torch.tensor(state)

    for i in range(traj_len):
        action = env.action_space.sample()
        state_next, reward, terminated, truncated, info = env.step(action)

        traj_actions[n,i] = action
        traj_rewards[n,i] = reward
        traj_states[n,i,:] = state

        state = torch.tensor(state_next)

        if (terminated or truncated) and (i<499):
            n = n - 1
            break
            #state, info = env.reset()
    env.close()
    n = n + 1

In [210]:
torch.save( traj_states, "traj_states.pt")
torch.save(traj_actions, "traj_actions.pt")

In [213]:
traj_states = torch.load("traj_states.pt")
traj_states.size()

torch.Size([100, 500, 6])

In [194]:
def get_batch(batch_size, n_traj, traj_states, traj_actions):
    idxs = random.choices(range(n_traj), k=batch_size)
    batch_states = traj_states[idxs,:,:]
    batch_actions = traj_actions[idxs,:]
    
    return batch_states, batch_actions

In [195]:
class DecisionTransformer(TrajectoryModel):

    """
    This model uses GPT to model (Return_1, state_1, action_1, Return_2, state_2, ...)
    """

    def __init__(
            self,
            state_dim,
            act_dim,
            hidden_size,
            max_length=None,
            max_ep_len=500,
            action_tanh=True,
            **kwargs
    ):
        super().__init__(state_dim, act_dim, max_length=max_length)

        self.hidden_size = hidden_size
        config = transformers.GPT2Config(
            vocab_size=1,  # doesn't matter -- we don't use the vocab
            n_embd=hidden_size,
            **kwargs
        )

        # note: the only difference between this GPT2Model and the default Huggingface version
        # is that the positional embeddings are removed (since we'll add those ourselves)
        self.transformer = GPT2Model(config)

        self.embed_timestep = nn.Embedding(max_ep_len, hidden_size)
        #self.embed_return = torch.nn.Linear(1, hidden_size)
        self.embed_state = torch.nn.Linear(self.state_dim, hidden_size)
        self.embed_action = torch.nn.Linear(self.act_dim, hidden_size)

        self.embed_ln = nn.LayerNorm(hidden_size)

        # note: we don't predict states or returns for the paper
        self.predict_state = torch.nn.Linear(hidden_size, self.state_dim)
        self.predict_action = nn.Sequential(
            *([nn.Linear(hidden_size, self.act_dim)] + ([nn.Tanh()] if action_tanh else []))
        )
        self.predict_return = torch.nn.Linear(hidden_size, 1)

    def forward(self, states, actions, timesteps, attention_mask=None):

        batch_size, seq_length = states.shape[0], states.shape[1]

        if attention_mask is None:
            # attention mask for GPT: 1 if can be attended to, 0 if not
            attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)

        # embed each modality with a different head
        state_embeddings = self.embed_state(states)
        action_embeddings = self.embed_action(actions)
        #returns_embeddings = self.embed_return(returns_to_go)
        time_embeddings = self.embed_timestep(timesteps)

        # time embeddings are treated similar to positional embeddings
        state_embeddings = state_embeddings + time_embeddings
        action_embeddings = action_embeddings + time_embeddings
        #returns_embeddings = returns_embeddings + time_embeddings
        
        #print(state_embeddings.size())
        #print(action_embeddings.size())
        #print(returns_embeddings.size())
        
#         print(torch.stack(
#             (returns_embeddings, state_embeddings, action_embeddings), dim=1
#         ).size())

        # this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...)
        # which works nice in an autoregressive sense since states predict actions
        stacked_inputs = torch.stack(
            (state_embeddings, action_embeddings), dim=1
        ).permute(0, 2, 1, 3).reshape(batch_size, 2*seq_length, self.hidden_size)
        stacked_inputs = self.embed_ln(stacked_inputs)

        # to make the attention mask fit the stacked inputs, have to stack it as well
        stacked_attention_mask = torch.stack(
            (attention_mask, attention_mask), dim=1
        ).permute(0, 2, 1).reshape(batch_size, 2*seq_length)

        # we feed in the input embeddings (not word indices as in NLP) to the model
        transformer_outputs = self.transformer(
            inputs_embeds=stacked_inputs,
            attention_mask=stacked_attention_mask,
        )
        x = transformer_outputs['last_hidden_state']

        # reshape x so that the second dimension corresponds to the original
        # returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t
        x = x.reshape(batch_size, seq_length, 2, self.hidden_size).permute(0, 2, 1, 3)

        # get predictions
        #return_preds = self.predict_return(x[:,2])  # predict next return given state and action
        state_preds = self.predict_state(x[:,1])    # predict next state given state and action
        action_preds = self.predict_action(x[:,0])  # predict next action given state

        return state_preds, action_preds

    def get_state(self, states, actions, timesteps, **kwargs):
        # we don't care about the past rewards in this model

        states = states.reshape(1, -1, self.state_dim)
        actions = actions.reshape(1, -1, self.act_dim)
        timesteps = timesteps.reshape(1, -1)

        if self.max_length is not None:
            states = states[:,-self.max_length:]
            actions = actions[:,-self.max_length:]
            timesteps = timesteps[:,-self.max_length:]

            # pad all tokens to sequence length
            attention_mask = torch.cat([torch.zeros(self.max_length-states.shape[1]), torch.ones(states.shape[1])])
            attention_mask = attention_mask.to(dtype=torch.long, device=states.device).reshape(1, -1)
            states = torch.cat(
                [torch.zeros((states.shape[0], self.max_length-states.shape[1], self.state_dim), device=states.device), states],
                dim=1).to(dtype=torch.float32)
            actions = torch.cat(
                [torch.zeros((actions.shape[0], self.max_length - actions.shape[1], self.act_dim),
                             device=actions.device), actions],
                dim=1).to(dtype=torch.float32)
            timesteps = torch.cat(
                [torch.zeros((timesteps.shape[0], self.max_length-timesteps.shape[1]), device=timesteps.device), timesteps],
                dim=1
            ).to(dtype=torch.long)
        else:
            attention_mask = None

        state_preds, action_preds = self.forward(
            states, actions, timesteps, attention_mask=attention_mask, **kwargs)
        
#         print(state_preds.size())
#         print(state_preds[0,-1].size())

        return state_preds[0,-1]

In [196]:
model = DecisionTransformer(
            state_dim=state_dim,
            act_dim=act_dim,
            max_length=20,
            max_ep_len=500,
            hidden_size=128,
            n_layer=3,
            n_head=1,
            n_inner=4*128,
            activation_function='relu',
            n_positions=1024,
            resid_pdrop=0.1,
            attn_pdrop=0.1,
            n_ctx=1500
        )

In [197]:
batch_size = 100
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=0.1,
    weight_decay=1e-4,
)
scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer,
    lambda steps: min((steps+1)/10000, 1)
)

n_epochs = 1000

losses = torch.zeros(n_epochs)

for n in range(n_epochs):
    state_batch, action_batch = get_batch(batch_size, n_traj, traj_states, traj_actions)

    state_pred, action_preds = model.forward(state_batch, (action_batch.unsqueeze(-1)), torch.arange(0,500,1).unsqueeze(0))

    loss = nn.MSELoss()
    l = loss(state_pred,state_batch)
    losses[n] = l.detach().cpu().item()
    optimizer.zero_grad()
    l.backward()
    optimizer.step()
    scheduler.step()

KeyboardInterrupt: 

In [184]:
states_i, actions_i = get_batch(1, n_traj, traj_states, traj_actions)
actions_i.size()

torch.Size([1, 500])

In [188]:
model.eval()
model.to(device=device)

state = states_i[0,0,:]
action = actions_i[0,0]

# we keep all the histories on the device
# note that the latest action and reward will be "padding"
states = state.reshape(1, state_dim).to(device=device, dtype=torch.float32)
print(states)
actions = action.reshape(1, act_dim).to(device=device, dtype=torch.float32)
print(actions)
#rewards = torch.zeros(0, device=device, dtype=torch.float32)
#target_return = torch.tensor(target_return, device=device, dtype=torch.float32)
sim_states = []

episode_return, episode_length = 0, 0
for t in range(499):

    # add padding
    actions = torch.cat([actions, torch.zeros((1, act_dim), device=device)], dim=0)
    #rewards = torch.cat([rewards, torch.zeros(1, device=device)])

    pred_state = model.get_state(
        states.to(dtype=torch.float32) ,
        actions.to(dtype=torch.float32),
        torch.arange(0,t+1,1).unsqueeze(0)
    )
    
    actions[-1] = actions_i[0,t+1].reshape(1, act_dim)

    cur_state = (pred_state).to(device=device).reshape(1, state_dim)
    states = torch.cat([states, cur_state], dim=0)

tensor([[ 0.9999,  0.0117,  0.9994, -0.0350, -0.0586,  0.0875]])
tensor([[1.]])


In [189]:
states.size()

torch.Size([500, 6])

In [191]:
loss = nn.MSELoss()
l = loss(states.unsqueeze(0),states_i)
print(l)

tensor(1.4659, grad_fn=<MseLossBackward0>)
