In [3]:
# this notebook will use a basic GPT based decision transformer in offline reinforcement learning setting to create bot for trading stock

In [None]:
# import libraries
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

# define the masked causal attention
# based on https://github.com/nikhilbarhate99/min-decision-transformer/blob/master/decision_transformer/model.py
class MaskedAttention(nn.Module):
    def __init__(self, h_dim, max_T, n_heads, drop_p):
        super().__init__()
        self.n_heads = n_heads
        self.drop_p = drop_p

        self.Q_net = nn.Linear(h_dim, h_dim)
        self.K_net = nn.Linear(h_dim, h_dim)
        self.V_net = nn.Linear(h_dim, h_dim)

        self.proj_net = nn.linear(h_dim, h_dim)

        self.att_drop = nn.Dropout(drop_p)
        self.proj_drop = nn.Dropout(drop_p)

        mask = torch.tril(torch.ones(max_T, max_T)).view(1, 1, max_T, max_T)

        # register_buffer will make the mask a constant tensor
        # so that it will not be included in the model parameters and be updated during backpropagation
        self.register_buffer('mask', mask)

    def forward(self, x):
        # x: [B, T, H]
        B, T, C = x.shape # batch size, sequence length, hidden dimension * number of heads
        N, D = self.n_heads, C // self.n_heads # number of heads, dimension of each head

        # compute the query, key and value
        Q = self.Q_net(x).view(B, T, N, D).transpose(1, 2) # [B, N, T, D]
        K = self.K_net(x).view(B, T, N, D).transpose(1, 2)
        V = self.V_net(x).view(B, T, N, D).transpose(1, 2)

        # compute the attention
        weights = Q @ K.transpose(2,3) / math.sqrt(D) # QK^T / sqrt(D)
        weights = weights.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf')) # mask the future tokens
        normalized_weights = F.softmax(weights, dim=-1) # softmax along the last dimension
        A = self.att_drop(normalized_weights) # attention with dropout

        # compute the output
        # gather heads and project to correct dimension
        attention = attention.transpose(1, 2).contiguous().view(B, T, N*D)

        out = self.proj_drop(self.proj_net(attention))

        return out