In [2]:
!pip install pytorch-lightning

Collecting pytorch-lightning
  Downloading pytorch_lightning-2.1.3-py3-none-any.whl (777 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m777.7/777.7 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
Collecting torchmetrics>=0.7.0 (from pytorch-lightning)
  Downloading torchmetrics-1.3.0-py3-none-any.whl (840 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m840.2/840.2 kB[0m [31m20.9 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities>=0.8.0 (from pytorch-lightning)
  Downloading lightning_utilities-0.10.0-py3-none-any.whl (24 kB)
Installing collected packages: lightning-utilities, torchmetrics, pytorch-lightning
Successfully installed lightning-utilities-0.10.0 pytorch-lightning-2.1.3 torchmetrics-1.3.0


In [7]:
import math

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.optimizer import Optimizer
from torch.nn.utils.rnn import PackedSequence

import pytorch_lightning as pl

from typing import *


In [4]:
# Hyperparameters
MAX_SEQ = 150


In [None]:
def position_encoding_init(n_position, emb_dim):
    """Init the sinusoid position encoding table"""

    # keep dim 0 for padding token position encoding zero vector
    position_enc = torch.tensor(
        [
            [pos / np.power(10000, 2 * (j // 2) / emb_dim) for j in range(emb_dim)]
            if pos != 0
            else np.zeros(emb_dim)
            for pos in range(n_position)
        ],
        dtype=torch.float32,
    )

    position_enc[1:, 0::2] = np.sin(
        position_enc[1:, 0::2]
    )  # apply sin on 0th,2nd,4th...emb_dim
    position_enc[1:, 1::2] = np.cos(
        position_enc[1:, 1::2]
    )  # apply cos on 1st,3rd,5th...emb_dim

    if torch.cuda.is_available():
        position_enc = position_enc.cuda()

    return position_enc


In [None]:
class VariationalDropout(nn.Module):
    """
    Applies the same dropout mask across the temporal dimension. See
    https://arxiv.org/abs/1512.05287 for more details.

    Note that this is not applied to the recurrent activations in the
    LSTM like the above paper. Instead, it is applied to the inputs
    and outputs of the recurrent layer.
    """

    def __init__(self, dropout: float, batch_first: Optional[bool] = False):
        super().__init__()
        self.dropout = dropout
        self.batch_first = batch_first

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if not self.training or self.dropout <= 0.0:
            return x

        is_packed = isinstance(x, PackedSequence)
        if is_packed:
            x, batch_sizes = x
            max_batch_size = int(batch_sizes[0])
        else:
            batch_sizes = None
            max_batch_size = x.size(0)

        # Drop same mask across entire sequence
        if self.batch_first:
            m = x.new_empty(
                max_batch_size, 1, x.size(2), requires_grad=False
            ).bernoulli_(1 - self.dropout)
        else:
            m = x.new_empty(
                1, max_batch_size, x.size(2), requires_grad=False
            ).bernoulli_(1 - self.dropout)
        x = x.masked_fill(m == 0, 0) / (1 - self.dropout)

        if is_packed:
            return PackedSequence(x, batch_sizes)
        else:
            return x


In [None]:
class Transformer_Model(nn.Module):
    def __init__(
        self,
        nb_tags,
        nb_layers=1,
        pe_dim=0,
        emb_dim=100,
        batch_size=1,
        seq_len=MAX_SEQ,
        dropout=0.0,
        encoder_only=True,
    ):
        super(Transformer_Model, self).__init__()

        self.nb_layers = nb_layers
        self.emb_dim = emb_dim
        self.batch_size = batch_size
        self.seq_len = seq_len
        self.pe_dim = pe_dim
        self.dropout = dropout

        self.nb_tags = nb_tags

        self.encoder_only = encoder_only

        # build actual NN
        self.__build_model()

    def __build_model(self):
        self.embedding = nn.Embedding(self.nb_tags, self.emb_dim)

        if not self.encoder_only:
            self.embedding2 = nn.Embedding(self.nb_tags, self.emb_dim)

        self.pos_emb = position_encoding_init(MAX_SEQ, self.pe_dim)
        self.pos_emb.requires_grad = False

        self.dropout_i = nn.Dropout(self.dropout)

        input_size = self.pe_dim + self.emb_dim

        self.transformerLayerI = nn.TransformerEncoderLayer(
            d_model=input_size, nhead=8, dropout=self.dropout, dim_feedforward=1024
        )

        self.transformerI = nn.TransformerEncoder(
            self.transformerLayerI,
            num_layers=self.nb_layers,
        )

        self.dropout_m = nn.Dropout(self.dropout)

        if not self.encoder_only:
            # design decoder
            self.transformerLayerO = nn.TransformerDecoderLayer(
                d_model=input_size, nhead=8, dropout=self.dropout, dim_feedforward=1024
            )

            self.transformerO = nn.TransformerDecoder(
                self.transformerLayerO,
                num_layers=self.nb_layers,
            )

            self.dropout_o = nn.Dropout(self.dropout)

        # output layer which projects back to tag space
        self.hidden_to_tag = nn.Linear(self.emb_dim + self.pe_dim, self.nb_tags)

    def __pos_encode(self, p):
        return self.pos_emb[p]

    def forward(self, X, p, X2=None, train_embedding=True):
        self.embedding.weight.requires_grad = train_embedding
        if not self.encoder_only:
            self.embedding2.weight.requires_grad = train_embedding

        I = X

        self.mask = (torch.triu(torch.ones(self.seq_len, self.seq_len)) == 1).transpose(
            0, 1
        )
        self.mask = (
            self.mask.float()
            .masked_fill(self.mask == 0, float("-inf"))
            .masked_fill(self.mask == 1, float(0.0))
        )

        if torch.cuda.is_available():
            self.mask = self.mask.cuda()

        # ---------------------
        # Combine inputs
        X = self.embedding(I)
        X = X.view(self.seq_len, self.batch_size, -1)

        if self.pe_dim > 0:
            P = self.__pos_encode(p)
            P = P.view(self.seq_len, self.batch_size, -1)
            X = torch.cat((X, P), 2)

        X = self.dropout_i(X)

        # Run through transformer encoder

        M = self.transformerI(X, mask=self.mask)
        M = self.dropout_m(M)

        if not self.encoder_only:
            # ---------------------
            # Decoder stack
            X = self.embedding2(X2)
            X = X.view(self.seq_len, self.batch_size, -1)

            if self.pe_dim > 0:
                X = torch.cat((X, P), 2)

            X = self.dropout_i(X)

            X = self.transformerO(X, M, tgt_mask=self.mask, memory_mask=None)
            X = self.dropout_o(X)

            # run through linear layer
            X = self.hidden_to_tag(X)
        else:
            X = self.hidden_to_tag(M)

        Y_hat = X
        return Y_hat
