In [1]:
import numpy as np

In [2]:
def softmax(x):
    x_exp = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return x_exp / np.sum(x_exp, axis=-1, keepdims=True)

In [3]:
# numpy has the same broadcasting rules as PyTorch when there is no ambiguity 
B, C, E = 11, 123, 30000 # batch, context length, embedding dim
q = np.random.normal(loc=0, scale=1, size=(B, C, E))
k = np.random.normal(loc=0, scale=1, size=(C, E))
v = np.random.normal(loc=0, scale=1, size=(C, E))

In [4]:
# Since q and k are independently random vectors i.e. N(μ=0, σ^2=1), the dot product grows proportionally to E..
logits = q @ k.T

# softmax would converge to one hot vectors if uncontrolled
# this is one way to regularize the values (sqrt(E) because we are summing E independent terms)
logits_regd = logits / np.sqrt(E)
print("variance", np.var(logits_regd, axis=-1).round(2))

variance [[1.06 1.   1.13 ... 1.01 0.99 1.05]
 [0.96 1.01 1.04 ... 0.85 0.94 1.1 ]
 [0.98 0.87 0.98 ... 0.9  1.12 1.01]
 ...
 [1.01 1.08 0.9  ... 0.84 0.95 1.15]
 [0.99 1.01 0.84 ... 0.9  0.99 0.87]
 [1.12 1.   1.12 ... 0.89 1.17 1.06]]


In [5]:
# Classical self-attention
weights = softmax(logits)
print(logits.shape, weights.shape, weights.sum(-1))
attention = weights @ v
print(attention.shape)

(11, 123, 123) (11, 123, 123) [[1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 1. 1. 1.]
 ...
 [1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 1. 1. 1.]]
(11, 123, 30000)


In [6]:
# Causal self-attention
mask = np.where(np.triu(np.ones((C, C)), k=1) == 1, -np.inf, 0) # Fill upper triangle with -inf
print("idx 0. mask\n", mask)
print("idx 0. logit\n", logits_regd[0].round(2))
print("idx 0. masked\n", logits_regd[0].round(2) + mask)
print("idx 0. softmax\n", softmax(logits_regd[0].round(2) + mask))
print("idx 0. probs check\n", softmax(logits_regd[0].round(2) + mask).sum(-1))

causal_self_attention = softmax(logits_regd + mask) @ v
print("Full output shape", causal_self_attention.shape)

idx 0. mask
 [[  0. -inf -inf ... -inf -inf -inf]
 [  0.   0. -inf ... -inf -inf -inf]
 [  0.   0.   0. ... -inf -inf -inf]
 ...
 [  0.   0.   0. ...   0. -inf -inf]
 [  0.   0.   0. ...   0.   0. -inf]
 [  0.   0.   0. ...   0.   0.   0.]]
idx 0. logit
 [[ 0.5   0.82 -0.27 ...  0.8  -0.73 -0.12]
 [ 0.92  0.06  0.73 ... -0.24  0.03 -0.32]
 [ 0.54  0.66  0.68 ... -0.44 -1.71 -1.65]
 ...
 [-1.4  -0.95  0.07 ... -0.41 -0.42 -0.93]
 [ 2.1  -0.56  0.62 ... -1.16 -0.78 -0.3 ]
 [ 0.59  1.14  1.33 ... -1.5  -0.43  0.35]]
idx 0. masked
 [[ 0.5   -inf  -inf ...  -inf  -inf  -inf]
 [ 0.92  0.06  -inf ...  -inf  -inf  -inf]
 [ 0.54  0.66  0.68 ...  -inf  -inf  -inf]
 ...
 [-1.4  -0.95  0.07 ... -0.41  -inf  -inf]
 [ 2.1  -0.56  0.62 ... -1.16 -0.78  -inf]
 [ 0.59  1.14  1.33 ... -1.5  -0.43  0.35]]
idx 0. softmax
 [[1.         0.         0.         ... 0.         0.         0.        ]
 [0.70266065 0.29733935 0.         ... 0.         0.         0.        ]
 [0.30508541 0.34398284 0.35093175 ... 0