In [1]:
import torch
import torch.nn as nn
import math

# -------------------------------
# GELU Activation (GPT uses this)
# -------------------------------
class GELU(nn.Module):
    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * (x ** 3))))


# -------------------------------
# Multi-Head Self Attention
# -------------------------------
class SelfAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        # Combined QKV projection (like GPT-2)
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.output_proj = nn.Linear(d_model, d_model)

        self.register_buffer("mask", None)

    def forward(self, x):
        #batch, sequence_length, number_attention_head
        B, T, D = x.size()

        # Create causal mask only once
        if self.mask is None or self.mask.size(0) != T:
            self.mask = torch.tril(torch.ones(T, T)).unsqueeze(0).unsqueeze(1)

        qkv = self.qkv(x)  # (B, T, 3*d_model)
        q, k, v = qkv.split(self.d_model, dim=2)

        # Shape into heads
        q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)

        # Attention scores
        scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        scores = scores.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))

        att = torch.softmax(scores, dim=-1)

        # Weighted sum
        out = att @ v  # (B, heads, T, head_dim)
        out = out.transpose(1, 2).contiguous().view(B, T, D)

        return self.output_proj(out)


# -------------------------------
# Feed Forward MLP (GPT uses GELU)
# -------------------------------
class FeedForward(nn.Module):
    def __init__(self, d_model, ff_dim):
        super().__init__()
        self.fc1 = nn.Linear(d_model, ff_dim)
        self.gelu = GELU()
        self.fc2 = nn.Linear(ff_dim, d_model)

    def forward(self, x):
        return self.fc2(self.gelu(self.fc1(x)))


# -------------------------------
# Decoder Block
# -------------------------------
class DecoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, ff_dim):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = SelfAttention(d_model, num_heads)
        self.ln2 = nn.LayerNorm(d_model)
        self.ff = FeedForward(d_model, ff_dim)

    def forward(self, x):
        # 1. Attention + Residual
        x = x + self.attn(self.ln1(x))

        # 2. MLP + Residual
        x = x + self.ff(self.ln2(x))

        return x


# -------------------------------
# GPT-Style Decoder-only Model
# -------------------------------
class GPT(nn.Module):
    def __init__(self, vocab_size, d_model=512, num_heads=8, n_layers=6, max_len=1024, ff_dim=2048):
        super().__init__()

        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_len, d_model)

        self.blocks = nn.ModuleList([
            DecoderBlock(d_model, num_heads, ff_dim)
        for _ in range(n_layers)])

        self.ln_final = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, idx):
        B, T = idx.size()

        tok = self.tok_emb(idx)
        pos = self.pos_emb(torch.arange(T, device=idx.device))

        x = tok + pos  # input embeddings

        for block in self.blocks:
            x = block(x)

        x = self.ln_final(x)

        logits = self.lm_head(x)  # (B, T, vocab_size)
        return logits
