In [1]:
%config Completer.use_jedi = False

In [2]:
# imports
import math
import warnings
from pathlib import Path
warnings.filterwarnings('ignore')

import spacy
import pandas as pd

# core torch
import torch
from torch import nn
from torch.utils.data import random_split, DataLoader

# torch ecosystem
import torchtext
from torchtext.datasets import WMT14
from torchtext.data.utils import get_tokenizer
from torchtext import data
import pytorch_lightning as pl

assert torch.__version__ == '1.7.1'
assert spacy.__version__ == '3.0.1'
assert torchtext.__version__ == '0.8.0'
assert pl.__version__ == '1.1.7'

print(f'torch version: {torch.__version__}')
print(f'spacy version: {spacy.__version__}')
print(f'torchtext version: {torchtext.__version__}')
print(f'torch lightning version: {pl.__version__}')

torch version: 1.7.1
spacy version: 3.0.1
torchtext version: 0.8.0
torch lightning version: 1.1.7


In [3]:
# source code
class TranslationDataModule(pl.LightningDataModule):
    """
    
    """
    def __init__(self, data_dir: Path, batch_size: int, reduced: bool, english_tokenizer, german_tokenizer):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.reduced = reduced

        # tokenizers
        self.english_tokenizer = english_tokenizer
        self.german_tokenizer = german_tokenizer

    def _reduce_data(self, 
                     source_suffix='train.tok.clean.bpe.32000',
                     target_suffix='train_reduced',
                     max_sent_len=30, 
                     max_data_size=None
                     ):
        '''
        '''
        en_source = str(self.data_dir / 'wmt14' / source_suffix) + '.en'
        de_source = str(self.data_dir / 'wmt14' / source_suffix) + '.de'

        english_df = pd.read_csv(en_source, sep='\n', header=None)
        german_df = pd.read_csv(de_source, sep='\n', header=None)
        
        df = pd.concat([english_df, german_df], axis=1)
        df.columns = ['english', 'german']

        # restrict dataframe size
        if max_data_size:
            df = df.sample(n=max_data_size)

        # preprocessing : move out of this func
        df.english = df.english.str.lower()
        df.german = df.german.str.lower()

        # remove very long sentence
        df['english_sent_len'] = df.english.apply(self.english_tokenizer).agg(len)
        df['german_sent_len'] = df.german.apply(self.german_tokenizer).agg(len)
        df = df.query(f'english_sent_len <= {max_sent_len} and german_sent_len <= {max_sent_len}')

        en_target = str(self.data_dir / 'wmt14' / target_suffix) + '.en'
        de_target = str(self.data_dir/ 'wmt14' /  target_suffix) + '.de'

        df['english'].to_csv(en_target, sep='\n', header=None, index=False)
        df['german'].to_csv(de_target, sep='\n', header=None, index=False)

        print(f'number of sentence pairs in reduced dataset = {len(df)}')
    
    def prepare_data(self, max_sent_len=30, max_train_data_size=1_000_000):
        """
        Use this method to do things that might write to disk or that need to be done only from a single GPU in distributed settings.
        Example:
            - Download dataset
            - Tokenize
        """
        # download
        WMT14.download(self.data_dir)

        # limit
        if self.reduced:
            # reduce training
            self._reduce_data(
                max_sent_len=max_sent_len,
                max_data_size=max_train_data_size)

            # reduce validation
            self._reduce_data(
                source_suffix='newstest2013.tok.bpe.32000', 
                target_suffix='valid_reduced',
                max_sent_len=max_sent_len,
                max_data_size=None
                )


    def setup(self, src_vocab_max_size=50_000, trgt_vocab_max_size=50_000, stage=None):
        '''
        There are also data operations you might want to perform on every GPU. Use setup to do things like:
        Example:
            - count number of classes
            - build vocabulary
            - perform train/val/test splits
            - apply transforms (defined explicitly in your datamodule or assigned in init)
        '''
        eos_token = '<eos>'
        self.src_field = torchtext.data.Field(
            tokenize=self.english_tokenizer,
            eos_token=eos_token, 
            batch_first=True,
            lower=True,
            )
        self.trgt_field = torchtext.data.Field(
            tokenize=self.german_tokenizer,
            eos_token=eos_token, 
            batch_first=True,
            lower=True
            )
        
        root = str(self.data_dir)
        train_data = 'train_reduced' if self.reduced else 'train.tok.clean.bpe.32000'
        valid_data = 'valid_reduced' if self.reduced else 'newstest2013.tok.bpe.32000'

        self.train, self.valid, self.test = WMT14.splits(
            exts=('.en', '.de'), 
            fields=(self.src_field, self.trgt_field), 
            root=root,
            train=train_data,
            validation=valid_data
            )
        self.src_field.build_vocab(self.train, max_size=src_vocab_max_size)
        self.trgt_field.build_vocab(self.train, max_size=trgt_vocab_max_size)


    def train_dataloader(self):
#         return DataLoader(self.train, self.batch_size)
        return torchtext.data.BucketIterator(
            dataset=self.train, 
            batch_size=self.batch_size,
            sort_key=lambda x: torchtext.data.interleave_keys(len(x.src), len(x.trgt))
        )

    def val_dataloader(self):
#         return DataLoader(self.valid, self.batch_size)
        return torchtext.data.BucketIterator(
            dataset=self.valid, 
            batch_size=self.batch_size,
            sort_key=lambda x: torchtext.data.interleave_keys(len(x.src), len(x.trgt))
        )

    def test_dataloader(self):
#         return DataLoader(self.test, self.batch_size)
        return torchtext.data.BucketIterator(
            dataset=self.test, 
            batch_size=self.batch_size,
            sort_key=lambda x: torchtext.data.interleave_keys(len(x.src), len(x.trgt))
        )

In [30]:
# helper functions
def generate_square_subsequent_mask(size: int):
    """Generate a triangular (size, size) mask. From PyTorch docs."""
    mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


# positionla encoder
class PositionalEncoding(nn.Module):
    """
    Classic Attention-is-all-you-need positional encoding.
    From PyTorch docs.
    """

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

# transformer
class Transformer(nn.Module):
    """
    Classic Transformer that both encodes and decodes.
    
    Prediction-time inference is done greedily.

    NOTE: start token is hard-coded to be 0, end token to be 1. If changing, update predict() accordingly.
    """

    def __init__(self, num_classes: int, max_output_length: int, dim: int = 128):
        super().__init__()

        # Parameters
        self.dim = dim
        self.max_output_length = max_output_length
        nhead = 4
        num_layers = 4
        dim_feedforward = dim

        # Encoder part
        self.embedding = nn.Embedding(num_classes, dim)
        
        self.pos_encoder = PositionalEncoding(d_model=self.dim)
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer=nn.TransformerEncoderLayer(d_model=self.dim, nhead=nhead, dim_feedforward=dim_feedforward),
            num_layers=num_layers
        )

        # Decoder part
        self.y_mask = generate_square_subsequent_mask(self.max_output_length)
        self.transformer_decoder = nn.TransformerDecoder(
            decoder_layer=nn.TransformerDecoderLayer(d_model=self.dim, nhead=nhead, dim_feedforward=dim_feedforward),
            num_layers=num_layers
        )
        self.fc = nn.Linear(self.dim, num_classes)

        # It is empirically important to initialize weights properly
        self.init_weights()
    
    def init_weights(self):
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()
        self.fc.weight.data.uniform_(-initrange, initrange)
        
      
    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """
        Input
            x: (B, Sx) with elements in (0, C) where C is num_classes
            y: (B, Sy) with elements in (0, C) where C is num_classes
        Output
            (B, C, Sy) logits
        """
        encoded_x = self.encode(x)  # (Sx, B, E)
        output = self.decode(y, encoded_x)  # (Sy, B, C)
        return output.permute(1, 2, 0)  # (B, C, Sy)

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """
        Input
            x: (B, Sx) with elements in (0, C) where C is num_classes
        Output
            (Sx, B, E) embedding
        """
        x = x.permute(1, 0)  # (Sx, B, E)
        x = self.embedding(x) * math.sqrt(self.dim)  # (Sx, B, E)
        x = self.pos_encoder(x)  # (Sx, B, E)
        x = self.transformer_encoder(x)  # (Sx, B, E)
        return x

    def decode(self, y: torch.Tensor, encoded_x: torch.Tensor) -> torch.Tensor:
        """
        Input
            encoded_x: (Sx, B, E)
            y: (B, Sy) with elements in (0, C) where C is num_classes
        Output
            (Sy, B, C) logits
        """
        y = y.permute(1, 0)  # (Sy, B)
        y = self.embedding(y) * math.sqrt(self.dim)  # (Sy, B, E)
        y = self.pos_encoder(y)  # (Sy, B, E)
        Sy = y.shape[0]
        y_mask = self.y_mask[:Sy, :Sy].type_as(encoded_x)  # (Sy, Sy)
        output = self.transformer_decoder(y, encoded_x, y_mask)  # (Sy, B, E)
        output = self.fc(output)  # (Sy, B, C)
        return output

    def predict(self, x: torch.Tensor) -> torch.Tensor:
        """
        Method to use at inference time. Predict y from x one token at a time. This method is greedy
        decoding. Beam search can be used instead for a potential accuracy boost.

        Input
            x: (B, Sx) with elements in (0, C) where C is num_classes
        Output
            (B, C, Sy) logits
        """
        encoded_x = self.encode(x)
        
        output_tokens = (torch.ones((x.shape[0], self.max_output_length))).type_as(x).long() # (B, max_length)
        output_tokens[:, 0] = 0  # Set start token
        for Sy in range(1, self.max_output_length):
            y = output_tokens[:, :Sy]  # (B, Sy)
            output = self.decode(y, encoded_x)  # (Sy, B, C)
            output = torch.argmax(output, dim=-1)  # (Sy, B)
            output_tokens[:, Sy] = output[-1:]  # Set the last output token
        return output_tokens

# lightning Model
class LitModel(pl.LightningModule):
    """Simple PyTorch-Lightning model to train our Transformer."""

    def __init__(self, model, padding_index):
        super().__init__()
        self.save_hyperparameters()
        self.model = model
        self.loss = nn.CrossEntropyLoss(ignore_index=padding_index)

    def training_step(self, batch, batch_ind):
        x, y = batch.src, batch.trg
        x = x.to(self.device)
        y = y.to(self.device)
        logits = self.model(x, y)
        loss = self.loss(logits, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_ind):
        x, y = batch.src, batch.trg
        x = x.to(self.device)
        y = y.to(self.device)
        logits = self.model(x, y)
        loss = self.loss(logits, y)
        self.log("val_loss", loss, prog_bar=True)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())

In [34]:
# define where the data would downloaded and accessed on the local file system
data_dir = Path('.')

def english_tokenizer(text):
    return text.split(' ')
    en_model = spacy.load('en_core_web_sm')
    return [token.text for token in en_model.tokenizer(text)]
def german_tokenizer(text):
    return text.split(' ')
    de_model = spacy.load('de_core_news_sm')
    return [token.text for token in de_model.tokenizer(text)]

# create data module
batch_size = 64
dm = TranslationDataModule(data_dir, batch_size, True, english_tokenizer, german_tokenizer)
dm.prepare_data(max_train_data_size=1_00_000)
dm.setup()

source_vocab = dm.src_field.vocab
target_vocab = dm.trgt_field.vocab
target_vocab_size = len(dm.trgt_field.vocab)

number of sentence pairs in reduced dataset = 54929
number of sentence pairs in reduced dataset = 1941


In [None]:
load_from_checkpoint = False

# create transformer model
model = Transformer(num_classes=target_vocab_size, max_output_length=32)

# create lightning model
if load_from_checkpoint:
    checkpoint_path= checkpoint_callback.best_model_path
    print(f'loading pretrained model from {checkpoint_path}')
    LitModel.load_from_checkpoint(checkpoint_path)
else:
    pad_index = source_vocab.stoi[dm.src_field.pad_token]
    lit_model = LitModel(model, padding_index=pad_index)

# create trainer
early_stop_callback = pl.callbacks.EarlyStopping(monitor='val_loss')
checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor='val_loss')

trainer = pl.Trainer(
    max_epochs=5, 
    gpus=2,
    num_nodes=1,
    callbacks=[early_stop_callback, checkpoint_callback], 
    progress_bar_refresh_rate=79,
    distributed_backend='dp'
    )

# train
trainer.fit(lit_model, datamodule=dm)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type             | Params
-------------------------------------------
0 | model | Transformer      | 6.5 M 
1 | loss  | CrossEntropyLoss | 0     
-------------------------------------------
6.5 M     Trainable params
0         Non-trainable params
6.5 M     Total params


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [None]:
english_tokenizer = dm.english_tokenizer
german_tokenizer = dm.german_tokenizer

source_vocab = dm.src_field.vocab
target_vocab = dm.trgt_field.vocab

sentence = 'a republican strategy to counter the re-election of obama.'
sentence = 'the country agreed on a common strategy'
tokens = english_tokenizer(sentence)
indices =  [source_vocab.stoi[token] for token in tokens]  + [source_vocab.stoi['<eos>']] 
x = torch.tensor(indices).unsqueeze(0)

# get the model
checkpoint_path = checkpoint_callback.best_model_path
lit_model = LitModel.load_from_checkpoint(checkpoint_path)

prediction = lit_model.model.predict(x)
prediction = prediction.squeeze()
target_tokens = []
for index in prediction.numpy():
    token = target_vocab.itos[index]
    if token == dm.trgt_field.eos_token:
        print('breaking')
        break
    target_tokens.append(token)
target_sentence = ' '.join(target_tokens)
target_sentence