# 1. Import Libraries

In [6]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from pygments.lexers.bibtex import BibTeXLexer
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from dataclasses import dataclass

# 2. Define Parameters

In [7]:
@dataclass
class GPTConfig:
    # Text length
    block_size: int = 512

    batch_size: int = 12
    n_layer: int = 12
    n_head: int = 12

    # hidden_dim / hidden_size
    # -> tie embedding_weight
    n_embed: int = 768 
    hidden_dim: int = n_embed

    dropout: float = 0.1
    head_size: int = n_embed // n_head

    # Official gpt2 tokenizer
    vocab_size: int = 50257

# 3. Define GPT Sturcture

### 3.1 Single-Head Attention

In [10]:

class SingleHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.head_size = config.head_size
        self.key = nn.Linear(config.hidden_dim, config.head_size)
        self.value = nn.Linear(config.hidden_dim, config.head_size)
        self.query = nn.Linear(config.hidden_dim, config.head_size)

        # Register attention_mask through register_buffer
        # No calc grad -> less ram & faster
        # Decoder
        self.register_buffer(
            "attention_mask",
            # tril: 下三角
            # block_size: 512
            torch.tril(
                torch.ones(config.block_size, config.block_size)
            )
        )

        self.dropout = nn.Dropout(config.dropout)


    def forward(self, x):
        # Batch: Group of data
        # seq_len: Max tokens
        batch_size, seq_len, hidden_dim = x.size()

        k = self.key(x)
        q = self.query(x)
        v = self.value(x)

        # The last two rows
        # @ -> torch.matmul
        # q 点积 k
        weight = q @ k.transpose(-2, -1)

        weight = weight.masked_fill(
            self.attention_mask[:seq_len, :seq_len] == 0,
            float("inf")  # After softmax -> 0
        )

        # Divided by d_k(dimension of the key vector) when calculating weight
        # Avoid high score, low grad after softmax
        weight = weight / math.sqrt(self.head_size())
        weight = F.softmax(weight, dim=-1)

        # Dropout attention weight
        weight = self.dropout(weight)

        # Dropout after weight
        out = weight @ v

        return out

### 3.2 Multi-head Attention
Four weights tensor (q, k, v, proj)

In [12]:
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.heads = nn.ModuleList(
            [SingleHeadAttention(config) for _ in range(config.n_head)]
        )
        self.proj = nn.Linear(config.hidden_dim, config.hidden_dim)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        output = torch.cat(
            [h(x) for h in self.heads],
            dim=-1
        )
        output = self.proj(output)
        output = self.dropout(output)
        return output

### 3.3 Feed Forward

In [19]:
class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(config.hidden_dim, 4 * config.hidden_dim),  # swiglu  -> 8/3
            nn.GELU(),  # 激活
            nn.Linear(4 * config.hidden_dim, config.hidden_dim),  # 降维
            nn.Dropout(config.dropout)
        )

    def forward(self, x):
        return self.net(x)

### 3.4 Block

In [21]:
class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.att = MultiHeadAttention(config)
        self.ffn = FeedForward(config)
        self.ln1 = nn.LayerNorm(config.hidden_dim)
        self.ln2 = nn.LayerNorm(config.hidden_dim)

    def forward(self, x):
        x = x + self.att(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x

### 3.4 GPT (embedding, position, norm, mlp, block)

In [27]:
class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()

        # (embedding, position, norm, mlp, block)
        # position embedding: 0, 1, xxx embedding -> rope
        # norm: layer norm -> rms norm
        # mlp -> swiglu
        # mha -> gqa
        self.token_embedding = nn.Embedding(config.vocab_size, config.n_embed)
        self.position_embedding = nn.Embedding(config.block_size, config.block_size)
        self.blocks = nn.Sequential(
            *[Block(config) for _ in range(config.n_layer)]
        )
        self.ln_final = nn.LayerNorm(config.n_embed)

        # Softmax -> no bias
        self.lm_head = nn.Linear(config.n_embed, config.voca_size, bias=False)

        # SLM model will use tie weight to decrease parameters
        # Linear: 4d -> 8d ( 8 x 4 )
        self.token_embedding_table.weight = self.lm_head.weight

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):

            # normal distribution
            torch.nn.init.normal_(module.weight, mean=0, std=0.02)

            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0, std=0.02)

    def forward(self, idx, targets=None):
        # idx: token ids
        # target: target token idx
        # Same shape
        batch, seq_len = idx.size()

        # (batch seq_len, n_embed)
        token_emb = self.token_embedding_table(idx)

        pos_emb = self.position_embedding_table(
            # pos_emb and idx on same device
            torch.arrange(seq_len, device=idx.device)
        )

        # ? token_embedding + position+embedding
        x = token_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_final(x)

        # Shape: (batch, seq_len, vocab_size)
        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            

# 4. Build Input Dataset