In [3]:
# this notebook will use a basic GPT based decision transformer in offline reinforcement learning setting to create bot for trading stock

In [None]:
# import libraries
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

# based on https://github.com/nikhilbarhate99/min-decision-transformer/blob/master/decision_transformer/model.py

# define the masked causal attention
class MaskedAttention(nn.Module):
    def __init__(self, h_dim, max_T, n_heads, drop_p):
        super().__init__()
        self.n_heads = n_heads
        self.drop_p = drop_p
        # feed forward networks which create the query, key and value
        self.Q_net = nn.Linear(h_dim, h_dim)
        self.K_net = nn.Linear(h_dim, h_dim)
        self.V_net = nn.Linear(h_dim, h_dim)

        # feed forward network which projects the attention to the correct dimension
        self.proj_net = nn.Linear(h_dim, h_dim)

        # dropout layers
        self.att_drop = nn.Dropout(drop_p)
        self.proj_drop = nn.Dropout(drop_p)

        # create the mask
        mask = torch.tril(torch.ones(max_T, max_T)).view(1, 1, max_T, max_T)

        # register_buffer will make the mask a constant tensor
        # so that it will not be included in the model parameters and be updated during backpropagation
        self.register_buffer('mask', mask)

    def forward(self, x):
        # x: [B, T, H]
        B, T, C = x.shape # batch size, sequence length, hidden dimension * number of heads
        N, D = self.n_heads, C // self.n_heads # number of heads, dimension of each head

        # compute the query, key and value
        Q = self.Q_net(x).view(B, T, N, D).transpose(1, 2) # [B, N, T, D]
        K = self.K_net(x).view(B, T, N, D).transpose(1, 2)
        V = self.V_net(x).view(B, T, N, D).transpose(1, 2)

        # compute the attention
        weights = Q @ K.transpose(2,3) / math.sqrt(D) # QK^T / sqrt(D)
        weights = weights.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf')) # mask the future tokens
        normalized_weights = F.softmax(weights, dim=-1) # softmax along the last dimension
        A = self.att_drop(normalized_weights) # attention with dropout

        # compute the output
        # gather heads and project to correct dimension
        attention = attention.transpose(1, 2).contiguous().view(B, T, N*D)
        out = self.proj_drop(self.proj_net(attention))

        return out

# define the attention block with layer normalization and residual connection as well as the feed forward network
class AttentionBlock(nn.Module):
    def __init__(self, h_dim, max_T, n_heads, drop_p):
        super().__init__()
        self.attention = MaskedAttention(h_dim, max_T, n_heads, drop_p)
        self.norm1 = nn.LayerNorm(h_dim)
        self.norm2 = nn.LayerNorm(h_dim)
        self.ffn = nn.Sequential(
            nn.Linear(h_dim, 4*h_dim),
            nn.GELU(),
            nn.Linear(4*h_dim, h_dim),
            nn.Dropout(drop_p)
        )

    def forward(self, x):
        # x: [B, T, H]
        # Attention -> LayerNorm -> Residual -> FFN -> LayerNorm -> Residual
        out = self.norm1(x + self.attention(x))
        out = self.norm2(out + self.ffn(out))

        return out

# define the decision transformer
class DecisionTransformer(nn.Module):
    def __init__(self, state_dim, act_dim, n_block, h_dim, context_len, n_heads, drop_p, max_timestep = 4096):
        super().__init__()
        self.state_dim = state_dim
        self.act_dim = act_dim
        self.h_dim = h_dim

        # transformer blocks
        input_seq_len = 3 * context_len
        blocks = [AttentionBlock(h_dim, input_seq_len, n_heads, drop_p) for _ in range(n_block)]
        self.transformer = nn.Sequential(*blocks)

        # embedding layers
        self.state_emb = nn.Linear(state_dim, h_dim)
        self.act_emb = nn.Linear(act_dim, h_dim)

        # decision transformer blocks
        self.blocks = nn.ModuleList([AttentionBlock(h_dim, context_len, n_heads, drop_p) for _ in range(n_block)])

        # output layers
        self.out = nn.Linear(h_dim, act_dim)