In [1]:
%load_ext autoreload
%autoreload 2

# Natural Language Processing Demystified | Transformers, Pre-training, and Transfer Learning

## Imports

In [2]:
!pip install BPEmb

from typing import Optional
from functools import partial

import math
import numpy as np
import tensorflow as tf
import jax
import jax.numpy as jnp
import haiku as hk

jax.config.update('jax_platform_name', 'gpu')
seed = 0
np.random.seed(seed)
tf.keras.utils.set_random_seed(seed)

from bpemb import BPEmb

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.1.2[0m[39;49m -> [0m[32;49m23.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


2023-07-15 22:14:03.467441: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-07-15 22:14:03.497469: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.




# Transformers From Scratch

We'll build a transformer from scratch, layer-by-layer. We'll start with the **Multi-Head Self-Attention** layer since that's the most involved bit. Once we have that working, the rest of the model will look familiar if you've been following the course so far.

## Multi-Head Self-Attention

#### Scaled Dot Product Self-Attention


Inside each attention head is a **Scaled Dot Product Self-Attention** operation as we covered in the slides. Given *queries*, *keys*, and *values*, the operation returns a new "mix" of the values.

$$Attention(Q, K, V) = softmax(\frac{QK^T)}{\sqrt{d_k}})V$$

The following function implements this and also takes a mask to account for padding and for masking future tokens for decoding (i.e. **look-ahead mask**). We'll cover masking later in the notebook.

In [3]:
def scaled_dot_product_attention(query, key, value, mask=None):
    key_dim = jnp.shape(key)[-1]
    scaled_scores = jnp.matmul(query, key.swapaxes(-1, -2)) / np.sqrt(key_dim)
    scaled_scores = scaled_scores.astype(jnp.float32)

    if mask is not None:
        scaled_scores = jnp.where(mask==0, -np.inf, scaled_scores)

    softmax = jax.nn.softmax
    weights = softmax(scaled_scores) 
    return jnp.matmul(weights, value), weights

In [4]:
seq_len = 3
embed_dim = 4

queries = np.random.rand(seq_len, embed_dim)
keys = np.random.rand(seq_len, embed_dim)
values = np.random.rand(seq_len, embed_dim)

print("Queries:\n", queries)

Queries:
 [[0.5488135  0.71518937 0.60276338 0.54488318]
 [0.4236548  0.64589411 0.43758721 0.891773  ]
 [0.96366276 0.38344152 0.79172504 0.52889492]]


In [5]:
output, attn_weights = scaled_dot_product_attention(queries, keys, values)

print("Output\n", output, "\n")
print("Weights\n", attn_weights)

Output
 [[0.3879928  0.5350287  0.13623255 0.7588937 ]
 [0.39492333 0.53118    0.1382761  0.755936  ]
 [0.39012212 0.53832173 0.12822185 0.75093997]] 

Weights
 [[0.26820898 0.34192598 0.38986498]
 [0.2510088  0.35895866 0.39003247]
 [0.25670317 0.3151627  0.4281342 ]]


#### Dense layer

In [6]:
class Dense(hk.Module):
    """ A 1-layer MLP. Adapted from https://theaisummer.com/jax-transformer/#the-linear-layer """

    def __init__(self,
                 output_dim: float,
                 name: Optional[str] = None,
                 activation = jax.nn.relu):
        super().__init__(name=name)
        self.hiddens = output_dim
        self.activation = activation

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        # hiddens = x.shape[-1]
        # initializer = hk.initializers.VarianceScaling(self._init_scale)
        x = hk.Linear(self.hiddens)(x)
        return self.activation(x)

#### Generating queries, keys, and values for multiple heads

In [7]:
batch_size = 1
seq_len = 3
embed_dim = 12
num_heads = 3
head_dim = embed_dim // num_heads

print(f"Dimension of each head: {head_dim}")

Dimension of each head: 4


Using separate weight matrices per head

Suppose these are our input embeddings. Here we have a batch of 1 containing a sequence of length 3, with each element being a 12-dimensional embedding.

In [8]:
x = np.random.rand(batch_size, seq_len, embed_dim).round(1)
print("Input shape: ", x.shape, "\n")
print("Input:\n", x)

Input shape:  (1, 3, 12) 

Input:
 [[[0.6 0.6 0.9 0.7 0.4 0.4 0.7 0.1 0.7 0.7 0.2 0.1]
  [0.3 0.4 0.6 0.4 1.  0.1 0.2 0.2 0.7 0.3 0.5 0.2]
  [0.2 0.1 0.7 0.1 0.2 0.4 0.8 0.1 0.8 0.1 1.  0.5]]]


We'll declare three sets of query weights (one for each head), three sets of key weights, and three sets of value weights. Remember each weight matrix should have a dimension of  d x d/h .

In [9]:
# The query weights for each head.
wq0 = np.random.rand(embed_dim, head_dim).round(1)
wq1 = np.random.rand(embed_dim, head_dim).round(1)
wq2 = np.random.rand(embed_dim, head_dim).round(1)

# The key weights for each head. 
wk0 = np.random.rand(embed_dim, head_dim).round(1)
wk1 = np.random.rand(embed_dim, head_dim).round(1)
wk2 = np.random.rand(embed_dim, head_dim).round(1)

# The value weights for each head.
wv0 = np.random.rand(embed_dim, head_dim).round(1)
wv1 = np.random.rand(embed_dim, head_dim).round(1)
wv2 = np.random.rand(embed_dim, head_dim).round(1)

print("The three sets of query weights (one for each head):")
print("wq0:\n", wq0)
print("wq1:\n", wq1)
print("wq2:\n", wq1)

The three sets of query weights (one for each head):
wq0:
 [[1.  0.6 0.7 0. ]
 [0.3 0.1 0.3 0.1]
 [0.3 0.4 0.1 0.7]
 [0.6 0.3 0.5 0.1]
 [0.6 0.9 0.3 0.7]
 [0.1 0.7 0.3 0.2]
 [0.6 0.  0.8 0. ]
 [0.7 0.3 0.7 1. ]
 [0.2 0.6 0.6 0.6]
 [0.2 1.  0.4 0.8]
 [0.7 0.3 0.8 0.4]
 [0.9 0.6 0.9 0.7]]
wq1:
 [[0.7 0.5 1.  0.6]
 [0.4 0.6 0.  0.3]
 [0.7 0.3 0.6 0.4]
 [0.1 0.3 0.6 0.6]
 [0.6 0.7 0.7 0.4]
 [0.9 0.4 0.4 0.9]
 [0.8 0.7 0.1 0.9]
 [0.7 1.  0.1 0.9]
 [0.2 0.6 0.1 0.8]
 [0.8 0.6 0.4 0.1]
 [0.7 0.5 0.7 0.9]
 [1.  0.9 0.  0.4]]
wq2:
 [[0.7 0.5 1.  0.6]
 [0.4 0.6 0.  0.3]
 [0.7 0.3 0.6 0.4]
 [0.1 0.3 0.6 0.6]
 [0.6 0.7 0.7 0.4]
 [0.9 0.4 0.4 0.9]
 [0.8 0.7 0.1 0.9]
 [0.7 1.  0.1 0.9]
 [0.2 0.6 0.1 0.8]
 [0.8 0.6 0.4 0.1]
 [0.7 0.5 0.7 0.9]
 [1.  0.9 0.  0.4]]


We'll generate our *queries*, *keys*, and *values* for each head by multiplying our input by the weights.

In [10]:
# Geneated queries, keys, and values for the first head.
q0 = np.dot(x, wq0)
k0 = np.dot(x, wk0)
v0 = np.dot(x, wv0)

# Geneated queries, keys, and values for the second head.
q1 = np.dot(x, wq1)
k1 = np.dot(x, wk1)
v1 = np.dot(x, wv1)

# Geneated queries, keys, and values for the third head.
q2 = np.dot(x, wq2)
k2 = np.dot(x, wk2)
v2 = np.dot(x, wv2)

These are the resulting *query*, *key*, and *value* vectors for the first head.

In [11]:
print("Q, K, and V for first head:\n")

print(f"q0 {q0.shape}:\n", q0, "\n")
print(f"k0 {k0.shape}:\n", k0, "\n")
print(f"v0 {v0.shape}:\n", v0)

Q, K, and V for first head:

q0 (1, 3, 4):
 [[[2.75 2.9  2.86 2.35]
  [2.44 2.6  2.34 2.42]
  [2.54 2.11 2.95 2.14]]] 

k0 (1, 3, 4):
 [[[3.18 3.45 3.1  2.62]
  [2.32 2.39 2.34 2.27]
  [2.26 2.38 2.7  2.3 ]]] 

v0 (1, 3, 4):
 [[[3.99 4.23 2.71 2.65]
  [3.16 3.67 1.82 2.6 ]
  [3.19 3.77 2.08 2.75]]]


Now that we have our Q, K, V vectors, we can just pass them to our self-attention operation. Here we're calculating the output and attention weights for the first head.

In [12]:
out0, attn_weights0 = scaled_dot_product_attention(q0, k0, v0)

print("Output from first attention head: ", out0, "\n")
print("Attention weights from first head: ", attn_weights0)

Output from first attention head:  [[[3.960351  4.211777  2.6832957 2.6515236]
  [3.9453316 4.202458  2.6695168 2.6521323]
  [3.9410832 4.199962  2.6660192 2.652564 ]]] 

Attention weights from first head:  [[[0.96347106 0.01419647 0.0223325 ]
  [0.9450061  0.02245    0.03254395]
  [0.9397198  0.02309422 0.03718603]]]


Here are the other two (attention weights are ignored).

In [13]:
out1, _ = scaled_dot_product_attention(q1, k1, v1)
out2, _ = scaled_dot_product_attention(q2, k2, v2)

print("Output from second attention head: ", out1, "\n")
print("Output from third attention head: ", out2,)

Output from second attention head:  [[[2.6091306 2.6788611 2.7399383 2.6570723]
  [2.6075811 2.6769052 2.7397096 2.6520724]
  [2.6086218 2.6782243 2.739854  2.6554472]]] 

Output from third attention head:  [[[3.350872  2.3895252 2.9842615 3.5212994]
  [3.318234  2.3826733 2.9540384 3.495711 ]
  [3.337096  2.3867323 2.9716554 3.5106113]]]


As we covered in the slides, once we have each head's output, we concatenate them and then put them through a linear layer for further processing.

In [14]:
combined_out_a = np.concatenate((out0, out1, out2), axis=-1)
print(f"Combined output from all heads {combined_out_a.shape}:")
print(combined_out_a)

# The final step would be to run combined_out_a through a linear/dense layer 
# for further processing.

Combined output from all heads (1, 3, 12):
[[[3.960351  4.211777  2.6832957 2.6515236 2.6091306 2.6788611 2.7399383
   2.6570723 3.350872  2.3895252 2.9842615 3.5212994]
  [3.9453316 4.202458  2.6695168 2.6521323 2.6075811 2.6769052 2.7397096
   2.6520724 3.318234  2.3826733 2.9540384 3.495711 ]
  [3.9410832 4.199962  2.6660192 2.652564  2.6086218 2.6782243 2.739854
   2.6554472 3.337096  2.3867323 2.9716554 3.5106113]]]


So that's a complete run of **multi-head self-attention** using separate sets of weights per head.<br>

Let's now get the same thing done using a single query weight matrix, single key weight matrix, and single value weight matrix.<br><br>
These were our separate per-head query weights:

In [15]:
print("Query weights for first head: \n", wq0, "\n")
print("Query weights for second head: \n", wq1, "\n")
print("Query weights for third head: \n", wq2)

Query weights for first head: 
 [[1.  0.6 0.7 0. ]
 [0.3 0.1 0.3 0.1]
 [0.3 0.4 0.1 0.7]
 [0.6 0.3 0.5 0.1]
 [0.6 0.9 0.3 0.7]
 [0.1 0.7 0.3 0.2]
 [0.6 0.  0.8 0. ]
 [0.7 0.3 0.7 1. ]
 [0.2 0.6 0.6 0.6]
 [0.2 1.  0.4 0.8]
 [0.7 0.3 0.8 0.4]
 [0.9 0.6 0.9 0.7]] 

Query weights for second head: 
 [[0.7 0.5 1.  0.6]
 [0.4 0.6 0.  0.3]
 [0.7 0.3 0.6 0.4]
 [0.1 0.3 0.6 0.6]
 [0.6 0.7 0.7 0.4]
 [0.9 0.4 0.4 0.9]
 [0.8 0.7 0.1 0.9]
 [0.7 1.  0.1 0.9]
 [0.2 0.6 0.1 0.8]
 [0.8 0.6 0.4 0.1]
 [0.7 0.5 0.7 0.9]
 [1.  0.9 0.  0.4]] 

Query weights for third head: 
 [[0.7 0.2 0.5 0.1]
 [0.2 0.  0.8 0.2]
 [0.3 0.9 0.7 0. ]
 [0.2 0.6 0.6 0.2]
 [0.9 0.6 0.5 0.6]
 [0.7 0.3 0.4 0.2]
 [0.2 0.9 0.7 0.5]
 [0.2 0.3 0.1 0.4]
 [0.3 0.7 0.4 0.2]
 [0.  0.1 0.7 0.5]
 [0.5 0.9 1.  0.2]
 [0.7 0.3 0.  0.8]]


In [16]:
wq = np.concatenate((wq0, wq1, wq2), axis=1)
print(f"Single query weight matrix {wq.shape}: \n", wq)

Single query weight matrix (12, 12): 
 [[1.  0.6 0.7 0.  0.7 0.5 1.  0.6 0.7 0.2 0.5 0.1]
 [0.3 0.1 0.3 0.1 0.4 0.6 0.  0.3 0.2 0.  0.8 0.2]
 [0.3 0.4 0.1 0.7 0.7 0.3 0.6 0.4 0.3 0.9 0.7 0. ]
 [0.6 0.3 0.5 0.1 0.1 0.3 0.6 0.6 0.2 0.6 0.6 0.2]
 [0.6 0.9 0.3 0.7 0.6 0.7 0.7 0.4 0.9 0.6 0.5 0.6]
 [0.1 0.7 0.3 0.2 0.9 0.4 0.4 0.9 0.7 0.3 0.4 0.2]
 [0.6 0.  0.8 0.  0.8 0.7 0.1 0.9 0.2 0.9 0.7 0.5]
 [0.7 0.3 0.7 1.  0.7 1.  0.1 0.9 0.2 0.3 0.1 0.4]
 [0.2 0.6 0.6 0.6 0.2 0.6 0.1 0.8 0.3 0.7 0.4 0.2]
 [0.2 1.  0.4 0.8 0.8 0.6 0.4 0.1 0.  0.1 0.7 0.5]
 [0.7 0.3 0.8 0.4 0.7 0.5 0.7 0.9 0.5 0.9 1.  0.2]
 [0.9 0.6 0.9 0.7 1.  0.9 0.  0.4 0.7 0.3 0.  0.8]]


In the same vein, pretend we declared a single key weight matrix, and single value weight matrix.

In [17]:
wk = np.concatenate((wk0, wk1, wk2), axis=1)
wv = np.concatenate((wv0, wv1, wv2), axis=1)

print(f"Single key weight matrix {wk.shape}:\n", wk, "\n")
print(f"Single value weight matrix {wv.shape}:\n", wv)

Single key weight matrix (12, 12):
 [[0.3 0.4 0.6 0.8 0.7 0.8 0.3 0.6 0.1 0.1 0.7 0.2]
 [0.6 0.9 0.3 0.8 0.1 0.5 1.  0.9 0.4 0.6 0.9 0.7]
 [0.2 1.  0.7 0.2 0.3 1.  0.2 0.9 0.3 0.1 0.1 0.3]
 [0.9 0.7 0.3 0.2 0.9 0.8 0.6 0.9 0.3 0.5 0.7 0.7]
 [0.5 0.  0.2 0.4 0.3 0.8 0.6 0.  0.3 0.4 0.2 0.8]
 [0.4 0.5 0.3 0.6 0.3 0.1 1.  0.5 0.1 0.7 0.8 0.8]
 [0.9 0.1 0.5 0.1 0.5 0.6 0.4 0.1 0.3 0.4 0.6 0.3]
 [0.7 0.4 0.6 0.2 0.8 0.2 0.5 0.2 0.4 0.2 0.5 0. ]
 [0.1 0.5 0.4 0.9 0.1 0.9 1.  1.  0.8 0.1 0.5 0.3]
 [0.8 0.7 0.9 0.1 0.9 0.8 0.3 0.1 0.6 1.  0.6 0. ]
 [0.6 0.6 1.  0.3 0.4 0.2 0.1 0.1 0.4 0.5 0.5 0.7]
 [0.2 0.1 0.  0.9 0.7 0.  0.8 0.1 0.3 0.1 0.4 1. ]] 

Single value weight matrix (12, 12):
 [[0.2 0.9 0.5 0.5 0.6 0.  0.  0.4 0.7 0.1 0.7 0.9]
 [0.9 0.5 0.7 0.4 0.1 0.3 0.2 0.3 0.3 0.4 0.4 0.7]
 [0.9 0.7 0.7 0.3 0.1 0.  0.1 0.6 0.2 0.2 0.4 0.3]
 [0.8 0.6 0.2 0.2 1.  1.  0.4 0.2 0.8 0.2 0.5 0.9]
 [0.8 1.  0.5 0.6 0.6 0.5 1.  0.1 0.6 0.9 0.9 0.8]
 [0.9 0.5 1.  0.6 0.8 0.3 0.2 0.7 0.7 1.  1.  0.5]
 [0.8


Now we can calculate all our queries, keys, and values with three dot products.

In [18]:
q_s = np.dot(x, wq)
k_s = np.dot(x, wk)
v_s = np.dot(x, wv)

These are our resulting query vectors (we'll call them "combined queries"). How do we simulate different heads with this?

In [19]:
print(f"Query vectors using a single weight matrix {q_s.shape}:\n", q_s)

Query vectors using a single weight matrix (1, 3, 12):
 [[[2.75 2.9  2.86 2.35 3.53 3.2  2.57 3.41 2.13 3.14 3.66 1.64]
  [2.44 2.6  2.34 2.42 2.75 2.8  2.22 2.75 2.2  2.74 2.82 1.54]
  [2.54 2.11 2.95 2.14 3.31 2.85 1.89 3.49 2.12 3.34 2.95 1.51]]]


Somehow, we need to separate these vectors such they're treated like three separate sets by the self-attention operation.

In [20]:
print(q0, "\n")
print(q1, "\n")
print(q2)

[[[2.75 2.9  2.86 2.35]
  [2.44 2.6  2.34 2.42]
  [2.54 2.11 2.95 2.14]]] 

[[[3.53 3.2  2.57 3.41]
  [2.75 2.8  2.22 2.75]
  [3.31 2.85 1.89 3.49]]] 

[[[2.13 3.14 3.66 1.64]
  [2.2  2.74 2.82 1.54]
  [2.12 3.34 2.95 1.51]]]


Notice how each set of per-head queries looks like we took the combined queries, and chopped them vertically every four dimensions.
<br><br>
We can split our combined queries into $\text{d}\ \text{x}\ \text{d/h}$ heads using **reshape** and **transpose**.<br><br>
The first step is to *reshape* our combined queries from a shape of:<br>
(batch_size, seq_len, embed_dim)<br>

into a shape of<br>
 (batch_size, seq_len, num_heads, head_dim).
 <br>

 https://www.tensorflow.org/api_docs/python/tf/reshape

In [21]:
# Note: we can achieve the same thing by passing -1 instead of seq_len.
q_s_reshaped = jnp.reshape(q_s, (batch_size, seq_len, num_heads, head_dim))
print(f"Combined queries: {q_s.shape}\n", q_s, "\n")
print(f"Reshaped into separate heads: {q_s_reshaped.shape}\n", q_s_reshaped)

Combined queries: (1, 3, 12)
 [[[2.75 2.9  2.86 2.35 3.53 3.2  2.57 3.41 2.13 3.14 3.66 1.64]
  [2.44 2.6  2.34 2.42 2.75 2.8  2.22 2.75 2.2  2.74 2.82 1.54]
  [2.54 2.11 2.95 2.14 3.31 2.85 1.89 3.49 2.12 3.34 2.95 1.51]]] 

Reshaped into separate heads: (1, 3, 3, 4)
 [[[[2.75 2.9  2.86 2.35]
   [3.53 3.2  2.57 3.41]
   [2.13 3.14 3.66 1.64]]

  [[2.44 2.6  2.34 2.42]
   [2.75 2.8  2.22 2.75]
   [2.2  2.74 2.82 1.54]]

  [[2.54 2.11 2.95 2.14]
   [3.31 2.85 1.89 3.49]
   [2.12 3.34 2.95 1.51]]]]


At this point, we have our desired shape. The next step is to *transpose* it such that simulates vertically chopping our combined queries. By transposing, our matrix dimensions become:<br>
(batch_size, num_heads, seq_len, head_dim)<br>

https://www.tensorflow.org/api_docs/python/tf/transpose

In [22]:
q_s_transposed = jnp.transpose(q_s_reshaped, axes=[0, 2, 1, 3])
print(f"Queries transposed into \"separate\" heads {q_s_transposed.shape}:\n", 
      q_s_transposed)

Queries transposed into "separate" heads (1, 3, 3, 4):
 [[[[2.75 2.9  2.86 2.35]
   [2.44 2.6  2.34 2.42]
   [2.54 2.11 2.95 2.14]]

  [[3.53 3.2  2.57 3.41]
   [2.75 2.8  2.22 2.75]
   [3.31 2.85 1.89 3.49]]

  [[2.13 3.14 3.66 1.64]
   [2.2  2.74 2.82 1.54]
   [2.12 3.34 2.95 1.51]]]]


If we compare this against the separate per-head queries we calculated previously, we see the same result except we now have all our queries in a single matrix.

In [23]:
print("The separate per-head query matrices from before: ")
print(q0, "\n")
print(q1, "\n")
print(q2)

The separate per-head query matrices from before: 
[[[2.75 2.9  2.86 2.35]
  [2.44 2.6  2.34 2.42]
  [2.54 2.11 2.95 2.14]]] 

[[[3.53 3.2  2.57 3.41]
  [2.75 2.8  2.22 2.75]
  [3.31 2.85 1.89 3.49]]] 

[[[2.13 3.14 3.66 1.64]
  [2.2  2.74 2.82 1.54]
  [2.12 3.34 2.95 1.51]]]


In [24]:
k_s_transposed = jnp.transpose(jnp.reshape(k_s, (batch_size, -1, num_heads, head_dim)), axes=[0, 2, 1, 3])
v_s_transposed = jnp.transpose(jnp.reshape(v_s, (batch_size, -1, num_heads, head_dim)), axes=[0, 2, 1, 3])

print(f"Keys for all heads in a single matrix {k_s.shape}: \n", k_s_transposed, "\n")
print(f"Values for all heads in a single matrix {v_s.shape}: \n", v_s_transposed)

Keys for all heads in a single matrix (1, 3, 12): 
 [[[[3.18 3.45 3.1  2.62]
   [2.32 2.39 2.34 2.27]
   [2.26 2.38 2.7  2.3 ]]

  [[2.9  4.27 3.36 3.43]
   [2.06 3.3  2.73 2.35]
   [2.03 2.69 2.58 2.19]]

  [[2.28 2.48 3.32 2.6 ]
   [1.94 1.76 2.27 2.5 ]
   [1.93 1.63 2.42 2.55]]]] 

Values for all heads in a single matrix (1, 3, 12): 
 [[[[3.99 4.23 2.71 2.65]
   [3.16 3.67 1.82 2.6 ]
   [3.19 3.77 2.08 2.75]]

  [[2.61 2.68 2.74 2.66]
   [2.4  2.37 2.78 1.85]
   [2.4  2.56 2.48 2.41]]

  [[3.4  2.4  3.03 3.56]
   [2.76 2.32 2.52 3.12]
   [2.42 2.14 2.04 2.73]]]]


Set up this way, we can now calculate the outputs from all attention heads with a single call to our self-attention operation.

In [25]:
all_heads_output, all_attn_weights = scaled_dot_product_attention(q_s_transposed, 
                                                                  k_s_transposed, 
                                                                  v_s_transposed)
print("Self attention output:\n", all_heads_output)

Self attention output:
 [[[[3.960351  4.211777  2.6832957 2.6515236]
   [3.9453316 4.202458  2.6695168 2.6521323]
   [3.9410832 4.199962  2.6660192 2.652564 ]]

  [[2.6091306 2.6788611 2.7399383 2.6570723]
   [2.6075811 2.6769052 2.7397096 2.6520724]
   [2.6086218 2.6782243 2.739854  2.6554472]]

  [[3.350872  2.3895252 2.9842615 3.5212994]
   [3.318234  2.3826733 2.9540384 3.495711 ]
   [3.337096  2.3867323 2.9716554 3.5106113]]]]


As a sanity check, we can compare this against the outputs from individual heads we calculated earlier:

In [26]:
print("Per head outputs from using separate sets of weights per head:")
print(out0, "\n")
print(out1, "\n")
print(out2)

Per head outputs from using separate sets of weights per head:
[[[3.960351  4.211777  2.6832957 2.6515236]
  [3.9453316 4.202458  2.6695168 2.6521323]
  [3.9410832 4.199962  2.6660192 2.652564 ]]] 

[[[2.6091306 2.6788611 2.7399383 2.6570723]
  [2.6075811 2.6769052 2.7397096 2.6520724]
  [2.6086218 2.6782243 2.739854  2.6554472]]] 

[[[3.350872  2.3895252 2.9842615 3.5212994]
  [3.318234  2.3826733 2.9540384 3.495711 ]
  [3.337096  2.3867323 2.9716554 3.5106113]]]


To get the final concatenated result, we need to reverse our **reshape** and **transpose** operation, starting with the **transpose** this time.

In [27]:
combined_out_b = jnp.reshape(jnp.transpose(all_heads_output, axes=[0, 2, 1, 3]), 
                            newshape=(batch_size, seq_len, embed_dim))
print("Final output from using single query, key, value matrices:\n", 
      combined_out_b, "\n")
print("Final output from using separate query, key, value matrices per head:\n", 
      combined_out_a)

Final output from using single query, key, value matrices:
 [[[3.960351  4.211777  2.6832957 2.6515236 2.6091306 2.6788611 2.7399383
   2.6570723 3.350872  2.3895252 2.9842615 3.5212994]
  [3.9453316 4.202458  2.6695168 2.6521323 2.6075811 2.6769052 2.7397096
   2.6520724 3.318234  2.3826733 2.9540384 3.495711 ]
  [3.9410832 4.199962  2.6660192 2.652564  2.6086218 2.6782243 2.739854
   2.6554472 3.337096  2.3867323 2.9716554 3.5106113]]] 

Final output from using separate query, key, value matrices per head:
 [[[3.960351  4.211777  2.6832957 2.6515236 2.6091306 2.6788611 2.7399383
   2.6570723 3.350872  2.3895252 2.9842615 3.5212994]
  [3.9453316 4.202458  2.6695168 2.6521323 2.6075811 2.6769052 2.7397096
   2.6520724 3.318234  2.3826733 2.9540384 3.495711 ]
  [3.9410832 4.199962  2.6660192 2.652564  2.6086218 2.6782243 2.739854
   2.6554472 3.337096  2.3867323 2.9716554 3.5106113]]]


We can encapsulate everything we just covered in a class.

---



In [28]:
class MultiHeadSelfAttention(hk.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadSelfAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads

        self.d_head = self.d_model // self.num_heads

        self.wq = hk.Linear(self.d_model)
        self.wk = hk.Linear(self.d_model)
        self.wv = hk.Linear(self.d_model)

        # Linear layer to generate the final output.
        self.dense = hk.Linear(self.d_model)

    def split_heads(self, x):
        batch_size = x.shape[0]

        split_inputs = jnp.reshape(x, (batch_size, -1, self.num_heads, self.d_head))
        return jnp.transpose(split_inputs, axes=[0, 2, 1, 3])

    def merge_heads(self, x):
        batch_size = x.shape[0]

        merged_inputs = jnp.transpose(x, axes=[0, 2, 1, 3])
        return jnp.reshape(merged_inputs, (batch_size, -1, self.d_model))

    def __call__(self, q, k, v, mask):
        qs = self.wq(q)
        ks = self.wk(k)
        vs = self.wv(v)

        qs = self.split_heads(qs)
        ks = self.split_heads(ks)
        vs = self.split_heads(vs)

        output, attn_weights = scaled_dot_product_attention(qs, ks, vs, mask)
        output = self.merge_heads(output)

        return self.dense(output), attn_weights


In [29]:
# mhsa = MultiHeadSelfAttention(12, 3)
def fwdx(q, k, v, mask):
    mhsa = MultiHeadSelfAttention(12, 3)
    return mhsa(q, k, v, mask)
    
mhsa = hk.transform(fwdx)
rng = hk.PRNGSequence(jax.random.PRNGKey(seed))
p = mhsa.init(next(rng), x, x, x, None)
output, attn_weights = mhsa.apply(p, None, x, x, x, None)
print(f"MHSA output{output.shape}:")
print(output)

MHSA output(1, 3, 12):
[[[ 0.08374095  0.28195387  0.45835775 -0.20913911 -0.00389449
    0.51052946  0.38135567 -0.07603626 -0.8463101   0.7846994
    0.3223069   0.7794864 ]
  [ 0.08046716  0.27548125  0.46089178 -0.21426184  0.00246251
    0.5044182   0.38591275 -0.08152059 -0.8504362   0.7838822
    0.32661867  0.7838007 ]
  [ 0.08387989  0.27426255  0.47376037 -0.2213613   0.00847191
    0.49273473  0.37900913 -0.08729805 -0.8445233   0.77479714
    0.3281104   0.7750339 ]]]


  param = init(shape, dtype)


## Encoder Block

We can now build our **Encoder Block**. In addition to the **Multi-Head Self Attention** layer, the **Encoder Block** also has **skip connections**, **layer normalization steps**, and a **two-layer feed-forward neural network**. The original **Attention Is All You Need** paper also included some **dropout** applied to the self-attention output which isn't shown in the illustration below (see references for a link to the paper).

<div>
<img src="https://drive.google.com/uc?export=view&id=1D8sLDyQMqqhCjHWOn-I7rZKHugWxFyLy" width="500"/>
</div>

Since a two-layer feed forward neural network is used in multiple places in the transformer, here's a function which creates and returns one.

In [30]:
def feed_forward_network(d_model, hidden_dim):
    return hk.Sequential([
        Dense(hidden_dim, activation=jax.nn.relu),
        Dense(d_model)
    ])

This is our encoder block containing all the layers and steps from the preceding illustration (plus dropout).

In [31]:
# mhsa = MultiHeadSelfAttention(12, 3)
# def fwdx(q, k, v, mask):
#     mhsa = MultiHeadSelfAttention(12, 3)
#     return mhsa(q, k, v, mask)
    
# mhsa = hk.transform(fwdx)
# rng = hk.PRNGSequence(jax.random.PRNGKey(seed))
# p = mhsa.init(next(rng), x, x, x, None)
# output, attn_weights = mhsa.apply(p, None, x, x, x, None)
# print(f"MHSA output{output.shape}:")
# print(output)


class EncoderBlock(hk.Module):
    def __init__(self, d_model, num_heads, hidden_dim, dropout_rate=0.1):
        super(EncoderBlock, self).__init__()
        
        self.dropout_rate = dropout_rate

        self.mhsa = MultiHeadSelfAttention(d_model, num_heads)
        # self.params_mhsa = self.mhsa.init(x, x, x, None)
        self.ffn = feed_forward_network(d_model, hidden_dim)

        # self.dropout1 = partial(hk.dropout, rng=next(rng), rate=dropout_rate)
        # self.dropout2 = partial(hk.dropout, rng=next(rng), rate=dropout_rate)

        self.layernorm1 = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)
        self.layernorm2 = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)
    
    def __call__(self, x, training, mask):
        mhsa_output, attn_weights = self.mhsa(x, x, x, mask)
        mhsa_output = hk.dropout(next(rng), self.dropout_rate, mhsa_output) #, training=training)
        mhsa_output = self.layernorm1(x + mhsa_output)

        ffn_output = self.ffn(mhsa_output)
        ffn_output = hk.dropout(next(rng), self.dropout_rate, ffn_output) #, training=training)
        output = self.layernorm2(mhsa_output + ffn_output)

        return output, attn_weights


Suppose we have an embedding dimension of 12, and we want 3 attention heads and a feed forward network with a hidden dimension of 48 (4x the embedding dimension). We would declare and use a single encoder block like so:

In [32]:


def fwd_enc(x, training, mask):
    encoder_block = EncoderBlock(12, 3, 48)
    return encoder_block(x, training, mask)
    
encoder_block = hk.transform(fwd_enc)
params_enc = encoder_block.init(next(rng), x, True, None)
block_output,  _ = encoder_block.apply(params_enc, next(rng), x, True, None)
print(f"Output from single encoder block {block_output.shape}:")
print(block_output)

Output from single encoder block (1, 3, 12):
[[[ 0.75087804 -1.5568879   0.4482081   0.51610124 -0.86587226
    0.9495493   0.83768344 -0.03637511  0.1033235   1.3288491
   -0.41613346 -2.0593245 ]
  [ 0.5299162  -1.9707414  -0.98887664  0.32056147  0.9441586
    0.5829572  -0.16058876  0.62748796  0.40224558  1.0025263
    0.601236   -1.8908825 ]
  [-0.4772026  -1.9451121  -0.66309965 -0.35051018 -0.9115741
    1.1237394   0.894805    0.15113567  1.0693713   0.20748584
    1.6580435  -0.7570817 ]]]


## Word and Positional Embeddings

Let's now deal with the actual input to the **initial** encoder block. The inputs are going to be *positional word embeddings*. That is, word embeddings with some positional information added to them.
<br>

Let's start with **subword** tokenization. For demonstration, we'll use a subword tokenizer called **BPEmb**. It uses **Byte-Pair Encoding** and supports over two hundred languages. 

https://bpemb.h-its.org/


In [33]:
# Load the English tokenizer.
bpemb_en = BPEmb(lang="en")

downloading https://nlp.h-its.org/bpemb/en/en.wiki.bpe.vs10000.model


100%|██████████| 400869/400869 [00:00<00:00, 4519790.14B/s]


downloading https://nlp.h-its.org/bpemb/en/en.wiki.bpe.vs10000.d100.w2v.bin.tar.gz


100%|██████████| 3784656/3784656 [00:00<00:00, 25668469.85B/s]


The library comes with embeddings for a number of words.

In [34]:
bpemb_vocab_size, bpemb_embed_size = bpemb_en.vectors.shape
print("Vocabulary size:", bpemb_vocab_size)
print("Embedding size:", bpemb_embed_size)

Vocabulary size: 10000
Embedding size: 100


In [35]:
# Embedding for the word "car".
bpemb_en.vectors[bpemb_en.words.index('car')]

array([-0.305548, -0.325598, -0.134716, -0.078735, -0.660545,  0.076211,
       -0.735487,  0.124533, -0.294402,  0.459688,  0.030137,  0.174041,
       -0.224223,  0.486189, -0.504649, -0.459699,  0.315747,  0.477885,
        0.091398,  0.427867,  0.016524, -0.076833, -0.899727,  0.493158,
       -0.022309, -0.422785, -0.154148,  0.204981,  0.379834,  0.070588,
        0.196073, -0.368222,  0.473406,  0.007409,  0.004303, -0.007823,
       -0.19103 , -0.202509,  0.109878, -0.224521, -0.35741 , -0.611633,
        0.329958, -0.212956, -0.497499, -0.393839, -0.130101, -0.216903,
       -0.105595, -0.076007, -0.483942, -0.139704, -0.161647,  0.136985,
        0.415363, -0.360143,  0.038601, -0.078804, -0.030421,  0.324129,
        0.223378, -0.523636, -0.048317, -0.032248, -0.117367,  0.470519,
        0.225816, -0.222065, -0.225007, -0.165904, -0.334389, -0.20157 ,
        0.572352, -0.268794,  0.301929, -0.005563,  0.387491,  0.261031,
       -0.11613 ,  0.074982, -0.008433,  0.259987, 

We don't need the embeddings since we're going to use our own embedding layer. What we're interested in are the subword tokens and their respective ids. The ids will be used as indexes into our embedding layer.<br>

If this doesn't sound familiar, refer to the module on word vectors:<br>
https://www.nlpdemystified.org/course/word-vectors

These are the subword tokens for our example sentence from the slides. **BPEmb** places underscores in front of any tokens which are whole words or intended to begin words.<br>

Remember that subword tokenizers are trained using count frequencies over a corpus. So these subword tokens are specific to **BPEmb**. Another subword tokenizer may output something different. This is why it's important that when we use a pretrained model, we make sure to use the pretrained model's tokenizer. We'll see this when we use pretrained transformers later in this module.

In [36]:
sample_sentence = "Where can I find a pizzeria?"
tokens = bpemb_en.encode(sample_sentence)
print(tokens)

['▁where', '▁can', '▁i', '▁find', '▁a', '▁p', 'iz', 'zer', 'ia', '?']


We can retrieve each subword token's respective id using the *encode_ids* method.

In [37]:
token_seq = np.array(bpemb_en.encode_ids("Where can I find a pizzeria?"))
print(token_seq)

[ 571  280  386 1934    4   24  248 4339  177 9967]


Now that we have a way to tokenize and vectorize sentences, we can declare and use an embedding layer with the same vocabulary size as **BPEmb** and a desired embedding size.

In [47]:
token_embed = hk.transform(lambda x: hk.Embed(bpemb_vocab_size, embed_dim)(x))
p_token = token_embed.init(next(rng), token_seq)
token_embeddings = token_embed.apply(p_token, next(rng), token_seq)

# The untrained embeddings for our sample sentence.
print("Embeddings for: ", sample_sentence, f"({token_embeddings.shape})")
print(token_embeddings)

Embeddings for:  Where can I find a pizzeria? ((10, 12))
[[ 1.6094604  -0.52538157 -0.8895934  -0.5229291  -1.085535    0.42052966
  -1.6960709   1.1735058   0.37995493 -0.49312064 -0.42313188  0.15053356]
 [ 1.7967463   0.3124195   0.55778956 -0.9394734   0.420552    0.33601767
  -0.18352616 -1.0491136  -0.64903814 -0.15722333  0.9599779   1.6771305 ]
 [ 0.30299115 -0.31363884 -0.11355441 -0.03616833 -0.5260365  -0.53790313
  -0.4002148   0.8586032  -0.92141944 -1.4005889  -0.473717    1.5010474 ]
 [-0.04111395 -1.6209657  -0.22683966  1.1904185   0.67703724 -0.29590532
   0.29054746  0.56238264 -0.7022962  -1.0431658  -0.5001555   0.30058387]
 [-0.08979569 -0.23829778 -0.77581775  0.13661242 -0.70797837 -0.14180613
   1.4101925   0.4633062  -0.802325    0.20152093 -1.5932127  -0.32765976]
 [ 0.08688173 -0.84189755  0.1865582  -0.13941963 -0.22614297  1.7828015
  -0.36164925 -0.5667005   1.6301054  -0.14879316  0.56012404  1.0910765 ]
 [-0.10597399 -0.37248015  1.5519582  -0.8088619  

Next, we need to add *positional* information to each token embedding. As we covered in the slides, the original paper used sinusoidals but it's more common these days to just use another set of embeddings. We'll do the latter here.<br>

Here, we're declaring an embedding layer with rows equalling a maximum sequence length and columns equalling our token embedding size. We then generate a vector of position ids.

In [56]:
max_seq_len = 256
pos_embed = hk.transform(lambda x: hk.Embed(max_seq_len, embed_dim)(x))

# Generate ids for each position of the token sequence.
pos_idx = jnp.arange(len(token_seq))
print(pos_idx)

[0 1 2 3 4 5 6 7 8 9]


We'll use these position ids to index into the positional embedding layer.

In [58]:
# These are our positon embeddings.
p_pos = pos_embed.init(next(rng), pos_idx)
position_embeddings = pos_embed.apply(p_pos, next(rng), pos_idx)
print(f"Position embeddings for the input sequence ({position_embeddings.shape})\n", position_embeddings)

Position embeddings for the input sequence ((10, 12))
 [[ 0.7781124   1.7142845  -0.2877947  -0.23030427  0.7116986   0.8313217
  -0.12955669 -0.15172009  0.05667705  0.08876359 -1.1676164   0.05004088]
 [-0.6225034   1.2481763  -0.328634    0.02548295 -1.8445315  -0.0901474
   1.7251247   0.41959327 -1.4409891   0.64706844  0.17834616 -0.42675495]
 [ 1.7740337  -0.13828552 -1.1180384  -0.41133064 -1.1012032  -1.6601293
   0.99315846 -0.58263636 -1.106893    0.5538767  -0.64504117  0.7844581 ]
 [ 1.5640007  -0.02428322 -0.79117036  0.755527    1.2078457  -0.3640467
  -1.3301423  -0.3062533   0.55079937 -0.6502779  -1.1676248   0.19947718]
 [-0.29926735 -1.040653    1.5888288   1.4027246  -0.2158562  -1.52093
  -0.04889803 -0.73962593  0.0248518   0.23585802  1.0188191   0.07387792]
 [ 0.59272784 -0.12436684  0.77351147  1.0164343  -1.0248847  -0.22747943
  -1.4248497  -0.8895201   0.08861105 -0.4611798  -1.1168866   1.851782  ]
 [ 0.9001088   1.1408491   0.7369474  -0.9238413   0.50584

The final step is to add our token and position embeddings. The result will be the input to the first encoder block.

In [60]:
input = token_embeddings + position_embeddings
print(f"Input to the initial encoder block ({input.shape}):\n", input)

Input to the initial encoder block ((10, 12)):
 [[ 2.3875728   1.188903   -1.1773882  -0.7532333  -0.37383646  1.2518513
  -1.8256276   1.0217857   0.43663198 -0.40435705 -1.5907483   0.20057444]
 [ 1.1742429   1.5605959   0.22915557 -0.91399044 -1.4239795   0.24587026
   1.5415986  -0.62952036 -2.0900273   0.4898451   1.138324    1.2503755 ]
 [ 2.077025   -0.45192435 -1.2315928  -0.44749898 -1.6272397  -2.1980324
   0.59294367  0.27596682 -2.0283124  -0.8467122  -1.1187582   2.2855055 ]
 [ 1.5228868  -1.6452489  -1.01801     1.9459455   1.8848829  -0.65995204
  -1.0395948   0.25612932 -0.15149683 -1.6934438  -1.6677804   0.50006104]
 [-0.38906306 -1.2789508   0.81301105  1.539337   -0.92383456 -1.6627362
   1.3612945  -0.27631974 -0.7774732   0.43737894 -0.57439363 -0.25378183]
 [ 0.6796096  -0.96626437  0.96006966  0.8770147  -1.2510277   1.555322
  -1.786499   -1.4562206   1.7187164  -0.60997295 -0.5567626   2.9428585 ]
 [ 0.79413486  0.76836896  2.2889056  -1.7327032   0.10422492  

## Encoder

Now that we have an encoder block and a way to embed our tokens with position information, we can create the **encoder** itself.<br>

Given a batch of vectorized sequences, the encoder creates positional embeddings, runs them through its encoder blocks, and returns contextualized tokens.

In [109]:
class Encoder(hk.Module):
    def __init__(self, num_blocks, d_model, num_heads, hidden_dim, src_vocab_size,
                max_seq_len, dropout_rate=0.1):
        super(Encoder, self).__init__()
        
        self.dropout_rate = dropout_rate

        self.d_model = d_model
        self.max_seq_len = max_seq_len

        self.token_embed = hk.Embed(src_vocab_size, self.d_model)
        self.pos_embed = hk.Embed(max_seq_len, self.d_model)

        # The original Attention Is All You Need paper applied dropout to the
        # input before feeding it to the first encoder block.
        # self.dropout = tf.keras.layers.Dropout(dropout_rate)

        # Create encoder blocks.
        self.blocks = [EncoderBlock(self.d_model, num_heads, hidden_dim, dropout_rate) 
            for _ in range(num_blocks)]
    
    def __call__(self, input, training, mask):
        token_embeds = self.token_embed(input)

        # Generate position indices for a batch of input sequences.
        num_pos = input.shape[0] * self.max_seq_len
        pos_idx = np.resize(np.arange(self.max_seq_len), num_pos)
        pos_idx = np.reshape(pos_idx, input.shape)
        pos_embeds = self.pos_embed(pos_idx)

        x = hk.dropout(next(rng), self.dropout_rate, token_embeds + pos_embeds) #, training=training)

        # Run input through successive encoder blocks.
        for block in self.blocks:
            x, weights = block(x, training, mask)

        return x, weights

If you're wondering about this code block here:


```
num_pos = input.shape[0] * self.max_seq_len
pos_idx = np.resize(np.arange(self.max_seq_len), num_pos)
pos_idx = np.reshape(pos_idx, input.shape)
pos_embeds = self.pos_embed(pos_idx)
```


This generates positional embeddings for a *batch* of input sequences. Suppose this was our batch of input sequences to the encoder.

In [99]:
# Batch of 3 sequences, each of length 10 (10 is also the 
# maximum sequence length in this case).
seqs = np.random.randint(0, 10000, size=(3, 10))
print(seqs.shape)
print(seqs)

(3, 10)
[[6772 8373 2283 4289 6732  120 8717  222 3676 4332]
 [3632 9768 9114 7326 7557 2544 6432   62 6490 2023]
 [8342 8595 2437 2852  127 1508 9989 2508 7322 9299]]


We need to retrieve a positional embedding for every element in this batch. The first step is to create the respective positional ids...

In [100]:
pos_ids = np.resize(np.arange(seqs.shape[1]), seqs.shape[0] * seqs.shape[1])
print(pos_ids)


[0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9]


...and then reshape them to match the input batch dimensions.

In [101]:
pos_ids = np.reshape(pos_ids, (3, 10))
print(pos_ids.shape)
print(pos_ids)

(3, 10)
[[0 1 2 3 4 5 6 7 8 9]
 [0 1 2 3 4 5 6 7 8 9]
 [0 1 2 3 4 5 6 7 8 9]]


We can now retrieve position embeddings for every token embedding.

In [102]:
pos_embed.apply(p_pos, next(rng), pos_ids).shape

(3, 10, 12)

Let's try our encoder on a batch of sentences.

In [103]:
input_batch = [
    "Where can I find a pizzeria?",
    "Mass hysteria over listeria.",
    "I ain't no circle back girl."
]

bpemb_en.encode(input_batch)

[['▁where', '▁can', '▁i', '▁find', '▁a', '▁p', 'iz', 'zer', 'ia', '?'],
 ['▁mass', '▁hy', 'ster', 'ia', '▁over', '▁l', 'ister', 'ia', '.'],
 ['▁i', '▁a', 'in', "'", 't', '▁no', '▁circle', '▁back', '▁girl', '.']]

In [104]:
input_seqs = bpemb_en.encode_ids(input_batch)
print("Vectorized inputs:")
input_seqs

Vectorized inputs:


[[571, 280, 386, 1934, 4, 24, 248, 4339, 177, 9967],
 [1535, 1354, 1238, 177, 380, 43, 871, 177, 9935],
 [386, 4, 6, 9937, 9915, 467, 5410, 810, 3692, 9935]]

Note how the input sequences aren't the same length in this batch. In this case, we need to pad them out so that they are. If you're unfamiliar with why, refer to the notebook on Recurrent Neural Networks:<br>
https://colab.research.google.com/github/nitinpunjabi/nlp-demystified/blob/main/notebooks/nlpdemystified_recurrent_neural_networks.ipynb<br>

We'll do this using *pad_sequences*.<br>
https://www.tensorflow.org/api_docs/python/tf/keras/utils/pad_sequences

In [123]:
padded_input_seqs = tf.keras.preprocessing.sequence.pad_sequences(input_seqs, padding="post")
print("Input to the encoder:")
print(padded_input_seqs.shape)
print(padded_input_seqs)

Input to the encoder:
(3, 10)
[[ 571  280  386 1934    4   24  248 4339  177 9967]
 [1535 1354 1238  177  380   43  871  177 9935    0]
 [ 386    4    6 9937 9915  467 5410  810 3692 9935]]


Since our input now has padding, now's a good time to cover **masking**.
<br>

So given a mask, wherever there's a mask position set to 0, the corresponding position in the attention scores will be set to *-inf*. The resulting attention weight for the position will then be zero and no attending will occur for that position.
<br>

In the slides, we covered *look-ahead* masks for the decoder to prevent it from attending to future tokens, but we also need masks for padding.
<br>

In total, there are three masks involved:
1. The *encoder mask* to mask out any padding in the encoder sequences.

2. The *decoder mask* which is used in the decoder's **first** multi-head self-attention layer. It's a <u>combination of two masks</u>: one to account for the padding in target sequences, and the look-ahead mask.

3. The *memory mask* which is used in the decoder's **second** multi-head self-attention layer. The keys and values for this layer are going to be the encoder's output, and this mask will ensure the decoder doesn't attend to any encoder output which corresponds to padding. In practice, 1 and 3 are often the same.

The *scaled_dot_product_attention* function has this line:
```
  if mask is not None:
    scaled_scores = tf.where(mask==0, -np.inf, scaled_scores)
```

Let's create an encoder mask for our batch of input sequences.<br>

Wherever there's padding, we want the mask position set to zero.

In [124]:
enc_mask = np.where(padded_input_seqs, 1, 0)
print("Input:")
print(padded_input_seqs, '\n')
print("Encoder mask:")
print(enc_mask)

Input:
[[ 571  280  386 1934    4   24  248 4339  177 9967]
 [1535 1354 1238  177  380   43  871  177 9935    0]
 [ 386    4    6 9937 9915  467 5410  810 3692 9935]] 

Encoder mask:
[[1 1 1 1 1 1 1 1 1 1]
 [1 1 1 1 1 1 1 1 1 0]
 [1 1 1 1 1 1 1 1 1 1]]


Keep in mind that the dimension of the attention matrix (for this example) is going to be:<br>
*(batch size, number of heads, query size, key size)*<br>
(3, 3, 10, 10)

So we need to expand the mask dimensions like so:

In [125]:
enc_mask = enc_mask[:, None, None, :]
enc_mask

array([[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]],


       [[[1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]],


       [[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]]])

This way, the encoder mask will now be *broadcasted*.<br>
https://www.tensorflow.org/xla/broadcasting

Now we can declare an encoder and pass it batches of vectorized sequences.

In [126]:

def build_encoder(padded_input_seqs, training, mask):
    num_encoder_blocks = 6

    # d_model is the embedding dimension used throughout.
    d_model = 12

    num_heads = 3

    # Feed-forward network hidden dimension width.
    ffn_hidden_dim = 48

    src_vocab_size = bpemb_vocab_size
    max_input_seq_len = padded_input_seqs.shape[1]

    return Encoder(
        num_encoder_blocks,
        d_model,
        num_heads,
        ffn_hidden_dim,
        src_vocab_size,
        max_input_seq_len)(padded_input_seqs, training, mask)
    
encoder = hk.transform(build_encoder)
params_enc2 = encoder.init(next(rng), padded_input_seqs, training=True, mask=enc_mask)

We can now pass our input sequences and mask to the encoder.

In [127]:
encoder_output, attn_weights = encoder.apply(params_enc2, next(rng), 
                                             padded_input_seqs, training=True, mask=enc_mask)
print(f"Encoder output {encoder_output.shape}:")
print(encoder_output)

Encoder output (3, 10, 12):
[[[ 1.41915858e+00 -3.07213008e-01  1.10027516e+00 -4.75663185e-01
   -1.30171478e-01  5.72483182e-01  5.82186282e-01  4.49498504e-01
   -1.79342103e+00  1.08983183e+00 -1.35931861e+00 -1.14764631e+00]
  [ 1.74795461e+00 -1.60028651e-01  7.03835070e-01 -3.66131395e-01
   -6.39971673e-01  6.04712069e-01 -8.59635890e-01  6.99656904e-01
   -3.45304072e-01  1.41395795e+00 -1.04601347e+00 -1.75303125e+00]
  [ 1.19105601e+00  6.01129889e-01  3.93038094e-02  4.87445354e-01
   -6.61754668e-01  8.12084973e-01  1.74837068e-01  3.17051888e-01
   -1.31270194e+00  1.46537018e+00 -1.42688251e+00 -1.68694019e+00]
  [ 8.62344146e-01  2.49831170e-01  9.24054027e-01  5.88451862e-01
   -8.11926842e-01  3.88371974e-01  6.01497246e-03  1.65743917e-01
   -8.38069797e-01  1.54521656e+00 -7.26764679e-01 -2.35326767e+00]
  [-1.23476624e+00  6.44548893e-01  1.37193394e+00  3.45725179e-01
   -1.04841232e+00  1.03857934e+00 -4.68277752e-01  3.51037592e-01
    2.77033776e-01  1.38381481

## Decoder Block

Let's build the **Decoder Block**. Everything we did to create the **encoder** block applies here. The major differences are that the **Decoder Block** has:
1. a **Multi-Head Cross-Attention** layer which uses the encoder's outputs as the keys and values.

2. an extra skip/residual connection along with an extra layer normalization step.

<div>
<img src="https://drive.google.com/uc?export=view&id=1WVT4SX49bnta4uscOTF4xrsxFI4PbPER" width="500"/>
</div>

In [143]:
class DecoderBlock(hk.Module):
    def __init__(self, d_model, num_heads, hidden_dim, dropout_rate=0.1):
        super(DecoderBlock, self).__init__()

        self.mhsa1 = MultiHeadSelfAttention(d_model, num_heads)
        self.mhsa2 = MultiHeadSelfAttention(d_model, num_heads)

        self.ffn = feed_forward_network(d_model, hidden_dim)

        self.dropout1 = partial(hk.dropout, rng=next(rng), rate=dropout_rate)
        self.dropout2 = partial(hk.dropout, rng=next(rng), rate=dropout_rate)
        self.dropout3 = partial(hk.dropout, rng=next(rng), rate=dropout_rate)

        self.layernorm1 = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)
        self.layernorm2 = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)
        self.layernorm3 = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)
    
    # Note the decoder block takes two masks. One for the first MHSA, another
    # for the second MHSA.
    def __call__(self, encoder_output, target, training, decoder_mask, memory_mask):
        mhsa_output1, attn_weights = self.mhsa1(target, target, target, decoder_mask)
        if training:
            mhsa_output1 = self.dropout1(x=mhsa_output1)
        mhsa_output1 = self.layernorm1(mhsa_output1 + target)

        mhsa_output2, attn_weights = self.mhsa2(mhsa_output1, encoder_output, 
                                                encoder_output, 
                                                memory_mask)
        if training:
            mhsa_output2 = self.dropout2(x=mhsa_output2)
        mhsa_output2 = self.layernorm2(mhsa_output2 + mhsa_output1)

        ffn_output = self.ffn(mhsa_output2)
        if training:
            ffn_output = self.dropout3(x=ffn_output)
        output = self.layernorm3(ffn_output + mhsa_output2)

        return output, attn_weights


## Decoder

The decoder is almost the same as the encoder except it takes the encoder's output as part of its input, and it takes two masks: the decoder mask and memory mask.

In [151]:
class Decoder(tf.keras.layers.Layer):
    def __init__(self, num_blocks, d_model, num_heads, hidden_dim, target_vocab_size,
                max_seq_len, dropout_rate=0.1):
        super(Decoder, self).__init__()

        self.d_model = d_model
        self.max_seq_len = max_seq_len

        self.token_embed = hk.Embed(target_vocab_size, self.d_model)
        self.pos_embed = hk.Embed(max_seq_len, self.d_model)

        self.dropout = partial(hk.dropout, rng=next(rng), rate=dropout_rate)

        self.blocks = [DecoderBlock(self.d_model, num_heads, hidden_dim, dropout_rate) for _ in range(num_blocks)]

    def __call__(self, encoder_output, target, training, decoder_mask, memory_mask):
        token_embeds = self.token_embed(target)

        # Generate position indices.
        num_pos = target.shape[0] * self.max_seq_len
        pos_idx = np.resize(np.arange(self.max_seq_len), num_pos)
        pos_idx = np.reshape(pos_idx, target.shape)

        pos_embeds = self.pos_embed(pos_idx)

        if training:
            x = self.dropout(x=token_embeds + pos_embeds)

        for block in self.blocks:
            x, weights = block(encoder_output, x, training, decoder_mask, memory_mask)

        return x, weights

Before we try the decoder, let's cover the masks involved. The decoder takes two masks:

The *decoder mask* which is a <u>combination of two masks</u>: one to account for the padding in target sequences, and the look-ahead mask. This mask is used in the decoder's **first** multi-head self-attention layer.

The *memory mask* which is used in the decoder's **second** multi-head self-attention. The keys and values for this layer are going to be the encoder's output, and this mask will ensure the decoder doesn't attend to any encoder output which corresponds to padding.

Suppose this is our batch of vectorized target *input* sequences for the decoder. These values are just made up.<br>

**Note**: If you need a refresher on how to prepare target input and output sequences for the decoder, refer to the [seq2seq notebook](https://colab.research.google.com/github/nitinpunjabi/nlp-demystified/blob/main/notebooks/nlpdemystified_seq2seq_and_attention.ipynb).



In [152]:
# Made up values.
target_input_seqs = [
    [1, 652, 723, 123, 62],
    [1, 25,  98, 129, 248, 215, 359, 249],
    [1, 2369, 1259, 125, 486],
]

As we did with the encoder input sequences, we need to pad out this batch so that all sequences within it are the same length.

In [153]:
padded_target_input_seqs = tf.keras.preprocessing.sequence.pad_sequences(target_input_seqs, padding="post")
print("Padded target inputs to the decoder:")
print(padded_target_input_seqs.shape)
print(padded_target_input_seqs)

Padded target inputs to the decoder:
(3, 8)
[[   1  652  723  123   62    0    0    0]
 [   1   25   98  129  248  215  359  249]
 [   1 2369 1259  125  486    0    0    0]]


We can create the padding mask the same way we did for the encoder.

In [154]:
dec_padding_mask = np.where(padded_target_input_seqs, 1, 0).astype(jnp.float32)
dec_padding_mask = dec_padding_mask[:, None, None, :]
print(dec_padding_mask)

[[[[1. 1. 1. 1. 1. 0. 0. 0.]]]


 [[[1. 1. 1. 1. 1. 1. 1. 1.]]]


 [[[1. 1. 1. 1. 1. 0. 0. 0.]]]]


As we covered in the slides, the look-ahead mask is a diagonal where the lower half are 1s and the upper half are zeros. This is easy to create using the *band_part* method:<br>
https://www.tensorflow.org/api_docs/python/tf/linalg/band_part

In [155]:
target_input_seq_len = padded_target_input_seqs.shape[1]
look_ahead_mask = jnp.tril(np.ones((target_input_seq_len, 
                                    target_input_seq_len)))
print(look_ahead_mask)

[[1. 0. 0. 0. 0. 0. 0. 0.]
 [1. 1. 0. 0. 0. 0. 0. 0.]
 [1. 1. 1. 0. 0. 0. 0. 0.]
 [1. 1. 1. 1. 0. 0. 0. 0.]
 [1. 1. 1. 1. 1. 0. 0. 0.]
 [1. 1. 1. 1. 1. 1. 0. 0.]
 [1. 1. 1. 1. 1. 1. 1. 0.]
 [1. 1. 1. 1. 1. 1. 1. 1.]]


To create the decoder mask, we just need to combine the padding and look-ahead masks. Note how the columns of the resulting decoder mask are all zero for padding positions.

In [156]:
dec_mask = jnp.minimum(dec_padding_mask, look_ahead_mask)
print("The decoder mask:")
print(dec_mask)

The decoder mask:
[[[[1. 0. 0. 0. 0. 0. 0. 0.]
   [1. 1. 0. 0. 0. 0. 0. 0.]
   [1. 1. 1. 0. 0. 0. 0. 0.]
   [1. 1. 1. 1. 0. 0. 0. 0.]
   [1. 1. 1. 1. 1. 0. 0. 0.]
   [1. 1. 1. 1. 1. 0. 0. 0.]
   [1. 1. 1. 1. 1. 0. 0. 0.]
   [1. 1. 1. 1. 1. 0. 0. 0.]]]


 [[[1. 0. 0. 0. 0. 0. 0. 0.]
   [1. 1. 0. 0. 0. 0. 0. 0.]
   [1. 1. 1. 0. 0. 0. 0. 0.]
   [1. 1. 1. 1. 0. 0. 0. 0.]
   [1. 1. 1. 1. 1. 0. 0. 0.]
   [1. 1. 1. 1. 1. 1. 0. 0.]
   [1. 1. 1. 1. 1. 1. 1. 0.]
   [1. 1. 1. 1. 1. 1. 1. 1.]]]


 [[[1. 0. 0. 0. 0. 0. 0. 0.]
   [1. 1. 0. 0. 0. 0. 0. 0.]
   [1. 1. 1. 0. 0. 0. 0. 0.]
   [1. 1. 1. 1. 0. 0. 0. 0.]
   [1. 1. 1. 1. 1. 0. 0. 0.]
   [1. 1. 1. 1. 1. 0. 0. 0.]
   [1. 1. 1. 1. 1. 0. 0. 0.]
   [1. 1. 1. 1. 1. 0. 0. 0.]]]]


We can now declare a decoder and pass it everything it needs. In our case, the *memory* mask is the same as the *encoder* mask.

In [157]:

def build_decoder(encoder_output, target, training, decoder_mask, memory_mask):
    return Decoder(6, 12, 3, 48, 10000, 8)(
        encoder_output, target, training, decoder_mask, memory_mask)
    
decoder = hk.transform(build_decoder)
p_dec = decoder.init(next(rng), encoder_output, padded_target_input_seqs, 
                            True, dec_mask, enc_mask)
    
decoder_output, _ = decoder.apply(p_dec, next(rng), encoder_output, padded_target_input_seqs, 
                            True, dec_mask, enc_mask)
print(f"Decoder output {decoder_output.shape}:")
print(decoder_output)

Decoder output (3, 8, 12):
[[[-1.6127183e+00 -1.8213560e-01  9.1229409e-01  3.8112000e-01
    7.0431960e-01  2.9676458e-01 -1.6670176e-01  9.1213053e-01
    1.8696641e+00 -7.3093414e-01 -1.3832730e+00 -1.0005306e+00]
  [-6.9025362e-01  1.3053727e+00  1.8864352e+00 -2.3502527e-01
    5.5938315e-01 -4.6635270e-01 -8.7049794e-01  7.3003548e-01
    8.5302389e-01 -1.1479421e+00 -1.4559454e+00 -4.6823293e-01]
  [-1.2876213e+00 -4.4866640e-02  1.6226444e+00  1.1253241e+00
    4.6261898e-01  1.1346519e+00 -8.4329134e-01  6.9205856e-01
    3.4550142e-01 -1.4324353e+00 -9.2756349e-01 -8.4702140e-01]
  [ 9.5239811e-02  7.6274380e-02  1.6914363e+00  9.3968374e-01
    7.9112315e-01  5.6380188e-01 -7.2576153e-01 -2.5724638e-01
    6.4836198e-01 -1.1518744e+00 -2.1691794e+00 -5.0185937e-01]
  [-6.6998619e-01 -1.9678682e-01  1.2874272e+00  1.3559465e+00
    2.8068951e-01 -4.4451751e-02 -7.4557924e-01  5.5065371e-02
    1.6681513e+00 -5.4560304e-01 -2.0309603e+00 -4.1391230e-01]
  [-9.8530191e-01 -8.06

## Transformer

We now have all the pieces to build the **Transformer** itself, and it's pretty simple. 

In [159]:
class Transformer(tf.keras.Model):
    def __init__(self, num_blocks, d_model, num_heads, hidden_dim, source_vocab_size,
                target_vocab_size, max_input_len, max_target_len, dropout_rate=0.1):
        super(Transformer, self).__init__()

        self.encoder = Encoder(num_blocks, d_model, num_heads, hidden_dim, source_vocab_size, 
                            max_input_len, dropout_rate)
        
        self.decoder = Decoder(num_blocks, d_model, num_heads, hidden_dim, target_vocab_size,
                            max_target_len, dropout_rate)
        
        # The final dense layer to generate logits from the decoder output.
        self.output_layer = Dense(target_vocab_size)

    def call(self, input_seqs, target_input_seqs, training, encoder_mask,
            decoder_mask, memory_mask):
        encoder_output, encoder_attn_weights = self.encoder(input_seqs, 
                                                            training, encoder_mask)

        decoder_output, decoder_attn_weights = self.decoder(encoder_output, 
                                                            target_input_seqs, training,
                                                            decoder_mask, memory_mask)

        return self.output_layer(decoder_output), encoder_attn_weights, decoder_attn_weights


In [161]:
def build_transformer(input_seqs, target_input_seqs, 
                      training, encoder_mask, decoder_mask, memory_mask):
    return Transformer(
        num_blocks = 6,
        d_model = 12,
        num_heads = 3,
        hidden_dim = 48,
        source_vocab_size = bpemb_vocab_size,
        target_vocab_size = 7000, # made-up target vocab size.
        max_input_len = padded_input_seqs.shape[1],
        max_target_len = padded_target_input_seqs.shape[1])(input_seqs, target_input_seqs, 
                      training, encoder_mask, decoder_mask, memory_mask)

transformer = hk.transform(build_transformer)
params = transformer.init(next(rng), padded_input_seqs, 
                          padded_target_input_seqs, True, 
                          enc_mask, dec_mask, memory_mask=enc_mask)

transformer_output, _, _ = transformer.apply(params, next(rng), padded_input_seqs, 
                                       padded_target_input_seqs, True, 
                                       enc_mask, dec_mask, memory_mask=enc_mask)
print(f"Transformer output {transformer_output.shape}:")
print(transformer_output) # If training, we would use this output to calculate losses.

Transformer output (3, 8, 7000):
[[[0.         0.02891155 1.180111   ... 0.70502913 0.2725506  0.16392171]
  [0.         0.         0.         ... 0.71455824 0.13397743 0.82286805]
  [0.2733927  0.776259   0.32521093 ... 0.         0.         0.58686864]
  ...
  [0.13541673 0.73615843 0.6665622  ... 0.37348574 0.         1.038203  ]
  [0.24842629 0.7897228  0.6983148  ... 0.26553378 0.         0.7705957 ]
  [0.1324051  0.85790825 0.50942504 ... 0.37143385 0.         1.234074  ]]

 [[0.         0.         0.         ... 0.50757706 0.82317847 0.6037749 ]
  [0.         0.         0.         ... 0.         1.2882888  0.5017553 ]
  [0.         0.         0.02330121 ... 0.         1.5237322  0.10844213]
  ...
  [0.         0.         0.         ... 0.39341128 0.30955768 0.02940781]
  [0.         0.         0.         ... 0.45490438 0.         0.05493074]
  [0.         0.         0.         ... 0.05803161 0.         0.        ]]

 [[0.         0.         0.5005577  ... 0.         0.         0

That's the whole original transformer from scratch. From here, if you want to train this transformer, you can use the same approach we used when we built the translation model with attention in the [seq2seq notebook](https://colab.research.google.com/github/nitinpunjabi/nlp-demystified/blob/main/notebooks/nlpdemystified_seq2seq_and_attention.ipynb#scrollTo=x8Ef_eWXjWMn&line=3&uniqifier=1). Remember to use a learning rate warmup (Refer to the paper for more information on this).

It's useful to know how these models work under the hood, but to train our own transformer to get impressive results is expensive. Both in terms of compute and data.<br>

Fortunately, there's a zoo of **pretrained** transformer models we can use. We'll explore that next.

# Pre-Training and Transfer Learning with Hugging Face and OpenAI

**IMPORTANT**<br>
Enable **GPU acceleration** by going to *Runtime > Change Runtime Type*. Keep in mind that, on certain tiers, you're not guaranteed GPU access depending on usage history and current load.
<br><br>
Also, if you're running this in the cloud rather than a local Jupyter server on your machine, then the notebook will *timeout* after a period of inactivity.
<br><br>
Refer to this link on how to run Colab notebooks locally on your machine to avoid this issue:<br>
https://research.google.com/colaboratory/local-runtimes.html

We'll explore pre-training and transfer learning using the **Transformers** library from [Hugging Face](https://huggingface.co/). **Transformers** is an API and toolkit to download pre-trained models and further train them as needed. <br>

We'll start with the **pipelines** module which abstracts a lot of operations such as tokenization, vectorization, inference, etc.<br>

With **Transformers pipelines**, we can just feed text input and get text output. And there are **pipelines** for common tasks including classification, NER, summarization, etc.<br>
https://huggingface.co/docs/transformers/index<br>
https://huggingface.co/docs/transformers/main/en/main_classes/pipelines#pipelines

To get started, we'll need to install **Transformers**.


In [162]:
!pip install transformers
!pip install datasets

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting transformers
  Downloading transformers-4.30.2-py3-none-any.whl (7.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.2/7.2 MB[0m [31m64.7 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hCollecting filelock (from transformers)
  Downloading filelock-3.12.2-py3-none-any.whl (10 kB)
Collecting huggingface-hub<1.0,>=0.14.1 (from transformers)
  Downloading huggingface_hub-0.16.4-py3-none-any.whl (268 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m259.6 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers)
  Downloading tokenizers-0.13.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m116.6 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hCollecting safetensors>=0.3.1 (from transformers)
  Do

In [163]:
import operator
import pandas as pd
import tensorflow as tf
import transformers

from datasets import load_dataset
from tensorflow import keras
from transformers import AutoTokenizer
from transformers import pipeline
from transformers import TFAutoModelForQuestionAnswering