In [1]:
import torch
import torch.nn as nn

In [2]:
seq_len = 100
batch_size = 32
d_model = 512

In [3]:
class Attention(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.softmax = nn.Softmax(-1)
        self.scale = d ** 0.5

    def forward(self, q, k, v):
        weights = q @ k.mT
        causal_mask = self._get_causal_mask(q.shape[-2])
        weights = weights.masked_fill(causal_mask, -torch.inf)
        scores = self.softmax(weights / self.scale)
        out = scores @ v
        return out

    def _get_causal_mask(self, seq_len):
        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
        return mask

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, bias=True):
        super().__init__()
        self.n_heads = n_heads
        assert not d_model % n_heads, "d_model must be divisible by n_heads"
        d_k = d_model // n_heads
        self.w_q = nn.Linear(d_model, d_model, bias=bias)
        self.w_k = nn.Linear(d_model, d_model, bias=bias)
        self.w_v = nn.Linear(d_model, d_model, bias=bias)
        self.attention = Attention(d_k)
        self.w_o = nn.Linear(d_model, d_model, bias=bias)        

    def forward(self, x):
        q = self.w_q(x)
        k = self.w_k(x)
        v = self.w_v(x)
        queries = torch.stack(q.chunk(self.n_heads, -1), 0)
        keys = torch.stack(k.chunk(self.n_heads, -1), 0)
        values = torch.stack(v.chunk(self.n_heads, -1), 0)
        heads = self.attention(queries, keys, values).unbind(0)
        out = self.w_o(torch.cat(heads, -1))
        return out
    
    def _get_causal_mask(self, seq_len):
        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
        return mask

In [4]:
d_model, n_heads, d_ff = 512, 8, 2048

In [5]:
mha = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
n_params = sum(p.numel() for p in mha.parameters() if p.requires_grad)
print(n_params)

custom_mha = MultiHeadAttention(d_model, n_heads)
n_custom_params = sum(p.numel() for p in custom_mha.parameters() if p.requires_grad)
print(n_custom_params)

1050624
1050624


In [6]:
batch_size, seq_len = 8, 10
x = torch.randn(2, 10, d_model)
out1, _ = mha(x, x, x)
out2 = custom_mha(x)
print(out1.shape, out2.shape)

torch.Size([2, 10, 512]) torch.Size([2, 10, 512])


In [7]:
class FeedFowardNetwork(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        out = self.fc2(self.relu(self.fc1(x)))
        return out

In [8]:
class TransformerLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff):
        super().__init__()
        self.mha = MultiHeadAttention(d_model, n_heads)
        self.ffn = FeedFowardNetwork(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        z1 = self.mha(x)
        z1 = self.norm1(z1 + x)
        z2 = self.ffn(z1)
        z2 = self.norm2(z2 + z1)
        return z2

In [9]:
transformer_encoder = nn.TransformerEncoderLayer(d_model, n_heads, d_ff, batch_first=True)
n_params = sum(p.numel() for p in transformer_encoder.parameters() if p.requires_grad)
print(n_params)

transformer_layer = TransformerLayer(d_model, n_heads, d_ff)
n_params = sum(p.numel() for p in transformer_layer.parameters() if p.requires_grad)
print(n_params)

3152384
3152384


In [10]:
class Transformer(nn.Module):
    def __init__(self, vocab_size, n_layers, d_model, n_heads, d_ff, max_len):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pe = nn.Parameter(torch.rand((max_len, d_model)))
        self.transformer_layers = nn.ModuleList([
            TransformerLayer(d_model, n_heads, d_ff) for _ in range(n_layers)
        ])

    def forward(self, x):
        x = self.embedding(x)
        seq_len = x.shape[1]
        x = x + self.pe[:seq_len, :]
        for layer in self.transformer_layers:
            x = layer(x)
        return x

In [11]:
vocab_size = 10000
n_layers = 6
max_len = 500
model = Transformer(vocab_size, n_layers, d_model, n_heads, d_ff, max_len)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(n_params)

24290304


In [15]:
model = nn.TransformerEncoder(
    nn.TransformerEncoderLayer(d_model, n_heads, d_ff, batch_first=True),
    num_layers=n_layers
)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
n_params_embedding = vocab_size * d_model + max_len * d_model
print(n_params + n_params_embedding)

24290304


In [12]:
x = torch.randint(0, vocab_size, (batch_size, seq_len))
out = model(x)
print(out.shape)

torch.Size([8, 10, 512])
