# numpy implementation

In [1]:
import numpy as np
import torch
import torch.nn as nn

Y = softmax(K * Q.T / Vd) * V

K, Q, V = inputs to self attention

we try to make the self attention a module in and of itself and include the weights that manipulate the input as part of the actual layer.
We'll try to implement something like the follows:

Y = softmax(X * Wk * Wq.T * x) * X * Wv) * Wout/ sqrt(d) ........ self attention layer
     where, Wk, Wq, Wv, Wout are parameters
     Wout will actually be implemented as a linear layer
     
     we are ignoring the biases and setting them to zero

In [11]:
def softmax(Z):
    Z = np.exp(Z - Z.max(axis=-1, keepdims=True))#always taking softmax of the last layer
    return Z / Z.sum(axis=-1, keepdims=True)#normalize Z

def self_attention(X, mask, W_KQV, W_out):
    # instead of seperately multiplying
    # K = X * Wq
    # Q = X * Wq
    # V = X * Wv
    # we create matrices: [K Q V] = X * [Wk  Wq  Wv] = X * W_KQV
    K, Q, V = np.split(X@W_KQV, 3, axis=1)   # we want the resulting matrices to be consistent, of the same size
    # splitting it over the second matrix
    
    # X would be txd dimensional matrix
    attn = softmax(K@Q.T / np.sqrt(X.shape[1]) + mask)
    
    return attn@V@W_out, attn # it is good to return the actual output of the attention layer and the attention matrix itself 

# pytorch implementation

In [7]:
T, d = 100, 64    # d - hidden dimension
attn = nn.MultiheadAttention(d, 1, bias=False, batch_first=True)
M = torch.triu(-float("inf")*torch.ones(T, T), 1) # upper triangular with entries as -inf
X = torch.randn(1, T, d)
Y_, A_ = attn(X, X, X, attn_mask = M)  # true output  # pytorch convention to give the input three different times to give three different components of it

In [8]:
attn.in_proj_weight.shape

torch.Size([192, 64])

In [12]:
Y, A = self_attention(X[0].numpy(), M.numpy(), 
                      attn.in_proj_weight.detach().numpy().T, 
                      attn.out_proj.weight.detach().numpy().T) # attn.out_proj = W_out

In [13]:
np.linalg.norm(Y - Y_[0].detach().numpy())

1.0793809e-06

## minibatch

for a single minibatch example, K will be of dimension Txd

if we have many minibatches, we want to contain more than one batch in a tensor

In RNNs we had T x batch_size x d

For Transformers we do batch_size x T x d 
thats why we did batch_first = True

### batch matrix multiplication (BMM)

batch multiplicatiion vs a single big multiplication '@' operator

In [14]:
C = np.random.randn(5, 4, 10, 3)
D = np.random.randn(3, 6)
(C@D).shape

(5, 4, 10, 6)

In [15]:
C.reshape(-1, 3).shape

(200, 3)

In [18]:
(C.reshape(-1, 3) @ D).reshape(5, 4, 10, 6)

array([[[[ 1.02717692e+00,  8.21823268e-03,  2.02327841e+00,
          -3.15319705e+00, -2.29444417e+00, -2.54323402e+00],
         [-1.23547753e+00, -5.77141305e-01, -9.33934721e-01,
           2.76233799e+00,  1.19655706e+00,  1.05973464e+00],
         [ 1.89762579e+00,  7.73718134e-01,  1.60330800e+00,
          -3.65269988e+00, -1.81517128e+00, -2.13318242e+00],
         ...,
         [-2.61846089e+00, -1.79046769e+00, -4.24185160e-01,
           4.48281260e+00,  8.29589404e-01,  2.93113769e-01],
         [-1.85535457e+00, -1.08410920e+00, -5.80924178e-01,
           2.23480597e+00,  5.60950331e-01,  1.03186110e+00],
         [-4.41042969e+00, -3.14262290e+00,  1.11497803e-01,
           4.30181498e+00, -2.18064722e-01,  4.57400066e-01]],

        [[-9.39807324e-02,  1.31683046e-01, -4.95477399e-01,
           2.84469885e-01,  4.72310457e-01,  7.32721226e-01],
         [ 1.48667235e+00,  9.53396616e-01,  1.69274697e-01,
          -1.19235373e+00, -2.96242764e-02, -5.88732228e-01],


In [19]:
C@D

array([[[[ 1.02717692e+00,  8.21823268e-03,  2.02327841e+00,
          -3.15319705e+00, -2.29444417e+00, -2.54323402e+00],
         [-1.23547753e+00, -5.77141305e-01, -9.33934721e-01,
           2.76233799e+00,  1.19655706e+00,  1.05973464e+00],
         [ 1.89762579e+00,  7.73718134e-01,  1.60330800e+00,
          -3.65269988e+00, -1.81517128e+00, -2.13318242e+00],
         ...,
         [-2.61846089e+00, -1.79046769e+00, -4.24185160e-01,
           4.48281260e+00,  8.29589404e-01,  2.93113769e-01],
         [-1.85535457e+00, -1.08410920e+00, -5.80924178e-01,
           2.23480597e+00,  5.60950331e-01,  1.03186110e+00],
         [-4.41042969e+00, -3.14262290e+00,  1.11497803e-01,
           4.30181498e+00, -2.18064722e-01,  4.57400066e-01]],

        [[-9.39807324e-02,  1.31683046e-01, -4.95477399e-01,
           2.84469885e-01,  4.72310457e-01,  7.32721226e-01],
         [ 1.48667235e+00,  9.53396616e-01,  1.69274697e-01,
          -1.19235373e+00, -2.96242764e-02, -5.88732228e-01],


In [21]:
C = np.random.randn(5, 10, 3)
D = np.random.randn(5, 3, 6)
(C@D).shape

(5, 10, 6)

In [22]:
# self attention in batch form: works for both batch/ non-batch
def self_attention(X, mask, W_KQV, W_out):
    K, Q, V = np.split(X@W_KQV, 3, axis=-1)
    attn = softmax(K@Q.swapaxes(-1, -2) / np.sqrt(X.shape[-1]) + mask)
    return attn@V@W_out, attn

In [23]:
B, T, d = 50, 100, 64
X = torch.randn(B, T, d)
M = torch.triu(-float("inf")*torch.ones(T, T), 1)
Y_, A_ = attn(X, X, X, attn_mask = M)

In [24]:
Y, A = self_attention(X.numpy(), M.numpy(), 
                      attn.in_proj_weight.detach().numpy().T, 
                      attn.out_proj.weight.detach().numpy().T) 

In [25]:
np.linalg.norm(Y - Y_.detach().numpy())

7.6204237e-06

In [26]:
np.linalg.norm(A - A_.detach().numpy())

9.892931e-07

Now we have all the ingredients to form a transformer

we only need a multihead attention now

## Multihead Attention

Attention is not usually applied in this way

the problem with self attention is that when we are forming the KQ.T matrix,



In [45]:
def multihead_attention(X, mask, heads, W_KQV, W_out):
    B, T, d = X.shape
    K, Q, V = np.split(X@W_KQV, 3, axis=-1)
    # B x T x d
    # B x T x d/heads
    # B x T x d => B x heads x T x d/heads
    
    # K.shape(B, T, heads, d//heads).swapaxes(1,2)
    
    K, Q, V = [a.reshape(B, T, heads, d//heads).swapaxes(1, 2) for a in (K, Q, V)]
    attn = softmax(K@Q.swapaxes(-1, -2) / np.sqrt(d//heads) + mask)
    return (attn@V).swapaxes(1, 2).reshape(B, T, d) @ W_out, attn

In [46]:
heads = 4
attn = nn.MultiheadAttention(d, heads, bias=False, batch_first=True)
Y_, A_ = attn(X, X, X, attn_mask = M)                          

In [47]:
Y, A = multihead_attention(X.numpy(), M.numpy(), heads, 
                             attn.in_proj_weight.detach().numpy().T, 
                             attn.out_proj.weight.detach().numpy().T)  

In [48]:
np.linalg.norm(Y - Y_.detach().numpy())

7.77548e-06

In [49]:
np.linalg.norm(A - A_.detach().numpy())

ValueError: operands could not be broadcast together with shapes (50,4,100,100) (50,100,100) 

In [50]:
A.shape

(50, 4, 100, 100)

In [51]:
A_.shape

torch.Size([50, 100, 100])

In [52]:
np.linalg.norm(A.mean(axis=1) - A_.detach().numpy())

7.814069e-07

# creating a transformer block

1. takes x as input
2. splits them into K, Q, V
3. passes them through self attention
4. adds together the input as a residual layer and puts this as a residual norm -> layer norm
5. Feed forward network -> linear-ReLU-linear-ReLU
6. another residual connection
7. layer norm

In [60]:
def layer_norm(Z, eps):
    return (Z - Z.mean(axis=-1, keepdims = True)) / np.sqrt(Z.var(axis=-1, keepdims=True) + eps)

def relu(Z):
    return np.maximum(Z, 0)

def transformer(X, mask, heads, W_KQV, W_out, W_ff1, W_ff2, eps):
    Z = multihead_attention(X, mask, heads, W_KQV, W_out)
    Z = layer_norm(X + multihead_attention(X, mask, heads, W_KQV, W_out)[0], eps)
    return layer_norm(Z + relu(Z@W_ff1)@W_ff2, eps)

In [61]:
trans = nn.TransformerEncoderLayer(d, heads, dim_feedforward=128, 
                                   dropout=0.0, batch_first=True)
trans.linear1.bias.data.zero_()
trans.linear2.bias.data.zero_()
Y_ = trans(X, M)

In [62]:
Y = transformer(X.numpy(), M.numpy(), heads, 
                trans.self_attn.in_proj_weight.detach().numpy().T,
                trans.self_attn.out_proj.weight.detach().numpy().T,
                trans.linear1.weight.detach().numpy().T,
                trans.linear2.weight.detach().numpy().T,
                trans.norm1.eps
               )

In [63]:
np.linalg.norm(Y - Y_.detach().numpy())

6.161502e-05

In [65]:
np.linalg.norm(A.mean(axis=1) - A_.detach().numpy())

7.814069e-07