In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
from dataclasses import dataclass

@dataclass
class GPTConfig:
    sequence_len: int = 1024
    vocab_size: int = 50304
    n_layer :int = 12
    n_head: int = 6
    n_embd: int = 768

In [None]:
class CausalSelfAttention(nn.Module):
    def __init__(self, config, layer_idx):
        super().__init__()     
        self.layer_idx = layer_idx
        self.n_head = config.n_head
        self.n_kv_head = config.n_kv_head
        self.n_embd = config.n_embd
        self.head_dim = self.n_embd // self.n_head
        assert self.n_embd % self.n_head == 0
        self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
        self.c_k = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
        self.c_v = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
        self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)

    def forward(self, x):
        B, T, C =  x.size()

        q = self.c_q(x).view(B, T, self.n_head, self.head_dim) # B, T, H, D
        k = self.c_k(x).view(B, T, self.n_head, self.head_dim) # B, T, H, D
        v = self.c_v(x).view(B, T, self.n_head, self.head_dim) # B, T, H, D

        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)  # B, H, T, D

        y = F.scaled_dot_product_attention(q, k, v, is_causal=True) # B, H, T, D

        y = y.transpose(1, 2).contigous().view(B, T ,-1) # B, H, T, D -> B, T, H, D -> B, T, C

        y = self.c_proj(y) # B, T, C
        
        return y

In [None]:
class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4* config.n_embd, bias=False)
        self.c_proj = nn.Linear(4*config.n_embd, config.n_embd, bias=False)
    
    def forward(self, x):
        x = self.c_fc(x)
        x = F.relu(x).square()
        x = self.c_proj(x)
        return x
