<a href="https://colab.research.google.com/github/grahamstelzer/fundamentals/blob/main/transformer_basics.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# setup
import torch
import torch.nn as nn
import math

In [None]:
torch.manual_seed(0)
print("pytorch version: ", torch.__version__)

In [None]:
# "config"
batch_size = 2
seq_len = 4
input_dim = 16
d_model = 8
num_heads = 2
d_ff = 32

In [None]:
# "base" mha
class MultiHeadAttention(nn.Module):

    # constructor
    def __init__(self, d_model, num_heads):
        super().__init__() # call nn.Module constructor
        assert d_model % num_heads == 0

        # member variables
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        # set transformations for weights
        #   q, k, v and output
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    # forward
    def forward(self, x):

        # shape usually (batch_size, seq_length, d_model) but sometimes not
        B, S, _ = x.shape

        # apply transformations
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)

        # change tensors from (B, S, d_model) -> (B, S, head_idx, d_model // num_heads)
        #   or in other words, split the d_model dimension into 1 section per head
        #   (this is why we need to do assert d_model % num heads == 0, otherwise uneven split)

        def reshape(t):
            t_reshaped = t.view(B, S, self.num_heads, self.head_dim)
            t_heads = t_reshaped.transpose(1, 2)
            return t_heads

        Qh = reshape(Q)
        Kh = reshape(K)
        Vh = reshape(V)


        # scaled dot product attention:
        scores = Qh @ Kh.transpose(-2, -1) # Q * K_t
        scores = scores / math.sqrt(self.head_dim) # divide by sqrt d_model accounting for multihead

        attn = scores.softmax(dim=-1) # softmax

        out = attn @ Vh # multiply by V

        out = out.transpose(1,2).contiguous().view(B, S, self.d_model) # combine heads

        out = self.W_o(out) # final linear transform

        return out




In [None]:
# feed forward (position-wise)
class FFN(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )

    def forward(self, x):
        out = self.net(x)
        return out


In [None]:
# "transformer block" - though i think this could also be referred to as "encoderblock"
#   apparently standard convention these days is to use a transformer block with toggleable
#   cross attention and attention mask, then use for both "encoder" and "decoder

class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.mha = MultiHeadAttention(d_model, num_heads)
        self.norm2 = nn.LayerNorm(d_model)
        self.ffn = FFN(d_model, d_ff)

    def forward(self, x):

        # mha and add/norm
        x_norm = self.norm1(x)
        attn_out = self.mha(x_norm)
        x = x + attn_out

        # ff and add/norm again
        x_norm2 = self.norm2(x)
        ffn_out = self.ffn(x_norm2)
        x = x + ffn_out

        return x


In [None]:
# "encoder" (not encoderblock), will stack the tformerblock
class Encoder(nn.Module):
    def __init__(self, input_dim, d_model, num_heads, d_ff, num_layers=2):
        super().__init__()

        # must make sure input is correct dimension (? double check)
        self.input_proj = nn.Linear(input_dim, d_model)

        # setup block layers
        self.layers = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff)
            for _ in range(num_layers)
        ])

    def forward(self, x):
        x = self.input_proj(x)

        for i, layer in enumerate(self.layers):
            x = layer(x)

        return x

In [None]:
# "main" to test just the model code
x = torch.randn(batch_size, seq_len, input_dim)
encoder = Encoder(input_dim, d_model, num_heads, d_ff, num_layers=2)

output = encoder(x)
print(output.shape)
print(output)
print(output.grad)