In [127]:
import numpy as np

import math

import torch
import torch.nn as nn
from torch.nn import functional as F

# Attentions

## Vanilla

In [128]:
class VanillaAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.block_size = config.block_size
        
        # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            # causal mask to ensure that attention is only applied to the left in the input sequence
            self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                 .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        if self.flash:
            # efficient attention using Flash Attention CUDA kernels
            y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None,
                                                               is_causal=True)
        else:
            # manual implementation of attention
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
            att = nn.functional.softmax(att, dim=-1)
            y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)

        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.c_proj(y)
        return y 

## Nystrom

In [129]:
class NystromAttention(nn.Module):
    """
    Linformer self-attention mechanism with linear complexity.
    Projects keys and values to a lower dimensional space for efficiency.
    """
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0

        self.n_landmarks = config.attention_config.get('nystrom_landmarks', 32) if config.attention_config else 32
        print(f'Nystrom with landmarks {self.n_landmarks}')

        # key, query, value projections for all heads
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)

        # regularization
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        
        # mask
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                             .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)
        hs = C // self.n_head

        # calculate query, key, values for all heads in batch
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)

        k = k.view(B, T, self.n_head, hs).transpose(1, 2)  # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, hs).transpose(1, 2)  # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, hs).transpose(1, 2)  # (B, nh, T, hs)
        
        params = {'B': B, 'nh': self.n_head, 'T': T, 'hs': hs}
        
        # Project keys and values to lower dimensional space
        q_landmarks = self.__get_landmark_representation(q, self.n_landmarks, **params)
        k_landmarks = self.__get_landmark_representation(k, self.n_landmarks, **params)
        
        # Compute the attention matrix
        L = F.softmax(q @ k_landmarks.transpose(-1, -2) / math.sqrt(hs), dim=-1)
        P = self.__iterative_inv(F.softmax(q_landmarks @ k_landmarks.transpose(-1, -2) / math.sqrt(hs), dim=-1))

        N_prod = (q_landmarks @ k.transpose(-1, -2))
        # print(N_prod.shape)
        # print(self.bias.shape)
        # print(q_landmarks.shape)
        N_masked = N_prod.masked_fill(self.bias[:, :, :self.n_landmarks, :T] == 0, float('-inf'))
        N = F.softmax(N_masked / math.sqrt(hs), dim=-1)
        
        # Compute attention scores
        att = L @ P @ N

        # Apply attention to values and reshape
        y = att @ v  # (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C)  # re-assemble all head outputs side by side

        # output projection
        return self.c_proj(y)
    
    def __get_landmark_representation(self, tensor, num_landmarks, B, nh, T, hs):
        tensor_reshaped = tensor.reshape(-1, nh, num_landmarks, T // num_landmarks, hs) # (B, nh, T, hs)
        tensor_landmarks = tensor_reshaped.mean(dim=-2)
        return tensor_landmarks

    def __iterative_inv(self, mat, n_iter=6):
        I = torch.eye(mat.size(-1), device=mat.device)
        K = mat

        # The entries of K are positive and ||K||_{\infty} = 1 due to softmax
        V = 1 / torch.max(torch.sum(K, dim=-2), dim = -1).values[:, :, None, None] * K.transpose(-1, -2)

        for _ in range(n_iter):
            KV = torch.matmul(K, V)
            V = torch.matmul(0.25 * V, 13 * I - torch.matmul(KV, 15 * I - torch.matmul(KV, 7 * I - KV)))
        return V

## Linformer

In [130]:
import math
import torch
import torch.nn as nn
import pickle

def gen_causal_mask(input_size, dim_k, full_attention=False):
    """
    Generates a causal mask of size (input_size, dim_k) for Linformer.
    When full_attention is True, returns an (input_size, input_size) mask.
    """
    if full_attention:
        return (torch.triu(torch.ones(input_size, input_size)) == 1).transpose(0, 1)
    return (torch.triu(torch.ones(dim_k, input_size)) == 1).transpose(0, 1)

class LinformerAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        # Ensure the embedding is divisible by number of heads
        assert config.n_embd % config.n_head == 0, "Embedding dimension must be divisible by number of heads."

        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.head_dim = config.n_embd // config.n_head
        self.block_size = config.block_size

        # linformer_k: reduced sequence length (e.g., 64) for keys/values projections
        self.linformer_k = config.block_size // 4
        self.causal = getattr(config, "causal", True)

        # Single linear layer to produce queries, keys, and values (as in Nanogpt)
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        # Output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)

        # Learnable projection matrices for keys and values:
        # They map from the original sequence length (block_size) to the lower dimension (linformer_k)
        self.E = nn.Parameter(torch.randn(self.block_size, self.linformer_k))
        self.F = nn.Parameter(torch.randn(self.block_size, self.linformer_k))

        # Attempt to initialize E and F from pickle files if available.
        with open("../out-shakespeare-char/linformer_E_256.pkl", "rb") as f:
            E_init = pickle.load(f)["block_1"]
            self.E = nn.Parameter(E_init)

        with open("../out-shakespeare-char/linformer_F_256.pkl", "rb") as f:
            F_init = pickle.load(f)["block_1"]
            self.F = nn.Parameter(F_init)

        # Precompute causal mask if needed (mask shape: (block_size, linformer_k))
        if self.causal:
            self.register_buffer("causal_mask", gen_causal_mask(self.block_size, self.linformer_k, full_attention=False))

    def forward(self, x):
        """
        x: (batch_size, seq_len, n_embd)
        """
        B, T, C = x.size()

        # Compute queries, keys, values in one go and split them up
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        # Reshape and transpose to get shape (B, n_head, T, head_dim)
        q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)

        # Apply Linformer projection: project along the sequence dimension.
        # k_proj: (B, n_head, linformer_k, head_dim)
        # v_proj: (B, n_head, linformer_k, head_dim)
        k_proj = torch.einsum('bhtd,tk->bhkd', k, self.E)
        v_proj = torch.einsum('bhtd,tk->bhkd', v, self.F)

        # Compute attention scores using the projected keys.
        # q: (B, n_head, T, head_dim)
        # k_proj.transpose(-2, -1): (B, n_head, head_dim, linformer_k)
        att = (q @ k_proj.transpose(-2, -1)) / math.sqrt(self.head_dim)

        if self.causal:
            # Use only the first T rows of the precomputed causal mask.
            # causal_mask shape: (T, linformer_k)
            mask = self.causal_mask[:T, :]
            att = att.masked_fill(~mask.unsqueeze(0).unsqueeze(0), float('-inf'))

        att = att.softmax(dim=-1)

        # Multiply attention weights with projected values.
        # Resulting shape: (B, n_head, T, head_dim)
        y = att @ v_proj

        # Reassemble multi-head outputs.
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        return self.c_proj(y)


## Performer

In [131]:
class CausalPerformerAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0

        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.head_size = self.n_embd // self.n_head

        self.n_features = config.attention_config.get('performer_features', 64) \
            if config.attention_config else 64

        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=True)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=True)

        # random projection for performer
        proj = torch.randn(self.head_size, self.n_features) * 0.1
        self.register_buffer("proj", proj)

        # causal mask for safety (though the prefix-sum approach is also doing it)
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                             .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size()
        qkv = self.c_attn(x)
        q, k, v = qkv.split(C, dim=2)

        q = q.view(B, T, self.n_head, self.head_size).transpose(1, 2)
        k = k.view(B, T, self.n_head, self.head_size).transpose(1, 2)
        v = v.view(B, T, self.n_head, self.head_size).transpose(1, 2)

        # random feature mapping
        k_prime = self._prime(k)  # (B, n_head, T, n_features)
        q_prime = self._prime(q)

        # prefix sums along T
        kprime_v = k_prime.unsqueeze(-1) * v.unsqueeze(-2)   # (B, n_head, T, n_features, head_size)
        prefix_k = torch.cumsum(k_prime, dim=2)              # (B, n_head, T, n_features)
        prefix_kv = torch.cumsum(kprime_v, dim=2)            # (B, n_head, T, n_features, head_size)

        # numerator: q_prime[t, :] dot prefix_kv[t, :, :]
        numerator = torch.einsum('b n t f, b n t f d -> b n t d', q_prime, prefix_kv)
        # denominator
        denominator = torch.einsum('b n t f, b n t f -> b n t', q_prime, prefix_k) + 1e-6

        out = numerator / denominator.unsqueeze(-1)

        out = out.transpose(1, 2).contiguous().view(B, T, C)
        return self.c_proj(out)

    def _prime(self, x):
        # phi(x) = exp(xW - ||x||^2/2) / sqrt(n_features)
        norm_sq = torch.sum(x**2, dim=-1, keepdim=True)
        x_proj = torch.einsum('b n t d, d f -> b n t f', x, self.proj)
        x_exp = torch.exp(x_proj - 0.5 * norm_sq)
        return x_exp * (1.0 / math.sqrt(self.n_features))

# Config

In [132]:
class GPTConfig:
    block_size: int = 4096
    vocab_size: int = 65  # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
    n_layer: int = 6
    n_head: int = 6
    n_embd: int = 384
    attention_config = {'nystrom_landmarks': 64}
    bias = 0.2

# Util

In [133]:
def prepare_tokens(x, config, wte, wpe, ln):
    # In: token embeddings of shape (1, t)
    b, t = x.size()
    pos = torch.arange(0, t, dtype=torch.long, device='cpu').unsqueeze(0)  # shape (1, t)
    tok_emb = wte(x)  # token embeddings of shape (b, t, n_embd)
    pos_emb = wpe(pos)  # position embeddings of shape (1, t, n_embd)
    coded_x = tok_emb + pos_emb
    norm_x = ln(coded_x)
    return norm_x # token embeddings of shape (1, t, n_embd)

In [134]:
def get_config(block_size):
    config = GPTConfig()
    config.block_size = block_size

    wte=nn.Embedding(config.vocab_size, config.n_embd)
    wpe=nn.Embedding(config.block_size, config.n_embd)
    ln = nn.LayerNorm(config.n_embd, bias=False)

    return config, wte, wpe, ln

In [135]:
def get_tokens(x_np_array, config, wte, wpe, ln):
    print(f'Tokens shape (batch_size, context_window): {x_np_array.shape}')
    x = torch.stack([torch.from_numpy(tokens) for tokens in x_np_array])
    return prepare_tokens(x, config, wte, wpe, ln)

In [136]:
with torch.no_grad():
    def get_attention_matrix(att_type, block_size, x):
        torch.manual_seed(42)
        np.random.seed(42)
        config, wte, wpe, ln = get_config(block_size)
    
        if att_type == 'linformer':
            attention = LinformerAttention(config)
        elif att_type == 'nystrom':
            attention = NystromAttention(config)
        elif att_type == 'vanilla':
            attention = VanillaAttention(config)
        elif att_type == 'performer':
            attention = CausalPerformerAttention(config)
    
        tokens = get_tokens(x[:1, :block_size], config, wte, wpe, ln)
        att_matrix = attention.forward(tokens)
        print(f'obtained attention matrix of shape {att_matrix.shape}')
        att_matrix = att_matrix
            
        return att_matrix

In [137]:
# Metrics

In [138]:
x_array = np.loadtxt('./shakespeare_char/tokens10k.txt', delimiter=',')
x_np = np.array([x_array]).astype(np.int64)

In [139]:
%%time
attention_type = 'vanilla'
block_size = 256
attention_matrix = get_attention_matrix(attention_type, block_size, x_np)

Tokens shape (batch_size, context_window): (1, 256)
obtained attention matrix of shape torch.Size([1, 256, 384])
CPU times: user 10.6 ms, sys: 4.58 ms, total: 15.1 ms
Wall time: 5.7 ms


In [140]:
config, wte, wpe, ln = get_config(block_size)

In [141]:
vanilla_out = get_attention_matrix('vanilla', block_size, x_np).clone().detach()

# compare with Nystrom
nystrom_out = get_attention_matrix('nystrom', block_size, x_np).clone().detach()
sim_nys = nn.functional.cosine_similarity(vanilla_out.view(-1), nystrom_out.view(-1), dim=0)
print(f"Vanilla vs Nystrom similarity: {sim_nys.item():.4f}")

# compare with Linformer
linformer_out = get_attention_matrix('linformer', block_size, x_np).clone().detach()
sim_lin = nn.functional.cosine_similarity(vanilla_out.view(-1), linformer_out.view(-1), dim=0)
print(f"Vanilla vs Linformer similarity: {sim_lin.item():.4f}")

# compare with Performer
performer_out = get_attention_matrix('performer', block_size, x_np).clone().detach()
sim_per = nn.functional.cosine_similarity(vanilla_out.view(-1), performer_out.view(-1), dim=0)
print(f"Vanilla vs Performer similarity: {sim_per.item():.4f}")

Tokens shape (batch_size, context_window): (1, 256)
obtained attention matrix of shape torch.Size([1, 256, 384])
Nystrom with landmarks 64
Tokens shape (batch_size, context_window): (1, 256)
obtained attention matrix of shape torch.Size([1, 256, 384])
Vanilla vs Nystrom similarity: 0.7824
Tokens shape (batch_size, context_window): (1, 256)
obtained attention matrix of shape torch.Size([1, 256, 384])
Vanilla vs Linformer similarity: 0.0281
Tokens shape (batch_size, context_window): (1, 256)
obtained attention matrix of shape torch.Size([1, 256, 384])
Vanilla vs Performer similarity: 0.5644
