In [None]:
# Writing up a decoder-only GPT architecture
# Multihead Attention
# Positional Encoding
# Feedforward
# Decoder
# GPT

In [None]:
import torch.nn as nn
import torch

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, seq_length):
        assert d_model % n_heads == 0

        self.n_heads = n_heads
        self.d_model = d_model
        self.seq_length = seq_length
        self.d_k = d_model // n_heads

        # Because it's a decoder input, we now use this version of the attention initialization
        self.W_attn = nn.Linear(d_model, 3 * d_model) 
        self.W_o = nn.Linear(d_model, d_model)

        self.register_buffer("causal_mask", torch.tril(torch.ones(seq_length, seq_length)).view(1, 1, seq_length, seq_length))

    def split_heads(self, x):
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.n_heads, self.d_k).transpose(1, 2)

    def attention(self, q, k, v):
        attn_scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.d_k)
        
        # Apply causal masking
        causal_mask = self.causal_mask[:, :, :self.seq_length, :self.seq_length]
        attn_scores = attn_scores.masked_fill(causal_mask == 0, 1e-9)

        # Apply softmax
        attn_probs = nn.functional.softmax(attn_scores, dim=-1)
        attn_weight = torch.matmul(attn_probs, v)
        return attn_weight


    def forward(self, x):
        # Input: [batch_size, seq_length, d_model]
        # Linear transformation: [batch_size, seq_length, d_model * 3]
        # Output: 3 matrix of [batch_size, seq_length, d_model]
        Q, K, V = self.W_attn(x).split(self.d_model, dim=2)
        
        # [batch_size, n_heads, seq_length, d_k]
        Q = self.split_heads(Q)
        K = self.split_heads(K)
        V = self.split_heads(V)

        # Attention mechanism
        attn_output = self.attention(Q, K, V)
        
        # Combine heads
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_length, d_model)
        output = self.W_o(attn_output)
        return output

In [None]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.activation = nn.ReLU()

    def forward(self, x):
        out = self.activation(self.fc1(x))
        out = self.fc2(out)
        return out

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, seq_length, dropout):
        self.attn = MultiHeadAttention(d_model, h_heads, seq_length)
        self.ff = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.dropout(self.ff(self.norm2(x)))
        return x

In [None]:
class GPT2(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, seq_length, vocab_size, dropout, n_layers):
        self.wte = nn.Embedding(vocab_size, d_model)
        self.wpe = nn.Embedding(seq_length, d_model)
        self.seq_length = seq_length
        self.decoder = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ff, seq_length) for _ in range(n_layers)])
        
        self.norm = nn.LayerNorm(d_model)
        self.fc = nn.Linear(d_model, vocab_size)
        self.dropout = nn.Dropout(p=dropout)

        # Special weight initialization
        for pn, p in self.named_parameters():
            if pn.endswith('W_o.weights'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * n_layers))

    def forward(self, x, mask):
        # positional tokens
        pos = torch.arange(0, self.seq_length).unsqueeze(0)

        token_embeds = self.wte(x)
        pos_embeds = self.wpe(pos)
        x = self.dropout(token_embeds + pos_embeds)

        for layer in self.decoder:
            x = layer(x, mask)
        
        # The raw logits are outputted because the torch CrossEntropyLoss function doesn't include 
        logits = self.fc(self.norm(x))
        return logits