# The Transformer Decoder
Implement transformer decoder using Trax

- Translate the mathematics of attention into numpy code
- How multi-head causal attention fits inot a GPT-2 transformer decoder
- How to build one with Trax layers
- Implementation of causal attention from scratch
- Exploit the handy-dandy tl.CausalAttention() layer

Components and Flow of a Transformer Decoder  
![alt-text](images/C4_W2_L6_transformer-decoder_S01_transformer-decoder.png)

In [1]:
import sys
import os
import time
import numpy as np
import gin

import textwrap
wrapper = textwrap.TextWrapper(width=70)

import trax
from trax import layers as tl
from trax.fastmath import numpy as jnp

# to pring the entire np arry
np.set_printoptions(threshold=sys.maxsize)

INFO:tensorflow:tokens_length=568 inputs_length=512 targets_length=114 noise_density=0.15 mean_noise_span_length=3.0 


### Sentence gets embedded, add positional encoding
Embed the words, then create vectors representing ech word's position in each sentence $\in {0, 1, 2, \cdots, K}$ = `range(max_len)`, where max_len=K+1

In [2]:
def PositionalEncoder(vocab_size, d_model, dropout, max_len, mode):
    """
    Returns a list of layers that:
    1. Takes a block of text as input
    2. Embeds the words in that text and
    3. Adds positional encoding
        i.e. Associates a number in range(max_len) with
        each word in each sentence of embeded input text
    The input is a list of tokenized blocks of text
    
    Args
    vocab_size (int): vocab_size
    d_model: depth of embedding
    dropout(float): dropout rate 
    max_len: maximum symbol length for positional encoding
    mode(str): 'train' or 'eval'
    """
    
    # Embeddings inputs and positional encoder
    return [
        # Add embedding layer of dimension(vocab_size, d_model)
        tl.Embedding(vocab_size, d_model),
        # Use dropout with rate and mode specified
        tl.Dropout(rate=dropout, mode=mode),
        # Add positional encoding layer with minimum input length and mode specified
        tl.PositionalEncoding(max_max_len, mode=mode)
    ]

## Multi-Head Causal Attention
The layers and array dimension involved in multi-head causal attention -  which looks at previous words in the input tex
![alt-text](images/C4_W2_L5_multi-head-attention_S05_multi-head-attention-concatenation_stripped.png)

- `tl.CausalAttention()` does all the above
- [Q, K, V] is all input text, no need to pass 3 times, `tl.Branch` is used to handle this
-  each branch within a tl.Branch() layer performs parallel operations on copies of the layer's inputs. 
- For causal attention, each branch (representing Q, K, and V) applies a linear transformation (i.e. a dense layer without a subsequent activation) to its copy of the input, then splits that result into heads. 
- You can see the syntax for this in the screenshot from the trax.layers.attention.py source code below: 
![alt_text](images/use-of-tl-Branch-in-tl-CausalAttention.png)

## Feed-forward Layer
- Typically ends with a ReLU activtion, but we will leave open the possibility of a different activation

In [3]:
def FeedForward(d_model, d_ff, dropout, mode, ff_activation):
    """
    Returns a list of layers that implements a FF bloc
    
    The input is an activation tensor
    Args:
        d_model (int):  depth of embedding.
        d_ff (int): depth of feed-forward layer.
        dropout (float): dropout rate (how much to drop out).
        mode (str): 'train' or 'eval'.
        ff_activation (function): the non-linearity in feed-forward layer.
        
    Returns:
        list: list of trax.layers.combinators.Serial that maps an activation tensor to an activation tensor.
    """
    
    # Create feed-forward block (list) with two dense layers with dropout and input normalized
    return [
        # Normalize layer inputs
        tl.LayerNorm(),
        # Add first feed forward (dense) layer (don't forget to set the correct value for n_units)
        tl.Dense(d_ff),
        # Add activtion function passed in as a parameter
        ff_activation(), # ~ReLU
        # Add dropout with rate and mode specified (ie. do not use dropout during evaluation)
        tl.Dropout(rate=dropout, mode=mode),
        # Add second feed forward layer (don't forget to set the correct value for n_units)
        tl.Dense(d_model),
        # Add dropout with rate and mode speficied (i.e. dont use dropout during evaluation)
        tl.Dropout(rate=dropout, mode=mode)
    ]

## Decoder block
- Here, we return a list containing two residual blocks. 
- The first wraps around the causal attention layer, whose inputs are normalized and to which we apply dropout regulation. 
- The second wraps around the feed-forward layer. 
- You may notice that the second call to `tl.Residual()` doesn't call a normalization layer before calling the feed-forward layer. 
- This is because the normalization layer is included in the feed-forward layer.

In [4]:
def DecoderBlock(
    d_model, d_ff, n_heads,
    dropout, mode, ff_activation
):
    """
    Returns a list of layers that implements a Transformer decoder block
    
    The input is an activation tensor
    
    Args:
        d_model (int):  depth of embedding.
        d_ff (int): depth of feed-forward layer.
        n_heads (int): number of attention heads.
        dropout (float): dropout rate (how much to drop out).
        mode (str): 'train' or 'eval'.
        ff_activation (function): the non-linearity in feed-forward layer.

    Returns:
        list: list of trax.layers.combinators.Serial that maps an activation tensor to an activation tensor.
    """
    
    # Add list of two Residual Blocks: the attention with normalization and dropout and FF blocks
    
    return [
        tl.Residual(
            tl.LayerNorm(),
            # Add causal attention
            tl.CausalAttention(d_feature, n_heads=n_heads, dropout=dropout, mode=mode)
        ),
        tl.Residual(
            # Add feed-forwards block
            # We dont need to normalize the layer inputs here.
            # The feed-forward block takes care of that for us
            FeedForward(d_modl, d_ff, dropout, mode, ff_activation)
        )
    ]

## The Transformer Decoder
aka. repead N times Dense Layer and Softmax for output

In [6]:
def TransformerLM(
    vocab_size=33300,
    d_model=512,
    d_ff=2048,
    n_layers=6,
    n_heads=8,
    dropout=0.1,
    max_len=4096,
    mode='train',
    ff_activation=tl.Relu
):
    
    """
    Returns a Transformer Language Model
    
    The input to the model is a tensor of tokens (this model uses only the
    decoder part of the verall Transformer)
    
    Args:
        vocab_size (int): vocab size.
        d_model (int):  depth of embedding.
        d_ff (int): depth of feed-forward layer.
        n_layers (int): number of decoder layers.
        n_heads (int): number of attention heads.
        dropout (float): dropout rate (how much to drop out).
        max_len (int): maximum symbol length for positional encoding.
        mode (str): 'train', 'eval' or 'predict', predict mode is for fast inference.
        ff_activation (function): the non-linearity in feed-forward layer.
        
    Returns:
        trax.layers.combinators.Serial: A Transformer  language model as a layer that maps from a tensor of tokens
        to activations over a vocab set
    """
    
    # Create stack (list) of decoder blocks with n_layers with necessary parameters
    decoder_blocks = [
        DecoderBlock(d_model, d_ff, n_heads, dropout, mode, ff_activation) for _ in range(n_layers)
    ]
    
    # Create the complete model as written in the figure
    return tl.Serial(
        # Use teacher forcing (feed output of previous step to current step)
        tl.ShiftRight(mode=mode),
        # Add embedding inputs and positional encoder
        PositionalEncoder(vocab_size, d_model, dropout, max_len, mode),
        # Add decoder blocks
        decoder_blocks,
        # Normalize layer
        tl.LayerNorm(),
        # Add dense layer of vocab_size (since need to select a word to translate to)
        # (a.k.a logits layer, Note: activation already set by ff_activation)
        tl.Dense(vocab_size),
        # Get Probabilities with LogSoftmax
        tl.LogSoftmax()
    )