In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch.autograd import Function
from torch import nn
import torch.nn.functional as F
import ipdb
import math

In [3]:
torch.set_printoptions(precision=7)

# nn.GELU

In [4]:
x = torch.tensor([0.1, 1.0, 1.2])
gelu = nn.GELU()
gelu_approx = nn.GELU("tanh")
print(gelu(x))
display(gelu_approx(x))

tensor([0.0539828, 0.8413447, 1.0619165])


tensor([0.0539828, 0.8411920, 1.0617028])

# CausalSelfAttention

In [7]:
class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads
        self.key = nn.Linear(config.n_embd, config.n_embd)
        self.query = nn.Linear(config.n_embd, config.n_embd)
        self.value = nn.Linear(config.n_embd, config.n_embd)
        # regularization
        self.attn_drop = nn.Dropout(config.attn_pdrop)
        self.resid_drop = nn.Dropout(config.resid_pdrop)
        # output projection
        self.proj = nn.Linear(config.n_embd, config.n_embd)
        # causal mask to ensure that attention is only applied to the left in the input sequence
        self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size))
                                     .view(1, 1, config.block_size, config.block_size))
        ## mask previous value estimates
        if "action_dim" in dir(config):
            joined_dim = config.observation_dim + config.action_dim + 2
            self.mask.squeeze()[:,joined_dim-1::joined_dim] = 0
        ##
        self.n_head = config.n_head

    def forward(self, x, layer_past=None):
        B, T, C = x.size()

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        ## [ B x n_heads x T x head_dim ]
        k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        ## [ B x n_heads x T x T ]
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        self._attn_map = att.clone()
        att = self.attn_drop(att)
        ## [ B x n_heads x T x head_size ]
        y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        ## [ B x T x embedding_dim ]
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_drop(self.proj(y))
        return y

In [8]:
class Config:
    pass

cfg = Config()
cfg.n_embd = 8
cfg.block_size = 12
cfg.action_dim=1
cfg.observation_dim=2
cfg.attn_pdrop=0.0
cfg.resid_pdrop=0.0
cfg.n_head = 2
def main():
    csa = CausalSelfAttention(cfg)
    x = torch.randn(3, 2, 8)
    return csa(x)
main()

tensor([[[-0.2076447, -0.1660343, -0.1446521,  0.8652914, -0.3342855,
          -0.2586271,  0.2681955,  0.3942418],
         [-0.0934719, -0.1169063,  0.1851222,  0.7146784, -0.1560096,
          -0.2340157,  0.0883798,  0.0436949]],

        [[-0.1535415, -0.3847658,  0.4095368,  0.3740130, -0.2884263,
          -0.1888891, -0.4992743, -0.6241646],
         [-0.2291546, -0.2408308,  0.1749529,  0.4173355, -0.1803498,
          -0.2365421, -0.2945327, -0.1723328]],

        [[ 0.0732064, -0.0597422,  0.1265753,  0.8852328, -0.1112194,
          -0.2290128, -0.1146231, -0.2866505],
         [ 0.1019692,  0.0156879,  0.3203210,  0.6205373, -0.0756834,
          -0.0020490, -0.3810693, -0.5421953]]], grad_fn=<ViewBackward0>)

# Block

In [12]:
class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.ln2 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.mlp = nn.Sequential(
            nn.Linear(config.n_embd, 4 * config.n_embd),
            nn.GELU(),
            nn.Linear(4 * config.n_embd, config.n_embd),
            nn.Dropout(config.resid_pdrop),
        )

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

In [14]:
def block_test():
    block = Block(cfg)
    x = torch.randn(3, 2, 8)
    return block(x)
block_test().shape

torch.Size([3, 2, 8])

# Mask

In [79]:
def main():
    mask = torch.tril(torch.ones(650,650))
    joined_dim = 25
    mask.squeeze()[:,joined_dim-1::joined_dim] = 0
    return mask.sum()
main()

tensor(203424.)

# Dropout

In [82]:
def main():
    dp = nn.Dropout(0.1)
    x = torch.randn(10,10)
    return dp(x)
main()

tensor([[-1.3905e+00, -4.6586e-01,  7.7713e-02,  4.4452e-01,  1.4362e-01,
         -1.8046e-01, -1.0321e+00,  1.1369e+00, -4.5345e-01, -9.0033e-01],
        [-1.1170e-03, -1.8191e-02, -8.9934e-01,  2.9823e-01,  1.2260e+00,
          2.1215e-01,  9.9756e-01,  1.5730e+00, -0.0000e+00, -1.6621e+00],
        [ 1.2702e+00, -1.5433e+00,  4.2365e-01,  1.5921e+00,  2.3519e-01,
          4.2690e-01, -9.3740e-01, -4.9528e-01, -3.7724e-01,  1.4033e+00],
        [-1.1512e+00,  3.0650e-01, -9.3674e-02, -7.4107e-01, -6.8297e-02,
         -1.5684e+00,  1.6519e-01, -9.2251e-02,  0.0000e+00,  4.4000e-01],
        [-1.0474e+00,  3.4574e-02,  1.2142e+00, -1.3816e-01,  1.4342e+00,
         -5.6411e-01, -1.2610e+00, -0.0000e+00, -1.1343e+00,  9.3370e-01],
        [-2.0787e+00, -1.8534e-01,  1.2783e-02, -1.6651e-01,  3.5310e-02,
          4.7878e-01, -1.2086e+00, -0.0000e+00, -4.2451e-01, -5.1759e-01],
        [-2.2724e+00,  4.8052e-01,  2.1005e+00, -5.8680e-01, -1.1117e+00,
          8.6624e-01,  4.7306e-0

In [91]:
x = torch.randn(4, 1, 3, 3)
y = torch.ones(1, 2, 3, 4)
(x @ y).shape

torch.Size([4, 2, 3, 4])