In [97]:
import numpy as np

import math

import torch
import torch.nn as nn
from torch.nn import functional as F

# Attentions

## Vanilla

In [98]:
class VanillaAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.block_size = config.block_size
        
        # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            # causal mask to ensure that attention is only applied to the left in the input sequence
            self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                 .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        if self.flash:
            # efficient attention using Flash Attention CUDA kernels
            y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None,
                                                               is_causal=True)
        else:
            # manual implementation of attention
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
            att = nn.functional.softmax(att, dim=-1)
            y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)

        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.c_proj(y)
        return y 

## Nystrom

In [99]:
class NystromAttention(nn.Module):
    """
    Linformer self-attention mechanism with linear complexity.
    Projects keys and values to a lower dimensional space for efficiency.
    """
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0

        # Default Linformer config
        self.n_landmarks = config.attention_config.get('nystrom_landmarks', 32) if config.attention_config else 32
        print(f'Nystrom with landmarks {self.n_landmarks}')

        # key, query, value projections for all heads
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)

        # regularization
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        
        # mask
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                             .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)
        hs = C // self.n_head

        # calculate query, key, values for all heads in batch
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)

        k = k.view(B, T, self.n_head, hs).transpose(1, 2)  # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, hs).transpose(1, 2)  # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, hs).transpose(1, 2)  # (B, nh, T, hs)
        
        params = {'B': B, 'nh': self.n_head, 'T': T, 'hs': hs}
        
        # Project keys and values to lower dimensional space
        q_landmarks = self.__get_landmark_representation(q, self.n_landmarks, **params)
        k_landmarks = self.__get_landmark_representation(k, self.n_landmarks, **params)
        
        # Compute the attention matrix
        L = F.softmax(q @ k_landmarks.transpose(-1, -2) / math.sqrt(hs), dim=-1)
        P = self.__iterative_inv(F.softmax(q_landmarks @ k_landmarks.transpose(-1, -2) / math.sqrt(hs), dim=-1))

        N_prod = (q_landmarks @ k.transpose(-1, -2))
        # print(N_prod.shape)
        # print(self.bias.shape)
        # print(q_landmarks.shape)
        N_masked = N_prod.masked_fill(self.bias[:, :, :self.n_landmarks, :T] == 0, float('-inf'))
        N = F.softmax(N_masked / math.sqrt(hs), dim=-1)
        
        # Compute attention scores
        att = L @ P @ N

        # Apply attention to values and reshape
        y = att @ v  # (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C)  # re-assemble all head outputs side by side

        # output projection
        return self.c_proj(y)
    
    def __get_landmark_representation(self, tensor, num_landmarks, B, nh, T, hs):
        tensor_reshaped = tensor.reshape(-1, nh, num_landmarks, T // num_landmarks, hs) # (B, nh, T, hs)
        tensor_landmarks = tensor_reshaped.mean(dim=-2)
        return tensor_landmarks

    def __iterative_inv(self, mat, n_iter=6):
        I = torch.eye(mat.size(-1), device=mat.device)
        K = mat

        # The entries of K are positive and ||K||_{\infty} = 1 due to softmax
        V = 1 / torch.max(torch.sum(K, dim=-2), dim = -1).values[:, :, None, None] * K.transpose(-1, -2)

        for _ in range(n_iter):
            KV = torch.matmul(K, V)
            V = torch.matmul(0.25 * V, 13 * I - torch.matmul(KV, 15 * I - torch.matmul(KV, 7 * I - KV)))
        return V

## Linformer

In [100]:
class LinformerAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0

        self.n_head = config.n_head
        self.head_size = config.n_embd // config.n_head

        # default: k = block_size // 4, or from config
        self.linformer_k = config.attention_config.get('linformer_k', config.block_size//4) \
                           if config.attention_config else config.block_size//4

        # Q, K, V projections
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=True)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=True)

        # Initialize E and F with mean of K and V projections
        self.E = self.initialize_projection_matrix(config.block_size, self.linformer_k)
        self.F = self.initialize_projection_matrix(config.block_size, self.linformer_k)

        # naive causal mask
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                             .view(1, 1, config.block_size, config.block_size))

    def initialize_projection_matrix(self, block_size, linformer_k):
        # Initialize a tensor that averages over the input tokens
        init_matrix = torch.ones(self.n_head, linformer_k, block_size) / block_size
        return nn.Parameter(init_matrix)

    def forward(self, x):
        B, T, C = x.size()
        q, k, v = self.c_attn(x).split(C, dim=2)

        q = q.view(B, T, self.n_head, self.head_size).transpose(1, 2)
        k = k.view(B, T, self.n_head, self.head_size).transpose(1, 2)
        v = v.view(B, T, self.n_head, self.head_size).transpose(1, 2)

        # slice E, F to current T => shape (n_head, linformer_k, T)
        E = self.E[:, :, :T]
        F = self.F[:, :, :T]

        # project K, V
        k_projected = torch.einsum('hkt,bhtd->bhkd', E, k)  # (B, n_head, k_lin, head_size)
        v_projected = torch.einsum('hkt,bhtd->bhkd', F, v)  # (B, n_head, k_lin, head_size)

        # attention
        att = torch.matmul(q, k_projected.transpose(-2, -1)) / math.sqrt(self.head_size)
        # 'att' is (B, n_head, T, linformer_k)

        # naive masking: we slice the standard T x T mask to T x k
        # This is not truly correct for a "causal Linformer," but it at least
        # prevents the largest mismatch where Linformer sees all tokens.
        causal_mask = self.bias[:, :, :T, :k_projected.size(2)]  # shape (1,1,T,k_lin)
        att = att.masked_fill(causal_mask == 0, float('-inf'))

        att = nn.functional.softmax(att, dim=-1)
        y = torch.matmul(att, v_projected)  # (B, n_head, T, head_size)

        # reassemble
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        return self.c_proj(y)


## Performer

In [101]:
class CausalPerformerAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0

        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.head_size = self.n_embd // self.n_head

        self.n_features = config.attention_config.get('performer_features', 64) \
            if config.attention_config else 64

        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=True)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=True)

        # random projection for performer
        proj = torch.randn(self.head_size, self.n_features) * 0.1
        self.register_buffer("proj", proj)

        # causal mask for safety (though the prefix-sum approach is also doing it)
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                             .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size()
        qkv = self.c_attn(x)
        q, k, v = qkv.split(C, dim=2)

        q = q.view(B, T, self.n_head, self.head_size).transpose(1, 2)
        k = k.view(B, T, self.n_head, self.head_size).transpose(1, 2)
        v = v.view(B, T, self.n_head, self.head_size).transpose(1, 2)

        # random feature mapping
        k_prime = self._prime(k)  # (B, n_head, T, n_features)
        q_prime = self._prime(q)

        # prefix sums along T
        kprime_v = k_prime.unsqueeze(-1) * v.unsqueeze(-2)   # (B, n_head, T, n_features, head_size)
        prefix_k = torch.cumsum(k_prime, dim=2)              # (B, n_head, T, n_features)
        prefix_kv = torch.cumsum(kprime_v, dim=2)            # (B, n_head, T, n_features, head_size)

        # numerator: q_prime[t, :] dot prefix_kv[t, :, :]
        numerator = torch.einsum('b n t f, b n t f d -> b n t d', q_prime, prefix_kv)
        # denominator
        denominator = torch.einsum('b n t f, b n t f -> b n t', q_prime, prefix_k) + 1e-6

        out = numerator / denominator.unsqueeze(-1)

        out = out.transpose(1, 2).contiguous().view(B, T, C)
        return self.c_proj(out)

    def _prime(self, x):
        # phi(x) = exp(xW - ||x||^2/2) / sqrt(n_features)
        norm_sq = torch.sum(x**2, dim=-1, keepdim=True)
        x_proj = torch.einsum('b n t d, d f -> b n t f', x, self.proj)
        x_exp = torch.exp(x_proj - 0.5 * norm_sq)
        return x_exp * (1.0 / math.sqrt(self.n_features))

# Config

In [102]:
class GPTConfig:
    block_size: int = 4096
    vocab_size: int = 65  # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
    n_layer: int = 6
    n_head: int = 6
    n_embd: int = 384
    attention_config = {'nystrom_landmarks': 64}
    bias = 0.2

# Util

In [103]:
def prepare_tokens(x, config, wte, wpe, ln):
    # In: token embeddings of shape (1, t)
    b, t = x.size()
    pos = torch.arange(0, t, dtype=torch.long, device='cpu').unsqueeze(0)  # shape (1, t)
    tok_emb = wte(x)  # token embeddings of shape (b, t, n_embd)
    pos_emb = wpe(pos)  # position embeddings of shape (1, t, n_embd)
    coded_x = tok_emb + pos_emb
    norm_x = ln(coded_x)
    return norm_x # token embeddings of shape (1, t, n_embd)

In [104]:
def get_config(block_size):
    config = GPTConfig()
    config.block_size = block_size

    wte=nn.Embedding(config.vocab_size, config.n_embd)
    wpe=nn.Embedding(config.block_size, config.n_embd)
    ln = nn.LayerNorm(config.n_embd, bias=False)

    return config, wte, wpe, ln

In [105]:
def get_tokens(x_np_array, config, wte, wpe, ln):
    print(f'Tokens shape (batch_size, context_window): {x_np_array.shape}')
    x = torch.stack([torch.from_numpy(tokens) for tokens in x_np_array])
    return prepare_tokens(x, config, wte, wpe, ln)

In [106]:
with torch.no_grad():
    def get_attention_matrix(att_type, block_size, x):
        torch.manual_seed(42)
        np.random.seed(42)
        config, wte, wpe, ln = get_config(block_size)
    
        if att_type == 'linformer':
            attention = LinformerAttention(config)
        elif att_type == 'nystrom':
            attention = NystromAttention(config)
        elif att_type == 'vanilla':
            attention = VanillaAttention(config)
        elif att_type == 'performer':
            attention = CausalPerformerAttention(config)
    
        tokens = get_tokens(x[:1, :block_size], config, wte, wpe, ln)
        att_matrix = attention.forward(tokens)
        print(f'obtained attention matrix of shape {att_matrix.shape}')
        att_matrix = att_matrix
            
        return att_matrix

In [107]:
# Metrics

In [108]:
x_array = np.loadtxt('./shakespeare_char/tokens10k.txt', delimiter=',')
x_np = np.array([x_array]).astype(np.int64)

In [109]:
%%time
attention_type = 'vanilla'
block_size = 4096
attention_matrix = get_attention_matrix(attention_type, block_size, x_np)

Tokens shape (batch_size, context_window): (1, 4096)
obtained attention matrix of shape torch.Size([1, 4096, 384])
CPU times: user 304 ms, sys: 57.9 ms, total: 362 ms
Wall time: 63.4 ms


In [110]:
vanilla_out = get_attention_matrix('vanilla', block_size, x_np).clone().detach()

# compare with Nystrom
nystrom_out = get_attention_matrix('nystrom', block_size, x_np).clone().detach()
sim_nys = nn.functional.cosine_similarity(vanilla_out.view(-1), nystrom_out.view(-1), dim=0)
print(f"Vanilla vs Nystrom similarity: {sim_nys.item():.4f}")

# compare with Linformer
linformer_out = get_attention_matrix('linformer', block_size, x_np).clone().detach()
sim_lin = nn.functional.cosine_similarity(vanilla_out.view(-1), linformer_out.view(-1), dim=0)
print(f"Vanilla vs Linformer similarity: {sim_lin.item():.4f}")

# compare with Performer
performer_out = get_attention_matrix('performer', block_size, x_np).clone().detach()
sim_per = nn.functional.cosine_similarity(vanilla_out.view(-1), performer_out.view(-1), dim=0)
print(f"Vanilla vs Performer similarity: {sim_per.item():.4f}")

Tokens shape (batch_size, context_window): (1, 4096)
obtained attention matrix of shape torch.Size([1, 4096, 384])
Nystrom with landmarks 64
Tokens shape (batch_size, context_window): (1, 4096)
obtained attention matrix of shape torch.Size([1, 4096, 384])
Vanilla vs Nystrom similarity: 0.7488
Tokens shape (batch_size, context_window): (1, 4096)
obtained attention matrix of shape torch.Size([1, 4096, 384])
Vanilla vs Linformer similarity: 0.9479
Tokens shape (batch_size, context_window): (1, 4096)
obtained attention matrix of shape torch.Size([1, 4096, 384])
Vanilla vs Performer similarity: 0.7824


In [111]:
config, wte, wpe, ln = get_config(256)

In [112]:
x = torch.stack([torch.from_numpy(tokens) for tokens in x_np])

In [113]:
wte(x)[0][:5][:5]

tensor([[-0.5744, -0.3336,  1.2779,  ..., -0.4308,  0.8104,  2.1108],
        [ 0.2076, -0.2405,  0.5607,  ...,  0.1064,  0.9199,  2.4935],
        [-1.2587,  0.3004, -0.1708,  ...,  0.3788, -0.2477,  0.1624],
        [ 0.8966, -0.5991, -0.1178,  ...,  1.4480, -0.7148,  0.1802],
        [-0.0349,  1.2517, -0.3017,  ..., -1.5413, -0.3463,  0.0140]],
       grad_fn=<SliceBackward0>)