In [1]:
import numpy as np
import torch
import torch.nn as nn

In [2]:
def softmax(Z):
    Z = np.exp(Z - Z.max(axis=-1, keepdims=True))
    return Z / Z.sum(axis=-1, keepdims=True)

def self_attention(X, mask, W_KQV, W_out):
    K,Q,V = np.split(X@W_KQV, 3, axis=-1)
    attn = softmax(K@Q.swapaxes(-1, -2) / np.sqrt(X.shape[-1]) + mask)
    return attn@V@W_out, attn

In [3]:
T = 5
M = torch.triu(-float("inf")*torch.ones(T,T),1)

In [4]:
T, d = 100, 64
attn = nn.MultiheadAttention(d, 1, bias=False, batch_first=True)
M = torch.triu(-float("inf")*torch.ones(T,T),1)
X = torch.randn(1,T,d)
Y_, A_ = attn(X,X,X, attn_mask=M)

In [5]:
Y, A = self_attention(X[0].numpy(), M.numpy(), 
                      attn.in_proj_weight.detach().numpy().T,
                      attn.out_proj.weight.detach().numpy().T)

In [6]:
Y.shape

(100, 64)

In [7]:
print(np.linalg.norm(A - A_[0].detach().numpy()))
print(np.linalg.norm(Y - Y_[0].detach().numpy()))

2.4646016e-07
1.4785239e-06


# mini batching

In [8]:
C = np.random.randn(5, 4, 10, 3)
D = np.random.randn(5, 4, 3, 6)
(C@D).shape

(5, 4, 10, 6)

In [9]:
B, T, d = 50, 100, 64
X = torch.randn(B, T, d)
M = torch.triu(-float('inf')*torch.ones(T, T), 1)
Y_, A_ = attn(X, X, X, attn_mask=M)

In [10]:
Y, A = self_attention(X.numpy(), M.numpy(), 
                      attn.in_proj_weight.detach().numpy().T,
                      attn.out_proj.weight.detach().numpy().T)

In [11]:
print(np.linalg.norm(A - A_.detach().numpy()))

1.8331532e-06


# Multihead attention

In [12]:
def multihead_attention(X, mask, heads, W_KQV, W_out):
    B, T, d = X.shape
    K,Q,V = np.split(X@W_KQV, 3, axis=-1)
    # B * T * d => B * heads * T * d/heads
    K, Q, V = [a.reshape(B, T, heads, d // heads).swapaxes(1,2) for a in (K, Q, V)]
    attn = softmax(K@Q.swapaxes(-1, -2) / np.sqrt(d // heads) + mask)
    return (attn@V).swapaxes(1, 2).reshape(B, T, d) @ W_out, attn

In [13]:
heads = 4

In [14]:
attn = nn.MultiheadAttention(d, heads, bias=False, batch_first=True)
Y_, A_ = attn(X, X, X, attn_mask = M)

In [15]:
Y, A = multihead_attention(X.numpy(), M.numpy(), heads,
                      attn.in_proj_weight.detach().numpy().T,
                      attn.out_proj.weight.detach().numpy().T)

In [16]:
np.linalg.norm(Y - Y_.detach().numpy())

1.1083209e-05

# Transformer Block

In [29]:
def layer_norm(Z, eps):
    return (Z - Z.mean(axis=-1, keepdims=True)) / np.sqrt(Z.var(axis=-1, keepdims=True) + eps)
   
def relu(Z):
    return np.maximum(Z, 0)

def transformer(X, mask, heads, W_KQV, W_out, W_ff1, W_ff2, eps):
    Z = layer_norm(multihead_attention(X, mask, heads, W_KQV, W_out)[0] + X, eps)
    return layer_norm(Z + relu(Z@W_ff1)@W_ff2, eps)

In [30]:
trans = nn.TransformerEncoderLayer(d, heads, dim_feedforward=128, dropout=0.0, batch_first=True)
trans.linear1.bias.data.zero_()
trans.linear2.bias.data.zero_();
Y_ = trans(X, M)


In [31]:
Y = transformer(X.numpy(), M.numpy(), heads,
                trans.self_attn.in_proj_weight.detach().numpy().T, 
                trans.self_attn.out_proj.weight.detach().numpy().T,
                trans.linear1.weight.detach().numpy().T,
                trans.linear2.weight.detach().numpy().T,
                trans.norm1.eps)

In [32]:
np.linalg.norm(Y - Y_.detach().numpy())

6.323718e-05