In [1]:
!pip install pytorch_lightning

Collecting pytorch_lightning
  Downloading pytorch_lightning-1.5.0-py3-none-any.whl (1.0 MB)
[K     |████████████████████████████████| 1.0 MB 1.9 MB/s 
[?25hCollecting pyDeprecate==0.3.1
  Downloading pyDeprecate-0.3.1-py3-none-any.whl (10 kB)
Collecting torchmetrics>=0.4.1
  Downloading torchmetrics-0.6.0-py3-none-any.whl (329 kB)
[K     |████████████████████████████████| 329 kB 52.0 MB/s 
[?25hCollecting PyYAML>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 57.7 MB/s 
[?25hCollecting fsspec[http]!=2021.06.0,>=2021.05.0
  Downloading fsspec-2021.10.1-py3-none-any.whl (125 kB)
[K     |████████████████████████████████| 125 kB 71.4 MB/s 
Collecting future>=0.17.1
  Downloading future-0.18.2.tar.gz (829 kB)
[K     |████████████████████████████████| 829 kB 58.6 MB/s 
[?25hCollecting aiohttp
  Downloading aiohttp-3.8.0-cp37-cp37m-manylinux_2_5_x8

In [2]:
import math
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
N = 10000
S = 32  # target sequence length. input sequence will be twice as long
C = 128  # number of "classes", including 0, the "start token", and 1, the "end token"

Y = (torch.rand((N * 10, S - 2)) * (C - 2)).long() + 2  # Only generate ints in (2, 99) range

# Make sure we only have unique rows
Y = torch.tensor(np.unique(Y, axis=0)[:N])
X = torch.repeat_interleave(Y, 2, dim=1)

# Add special 0 "start" and 1 "end" tokens to beginning and end
Y = torch.cat([torch.zeros((N, 1)), Y, torch.ones((N, 1))], dim=1).long()
X = torch.cat([torch.zeros((N, 1)), X, torch.ones((N, 1))], dim=1).long()

# Look at the data
print(X, X.shape)
print(Y, Y.shape)
print(Y.min(), Y.max())

tensor([[  0,   2,   2,  ..., 119, 119,   1],
        [  0,   2,   2,  ...,  68,  68,   1],
        [  0,   2,   2,  ...,  18,  18,   1],
        ...,
        [  0,  14,  14,  ...,  21,  21,   1],
        [  0,  14,  14,  ..., 123, 123,   1],
        [  0,  14,  14,  ...,  98,  98,   1]]) torch.Size([10000, 62])
tensor([[  0,   2,   2,  ...,  73, 119,   1],
        [  0,   2,   2,  ...,  98,  68,   1],
        [  0,   2,   2,  ...,  45,  18,   1],
        ...,
        [  0,  14,  87,  ...,   4,  21,   1],
        [  0,  14,  88,  ...,  43, 123,   1],
        [  0,  14,  88,  ...,  70,  98,   1]]) torch.Size([10000, 32])
tensor(0) tensor(127)


In [7]:
BATCH_SIZE = 128
TRAIN_FRAC = 0.8

dataset = list(zip(X, Y))  # This fulfills the pytorch.utils.data.Dataset interface

# Split into train and val
num_train = int(N * TRAIN_FRAC)
num_val = N - num_train
data_train, data_val = torch.utils.data.random_split(dataset, (num_train, num_val))

dataloader_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE)
dataloader_val = torch.utils.data.DataLoader(data_val, batch_size=BATCH_SIZE)

# Sample batch
x, y = next(iter(dataloader_train))
print(f'x: {x.size()}')
print(f'y: {y.size()}')

x: torch.Size([128, 62])
y: torch.Size([128, 32])


In [33]:
class PositionalEncoding(nn.Module):
    """
    Classic Attention-is-all-you-need positional encoding.
    From PyTorch docs.
    """

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)



In [38]:
def generate_square_subsequent_mask(size: int):
    """Generate a triangular (size, size) mask. From PyTorch docs."""
    mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

In [39]:
# test
temp = generate_squre_subsequent_mask(10)
print(temp)

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])


In [40]:
class Transformer(nn.Module):
    """
    Classic Transformer that both encodes and decodes.
    
    Prediction-time inference is done greedily.

    NOTE: start token is hard-coded to be 0, end token to be 1. If changing, update predict() accordingly.
    """

    def __init__(self, num_classes: int, max_output_length: int, dim: int = 128):
        super().__init__()

        # Parameters
        self.dim = dim
        self.max_output_length = max_output_length
        nhead = 4
        num_layers = 4
        dim_feedforward = dim

        # Encoder part
        self.embedding = nn.Embedding(num_classes, dim)
        self.pos_encoder = PositionalEncoding(d_model=self.dim)
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer=nn.TransformerEncoderLayer(d_model=self.dim, nhead=nhead, dim_feedforward=dim_feedforward),
            num_layers=num_layers
        )

        # Decoder part
        self.y_mask = generate_square_subsequent_mask(self.max_output_length)
        self.transformer_decoder = nn.TransformerDecoder(
            decoder_layer=nn.TransformerDecoderLayer(d_model=self.dim, nhead=nhead, dim_feedforward=dim_feedforward),
            num_layers=num_layers
        )
        self.fc = nn.Linear(self.dim, num_classes)

        # It is empirically important to initialize weights properly
        self.init_weights()
    
    def init_weights(self):
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()
        self.fc.weight.data.uniform_(-initrange, initrange)
      
    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """
        Input
            x: (B, Sx) with elements in (0, C) where C is num_classes
            y: (B, Sy) with elements in (0, C) where C is num_classes
        Output
            (B, C, Sy) logits
        """
        encoded_x = self.encode(x)  # (Sx, B, E)
        output = self.decode(y, encoded_x)  # (Sy, B, C)
        return output.permute(1, 2, 0)  # (B, C, Sy)

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """
        Input
            x: (B, Sx) with elements in (0, C) where C is num_classes
        Output
            (Sx, B, E) embedding
        """
        x = x.permute(1, 0)  # (Sx, B, E)
        x = self.embedding(x) * math.sqrt(self.dim)  # (Sx, B, E)
        x = self.pos_encoder(x)  # (Sx, B, E)
        x = self.transformer_encoder(x)  # (Sx, B, E)
        return x

    def decode(self, y: torch.Tensor, encoded_x: torch.Tensor) -> torch.Tensor:
        """
        Input
            encoded_x: (Sx, B, E)
            y: (B, Sy) with elements in (0, C) where C is num_classes
        Output
            (Sy, B, C) logits
        """
        y = y.permute(1, 0)  # (Sy, B)
        y = self.embedding(y) * math.sqrt(self.dim)  # (Sy, B, E)
        y = self.pos_encoder(y)  # (Sy, B, E)
        Sy = y.shape[0]
        y_mask = self.y_mask[:Sy, :Sy].type_as(encoded_x)  # (Sy, Sy)
        output = self.transformer_decoder(y, encoded_x, y_mask)  # (Sy, B, E)
        output = self.fc(output)  # (Sy, B, C)
        return output

    def predict(self, x: torch.Tensor) -> torch.Tensor:
        """
        Method to use at inference time. Predict y from x one token at a time. This method is greedy
        decoding. Beam search can be used instead for a potential accuracy boost.

        Input
            x: (B, Sx) with elements in (0, C) where C is num_classes
        Output
            (B, C, Sy) logits
        """
        encoded_x = self.encode(x)
        
        output_tokens = (torch.ones((x.shape[0], self.max_output_length))).type_as(x).long() # (B, max_length)
        output_tokens[:, 0] = 0  # Set start token
        for Sy in range(1, self.max_output_length):
            y = output_tokens[:, :Sy]  # (B, Sy)
            output = self.decode(y, encoded_x)  # (Sy, B, C)
            output = torch.argmax(output, dim=-1)  # (Sy, B)
            output_tokens[:, Sy] = output[-1:]  # Set the last output token
        return output_tokens

In [41]:
model = Transformer(num_classes=C, max_output_length=y.shape[1])
logits = model(x, y[:, :-1])
print(x.shape, y.shape, logits.shape)
print(x[0:1])
print(model.predict(x[0:1]))

torch.Size([128, 62]) torch.Size([128, 32]) torch.Size([128, 128, 31])
tensor([[  0,  13,  13,  34,  34, 120, 120, 102, 102,  15,  15,  59,  59,  64,
          64,  79,  79,  64,  64,  92,  92,  77,  77,  54,  54,  92,  92, 127,
         127,  71,  71,   8,   8,  94,  94,  88,  88,   2,   2,  24,  24, 101,
         101,  45,  45,  95,  95, 117, 117, 124, 124, 124, 124,  15,  15,  47,
          47,  65,  65,  97,  97,   1]])
tensor([[  0,   6, 115,  27,  14, 115,  27,  27,   6,  27,   6, 115, 115,   2,
          27,   2,  27,   2,  27,   2,  27,  27,   2,   2,  27,   2,  25,   2,
          27,   2,   2,  19]])


In [43]:
class LitModel(pl.LightningModule):
    """Simple PyTorch-Lightning model to train our Transformer."""

    def __init__(self, model):
        super().__init__()
        self.model = model
        self.loss = nn.CrossEntropyLoss()
        self.val_acc = pl.metrics.Accuracy()

    def training_step(self, batch, batch_ind):
        x, y = batch
        # Teacher forcing: model gets input up to the last character,
        # while ground truth is from the second character onward.
        logits = self.model(x, y[:, :-1])
        loss = self.loss(logits, y[:, 1:])
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_ind):
        x, y = batch
        logits = self.model(x, y[:, :-1])
        loss = self.loss(logits, y[:, 1:])
        self.log("val_loss", loss, prog_bar=True)
        pred = self.model.predict(x)
        self.val_acc(pred, y)
        self.log("val_acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())


model = Transformer(num_classes=C, max_output_length=y.shape[1])
lit_model = LitModel(model)
early_stop_callback = pl.callbacks.EarlyStopping(monitor='val_loss')
trainer = pl.Trainer(max_epochs=5, gpus=[0], callbacks=[early_stop_callback], progress_bar_refresh_rate=79)
trainer.fit(lit_model, dataloader_train, dataloader_val)

AttributeError: ignored

In [44]:
x, y = next(iter(dataloader_val))
print('Input:', x[:1])
pred = lit_model.model.predict(x[:1])
print('Truth/Pred:')
print(torch.cat((y[:1], pred)))

Input: tensor([[  0,   3,   3,  56,  56,  34,  34,  74,  74, 121, 121,  81,  81,  30,
          30,  99,  99,  69,  69,  28,  28,  13,  13,   8,   8,  52,  52,  95,
          95, 104, 104, 101, 101, 110, 110,  66,  66,  25,  25,  45,  45,  59,
          59, 107, 107,  32,  32,  51,  51,  33,  33,  39,  39,  36,  36,  78,
          78,  86,  86,   6,   6,   1]])


NameError: ignored