In [1]:
# This notebook is an exploration of the work by Peter Bloem
# http://peterbloem.nl/blog/transformers

In [2]:
# A naive implimentation would loop over all vectors to compute the weights and outputs
# We wish to express the self attention in matrix multiplications.
# The input is expressed as a sequence of t vectors of dimension k as a t by k matrix X.
# Including a minibatch dimension b, gives us an input tensor size:
# (b,t,k)
# The set of all raw dot products w'_ij forms a matrix, which we can compute by multiplying X by it's transpose

In [9]:
# Basic Self Attention

# import torch
# import torch.nn.functional as F

# Assume we have some tensor x with size (b,t,k) 
# x = ...

# raw_weights = torch.bmm(x, x.transpose(1,2))

# https://pytorch.org/docs/stable/generated/torch.bmm.html
# torch.bmm is a batched matrix multiplication. It applies matrix multiplication over batches of matrices. 
# https://pytorch.org/docs/stable/generated/torch.transpose.html
# dimensions 1 and 2 are are swapped in transpose (i.e. t and k) 

# We apply row-wise softmax
# weights = F.softmax(raw_weights, dim=2)

# Multiply the weight matrix by X to compute the output sequence
# y = torch.bmm(weights, x)

In [8]:
# Complete Self Attention

import torch
from torch import nn
import torch.nn.functional as F 

class SelfAttention(nn.Module):
    def __init__(self, k, heads=8): 
        super().__init__()
        self.k = k
        self.heads = heads
        self.tokeys = nn.Linear(k, k * heads, bias=False)
        self.toqueries = nn.Linear(k, k*heads, bias=False)
        self.tovalues = nn.Linear(k, k*heads, bias=False)
        self.unifyheads = nn.Linear(heads*k, k)
        
        
# We think of the h attention heads as h separate sets of three matricies, k x hk matricies 
# We can then compute all the concatenated queries, keys and values in a single matrix multiplication

    def forward(self, x): 
        b, t, k = x.size()
        h = self.heads
        
        # https://pytorch.org/docs/stable/tensors.html#torch.Tensor.view
        # Output of each linear module has size (b, t, h*k) which we reshape to (b,t,h,k)
        queries = self.toqueries(x).view(b, t, h, k)
        keys = self.tokeys(x).view(b, t, h, k)
        values = self.tovalues(x).view(b, t, h, k)
        
        # Next we compute the dot products
        # We fold the heads into the batch dimension to ensure we can use torch.bmm()
        queries = queries.transpose(1, 2).contiguous().view(b * h, t, k)
        keys = keys.transpose(1, 2).contiguous().view(b * h, t, k)
        values = values.transpose(1, 2).contiguous().view(b * h, t, k)
        
        # Before we compute the dot products, we scale the keys and queries by k^1/4
        # This is done instead of scaling the dot product by k^1/2
        queries = queries / (k ** (1/4))
        keys = keys / (k ** (1/4))
        
        dot = torch.bmm(queries, keys.transpose(1, 2))
        # dot has size (b*h, t, t) containing raw weights
        
        dot = F.softmax(dot, dim=2) 
        # dot now contains row-wise normalized weights
        
        # We then apply the self attention to the values, giving us the output for each attention head 
        out = torch.bmm(dot, values).view(b, h, t, k)
        
        # To unify the attention heads, we transpose again
        # This is to ensure the head dimension and the embedding dimension are next to each other
        # Reshape to get concatenated vectors of dimension kh
        # We then pass these through the unifyheads layer to project them back down to k dimensions
        
        out = out.transpose(1,2).contiguous().view(b, t, h*k)
        return self.unifyheads(out)

In [10]:
# The Transformer Block

class TransformerBlock(nn.Module):
    def __init__(self, k, heads):
        super().__init__()
        
        self.attention = SelfAttention(k, heads=heads)
        self.norm1 = nn.LayerNorm(k)
        self.norm2 = nn.LayerNorm(k)
        
        self.ff = nn.Sequential(
            nn.Linear(k, 4*k),
            nn.ReLU(),
            nn.Linear(4*k, k))
        
    def forward(self, x):
        attended = self.attention(x)
        x = self.norm1(attended + x)
        
        fedforward = self.ff(x)
        return self.norm2(fedforward + x)
    
# The hidden layer of the feedforward is 4 times as big as the input and output
# Smaller values may work as well and save memory, but it should be bigger than input/output layers

In [11]:
# Classification Transformer

# The instances are movie reviews, tokenized into sequences of words
# The classification labels are positive and negative
# The architecture will simply be a large chain of transformer blocks
# What we need to work out is how to feed it the input sequences 
# And how to transformer the final output sequence into a single classification

# The most common way is to apply global average pooling to the final output sequence
# and to map the result to a softmaxed class vector

# Stacking permutation invariant layers results in a final global average pooling that is permutation invariant
# We hence create a second vector of equal length that represents the position of the word in the current sequence

In [None]:
class Transformer(nn.Module):
    def __init__(self, k, heads, depth, seq_length, num_tokens, num_classes):
        super().__init__()
        
        self.num_tokens = num_tokens
        self.token_emb = nn.Embedding(num_tokens, k)
        self.pos_emb = nn.Embedding(seq_length, k)
        
        tblocks = []
        for i in range(depth):
            tblocks.append(TransformerBlock(k=k, heads=heads))
        self.tblocks = nn.Sequential(*tblocks)
        
        self.toprobs = nn.Linear(k, num_classes)
        
    def forward(self, x):
        """
        :param x: A (b,t) tensor of integer values representng words (in some predetermined vocabulary)
        :return: A (b,c) tensor of log-probabilities over the classes
        """
        
        # Generate token embeddings 
        tokens = self.token_emb(x)
        b, t, k = tokens.size()
        
        # Generate position embeddings
        positions = torch.arange(t)
        positions = self.pos_emb(positions)[None, :, :].expand(b, t, k) # <- REVISE
        
        x = tokens + positions
        x = self.tblocks(x)
        
        # Average-pool over the t-dimensions and project to class probabilities
        x = self.toprobs(x.mean(dim=1))
        return F.log_softmax(x, dim=1)