# Transformers & Attention

In [1]:
%cd ../..

/home/karimgamaleldin/projects/KTorch


In [2]:
# imports 
import numpy as np
from nn import MultiheadAttention, Linear, ReLU, MSELoss
from autograd import Tensor
from core import KTorch

### Multi-head attention

In [3]:
# Create input
np.random.seed(0)
x = np.random.randn(16, 10, 64).astype(np.float32)
y = np.random.randint(0, 100, 1).astype(np.float32)

x_t = Tensor(x)
y_t = Tensor(y)

# Create MultiheadAttention
mha = MultiheadAttention(64, 8, add_bias_kv=True)
linear = Linear(64, 1)
relu = ReLU()
mha_output  = mha(x_t, x_t, x_t)
linear_output = linear(mha_output[:, -1, :])
relu_output = relu(linear_output)
relu_output.shape

(16, 1)

In [4]:
# Manual calculation
q_w, k_w, v_w = mha.q.weight.data, mha.k.weight.data, mha.v.weight.data
q_b, k_b, v_b = mha.q.bias.data, mha.k.bias.data, mha.v.bias.data

q = x @ q_w + q_b
k = x @ k_w + k_b
v = x @ v_w + v_b

print((q == mha.q(x_t).data).all(), (k == mha.k(x_t).data).all(), (v == mha.v(x_t).data).all())

q = q.reshape(16, 10, 8, 8).transpose(0, 2, 1, 3)
k = k.reshape(16, 10, 8, 8).transpose(0, 2, 1, 3)
v = v.reshape(16, 10, 8, 8).transpose(0, 2, 1, 3)

attention = np.matmul(q, k.transpose(0, 1, 3, 2))
attention = attention * np.sqrt(8, dtype=np.float32)**-1
t = np.exp(attention - np.max(attention, axis=-1, keepdims=True)).astype(np.float32)
attention = t / np.sum(t, axis=-1, keepdims=True).astype(np.float32)
output = attention @ v
output = output.transpose(0, 2, 1, 3).reshape(16, 10, 64)
output = output @ mha.out.weight.data + mha.out.bias.data
output = output[:, -1, :]
lin_manual = output @ linear.weight.data + linear.bias.data
relu_manual = np.maximum(0, lin_manual)
relu_manual.shape

True True True


(16, 1)

In [5]:
# Check equivalence
print(np.equal(relu_output.data, relu_manual).all())

True


In [6]:
# Loss
loss = MSELoss()
loss_value = loss(relu_output, y_t)

# Manual calculation
loss_manual = np.mean((relu_manual - y)**2)
np.equal(loss_value.data, loss_manual).all()

True

In [7]:
# Backward
loss_value.backward()

ValueError: non-broadcastable output operand with shape (64,64) doesn't match the broadcast shape (16,64,64)