# 1. Import Libraries

In [6]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from dataclasses import dataclass

# 2. Define Parameters

In [7]:
@dataclass
class GPTConfig:
    # Text length
    block_size: int = 512

    batch_size: int = 12
    n_layer: int = 12
    n_head: int = 12

    # hidden_dim / hidden_size
    # -> tie embedding_weight
    n_embed: int = 768 
    hidden_dim: int = n_embed

    dropout: float = 0.1
    head_size: int = n_embed // n_head

    # Official gpt2 tokenizer
    vocab_size: int = 50257

# 3. Define GPT Sturcture

### 3.1 Single Head Attention

In [10]:

class SingleHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.head_size = config.head_size
        self.key = nn.Linear(config.hidden_dim, config.head_size)
        self.value = nn.Linear(config.hidden_dim, config.head_size)
        self.query = nn.Linear(config.hidden_dim, config.head_size)

        # Register attention_mask through register_buffer
        # No calc grad -> less ram & faster
        # Decoder
        self.register_buffer(
            "attention_mask",
            # tril: 下三角
            # block_size: 512
            torch.tril(
                torch.ones(config.block_size, config.block_size)
            )
        )

        self.dropout = nn.Dropout(config.dropout)


    def forward(self, x):
        batch_size, seq_len, hidden_dim = x.size()
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)

        # The last two rows
        # @ -> torch.matmul
        # 点积
        weight = q @ k.transpose(-2, -1)

        weight = weight.masked_fill(
            self.attention_mask[:seq_len, :seq_len] == 0,
            float("inf")  # After softmax -> 0
        )

        # Divided by d_k(dimension of the key vector) when calculating weight
        # Avoid high score, low grad after softmax
        weight = weight / math.sqrt(self.head_size())
        weight = F.softmax(weight, dim=-1)

        # Dropout attention weight
        weight = self.dropout(weight)

        # Dropout after weight
        out = weight @ v

        return out

### 3.2 Multi-head attention

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.head = nn.ModuleList(
            [SingleHeadAttention(config) for _ in range(config.n_head)]
        )





In [11]:
# 3. feed forward (MLP)
# 4. block
# 5. GPT (embedding, positoin, norm, mlp, block)