In [2]:
import numpy as np
import torch
import torch.nn as nn

In [12]:
def softmax(Z):
    Z = np.exp(Z - Z.max(axis=-1, keepdims=True))
    return Z / Z.sum(axis=-1, keepdims=True)

def self_attention(X, mask, W_KQV, W_out):
    K, Q, V = np.split(X @ W_KQV, 3, axis=1)
    attn = softmax(K @ Q.T / np.sqrt(X.shape[1]) + mask)
    return attn @ V @ W_out, attn

In [4]:
T, d = 100, 64
attn = nn.MultiheadAttention(d, 1, bias=False, batch_first=True)
M = torch.triu(-float('inf') * torch.ones(T, T), 1)
X = torch.randn(1, T, d)
Y_, A_ = attn(X, X, X, attn_mask=M)

In [13]:
Y, A = self_attention(X[0].numpy(), M.numpy(), attn.in_proj_weight.detach().numpy().T, attn.out_proj.weight.detach().numpy().T)

In [14]:
np.linalg.norm(Y - Y_.detach().numpy())

1.18326e-06

# Minibatch

In [15]:
C = np.random.randn(5, 4, 10, 3)
D = np.random.randn(3, 6)
(C @ D).shape

(5, 4, 10, 6)

In [22]:
def self_attention(X, mask, W_KQV, W_out):
    K, Q, V = np.split(X @ W_KQV, 3, axis=-1)
    attn = softmax(K @ Q.swapaxes(-1, -2) / np.sqrt(X.shape[-1]) + mask)
    return attn @ V @ W_out, attn

In [17]:
B, T, d = 50, 100, 64
X = torch.randn(B, T, d)
M = torch.triu(-float('inf') * torch.ones(T, T), 1)
Y_, A_ = attn(X, X, X, attn_mask=M)

In [23]:
Y, A = self_attention(X.numpy(), M.numpy(), attn.in_proj_weight.detach().numpy().T, attn.out_proj.weight.detach().numpy().T)

In [24]:
np.linalg.norm(Y - Y_.detach().numpy())

8.6435775e-06

# Multihead Attention

In [32]:
def multihead_attention(X, mask, heads, W_KQV, W_out):
    B, T, d = X.shape
    K, Q, V = np.split(X @ W_KQV, 3, axis=-1)
    K, Q, V = [a.reshape(B, T, heads, d // heads).swapaxes(1, 2) for a in (K, Q, V)]
    attn = softmax((K @ Q.swapaxes(-1, -2)) / np.sqrt(d // heads) + mask)
    return (attn @ V).swapaxes(1, 2).reshape(B, T, d) @ W_out, attn

In [33]:
heads = 4
attn = nn.MultiheadAttention(d, heads, bias=False, batch_first=True)
Y_, A_ = attn(X, X, X, attn_mask=M)
Y, A = multihead_attention(X.numpy(), M.numpy(), heads, attn.in_proj_weight.detach().numpy().T, attn.out_proj.weight.detach().numpy().T)
np.linalg.norm(Y - Y_.detach().numpy())

8.918194e-06

# Transformer Block

In [34]:
def layer_norm(Z, eps):
    return (Z - Z.mean(axis=-1, keepdims=True)) / np.sqrt(Z.var(axis=-1, keepdims=True) + eps)


def relu(Z):
    return np.maximum(Z, 0)


def transformer(X, mask, heads, W_KQV, W_out, W_ff1, W_ff2, eps):
    Z = layer_norm(X + multihead_attention(X, mask,
                   heads, W_KQV, W_out)[0], eps)
    return layer_norm(Z + relu(Z @ W_ff1) @ W_ff2, eps)

In [42]:
trans = nn.TransformerEncoderLayer(d, heads, dim_feedforward=128, dropout=0.0, batch_first=True)
trans.linear1.bias.data.zero_()
trans.linear2.bias.data.zero_()
Y_ = trans(X, M)

In [43]:
Y = transformer(
    X.numpy(), 
    M.numpy(), 
    heads, 
    trans.self_attn.in_proj_weight.detach().numpy().T, 
    trans.self_attn.out_proj.weight.detach().numpy().T, 
    trans.linear1.weight.detach().numpy().T, 
    trans.linear2.weight.detach().numpy().T, 
    trans.norm1.eps
    )

In [45]:
np.linalg.norm(Y - Y_.detach().numpy())

5.888834e-05