In [115]:
import os
import functools
import itertools
import random
import math
import pdb

import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

import spacy
from cached_property import cached_property

In [82]:
nlp = spacy.load(
    "en_core_web_lg",
    disable=["tagger", "ner", "textcat"]
)

In [83]:
articles = pd.read_parquet("data/articles-processed.parquet.gzip")

In [84]:
class DiscriminatorNet(nn.Module):
    def __init__(self):
        super(DiscriminatorNet, self).__init__()

    def forward(self, hidden_state):
        """
        The forward pass for the network
        
        hidden_state : tensor (batch_num, hidden_size)
        
        returns         : tensor (batch_num, 1)
        """
        
        return hidden_state

In [85]:
class SummarizeNet(nn.Module):
    def __init__(self, hidden_size):
        super(SummarizeNet, self).__init__()
        
        self.hidden_size = hidden_size

    def forward(self, word_embeddings, generate="sentence"):
        """
        The forward pass for the network
        
        word_embeddings : tensor (batch_num, max_seq_len, vocab_len)
        
        returns         : tuple (
                            tensor (batch_num, max_seq_len, vocab_len),
                            tensor (batch_num, hidden_size)
                          )
        
        First tensor in the returning tuple is a probability over the vocabulary
        for each sequence position
        
        The second tensor is an encoder's hidden state 
        """
        
        return x

In [86]:
class ArticlesDataset(Dataset):
    def __init__(self, dataframe, mode, transforms=[]):
        self.data = dataframe[dataframe.set == mode]
        self.transforms = transforms
        self.mode = mode
        
    def __len__(self):
        return 2*len(self.data)
    
    def __getitem__(self, idx):
        _idx = []
        
        if torch.is_tensor(idx):
            _idx = idx.tolist()
        
        if isinstance(idx, list):
            _idx = idx
        else:
            _idx = [ idx ]
        
        _ids = [ (i - (i % 2))/2 for i in _idx]

        data = self.data.iloc[_ids, :]
        data['asked_id'] = _idx
        
        data = pd.DataFrame(
            {
                'set': [self.mode for _ in range(0, len(_ids))],
                'mode': [ (0 if i % 2 == 0 else 1) for i in _idx ],
                'text': data.apply(lambda row: row['text'] if row['asked_id'] % 2 == 0 else row['headline'], axis=1),
                'title': data['normalized_title']
            }
        )

        for transform in self.transforms:
            data = transform(data)

        return data

In [87]:
class TextToParsedDoc(object):
    def __init__(self, nlp):
        self.nlp = nlp
        
    def __call__(self, sample):
        sample['doc'] = sample.apply(lambda row: self.nlp(row['text']), axis=1)
        return sample

In [88]:
class WordsToVectors(object):
    def __init__(self, nlp):
        self.nlp = nlp

    def __call__(self, sample):
        
        sample['word_embeddings'] = sample.apply(
            lambda row: np.stack([token.vector for token in row['doc']]),
            axis=1
        )
    
        return sample

In [89]:
class AddNoiseToEmbeddings(object):
    def __init__(self, probability_of_mask_for_word):
        self.probability_of_mask_for_word = probability_of_mask_for_word
        self.rng = np.random.default_rng()
        
    def mask_vector(self, vector):
        """
        Masks words with zeros randomly
        """
        seq_len = vector.shape[0]
        vector_len = vector.shape[1]
        
        mask = np.repeat(
            self.rng.choice(
                [0, 1],
                seq_len,
                p=[
                    self.probability_of_mask_for_word,
                    (1 - self.probability_of_mask_for_word)
                ]
            ).reshape((seq_len, 1)),
            vector_len,
            axis=1
        )
        
        return vector * mask
        
    def __call__(self, sample):       
        sample['noisy_word_embeddings'] = sample['word_embeddings'].apply(self.mask_vector)

        return sample

In [90]:
class MergeBatch(object):
    def stack_vectors(self, vectors):
        max_seq = max([vector.shape[0] for vector in vectors])
        
        return np.stack(
            [
                np.pad(vector, [(0, max_seq - vector.shape[0]), (0, 0)])
                for vector in vectors
            ]
        )
        
    def __call__(self, sample):
        del sample['doc']
        
        sample = sample.to_dict(orient="list")
        
        sample['word_embeddings'] = self.stack_vectors(sample['word_embeddings'])
        sample['noisy_word_embeddings'] = self.stack_vectors(sample['noisy_word_embeddings'])
    
        return sample

In [125]:
class DataLoader(object):
    def __init__(self, dataset, batch_size=8, num_workers=1):
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_workers = num_workers
        
    @property
    def epoch_size(self):
        return math.ceil(len(self.dataset) / self.batch_size) * self.batch_size
    
    def __iter__(self):
        ids = random.choices(range(0, len(self.dataset)), k=self.epoch_size)
        
        for start_ix in range(0, self.epoch_size, self.batch_size):
            yield self.dataset[ids[start_ix:(start_ix + self.batch_size)]]

In [52]:
class ArticlesBatch:
    def __init__(self, data, id):
        self.data = data
        self.id = id
        
    def __getattr__(self, name):
        if name in self.data:
            return self.data[name]
        else:
            raise AttributeError(f"Attribute missing: {name}")

In [None]:
class UpdateInfo(object):
    def __init__(self, batch, loss_sum, mode):
        self.batch = batch
        self.loss_sum = loss_sum
        self.mode = mode

    def __str__(self):
        return f"{self.mode} | {self.batch.id}\t| Loss: {loss_sum}\t"

In [130]:
class BaseTrainer:
    def __init__(self, name, nlp, dataframe):
        self.name = name
        
        self.datasets = {
            "train": ArticlesDataset(
                dataframe,
                "train",
                transforms=[
                    TextToParsedDoc(nlp),
                    WordsToVectors(nlp),
                    AddNoiseToEmbeddings(probability_of_mask_for_word),
                    MergeBatch()
                ]
            ),
            "test":  ArticlesDataset(
                dataframe,
                "test",
                transforms=[
                    TextToParsedDoc(nlp),
                    WordsToVectors(nlp),
                    AddNoiseToEmbeddings(0),
                    MergeBatch()
                ]
            ),
            "eval":  ArticlesDataset(
                dataframe,
                "eval",
                transforms=[
                    TextToParsedDoc(nlp),
                    WordsToVectors(nlp),
                    AddNoiseToEmbeddings(0),
                    MergeBatch()
                ]
            )
        }
        
        self.current_batch_id = 0
        
    def configure(self, summarize_model_kwargs, discriminate_model_kwargs,
                 batch_size, update_every, save_every, loader_workers,
                 probability_of_mask_for_word):
        
        self.batch_size = batch_size
        self.update_every = update_every
        self.save_every = save_every
        self.loader_workers = loader_workers
        
        self.summarize_model_kwargs = summarize_model_kwargs
        self.discriminate_model_kwargs = discriminate_model_kwargs
        
    @cachedproperty
    def summarize_model(self):
        pass
    
    @cachedproperty
    def discriminate_model(self):
        pass
    
    def save_checkpoint(self):
        checkpoint_path = f"checkpoints/{self.name}/batch-#{self.current_batch_id}"
        os.makedirs(checkpoint_path, exist_ok=True)
        
        torch.save(
            {
                'current_batch_id': self.current_batch_id,
                'batch_size': self.batch_size,
                'update_every': self.update_every,
                'save_every': self.save_every,
                'lambda_article': self.lambda_article,
                'lambda_sentence': self.lambda_sentence,
                'summarize_model_state': self.summarize_model.state_dict(),
                'discriminate_model_state': self.discriminate_model.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict()
            },
            f"{checkpoint_path}/state.pth"
        )
        
    def load_checkpoint(name, dataframe):
        raise NotImplementedError
    
    def batches(self, mode):
        start_id = self.current_batch_id
        
        while True:
            loader = DataLoader(
                self.datasets[mode],
                batch_size=self.batch_size,
                num_workers=self.loader_workers
            )

            for ix, data in enumerate(loader):
                self.current_batch_id += ix
                
                yield(
                    ArticlesBatch(
                        data,
                        id=self.current_batch_id
                    )
                )
        
    
    def batch_loss(self, batch):
        raise NotImplementedError
    
    def train_and_evaluate_updates(self, evaluate_every=100):
        train_updates = self.updates(mode="train")
        evaluate_updates = self.updates(mode="eval")
        
        for update_info in train_updates:
            yield(update_info)
            
            if update_info.batch.id % evaluate_every == 0:
                yield(next(evaluate_updates))
                
    def updates(self, mode="train"):
        batches = self.batches(mode)
        loss_sum = 0
        
        for batch in batches:
            loss = self.batch_loss(batch) / (self.update_every * self.batch_size)
            
            loss.backward()
            loss_sum += loss
            
            # we're doing the accumulated gradients trick to get the gradients variance
            # down while being able to use commodity GPU:
            if batch.id % self.update_every == 0:
                if mode == "train":
                    self.optimizer.step()
                    self.optimizer.zero_grad()
                
                yield(UpdateInfo(batch, loss_sum, mode=mode))
                
                loss_sum = 0
    
    def test(self):
        raise NotImplementedError

In [None]:
class Trainer(BaseTrainer):
    def __init__(*args, **kwargs):
        super().__init__(*args, **kwargs)
        
    def compute_loss(self, word_embeddings, original_word_embeddings, discriminate_probs):
        embeddings_loss = F.cosine_embedding_loss(
            word_embeddings,
            orig_word_embeddings,
            torch.ones(word_embeddings.shape[0])
        )
        
        discriminator_loss = F.binary_cross_entropy(
            discriminate_probs,
            torch.zeros_like(discriminate_probs)
        )
        
        return embeddings_loss + discriminator_loss
        

    def batch_loss(self, batch):
        word_embeddings, state = self.summarize_model(
            batch.noisy_word_embeddings
        )

        # the discriminator guessing which mode (article or headline) was one state
        # created for to deal with the "segregation" problem described in the paper:
        discriminate_probs = self.discriminate_model(articles_state)

        # we're diverging from the article here by outputting the word embeddings
        # instead of the probabilities for each word in a vocabulary
        # our loss function is using the cosine embedding loss coupled with
        # the discriminator loss:
        return self.compute_loss(
            word_embeddings,
            batch.word_embeddings,
            discriminate_probs
        )

In [None]:
trainer = Trainer('first-try', nlp, articles, 8)

cumulative_train_info = UpdateInfo.empty()
cumulative_evaluate_info = UpdateInfo.empty()

for update_info in trainer.train_and_evaluate_updates():
    if update_info.from_train:
        cumulative_train_info += update_info
        
        print(f"{cumulative_train_info}")
    
    if update_info.from_evaluate:
        cumulative_evaluate_info += update_info
        
        print(f"{cumulative_evaluate_info}")
        
        trainer.save_checkpoint()