In [1]:
import torch

In [37]:
from fairseq.models import FairseqEncoder, FairseqEncoderDecoderModel

from fairseq.models.fconv import (
    Embedding,
    PositionalEmbedding,
    FConvDecoder
)
from torch.distributions import Normal


from restorant_dataset import RestDataset, lines_generator
from torchtext import data

In [29]:
class NoEncoder(FairseqEncoder):
    """
    The input contain:
        sequence of latent embedding indecies
        class index (positive / negative)
        embed the input and noise the sample embeddings.
    """
    def __init__(self, sample_size, padding_index, ntokens=5, embed_dim=512, noise_std=0.1, dropout=0.1):
        """
        number of latent-space tokens is constant.
        """
        super().__init__(None)
        self.dropout = dropout
        self.dim = embed_dim
        self.ntokens = ntokens
        
        self.content_embeddings = Embedding(sample_size, embed_dim * ntokens, padding_idx) # tokens-encoder, sample-specific
        
        self.negative_embedding = PositionalEmbedding(num_embeddings=ntokens, 
                                                     embedding_dim=embed_dim, 
                                                     padding_idx=padding_index)
        
        self.positive_embedding = PositionalEmbedding(num_embeddings=ntokens, 
                                                     embedding_dim=embed_dim, 
                                                     padding_idx=padding_index)
        
        self.noise = Normal(loc=0.0, scale=noise_std)
        
    def forward(self, src_tokens, src_lengths):
        """
        src_tokens are two: one for the sentiment (0 or 1),
                            and one for the sample [0.. sample_size]
                            shape is always (batch, 2)
        src_lengths is (batch)-size array full of 2.
        """
        
        # content embedding and noise
        content = self.content_embeddings(src_tokens[:, 1])
        content = torch.view(self.ntokens, self.dim)
        content = content + self.noise.sample(sample_shape=content.size())
        
        # sentiment positional embedding
        positions = torch.arange(0, self.ntokens).unsqueeze_(0) # 1 x ntokens
        sentiment = src_tokens[:, 0].unsqueeze_(1).unsqueeze_(2) # batch x 1 x 1
        sentiment = self.positive_embedding(positions) * sentiment + \
                     self.negative_embedding(positions) * (torch.tensor(1) - sentiment) # batch x ntokens x dim
        
        x = content + sentiment
        x = F.dropout(x, p=self.dropout, training=self.training)
        return {
            'encoder_out': (x,x),
            'encoder_padding_mask': None
        }
        


In [30]:
class NoEncoderFConvDecoderModel(FairseqEncoderDecoderModel):
    """
    encoder-decoder that use the no-encoder as encoder and the fconv decoder as decoder.
    inspiration from fconv.py
    """
    def __init__(self, encoder, decoder):
        super().__init__(encoder, decoder)

In [32]:
def get_dataset(max_examples):
    g = lines_generator()

    id_f = data.Field(sequential=False, use_vocab=False)
    stars_f = data.Field(sequential=False, use_vocab=False)
    review_f = data.Field(sequential=True, use_vocab=True)

    dataset = RestDataset(g, id_f, stars_f, review_f, max_examples)

    review_f.build_vocab(dataset)
    
    return dataset, review_f.vocab

nsamples = 50000
dataset, vocab = get_dataset(50000)

49190it [00:00, 10.97it/s]


In [34]:
batch_size = 64
iterations_per_epoch = nsamples // batch_size
print(iterations_per_epoch)
print(len(vocab))
train_iter = data.BucketIterator(
        dataset=dataset, batch_size=batch_size,
        sort_key = lambda x: len(x.review), sort=True, repeat=True)

781
7590


In [35]:
PAD = len(vocab)

In [36]:
def checkpoint(model, path='/cs/labs/dshahaf/omribloch/train/text_lord/checkpoint.txt'):
#     Save:
    if os.path.exists(path):
        os.remove(path)
    torch.save(model.state_dict(), path)
    print('saved checkpoint!')

def load_checkpoint(path='/cs/labs/dshahaf/omribloch/train/text_lord/checkpoint.txt'):
#     Load:
    model = TheModelClass(*args, **kwargs)
    model.load_state_dict(torch.load(PATH))
    model.eval()

In [None]:
decoder_dictionary = fairseq.data.dictionary()
for token in vocab:

In [None]:
encoder = NoEncoder(nsamples, PAD)
decoder = FConvDecoder()