In [1]:
import matplotlib as mpl
import matplotlib.pyplot as plt

%matplotlib inline
mpl.style.use('bmh')

In [17]:
import numpy as np

import ujson
import attr
import random
import torch

from glob import glob
from tqdm import tqdm_notebook
from itertools import islice
from boltons.iterutils import pairwise
from collections import Counter

from gensim.models import KeyedVectors

from torch import nn
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torch.nn import functional as F

In [3]:
vectors = KeyedVectors.load_word2vec_format(
    '../data/vectors/GoogleNews-vectors-negative300.bin.gz',
    binary=True,
)

In [4]:
class Corpus:
    
    def __init__(self, pattern, skim=None):
        self.pattern = pattern
        self.skim = skim
        
    def lines(self):
        for path in glob(self.pattern):
            with open(path) as fh:
                for line in fh:
                    yield line.strip()

    def abstracts(self):
        lines = self.lines()
        if self.skim:
            lines = islice(lines, self.skim)
        for line in tqdm_notebook(lines, total=self.skim):
            raw = ujson.loads(line)
            yield Abstract.from_raw(raw)
            
    def xy(self):
        for abstract in self.abstracts():
            yield from abstract.xy()

In [5]:
@attr.s
class Abstract:
    
    sentences = attr.ib()
    
    @classmethod
    def from_raw(cls, raw):
        return cls([Sentence(s['token']) for s in raw['sentences']])
    
    def tensor(self):
        return torch.stack([s.tensor() for s in self.sentences])

In [6]:
@attr.s
class Sentence:
    
    tokens = attr.ib()
    
    def tensor(self, dim=300, pad=50):
        x = [vectors[t] for t in self.tokens if t in vectors]
        x += [np.zeros(dim)] * pad
        x = x[:pad]
        x = np.array(x)
        x = torch.from_numpy(x)
        x = x.float()
        return x

In [182]:
class SentenceEncoder(nn.Module):
    
    def __init__(self, embed_dim, hidden_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.rnn = nn.RNN(embed_dim, hidden_dim, nonlinearity='relu', batch_first=True)
        
    def init_hidden(self):
        return Variable(torch.zeros(1, 1, self.hidden_dim))
        
    def forward(self, x):
        hidden = self.init_hidden()
        rnn_out, hidden = self.rnn(x, hidden)
        return hidden

In [183]:
class Model(nn.Module):
    
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.i2h = nn.Linear(2*input_dim, hidden_dim)
        self.h2o = nn.Linear(hidden_dim, 1)
        
    def forward(self, x):
        y = F.relu(self.i2h(x))
        y = F.sigmoid(self.h2o(y))
        return y

In [184]:
train = Corpus('../data/train.json/*.json', 100)

In [185]:
torch.manual_seed(1)

<torch._C.Generator at 0x10f802ca8>

In [186]:
HIDDEN_DIM = 150

In [187]:
sent_encoder = SentenceEncoder(300, HIDDEN_DIM)

In [190]:
model = Model(HIDDEN_DIM, 100)

In [191]:
criterion = nn.BCELoss()

In [192]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

In [193]:
def train_pair(model, s1, s2, y):
    
    x = torch.cat([s1, s2])
    
    model.zero_grad()

    y = Variable(torch.FloatTensor([y]))
    y_pred = model(x)

    loss = criterion(y_pred, y)
    loss.backward(retain_graph=True)

    optimizer.step()
    
    return loss.data[0]

In [194]:
train_loss = []
for epoch in range(10):
    
    epoch_loss = 0
    for ab in train.abstracts():
        
        sent_encoder.zero_grad()
        
        sents = ab.tensor()
        sents = Variable(sents)
        sents = sent_encoder(sents)
        
        for s1, s2 in pairwise(sents[0]):
            epoch_loss += train_pair(model, s1, s2, 1)
            epoch_loss += train_pair(model, s2, s1, 0)

    epoch_loss /= (train.skim*2)
    train_loss.append(epoch_loss)
    print(epoch_loss)


3.1018488770723343



3.1005013370513916



3.0993597680330276



3.099168309867382



3.09919187232852



3.098306267410517



3.097184216082096



3.0961393562704327



3.104404819570482



3.0942999129742383
