In [None]:
import os
import time
import math
import copy
import spacy
import GPUtil
import pandas as pd
from typing import *
from itertools import chain

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.optim.lr_scheduler import LambdaLR
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, Dataset

import altair as alt
from altair import Chart

alt.data_transformers.disable_max_rows()

## Positional Encoding

The positional encoding module is added, for the transformer to understand relative word positions, this is, absolute positions within the text but also in relation to each other. Periodical functions (sine and cosine) are used, as their orthogonality allows for unique encodings to be described through combinations of them (trigonometric identities). In addition, a dropout layer is added after the PE to avoid overfitting during training, as it prevents over-dependence on exact token positions.

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2)).float() * (-math.log(10000)/d_model)  #Exp for (math) convenience
        pe[:,0::2] = torch.sin(position * div_term)
        pe[:,1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  #Add batch dimension for input
        self.register_buffer("pe", pe)   #Register positional encoding as non-updatable tensor (not parameter)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)].detach()  #Adjust to input size, stop gradient flowing through PE
        return self.dropout(x)

## Multi-Head Attention

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        assert head_dim * num_heads == d_model

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

        