# TRANSFORMER

The Transformer was proposed by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin in their paper from 2017:  
<i>[1] "Attention is all you need".</i>

This is, without a doubt, one of the most impactful papers of 2017. The paper presents a lot of improvements to the soft attention mechanism and makes it possible to do <i>seq2seq</i> without recurrence. The proposed "transformer" model is entirely built on the self-attention mechanism without using sequence-aligned recurrent architecture.

![Intro](img/transformer_intro.png "Intro")

Before developing the transformer model we need to develop a few additional fundamental components:
* Self-attention
* Masked decoding
* Multi-headed self-attention
* Adding nonlinearities
* Positional encoding
* Cross-attention

In [1]:
import numpy as np
import sys
sys.path.append("../")

from src.utils.gradient_check import eval_numerical_gradient, rel_error

# for auto-reloading external modules
%load_ext autoreload
%autoreload 2

## ATTENTION AND SELF-ATTENTION

Attention in deep learning can broadly be interpreted as a vector of importance weights.  
In a sequence-to-sequence model a score $ \alpha_{t,i} $ is assigned to the pair of output at position <i>t</i> ($ y_{t} $) and input at position <i>i</i> ($ x_{i} $).  
For example:  

$$ \beta_{t,i} = y_{t}^{T}x_{i} $$  
$$ \alpha_{t,i} = \frac{e^{\beta_{t,i}}}{\displaystyle\sum_{i}{e^{\beta_{t,i}}}} $$


The attention vector $ \alpha_{t} $ defines how much of each source hidden state should be considered for the output at position <i>t</i>  
We compute the attention output as a weighted sum over the source hidden states using the vector $ \alpha $:

$$ \displaystyle c_{t} = \sum_{i} a_{t, i} x_{i} $$

This attention output is then used to generate the next output.

![Attention](img/transformer_attention.png "Attention")

Self-attention, also known as intra-attention, is an attention mechanism relating different positions of a single sequence in order to compute a representation of the same sequence.

![Self Attention](img/transformer_selfattention.png "Self Attention")

Generally, attention can be described as operating on <i>keys</i> $ k_{i} $, <i>queries</i> $ q_{i} $, and <i>values</i> $ v_{i} $: the attention vector $ \alpha_{t} $ is computed using key-query affinities, and the attention output $ c_{t} $ is calculated as a weighted sum over the values.
$$ \beta_{t,i} = q_{t}^{T}k_{i} $$  
$$ \alpha_{t,i} = \frac{e^{\beta_{t,i}}}{\displaystyle\sum_{i}{e^{\beta_{t,i}}}} $$
$$ \displaystyle c_{t} = \sum_{i} a_{t, i} v_{i} $$
In the case of the sequence-to-sequence model we have: $ k_{i} = x_{i} $; $ q_{i} = y_{i} $; $ v_{i} = x_{i} $ - the keys and the values are equal to the source hidden states, and the queries are equal to the outputs.

In self-attention, from each input $ x_{i} $ of the layer we create three vectors: $ k_{i}, q_{i}, v_{i}$. These vectors are obtained by multiplying $ x_{i} $ with three matrices $ W^{K}, W^{Q}, W^{V} $ which are trained during the training process. These matrices allow different aspects of the $ x $ vectors to be used/emphasized in each of the three roles.

![KQV](img/transformer_selfattention_kqv.png "KQV")

One final variation is to scale the vector $ \beta_{t, i} $ just before applying the softmax. The reasoning behind this is that when the dimensionality $ d $ of the vectors (<b>keys, queries, values</b>) becomes large, the dot products tend to also become large. Because of this, inputs to the softmax can also become large, making the gradients small. For this reason we divide the vector $ \beta_{t, i} $ by $ \sqrt{d} $. Thus:

$$ \alpha_{t,i} = \large\frac{e^{\frac{\beta_{t,i}}{\sqrt{d}}}}{\displaystyle\sum_{i}{e^{\frac{\beta_{t,i}}{\sqrt{d}}}}} $$


In [2]:
np.random.seed(42)
np.set_printoptions(precision=3)

T = 2  # number of time-steps
D = 4  # input dimension
M = 3  # output dimension

x = np.random.randn(T, D)
K = np.random.randn(D, M)
Q = np.random.randn(D, M)
V = np.random.randn(D, M)

print("x:\n", x)
print("K:\n", K)
print("Q:\n", Q)
print("V:\n", V)

x:
 [[ 0.497 -0.138  0.648  1.523]
 [-0.234 -0.234  1.579  0.767]]
K:
 [[-0.469  0.543 -0.463]
 [-0.466  0.242 -1.913]
 [-1.725 -0.562 -1.013]
 [ 0.314 -0.908 -1.412]]
Q:
 [[ 1.466 -0.226  0.068]
 [-1.425 -0.544  0.111]
 [-1.151  0.376 -0.601]
 [-0.292 -0.602  1.852]]
V:
 [[-0.013 -1.058  0.823]
 [-1.221  0.209 -1.96 ]
 [-1.328  0.197  0.738]
 [ 0.171 -0.116 -0.301]]


The input to the self-attention layer is a sequence $ [x_{1}, x_{2}, ..., x_{T}] $, and the output is again a sequence of the same lentgth $ [z_{1}, z_{2}, ..., z_{T}] $. Here $ z_{t} $ is the attention output after taking the input $ x_{t} $ and attending to every other input from the sequence (including itself).

$$ z_{t} = \sum_{i} \alpha_{t, i} v_{i} $$
$$ z_{t} = \alpha_{t} V $$

In [3]:
# Naive implementation of softmax.
def softmax(s):
    return np.exp(s) / np.sum(np.exp(s), axis=-1, keepdims=True)

In [4]:
print("keys:\n", x @ K)
print("queries:\n", x @ Q)
print("values:\n", x @ V)

print("\nattention scores:\n", softmax(x @ Q @ K.T @ x.T / np.sqrt(M)))

keys:
 [[-0.807 -1.511 -2.773]
 [-2.264 -1.769 -2.127]]
queries:
 [[-0.265 -0.71   2.45 ]
 [-2.051  0.312  0.431]]
values:
 [[-0.437 -0.603  0.699]
 [-1.677  0.421  1.201]]

attention scores:
 [[0.224 0.776]
 [0.137 0.863]]


In [5]:
def self_attention_naive(x, K, Q, V):
    """
    Inputs:
    - x: A numpy array of shape (T, D).
    - K: A numpy array of shape (D, M) containing the weights for the key matrix.
    - Q: A numpy array of shape (D, M) containing the weights for the query matrix.
    - V: A numpy array of shape (D, M) containing the weights for the value matrix.

    Returns a tuple of:
    - out: Output data of shape (T, M)
    """
    M = K.shape[-1]

    keys = x @ K      # matrix of shape (T, M)
    queries = x @ Q   # matrix of shape (T, M)
    values = x @ V    # matrix of shape (T, M)

    beta = queries @ keys.T    # matrix of shape (T, T)
    beta /= np.sqrt(M)         # scale before applying softmax
    alpha = softmax(beta)
    
    out = alpha @ values

    return out

In [6]:
z = self_attention_naive(x, K, Q, V)
print("\nz shape:", z.shape)
print("z:\n", z)


z shape: (2, 3)
z:
 [[-1.399  0.191  1.089]
 [-1.507  0.28   1.132]]


Looking at the formulas we can see that we could pipeline the operations for computing the output.  
First take the query-key dot products in one matrix multiplication: $ \beta = xQK^{T}x^{T} $.  
Next, softmax, and compute the weighted average with another matrix multiplication: $ z = \text{softmax}(\beta) \space xV $.  

![Matrix form](img/transformer_selfattention_matrix.png "Matrix form")

In [7]:
def self_attention(x, K, Q, V):
    """
    Inputs:
    - x: A numpy array of shape (T, D).
    - K: A numpy array of shape (D, M) containing the weights for the key matrix.
    - Q: A numpy array of shape (D, M) containing the weights for the query matrix.
    - V: A numpy array of shape (D, M) containing the weights for the value matrix.

    Returns a tuple of:
    - out: Output data of shape (T, M)
    """
    M = K.shape[-1]
    return softmax(x @ Q @ K.T @ x.T / np.sqrt(M)) @ x @ V

In [8]:
z = self_attention(x, K, Q, V)
print("z shape:", z.shape)
print("z:\n", z)

z shape: (2, 3)
z:
 [[-1.399  0.191  1.089]
 [-1.507  0.28   1.132]]


To perform the backward pass we need to compute the derivative of the loss $\ell$ with respect to $ x, K, Q, V $ (the loss $ \ell $ is a single number, i.e. $ \ell \in \mathbb{R} $).  
Assume we have the upstream derivative of $\ell$ with respect to a tensor $z$, $\displaystyle \frac{d\ell}{dz}$. Then, for any tenor $y$ such that $z=f(y)$ we have that the partail derivative of $\ell$ with respect to a single element $y_{j}$ of the tensor $y$ is:  
$$ \frac{\partial \ell}{\partial y_{j}} = \sum_{i} \frac{\partial \ell
}{\partial z_{i}} \frac{\partial z_{i}}{\partial y_{j}} $$
where the summation is over all the elements of the tensor $z$.  

Suppose that $C=AB$, where $A$, $B$, and $C$ are matrices, and that $\ell=f(C)$. In this case we have that:

$$
\begin{align*}
& \frac{d\ell}{dB} = A^{T}\frac{d\ell}{dC} \\
& \frac{d\ell}{dA} = \frac{d\ell}{dC}B^{T} \\
\end{align*}
$$

Using these formulas we can derive the derivatives needed for the backward pass of the self-attention layer. To simplify the derivation we will use the following notation for the loss $\ell$ with respect to a tensor $A$: $dA = \displaystyle \frac{d\ell}{dA}$

For the derivative of $\ell$ with respect to $V$ we have the following:
$$
\begin{align*}
& z = (\alpha.x)V \\
& \frac{d\ell}{dV} = (\alpha.x)^{T}\frac{d\ell}{dz} \\
& dV = (\alpha.x)^{T}dz
\end{align*}
$$

For the derivative of $\ell$ with respect to $Q$ we have:
$$
\begin{align*}
& z = \alpha.(x.V) = \text{softmax}(\beta) \space xV \\
& d\alpha = dz.(x.V)^{T} \\
& d\beta = \text{softmax_backward}(d\alpha) \\
& \beta = (xQ).(K^{T}x^{T}) \\
& d(xQ) = d\beta.(K^{T}x^{T})^{T} = d\beta.xK \\
& dQ = x^{T}d(xQ) \\
& dQ = x^{T} d\beta.xK \\
\end{align*}
$$

Similarly for the derivative of $\ell$ with respect to $K$ we have:  
$$
\begin{align*}
& \beta = (xQ).(K^{T}x^{T}) \\
& d(K^{T}x^{T}) = (xQ)^{T}.d\beta = Q^{T}x^{T}.d\beta \\
& d(K^{T}) = d(K^{T}x^{T}).x = Q^{T}x^{T}.d\beta.x \\
& dK = x^{T}d\beta.xQ \\
\end{align*}
$$

The last equation follows from the fact that $ d(K^{T}) $ is actually the matrix $\displaystyle \frac{d\ell}{dK^{T}} $.  
$$
\frac{d\ell}{dK^{T}} =
\begin{bmatrix} 
\frac{\partial \ell}{\partial K_{11}} & \frac{\partial \ell}{\partial K_{21}} & \dots & \frac{\partial \ell}{\partial K_{M1}} \\
\frac{\partial \ell}{\partial K_{12}} & \frac{\partial \ell}{\partial K_{22}} & \dots & \frac{\partial \ell}{\partial K_{M2}} \\
\vdots \\
\frac{\partial \ell}{\partial K_{1D}} & \frac{\partial \ell}{\partial K_{2D}} & \dots & \frac{\partial \ell}{\partial K_{MD}} \\
\end{bmatrix}
$$
Now it is obvious that $ d(K^{T}) = (dK)^{T} $.

Finally, for the derivative of $\ell$ with respect to $x$ we have:  
$$
\begin{align*}
& z = \text{softmax}(xQK^{T}x^{T})xV = \alpha.x.V \\
& \frac{d\ell}{dx} = \frac{d\ell}{dz}\frac{dz}{dx} = \frac{d\ell}{dz}\frac{dz}{d\alpha}\frac{d\alpha}{dx} + \frac{d\ell}{dz}\frac{dz}{d(xV)}\frac{d(xV)}{dx} \\
\end{align*}
$$

Let:

$\displaystyle dx^{(1)} = \frac{d\ell}{dz}\frac{dz}{d\alpha}\frac{d\alpha}{dx} = \frac{d\ell}{dz}\frac{dz}{d\alpha}\frac{d\alpha}{d\beta}\frac{d\beta}{dx} = \frac{d\ell}{d\beta}\frac{d\beta}{dx}$

and

$\displaystyle dx^{(2)} = \frac{d\ell}{dz}\frac{dz}{d(xV)}\frac{d(xV)}{dx} = \frac{d\ell}{d(xV)}\frac{d(xV)}{dx}$  

For the first part of the derivative $dx^{(1)}$ we have:  
$$
\begin{align*}
& \beta = xQK^{T}x^{T} \\
& dx^{(11)} = ((xQK^{T})^{T} d\beta)^{T} = d\beta^{T} xQK^{T}\\
& dx^{(12)} = d\beta(QK^{T}x^{T})^{T} = d\beta x KQ^{T} \\
& dx^{(1)} = d\beta^{T} xQK^{T} + d\beta x KQ^{T} \\
\end{align*}
$$

For the second part of the derivative $dx^{(2)}$ we have:  
$$
\begin{align*}
& z = \alpha.x.V \\
& d(xV) = \alpha^{T}.dz \\
& dx^{(2)} = d(xV).V^{T} = \alpha^{T}.dz.V^{T} \\
\end{align*}
$$

Finally we have:

$$ dx = \alpha^{T}.dz.V^{T} + d\beta.x.KQ^{T} + d\beta^{T}.x.QK^{T} $$

<b>Proof</b> for the second part of the derivative:  
$ \beta = xQK^{T}x^{T} = xAx^{T} = Bx^{T} $  
$ \displaystyle \beta_{mn} = \sum_{r} b_{mr} x_{nr} = \sum_{r} \left( \sum_{s} x_{mr}a_{rs} \right) x_{nr} $  

Thus $\beta_{m,n}$ depends on the <i>m-th</i> and on the <i>n-th</i> rows of the matrix $x$.  
We know that:  

$ \displaystyle \frac{\partial \ell}{\partial x_{ij}} = \sum_{k,l} \frac{\partial \ell
}{\partial \beta_{kl}} \frac{\partial \beta_{kl}}{\partial x_{ij}} $  

Since $ x_{ij} $ is in the <i>i-th</i> row, we know that only the elements from the <i>i-th</i> row and the elements from the <i>i-th</i> column of the matrix $\beta$ will depend on $x_{ij}$. All other partial derivatives will vanish.

$ \displaystyle \frac{\partial \beta_{kl}}{\partial x_{ij}} = 0 $ - for $ k \neq i $ or $ l \neq i $

And so we have:  

$ \displaystyle \frac{\partial \ell}{\partial x_{ij}} = \sum_{l} \frac{\partial \ell}{\partial \beta_{il}} \frac{\partial \beta_{il}}{\partial x_{ij}} +  \sum_{k} \frac{\partial \ell}{\partial \beta_{ki}} \frac{\partial \beta_{ki}}{\partial x_{ij}} $

We have to be careful with the above summation and count the term $ \displaystyle \frac{\partial \ell}{\partial \beta_{ii}}\frac{\partial \beta_{ii}}{\partial x_{ij}} $ only once!

$$
\begin{align*}
& \frac{\partial \beta_{il}}{\partial x_{ij}} = \frac{\displaystyle \partial \sum_{r} b_{ir}x_{lr}}{\partial x_{ij}} = \sum_{r} x_{lr} \frac{\partial b_{ir}}{\partial x_{ij}} = \sum_{r} x_{lr} \frac{\displaystyle \partial \sum_{s} x_{is}a_{sr}}{\partial x_{ij}} = \sum_{r} x_{lr}a_{jr} \space\text{- for $l \neq i$} \\
& \frac{\partial \beta_{ki}}{x_{ij}} = \frac{\displaystyle \partial \sum_{r} b_{kr}x_{ir}}{\partial x_{ij}} = \sum_{r} b_{kr} \frac{\partial x_{ir}}{\partial x_{ij}} = b_{kj} = \sum_{s} x_{ks}a_{sj} \space\text{- for $k \neq i$} \\
& \frac{\partial \beta_{ii}}{\partial x_{ij}} = \frac{\displaystyle \partial \sum_{r} b_{ir}x_{ir}}{\partial x_{ij}} = \sum_{r} x_{ir} \frac{\partial b_{ir}}{\partial x_{ij}} + \sum_{r} b_{ir} \frac{\partial x_{ir}}{\partial x_{ij}} = \sum_{r} x_{ir}a_{jr} + \sum_{s} x_{is}a_{sj} \\
\end{align*}
$$

From the above we can see that the term $ \displaystyle \frac{\partial \beta_{ii}}{\partial x_{ij}} $ nicely completes the other two sums.  
And so we have:  

$$ \frac{\partial \ell}{\partial x_{ij}} = \sum_{l} \frac{\partial \ell}{\partial \beta_{il}} \sum_{r} x_{lr}a_{jr} + \sum_{k} \frac{\partial \ell}{\partial \beta_{ki}} \sum_{s} x_{ks}a_{sj} $$

The product of a sum can be expressed as the inner product between two vectors. And thus the derivative with respect to the entire matrix can be written as a matrix multiplication of three matrices:

$$ \frac{d\ell}{dx} = d\beta.xA^{T} + d\beta^{T}.xA $$

In [9]:
# A more stable implementation of softmax.
def softmax_forward(x):
    shifted = x - np.max(x, axis=-1, keepdims=True)
    out = np.exp(shifted) / np.sum(np.exp(shifted), axis=-1, keepdims=True)
    cache = out
    return out, cache


def softmax_backward(dout, cache):
    out = cache
    N, D = out.shape
    diag = np.expand_dims(out, axis=2) * np.expand_dims(np.identity(D), axis=0)
    outer = np.matmul(np.expand_dims(out, axis=2), np.expand_dims(out, axis=1))
    dx = np.matmul(np.expand_dims(dout, axis=1), diag-outer).squeeze(axis=1)
    return dx

In [10]:
def self_attention_forward(x, K, Q, V):
    """
    Inputs:
    - x: A numpy array of shape (T, D).
    - K: A numpy array of shape (D, M) containing the weights for the key matrix.
    - Q: A numpy array of shape (D, M) containing the weights for the query matrix.
    - V: A numpy array of shape (D, M) containing the weights for the value matrix.

    Returns a tuple of:
    - out: Output data of shape (T, M)
    - cache: Values needed for the backward pass.
    """
    M = K.shape[-1]

    beta = x @ Q @ K.T @ x.T
    beta /= np.sqrt(M)
    alpha, softmax_cache = softmax_forward(beta)
    out = alpha @ x @ V

    cache = (softmax_cache, alpha, x, K, Q, V)
    return out, cache


def self_attention_backward(dout, cache):
    """
    Inputs:
    - dout: Upstream derivative of the self-attention output of shape (T, M)
    - cache: A tuple of values from the forward pass.

    Returns:
    - dx: Gradient with respect to x, of shape (T, D).
    - dK: Gradient with respect to K, of shape (D, M).
    - dQ: Gradient with respect to Q, of shape (D, M).
    - dV: Gradient with respect to V, of shape (D, M).
    """
    softmax_cache, alpha, x, K, Q, V = cache
    M = dout.shape[-1]

    dV = (alpha @ x).T @ dout
    dalpha = dout @ (x @ V).T
    dbeta = softmax_backward(dalpha, softmax_cache) / np.sqrt(M)
    dQ = x.T @ dbeta @ x @ K
    dK = x.T @ dbeta.T @ x @ Q
    dx = alpha.T @ dout @ V.T
    dx += dbeta @ x @ K @ Q.T + dbeta.T @ x @ Q @ K.T

    return dx, dK, dQ, dV

In [11]:
z, cache = self_attention_forward(x, K, Q, V)
print("z shape:", z.shape)
print("z:\n", z)

z shape: (2, 3)
z:
 [[-1.399  0.191  1.089]
 [-1.507  0.28   1.132]]


In [12]:
# Dummy loss function. Returns loss value, and gradient of the loss w.r.t. the input.
def loss(x):
    return np.sum(x), np.ones_like(x)

To check the implementation of the backward pass we will use the numeric gradient. If the implementation is correct, the difference between the numeric and the analytic gradients should be less than 1e-8 for each of the model parameters.

In [13]:
l, dz = loss(z)
dx, dK, dQ, dV = self_attention_backward(dz, cache)
params = {"x":x, "K":K, "Q":Q, "V":V}
grads = {"x":dx, "K":dK, "Q":dQ, "V":dV}

# These should all be less than 1e-8 or so.
f = lambda a: loss(self_attention_forward(x, K, Q, V)[0])[0]
for _name in grads:
    grad_numeric = eval_numerical_gradient(f, params[_name], verbose=False)
    print("%s max relative error: %e" % (_name, rel_error(grad_numeric, grads[_name])))

x max relative error: 1.453740e-10
K max relative error: 2.955696e-09
Q max relative error: 3.938783e-09
V max relative error: 1.953700e-11


Finally, to make the self-attention layer work on a batch of input sequences we need to note that the gradient of the loss $\ell$ with respect to a weight matrix ($K, Q, V$) will be equal to the sum of the gradients from each data-point from the batch.

$$ \frac{d\ell}{dW} = \sum_{i=0}^{N-1} \frac{d\ell}{dW^{i}} $$

In [14]:
def batch_selfattention_forward(x, K, Q, V):
    """
    Inputs:
    - x: A numpy array of shape (N, T, D).
    - K: A numpy array of shape (D, M) containing the weights for the key matrix.
    - Q: A numpy array of shape (D, M) containing the weights for the query matrix.
    - V: A numpy array of shape (D, M) containing the weights for the value matrix.

    Returns a tuple of:
    - out: Output data of shape (N, T, M)
    - cache: Values needed for the backward pass.
    """
    N, T, D = x.shape
    M = K.shape[-1]
 
    beta = x @ np.expand_dims(Q, axis=0) @ np.expand_dims(K.T, axis=0) @ x.transpose(0,2,1)
    beta /= np.sqrt(M)
    alpha, softmax_cache = softmax_forward(beta.reshape(N*T, -1))   # tensor (N, T, T)
    alpha = alpha.reshape(N, T, -1)
    out = alpha @ x @ np.expand_dims(V, axis=0)

    cache = (softmax_cache, alpha, x, K, Q, V)
    return out, cache


def batch_selfattention_backward(dout, cache):
    """
    Inputs:
    - dout: Upstream derivative of the self-attention output of shape (N, T, M)
    - cache: A tuple of values from the forward pass.

    Returns:
    - dx: Gradient with respect to x, of shape (N, T, D).
    - dK: Gradient with respect to K, of shape (D, M).
    - dQ: Gradient with respect to Q, of shape (D, M).
    - dV: Gradient with respect to V, of shape (D, M).
    """
    softmax_cache, alpha, x, K, Q, V = cache
    N, T, M = dout.shape

    dV = np.sum((alpha @ x).transpose(0,2,1) @ dout, axis=0)
    dalpha = dout @ (x @ np.expand_dims(V, axis=0)).transpose(0,2,1)
    dbeta = softmax_backward(dalpha.reshape(N*T,-1), softmax_cache) / np.sqrt(M)
    dbeta = dbeta.reshape(N, T, -1)
    dQ = np.sum(x.transpose(0,2,1) @ dbeta @ x @ np.expand_dims(K, axis=0), axis=0)
    dK = np.sum(x.transpose(0,2,1) @ dbeta.transpose(0,2,1) @ x @ np.expand_dims(Q, axis=0), axis=0)
    dx = alpha.transpose(0,2,1) @ dout @ np.expand_dims(V.T, axis=0)    
    dx += dbeta @ x @ np.expand_dims(K @ Q.T, axis=0)
    dx += dbeta.transpose(0,2,1) @ x @ np.expand_dims(Q @ K.T, axis=0)

    return dx, dK, dQ, dV

In [15]:
N = 3   # batch size
x = np.stack([x] * N)
print("x_batch shape:", x.shape)
z, cache = batch_selfattention_forward(x, K, Q, V)
print("z shape:", z.shape)
print("z:\n", z)

x_batch shape: (3, 2, 4)
z shape: (3, 2, 3)
z:
 [[[-1.399  0.191  1.089]
  [-1.507  0.28   1.132]]

 [[-1.399  0.191  1.089]
  [-1.507  0.28   1.132]]

 [[-1.399  0.191  1.089]
  [-1.507  0.28   1.132]]]


In [16]:
l, dz = loss(z)
dx, dK, dQ, dV = batch_selfattention_backward(dz, cache)
params = {"x":x, "K":K, "Q":Q, "V":V}
grads = {"x":dx, "K":dK, "Q":dQ, "V":dV}

# These should all be less than 1e-8 or so.
f = lambda a: loss(batch_selfattention_forward(x, K, Q, V)[0])[0]
for _name in grads:
    grad_numeric = eval_numerical_gradient(f, params[_name], verbose=False)
    print("%s max relative error: %e" % (_name, rel_error(grad_numeric, grads[_name])))

x max relative error: 1.453740e-10
K max relative error: 4.305619e-09
Q max relative error: 3.938783e-09
V max relative error: 2.859633e-11


## MASKED DECODING

To use self-attention in the decoder, we need to ensure we can't peek at the future. In the decoder, the self-attention layer is only allowed to attend to earlier positions in the output sequece. This is done by masking future positions by setting attention scores to $-\infty$.

![Mask Attention](img/transformer_maskattention.png "Mask Attention")

In practice instead of setting $e_{i,j} = -\infty$ we just replace $\text{exp}(e_{i,j}) = 0$.  
However, this implementation uses the naive approach of setting $e_{i,j} = \text{-1e-20}$.

In [17]:
def masked_batch_selfattention_forward(x, K, Q, V, mask=False):
    """
    Inputs:
    - x: A numpy array of shape (N, T, D).
    - K: A numpy array of shape (D, M) containing the weights for the key matrix.
    - Q: A numpy array of shape (D, M) containing the weights for the query matrix.
    - V: A numpy array of shape (D, M) containing the weights for the value matrix.
    - mask: A numpy array of shape (T, T) of boolean values. Flag=True masks the value.

    Returns a tuple of:
    - out: Output data of shape (N, T, M)
    - cache: Values needed for the backward pass.
    """
    N, T, D = x.shape
    M = K.shape[-1]

    beta = x @ np.expand_dims(Q, axis=0) @ np.expand_dims(K.T, axis=0) @ x.transpose(0,2,1)
    beta /= np.sqrt(M)
    beta += np.expand_dims(mask * (-1e20), axis=0)                  # tricks are for kids
    alpha, softmax_cache = softmax_forward(beta.reshape(N*T, -1))   # tensor (N, T, T)
    alpha = alpha.reshape(N, T, -1)
    out = alpha @ x @ np.expand_dims(V, axis=0)
    cache = (softmax_cache, alpha, mask, x, K, Q, V)
    return out, cache


def masked_batch_selfattention_backward(dout, cache):
    """
    Inputs:
    - dout: Upstream derivative of the self-attention output of shape (N, T, M)
    - cache: A tuple of values from the forward pass.

    Returns:
    - dx: Gradient with respect to x, of shape (N, T, D).
    - dK: Gradient with respect to K, of shape (D, M).
    - dQ: Gradient with respect to Q, of shape (D, M).
    - dV: Gradient with respect to V, of shape (D, M).
    """
    softmax_cache, alpha, mask, x, K, Q, V = cache
    N, T, M = dout.shape

    dV = np.sum((alpha @ x).transpose(0,2,1) @ dout, axis=0)
    dalpha = dout @ (x @ np.expand_dims(V, axis=0)).transpose(0,2,1)
    dalpha *= np.expand_dims(1-mask, axis=0)
    dbeta = softmax_backward(dalpha.reshape(N*T,-1), softmax_cache) / np.sqrt(M)
    dbeta = dbeta.reshape(N, T, -1)
    dQ = np.sum(x.transpose(0,2,1) @ dbeta @ x @ np.expand_dims(K, axis=0), axis=0)
    dK = np.sum(x.transpose(0,2,1) @ dbeta.transpose(0,2,1) @ x @ np.expand_dims(Q, axis=0), axis=0)
    dx = alpha.transpose(0,2,1) @ dout @ np.expand_dims(V.T, axis=0)    
    dx += dbeta @ x @ np.expand_dims(K @ Q.T, axis=0)
    dx += dbeta.transpose(0,2,1) @ x @ np.expand_dims(Q @ K.T, axis=0)

    return dx, dK, dQ, dV

In [18]:
mask = np.array([[False, True],    # the first output will attend only to the first value vector
                 [False, False]],  # the second output will attend to both the first and the second value vector
                dtype=bool)

print("""Observe that the output from the first time step z_1 attends only
to the value at the first time step v_1\n""")

z, cache = masked_batch_selfattention_forward(x, K, Q, V, mask)
print("z shape:", z.shape)
print("z:\n", z)

values = x[0] @ V
print("\nvalues:\n", values)

Observe that the output from the first time step z_1 attends only
to the value at the first time step v_1

z shape: (3, 2, 3)
z:
 [[[-0.437 -0.603  0.699]
  [-1.507  0.28   1.132]]

 [[-0.437 -0.603  0.699]
  [-1.507  0.28   1.132]]

 [[-0.437 -0.603  0.699]
  [-1.507  0.28   1.132]]]

values:
 [[-0.437 -0.603  0.699]
 [-1.677  0.421  1.201]]


In [19]:
l, dz = loss(z)
dx, dK, dQ, dV = masked_batch_selfattention_backward(dz, cache)
params = {"x":x, "K":K, "Q":Q, "V":V}
grads = {"x":dx, "K":dK, "Q":dQ, "V":dV}

# These should all be less than 1e-8 or so.
f = lambda a: loss(masked_batch_selfattention_forward(x, K, Q, V, mask)[0])[0]
for _name in grads:
    grad_numeric = eval_numerical_gradient(f, params[_name], verbose=False)
    print("%s max relative error: %e" % (_name, rel_error(grad_numeric, grads[_name])))

x max relative error: 5.350397e-11
K max relative error: 1.066312e-08
Q max relative error: 5.230506e-09
V max relative error: 2.457892e-11


## MULTI-HEADED SELF-ATTENTION

One problem with the proposed self-attention mechanism is that an output $z_{t}$ will most likely be dominated by a single $v_{i}$ because of the softmax function.

$$ z_{t} = \sum_{i} \alpha_{t,i}v_{i} $$

If we want to "look" at multiple places at once we could define multiple <b>attention heads</b> using multiple $K$,$Q$,$V$ matrices. Each attention head performs attention computation independantly and then the outputs from all heads are concatenated. This way each head gets to "look" at different things, and constructs value vectors differently.  
This modification expands the model's ability to focus on different positions and gives the attention layer multiple representation subspaces.

![Multi-head](img/transformer_multihead.png "Multi-head")

The dimension of the output is equal to the number of attention heads ($h$) multiplied by the dimension of the output of a single head ($d$). Thus the concatenated output has dimension $dh$. Usually we want the output to have the same dimension as the input. To achieve this we simply multiply the concatenated result by a transformation matrix $W$ with dimensions $dh \space x \space D$, where $D$ is the dimension of the input.  

![Multi-head KQV](img/transformer_multiheadKQV.png "Multi-head KQV")

In [20]:
h = 3  # number of attention heads
K = np.stack([K] * h)
Q = np.stack([Q] * h)
V = np.stack([V] * h)

print("K shape:", K.shape)
print("Q shape:", Q.shape)
print("V shape:", V.shape)

W = np.random.randn(h * M, D)
print("\nW shape:", W.shape)
print("W:\n", W)

K shape: (3, 4, 3)
Q shape: (3, 4, 3)
V shape: (3, 4, 3)

W shape: (9, 4)
W:
 [[-1.479 -0.72  -0.461  1.057]
 [ 0.344 -1.763  0.324 -0.385]
 [-0.677  0.612  1.031  0.931]
 [-0.839 -0.309  0.331  0.976]
 [-0.479 -0.186 -1.106 -1.196]
 [ 0.813  1.356 -0.072  1.004]
 [ 0.362 -0.645  0.361  1.538]
 [-0.036  1.565 -2.62   0.822]
 [ 0.087 -0.299  0.092 -1.988]]


In [21]:
def multihead_attention_naive(x, K, Q, V, W, mask=False):
    """
    Inputs:
    - x: A numpy array of shape (N, T, D).
    - K: A numpy array of shape (h, D, M) containing the weights for the key matrix.
    - Q: A numpy array of shape (h, D, M) containing the weights for the query matrix.
    - V: A numpy array of shape (h, D, M) containing the weights for the value matrix.
    - W: A numpy array of shape (hM, D) contaning the weights for the transformation matrix.
    - mask: A numpy array of shape (T, T) of boolean values. Flag=True masks the value.

    Returns a tuple of:
    - out: Output data of shape (N, T, D)
    - cache: Values needed for the backward pass.
    """
    N, T, D = x.shape
    h, _, M = K.shape
    heads_out = []
    
    for i in range(h):
        out, _ = masked_batch_selfattention_forward(x, K[i], Q[i], V[i], mask) # tensor of shape (N, T, M)
        heads_out.append(out)
    heads_out = np.dstack(heads_out)  # tensor of shape (N, T, hM)
    print("heads out:\n", heads_out)
    return heads_out @ np.expand_dims(W, axis=0)

In [22]:
print("x shape:", x.shape)

print("""\nNote that the output before transfromation is equal to
the ouput from a single head concatenaded {:d} times""".format(h))

z = multihead_attention_naive(x, K, Q, V, W)
print("\nOutput after transformation:")
print("z shape:", z.shape)
print("z:\n", z)

print("\nMasked output:")
z = multihead_attention_naive(x, K, Q, V, W, mask)
print("z:\n", z)

x shape: (3, 2, 4)

Note that the output before transfromation is equal to
the ouput from a single head concatenaded 3 times
heads out:
 [[[-1.399  0.191  1.089 -1.399  0.191  1.089 -1.399  0.191  1.089]
  [-1.507  0.28   1.132 -1.507  0.28   1.132 -1.507  0.28   1.132]]

 [[-1.399  0.191  1.089 -1.399  0.191  1.089 -1.399  0.191  1.089]
  [-1.507  0.28   1.132 -1.507  0.28   1.132 -1.507  0.28   1.132]]

 [[-1.399  0.191  1.089 -1.399  0.191  1.089 -1.399  0.191  1.089]
  [-1.507  0.28   1.132 -1.507  0.28   1.132 -1.507  0.28   1.132]]]

Output after transformation:
z shape: (3, 2, 4)
z:
 [[[ 2.946  4.086  0.168 -5.198]
  [ 3.152  4.305 -0.114 -5.654]]

 [[ 2.946  4.086  0.168 -5.198]
  [ 3.152  4.305 -0.114 -5.654]]

 [[ 2.946  4.086  0.168 -5.198]
  [ 3.152  4.305 -0.114 -5.654]]]

Masked output:
heads out:
 [[[-0.437 -0.603  0.699 -0.437 -0.603  0.699 -0.437 -0.603  0.699]
  [-1.507  0.28   1.132 -1.507  0.28   1.132 -1.507  0.28   1.132]]

 [[-0.437 -0.603  0.699 -0.437 -0.603  0

Finally, to make the multi-head self-attention layer work using only matrix algebra we need to note that the gradient of the loss $\ell$ with respect to the input tensor (with shape $(N, T, D)$) will be equal to the sum of the gradients from each attention head.

$$ \frac{d\ell}{dx} = \sum_{i=0}^{h-1} \frac{d\ell}{dx^{i}} $$

In [23]:
def affine_forward(x, w, b=0):
    _x = x.reshape(x.shape[0], -1)
    out = np.dot(_x, w) + b
    cache = (x, w, b)
    return out, cache


def affine_backward(dout, cache):
    x, w, b = cache
    _x = x.reshape(x.shape[0], -1)

    db = np.sum(dout, axis=0)
    dw = np.dot(_x.T, dout)
    dx = np.dot(dout, w.T)
    dx = dx.reshape(x.shape)
    return dx, dw, db

In [24]:
def multihead_selfattention_forward(x, K, Q, V, W, mask=False):
    """
    Inputs:
    - x: A numpy array of shape (N, T, D).
    - K: A numpy array of shape (h, D, M) containing the weights for the key matrix.
    - Q: A numpy array of shape (h, D, M) containing the weights for the query matrix.
    - V: A numpy array of shape (h, D, M) containing the weights for the value matrix.
    - W: A numpy array of shape (hM, D) contaning the weights for the transformation matrix.
    - mask: A numpy array of shape (T, T) of boolean values.  Flag=True masks the value.

    Returns a tuple of:
    - out: Output data of shape (N, T, D)
    - cache: Values needed for the backward pass.
    """
    N, T, D = x.shape
    h, _, M = K.shape

    beta = np.expand_dims(x, axis=1) @ np.expand_dims(Q, axis=0) \
         @ np.expand_dims(K.transpose(0,2,1), axis=0) @ np.expand_dims(x.transpose(0,2,1), axis=1)
    beta /= np.sqrt(M)
    beta += np.expand_dims(mask * (-1e20), axis=(0,1))              # tricks are for kids
    alpha, softmax_cache = softmax_forward(beta.reshape(N*h*T, -1))
    alpha = alpha.reshape(N, h, T, T)
    heads_out = alpha @ np.expand_dims(x, axis=1) @ np.expand_dims(V, axis=0)
    heads_out = heads_out.transpose(0, 2, 1, 3).reshape(N, T, h*M)
    out, affine_cache = affine_forward(heads_out.reshape(N*T, -1), W)
    out = out.reshape(N, T, D)

    cache = (softmax_cache, affine_cache, alpha, x, K, Q, V, mask)
    return out, cache


def multihead_selfattention_backward(dout, cache):
    """
    Inputs:
    - dout: Upstream derivative of the self-attention output of shape (N, T, D)
    - cache: A tuple of values from the forward pass.

    Returns:
    - dx: Gradient with respect to x, of shape (N, T, D).
    - dK: Gradient with respect to K, of shape (h, D, M).
    - dQ: Gradient with respect to Q, of shape (h, D, M).
    - dV: Gradient with respect to V, of shape (h, D, M).
    """
    softmax_cache, affine_cache, alpha, x, K, Q, V, mask = cache
    N, T, D = dout.shape
    h, _, M = K.shape

    dheads_out, dW, _ = affine_backward(dout.reshape(N*T, -1), affine_cache)
    dheads_out = dheads_out.reshape(N, T, h, M).transpose(0, 2, 1, 3)

    dV = np.sum((alpha @ np.expand_dims(x, axis=1)).transpose(0,1,3,2) @ dheads_out, axis=0)
    dalpha = dheads_out @ (np.expand_dims(x, axis=1) @ np.expand_dims(V, axis=0)).transpose(0,1,3,2)
    dalpha *= np.expand_dims(1-mask, axis=(0,1))
    dbeta = softmax_backward(dalpha.reshape(N*h*T,-1), softmax_cache) / np.sqrt(M)
    dbeta = dbeta.reshape(N, h, T, T)
    dQ = np.sum(np.expand_dims(x.transpose(0,2,1), axis=1) @ dbeta @ np.expand_dims(x, axis=1) \
        @ np.expand_dims(K, axis=0), axis=0)
    dK = np.sum(np.expand_dims(x.transpose(0,2,1), axis=1) @ dbeta.transpose(0,1,3,2) \
        @ np.expand_dims(x, axis=1) @ np.expand_dims(Q, axis=0), axis=0)
    dx = alpha.transpose(0,1,3,2) @ dheads_out @ np.expand_dims(V.transpose(0,2,1), axis=0)
    dx += dbeta @ np.expand_dims(x, axis=1) @ np.expand_dims(K @ Q.transpose(0,2,1), axis=0)
    dx += dbeta.transpose(0,1,3,2) @ np.expand_dims(x, axis=1) @ np.expand_dims(Q @ K.transpose(0,2,1), axis=0)
    dx = np.sum(dx, axis=1)

    return dx, dK, dQ, dV, dW

In [25]:
z, cache = multihead_selfattention_forward(x, K, Q, V, W)
print("z shape:", z.shape)
print("z:\n", z)

z shape: (3, 2, 4)
z:
 [[[ 2.946  4.086  0.168 -5.198]
  [ 3.152  4.305 -0.114 -5.654]]

 [[ 2.946  4.086  0.168 -5.198]
  [ 3.152  4.305 -0.114 -5.654]]

 [[ 2.946  4.086  0.168 -5.198]
  [ 3.152  4.305 -0.114 -5.654]]]


In [26]:
l, dz = loss(z)
dx, dK, dQ, dV, dW = multihead_selfattention_backward(dz, cache)
params = {"x":x, "K":K, "Q":Q, "V":V}
grads = {"x":dx, "K":dK, "Q":dQ, "V":dV}

# These should all be less than 1e-8 or so.
f = lambda a: loss(multihead_selfattention_forward(x, K, Q, V, W)[0])[0]
for _name in grads:
    grad_numeric = eval_numerical_gradient(f, params[_name], verbose=False)
    print("%s max relative error: %e" % (_name, rel_error(grad_numeric, grads[_name])))

x max relative error: 4.381519e-09
K max relative error: 3.147993e-09
Q max relative error: 5.043319e-09
V max relative error: 1.714666e-09


In [27]:
z, cache = multihead_selfattention_forward(x, K, Q, V, W, mask)
print("\nMasked output:")
print("z:\n", z)


Masked output:
z:
 [[[ 1.114  2.13   2.684 -1.14 ]
  [ 3.152  4.305 -0.114 -5.654]]

 [[ 1.114  2.13   2.684 -1.14 ]
  [ 3.152  4.305 -0.114 -5.654]]

 [[ 1.114  2.13   2.684 -1.14 ]
  [ 3.152  4.305 -0.114 -5.654]]]


In [28]:
l, dz = loss(z)
dx, dK, dQ, dV, dW = multihead_selfattention_backward(dz, cache)
params = {"x":x, "K":K, "Q":Q, "V":V}
grads = {"x":dx, "K":dK, "Q":dQ, "V":dV}

# These should all be less than 1e-8 or so.
f = lambda a: loss(multihead_selfattention_forward(x, K, Q, V, W, mask)[0])[0]
for _name in grads:
    grad_numeric = eval_numerical_gradient(f, params[_name], verbose=False)
    print("%s max relative error: %e" % (_name, rel_error(grad_numeric, grads[_name])))

x max relative error: 7.402705e-10
K max relative error: 8.872991e-09
Q max relative error: 7.414079e-09
V max relative error: 5.805844e-10


## ADDING NON-LINEARITIES

So far every attention layer is a linear transformation of the previous layer (with non-linear weights on the value vectors). There are no elementwise non linearities. Thus, stacking more self-attention layers just re-averages the value vectors.  
To fix this we add a feed-forward network to post process each output of the self-attention layer.

![Non-linearity](img/transformer_nonlinearity.png "Non-linearity")

One additional detail in the architecture is that every layer has a residual connection and is followed by a layer-normalization step. Both residual connections and layer normalization are only used to help the models train better.  
Residual connections are thought to make the loss landscape considerably smoothe. And layer normalization is thought to cut down on the uninformative variation of the input vectors.

![Residuals](img/transformer_residuals.png "Residuals")

In [29]:
def residual_forward(x, y):
    return x + y, None

def residual_backward(dout, cache):
    dx = dout
    dy = dout
    return dx, dy

## POSITIONAL ENCODING

Since the self-attention layer doesn't account for the order of the input sequence, we need to encode the order in some way. Naively we could just try to append the sequence number to the input:

$$ \bar{x_{t}} = \begin{bmatrix} x_{t} \\ t \end{bmatrix}$$

However this approach is not a good idea, because <b>absolute</b> position is less important than <b>relative</b> position. We want to represent position in a way that tokens with similar relative position have similar positional encoding.

<b>One idea</b> is to use a <b>frequency-based</b> representation: concatenate sinusoidal functions of varaying periods. What this means is that basically for every input we add an indicator whether input comes from the first or the second half of the sequence. And an indicator from which quarter the input comes, and so on. And finally, whether the input has an even or and odd index.  

<b>Another idea</b> is to let all $p_{i}$ be learnable parameters and learn a matrix $p \in \mathbb{R}^{dxT}$. This approach is perhaps more optimal than sin/cos encoding, but the downside is that we need to pick the maximum sequence lenght and cannot generalize beyond it. Most systems use this approach.

To incorporate the positional encoding we could just concatenate it to the input:

$$ \bar{x_{t}} = \begin{bmatrix} x_{t} \\ p_{t} \end{bmatrix} $$

Another approach is to add the positional encoding to the input. And yet another approach is to add the positional encoding to the key, query and value of the input.

$ \bar{x_{t}} = x_{t} + p_{t} $ (most people do this)  

or  

$ \bar{k_{t}} = k_{t} + p_{t} $  
$ \bar{q_{t}} = q_{t} + p_{t} $  
$ \bar{v_{t}} = v_{t} + p_{t} $  

![Position](img/transformer_position.png "Position")



## CROSS-ATTENTION

Self-attention is when keys, queries, and values come from the same source. However, we want the decoder to attend to the input sequence. To do this we use the output of the top encoder and transform it into a set of keys and queries. These are to be used by each decoder in the cross-attention layer.  
Let $ y = [y_{1}, ..., y_{T}]$ be the output from the top encoder; $y \in \mathbb{R}^{Txd}$.  
Let $ x = [x_{1}, ..., x_{T}]$ be the input to the decoder; $x \in \mathbb{R}^{Txd}$.  
Then the keys and values are drawn from the encoder (like a memory):
$$
\begin{align*}
& \text{keys} = yQ \\
& \text{values} = yV \\
\end{align*}
$$

And the queries are drawn from the decoder:

$$ \text{queries} = xQ $$

The cross-attention output is computed similarly to the self-attention output:

$$ \text{out} = \text{softmax}(xQK^{T}y^{T})yV $$

![Cross attention](img/transformer_crossattention_matrix.png "Cross attention")

In [30]:
def cross_attention_naive(x, y, K, Q, V):
    """
    Inputs:
    - x: A numpy array of shape (Tdec, D), giving decoder inputs.
    - y: A numpy array of shape (Tenc, D), giving encoder outputs.
    - K: A numpy array of shape (D, M) containing the weights for the key matrix.
    - Q: A numpy array of shape (D, M) containing the weights for the query matrix.
    - V: A numpy array of shape (D, M) containing the weights for the value matrix.

    Returns a tuple of:
    - out: Output data of shape (Tdec, M)
    - cache: Values needed for the backward pass.
    """
    M = K.shape[-1]
    return softmax(x @ Q @ K.T @ y.T / np.sqrt(M)) @ y @ V

In [31]:
Tenc = T
Tdec = 4

y = np.random.randn(Tenc, D)
x = np.random.randn(Tdec, D)
z = cross_attention_naive(x, y, K[0], Q[0], V[0])

print("y shape:", y.shape)
print("z shape:", z.shape)
print("z:\n", z)

y shape: (2, 4)
z shape: (4, 3)
z:
 [[-1.699  0.752  0.58 ]
 [-2.284  0.682  0.421]
 [-1.566  0.768  0.616]
 [-2.338  0.676  0.407]]


In [32]:
def batch_cross_attention_forward(x, y, K, Q, V):
    """
    Inputs:
    - x: A numpy array of shape (N, Tdec, D), giving decoder inputs.
    - y: A numpy array of shape (N, Tenc, D), giving encoder outputs.
    - K: A numpy array of shape (D, M) containing the weights for the key matrix.
    - Q: A numpy array of shape (D, M) containing the weights for the query matrix.
    - V: A numpy array of shape (D, M) containing the weights for the value matrix.

    Returns a tuple of:
    - out: Output data of shape (N, Tdec, M)
    - cache: Values needed for the backward pass.
    """
    N, Tdec, D = x.shape
    M = K.shape[-1]
 
    beta = x @ np.expand_dims(Q, axis=0) @ np.expand_dims(K.T, axis=0) @ y.transpose(0,2,1)
    beta /= np.sqrt(M)
    alpha, softmax_cache = softmax_forward(beta.reshape(N*Tdec, -1))   # tensor (N, Tdec, Tenc)
    alpha = alpha.reshape(N, Tdec, Tenc)
    out = alpha @ y @ np.expand_dims(V, axis=0)

    cache = (softmax_cache, alpha, x, y, K, Q, V)
    return out, cache


def batch_cross_attention_backward(dout, cache):
    """
    Inputs:
    - dout: Upstream derivative of the cross-attention output of shape (N, Tdec, M)
    - cache: A tuple of values from the forward pass.

    Returns:
    - dx: Gradient with respect to x, of shape (N, Tenc, D).
    - dy: Gradient with respect to y, of shape (N, Tdec, D).
    - dK: Gradient with respect to K, of shape (D, M).
    - dQ: Gradient with respect to Q, of shape (D, M).
    - dV: Gradient with respect to V, of shape (D, M).
    """
    softmax_cache, alpha, x, y, K, Q, V = cache
    N, Tdec, M = dout.shape

    dV = np.sum((alpha @ y).transpose(0,2,1) @ dout, axis=0)
    dalpha = dout @ (y @ np.expand_dims(V, axis=0)).transpose(0,2,1)
    dbeta = softmax_backward(dalpha.reshape(N*Tdec,-1), softmax_cache) / np.sqrt(M)
    dbeta = dbeta.reshape(N, Tdec, Tenc)
    dQ = np.sum(x.transpose(0,2,1) @ dbeta @ y @ np.expand_dims(K, axis=0), axis=0)
    dK = np.sum(y.transpose(0,2,1) @ dbeta.transpose(0,2,1) @ x @ np.expand_dims(Q, axis=0), axis=0)
    dy = alpha.transpose(0,2,1) @ dout @ np.expand_dims(V.T, axis=0)
    dy += dbeta.transpose(0,2,1) @ x @ np.expand_dims(Q @ K.T, axis=0)
    dx = dbeta @ y @ np.expand_dims(K @ Q.T, axis=0)

    return dx, dy, dK, dQ, dV

In [33]:
N = 2
y = np.stack([y] * N)
x = np.stack([x] * N)
z, cache = batch_cross_attention_forward(x, y, K[0], Q[0], V[0])

print("y shape:", y.shape)
print("z shape:", z.shape)
print("z:\n", z)

y shape: (2, 2, 4)
z shape: (2, 4, 3)
z:
 [[[-1.699  0.752  0.58 ]
  [-2.284  0.682  0.421]
  [-1.566  0.768  0.616]
  [-2.338  0.676  0.407]]

 [[-1.699  0.752  0.58 ]
  [-2.284  0.682  0.421]
  [-1.566  0.768  0.616]
  [-2.338  0.676  0.407]]]


In [34]:
l, dz = loss(z)
dx, dy, dK, dQ, dV = batch_cross_attention_backward(dz, cache)
params = {"x":x, "y":y, "K":K[0], "Q":Q[0], "V":V[0]}
grads = {"x":dx, "y":dy, "K":dK, "Q":dQ, "V":dV}

# These should all be less than 1e-8 or so.
f = lambda a: loss(batch_cross_attention_forward(x, y, K[0], Q[0], V[0])[0])[0]
for _name in grads:
    grad_numeric = eval_numerical_gradient(f, params[_name], verbose=False)
    print("%s max relative error: %e" % (_name, rel_error(grad_numeric, grads[_name])))

x max relative error: 5.405021e-11
y max relative error: 3.270507e-10
K max relative error: 1.397881e-09
Q max relative error: 4.778638e-10
V max relative error: 3.774175e-11


Just as with self-attention, we again want to "look" at multiple places at once, and construct value vectors differently. Thus, we define multiple cross-attention heads using multiple $K$, $Q$, $V$ matrices.

In [35]:
def multihead_cross_attention_naive(x, y, K, Q, V, W):
    """
    Inputs:
    - x: A numpy array of shape (N, Tdec, D), giving decoder inputs.
    - y: A numpy array of shape (N, Tenc, D), giving encoder outputs.
    - K: A numpy array of shape (h, D, M) containing the weights for the key matrix.
    - Q: A numpy array of shape (h, D, M) containing the weights for the query matrix.
    - V: A numpy array of shape (h, D, M) containing the weights for the value matrix.
    - W: A numpy array of shape (hM, D) contaning the weights for the transformation matrix.

    Returns a tuple of:
    - out: Output data of shape (N, T, D)
    - cache: Values needed for the backward pass.
    """
    N, Tdec, D = x.shape
    h, _, M = K.shape
    heads_out = []
    
    for i in range(h):
        out, _ = batch_cross_attention_forward(x, y, K[i], Q[i], V[i]) # tensor of shape (N, T, M)
        heads_out.append(out)
    heads_out = np.dstack(heads_out)  # tensor of shape (N, T, hM)
    print("heads out:\n", heads_out)
    return heads_out @ np.expand_dims(W, axis=0)

In [36]:
print("x shape:", x.shape)

print("""\nNote that the output before transfromation is equal to
the ouput from a single head concatenaded {:d} times""".format(h))

z = multihead_cross_attention_naive(x, y, K, Q, V, W)
print("\nOutput after transformation:")
print("z shape:", z.shape)
print("z:\n", z)

x shape: (2, 4, 4)

Note that the output before transfromation is equal to
the ouput from a single head concatenaded 3 times
heads out:
 [[[-1.699  0.752  0.58  -1.699  0.752  0.58  -1.699  0.752  0.58 ]
  [-2.284  0.682  0.421 -2.284  0.682  0.421 -2.284  0.682  0.421]
  [-1.566  0.768  0.616 -1.566  0.768  0.616 -1.566  0.768  0.616]
  [-2.338  0.676  0.407 -2.338  0.676  0.407 -2.338  0.676  0.407]]

 [[-1.699  0.752  0.58  -1.699  0.752  0.58  -1.699  0.752  0.58 ]
  [-2.284  0.682  0.421 -2.284  0.682  0.421 -2.284  0.682  0.421]
  [-1.566  0.768  0.616 -1.566  0.768  0.616 -1.566  0.768  0.616]
  [-2.338  0.676  0.407 -2.338  0.676  0.407 -2.338  0.676  0.407]]]

Output after transformation:
z shape: (2, 4, 4)
z:
 [[[ 3.324  3.524 -2.345 -6.67 ]
  [ 4.444  4.265 -2.407 -8.694]
  [ 3.069  3.355 -2.33  -6.208]
  [ 4.547  4.333 -2.413 -8.881]]

 [[ 3.324  3.524 -2.345 -6.67 ]
  [ 4.444  4.265 -2.407 -8.694]
  [ 3.069  3.355 -2.33  -6.208]
  [ 4.547  4.333 -2.413 -8.881]]]


And again, to make the multi-head cross-attention layer work using only matrix algebra we need to note that the gradient of the loss $\ell$ with respect to the input tensors will be equal to the sum of the gradients from each attention head.

$$ \frac{d\ell}{dx} = \sum_{i=0}^{h-1} \frac{d\ell}{dx^{i}} $$
$$ \frac{d\ell}{dy} = \sum_{i=0}^{h-1} \frac{d\ell}{dy^{i}} $$

In [37]:
def multihead_crossattention_forward(x, y, K, Q, V, W):
    """
    Inputs:
    - x: A numpy array of shape (N, Tdec, D), giving decoder inputs.
    - y: A numpy array of shape (N, Tenc, D), giving encoder outputs.
    - K: A numpy array of shape (h, D, M) containing the weights for the key matrix.
    - Q: A numpy array of shape (h, D, M) containing the weights for the query matrix.
    - V: A numpy array of shape (h, D, M) containing the weights for the value matrix.
    - W: A numpy array of shape (hM, D) contaning the weights for the transformation matrix.

    Returns a tuple of:
    - out: Output data of shape (N, Tdec, D)
    - cache: Values needed for the backward pass.
    """
    N, Tdec, D = x.shape
    h, _, M = K.shape

    beta = np.expand_dims(x, axis=1) @ np.expand_dims(Q, axis=0) \
         @ np.expand_dims(K.transpose(0,2,1), axis=0) @ np.expand_dims(y.transpose(0,2,1), axis=1)
    beta /= np.sqrt(M)
    alpha, softmax_cache = softmax_forward(beta.reshape(N*h*Tdec, -1))
    alpha = alpha.reshape(N, h, Tdec, -1)
    heads_out = alpha @ np.expand_dims(y, axis=1) @ np.expand_dims(V, axis=0)
    heads_out = heads_out.transpose(0, 2, 1, 3).reshape(N, Tdec, h*M)
    out, affine_cache = affine_forward(heads_out.reshape(N*Tdec, -1), W)
    out = out.reshape(N, Tdec, D)

    cache = (softmax_cache, affine_cache, alpha, x, y, K, Q, V, W)
    return out, cache


def multihead_crossattention_backward(dout, cache):
    """
    Inputs:
    - dout: Upstream derivative of the multi-head cross-attention output of shape (N, Tdec, D)
    - cache: A tuple of values from the forward pass.

    Returns:
    - dx: Gradient with respect to x, of shape (N, Tdec, D).
    - dy: Gradient with respect to y, of shape (N, Tenc, D).
    - dK: Gradient with respect to K, of shape (h, D, M).
    - dQ: Gradient with respect to Q, of shape (h, D, M).
    - dV: Gradient with respect to V, of shape (h, D, M).
    """
    softmax_cache, affine_cache, alpha, x, y, K, Q, V, W = cache
    N, Tdec, D = dout.shape
    h, _, M = K.shape

    dheads_out, dW, _ = affine_backward(dout.reshape(N*Tdec, -1), affine_cache)
    dheads_out = dheads_out.reshape(N, Tdec, h, M).transpose(0, 2, 1, 3)

    dV = np.sum((alpha @ np.expand_dims(y, axis=1)).transpose(0,1,3,2) @ dheads_out, axis=0)
    dalpha = dheads_out @ (np.expand_dims(y, axis=1) @ np.expand_dims(V, axis=0)).transpose(0,1,3,2)
    dbeta = softmax_backward(dalpha.reshape(N*h*Tdec,-1), softmax_cache) / np.sqrt(M)
    dbeta = dbeta.reshape(N, h, Tdec, -1)
    dQ = np.sum(np.expand_dims(x.transpose(0,2,1), axis=1) @ dbeta @ np.expand_dims(y, axis=1) \
        @ np.expand_dims(K, axis=0), axis=0)
    dK = np.sum(np.expand_dims(y.transpose(0,2,1), axis=1) @ dbeta.transpose(0,1,3,2) \
        @ np.expand_dims(x, axis=1) @ np.expand_dims(Q, axis=0), axis=0)
    dy = alpha.transpose(0,1,3,2) @ dheads_out @ np.expand_dims(V.transpose(0,2,1), axis=0)
    dy += dbeta.transpose(0,1,3,2) @ np.expand_dims(x, axis=1) @ np.expand_dims(Q @ K.transpose(0,2,1), axis=0)
    dx = dbeta @ np.expand_dims(y, axis=1) @ np.expand_dims(K @ Q.transpose(0,2,1), axis=0)
    dx = np.sum(dx, axis=1)
    dy = np.sum(dy, axis=1)

    return dx, dy, dK, dQ, dV, dW

In [38]:
z, cache = multihead_crossattention_forward(x, y, K, Q, V, W)
print("z shape:", z.shape)
print("z:\n", z)

z shape: (2, 4, 4)
z:
 [[[ 3.324  3.524 -2.345 -6.67 ]
  [ 4.444  4.265 -2.407 -8.694]
  [ 3.069  3.355 -2.33  -6.208]
  [ 4.547  4.333 -2.413 -8.881]]

 [[ 3.324  3.524 -2.345 -6.67 ]
  [ 4.444  4.265 -2.407 -8.694]
  [ 3.069  3.355 -2.33  -6.208]
  [ 4.547  4.333 -2.413 -8.881]]]


In [39]:
l, dz = loss(z)
dx, dy, dK, dQ, dV, dW = multihead_crossattention_backward(dz, cache)
params = {"x":x, "y":y, "K":K, "Q":Q, "V":V, "W":W}
grads = {"x":dx, "y":dy, "K":dK, "Q":dQ, "V":dV, "W":dW}

# These should all be less than 1e-8 or so.
f = lambda a: loss(multihead_crossattention_forward(x, y, K, Q, V, W)[0])[0]
for _name in grads:
    grad_numeric = eval_numerical_gradient(f, params[_name], verbose=False)
    print("%s max relative error: %e" % (_name, rel_error(grad_numeric, grads[_name])))

x max relative error: 3.424784e-09
y max relative error: 1.758840e-10
K max relative error: 3.060167e-08
Q max relative error: 1.574904e-08
V max relative error: 1.308661e-09
W max relative error: 5.105479e-11


## THE TRANSFORMER ENCODER-DECODER

The transformer encoder is composed of a self-attention layer and a feed-forward network.  
The inputs are encoded with positional encodings and are then fed to the self-attention layer. The outputs of the self-attention layer are then fed to a feed-forward network. Every layer has a residual connection and is followed by a layer normalization.

![Encoder](img/transformer_encoder.png "Encoder")


The transformer decoder uses a masked self-attention layer, followed by a cross-attention layer, and a feed-forward network. The masked self-attention layer makes sure that the decoder does not "look" into the future by masking future positions. The cross-attention layer provides the functionality to attend to the input source sequence. Again, as with the encoder, every layer has a residual connection and is followed by a layer normalization.

![Decoder](img/transformer_decoder.png "Decoder")

In [40]:
from src.transformer import Transformer

In [41]:
# Initialize toy example to check the implementation.
np.random.seed(13)

batch_size = 2
src_seq_len = 4
src_vocab_size = 10
src_embed_dim = 4

tgt_seq_len = 7
tgt_vocab_size = 10
tgt_embed_dim = 4

hidden_dim = 5

null_idx = 0
start_idx = 1
end_idx = 2

src = np.random.randint(low=0, high=src_vocab_size, size=(batch_size, src_seq_len))
tgt = np.random.randint(low=0, high=tgt_vocab_size, size=(batch_size, tgt_seq_len + 1))

print("Example source:\n", src)
print("Example target:\n", tgt)

Example source:
 [[2 0 0 6]
 [2 4 9 3]]
Example target:
 [[4 2 6 5 9 4 2 0]
 [3 5 3 6 5 1 2 8]]


In [42]:
# Check the backward pass for the Transformer model.
D = 4
M = 3
h = 5
n_enc = 2
n_dec = 2
transformer_model = Transformer(D, M, h, n_enc, n_dec,
                                src_vocab_size, src_embed_dim,
                                tgt_vocab_size, tgt_embed_dim,
                                null_idx, dtype=np.float64)


loss, grads = transformer_model.loss(src, tgt)
f = lambda _ : transformer_model.loss(src, tgt)[0]

for param_name in sorted(grads):
    param_grad_num = eval_numerical_gradient(f, transformer_model.params[param_name], verbose=False, h=1e-6)
    print("%s relative error: %e" % (param_name, rel_error(param_grad_num, grads[param_name])))

K_0_dec relative error: 4.546203e-06
K_0_enc relative error: 3.196035e-04
K_1_dec relative error: 1.205586e-06
K_1_enc relative error: 3.443817e-07
K_dec_cross relative error: 1.267143e-06
Q_0_dec relative error: 2.098227e-06
Q_0_dec_cross relative error: 9.284055e-07
Q_0_enc relative error: 1.371931e-06
Q_1_dec relative error: 1.673130e-07
Q_1_dec_cross relative error: 5.509773e-07
Q_1_enc relative error: 3.215609e-07
V_0_dec relative error: 3.681092e-08
V_0_enc relative error: 1.455361e-07
V_1_dec relative error: 4.208471e-07
V_1_enc relative error: 5.975341e-07
V_dec_cross relative error: 1.368287e-07
W_0_dec relative error: 6.207619e-08
W_0_dec_cross relative error: 3.587090e-08
W_0_dec_ff relative error: 3.649067e-09
W_0_enc relative error: 7.501286e-08
W_0_enc_ff relative error: 1.740355e-08
W_1_dec relative error: 4.006797e-07
W_1_dec_cross relative error: 2.670082e-08
W_1_dec_ff relative error: 1.062843e-08
W_1_enc relative error: 7.265027e-08
W_1_enc_ff relative error: 7.29826