In [14]:
import torch
import torch.nn as nn
from torch.nn import functional as F

import numpy
import math

In [60]:
class CausalSelfAttention(nn.Module):

    def __init__(self, input_dim, n_heads, head_embed, dropout_p):
        super().__init__()
        self.input_dim = input_dim
        self.n_heads = n_heads
        self.head_embed = head_embed
        self.dropout_p = dropout_p
        
        self.qkv_proj = nn.Linear(input_dim, 3 * head_embed * n_heads, bias = False)
        
        self.output_proj = nn.Linear(head_embed * n_heads, head_embed * n_heads)

        self.dropout_output = nn.Dropout(dropout_p)

    def forward(self, x):
        B, T, input_dim = x.size()

        q, k, v = self.qkv_proj(x).chunk(3, dim = 2)
        q = q.view(B, T, n_heads, head_embed).transpose(-2, -1) # B, n_heads, T, head_embed
        k = k.view(B, T, n_heads, head_embed).transpose(-2, -1) # B, n_heads, T, head_embed
        v = v.view(B, T, n_heads, head_embed).transpose(-2, -1) # B, n_heads, T, head_embed
        
        output = nn.functional.scaled_dot_product_attention(q, k, v, dropout_p = self.dropout_p, is_causal = True) # B, n_heads, T, head_embed
        output = output.transpose(1,2).contiguous().view(B, T, n_heads * head_embed) 

        output = self.dropout_output(self.output_proj(output))

        return output

class LayerNorm(nn.Module):
    def __init__(self, ndim, bias):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, input):
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)

def new_gelu(x):
    """
    Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
    Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415
    """
    return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))

class MLP(nn.Module):

    def __init__(self, input_dim, hidden_dim, output_dim, dropout_p):
        super().__init__()
        self.c_fc = nn.Linear(input_dim, hidden_dim, bias = False)
        self.c_proj = nn.Linear(hidden_dim, output_dim, bias = False)
        self.dropout = nn.Dropout(dropout_p)

    def forward(self, x):
        x = self.c_fc(x)
        x = new_gelu(x)
        x = self.c_proj(x)
        return self.dropout(x)

class Block(nn.Module):

    def __init__(self, n_heads, head_embed, dropout_p):
        super().__init__()
        self.ln_1 = LayerNorm(n_heads * head_embed, bias = False)
        self.attn = CausalSelfAttention(n_heads * head_embed, n_heads, head_embed, dropout_p)
        self.ln_2 = LayerNorm(n_heads * head_embed, bias = False)
        self.mlp = MLP(n_heads * head_embed, 4 * n_heads * head_embed, n_heads * head_embed, dropout_p)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

class DecoderTransformer(nn.Module):

    def __init__(self, seq_len, input_dim, n_blocks, n_heads, head_embed, dropout_p):
        super().__init__()
        self.input_mlp = MLP(input_dim, n_heads * head_embed, n_heads * head_embed, dropout_p)
        self.block_layers = nn.Sequential(*[Block(n_heads, head_embed, dropout_p) for i in range(n_blocks)])
        self.flatten = nn.Flatten(start_dim = -2) # B, T, (n_heads * head_embed) => B, (T * n_heads * head_embed)-
        self.output_mlp = MLP(seq_len * n_heads * head_embed,  seq_len * n_heads * head_embed, input_dim, dropout_p)

    def forward(self, x):
        x = self.input_mlp(x)
        for layer in self.block_layers:
            x = layer(x)
        x = self.flatten(x)
        x = self.output_mlp(x)
        return x

In [61]:
a = torch.randn(10, 20,30)
flatten = nn.Flatten(-2)
flatten(a).size()

torch.Size([10, 600])

In [62]:
B = 30
T = 10
input_dim = 3
x = torch.randn(B, T, input_dim)

n_heads = 5
head_embed = 12
n_blocks = 3
transformer = DecoderTransformer(T, input_dim, n_blocks, n_heads, head_embed, dropout_p = 0.1)

output = transformer(x)

In [63]:
output.size()

torch.Size([30, 3])

In [64]:
output[0]

tensor([-0.0872, -0.0219,  0.0000], grad_fn=<SelectBackward0>)

In [65]:
output

tensor([[-0.0872, -0.0219,  0.0000],
        [ 0.0000,  0.1498, -0.0055],
        [-0.0969,  0.0209,  0.1691],
        [ 0.0090,  0.0396,  0.1347],
        [ 0.0000,  0.1975,  0.0000],
        [-0.0358,  0.1874,  0.0247],
        [-0.1399,  0.1477,  0.3660],
        [-0.1733,  0.0761, -0.0695],
        [ 0.0000,  0.0000,  0.0911],
        [ 0.1030,  0.0042,  0.0493],
        [ 0.0600,  0.0232,  0.1133],
        [ 0.1529,  0.0000,  0.0398],
        [-0.0022, -0.0204,  0.1763],
        [ 0.0077, -0.0000,  0.0000],
        [-0.0228,  0.0279,  0.0882],
        [-0.1414, -0.0218,  0.2705],
        [-0.1848, -0.1446,  0.1293],
        [-0.0097,  0.1370,  0.2138],
        [ 0.0524, -0.0408,  0.0000],
        [-0.1029, -0.0034,  0.0000],
        [-0.1501,  0.0475,  0.1693],
        [ 0.0376,  0.0168,  0.1765],
        [ 0.0251, -0.1684, -0.0251],
        [ 0.1444,  0.0000,  0.1557],
        [-0.2240,  0.1352,  0.0896],
        [-0.0319, -0.1210,  0.1697],
        [ 0.0344, -0.0119, -0.0269],
 