In [None]:
# http://www.peterbloem.nl/blog/transformers
# https://github.com/pbloem/former

In [1]:
import torch
import torch.nn as nn
from torch.optim import Adam, SGD
from torch.utils.data import DataLoader
from torch.utils.data import sampler

import torchvision.datasets as dset
import torchvision.transforms as T
import torch.nn.functional as F

import math, copy, time

import numpy as np
import pandas as pd
import os
import glob
import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import sklearn

eps = np.finfo(float).eps

plt.rcParams['figure.figsize'] = 10, 10
%matplotlib inline

%load_ext autoreload
%autoreload 2

### Attention module
* h attention heads, as h seperate sets of the three matrices W_q, W_k, W_v
* for efficiency, we combine thise into three single k x kh matrices

In [4]:
class AttentionLayer(nn.Module):
    def __init__(self, k, num_heads=8):
        super(AttentionLayer, self).__init__()
        self.k = k
        self.num_heads = num_heads
        
        # determine queries, keys, values
        self.key_layer = nn.Linear(self.k, self.k * self.num_heads, bias=False)
        self.query_layer = nn.Linear(self.k, self.k * self.num_heads, bias=False)
        self.value_layer = nn.Linear(self.k, self.k * self.num_heads, bias=False)
        
        # project down all cat-ed heads
        self.unify_layer = nn.Linear(heads * k, k)
    
    def forward(self, x):
        # get batch size, t sentences of k items
        b_sz, t_sz, k_sz = x.size()
        h_sz = self.num_heads
        
        keys = self.key_layer(x).view(b_sz, t_sz, h_sz, self.k)
        queries = self.query_layer(x).view(b_sz, t_sz, h_sz, self.k)
        values = self.value_layer(x).view(b_sz, t_sz, h_sz, self.k)
    
        # compute dot products (k x k). Same op for every head, so fold in to the
        # batch dim
        # q, k, v, (b, t, h, k) -> (b, h, t, k) -> (bh, t, k)
        # and for the key (bh, t, k) -> (bh, k, t) to be able to use bmm
        #
        keys = keys.transpose(1, 2).continuous().view(b_sz * h_sz, t_sz, k_sz)
        queries = queries.transpose(1, 2).continuous().view(b_sz * h_sz, t_sz, k_sz)
        values = values.transpose(1, 2).continuous().view(b_sz * h_sz, t_sz, k_sz)
        
        # intermediate scaling
        queries = queries / ( k ** (1./4.))
        keys = keys / ( k ** (1./4.))
        
        # final transpose for the bmm, out -> (b*h, t, t)
        raw_weights = torch.bmm(queries, keys.transpose(1, 2))
        
        # row wise softmax normalize
        weights = F.softmax(raw_weights, dim=2)
        
        # apply self attention to the values
        out = torch.bmm(weights, values).view(b_sz, h_sz, t_sz, k_sz)
        
        # Unify attention heads
        # reshuffle (b, h, t, k) -> (b, t, h, k) -> (b, t, h*k) with all the heads catted
        # ontop of each other to be able to down project
        out = out.transpose(1, 2).continuous().view(b_sz, t_sz, h_sz * k_sz)
        
        # project down
        out = self.unify_layer(out)
        
        return out
        

### Transformer module

In [5]:
class TransformerBlock(nn.Module):
    def __init__(self, k, num_heads):
        super(TransformerBlock, self).__init__()
        
        self.attention = AttentionLayer(k, num_heads)
        
        self.layer_norm1 = nn.LayerNorm(k)
        self.layer_norm2 = nn.LayerNorm(k)
        
        self.mlp = nn.Sequential(
            nn.Linear(k, 4 * k),
            nn.Relu(),
            nn.Linear(4 * k, k)
        )
    
    def forward(self, x):
        # Attention block
        x_att = self.attention(x)
        # Residual + norm
        x = self.layer_norm1(x + x_att)
        # MLP
        x_mlp = self.mlp(x)
        out = self.layer_norm2(x + x_mlp)
        return out

### Transformer