In [1]:
import os
import functools
import random
import pdb

import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import spacy
from cached_property import cached_property

In [2]:
nlp = spacy.load(
    "en_core_web_lg",
    disable=["tagger", "ner", "textcat"]
)

In [3]:
articles = pd.read_parquet("data/articles-processed.parquet.gzip")

In [257]:
2 % 2 == 0

True

In [4]:
class DiscriminatorNet(nn.Module):
    def __init__(self):
        super(DiscriminatorNet, self).__init__()

    def forward(self, hidden_state):
        """
        The forward pass for the network
        
        hidden_state : tensor (batch_num, hidden_size)
        
        returns         : tensor (batch_num, 1)
        """
        
        return hidden_state

In [5]:
class SummarizeNet(nn.Module):
    def __init__(self, hidden_size):
        super(SummarizeNet, self).__init__()
        
        self.hidden_size = hidden_size

    def forward(self, word_embeddings, generate="sentence"):
        """
        The forward pass for the network
        
        word_embeddings : tensor (batch_num, max_seq_len, vocab_len)
        
        returns         : tuple (
                            tensor (batch_num, max_seq_len, vocab_len),
                            tensor (batch_num, hidden_size)
                          )
        
        First tensor in the returning tuple is a probability over the vocabulary
        for each sequence position
        
        The second tensor is an encoder's hidden state 
        """
        
        return x

In [6]:
articles

Unnamed: 0_level_0,headline,text,normalized_title,set
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
2,"\nGet a bachelor’s degree.,\nEnroll in a studi...",It is possible to become a VFX artist without...,HowtoBeaVisualEffectsArtist1,train
4,"\nKeep your reference materials, sketches, art...","As you start planning for a project or work, ...",HowtoBeanOrganizedArtist2,train
6,"\nCreate a compelling reel or portfolio.,\nLan...",This should be a short video showcasing the b...,HowtoBeaVisualEffectsArtist2,train
7,"\nJoin a professional society.,\nEnjoy working...",Networking is a great way to find new opportu...,HowtoBeaVisualEffectsArtist3,train
9,"\nMake a list of what your friends watch, read...",Use your friends’ conversations to figure out...,HowtoAlwaysCatchPopCultureReferences1,train
...,...,...,...,...
215359,"\nUse a childhood nickname.,\nUse your middle ...",You may have been called something other than...,HowtoPickaStageName2,train
215360,\nConsider changing the spelling of your name....,"If you have a name that you like, you might f...",HowtoPickaStageName3,train
215361,"\nTry out your name.,\nDon’t legally change yo...",Your name might sound great to you when you s...,HowtoPickaStageName4,train
215362,"\nUnderstand the process of relief printing.,\...",Relief printing is the oldest and most tradit...,HowtoIdentifyPrints1,train


In [7]:
class ArticlesDataset(Dataset):
    def __init__(self, dataframe, mode, transforms=[]):
        self.data = dataframe[dataframe.set == mode]
        self.transforms = transforms
        self.mode = mode
        
    def __len__(self):
        return 2*len(self.data)
    
    def __getitem__(self, idx):
        _idx = []
        
        if torch.is_tensor(idx):
            _idx = idx.tolist()
        
        if isinstance(idx, list):
            _idx = idx
        else:
            _idx = [ idx ]
        
        _ids = [ (i - (i % 2))/2 for i in _idx]

        data = self.data.iloc[_ids, :]
        data['asked_id'] = _idx
        
        data = pd.DataFrame(
            {
                'set': [self.mode for _ in range(0, len(_ids))],
                'mode': [ (0 if i % 2 == 0 else 1) for i in _idx ],
                'text': data.apply(lambda row: row['text'] if row['asked_id'] % 2 == 0 else row['headline'], axis=1),
                'title': data['normalized_title']
            }
        )

        for transform in self.transforms:
            data = transform(data)

        return data

In [8]:
class TextToParsedDoc(object):
    def __init__(self, nlp):
        self.nlp = nlp
        
    def __call__(self, sample):
        sample['doc'] = sample.apply(lambda row: self.nlp(row['text']), axis=1)
        return sample

In [39]:
class WordsToVectors(object):
    def __init__(self, nlp):
        self.nlp = nlp

    def __call__(self, sample):
        
        sample['word_embeddings'] = sample.apply(
            lambda row: np.stack([token.vector for token in row['doc']]),
            axis=1
        )
    
        return sample

In [38]:
class AddNoiseToEmbeddings(object):
    def __init__(self, probability_of_mask_for_word):
        self.probability_of_mask_for_word = probability_of_mask_for_word
        self.rng = np.random.default_rng()
        
    def mask_vector(self, vector):
        """
        Masks words with zeros randomly
        """
        seq_len = vector.shape[0]
        vector_len = vector.shape[1]
        
        mask = np.repeat(
            self.rng.choice(
                [0, 1],
                seq_len,
                p=[
                    self.probability_of_mask_for_word,
                    (1 - self.probability_of_mask_for_word)
                ]
            ).reshape((seq_len, 1)),
            vector_len,
            axis=1
        )
        
        return vector * mask
        
    def __call__(self, sample):       
        sample['noisy_word_embeddings'] = sample['word_embeddings'].apply(self.mask_vector)

        return sample

In [49]:
class MergeBatch(object):
    def stack_vectors(self, vectors):
        max_seq = max([vector.shape[0] for vector in vectors])
        
        return np.stack(
            [
                np.pad(vector, [(0, max_seq - vector.shape[0]), (0, 0)])
                for vector in vectors
            ]
        )
        
    def __call__(self, sample):
        del sample['doc']
        
        sample = sample.to_dict(orient="list")
        
        sample['word_embeddings'] = self.stack_vectors(sample['word_embeddings'])
        sample['noisy_word_embeddings'] = self.stack_vectors(sample['noisy_word_embeddings'])
    
        return sample

In [None]:
ArticlesDataset(
    articles,
    mode="train",
    transforms=[
        TextToParsedDoc(nlp),
        WordsToVectors(nlp),
        AddNoiseToEmbeddings(0.2),
        MergeBatch()
    ]
)[[0,1,2,3,4]]

In [None]:
class ArticlesBatch:
    def __init__(self, data, id):
        self.data = data
        self.id = id
        
    def __getattr__(self, name):
        if name in self.data:
            return self.data[name]
        else:
            raise AttributeError(f"Attribute missing: {name}")

In [None]:
class BaseTrainer:
    def __init__(self, name, nlp, summarize_model, discriminate_model, dataframe,
                 batch_size, update_every, save_every, loader_workers,
                 probability_of_mask_for_word,
                 lambda_article, lambda_sentence):
        self.name = name
        
        self.datasets = {
            "train": ArticlesDataset(
                dataframe,
                "train",
                transforms=[
                    TextToParsedDoc(nlp),
                    WordsToVectors(nlp),
                    AddNoiseToEmbeddings(
                        probability_of_mask_for_word
                    )
                ]
            ),
            "test":  ArticlesDataset(
                dataframe,
                "test",
                transforms=[
                    TextToParsedDoc(nlp),
                    WordsToVectors(nlp)
                ]
            ),
            "eval":  ArticlesDataset(
                dataframe,
                "eval",
                transforms=[
                    TextToParsedDoc(nlp),
                    WordsToVectors(nlp)
                ]
            )
        }
        
        self.batch_size = batch_size
        self.update_every = update_every
        self.save_every = save_every
        self.loader_workers = loader_workers
        
        self.summarize_model = summarize_model
        self.discriminate_model = discriminate_model
        
        self.lambda_article = lambda_article
        self.lambda_sentence = lambda_sentence
        
        self.current_batch_id = 0
        
    @property
    def models(self):
        return self.summarize_model, self.discriminate_model
    
    def save(self):
        checkpoint_path = f"checkpoints/{self.name}/batch-#{self.current_batch_id}"
        os.makedirs(checkpoint_path, exist_ok=True)
        
        torch.save(
            {
                'current_batch_id': self.current_batch_id,
                'batch_size': self.batch_size,
                'update_every': self.update_every,
                'save_every': self.save_every,
                'lambda_article': self.lambda_article,
                'lambda_sentence': self.lambda_sentence,
                'summarize_model_state': self.summarize_model.state_dict(),
                'discriminate_model_state': self.discriminate_model.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict()
            },
            f"{checkpoint_path}/state.pth"
        )
        
    def load(name, dataframe):
        raise NotImplementedError
    
    def batches(self, mode):
        start_id = self.current_batch_id
        
        while True:
            loader = DataLoader(
                self.datasets[mode],
                batch_size=self.batch_size,
                shuffle=True,
                num_workers=self.loader_workers
            )

            for ix, data in enumerate(loader):
                self.current_batch_id += ix
                
                yield(
                    ArticlesBatch(
                        data,
                        id=self.current_batch_id
                    )
                )
        
    
    def batch_loss(self, batch):
        raise NotImplementedError
        
    def after_update(self, batch, loss_sum):
        pass
    
    def train(self):
        batches = self.batches("train")
        loss_sum = 0
        
        for batch in batches:
            loss = self.batch_loss(batch) / (self.update_every * self.batch_size)
            
            loss.backward()
            loss_sum += loss
            
            # we're doing the accumulated gradients trick to get the gradients variance
            # down while being able to use commodity GPU:
            if batch.id % self.update_every == 0:
                self.optimizer.step()
                self.optimizer.zero_grad()
                
                self.after_update(batch, loss_sum)
                
                loss_sum = 0
    
    def test(self):
        raise NotImplementedError
    
    def evaluate(self):
        raise NotImplementedError

In [None]:
class Trainer(BaseTrainer):
    def __init__(self, name, nlp, summarize_model, discriminate_model, dataframe,
                 batch_size, update_every, save_every, loader_workers,
                 probability_of_mask_for_word, probability_of_masking_for_sample,
                 lambda_article, lambda_sentence):
        super().__init__(
            name, nlp,
            summarize_model, discriminate_model, dataframe,
            batch_size, update_every, save_every, loader_workers,
            lambda_article, lambda_sentence
        )
        
    def compute_loss(self, articles_word_embeddings, orig_articles_word_embeddings,
                     sentences_word_embeddings, orig_sentences_word_embeddings,
                     sentences_numbers_in_articles,
                     discriminate_articles_probs,
                     discriminate_sentences_probs
                    ):
        articles_loss = F.cosine_embedding_loss(
            articles_word_embeddings,
            orig_articles_word_embeddings,
            torch.ones(articles_word_embeddings.shape[0])
        )
        
        sentences_loss = F.cosine_embedding_loss(
            sentences_word_embeddings,
            orig_sentences_word_embeddings,
            torch.ones(articles_word_embeddings.shape[0])
        ) / sentences_numbers_in_articles
        
        discriminator_articles_loss = F.binary_cross_entropy(
            discriminate_articles_probs,
            torch.zeros_like(discriminate_articles_probs)
        )
        
        discriminator_sentences_loss = F.binary_cross_entropy(
            discriminate_sentences_probs,
            torch.zeros_like(discriminate_sentences_probs)
        )
        
        return (articles_loss * self.lambda_article).sum(dim=0) +
               (sentences_loss * self.lambda_sentence).sum(dim=0) +
               discriminator_articles_loss +
               discriminator_sentences_loss
        

    def batch_loss(self, batch):
        # article -> article (de-noising)
        articles_word_embeddings, articles_state = self.summarize_model(
            batch.articles_noisy_word_embeddings,
            generate="article"
        )

        # headline -> headline (de-noising)
        sentences_word_embeddings, sentences_state = self.summarize_model(
            batch.sentences_noisy_word_embeddings,
            generate="sentence"
        )

        # the discriminator guessing which mode (article or sentence) was one state
        # created for to deal with the "segregation" problem described in the paper:
        discriminate_articles_probs = self.discriminate_model(articles_state)
        discriminate_sentences_probs = self.discriminate_model(sentences_state)

        # we're diverging from the article here by outputting the word embeddings
        # instead of the probabilities for each word in a vocabulary
        # our loss function is using the cosine embedding loss coupled with
        # the discriminator loss:
        return self.compute_loss(
            articles_word_embeddings,
            batch.articles_word_embeddings,

            sentences_word_embeddings,
            batch.sentences_word_embeddings,
            batch.sentences_numbers_in_articles,

            discriminate_articles_probs,
            discriminate_sentences_probs
        )

In [None]:
trainer = Trainer(articles, 8)

for epoch in trainer.train():
    pass