In [136]:
import os
import functools
import itertools
import random
import math
from pathlib import Path
import pdb

import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

import spacy
from cached_property import cached_property

from rouge import Rouge

In [82]:
nlp = spacy.load(
    "en_core_web_lg",
    disable=["tagger", "ner", "textcat"]
)

In [83]:
articles = pd.read_parquet("data/articles-processed.parquet.gzip")

In [135]:
class NNModel(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        
        self._args = args
        self._kwargs = kwargs
        
    def save(self, path):
        torch.save(
            {
                'state': self.state_dict(),
                'args': self._args,
                'kwargs': self._kwargs
            },
            path
        )
    
    @classmethod
    def load(cls, path):
        if Path(path).exists():
            data = torch.load(path)

            model = cls(*data['args'], **data['kwargs'])
            model.load_state_dict(checkpoint['state'])

            return model
        else:
            raise FileNotFoundError

In [84]:
class DiscriminatorNet(NNModel):
    def __init__(self):
        super(DiscriminatorNet, self).__init__()

    def forward(self, hidden_state):
        """
        The forward pass for the network
        
        hidden_state : tensor (batch_num, hidden_size)
        
        returns         : tensor (batch_num, 1)
        """
        
        return hidden_state

In [85]:
class SummarizeNet(NNModel):
    def __init__(self, hidden_size):
        super(SummarizeNet, self).__init__()
        
        self.hidden_size = hidden_size

    def forward(self, word_embeddings, generate="sentence"):
        """
        The forward pass for the network
        
        word_embeddings : tensor (batch_num, max_seq_len, vocab_len)
        
        returns         : tuple (
                            tensor (batch_num, max_seq_len, vocab_len),
                            tensor (batch_num, hidden_size)
                          )
        
        First tensor in the returning tuple is a probability over the vocabulary
        for each sequence position
        
        The second tensor is an encoder's hidden state 
        """
        
        return x

In [86]:
class ArticlesDataset(Dataset):
    def __init__(self, dataframe, mode, transforms=[]):
        self.data = dataframe[dataframe.set == mode]
        self.transforms = transforms
        self.mode = mode
        
    def __len__(self):
        return 2*len(self.data)
    
    def __getitem__(self, idx):
        _idx = []
        
        if torch.is_tensor(idx):
            _idx = idx.tolist()
        
        if isinstance(idx, list):
            _idx = idx
        else:
            _idx = [ idx ]
        
        _ids = [ (i - (i % 2))/2 for i in _idx]

        data = self.data.iloc[_ids, :]
        data['asked_id'] = _idx
        
        data = pd.DataFrame(
            {
                'set': [self.mode for _ in range(0, len(_ids))],
                'mode': [ (0 if i % 2 == 0 else 1) for i in _idx ],
                'text': data.apply(lambda row: row['text'] if row['asked_id'] % 2 == 0 else row['headline'], axis=1),
                'title': data['normalized_title']
            }
        )

        for transform in self.transforms:
            data = transform(data)

        return data

In [87]:
class TextToParsedDoc(object):
    def __init__(self, nlp):
        self.nlp = nlp
        
    def __call__(self, sample):
        sample['doc'] = sample.apply(lambda row: self.nlp(row['text']), axis=1)
        return sample

In [88]:
class WordsToVectors(object):
    def __init__(self, nlp):
        self.nlp = nlp
        
    def document_embeddings(self, doc):
        word_embeddings = [
            [ l.vector ] if l.whitespace_ == '' else [ l.vector, np.zeros_like(l.vector) ] for l in doc
        ]

        return np.stack(
            [
                vector for vectors in word_embeddings for vector in vectors
            ]
        )

    def __call__(self, sample):
        
        sample['word_embeddings'] = sample.apply(
            lambda row: self.document_embeddings(row['doc']),
            axis=1
        )
    
        return sample

In [89]:
class AddNoiseToEmbeddings(object):
    def __init__(self, probability_of_mask_for_word):
        self.probability_of_mask_for_word = probability_of_mask_for_word
        self.rng = np.random.default_rng()
        
    def mask_vector(self, vector):
        """
        Masks words with zeros randomly
        """
        seq_len = vector.shape[0]
        vector_len = vector.shape[1]
        
        mask = np.repeat(
            self.rng.choice(
                [0, 1],
                seq_len,
                p=[
                    self.probability_of_mask_for_word,
                    (1 - self.probability_of_mask_for_word)
                ]
            ).reshape((seq_len, 1)),
            vector_len,
            axis=1
        )
        
        return vector * mask
        
    def __call__(self, sample):       
        sample['noisy_word_embeddings'] = sample['word_embeddings'].apply(self.mask_vector)

        return sample

In [90]:
class MergeBatch(object):
    def stack_vectors(self, vectors):
        max_seq = max([vector.shape[0] for vector in vectors])
        
        return np.stack(
            [
                np.pad(vector, [(0, max_seq - vector.shape[0]), (0, 0)])
                for vector in vectors
            ]
        )
        
    def __call__(self, sample):
        del sample['doc']
        
        sample = sample.to_dict(orient="list")
        
        sample['word_embeddings'] = self.stack_vectors(sample['word_embeddings'])
        sample['noisy_word_embeddings'] = self.stack_vectors(sample['noisy_word_embeddings'])
    
        return sample

In [125]:
class DataLoader(object):
    def __init__(self, dataset, batch_size=8, num_workers=1):
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_workers = num_workers
        
    @property
    def epoch_size(self):
        return math.ceil(len(self.dataset) / self.batch_size) * self.batch_size
    
    def __iter__(self):
        ids = random.choices(range(0, len(self.dataset)), k=self.epoch_size)
        
        for start_ix in range(0, self.epoch_size, self.batch_size):
            yield self.dataset[ids[start_ix:(start_ix + self.batch_size)]]

In [52]:
class ArticlesBatch:
    def __init__(self, data, id):
        self.data = data
        self.id = id
        
    def __getattr__(self, name):
        if name in self.data:
            return self.data[name]
        else:
            raise AttributeError(f"Attribute missing: {name}")

In [227]:
class Decoder(object):
    def __init__(self, nlp):
        self.nlp = nlp
        
    def decode_embeddings(self, word_embeddings):
        """
        Decodes a single document. Word embeddings given are of shape (N, D)
        where N is the number of lexemes and D the dimentionality of the embedding vector
        """
        
        return "".join(
            [
                token.text.lower() if not token.is_oov else " "
                for token in [
                    self.nlp.vocab[ks[0]]
                    for ks in self.nlp.vocab.vectors.most_similar(
                        word_embeddings, n=1
                    )[0]
                ]
            ]
        ).strip()

In [None]:
class UpdateInfo(object):
    def __init__(self, decoder, batch, word_embeddings, loss_sum, mode):
        self.decoder = decoder
        self.batch = batch
        self.word_embeddings = word_embeddings
        self.loss_sum = loss_sum
        self.mode = mode
        
    @classmethod
    def empty(cls, nlp, mode):
        return UpdateInfo(nlp, None, None, None, mode)

    def __str__(self):
        return f"{self.mode} | {self.batch.id}\t| Loss: {loss_sum}\t"
    
    def __add__(self, other):
        pass

In [130]:
class BaseTrainer:
    def __init__(self, name, nlp, dataframe,
                 optimizer_class_name,
                 model_args, optimizer_args, 
                 batch_size, update_every, loader_workers,
                 probability_of_mask_for_word
                ):
        self.name = name
        
        self.datasets = {
            "train": ArticlesDataset(
                dataframe,
                "train",
                transforms=[
                    TextToParsedDoc(nlp),
                    WordsToVectors(nlp),
                    AddNoiseToEmbeddings(probability_of_mask_for_word),
                    MergeBatch()
                ]
            ),
            "test":  ArticlesDataset(
                dataframe,
                "test",
                transforms=[
                    TextToParsedDoc(nlp),
                    WordsToVectors(nlp),
                    AddNoiseToEmbeddings(0),
                    MergeBatch()
                ]
            ),
            "eval":  ArticlesDataset(
                dataframe,
                "eval",
                transforms=[
                    TextToParsedDoc(nlp),
                    WordsToVectors(nlp),
                    AddNoiseToEmbeddings(0),
                    MergeBatch()
                ]
            )
        }
        
        self.batch_size = batch_size
        self.update_every = update_every
        self.loader_workers = loader_workers
        
        self.optimizer_class_name = optimizer_class_name
        
        self.model_args = model_args
        self.optimizer_args = optimizer_args
        
        self.current_batch_id = 0
        
        self.decoder = Decoder(nlp)
        
        if self.has_checkpoint:
            self.load_last_checkpoint()
        
    @cached_property
    def model(self):
        try:
            return SummarizeNet.load(f"{self.checkpoint_path}/model.pth")
        except FileNotFoundError:
            return SummarizeNet(self.model_args)
        
    @cached_property
    def optimizer(self):
        class_ = getattr(torch.optim, self.optimizer_class_name)
        
        return class_(self.model.parameters(), **self.optimizer_args)
    
    @property
    def checkpoint_path(self):
        return f"checkpoints/{self.name}/batch-#{self.current_batch_id}"
    
    def save_checkpoint(self):
        os.makedirs(self.checkpoint_path, exist_ok=True)
        
        self.model.save(f"{self.checkpoint_path}/model.pth")
        
        torch.save(
            {
                'current_batch_id': self.current_batch_id,
                'batch_size': self.batch_size,
                'update_every': self.update_every,
                'optimizer_class_name': self.optimizer_class_name,
                'optimizer_args': self.optimizer_args,
                'optimizer_state_dict': self.optimizer.state_dict()
            },
            f"{self.checkpoint_path}/trainer.pth"
        )
    
    @property
    def checkpoint_directories(self):
        return sorted(f"checkpoints/{self.name}/batch-*", reverse=True)
    
    @property
    def has_checkpoint():
        return len(self.checkpoint_directories) > 0
    
    def load_last_checkpoint():
        path = next(self.checkpoint_directories)
        
        data = torch.load(f"{path}/trainer.pth")
        
        self.batch_size = data['current_batch_id']
        self.update_every = data['update_every']
        self.loader_workers = data['loader_workers']
        
        self.optimizer_class_name = data['optimizer_class_name']
        self.optimizer_args = data['optimizer_args']
        
        self.current_batch_id = data['current_batch_id']
        
        del self.__dict__['model']
        def self.__dict__['optimizer']
        
        self.optimizer.load_state_dict(data['optimizer_state_dict'])
    
    def batches(self, mode):
        start_id = self.current_batch_id
        
        while True:
            loader = DataLoader(
                self.datasets[mode],
                batch_size=self.batch_size,
                num_workers=self.loader_workers
            )

            for ix, data in enumerate(loader):
                self.current_batch_id += ix
                
                yield(
                    ArticlesBatch(
                        data,
                        id=self.current_batch_id
                    )
                )
    
    def work_batch(self, batch):
        raise NotImplementedError
        
    def updates(self, mode="train", update_every=None):
        batches = self.batches(mode)
        loss_sum = 0
        
        if update_every is None:
            update_every = self.update_every
        
        for batch in batches:
            if mode == "train":
                self.model.train()
            else:
                self.model.eval()
            
            loss, word_embeddings = self.work_batch(batch)
            loss /= self.update_every * self.batch_size
            
            loss.backward()
            loss_sum += loss
            
            # we're doing the accumulated gradients trick to get the gradients variance
            # down while being able to use commodity GPU:
            if batch.id % update_every == 0:
                if mode == "train":
                    self.optimizer.step()
                    self.optimizer.zero_grad()
                
                yield(UpdateInfo(self.decoder, batch, word_embeddings, loss_sum, mode=mode))
                
                loss_sum = 0
    
    def train_and_evaluate_updates(self, evaluate_every=100):
        train_updates = self.updates(mode="train")
        evaluate_updates = self.updates(mode="eval")
        
        for update_info in train_updates:
            yield(update_info)
            
            if update_info.batch.id % evaluate_every == 0:
                yield(next(evaluate_updates))
    
    def test_updates(self):
        return self.updates(mode="test", update_every=1)

In [None]:
class Trainer(BaseTrainer):
    def __init__(*args, **kwargs):
        super().__init__(*args, **kwargs)
        
    def compute_loss(self, word_embeddings, original_word_embeddings, discriminate_probs):
        embeddings_loss = F.cosine_embedding_loss(
            word_embeddings,
            orig_word_embeddings,
            torch.ones(word_embeddings.shape[0])
        )
        
        discriminator_loss = F.binary_cross_entropy(
            discriminate_probs,
            torch.zeros_like(discriminate_probs)
        )
        
        return embeddings_loss + discriminator_loss
        

    def work_batch(self, batch):
        word_embeddings, discriminate_probs = self.summarize_model(
            batch.noisy_word_embeddings
        )

        # we're diverging from the article here by outputting the word embeddings
        # instead of the probabilities for each word in a vocabulary
        # our loss function is using the cosine embedding loss coupled with
        # the discriminator loss:
        return (
            self.compute_loss(word_embeddings, batch.word_embeddings, discriminate_probs),
            word_embeddings
        )

In [None]:
class InNotebookTrainer(Trainer):
    def __init__(*args, **kwargs):
        super().__init__(*args, **kwargs)

    def train(self):
        cumulative_train_info = UpdateInfo.empty(self.decoder, mode="train")
        cumulative_evaluate_info = UpdateInfo.empty(self.decoder, mode="eval")

        for update_info in trainer.train_and_evaluate_updates():
            if update_info.from_train:
                cumulative_train_info += update_info

                print(f"{cumulative_train_info}")

            if update_info.from_evaluate:
                cumulative_evaluate_info += update_info

                print(f"{cumulative_evaluate_info}")

                trainer.save_checkpoint()
                
    def test(self):
        cumulative_info = UpdateInfo.empty(mode="test")

In [228]:
%%time

document = """
The recommended way to upgrade from Ubuntu 14.04 LTS is to first upgrade to 16.04 LTS, then to 18.04 LTS, which will continue to receive support until April 2023. Ubuntu has LTS -> LTS upgrades, allowing you to skip intermediate non-LTS releases, but we can’t skip intermediate LTS releases; we have to go via 16.04, unless we want to do a fresh install of 18.04 LTS.
"""

word_embeddings = [
    [ t.vector ] if t.whitespace_ == '' else [ t.vector, np.zeros_like(t.vector) ] for t in nlp(document)
]

word_embeddings = np.stack(
    [
        vector for vectors in word_embeddings for vector in vectors
    ]
)

print(word_embeddings.shape)

"".join(
    [
        token.text.lower() if not token.is_oov else " "
        for token in [
            nlp.vocab[ks[0]]
            for ks in nlp.vocab.vectors.most_similar(
                word_embeddings, n=1
            )[0]
        ]
    ]
).strip()

(145, 300)
CPU times: user 4.4 s, sys: 1.78 s, total: 6.18 s
Wall time: 1.37 s


'the recommended way to upgrade from ubuntu 14.04 lts is to first upgrade to 16.04 lts, then to 18.04 lts, which will continue to receive support until april 2023. ubuntu has lts -> lts upgrades, allowing you to skip intermediate non-lts releases, but we can’t skip intermediate lts releases; we have to go via 16.04, unless we want to do a fresh install of 18.04 lts.'

In [148]:
np.stack([ t.vector for t in nlp("This is just a test sentence. This is another sentence")])

array([[-0.087595 ,  0.35502  ,  0.063868 , ...,  0.03446  , -0.15027  ,
         0.40673  ],
       [-0.084961 ,  0.502    ,  0.0023823, ..., -0.21511  , -0.26304  ,
        -0.0060173],
       [-0.025563 ,  0.44424  , -0.24555  , ..., -0.029137 ,  0.062257 ,
         0.090782 ],
       ...,
       [-0.084961 ,  0.502    ,  0.0023823, ..., -0.21511  , -0.26304  ,
        -0.0060173],
       [-0.062456 ,  0.026028 , -0.2255   , ..., -0.19075  , -0.26296  ,
         0.32319  ],
       [-0.23011  ,  0.24952  , -0.40514  , ...,  0.049255 , -0.22886  ,
        -0.23064  ]], dtype=float32)

In [198]:
np.ones((1, 300)).all()

True