In [74]:
import numpy as np

import math
import inspect
from dataclasses import dataclass

import torch
import torch.nn as nn
from torch.nn import functional
from torch.nn import functional as F

# Attentions

## Vanilla

In [75]:
class VanillaAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.block_size = config.block_size
        
        # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            # causal mask to ensure that attention is only applied to the left in the input sequence
            self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                 .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        if self.flash:
            # efficient attention using Flash Attention CUDA kernels
            y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None,
                                                               is_causal=True)
        else:
            # manual implementation of attention
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
            att = nn.functional.softmax(att, dim=-1)
            y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
            
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.c_proj(y)
        return y 

## Nystrom

In [109]:
class NystromAttention(nn.Module):
    """
    Linformer self-attention mechanism with linear complexity.
    Projects keys and values to a lower dimensional space for efficiency.
    """
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0

        # Default Linformer config
        self.n_landmarks = config.attention_config.get('nystrom_landmarks', 32) if config.attention_config else 32
        print(f'Nystrom with landmarks {self.n_landmarks}')

        # key, query, value projections for all heads
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)

        # regularization
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        
        # mask
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                             .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)
        hs = C // self.n_head

        # calculate query, key, values for all heads in batch
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)

        k = k.view(B, T, self.n_head, hs).transpose(1, 2)  # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, hs).transpose(1, 2)  # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, hs).transpose(1, 2)  # (B, nh, T, hs)
        
        params = {'B': B, 'nh': self.n_head, 'T': T, 'hs': hs}
        
        # Project keys and values to lower dimensional space
        q_landmarks = self.__get_landmark_representation(q, self.n_landmarks, **params)
        k_landmarks = self.__get_landmark_representation(k, self.n_landmarks, **params)
        
        # Compute the attention matrix
        L = F.softmax(q @ k_landmarks.transpose(-1, -2) / math.sqrt(hs), dim=-1)
        P = self.__iterative_inv(F.softmax(q_landmarks @ k_landmarks.transpose(-1, -2) / math.sqrt(hs), dim=-1))

        N_prod = (q_landmarks @ k.transpose(-1, -2))
        # print(N_prod.shape)
        # print(self.bias.shape)
        # print(q_landmarks.shape)
        N_masked = N_prod.masked_fill(self.bias[:, :, :self.n_landmarks, :T] == 0, float('-inf'))
        N = F.softmax(N_masked / math.sqrt(hs), dim=-1)
        
        # Compute attention scores
        att = L @ P @ N

        # Apply attention to values and reshape
        y = att @ v  # (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C)  # re-assemble all head outputs side by side

        # output projection
        return self.c_proj(y)
    
    def __get_landmark_representation(self, tensor, num_landmarks, B, nh, T, hs):
        tensor_reshaped = tensor.reshape(-1, nh, num_landmarks, T // num_landmarks, hs) # (B, nh, T, hs)
        tensor_landmarks = tensor_reshaped.mean(dim=-2)
        return tensor_landmarks

    def __iterative_inv(self, mat, n_iter=6):
        I = torch.eye(mat.size(-1), device=mat.device)
        K = mat

        # The entries of K are positive and ||K||_{\infty} = 1 due to softmax
        V = 1 / torch.max(torch.sum(K, dim=-2), dim = -1).values[:, :, None, None] * K.transpose(-1, -2)

        for _ in range(n_iter):
            KV = torch.matmul(K, V)
            V = torch.matmul(0.25 * V, 13 * I - torch.matmul(KV, 15 * I - torch.matmul(KV, 7 * I - KV)))
        return V

## Linformer

In [77]:
class LinformerAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0

        self.n_head = config.n_head
        self.head_size = config.n_embd // config.n_head
        self.linformer_k = config.attention_config.get('linformer_k', config.block_size // 4) \
            if config.attention_config else config.block_size // 4

        # Q, K, V projections
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)

        # Linformer projection matrices E, F to go from T -> k
        # shape (n_head, linformer_k, block_size)
        self.E = nn.Parameter(torch.randn(config.n_head, self.linformer_k, config.block_size))
        self.F = nn.Parameter(torch.randn(config.n_head, self.linformer_k, config.block_size))

    def forward(self, x):
        B, T, C = x.size()
        # (B, T, 3*n_embd) -> split -> each (B, T, n_embd)
        q, k, v = self.c_attn(x).split(C, dim=2)

        # shape them into heads
        q = q.view(B, T, self.n_head, self.head_size).transpose(1, 2)  # (B, n_head, T, head_size)
        k = k.view(B, T, self.n_head, self.head_size).transpose(1, 2)  # (B, n_head, T, head_size)
        v = v.view(B, T, self.n_head, self.head_size).transpose(1, 2)  # (B, n_head, T, head_size)

        # Project K and V from T->k
        # E,F each is (n_head, linformer_k, block_size), slice to current T
        E = self.E[:, :, :T]  # => (n_head, linformer_k, T)
        F = self.F[:, :, :T]  # => (n_head, linformer_k, T)

        # Now do the batch multiplication:
        k_projected = torch.einsum('hkt, bhtd -> bhkd', E, k)  # (B, n_head, linformer_k, head_size)
        v_projected = torch.einsum('hkt, bhtd -> bhkd', F, v)  # (B, n_head, linformer_k, head_size)

        # Compute attention
        # q is (B, n_head, T, head_size)
        # k_projected is (B, n_head, linformer_k, head_size)
        # so q @ k_projected^T => (B, n_head, T, linformer_k)
        att = torch.matmul(
            q, k_projected.transpose(-2, -1)
        ) * (1.0 / math.sqrt(self.head_size))

        att = nn.functional.softmax(att, -1)

        # Then multiply by v_projected => (B, n_head, T, head_size)
        y = att @ v_projected

        # re-combine heads
        y = y.transpose(1, 2).contiguous().view(B, T, C)

        # final linear out
        y = self.c_proj(y)
        return y 

## Performer

In [78]:
class CausalPerformerAttention(nn.Module):
    """
    Causal Performer attention (FAVOR+) for GPT-style models.
    Uses random feature maps + prefix sums to enforce autoregressive masking.

    NOTE:
    - This is a *vectorized* implementation. For long T,
      memory consumption can be high (we store prefix sums of shape ~ (B, n_head, T, ...)).
    - If you need incremental generation, you would maintain prefix sums
      in a stateful manner instead of computing them for the entire sequence at once.
    """

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0

        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.head_size = self.n_embd // self.n_head

        # Number of random features for Performer
        self.n_features = config.attention_config.get('performer_features', 64) \
            if config.attention_config else 64

        # Q, K, V projections (linear)
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        # Output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)

        # Create a random projection matrix for the feature map
        # Typically, it is not trained; we register it as a buffer
        # shape = (head_size, n_features)
        # scaled by e.g. 0.01 or 0.1 to keep exponent magnitudes stable
        proj = torch.randn(self.head_size, self.n_features) * 0.1
        self.register_buffer("proj", proj)

    def forward(self, x):
        """
        x: (B, T, C), where C = n_embd
        returns: (B, T, C)
        """
        B, T, C = x.size()

        # 1) Compute Q, K, V in one fused linear op
        qkv = self.c_attn(x)  # (B, T, 3*C)
        q, k, v = qkv.split(C, dim=2)  # each is (B, T, C)

        # 2) Reshape into heads: (B, n_head, T, head_size)
        q = q.view(B, T, self.n_head, self.head_size).transpose(1, 2)
        k = k.view(B, T, self.n_head, self.head_size).transpose(1, 2)
        v = v.view(B, T, self.n_head, self.head_size).transpose(1, 2)

        # 3) Map K, Q via the Performer random feature map: phi(k), phi(q)
        k_prime = self._prime(k)  # (B, n_head, T, n_features)
        q_prime = self._prime(q)  # (B, n_head, T, n_features)

        # 4) Compute prefix sums for enforcing causality.
        #    We'll do cumulative sums along time dimension, so each position t
        #    sees only sum_{<=t} instead of sum_{<=T}.

        # 4a) Compute k_prime * v => shape (B, n_head, T, n_features, head_size)
        #    We'll do an unsqueeze on the last dimension of k_prime or v
        #    to match them up.
        k_prime_expanded = k_prime.unsqueeze(-1)  # (B, n_head, T, n_features, 1)
        v_expanded = v.unsqueeze(-2)  # (B, n_head, T, 1, head_size)
        kprime_v = k_prime_expanded * v_expanded  # (B, n_head, T, n_features, head_size)

        # 4b) prefix sums along T
        prefix_k = torch.cumsum(k_prime, dim=2)  # (B, n_head, T, n_features)
        prefix_kprime_v = torch.cumsum(kprime_v, dim=2)  # (B, n_head, T, n_features, head_size)

        # 5) For each position t, the attention result is:
        #    numerator[t]   = q_prime[t] dot prefix_kprime_v[t]
        #    denominator[t] = q_prime[t] dot prefix_k[t]
        #    out[t] = numerator[t] / denominator[t]

        # We'll do that in a fully vectorized manner with einsum:
        # numerator shape => (B, n_head, T, head_size)
        numerator = torch.einsum(
            'b n t f, b n t f d -> b n t d',  # q_prime[t,f] * prefix_kprime_v[t,f,d] -> out[t,d]
            q_prime,
            prefix_kprime_v
        )
        # denominator shape => (B, n_head, T)
        denominator = torch.einsum(
            'b n t f, b n t f -> b n t',  # q_prime[t,f] * prefix_k[t,f] -> scalar
            q_prime,
            prefix_k
        ) + 1e-6  # avoid division by zero

        out = numerator / denominator.unsqueeze(-1)  # broadcast over 'd'

        # 7) Re-combine the heads: (B, T, C)
        out = out.transpose(1, 2).contiguous().view(B, T, C)

        # 8) Final linear projection
        out = self.c_proj(out)
        return out

    def _prime(self, x):
        """
        Performer random feature map:
           phi(x) = exp(x * W - ||x||^2 / 2) / sqrt(n_features)
        where W is self.proj (shape [head_size, n_features]).

        x: shape (B, n_head, T, head_size)
        returns: (B, n_head, T, n_features)
        """
        # squared norm of x => (B, n_head, T, 1)
        norm_sq = torch.sum(x ** 2, dim=-1, keepdim=True)  # ||x||^2

        # x_proj => (B, n_head, T, n_features)
        x_proj = torch.einsum('b n t d, d f -> b n t f', x, self.proj)

        # exponent => exp(x_proj - norm_sq/2)
        x_exp = torch.exp(x_proj - 0.5 * norm_sq)

        # scale by 1 / sqrt(n_features)
        x_exp = x_exp * (1.0 / math.sqrt(self.n_features))
        return x_exp

# Config

In [122]:
class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 65  # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
    n_layer: int = 6
    n_head: int = 6
    n_embd: int = 384
    attention_config = {'nystrom_landmarks': 64}
    bias = 0.2

# Util

In [80]:
def prepare_tokens(x, config, wte, wpe, ln):
    # In: token embeddings of shape (1, t)
    b, t = x.size()
    pos = torch.arange(0, t, dtype=torch.long, device='cpu').unsqueeze(0)  # shape (1, t)
    tok_emb = wte(x)  # token embeddings of shape (b, t, n_embd)
    pos_emb = wpe(pos)  # position embeddings of shape (1, t, n_embd)
    coded_x = tok_emb + pos_emb
    norm_x = ln(coded_x)
    return norm_x # token embeddings of shape (1, t, n_embd)

In [81]:
def get_config(block_size):
    config = GPTConfig()
    config.block_size = block_size

    wte=nn.Embedding(config.vocab_size, config.n_embd)
    wpe=nn.Embedding(config.block_size, config.n_embd)
    ln = nn.LayerNorm(config.n_embd, bias=False)

    return config, wte, wpe, ln

In [82]:
def get_tokens(x_np_array, config, wte, wpe, ln):
    print(f'Tokens shape (batch_size, context_window): {x_np_array.shape}')
    x = torch.stack([torch.from_numpy(tokens) for tokens in x_np_array])
    return prepare_tokens(x, config, wte, wpe, ln)

In [113]:
with torch.no_grad():
    def get_attention_matrix(att_type, block_size, x):
        torch.manual_seed(42)
        np.random.seed(42)
        config, wte, wpe, ln = get_config(block_size)
    
        if att_type == 'linformer':
            attention = LinformerAttention(config)
        elif att_type == 'nystrom':
            attention = NystromAttention(config)
        elif att_type == 'vanilla':
            attention = VanillaAttention(config)
        elif att_type == 'performer':
            attention = CausalPerformerAttention(config)
    
        tokens = get_tokens(x[:1, :block_size], config, wte, wpe, ln)
        att_matrix = attention.forward(tokens)
        print(f'obtained attention matrix of shape {att_matrix.shape}')
        att_matrix = att_matrix
            
        return att_matrix

In [84]:
# Metrics

In [85]:
x_array = np.loadtxt('/kaggle/input/attention-tokens/tokens8k.txt', delimiter=',')
x_np = np.array([x_array]).astype(np.int64)

In [100]:
%%time
attention_type = 'vanilla'
block_size = 256
attention_matrix = get_attention_matrix(attention_type, block_size, x_np)

Tokens shape (batch_size, context_window): (1, 256)
obtained attention matrix of shape torch.Size([1, 256, 384])
CPU times: user 23.7 ms, sys: 1.12 ms, total: 24.8 ms
Wall time: 19.4 ms


In [126]:
vanilla_matrix1 = get_attention_matrix('vanilla', 8192, x_np).clone().detach()
nystrom_matrix = get_attention_matrix('linformer', 8192, x_np).clone().detach()
sim = F.cosine_similarity(vanilla_matrix1.view(-1), nystrom_matrix.view(-1), dim=0)
print(sim.item())

Tokens shape (batch_size, context_window): (1, 8192)
obtained attention matrix of shape torch.Size([1, 8192, 384])
Tokens shape (batch_size, context_window): (1, 8192)
obtained attention matrix of shape torch.Size([1, 8192, 384])
-0.002036292804405093


In [88]:
config, wte, wpe, ln = get_config(256)

In [89]:
x = torch.stack([torch.from_numpy(tokens) for tokens in x_np])

In [90]:
wte(x)[0][:5][:5]

tensor([[-0.7175, -1.4476, -0.7953,  ..., -0.2047,  0.1062,  1.1078],
        [-1.0297, -0.1968,  0.7316,  ..., -1.1973, -0.3622, -0.2030],
        [-0.8943,  0.8867,  0.0396,  ...,  1.0203, -0.8696, -0.7008],
        [ 0.3200, -0.2651, -0.0264,  ..., -2.3089,  0.4150,  0.3817],
        [ 1.0118, -0.4148,  0.8462,  ..., -1.7770,  0.3753,  0.0305]],
       grad_fn=<SliceBackward0>)