In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass

In [2]:
class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        self.c_attn = nn.Linear(config.n_embd, 3*config.n_embd)
        self.n_embd = config.n_embd
        self.register_buffer('bias', torch.tril(torch.ones(config.n_embd, config.n_embd))).view(1, 1, config.n_embd, config.n_embd)

    def forward(self, x):
        B,T,C = x.size()
        qkv = self.c_attn(x)
        q,k,v = qkv.chunk(3, dim=2)
        att = (q @ k.transpose(-2, -1)) * (self.n_embd ** -0.5)
        att = att.masked_fill(self.bias == 0, float('-inf'))
        y = att @ v
        return y

In [3]:
@dataclass
class GPTConfig:
    n_embd:int = 8
    n_heads:int = 8
    n_layer:int = 12
    vocab_size:int = 64
    block_size:int = 8

In [4]:
class MLP(nn.Module):
    def __init__(self, config):
        self.net = nn.Sequential([
            nn.Linear(config.n_embd, 4 * config.n_embd),
            nn.GELU(),
            nn.Linear(4 * config.n_embd, config.n_embd),
        ])
    def forward(self, x):
        return self.net(x)

In [5]:
class Block(nn.Module):
    def __init__(self, config):
        self.c_attn = CausalSelfAttention(config)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.ln_2 = nn.LayerNorm(config.n_embd)
    def forward(self, x):
        x = x + self.c_attn(self.ln_1(x))
        x = x + self.c_proj(self.ln_2(x))
        return x

In [6]:
class GPT(nn.Module):
    def __init__(self, config:GPTConfig):
        self.config = config
    def forward(self, idx):
        

SyntaxError: invalid syntax (1070035621.py, line 3)