In [None]:
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F  
import math
import os
import time
import inspect


class CasualSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.h_head == 0
        # key query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        # regularlization
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        # not really a bias, more of a mask, but following the OpenAi/HF naming
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size()
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        # attention
        att = (q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.size(-1))))  #(B, n_head, T, head_dim) @ (B, n_head, head_dim, T) ==> (B, n_head, T, T)
        # q: (2, 4, 5, 64)
        # kᵀ: (2, 4, 64, 5)
        # att = q @ kᵀ: (2, 4, 5, 5)
        att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
        # shape will become (1, 1, T, T)
        # [[[1, 0, 0, 0],
        #   [1, 1, 0, 0],
        #   [1, 1, 1, 0],
        #   [1, 1, 1, 1]]]
        att = F.softmax(att, dim=-1)
        # softmax does not change the dimensionality; 
        # it just normalizes the values along the specified dimension (dim=-1).
        y = att @ v  # (B, n_head, T, head_dim):
        y = y.transpose(1, 2).contigious().view(B, T, C)
        # (B, T, C)
        y = self.c_proj(y)
        return y
    
class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.gelu = nn.GELU(approximate='tanh')
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
    
    def forward(self, x):
        x = self.c_fc()
        x = self.gelu(x)
        x = self.c_proj(x)
        return x

@dataclass
class GPTConfig:
    block_size: int = 1024  # max sqeunce length
    vocab_size: int = 50257  # 65 number of tokes: 50000 BPE merges + 256 bytes tokens + <|endoftext|>
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768 # embedding dimension
