# Efficient Multi-Head Attention

Also exercise in understanding operand dimensionality in batched mat-muls

In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# Hyperparameters
batch_size = 4     # num independent examples
block_size = 1024  # max sequence length
n_embd = 768       # total embedding dim, both in and out, divisible by n_head
n_head = 12        # number of heads

# Init
torch.manual_seed(42)
c_attn_W = torch.randn(2304, 768) / 2304**0.5
c_attn_b = torch.randn(2304)
c_proj_W = torch.randn(768, 768) / 768**0.5
c_proj_b = torch.randn(768)
x = torch.randn(batch_size,block_size,n_embd)

In [3]:
# Reference Implementation (unoptimized, Karpathy video)

class GPTConfig:
    def __init__(self, block_size, n_embd, n_head):
        self.block_size: int = block_size  # max sequence length
        self.n_head: int = n_head  # number of heads
        self.n_embd: int = n_embd # embedding dimension

class CausalSelfAttentionKarpathy1(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        # regularization
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        # not really a 'bias', more of a mask, but following the OpenAI/HF naming though
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                             .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
        # e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        # attention (materializes the large (t,T) matrix for all the queries and keys)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # (B, nh, T, T)
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
        # output projection
        y = self.c_proj(y)
        return y

In [4]:
config = GPTConfig(block_size=block_size, n_embd=n_embd, n_head=n_head)
csa_k1 = CausalSelfAttentionKarpathy1(config)

# Load Weights
csa_k1_state = csa_k1.state_dict()
for k, v in csa_k1_state.items():
    print(k, v.shape)
print('--- dst ---')
csa_k1_state['c_attn.weight'] = c_attn_W.clone()
csa_k1_state['c_attn.bias'] = c_attn_b.clone()
csa_k1_state['c_proj.weight'] = c_proj_W.clone()
csa_k1_state['c_proj.bias'] = c_proj_b.clone()
csa_k1.load_state_dict(csa_k1_state)

# Run
y_k1 = csa_k1(x)
print(y_k1.shape)
print(y_k1.sum().item())

bias torch.Size([1, 1, 1024, 1024])
c_attn.weight torch.Size([2304, 768])
c_attn.bias torch.Size([2304])
c_proj.weight torch.Size([768, 768])
c_proj.bias torch.Size([768])
--- dst ---
torch.Size([4, 1024, 768])
-18783.59375


In [5]:
# Init
csa_pt = torch.nn.MultiheadAttention(
    embed_dim=n_embd,
    num_heads=n_head,
    dropout=0.0,
    bias=True,
    add_bias_kv=False,
    add_zero_attn=False,
    kdim=None,
    vdim=None,
    batch_first=True,  # !
    device=None,
    dtype=None
)

# Copy Weights
print('--- src ---')
csa_pt_state = csa_pt.state_dict()
for k,v in csa_pt_state.items():
    print(k, v.shape)
csa_pt_state['in_proj_weight'] = c_attn_W.clone()
csa_pt_state['in_proj_bias'] = c_attn_b.clone()
csa_pt_state['out_proj.weight'] = c_proj_W.clone()
csa_pt_state['out_proj.bias'] = c_proj_b.clone()
csa_pt.load_state_dict(csa_pt_state)

# Run
attn_mask = torch.nn.Transformer.generate_square_subsequent_mask(block_size)
y_pt = csa_pt(x, x, x, attn_mask=attn_mask, is_causal=True)[0]   # (output, attention_weights)
print(y_pt.shape)
print(y_pt.sum().item())

--- src ---
in_proj_weight torch.Size([2304, 768])
in_proj_bias torch.Size([2304])
out_proj.weight torch.Size([768, 768])
out_proj.bias torch.Size([768])
torch.Size([4, 1024, 768])
-18783.6015625


In [None]:
# My Implementaiton - loopy

class Head(nn.Module):
    """One self-attention head"""
    def __init__(self, block_size, n_embd, head_size):
        super().__init__()

        self.query_W = torch.randn(head_size, n_embd)
        self.query_b = torch.randn(head_size)
        self.key_W = torch.randn(head_size, n_embd)
        self.key_b = torch.randn(head_size)
        self.value_W = torch.randn(head_size, n_embd)
        self.value_b = torch.randn(head_size)
        self.register_buffer('tril', torch.tril(torch.ones((block_size, block_size))))

    def forward(self, x):
        B,T,Ch = x.shape

        x_query = x @ self.query_W.T + self.query_b  # B,T,H
        x_key = x @ self.key_W.T + self.key_b  # B,T,H
        x_value = x @ self.value_W.T + self.value_b  # B,T,H

        H = x_key.shape[-1]
        W_affin = x_query @ x_key.mT / H**0.5 # / H**0.5  # B,T,T <- B,T,C @ B,C,T
        W_affin = W_affin.masked_fill(self.tril[:T,:T]==0, float('-inf'))
        W_affin = torch.softmax(W_affin, dim=-1)

        out = W_affin @ x_value
        return out
    
class MultiHead(nn.Module):
    """Multiple self-attention heads"""
    def __init__(self, block_size, n_head, n_embd):
        super().__init__()
        
        head_size = n_embd // n_head
        self.heads = nn.ModuleList(
            [Head(block_size, n_embd, head_size) for _ in range(n_head)]
        )

        self.proj_W = torch.randn(n_embd, n_embd)
        self.proj_b = torch.randn(n_embd)

    def forward(self, x):
        x = torch.cat([h(x) for h in self.heads], dim=-1)
        x = x @ self.proj_W.T + self.proj_b
        return x  


In [7]:
csa_m1 = MultiHead(block_size=block_size, n_head=n_head, n_embd=n_embd)

q_W, k_W, v_W = c_attn_W.split(n_embd, dim=0)
q_b, k_b, v_b = c_attn_b.split(n_embd, dim=0)

q_W_multi = q_W.split(n_embd//n_head, dim=0)
q_b_multi = q_b.split(n_embd//n_head, dim=0)
k_W_multi = k_W.split(n_embd//n_head, dim=0)
k_b_multi = k_b.split(n_embd//n_head, dim=0)
v_W_multi = v_W.split(n_embd//n_head, dim=0)
v_b_multi = v_b.split(n_embd//n_head, dim=0)

for i, h in enumerate(csa_m1.heads):
    h.query_W = q_W_multi[i].clone()
    h.query_b = q_b_multi[i].clone()
    h.key_W = k_W_multi[i].clone()
    h.key_b = k_b_multi[i].clone()
    h.value_W = v_W_multi[i].clone()
    h.value_b = v_b_multi[i].clone()
csa_m1.proj_W = c_proj_W.clone()
csa_m1.proj_b = c_proj_b.clone()

y_m1 = csa_m1(x)
print(y_m1.shape)
print(y_m1.sum().item())

torch.Size([4, 1024, 768])
-18783.595703125


In [None]:
# My implementation - efficient BLAS

class CausalSelfAttentionMarcin2(nn.Module):
    """Multiple self-attention heads"""
    def __init__(self, block_size, n_head, n_embd):
        super().__init__()

        assert n_embd % n_head == 0
        self.n_head = n_head

        self.c_attn = nn.Linear(n_embd, 3*n_embd)
        self.c_proj = nn.Linear(n_embd, n_embd)
        self.register_buffer('bias', torch.tril(torch.ones((1, 1, block_size, block_size))))

    def forward(self, x):
        B, T, C = x.size()
        qkv = self.c_attn(x)
        q, k, v = qkv.split(n_embd, dim=2)  # B, T, nh*hs
        q = q.view(B, T, self.n_head, C//self.n_head)  # B,T,nh,hs
        k = k.view(B, T, self.n_head, C//self.n_head)  # B,T,nh,hs
        v = v.view(B, T, self.n_head, C//self.n_head)  # B,T,nh,hs
        q = q.transpose(1, 2)  # B,nh,T,hs
        k = k.transpose(1, 2)  # B,nh,T,hs
        v = v.transpose(1, 2)  # B,nh,T,hs

        H = k.shape[-1]
        W_affin = q @ k.mT / H**0.5  # B,nh,T,hs @ B,nh,hs,T -> B,nh,T,T
        W_affin = W_affin.masked_fill(self.bias[:,:,:T,:T]==0, float('-inf'))
        W_affin = torch.softmax(W_affin, dim=-1)  # B,nh,T,T
        raw = W_affin @ v    # B,nh,T,T @ B,nh,T,hs -> B,nh,T,hs

        y = raw.transpose(1, 2)  # B,T,nh,hs
        y = y.contiguous()
        y = y.view(B,T,n_embd)

        out = self.c_proj(y)

        return out

In [12]:
csa_m2 = CausalSelfAttentionMarcin2(block_size=block_size, n_head=n_head, n_embd=n_embd)
csa_m2_state = csa_m2.state_dict()
csa_m2_state['c_attn.weight'] = c_attn_W.clone()
csa_m2_state['c_attn.bias'] = c_attn_b.clone()
csa_m2_state['c_proj.weight'] = c_proj_W.clone()
csa_m2_state['c_proj.bias'] = c_proj_b.clone()
csa_m2.load_state_dict(csa_m2_state)

y_m2 = csa_m2(x)
print(y_m2.shape)
print(y_m2.sum().item())

torch.Size([4, 1024, 768])
-18783.59375


In [None]:
# My implementation - fused fast attention

class CausalSelfAttentionMarcin2(nn.Module):
    """Multiple self-attention heads"""
    def __init__(self, block_size, n_head, n_embd):
        super().__init__()

        assert n_embd % n_head == 0
        self.n_head = n_head

        self.c_attn = nn.Linear(n_embd, 3*n_embd)
        self.c_proj = nn.Linear(n_embd, n_embd)
        self.register_buffer('bias', torch.tril(torch.ones((1, 1, block_size, block_size))))

    def forward(self, x):
        B, T, C = x.size()
        qkv = self.c_attn(x)
        q, k, v = qkv.split(n_embd, dim=2)  # B, T, nh*hs
        q = q.view(B, T, self.n_head, C//self.n_head)  # B,T,nh,hs
        k = k.view(B, T, self.n_head, C//self.n_head)  # B,T,nh,hs
        v = v.view(B, T, self.n_head, C//self.n_head)  # B,T,nh,hs
        q = q.transpose(1, 2)  # B,nh,T,hs
        k = k.transpose(1, 2)  # B,nh,T,hs
        v = v.transpose(1, 2)  # B,nh,T,hs

        H = k.shape[-1]
        W_affin = q @ k.mT / H**0.5  # B,nh,T,hs @ B,nh,hs,T -> B,nh,T,T
        W_affin = W_affin.masked_fill(self.bias[:,:,:T,:T]==0, float('-inf'))
        W_affin = torch.softmax(W_affin, dim=-1)  # B,nh,T,T
        raw = W_affin @ v    # B,nh,T,T @ B,nh,T,hs -> B,nh,T,hs

        y = raw.transpose(1, 2)  # B,T,nh,hs
        y = y.contiguous()
        y = y.view(B,T,n_embd)

        out = self.c_proj(y)

        return out

In [11]:
# Appendix - Check linear layer equivalence
lin = nn.Linear(n_embd, n_embd)
sd = lin.state_dict()
for k, v in sd.items():
    print(k, v.shape)
sd['weight'] = c_proj_W.clone()
sd['bias'] = c_proj_b.clone()
lin.load_state_dict(sd)

y = lin(x)
print(y[:3,:3,:3])
y2 = x @ c_proj_W.T + c_proj_b
print(y2[:3,:3,:3])
print(torch.allclose(y, y2))
print((y - y2).abs().max())

weight torch.Size([768, 768])
bias torch.Size([768])
tensor([[[ 3.9189, -0.4557, -0.6260],
         [-0.3068, -1.9111, -0.4942],
         [ 1.1698, -2.4789,  0.7141]],

        [[ 2.0960, -2.3606, -1.1278],
         [ 1.2358, -2.2632, -0.9145],
         [ 1.7210, -1.1523,  0.2450]],

        [[ 1.8260, -3.1186, -0.2035],
         [ 1.1445, -0.6755,  0.3767],
         [ 1.1435, -1.0695, -1.0294]]], grad_fn=<SliceBackward0>)
tensor([[[ 3.9189, -0.4557, -0.6260],
         [-0.3068, -1.9111, -0.4942],
         [ 1.1698, -2.4789,  0.7141]],

        [[ 2.0960, -2.3606, -1.1278],
         [ 1.2358, -2.2632, -0.9145],
         [ 1.7210, -1.1523,  0.2450]],

        [[ 1.8260, -3.1186, -0.2035],
         [ 1.1445, -0.6755,  0.3767],
         [ 1.1435, -1.0695, -1.0294]]])
False
tensor(9.5367e-07, grad_fn=<MaxBackward1>)
