# Transformer Networks

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math, copy, time
from torch.autograd import Variable
import matplotlib.pyplot as plt
import seaborn
seaborn.set_context(context="talk")
%matplotlib inline

Most of the paper was clear, but for me it was difficult to immediately see how it would work across sequences of different length. I don't mean how or why the positional encoding would work - this is intuitive, but rather how they construct the matrices to fit together :).

I refer to implementation details from 'The Annotated Transformer' and the 'tensor2tensor' library.

## Encoder Decoder

A encoder-decoder [structure](https://arxiv.org/abs/1409.0473) is used: an input sequence of symbols, $x = { x_1, x_2, \dots, x_n }$, is encoded into a sequence of continuous variables,  $\mathbf{z} = { z_1, z_2, \dots, z_n }$. This is then decoded into a sequence of symbols, $y = { y_1, y_2, \dots, y_n }$. This generation occurs one at a time - it is [auto-regressive](https://arxiv.org/abs/1308.0850), further consuming the previously generated symbols as additional input when generating the next. Encoder and decoder models usually use a recurrent architecture.

## Input Representation

This work used a Byte Pair Encoding scheme. This is a subword tokenization of your vocabulary. This is much more valuable than a UNK symbol. To build this representation, an iterative algorithm can be used to link together the most common segments, starting with character pairs. Below is the pseudo code provided by the original authors with slight annotations.

## Scaled Dot-Product Attention
The transformer network uses a stateless auto-regressive strategy which will decode the encoded (but not summarized) source words and the current output words. The primary featured used is scaled dot product attention.

In [2]:
def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.size(-1)
    # compatiblity function
    scores = torch.matmul(query, key.transpose(-2, -1))
    # scale
    scores = scores / math.sqrt(d_k)
    # optional max
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    # alpha
    p_attn = F.softmax(scores, dim = -1)
    # optional dropout
    if dropout is not None:
        p_attn = dropout(p_attn)
    # output linear combinations of values
    return torch.matmul(p_attn, value), p_attn

TODO(cjlovering): include image from notability.

In [13]:
def SelfAttention(X):
    Q, K, V = X, X, X
    return attention(Q, K, V)

### Self Attention
With a single query, self attention will have no effect. This is because the attention mechanism will be a linear combination of the values, and clearly it can only reproduce itself - an identity function.

In [12]:
SelfAttention(torch.FloatTensor([[0.1,0.1,0.8]]))

(tensor([[0.1000, 0.1000, 0.8000]]), tensor([[1.]]))

When two of vectors are more *compatible* they will become more similar -- averaged between these two vectors. The remaining vector will also be normalized *different*.

In [23]:
X = torch.FloatTensor([
    [0.20,0.15,0.65],
    [0.15,0.10,0.75],
    [0.75,0.05,0.05]
])
out, alpha = SelfAttention(X)
out.numpy()

array([[0.3436604 , 0.10272714, 0.50955087],
       [0.3374398 , 0.10344668, 0.5166535 ],
       [0.3969706 , 0.09622269, 0.44894177]], dtype=float32)

### Multi Head Attention

In [None]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        "Take in model size and number of heads."
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, query, key, value, mask=None):
        "Implements Figure 2"
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)
        
        # 1) Do all the linear projections in batch from d_model => h x d_k 
        query, key, value = \
            [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
             for l, x in zip(self.linears, (query, key, value))]
        
        # 2) Apply attention on all the projected vectors in batch. 
        x, self.attn = attention(query, key, value, mask=mask, 
                                 dropout=self.dropout)
        
        # 3) "Concat" using a view and apply a final linear. 
        x = x.transpose(1, 2).contiguous() \
             .view(nbatches, -1, self.h * self.d_k)
        return self.linears[-1](x)