# Transformer Summarizer

- Explore summarization using the transformer model
- Implement transformer decoder from scratch
![alt_text](images/transformerNews.png)

## Outline

- [Introduction](#0)
- [Part 1: Importing the dataset](#1)
    - [1.1 Encode & Decode helper functions](#1.1)
    - [1.2 Defining parameters](#1.2)
    - [1.3 Exploring the data](#1.3)
- [Part 2: Summarization with transformer](#2)
    - [2.1 Dot product attention](#2.1)
        - [Exercise 01](#ex01)
    - [2.2 Causal Attention](#2.2)
        - [Exercise 02](#ex02)
    - [2.3 Transformer decoder block](#2.3)
        - [Exercise 03](#ex03)
    - [2.4 Transformer Language model](#2.4)
        - [Exercise 04](#ex04)
- [Part 3: Training](#3)
    - [3.1 Training the model](#3.1)
        - [Exercise 05](#ex05)
- [Part 4: Evaluation](#4)
    - [4.1 Loading in a trained model](#4.1)
- [Part 5: Testing with your own input](#5) 
    - [Exercise 6](#ex06)
    - [5.1 Greedy decoding](#5.1)
        - [Exercise 07](#ex07)

<a name='0'></a>
### Introduction
- Summarization is an important task in NLP and could be useful for a consumer enterprise
- Bots can be used to scrape articles, summarize them and then you can use sentiment analysis to identify the sentiment about certain stock
- Summarize long emails or articles 

1. Use built in functions to preprocess data
2. Implement DotProductAttention
3. Implement Causal Attention
4. Understand how attention works
5. Build the transformer model
6. Evaluate the model
7. Summarize the article

In [1]:
import sys
import os

import numpy as np
import textwrap
wrapper = textwrap.TextWrapper(width=70)

import trax
from trax import layers as tl
from trax.fastmath import numpy as jnp

# To print the entire np array
np.set_printoptions(threshold=sys.maxsize)

INFO:tensorflow:tokens_length=568 inputs_length=512 targets_length=114 noise_density=0.15 mean_noise_span_length=3.0 


In [2]:
DATA_DIR = 'data'
VOCAB_DIR = os.path.join(DATA_DIR, 'vocab_dir')

In [3]:
# Importing CNN/DailyMail articles dataset
train_stream_fn = trax.data.TFDS('cnn_dailymail',
                                 data_dir=DATA_DIR,
                                 keys=('article', 'highlights'),
                                 train=True)

# This should be much faster as the data is downloaded already.
eval_stream_fn = trax.data.TFDS('cnn_dailymail',
                                data_dir=DATA_DIR,
                                keys=('article', 'highlights'),
                                train=False)

<a name='1.1'></a>
## 1.1 Tokenize & Detokenize helper functions

Just like in the previous assignment, the cell above loads in the encoder for you. Given any data set, you have to be able to map words to their indices, and indices to their words. The inputs and outputs to your [Trax](https://github.com/google/trax) models are usually tensors of numbers where each number corresponds to a word. If you were to process your data manually, you would have to make use of the following: 

- <span style='color:blue'> word2Ind: </span> a dictionary mapping the word to its index.
- <span style='color:blue'> ind2Word:</span> a dictionary mapping the index to its word.
- <span style='color:blue'> word2Count:</span> a dictionary mapping the word to the number of times it appears. 
- <span style='color:blue'> num_words:</span> total number of words that have appeared. 

Since you have already implemented these in previous assignments of the specialization, we will provide you with helper functions that will do this for you. Run the cell below to get the following functions:

- <span style='color:blue'> tokenize: </span> converts a text sentence to its corresponding token list (i.e. list of indices). Also converts words to subwords.
- <span style='color:blue'> detokenize: </span> converts a token list to its corresponding sentence (i.e. string).

In [4]:
def tokenize(input_str, EOS=1):
    """Input str to features dict, ready for inference"""
  
    # Use the trax.data.tokenize method. It takes streams and returns streams,
    # we get around it by making a 1-element stream with `iter`.
    inputs =  next(trax.data.tokenize(
        iter([input_str]),
        vocab_dir=VOCAB_DIR,
        vocab_file='summarize32k.subword.subwords'
    ))
    
    # Mark the end of the sentence with EOS
    return list(inputs) + [EOS]

def detokenize(integers):
    """List of ints to str"""
  
    s = trax.data.detokenize(
        integers,
        vocab_dir=VOCAB_DIR,
        vocab_file='summarize32k.subword.subwords'
    )
    
    return wrapper.fill(s)

<a name='1.2'></a>
## 1.2 Preprocessing for Language Models: Concatenate It!
**Transformer Decoder** Language model to solve an input-output problem
- Language models predict the next word, they have no notion of input
- Concatenate inputs with targets putting a separator in between
- Create a mask - [0, 1] 0s at inputs and 1s at targets, so that the model is not penalized for mis-predicting the article and only focus on summary


In [5]:
SEP = 0 # Padding or separator token
EOS = 1 # End of sentence token

# Concatenate ttokenized inputs and targets sing 0 as separator
def preprocess(stream):
    for (article, summary) in stream:
        joint = np.array(list(article) + [EOS, SEP] + list(summary) + [EOS])
        mask = [0] * (len(list(article)) + 2) + [1]* (len(list(summary)) + 1) # Accounting for EOS and SEP
        yield joint, joint, np.array(mask)
        
# You can combine a few data preprocessing steps into a pipeline like this
input_pipeline = trax.data.Serial(
    # Tokenizes
    trax.data.Tokenize(vocab_dir=VOCAB_DIR, vocab_file='summarize32k.subword.subwords'),
    # Uses function defined above
    preprocess,
    # Filters out examples longer than 2048
    trax.data.FilterByLength(2048)
)

# Apply preprocessing to data streams
train_stream = input_pipeline(train_stream_fn())
eval_stream = input_pipeline(eval_stream_fn())

train_input, train_target, train_mask = next(train_stream)

assert sum((train_input - train_target)**2) == 0 # They are the same in Language Model (LM)

In [6]:
# prints mask, 0s on article, 1s on summary
print(f'Single example mask:\n\n {train_mask}')

Single example mask:

 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1

In [7]:
# prints: [Example][<EOS>][<pad>][Example Summary][<EOS>]
print(f'Single example:\n\n {detokenize(train_input)}')

Single example:

 Pope Francis has just landed in the Philippines - to a crowd of
thousands, church bells and an airport greeting from President Benigno
Aquino. The Pope is in Asia, having begun his trip just days ago with
a stop in Sri Lanka, for a week-long tour, and has already caused a
huge spike in interest when it comes to travel in the region. Flight
searches to Sri Lanka have increased by 19 per cent, while the numbers
for the Philippines have increased by a whopping 51 per cent week on
week from last year. Pope Francis has just arrived in the Philippines
- and already travel searches for the region have spiked . The Pope's
six-day Asian tour includes stops in Manila, where adoring crowds
waited, and Sri Lanka . Interest in Manila has seen a particular
increase, up 60 per cent already, according to Cheapflights.co.uk .
Cheapflights.co.uk spokesperson Oonah Shiel attributes this 'halo
effect' to the Pope's genuine mass appeal. Of particular interest to
travellers is Manila, wher

<a name='1.3'></a>

## 1.3 Batching with bucketing

In [8]:
# Bucketing to crete batched generators

# Buckets are defined in terms of boundaries and batch sizes
# Batch_sizes[i] determines the batch size for items with length < boundaries[i]
# So below, we'll take a batch of 16 sentences of length < 128, 8 of length < 256
# 4 of length < 512 and so on

boundaries = [128, 256, 512, 1024]
batch_sizes = [16, 8, 4, 2, 1]

# Create the streams
train_batch_stream = trax.data.BucketByLength(
    boundaries,
    batch_sizes
)(train_stream)

eval_batch_stream = trax.data.BucketByLength(
    boundaries,
    batch_sizes
)(eval_stream)

In [9]:
# Every execution will result in generation of a different article
# Try running this cell multiple times to see how the length of the examples affects the batch size
input_batch, _, mask_batch = next(train_batch_stream)

input_batch.shape

(1, 1195)

In [10]:
# print corresponding integer values
print(input_batch[0])

[ 2713   132  1668    18   973    28 19300   527    28   157 11703   527
  4811    28   968 25789   157  1547   450   412    22   346   126    78
  3155  1859     3     9 19300  1353  3174   669   391   145    28  2642
  2460  2264  1248 13337  7977     6   968 13944 26939  3975 16981     2
  1120     2  1779   229 17654    16   102   144  1496   132   213 22208
   186 10255     3     9 10064     2  1779  2232   320  4756   213  1610
   132  4910  2793  1344     6  4041  1581    78  1967    78  3155  1859
     2   229   897   412    28   763  1718     2   286   320  1120    91
   292     2  1248    28  2692  1463   186    28 10722   959 13380   314
  1782   129  1435  6304   996  1019   669 27884     4   213 10064 27872
  3898 25183   318   136    10   197     3  2598   527   213  1668   676
   527  1347  8048   793 27439  9275  1628  7350  1652  4172 27439  9275
  7583     7   129    40 21090   185    64 13094     2    35    51   790
     7    26   211   116  2574   527 24692     4   

In [11]:
# print the article and its summary
print('Article:\n\n', detokenize(input_batch[0]))

Article:

 Police in Texas have released a sketch of a man suspected of shooting
a TV weatherman multiple times as he left work on Wednesday morning.
The sketch was drawn  during a hospital bed interview with KCEN-TV
meteorologist Patrick Crawford, 35, who is recovering after being shot
in the stomach and shoulder. The suspect, who managed to escape the
scene in Bruceville-Eddy on foot on Wednesday morning, is described as
a white male, 30 to 35 years old, with a medium build and a receding
hairline. 'We are actively looking for [the suspect],' Trooper D.L.
Wilson of the Texas Department of Public Safety told ABC News. 'We had
troopers out overnight, but we didn't get any calls of suspicious
people. We can't rule out that he's not in the area, but more than
likely, he's left the area.' Scroll down for video . Suspect: KCEN
weatherman Patrick Crawford (left) helped police put together this
sketch (right) of the man who shot him as he left the TV studio on
Wednesday morning. He said he h

<a name='2'></a>
# Part 2: Summarization with transformer
![alt_text](images/transformer_decoder_zoomin.png)

<a name='2.1'></a>
## 2.1 Dot product attention 

Now you will implement dot product attention which takes in a query, key, value, and a mask. It returns the output.   
![alt_text](images/dotproduct.png)

Here are some helper functions that will help you create tensors and display useful information:
   - `create_tensor`  creates a `jax numpy array` from a list of lists.
   - `display_tensor` prints out the shape and the actual tensor.

In [12]:
def create_tensor(t):
    """
    Create tensor from list of lists
    """
    return jnp.array(t)

def display_tensor(t, name):
    """Display shape and tensor"""
    print(f'{name} shape: {t.shape}\n')
    print(f'{t}\n')

The formula for attention is this one:

$$
\text { Attention }(Q, K, V)=\operatorname{softmax}\left(\frac{Q K^{T}}{\sqrt{d_{k}}}+{M}\right) V\tag{1}\
$$

$d_{k}$ stands for the dimension of queries and keys.

The `query`, `key`, `value` and `mask` vectors are provided for this example.

Notice that the masking is done using very negative values that will yield a similar effect to using $-\infty $. 

In [13]:
q = create_tensor([[1, 0, 0], [0, 1, 0]])
display_tensor(q, 'query')
k = create_tensor([[1, 2, 3], [4, 5, 6]])
display_tensor(k, 'key')
v = create_tensor([[0, 1, 0], [1, 0, 1]])
display_tensor(v, 'value')
m = create_tensor([[0, 0], [-1e9, 0]])
display_tensor(m, 'mask')

query shape: (2, 3)

[[1 0 0]
 [0 1 0]]

key shape: (2, 3)

[[1 2 3]
 [4 5 6]]

value shape: (2, 3)

[[0 1 0]
 [1 0 1]]

mask shape: (2, 2)

[[ 0.e+00  0.e+00]
 [-1.e+09  0.e+00]]





In [14]:
q_dot_k = q @ k.T / jnp.sqrt(3)
display_tensor(q_dot_k, 'query dot key')

query dot key shape: (2, 2)

[[0.57735026 2.309401  ]
 [1.1547005  2.8867514 ]]



In [15]:
masked = q_dot_k + m
display_tensor(masked, 'masked query dot key')

masked query dot key shape: (2, 2)

[[ 5.7735026e-01  2.3094010e+00]
 [-1.0000000e+09  2.8867514e+00]]



In [16]:
display_tensor(masked @ v, 'masked query dot key dot value')

masked query dot key dot value shape: (2, 3)

[[ 2.3094010e+00  5.7735026e-01  2.3094010e+00]
 [ 2.8867514e+00 -1.0000000e+09  2.8867514e+00]]



- In order to use the previous dummy tensors to test some of the graded functions, a batch dimension should be added to them so they mimic the shape of real-life examples. 
- The mask is also replaced by a version of it that resembles the one that is used by trax:

In [17]:
q_with_batch = q[None, :]
display_tensor(q_with_batch, 'query with batch dim')

query with batch dim shape: (1, 2, 3)

[[[1 0 0]
  [0 1 0]]]



In [18]:
k_with_batch = k[None,:]
display_tensor(k_with_batch, 'key with batch dim')
v_with_batch = v[None,:]
display_tensor(v_with_batch, 'value with batch dim')
m_bool = create_tensor([[True, True], [False, True]])
display_tensor(m_bool, 'boolean mask')

key with batch dim shape: (1, 2, 3)

[[[1 2 3]
  [4 5 6]]]

value with batch dim shape: (1, 2, 3)

[[[0 1 0]
  [1 0 1]]]

boolean mask shape: (2, 2)

[[ True  True]
 [False  True]]



<a name='ex01'></a>
### Exercise 01

**Instructions:** Implement the dot product attention. Concretely, implement the following equation


$$
\text { Attention }(Q, K, V)=\operatorname{softmax}\left(\frac{Q K^{T}}{\sqrt{d_{k}}}+{M}\right) V\tag{1}\
$$

$Q$ - query, 
$K$ - key, 
$V$ - values, 
$M$ - mask, 
${d_k}$ - depth/dimension of the queries and keys (used for scaling down)

You can implement this formula either by `trax` numpy (trax.math.numpy) or regular `numpy` but it is recommended to use `jnp`.

Something to take into consideration is that within trax, the masks are tensors of `True/False` values not 0's and $-\infty$ as in the previous example. Within the graded function don't think of applying the mask by summing up matrices, instead use `jnp.where()` and treat the **mask as a tensor of boolean values with `False` for values that need to be masked and True for the ones that don't.**

Also take into account that the real tensors are far more complex than the toy ones you just played with. Because of this avoid using shortened operations such as `@` for dot product or `.T` for transposing. Use `jnp.matmul()` and `jnp.swapaxes()` instead.

This is the self-attention block for the transformer decoder. Good luck!  

In [19]:
# UNQ_C1
# GRADED FUNCTION: DotProductAttention
def DotProductAttention(query, key, value, mask):
    """Dot product self-attention.
    Args:
        query (jax.interpreters.xla.DeviceArray): array of query representations with shape (L_q by d)
        key (jax.interpreters.xla.DeviceArray): array of key representations with shape (L_k by d)
        value (jax.interpreters.xla.DeviceArray): array of value representations with shape (L_k by d) where L_v = L_k
        mask (jax.interpreters.xla.DeviceArray): attention-mask, gates attention with shape (L_q by L_k)

    Returns:
        jax.interpreters.xla.DeviceArray: Self-attention array for q, k, v arrays. (L_q by L_k)
    """

    assert query.shape[-1] == key.shape[-1] == value.shape[-1], "Embedding dimensions of q, k, v aren't all the same"

    ### START CODE HERE (REPLACE INSTANCES OF 'None' with your code) ###
    # Save depth/dimension of the query embedding for scaling down the dot product
    depth = query.shape[-1]

    # Calculate scaled query key dot product according to formula above
    dots = jnp.matmul(query, jnp.swapaxes(key, -1, -2)) / jnp.sqrt(depth)
    
    # Apply the mask
    if mask is not None: # The 'None' in this line does not need to be replaced
        dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))
    
    # Softmax formula implementation
    # Use trax.fastmath.logsumexp of dots to avoid underflow by division by large numbers
    # Hint: Last axis should be used and keepdims should be True
    # Note: softmax = e^(dots - logsumexp(dots)) = E^dots / sumexp(dots)
    logsumexp = trax.fastmath.logsumexp(dots, axis=-1, keepdims=True)

    # Take exponential of dots minus logsumexp to get softmax
    # Use jnp.exp()
    dots = jnp.exp(dots - logsumexp)

    # Multiply dots by value to get self-attention
    # Use jnp.matmul()
    attention = jnp.matmul(dots, value)

    ## END CODE HERE ###
    
    return attention

In [20]:
DotProductAttention(q_with_batch, k_with_batch, v_with_batch, m_bool)

DeviceArray([[[0.8496746 , 0.15032545, 0.8496746 ],
              [1.        , 0.        , 1.        ]]], dtype=float32)

<a name='2.2'></a>

## 2.2 Causal Attention
Causal Attention: Multi-Headed attention with a mask to attend only to words that occured before
![alt_text](images/causal.png)
A word can see everything that was there before. Done by transforming vectors through reshaping.

<a name='ex02'></a>
### Exercise 02
- Compute Attention Heads: Gets an input x of dimension(batch_size, saqlen, n_heads X d_head) and splits the last (depth) dimension and stacks it to the zeroth dimension to allow matrix multiplication (batch_size x n_heads, saqlen, d_head)
- Dot Product Self Attention: Creates a mask matrix with False values above the diagonal and True values below and calls DotProductAttention which implements dot product self attention
- Compute Attention Output: Undoes compute_attention_heads by splitting first(Vertical) dimension and stacking in the last(depth) dimension (batch_size, saqlen, n_heads x d_head). These operations concatenate the heads

In [21]:
tensor2d = create_tensor(q)
display_tensor(tensor2d, 'query matrix (2D tensor)')

tensor4d2b = create_tensor([[q, q], [q, q]])
display_tensor(tensor4d2b, 'batch of two (multi-head) collections of query matrices (4D tensor)')

tensor3dc = create_tensor([jnp.concatenate([q, q], axis = -1)])
display_tensor(tensor3dc, 'one batch of concatenated heads of query matrices (3d tensor)')

tensor3dc3b = create_tensor([jnp.concatenate([q, q], axis = -1), jnp.concatenate([q, q], axis = -1), jnp.concatenate([q, q], axis = -1)])
display_tensor(tensor3dc3b, 'three batches of concatenated heads of query matrices (3d tensor)')

query matrix (2D tensor) shape: (2, 3)

[[1 0 0]
 [0 1 0]]

batch of two (multi-head) collections of query matrices (4D tensor) shape: (2, 2, 2, 3)

[[[[1 0 0]
   [0 1 0]]

  [[1 0 0]
   [0 1 0]]]


 [[[1 0 0]
   [0 1 0]]

  [[1 0 0]
   [0 1 0]]]]

one batch of concatenated heads of query matrices (3d tensor) shape: (1, 2, 6)

[[[1 0 0 1 0 0]
  [0 1 0 0 1 0]]]

three batches of concatenated heads of query matrices (3d tensor) shape: (3, 2, 6)

[[[1 0 0 1 0 0]
  [0 1 0 0 1 0]]

 [[1 0 0 1 0 0]
  [0 1 0 0 1 0]]

 [[1 0 0 1 0 0]
  [0 1 0 0 1 0]]]



It is important to know that the following 3 functions would normally be defined within the `CausalAttention` function further below. 

However this makes these functions harder to test. Because of this, these functions are shown individually using a `closure` (when necessary) that simulates them being inside of the `CausalAttention` function. This is done because they rely on some variables that can be accessed from within `CausalAttention`.

### Support Functions

<span style='color:blue'> compute_attention_heads </span>: Gets an input $x$ of dimension (batch_size, seqlen, n_heads $\times$ d_head) and splits the last (depth) dimension and stacks it to the zeroth dimension to allow matrix multiplication (batch_size $\times$ n_heads, seqlen, d_head).

**For the closures you only have to fill the inner function.**

In [22]:
def compute_attention_heads_closure(n_heads, d_head):
    """
    Function that simulates environment inside CausalAttention function
    Args:
        d_head (int): Dimensionality of heads
        n_heads (int): number of attention heads
    Returns
        function: compute_attention_heads function
    """
    
    def compute_attention_heads(x):
        """
        Compute the attention heads
        Args
            x (jax.interpreters.xla.DeviceArray): tensor with shape (batch_size, seqlen, n_heads X d_head)
        Returns
            jax.interpreters.xla.DeviceArray: reshaped tensor with shape (batch_size X n_heads, seqlen, d_head)
        """
        
        # Size of the x's batch dimension
        batch_size = x.shape[0]
        
        # Length of the sequence
        # Should be the size of x's first dimension without counting the batch dim
        seqlen = x.shape[1]
        # Reshape x using jnp.reshape()
        # batch_size, seqlen, n_heads*d_head -> batch_size, seqlen, n_heads, d_head
        x = jnp.reshape(x, [batch_size, seqlen, n_heads, d_head])
        # Transpose x using jnp.transpose()
        # batch_size, seqlen, n_heads, d_head -> batch_size, n_heads, seqlen, d_head
        # Note that the values within the tuple are the indexes of the dimensions of x and you must rearrange them
        x = jnp.transpose(x, [0, 2, 1, 3])
        # Reshape x using jnp.reshape()
        # batch_size, n_heads, seqlen, d_head -> batch_size*n_heads, seqlen, d_head
        x = jnp.reshape(x, [batch_size * n_heads, seqlen, d_head])
        
        return x
    
    return compute_attention_heads

In [23]:
display_tensor(tensor3dc3b, "input tensor")
result_cah = compute_attention_heads_closure(2,3)(tensor3dc3b)
display_tensor(result_cah, "output tensor")

input tensor shape: (3, 2, 6)

[[[1 0 0 1 0 0]
  [0 1 0 0 1 0]]

 [[1 0 0 1 0 0]
  [0 1 0 0 1 0]]

 [[1 0 0 1 0 0]
  [0 1 0 0 1 0]]]

output tensor shape: (6, 2, 3)

[[[1 0 0]
  [0 1 0]]

 [[1 0 0]
  [0 1 0]]

 [[1 0 0]
  [0 1 0]]

 [[1 0 0]
  [0 1 0]]

 [[1 0 0]
  [0 1 0]]

 [[1 0 0]
  [0 1 0]]]



**dot_product_self_attention** : Creates a mask matrix with False values above the diagonal and True values below and calls DotProductAttention which implements dot product self attention.

In [24]:
def dot_product_self_attention(q, k, v):
    """
    Masked dot product self attention
    Args: 
        q: Queries
        k: Keys
        v: Values
    Returns:
        DeviceArray: Masked dot product self attention tensor
    """
    
    # Hint: mask size should be equal to L_q. Remember that q has shame batch_size, L_q, d
    mask_size = q.shape[1]
    
    # Creates a matrix with ones below the diagonal and 0s above. It should have shape (1, mask_size, mask_size)
    # Notice that 1's and 0's get casted to True/False by setting dtype to jnp.bool_
    # Use jnp.tril() - Lower triangle of an array and jnp.ones()
    mask = jnp.tril(jnp.ones((1, mask_size, mask_size), dtype=jnp.bool_), k=0)
    
    return DotProductAttention(q, k, v, mask)

In [25]:
dot_product_self_attention(q_with_batch, k_with_batch, v_with_batch)

DeviceArray([[[0.        , 1.        , 0.        ],
              [0.8496746 , 0.15032543, 0.8496746 ]]], dtype=float32)

**compute_attention_output** : Undoes compute_attention_heads by splitting first (vertical) dimension and stacking in the last (depth) dimension (batch_size, seqlen, n_heads × d_head). These operations concatenate (stack/merge) the heads. 

In [26]:
def compute_attention_output_closure(n_heads, d_head):
    """ Function that simulates environment inside CausalAttention function.
    Args:
        d_head (int):  dimensionality of heads.
        n_heads (int): number of attention heads.
    Returns:
        function: compute_attention_output function
    """
    
    def compute_attention_output(x):
        """ Compute the attention output.
        Args:
            x (jax.interpreters.xla.DeviceArray): tensor with shape (batch_size X n_heads, seqlen, d_head).
        Returns:
            jax.interpreters.xla.DeviceArray: reshaped tensor with shape (batch_size, seqlen, n_heads X d_head).
        """
        ### START CODE HERE (REPLACE INSTANCES OF 'None' with your code) ###
        
        # Length of the sequence
        # Should be size of x's first dimension without counting the batch dim
        seqlen = x.shape[1]
        # Reshape x using jnp.reshape() to shape (batch_size, n_heads, seqlen, d_head)
        x = jnp.reshape(x, [int(x.shape[0] / n_heads), n_heads, seqlen, d_head])
        # Transpose x using jnp.transpose() to shape (batch_size, seqlen, n_heads, d_head)
        x = jnp.transpose(x, [0, 2, 1, 3])
        
        ### END CODE HERE ###
        
        # Reshape to allow to concatenate the heads
        return jnp.reshape(x, (-1, seqlen, n_heads * d_head))
    
    return compute_attention_output

In [27]:
display_tensor(result_cah, "input tensor")
result_cao = compute_attention_output_closure(2,3)(result_cah)
display_tensor(result_cao, "output tensor")

input tensor shape: (6, 2, 3)

[[[1 0 0]
  [0 1 0]]

 [[1 0 0]
  [0 1 0]]

 [[1 0 0]
  [0 1 0]]

 [[1 0 0]
  [0 1 0]]

 [[1 0 0]
  [0 1 0]]

 [[1 0 0]
  [0 1 0]]]

output tensor shape: (3, 2, 6)

[[[1 0 0 1 0 0]
  [0 1 0 0 1 0]]

 [[1 0 0 1 0 0]
  [0 1 0 0 1 0]]

 [[1 0 0 1 0 0]
  [0 1 0 0 1 0]]]



### Causal Attention Function
![alt_text](images/masked-attention.png)

Implement and return a causal attention, through `tl.Serial` with the following

- `tl.Branch`: consisting of 3 [tl.Densed(d_feature), ComputeAttentionHeads] to account for queries, keys and values
- `tl.Fn`: Takes in dot_product_self_attention function and uses it to compute the dot product using Q, K, V
- `tl.Fn`: Takes in compute_attention_output_closure to allow for parallel computing
- `tl.Dense`: Final dense layer with dimension d_feature

In [28]:
def CausalAttention(
    d_feature,
    n_heads,
    compute_attenention_heads_closure=compute_attention_heads_closure,
    dot_product_self_attention=dot_product_self_attention,
    compute_attention_output_closure=compute_attention_output_closure,
    mode='train'
):
    """
    Transformer style multi-headed causal attention
    Args:
    d_feature(int): Dimensionality of feature embedding
    n_heads: Number of attention bias
    compute_attention_head_closure(func): Closure around compute_attention heads
    dot_product_self_attention (function): dot_product_self_attention function. 
        compute_attention_output_closure (function): Closure around compute_attention_output. 
        mode (str): 'train' or 'eval'.

    Returns:
        trax.layers.combinators.Serial: Multi-headed self-attention model.
    """
    
    assert d_feature % n_heads == 0
    d_head = d_feature // n_heads
    
    ComputeAttentionHeads = tl.Fn('AttnHeads', compute_attention_heads_closure(n_heads, d_head), n_out=1)
    
    return tl.Serial(
        tl.Branch(
            [tl.Dense(d_feature), ComputeAttentionHeads],
            [tl.Dense(d_feature), ComputeAttentionHeads],
            [tl.Dense(d_feature), ComputeAttentionHeads]
        ),
        tl.Fn('DotProductAttn', dot_product_self_attention, n_out=1), # takes QKV
        tl.Fn('AttnOutput', compute_attention_output_closure(n_heads, d_head), n_out=1), # to allow for parallel
        tl.Dense(d_feature) # Final dense layer
    )

In [29]:
print(CausalAttention(d_feature=512, n_heads=8))

Serial[
  Branch_out3[
    [Dense_512, AttnHeads]
    [Dense_512, AttnHeads]
    [Dense_512, AttnHeads]
  ]
  DotProductAttn_in3
  AttnOutput
  Dense_512
]


<a name='2.3'></a>

## 2.3 Transformer decoder block
![alt_text](images/transformer_decoder_1.png)

To implement this function, you will have to call the `CausalAttention` or Masked multi-head attention function you implemented above. You will have to add a feedforward which consists of: 

- `tl.LayerNorm`: used to layer normalize
- `tl.Dense`: the dense layer
- `ff_activation`: feed forward activation here(ReLU)
- `tl.Dropout`: dropout layer
- `tl.Dense`: dense layer
- `dl.Dropout`: dropout layer

Implement the entire block using 2 `tl.Residual`

In [30]:
def DecoderBlock(d_model, d_ff, n_heads,
                 dropout, mode, ff_activation):
    """Returns a list of layers that implements a Transformer decoder block.

    The input is an activation tensor.

    Args:
        d_model (int):  depth of embedding.
        d_ff (int): depth of feed-forward layer.
        n_heads (int): number of attention heads.
        dropout (float): dropout rate (how much to drop out).
        mode (str): 'train' or 'eval'.
        ff_activation (function): the non-linearity in feed-forward layer.

    Returns:
        list: list of trax.layers.combinators.Serial that maps an activation tensor to an activation tensor.
    """
    
    ### START CODE HERE (REPLACE INSTANCES OF 'None' with your code) ###
    
    # Create masked multi-head attention block using CausalAttention function
    causal_attention = CausalAttention( 
                        d_model,
                        n_heads=n_heads,
                        mode=mode
                        )

    # Create feed-forward block (list) with two dense layers with dropout and input normalized
    feed_forward = [ 
        # Normalize layer inputs
        tl.LayerNorm(),
        # Add first feed forward (dense) layer (don't forget to set the correct value for n_units)
        tl.Dense(n_units=d_ff),
        # Add activation function passed in as a parameter (you need to call it!)
        ff_activation(), # Generally ReLU
        # Add dropout with rate and mode specified (i.e., don't use dropout during evaluation)
        tl.Dropout(dropout, mode=mode),
        # Add second feed forward layer (don't forget to set the correct value for n_units)
        tl.Dense(n_units=d_model),
        # Add dropout with rate and mode specified (i.e., don't use dropout during evaluation)
        tl.Dropout(dropout, mode=mode)
    ]

    # Add list of two Residual blocks: the attention with normalization and dropout and feed-forward blocks
    return [
      tl.Residual(
          # Normalize layer input
          tl.LayerNorm(),
          # Add causal attention block previously defined (without parentheses)
          causal_attention,
          #tlAdd dropout with rate and mode specified
          tl.Dropout(dropout, mode=mode)
        ),
      tl.Residual(
          # Add feed forward block (without parentheses)
          feed_forward
        ),
      ]
    ### END CODE HERE ###

In [31]:
print(DecoderBlock(d_model=512, d_ff=2048, n_heads=8, dropout=0.1, mode='train', ff_activation=tl.Relu))

[Serial[
  Branch_out2[
    None
    Serial[
      LayerNorm
      Serial[
        Branch_out3[
          [Dense_512, AttnHeads]
          [Dense_512, AttnHeads]
          [Dense_512, AttnHeads]
        ]
        DotProductAttn_in3
        AttnOutput
        Dense_512
      ]
      Dropout
    ]
  ]
  Add_in2
], Serial[
  Branch_out2[
    None
    Serial[
      LayerNorm
      Dense_2048
      Relu
      Dropout
      Dense_512
      Dropout
    ]
  ]
  Add_in2
]]


<a name='2.4'></a>
## 2.4 Transformer Language Model

You will now bring it all together. In this part you will use all the subcomponents you previously built to make the final model. Concretely, here is the image you will be implementing.
![alt_text](images/transformer_decoder.png)

Transformer Language Model..
- Positional Encoder
    - tl.Embedding
    - tl.Dropout
    - tl.PositionalEncoding
- A list of n_layers decoder blocks
- tl.Serial: takes in the following layers or lists of layers
    - tl.ShiftRight: shift the tensor to the right by adding on axis 1
    - positional_encoder: encodes the text positions
    - decoder_blocks: the ones you created
    - tl.LayerNorm: a layer norm
    - tl.Dense: takes in the voab size
    - tl.LogSoftmax: to predict
    

In [32]:
def TransformerLM(
    vocab_size=33300,
    d_model=512,
    d_ff=2048,
    n_layers=6,
    n_heads=8,
    dropout=0.1,
    max_len=4096,
    mode='train',
    ff_activation=tl.Relu
):
    
    """Returns a Transformer language model.

    The input to the model is a tensor of tokens. (This model uses only the
    decoder part of the overall Transformer.)

    Args:
        vocab_size (int): vocab size.
        d_model (int):  depth of embedding.
        d_ff (int): depth of feed-forward layer.
        n_layers (int): number of decoder layers.
        n_heads (int): number of attention heads.
        dropout (float): dropout rate (how much to drop out).
        max_len (int): maximum symbol length for positional encoding.
        mode (str): 'train', 'eval' or 'predict', predict mode is for fast inference.
        ff_activation (function): the non-linearity in feed-forward layer.

    Returns:
        trax.layers.combinators.Serial: A Transformer language model as a layer that maps from a tensor of tokens
        to activations over a vocab set.
    """
    # Embedding inputs and positional encoder
    positional_encoder = [ 
        # Add embedding layer of dimension (vocab_size, d_model)
        tl.Embedding(vocab_size, d_model),
        # Use dropout with rate and mode specified
        tl.Dropout(dropout, mode=mode),
        # Add positional encoding layer with maximum input length and mode specified
        tl.PositionalEncoding(max_len=max_len, mode=mode)]

    # Create stack (list) of decoder blocks with n_layers with necessary parameters
    decoder_blocks = [ 
        DecoderBlock(d_model, d_ff, n_heads,
                 dropout, mode, ff_activation) for _ in range(n_layers)]

    # Create the complete model as written in the figure
    return tl.Serial(
        # Use teacher forcing (feed output of previous step to current step)
        tl.ShiftRight(1), # Specify the mode!
        # Add positional encoder
        positional_encoder,
        # Add decoder blocks
        decoder_blocks,
        # Normalize layer
        tl.LayerNorm(),

        # Add dense layer of vocab_size (since need to select a word to translate to)
        # (a.k.a., logits layer. Note: activation already set by ff_activation)
        tl.Dense(vocab_size),
        # Get probabilities with Logsoftmax
        tl.LogSoftmax()
    )

In [33]:
print(TransformerLM(n_layers=1))

Serial[
  ShiftRight(1)
  Embedding_33300_512
  Dropout
  PositionalEncoding
  Serial[
    Branch_out2[
      None
      Serial[
        LayerNorm
        Serial[
          Branch_out3[
            [Dense_512, AttnHeads]
            [Dense_512, AttnHeads]
            [Dense_512, AttnHeads]
          ]
          DotProductAttn_in3
          AttnOutput
          Dense_512
        ]
        Dropout
      ]
    ]
    Add_in2
  ]
  Serial[
    Branch_out2[
      None
      Serial[
        LayerNorm
        Dense_2048
        Relu
        Dropout
        Dense_512
        Dropout
      ]
    ]
    Add_in2
  ]
  LayerNorm
  Dense_33300
  LogSoftmax
]


## Training

In [34]:
from trax.supervised import training

# UNQ_C8
# GRADED FUNCTION: train_model
def training_loop(TransformerLM, train_gen, eval_gen, output_dir = "~/model"):
    '''
    Input:
        TransformerLM (trax.layers.combinators.Serial): The model you are building.
        train_gen (generator): Training stream of data.
        eval_gen (generator): Evaluation stream of data.
        output_dir (str): folder to save your file.
        
    Returns:
        trax.supervised.training.Loop: Training loop.
    '''
    output_dir = os.path.expanduser(output_dir)  # trainer is an object
    lr_schedule = trax.lr.warmup_and_rsqrt_decay(n_warmup_steps=1000, max_value=0.01)

    ### START CODE HERE (REPLACE INSTANCES OF 'None' with your code) ###
    train_task = training.TrainTask( 
      labeled_data=train_gen, # The training generator
      loss_layer=tl.CrossEntropyLoss(), # Loss function 
      optimizer=trax.optimizers.Adam(0.01), # Optimizer (Don't forget to set LR to 0.01)
      lr_schedule=lr_schedule,
      n_steps_per_checkpoint=10
    )

    eval_task = training.EvalTask( 
      labeled_data=eval_gen, # The evaluation generator
      metrics=[tl.CrossEntropyLoss(), tl.Accuracy()] # CrossEntropyLoss and Accuracy
    )

    ### END CODE HERE ###

    loop = training.Loop(TransformerLM(d_model=4,
                                       d_ff=16,
                                       n_layers=1,
                                       n_heads=2,
                                       mode='train'),
                         train_task,
                         eval_tasks=[eval_task],
                         output_dir=output_dir)
    
    return loop

In [37]:
!rm -f model/model.pkl.gz
loop = training_loop(TransformerLM, train_batch_stream, eval_batch_stream, output_dir="model")
loop.run(10)


Step      1: Ran 1 train steps in 11.99 secs
Step      1: train CrossEntropyLoss |  10.41387844
Step      1: eval  CrossEntropyLoss |  10.41582870
Step      1: eval          Accuracy |  0.00000000

Step     10: Ran 9 train steps in 45.31 secs
Step     10: train CrossEntropyLoss |  10.41400909
Step     10: eval  CrossEntropyLoss |  10.41204834
Step     10: eval          Accuracy |  0.00000000


<a name='4'></a>
# Part 4:  Evaluation  

<a name='4.1'></a>
### 4.1 Loading in a trained model

    
   `Original (pretrained) model: `                                 
                                       
    TransformerLM(vocab_size=33300, d_model=512, d_ff=2048, n_layers=6, n_heads=8, 
                   dropout=0.1, max_len=4096, ff_activation=tl.Relu)
                   
   `Your model:`
   
    TransformerLM(d_model=4, d_ff=16, n_layers=1, n_heads=2)

In [38]:
# Get the model architecture
# model = TransformerLM(mode='eval')

# load the pre-trained weights
# model.init_from_file('model/model.pkl.gz', weights_only=True)

<a name='5'></a>
# Part 5: Testing with your own input

In [40]:
def next_symbol(cur_output_tokens, model):
    """
    Returns the next symbol for a given sentence.

    Args:
        cur_output_tokens (list): tokenized sentence with EOS and PAD tokens at the end.
        model (trax.layers.combinators.Serial): The transformer model.

    Returns:
        int: tokenized symbol.
    """
    
    # current output tokens length
    token_length = len(cur_output_tokens)
    # calculate the minimum power of 2 big enough to store token_length
    # HINT: use np.ceil() and np.log2()
    # add 1 to token_length so np.log2() doesn't receive 0 when token_length is 0
    padded_length = 2**int(np.ceil(np.log2(token_length+1)))

    # Fill cur_output_tokens with 0's until it reaches padded_length
    padded = cur_output_tokens + [0] * (padded_length - token_length)
    padded_with_batch = np.array(padded)[None, :] # Don't replace this 'None'! This is a way of setting the batch dim

    # model expects a tuple containing two padded tensors (with batch)
    output, _ = model((padded_with_batch, padded_with_batch)) 
    # HINT: output has shape (1, padded_length, vocab_size)
    # To get log_probs you need to index output with 0 in the first dim
    # token_length in the second dim and all of the entries for the last dim.
    log_probs = output[0, token_length, :]
    
    return int(np.argmax(log_probs))    

In [41]:
# Test it out!
sentence_test_nxt_symbl = "I want to fly in the sky."
detokenize([next_symbol(tokenize(sentence_test_nxt_symbl)+[0], model)])

'Neither'

<a name='5.1'></a>
### 5.1 Greedy decoding

Now you will implement the greedy_decode algorithm that will call the `next_symbol` function. It takes in the input_sentence, the trained model and returns the decoded sentence. 

<a name='ex07'></a>
### Exercise 07

In [43]:
# UNQ_C10
# Decoding functions.
def greedy_decode(input_sentence, model):
    """Greedy decode function.

    Args:
        input_sentence (string): a sentence or article.
        model (trax.layers.combinators.Serial): Transformer model.

    Returns:
        string: summary of the input.
    """
    
    ### START CODE HERE (REPLACE INSTANCES OF 'None' with your code) ###
    # Use tokenize()
    cur_output_tokens = tokenize(input_sentence) + [0]
    generated_output = [] 
    cur_output = 0 
    EOS = 1 
    
    while cur_output != EOS:
        # Get next symbol
        cur_output = next_symbol(cur_output_tokens, model)
        # Append next symbol to original sentence
        cur_output_tokens.append(cur_output)
        # Append next symbol to generated sentence
        generated_output.append(cur_output)
        print(detokenize(generated_output))
    
    ### END CODE HERE ###
    
    return detokenize(generated_output)

In [None]:
# Test it out on a sentence!
test_sentence = "It was a sunny day when I went to the market to buy some flowers. But I only found roses, not tulips."
print(wrapper.fill(test_sentence), '\n')
print(greedy_decode(test_sentence, model))

In [46]:
# Test it out with a whole article!
article = "It’s the posing craze sweeping the U.S. after being brought to fame by skier Lindsey Vonn, soccer star Omar Cummings, baseball player Albert Pujols - and even Republican politician Rick Perry. But now four students at Riverhead High School on Long Island, New York, have been suspended for dropping to a knee and taking up a prayer pose to mimic Denver Broncos quarterback Tim Tebow. Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were all suspended for one day because the ‘Tebowing’ craze was blocking the hallway and presenting a safety hazard to students. Scroll down for video. Banned: Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll (all pictured left) were all suspended for one day by Riverhead High School on Long Island, New York, for their tribute to Broncos quarterback Tim Tebow. Issue: Four of the pupils were suspended for one day because they allegedly did not heed to warnings that the 'Tebowing' craze at the school was blocking the hallway and presenting a safety hazard to students."
print(wrapper.fill(article), '\n')
print(greedy_decode(article, model))