In [None]:
# LLM 
# Autoregressive Generierung 
# Next Token Prediction 
# 1. Schritt Decoder [ <bos> ] -> "ich"
# 2. Schritt Decoder [ (<bos>,"ich")] -> "bin" 
# 3. Schritt Decoder [ (<bos>,"ich", "bin")] -> "ein"
# 4. Schritt Decoder [ (<bos>,"ich", "bin", "ein")] -> "Junge"
# 5. Schritt Decoder [ (<bos>,"ich", "bin", "ein", "Junge")] -> "<eos>"
# Masked Language Modelling
# BERT (Encoder-Only)
# Ich gehe in die ________ und trinke Bier. 
# ["<bos> , "Ich", ....., <mask>, "und", .... "<eos"] -> "Kneipe"

# (<bos>,"ich", "bin", "ein", <mask>) -> "Junge"

In [1]:
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
import time

In [None]:
class CausalSelfAttention(nn.Module):
    def __init__(self,config):
        super().__init__()
        assert config.n_embd % config.n_head == 0 # Wir spalten die interne Darstellung 
                                                  # und dann wieder am Ende vereinigen 
        self.c_attn = nn.Linear(config.n_embd, 3*config.n_embd)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)

        self.n_head = config.n_head
        self.n_embd = config.n_embd
    
    def forward(self, x):
        B, T, C = x.size() # batch_size, T = Sequenzlänge, C = n_embd 
        qkv = self.c_attn(x)
        q,k,v = qkv.split(self.n_embd, dim=2)
        # n_head = nh, C = nh * hs, hs ist Head Size 
        # C= 768 , n_head = 12, hs = 64 
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        y = F.scaled_dot_product_attention(q,k,v, is_causal=True) # flash attention / Effiziente Variante 
        y = y.transpose(1, 2).contiguous().view(B, T, C) # Die einzelnen Teilheads wieder vereinigen 
        # Ausgangsprojektion
        y = self.c_proj(y) 
        return y



class MLP(nn.Module):

    def __init__(self,config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4* config.n_embd)
        self.gelu = nn.GELU()
        self.c_proj = nn.Linear(config.n_embd*4, config.n_embd)
    
    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x 

class Block(nn.Module):

    def __init__(self):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config) # Causal Self-Attention = Masked Self-Attention 
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self,x):
        x = x + self.attn(self.ln_1(x)) # PreNorm 
        x = x + self.mlp(self.ln_2(x))
        return x 



@dataclass
class GPTConfig:
    block_size: int = 1024      # Maximale Anzahl der Positionen = Länge vom Kontextfenster 
    vocab_size: int = 50257     # Anzahl der Tokens (gpt2-Tokenizer)
    n_layer: int = 12           # Anzahl der Transformerblöcken 
    n_head: int = 12            # 
    n_embd: int = 768 


class GPT(nn.Module): # Generative Pretrained Transformer 2 
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd), # Kontextfenster 
            h = nn.Module(Block(config) for _ in range(config.n_layer)),
            ln_f = nn.LayerNorm(config.n_embd),
        )
        )
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # 

    def forward(self, idx, targets=None):
        # idx = (B, T) B ist batch_size, T ist Sequenzlänge
        B, T = idx.size()
        assert T <= self.config.block_size 
        pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
        pos_emb = self.transformer.wpe(pos) #   (T, n_embd)
        tok_emb = self.transformer.wte(idx) # (B, T, n_embd)
        x = tok_emb + pos_emb

        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x) # FC 
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        
        return logits, loss