In [1]:
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F

This tutorial is based on/ copied from http://peterbloem.nl/blog/transformers.

# Self-attention

Self-attention maps a sequence of vectors $x_1, ..., x_t$ to an output sequence of vectors $y_1, ..., y_t$ by taking weighted averages of the input:

$$y_i = \sum_j w_{ij}x_j$$

Here, $w_{ij}$ captures the interaction between inputs $x_i$ and $x_j$. For example, with the softmax over the inner products, i.e.

$$w'_{ij} = x_i^\text{T}x_j$$

$$w_{ij} = \frac{\exp(w'_{ij})}{\sum_jw'_{ij}}$$.

<img src="imgs/self-attention.svg" alt="drawing" width="500"/>

In [30]:
# Our input x is a sequence of t vectors of dimension k. 
# Also, we want to process it in a batch of size b later on.
# So our dimension is [b, t, k].

# Let's start by using a random tensor for x.
b, t, l = 8, 4, 10
x = torch.rand(size=(b, t, l))
print(f'x: {x.shape}')

# To compute w', we use the batch matrix multiplication bmm.
# This results in dimension [b, t, t].
w_prime = torch.bmm(x, x.transpose(1, 2))

# By applying the softmax over the last dimension of w_prime, we obtain w.
w = F.softmax(w_prime, dim=-1)
print(f'w: {w.shape}')

# Now to obtain the sequence y (of dimension [b, t, k]), we take the weighted (by w) average of X.
y = torch.bmm(w, x)
print(f'y: {y.shape}')

x: torch.Size([8, 4, 10])
w: torch.Size([8, 4, 4])
y: torch.Size([8, 4, 10])


## Query, Key, Value
In this basic form of self-attention a single vector $x_i$ is used for three different tasks:
1. Used in the weights for its own output $y_i$. -> **query**
2. Used in the weights for the j-th output $y_j$. -> **key**
3. Used as part of the weighted sum.  -> **value**

To disentangle this 3 different 'roles' of $x_i$, we introduce a (learnable) linear transformation for each. In particular, we need 3 $k \times k$ weight matrices $W_q, W_k, W_v$:

$$q_i = W_qx_i \qquad \text(Query)$$

$$k_i = W_kx_i \qquad \text(Key)$$

$$v_i = W_vx_i \qquad \text(Value)$$

This gives the self-attention layer some controllable parameters, and allows it to modify the incoming vectors to suit the three roles they must play.

<img src="imgs/key-query-value.svg" alt="drawing" width="500"/>

## Scaling the dot product

The softmax function can be sensitive to very large input values. These kill the gradient, and slow down learning. The average value of the dot product grows with the embedding dimension **k**, therefore, it helps to scale the dot product depending on this value:

$$w'_{ij}= \frac{q_i^\text{T}k_j}{\sqrt{k}}$$

We use $\sqrt{k}$ in the denominator because that's the euclidean length of a unit vector in $\mathbb{R}^k$.

## Multi-head attention

We can increase the representational power of the self attention by combining them. Instead of using only a single set of 3 transformation matrices $W_q, W_k, W_v$, we use many of them (indexed with $r$) $W^r_q, W^r_k, W^r_v$. These are called *attention heads*.

Using the individual attention heads, we produce multiple output vectors $y^r_i$ for a single input vector $x_i$. We can then concatenate the $y^r_i$ vectors and pass them through another linear transformation to reduce the dimension back to $k$.

Note for the implementation:
While we think about the attention heads as $h$ separate sets of three matrices (of shape $k\times k$), we implement it by 'stacking' them such that we have only a single set of three matrices of shape $k\times h*k$. This way we can compute all the concatenated queries, keys, and values in a single matrix multiplication.

## Implementation of a SelfAttention Module

In [2]:
# Let's implement a SelfAttention torch module.

class SelfAttention(nn.Module):
    """
    A SelfAttention model.
    
    Args:
        k: The embedding dimension.
        heads: The number of attention heads.
    """
    def __init__(self, k: int, heads: int=8):
        super().__init__()
        self.k, self.h = k, heads
        
        self.Wq = nn.Linear(k, k * heads, bias=False)
        self.Wk = nn.Linear(k, k * heads, bias=False)
        self.Wv = nn.Linear(k, k * heads, bias=False)
        
        # This unifies the outputs of the different heads into 
        # a single k-dimensional vector.
        self.unifyheads = nn.Linear(heads * k, k)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: The input embedding of shape [b, t, k].
            
        Returns:
            Self attention tensor of shape [b, t, k].
        """
        b, t, k = x.size()
        h = self.h
        
        # Transform the input embeddings x of shape [b, t, k] to queries, keys, values.
        # The output shape is [b, t, k, k*h] which we transform into [b, t, h, k].
        queries = self.Wq(x).view(b, t, h, k)
        keys = self.Wk(x).view(b, t, h, k)
        values = self.Wv(x).view(b, t, h, k)
        
        # Fold heads into the batch dimension.
        keys = keys.transpose(1, 2).contiguous().view(b * h, t, k)
        queries = queries.transpose(1, 2).contiguous().view(b * h, t, k)
        values = values.transpose(1, 2).contiguous().view(b * h, t, k)
        
        # Compute the product of queries and keys and scale with sqrt(k).
        # The tensor w' has shape (b*h, t, t) containing raw weights.
        w_prime = torch.bmm(queries, keys.transpose(1, 2)) / np.sqrt(k)

        # Compute w by normalizing w' over the last dimension.
        w = F.softmax(w_prime, dim=-1) 
        
        # Apply the self attention to the values.
        out = torch.bmm(w, values).view(b, h, t, k)
        
        # Swap h, t back.
        out = out.transpose(1, 2).contiguous().view(b, t, h * k)
        
        # Unify heads to arrive at shape [b, t, k].
        return self.unifyheads(out)


In [53]:
# Test it out.
b, t, k, h = 2, 4, 6, 8
sa = SelfAttention(k=k, heads=h)
x = torch.rand(size=(b, t, k))
sa(x).shape

torch.Size([2, 4, 6])

# Transformers

The transformer architecture consists of multiple transformer blocks that typically look like this: 

<img src="imgs/transformer-block.svg" alt="drawing" width="500"/>
It combines a self attention layer, layer normalization, a feed forward layer and another layer normalization. Additionally, it uses residual connections around the self attention and feed forward layer.

In [3]:
class TransformerBlock(nn.Module):
    """
    A Transformer block consisting of self attention and ff-layer.
    
    Args:
        k (int): The embedding dimension.
        heads (int): The number of attention heads.
    """
    def __init__(self, k: int, heads: int=8, n_mlp: int=4):
        super().__init__()
        
        # The self attention layer.
        self.attention = SelfAttention(k, heads=heads)
        
        # The two layer norms.
        self.norm1 = nn.LayerNorm(k)
        self.norm2 = nn.LayerNorm(k)
        
        # The feed-forward layer.
        self.ff = nn.Sequential(
            nn.Linear(k, n_mlp*k),
            nn.ReLU(),
            nn.Linear(n_mlp*k, k)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: The input embedding of shape [b, t, k].
            
        Returns:
            Transformer output tensor of shape [b, t, k].
        """
        x_prime = self.attention(x)
        x = self.norm1(x_prime + x)
        
        x_prime = self.ff(x)
        return self.norm2(x_prime + x)
        

# Sentiment Classification with Transformers

<img src="imgs/classifier.svg" alt="drawing" width="500"/>

In [47]:
class TextClassificationTransformer(nn.Module):
    """
    Stacked Transformer blocks for sequence classification.
    
    Args:
        k (int): The embedding dimension.
        heads (int): The number of attention heads for each transformer block.
        depth (int): The number of transformer blocks.
        max_seq_len (int): The maximum number of tokens of each sequence.
        num_classes (int): The number of classification classes.
    """
    def __init__(self, k: int, heads: int=8, depth: int=4,
                max_seq_len: int=100, num_tokens: int=50000, 
                num_classes: int=2):
        super().__init__()
        
        self.num_tokens = num_tokens
        
        # Embeddings for tokens and position.
        self.token_emb = nn.Embedding(num_tokens, k)
        self.pos_emb = nn.Embedding(max_seq_len, k)
        
        # The stacked transformer blocks.
        self.transformer_blocks = nn.Sequential(
            *[TransformerBlock(k=k, heads=heads) for _ in range(depth)]
        )
        
        # Mapping of final output sequence to class logits.
        self.classification = nn.Linear(k, num_classes)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x (torch.Tensor): A tensor of shape (b, t) of integer values
                representing words in some predetermined vocabulary.
        
        Returns:
            A tensor of shape (b, c) of logits over the classes
                (c is the number of classes).
        """
        # Generate token embeddings.
        # Shape: [b, t, k]
        tokens = self.token_emb(x)
        b, t, k = tokens.size()
        
        # Generate position embeddings.
        # Shape: [b, t, k]
        positions = self.pos_emb(torch.arange(t)).unsqueeze(0).expand(b, t, k)
        
        # Add the two embeddings.
        embedding = tokens + positions
        
        # Feed the embedding into the transformer blocks.
        # Shape: [b, t, k]
        x = self.transformer_blocks(embedding)
        
        # Compute the mean latent vector for each sequence.
        # The mean is applied over dim=1 (time).
        # Shape: [b, k]
        x = x.mean(dim=1)
        
        # Classify.
        # Shape: [b, num_classes]
        return self.classification(x)

In [49]:
import pytorch_lightning as pl
from pytorch_lightning import LightningModule, LightningDataModule
from keras.datasets import imdb
from tensorflow.keras.preprocessing.sequence import pad_sequences
from torch.utils.data import TensorDataset, DataLoader
from torch.optim import Adam

### Data Loader

In [66]:

class IMDBDataModule(LightningDataModule):
    """
    LightningDataModule to load the IMDB movie review sentiment data.
    """ 
    
    def __init__(self, batch_size: int):
        super().__init__()
        self.batch_size = batch_size
        
    def setup(self, num_words: int, max_seq_len: int):
        """
        Initial loading of the dataset and transformation.
        
        Args:
            num_words (int): The vocabulary size. The vocabulary is 
                sorted by frequency of appearance in the dataset.
            max_seq_len (int): The maximum number of tokens per
                review.
        """
        (self.x_train, self.y_train), (self.x_test, self.y_test) = imdb.load_data(
            num_words=num_words, 
            maxlen=max_seq_len
        )
        print(f'# Training Examples: {len(self.y_train)}')
        print(f'# Test Examples: {len(self.y_test)}')
        
        self.word2idx = dict(
            **{k: v+3 for k, v in imdb.get_word_index().items()},
            **{'<PAD>': 0,
               '<START>': 1,
               '<UNK>': 2,
               '<UNUSED>': 3,
              },
        )
        self.idx2word = {v: k for k, v in self.word2idx.items()}
        
        # Pad the inputs and convert to torch Tensors.
        self.x_train = pad_sequences(self.x_train, maxlen=max_seq_len, value = 0.0)
        self.x_test = pad_sequences(self.x_test, maxlen=max_seq_len, value = 0.0)
    
    def example(self):
        """Returns a random training example."""        
        idx = np.random.randint(0, len(self.x_train))
        x, y = self.x_train[idx], self.y_train[idx]
        review = ' '.join(self.idx2word[token_id] for token_id in x if token_id > 1)
        sentiment = 'POSITIVE' if y else 'NEGATIVE'
        return f'{review}\nSentiment: {sentiment}'
    
    def train_dataloader(self):
        dataset = TensorDataset(torch.LongTensor(self.x_train), 
                                torch.LongTensor(self.y_train))
        return DataLoader(dataset, self.batch_size)
                                
    def test_dataloader(self):
        dataset = TensorDataset(torch.LongTensor(self.x_test), 
                                torch.LongTensor(self.y_test))
        return DataLoader(dataset, self.batch_size)
    
    def val_dataloader(self):
        dataset = TensorDataset(torch.LongTensor(self.x_test), 
                                torch.LongTensor(self.y_test))
        return DataLoader(dataset, self.batch_size)
    
imdb_data = IMDBDataModule(128)
imdb_data.setup(num_words=30000,
                max_seq_len=100)
    

# Training Examples: 2773
# Test Examples: 2963


In [67]:
# Copy the nn.Module from above and use it as LightningModule here.

class TextClassificationTransformer(LightningModule):
    """
    Stacked Transformer blocks for sequence classification.
    
    Args:
        k (int): The embedding dimension.
        heads (int): The number of attention heads for each transformer block.
        depth (int): The number of transformer blocks.
        max_seq_len (int): The maximum number of tokens of each sequence.
        num_classes (int): The number of classification classes.
    """
    def __init__(self, k: int, heads: int=8, depth: int=4,
                max_seq_len: int=100, num_tokens: int=50000, 
                num_classes: int=2):
        super().__init__()
        
        self.num_tokens = num_tokens
        
        # Embeddings for tokens and position.
        self.token_emb = nn.Embedding(num_tokens, k)
        self.pos_emb = nn.Embedding(max_seq_len, k)
        
        # The stacked transformer blocks.
        self.transformer_blocks = nn.Sequential(
            *[TransformerBlock(k=k, heads=heads) for _ in range(depth)]
        )
        
        # Mapping of final output sequence to class logits.
        self.classification = nn.Linear(k, num_classes)
        
        self.criterion = nn.CrossEntropyLoss()
        self.accuracy = pl.metrics.Accuracy()
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x (torch.Tensor): A tensor of shape (b, t) of integer values
                representing words in some predetermined vocabulary.
        
        Returns:
            A tensor of shape (b, c) of logits over the classes
                (c is the number of classes).
        """
        # Generate token embeddings.
        # Shape: [b, t, k]
        tokens = self.token_emb(x)
        b, t, k = tokens.size()
        
        # Generate position embeddings.
        # Shape: [b, t, k]
        positions = self.pos_emb(torch.arange(t)).unsqueeze(0).expand(b, t, k)
        
        # Add the two embeddings.
        embedding = tokens + positions
        
        # Feed the embedding into the transformer blocks.
        # Shape: [b, t, k]
        x = self.transformer_blocks(embedding)
        
        # Compute the mean latent vector for each sequence.
        # The mean is applied over dim=1 (time).
        # Shape: [b, k]
        x = x.mean(dim=1)
        
        # Classify.
        # Shape: [b, num_classes]
        return self.classification(x)
    
    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-3)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        
        # Forward pass.
        logits = self(x)
        
        # Compute the loss with CrossEntropy.
        loss = self.criterion(logits, y)
        
        # Log the metrics.
        self.log('loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('acc', self.accuracy(logits, y), on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        # Lightning automatically disables gradients and puts model in eval mode.
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        
        # Log the metrics.
        self.log('test_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('test_acc', self.accuracy(logits, y), on_step=True, on_epoch=True, prog_bar=True, logger=True)
        
    def val_step(self, batch, batch_idx):
        return self.test_step(batch, batch_idx)
        
        
        

In [68]:
NUM_WORDS = 30000
MAX_SEQ_LEN = 100
EMBEDDING_DIM = 100
BATCH_SIZE = 32

imdb_data = IMDBDataModule(batch_size=BATCH_SIZE)
imdb_data.setup(num_words=NUM_WORDS,
                max_seq_len=MAX_SEQ_LEN)

model = TextClassificationTransformer(k=EMBEDDING_DIM,
                                      max_seq_len=MAX_SEQ_LEN,
                                      num_tokens=NUM_WORDS)
trainer = pl.Trainer(max_epochs=2,
                     default_root_dir='ckpts')
trainer.fit(model, imdb_data)

  x_train, y_train = np.array(xs[:idx]), np.array(labels[:idx])
  x_test, y_test = np.array(xs[idx:]), np.array(labels[idx:])
GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name               | Type             | Params
--------------------------------------------------------
0 | token_emb          | Embedding        | 3 M   
1 | pos_emb            | Embedding        | 10 K  
2 | transformer_blocks | Sequential       | 1 M   
3 | classification     | Linear           | 202   
4 | criterion          | CrossEntropyLoss | 0     
5 | accuracy           | Accuracy         | 0     


# Training Examples: 2773
# Test Examples: 2963




HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…




1

In [69]:
trainer.test()



HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'acc': tensor(0.2381),
 'acc_epoch': tensor(0.5889),
 'acc_step': tensor(0.2381),
 'loss': tensor(0.8217),
 'loss_epoch': tensor(0.6754),
 'loss_step': tensor(0.8217),
 'test_acc': tensor(0.5263),
 'test_acc_epoch': tensor(0.5413),
 'test_loss': tensor(0.6920),
 'test_loss_epoch': tensor(0.6883)}
--------------------------------------------------------------------------------



[{'loss_step': 0.8216840028762817,
  'acc_step': 0.2380952388048172,
  'loss': 0.8216840028762817,
  'acc': 0.2380952388048172,
  'loss_epoch': 0.6754302978515625,
  'acc_epoch': 0.588943600654602,
  'test_loss_epoch': 0.6882695555686951,
  'test_acc_epoch': 0.5413432121276855,
  'test_loss': 0.692012369632721,
  'test_acc': 0.5263158082962036}]