In [63]:
import numpy as np
import torch
import torch.nn as nn

# Attention
Let's begin by implementing 
$$ y = softmax(\frac{QK^{T}}{\sqrt(d)}) V$$

the main idea is to create attention just as a module/layer on its own. It takes as input x and applies attention on it, with the following conventional form:
$$ y = (softmax(\frac{xW_{K} * W_{Q}x}{\sqrt(d)}) W_{v}x) W_{out}$$

where $$W_{out}$$ is just a linear projection layer to project the new representations obtained via attention mechanism in another dimensional space. This could be integrated inside $$ W_{V} $$ but is better to do not and to divide the 2 different tasks (in particular in the multihead attention case).
 

In [64]:
def softmax(Z):
    Z = np.exp(Z- Z.max(axis=-1, keepdims=True))
    return Z / Z.sum(axis=-1, keepdims=True)

In [65]:
def self_attention(X, mask, W_KQV, W_out):
    K,Q,V = np.split(X @ W_KQV, 3, axis=1) # instead of doing 3 mm, is better to do a single large and then split
    attn = softmax(K @ Q.T / np.sqrt(X.shape[1]) + mask) # mask to avoid spilling info from current/past timesteps to future
    return attn @ V @ W_out, attn

In [66]:
# nn.MultiheadAttention?
# Init signature:
# nn.MultiheadAttention(
#     embed_dim,
#     num_heads,
#     dropout=0.0,
#     bias=True,
#     add_bias_kv=False,
#     add_zero_attn=False,
#     kdim=None,
#     vdim=None,
#     batch_first=False,
#     device=None,
#     dtype=None,
# ) -> None

In [67]:
NELEMS, FEATURES = 100, 64
pyt_attn = nn.MultiheadAttention(embed_dim=FEATURES, num_heads=1, bias=False, batch_first=True)
M = torch.triu(-float("inf")* torch.ones(NELEMS, NELEMS), diagonal=1)

X = torch.randn(1, NELEMS, FEATURES)
Y_true, attn_true = pyt_attn(X, X, X, attn_mask=M) # multihead_attn(query, key, value)
# attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions

In [68]:
Y, attn = self_attention(X[0].numpy(), M.numpy(),
                                   pyt_attn.in_proj_weight.detach().numpy().T,
                                   pyt_attn.out_proj.weight.detach().numpy().T)

In [69]:
attn

array([[1.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.74691534, 0.2530846 , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.23462467, 0.17931736, 0.58605796, ..., 0.        , 0.        ,
        0.        ],
       ...,
       [0.01249573, 0.02230426, 0.00682052, ..., 0.00766308, 0.        ,
        0.        ],
       [0.01006045, 0.00422531, 0.00532255, ..., 0.02608492, 0.01317528,
        0.        ],
       [0.00710144, 0.01114072, 0.00821384, ..., 0.00815385, 0.03464449,
        0.00397532]], dtype=float32)

In [70]:
np.linalg.norm(attn - attn_true.detach().numpy())

1.6083706e-07

In [71]:
np.linalg.norm(Y - Y_true.detach().numpy())

1.1628191e-06

# Minibatching

In [72]:
def self_attention(X, mask, W_KQV, W_out):
    K,Q,V = np.split(X @ W_KQV, 3, axis=-1) # instead of doing 3 mm, is better to do a single large and then split
    attn = softmax(K @ Q.swapaxes(-1, -2) / np.sqrt(X.shape[-1]) + mask) # mask to avoid spilling info from current/past timesteps to future
    return attn @ V @ W_out, attn

In [73]:
BATCH, NELEMS, FEATURES = 50, 100, 64
pyt_attn = nn.MultiheadAttention(embed_dim=FEATURES, num_heads=1, bias=False, batch_first=True)
M = torch.triu(-float("inf")* torch.ones(NELEMS, NELEMS), diagonal=1)

X = torch.randn(BATCH, NELEMS, FEATURES)
Y_true, attn_true = pyt_attn(X, X, X, attn_mask=M) # multihead_attn(query, key, value)
# attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions

In [74]:
Y, attn = self_attention(X.numpy(), M.numpy(),
                                   pyt_attn.in_proj_weight.detach().numpy().T,
                                   pyt_attn.out_proj.weight.detach().numpy().T)

In [75]:
np.linalg.norm(attn - attn_true.detach().numpy())

1.5335097e-06

In [76]:
np.linalg.norm(Y - Y_true.detach().numpy())

1.0895182e-05

# Multihead Attention

In [None]:
# in single head attention each el of the Q@K.T matrix is the dot prod between q_i and k_j