In [None]:
import os
import time
import math
import copy
import spacy
import GPUtil
import pandas as pd
from typing import *
from itertools import chain

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.optim.lr_scheduler import LambdaLR
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, Dataset

import altair as alt
from altair import Chart

alt.data_transformers.disable_max_rows()

## Positional Encoding

The positional encoding module is added, for the transformer to understand relative word positions, this is, absolute positions within the text but also in relation to each other. Periodical functions (sine and cosine) are used, as their orthogonality allows for unique encodings to be described through combinations of them (trigonometric identities). In addition, a dropout layer is added after the PE to avoid overfitting during training, as it prevents over-dependence on exact token positions.

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2)).float() * (-math.log(10000)/d_model)  #Exp for (math) convenience
        pe[:,0::2] = torch.sin(position * div_term)
        pe[:,1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  #Add batch dimension for input
        self.register_buffer("pe", pe)   #Register positional encoding as non-updatable tensor (not parameter)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)].detach()  #Adjust to input size, stop gradient flowing through PE
        return self.dropout(x)

## Multi-Head Attention

The Multi-Head Attention module uses several sets of Query (Q), Key (K) and Value (V) matrices, where each set of matrices(belonging to a head) will capture information about the text in a different regard. This is called an Attention Module.

For instance, Head 1 with matrices {K_1, Q_1, V_1} will extract the semantic information, while Head 2 with matrices {K_2, Q_2, V_2} will extract syntactic information. Each of these modules will compute a weighted sum of the attention probabilities.

In the end, the weighted sums are concatenated and projected through a final layer.

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        assert head_dim * num_heads == d_model      #Ensure integer dimensions

        #Linear (affine) transformations, acting like matrices
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def scaled_dot_product_attention(self, Q, K, V, mask=False):
        scale = math.sqrt(self.head_dim)
        K_t = K.transpose(-2, -1)    #Transpose to match dimensions and get right similarity scores

        attn_scores = torch.matmul(Q, K_t) / scale

        #We want to remove the similarity scores of zero from attn_scores, but softmax will turn them to 1 because exp(0)=1. 
        #The mask transforms those logits so exp(-1e9)=0 and they don't receive attention after softmax is applied
        if mask == True:
            attn_scores = attn_scores.masked_fill(mask==0, -1e9)

        attn_probs = F.softmax(attn_scores, dim=-1)
        return torch.matmul(attn_probs, V)

    def _project(self, x, linear):
        #Project and reshape, as output of projection has shape [batch_size, sequence_length, d_model]
        batch_size = x.size(0)
        return linear(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

    def forward(self, K, Q, V, mask=None):
        batch_size = Q.size(0)

        Q_proj = self._project(Q, self.W_q)
        K_proj = self._project(K, self.W_k)
        V_proj = self._project(V, self.W_v)

        #Attention for each head
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)

        #Concatenate heads and reshape vector to size=d_model
        attn_output = attn_output.transpose(1,2).contiguous().view(batch_size, -1, self.d_model)

        #Project concatenated heads
        return self.W_o(attn_output)