<a href="https://colab.research.google.com/github/elrrowwe/robust-principal-component-attention/blob/main/rpc_attention_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook contains my experiments with integrating robust principal component attention (RPCA) introduced in [here](https://arxiv.org/pdf/2406.13762) with the Karpathy from-scratch implementation of transformers.

In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from sklearn.model_selection import train_test_split

In [17]:
# defining some hyperparameters
batch_size = 16  # how many sequences (sentences of some length) are processed at once
block_size = 32  # how long the aforementioned sequences are (context length)
train_iters = 2000
eval_interval = 100
lr = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
eval_iters = 200
n_embed = 64  # the number of embedding dimensions (the dimensionality of embedding vectors)
n_head = 4  # the number of heads in the multi-head attention layer
n_layer = 6
dropout = 0.0  # the dropout rate

cuda


In [3]:
oliver_twist = open('pg730.txt', 'r', encoding='utf8').read()

In [4]:
# printing the number of characters (tokens) in the dataset
print(len(list(oliver_twist)))

1079372


In [5]:
# getting all the characters, like in the n-gram model
chars = sorted(list(set(oliver_twist)))
vocab_size = len(chars)

In [6]:
# a simple (de-)tokenizer
stoi = {s: i for i,s in enumerate(chars)}
itos = {i: s for i,s in enumerate(chars)}

In [7]:
encode = lambda s: [stoi[ch] for ch in s]  # tokenize some characters
decode = lambda i: ''.join([itos[num] for num in i])  # detokenize some integers

In [8]:
# tokenizing the entire data set
enc = torch.tensor(encode(oliver_twist), dtype=torch.long, device=device)

In [9]:
# splitting the text into train, test portions
train, test = train_test_split(enc, shuffle=False, test_size=0.1)

In [10]:
def get_batch(split):
    """
    Get a batch of training examples (sentences) from the train/test dataset.
    """
    data = train if split == 'train' else test
    ix = torch.randint(len(data) - block_size, (batch_size,), device=device)
    x = torch.stack([data[i:i + block_size] for i in ix])
    y = torch.stack([data[i + 1:i + block_size + 1] for i in ix])

    x, y = x.to(device), y.to(device)
    return x, y

In [11]:
@torch.no_grad()
def estimate_loss():
    """
    Estimate the average loss of the model on the train/test dataset.
    """
    out = {}  # the output placeholder dictionary
    bi.eval()
    for split in ['train', 'test']:
        losses = torch.zeros(eval_iters, device=device)

        for i in range(eval_iters):
            X, y = get_batch(split)
            logits, loss = bi(X, y)
            losses[i] = loss.item()

        out[split] = losses.mean()
    bi.train()
    return out

Attention implementation from [here](https://github.com/rachtsy/KPCA_code/blob/master/Robust/softmax.py)

In [12]:
# TODO: add explanatory comments

class Attention(nn.Module):
    """

    """
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.,
                 robust=False, layerth=0, n=1, lambd=0, layer=0):
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.n = n
        self.lambd = lambd
        self.layer = layer
        # sqrt (D)
        self.scale = head_dim ** -0.5
        self.layerth = layerth

        self.qkv = nn.Linear(dim, dim * 2, bias=qkv_bias)

        self.attn_drop = nn.Dropout(attn_drop)

        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.robust = robust

    def forward(self, x):
        B, N, C = x.shape
        # q,k -> B -> heads -> n -> features
        qkv = self.qkv(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        k, v = qkv.unbind(0)   # make torchscript happy (cannot use tensor as tuple)

        # robust principal components attention
        if (self.robust and self.layer < 0) or (self.robust and self.layerth==self.layer):
            l = torch.zeros((B,self.num_heads,N,C // self.num_heads)).to(torch.device("cuda"), non_blocking=True)
            y = torch.zeros((B,self.num_heads,N,C // self.num_heads)).to(torch.device("cuda"), non_blocking=True)

            mu=N*C/4/k.norm(p=1,dim=[-1,-2],keepdim=True)

            for i in range(0,self.n-1):
                s = k-l+y/mu
                s_less = s.le(-self.lambd*mu).int()
                s_more = s.ge(self.lambd*mu).int()
                s = (s-self.lambd*mu)*s_more + (s+self.lambd*mu)*s_less
                k2 = k-s-y/mu
                l = (k2 @ k2.transpose(-2, -1)) * self.scale
                l = l.softmax(dim=-1)
                l = l @ v
                y = y+mu*(k-l-s)

            s = k-l+y/mu
            s_less = s.le(-self.lambd*mu).int()
            s_more = s.ge(self.lambd*mu).int()
            s = (s-self.lambd*mu)*s_more + (s+self.lambd*mu)*s_less
            k2 = k-s-y/mu
            l = (k2 @ k2.transpose(-2, -1)) * self.scale
            l = l.softmax(dim=-1)
            l = self.attn_drop(l)
            x = l @ v
            y = y+mu*(k-x-s)

        # symmetric attention
        else:
            attn = (k @ k.transpose(-2, -1)) * self.scale
            attn = attn.softmax(dim=-1)

            attn = self.attn_drop(attn)

            x = (attn @ v)

        x = x.transpose(1, 2).reshape(B,N,C)

        x = self.proj(x)
        x = self.proj_drop(x)

        return x

In [13]:
class Dense(nn.Module):
    def __init__(self, n_embed):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4 * n_embed).to(device),
            nn.ReLU().to(device),
            nn.Linear(4 * n_embed, n_embed).to(device),
            nn.Dropout(dropout).to(device)
        )

    def forward(self, x):
        return self.net(x)

In [14]:
class Block(nn.Module):
    """

    """
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
                 dropout=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, layerth=None,
                 robust=False, n=1, lambd=0, layer=0):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias,
                                    attn_drop=attn_drop, proj_drop=drop, robust=robust,
                                    layerth=layerth, n=n, lambd=lambd, layer=layer)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Dense(n_embed=n_embed)
        self.layerth = layerth

    def forward(self, x):
        x = x + self.dropout(self.attn(self.norm1(x)))
        x = x + self.dropout(self.mlp(self.norm2(x)))
        return x

In [15]:
class BigramModel(nn.Module):
    def __init__(self,
                 embed_dim=n_embed,
                 num_heads=n_head,
                 qkv_bias=False,
                 dropout=dropout,
                 attn_drop_rate=dropout,
                 robust=True):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed).to(device)
        self.position_embedding_table = nn.Embedding(block_size, n_embed).to(device)
        self.lm_head = nn.Linear(n_embed, vocab_size).to(device)
        self.blocks = nn.Sequential(*[
            Block(
                dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias,
                attn_drop=attn_drop_rate, layerth = i, robust=robust)
            for i in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embed).to(device)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_embeddings = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        x = tok_embeddings + pos_emb
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)

        return idx

In [18]:
bi = BigramModel().to(device)

optimizer = torch.optim.AdamW(bi.parameters(), lr=lr)

for i in range(train_iters):
    if i % eval_interval == 0:
        losses = estimate_loss()
        print(f'step: {i}, train_loss: {losses["train"]}, test loss: {losses["test"]}')

    xt, yt = get_batch('train')
    logits, loss = bi(xt, yt)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

idx = torch.zeros((1, 1), dtype=torch.long, device=device)

print(decode(bi.generate(idx, max_new_tokens=2000)[0].tolist()))

step: 0, train_loss: 4.561573028564453, test loss: 4.563760280609131
step: 100, train_loss: 2.6156508922576904, test loss: 2.607220411300659
step: 200, train_loss: 2.2952284812927246, test loss: 2.311758279800415
step: 300, train_loss: 0.36925262212753296, test loss: 0.38749638199806213
step: 400, train_loss: 0.18664829432964325, test loss: 0.2270771712064743
step: 500, train_loss: 0.13469599187374115, test loss: 0.18306760489940643
step: 600, train_loss: 0.11643478274345398, test loss: 0.14675810933113098
step: 700, train_loss: 0.11148347705602646, test loss: 0.13251574337482452
step: 800, train_loss: 0.09771157056093216, test loss: 0.11405660957098007
step: 900, train_loss: 0.09336519986391068, test loss: 0.1087425947189331
step: 1000, train_loss: 0.09329553693532944, test loss: 0.10358624905347824
step: 1100, train_loss: 0.09017336368560791, test loss: 0.09857236593961716
step: 1200, train_loss: 0.08745908737182617, test loss: 0.09451554715633392
step: 1300, train_loss: 0.0866059511