In [7]:
import numpy as np

In [8]:
def softmax(x):
    x_exp = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return x_exp / np.sum(x_exp, axis=-1, keepdims=True)

In [9]:
# numpy has the same broadcasting rules as PyTorch when there is no ambiguity 
B, C, E = 2, 5, 3 # batch, context length, embedding dim
q = np.random.rand(B, C, E)
k = np.random.rand(C, E)
v = np.random.rand(C, E)

In [10]:
# Since q and k are independently random vectors i.e. N(μ=0, σ^2=1), the dot product grows proportionally to E..
logits = q @ k.T

# softmax would converge to one hot vectors if uncontrolled
# this is one way to regularize the values (sqrt(E) because we are summing E independent terms)
logits_regd = logits / np.sqrt(E)

In [11]:
# Classical self-attention
weights = softmax(logits)
print(logits.shape, weights.shape, weights.sum(-1))
attention = weights @ v
print(attention.shape)

(2, 5, 5) (2, 5, 5) [[1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]]
(2, 5, 3)


In [12]:
# Causal self-attention
mask = np.where(np.triu(np.ones((C, C)), k=1) == 1, -np.inf, 0) # Fill upper triangle with -inf
print("idx 0. mask\n", mask)
print("idx 0. logit\n", logits_regd[0].round(2))
print("idx 0. masked\n", logits_regd[0].round(2) + mask)
print("idx 0. softmax\n", softmax(logits_regd[0].round(2) + mask))
print("idx 0. probs check\n", softmax(logits_regd[0].round(2) + mask).sum(-1))

causal_self_attention = softmax(logits_regd + mask) @ v
print("Full output shape", causal_self_attention.shape)

idx 0. mask
 [[  0. -inf -inf -inf -inf]
 [  0.   0. -inf -inf -inf]
 [  0.   0.   0. -inf -inf]
 [  0.   0.   0.   0. -inf]
 [  0.   0.   0.   0.   0.]]
idx 0. logit
 [[0.7  0.38 0.2  0.2  0.48]
 [0.36 0.18 0.12 0.6  0.44]
 [0.85 0.49 0.24 0.45 0.73]
 [0.6  0.32 0.18 0.38 0.49]
 [0.46 0.26 0.14 0.58 0.55]]
idx 0. masked
 [[0.7  -inf -inf -inf -inf]
 [0.36 0.18 -inf -inf -inf]
 [0.85 0.49 0.24 -inf -inf]
 [0.6  0.32 0.18 0.38 -inf]
 [0.46 0.26 0.14 0.58 0.55]]
idx 0. softmax
 [[1.         0.         0.         0.         0.        ]
 [0.54487889 0.45512111 0.         0.         0.        ]
 [0.44622395 0.31131988 0.24245617 0.         0.        ]
 [0.31100819 0.23505494 0.20434695 0.24958992 0.        ]
 [0.2097953  0.17176587 0.15234266 0.23654354 0.22955263]]
idx 0. probs check
 [1. 1. 1. 1. 1.]
Full output shape (2, 5, 3)
