In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchinfo import summary
from tqdm import tqdm

In [None]:
with open("text8.txt", "r") as f:
    text = f.read()

vocab = " abcdefghijklmnopqrstuvwxyz"
char_to_idx = {char: idx for idx, char in enumerate(vocab)}
idx_to_char = {idx: char for char, idx in char_to_idx.items()}

# add unknown token as last token
char_to_idx["<unk>"] = len(char_to_idx)
idx_to_char[len(idx_to_char)] = "<unk>"

print(f"Vocab size: {len(vocab):,}")

In [None]:
text = torch.tensor([char_to_idx[char] for char in text], dtype=torch.long)

In [None]:
seq_len = 128
split = 0.2

train_len = int((1 - split) * len(text))

train_text = text[:train_len]
val_text = text[train_len:]
print(f"Train size: {len(train_text):,}")
print(f"Val size: {len(val_text):,}")

train_text = train_text[:len(train_text) // seq_len * seq_len].view(-1, seq_len)
val_text = val_text[:len(val_text) // seq_len * seq_len].view(-1, seq_len)
print(f"Train size: {train_text.shape}")
print(f"Val size: {val_text.shape}")

train_loader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(train_text),
    batch_size=128,
    shuffle=True,
    num_workers=2,
    persistent_workers=True,
    pin_memory=True,

)
val_loader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(val_text),
    batch_size=128,
    shuffle=False,
    num_workers=2,
    persistent_workers=True,
    pin_memory=True,
)

In [None]:
def encode(text: str) -> torch.Tensor:
    text = text.lower()
    return torch.tensor([char_to_idx[char] if char in char_to_idx else char_to_idx["<unk>"]for char in text], dtype=torch.long)

def decode(tensor: torch.Tensor) -> str:
    tensor = tensor.tolist()
    return "".join([idx_to_char[idx] for idx in tensor])

encoded = encode("Hello, World!")
decoded = decode(encoded)
print(encoded)
print(decoded)

In [None]:
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"

print(f"Using device: {device}")

@torch.no_grad()
def evaluate(model, data_loader):
    model.eval()
    total_loss = 0
    pbar = tqdm(data_loader, desc="Evaluating", leave=False)
    for i, (inputs,) in enumerate(pbar):
        inputs = inputs.to(device)

        x = inputs[:, :-1]
        y = inputs[:, 1:]

        y_pred = model(x)
        loss = model.loss(y_pred, y)
        total_loss += loss

        pbar.set_postfix(loss=f"{loss.item():.4f}")
       
    return total_loss.item() / len(data_loader)

In [None]:
def rotary_position_embedding(dim, seq_len, base=10000):
    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
    positions = torch.arange(seq_len, dtype=torch.float)
    sinusoid_inp = torch.einsum('i,j->ij', positions, inv_freq)
    
    sin_emb = torch.sin(sinusoid_inp)
    cos_emb = torch.cos(sinusoid_inp)

    rotary_emb = torch.stack((sin_emb, cos_emb), dim=-1).reshape(seq_len, dim)
    return rotary_emb

def apply_rotary_pos_emb(q, k, rotary_emb):
    seq_len = q.size(-2)
    rotary_emb = rotary_emb[:seq_len, :]  # Trim to the sequence length if necessary
    cos_pos = rotary_emb[..., 1::2].repeat_interleave(2, dim=-1)
    sin_pos = rotary_emb[..., 0::2].repeat_interleave(2, dim=-1)

    q_rot = (q * cos_pos) + (torch.roll(q, shifts=1, dims=-1) * sin_pos)
    k_rot = (k * cos_pos) + (torch.roll(k, shifts=1, dims=-1) * sin_pos)

    return q_rot, k_rot

class SelfAttention(nn.Module):
    def __init__(self, n_embd: int, n_head: int, attn_dropout: float = 0.0, is_causal: bool = True):
        super().__init__()
        assert n_embd % n_head == 0, f"Embedding dimension {n_embd} should be divisible by number of heads {n_head}"

        self.c_attn = nn.Linear(n_embd, n_embd * 3)
        self.c_proj = nn.Linear(n_embd, n_embd)
        
        self.n_head = n_head
        self.n_embd = n_embd
        self.attn_dropout = attn_dropout
        self.is_causal = is_causal
        
        self.rotary_emb = None

    def forward(self, x) -> torch.Tensor:
        B, T, C = x.size()
        
        # Initialize rotary embedding if not already done
        if self.rotary_emb is None or self.rotary_emb.size(0) < T:
            self.rotary_emb = rotary_position_embedding(C // self.n_head, T, base=10000).to(x.device)
        
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embd, dim=2)

        # (B, T, C) -> (B, T, n_head, C // n_head) -> (B, n_head, T, C // n_head)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        # Apply rotary position embedding
        q, k = apply_rotary_pos_emb(q, k, self.rotary_emb)

        y = (
            F.scaled_dot_product_attention(
                q, k, v, is_causal=self.is_causal, dropout_p=self.attn_dropout
            )
            .transpose(1, 2)
            .contiguous()
            .view(B, T, C)
        )

        # output projection
        y = self.c_proj(y)
        return y

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, emb_dim, atten_dropout=0.):
        super().__init__()
        self.num_heads = num_heads
        self.emb_dim = emb_dim

        self.heads = nn.ModuleList([
            SelfAttention(emb_dim, num_heads, atten_dropout) for _ in range(num_heads)
        ])

        self.fc = nn.Linear(emb_dim * num_heads, emb_dim)

    def forward(self, x):
        heads = [head(x) for head in self.heads]
        x = torch.cat(heads, dim=-1)
        x = self.fc(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, emb_dim, num_heads, m=4):
        super().__init__()
        self.attention = nn.Sequential(
            MultiHeadAttention(num_heads, emb_dim),
        )
        self.ln1 = nn.LayerNorm(emb_dim)
        self.ln2 = nn.LayerNorm(emb_dim)

        self.fc = nn.Sequential(
            nn.Linear(emb_dim, m * emb_dim),
            nn.SiLU(),
            nn.Linear(m * emb_dim, emb_dim),
        )

    def forward(self, x):
        x = x + self.attention(self.ln1(x))
        x = x + self.fc(self.ln2(x))
        return x

class Model(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.tok_emb = nn.Embedding(len(vocab), emb_dim)
        self.pos_emb = nn.Parameter(torch.randn(seq_len, emb_dim))

        self.blocks = nn.Sequential(
            TransformerBlock(emb_dim, num_heads=4, m=4),
            TransformerBlock(emb_dim, num_heads=4, m=4),
            TransformerBlock(emb_dim, num_heads=4, m=4),
        )

        self.lm_head = nn.Linear(emb_dim, len(vocab), bias=False)
        
        # init weights
        self.lm_head.weight = self.tok_emb.weight
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    def forward(self, x):
        if x.shape[1] > seq_len:
            x = x[:, -seq_len:]

        tok_emb = self.tok_emb(x)
        pos_emb = self.pos_emb[:tok_emb.shape[1]]
        
        x = tok_emb + pos_emb

        for block in self.blocks:
            x = block(x)

        x = self.lm_head(x)
        return x
    
    def loss(self, y_pred, y):
        return torch.nn.functional.cross_entropy(y_pred.reshape(-1, y_pred.shape[-1]), y.reshape(-1), ignore_index=-1)
    
    @torch.no_grad()
    def generate(self, start: list[int] | torch.Tensor | None = None, max_len: int = 100, temperature: float = 1.0, top_k: int = 0):
        if start is None:
            start = torch.randint(len(vocab), (1, 1), device=device)
        elif isinstance(start, list):
            start = torch.tensor(start, dtype=torch.long, device=device).unsqueeze(0)
        x = start

        for _ in tqdm(range(max_len)):
            y_pred = self(x)
            y_pred = y_pred[:, -1, :] / temperature
            if top_k > 0:
                y_pred = torch.topk(y_pred, top_k, dim=-1).values
            next_char = torch.multinomial(torch.nn.functional.softmax(y_pred, dim=-1), 1)
            x = torch.cat([x, next_char], dim=1)
        
        return x
    
model = Model(32).to(device)
print(decode(model.generate(max_len=128).squeeze()))
# print(f"Evaluation loss: {evaluate(model, val_loader):.4f}")
summary(model, (64, 64), dtypes=[torch.long], device=device, depth=2)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
if "val_loss" in locals():
    pass
else:
    val_loss = float("inf")


for epoch in range(100):
    model.train()
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}")
    for inputs, in pbar:
        inputs = inputs.to(device)

        # for _ in range(10):
        x = inputs[:, :-1]
        y = inputs[:, 1:]

        optimizer.zero_grad()
            
        y_pred = model(x)
        loss = F.cross_entropy(y_pred.reshape(-1, y_pred.shape[-1]), y.reshape(-1), ignore_index=-1)

        loss.backward()
        optimizer.step()
        pbar.set_postfix(loss=loss.item(), val_loss=val_loss)

    val_loss = evaluate(model, val_loader)

In [None]:
evaluate(model, val_loader)

In [None]:
decode(model.generate(max_len=128, top_k=len(vocab)).flatten())