In [1]:
import matplotlib as mpl
import matplotlib.pyplot as plt

%matplotlib inline
mpl.style.use('bmh')

In [17]:
import numpy as np

import ujson
import attr
import random
import torch

from glob import glob
from tqdm import tqdm_notebook
from itertools import islice
from boltons.iterutils import pairwise
from collections import Counter

from gensim.models import KeyedVectors

from torch import nn
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torch.nn import functional as F

In [3]:
vectors = KeyedVectors.load_word2vec_format(
    '../data/vectors/GoogleNews-vectors-negative300.bin.gz',
    binary=True,
)

In [4]:
class Corpus:
    
    def __init__(self, pattern, skim=None):
        self.pattern = pattern
        self.skim = skim
        
    def lines(self):
        for path in glob(self.pattern):
            with open(path) as fh:
                for line in fh:
                    yield line.strip()

    def abstracts(self):
        lines = self.lines()
        if self.skim:
            lines = islice(lines, self.skim)
        for line in tqdm_notebook(lines, total=self.skim):
            raw = ujson.loads(line)
            yield Abstract.from_raw(raw)
            
    def xy(self):
        for abstract in self.abstracts():
            yield from abstract.xy()

In [5]:
@attr.s
class Abstract:
    
    sentences = attr.ib()
    
    @classmethod
    def from_raw(cls, raw):
        return cls([Sentence(s['token']) for s in raw['sentences']])
    
    def tensor(self):
        return torch.stack([s.tensor() for s in self.sentences])

In [6]:
@attr.s
class Sentence:
    
    tokens = attr.ib()
    
    def tensor(self, dim=300, pad=50):
        x = [vectors[t] for t in self.tokens if t in vectors]
        x += [np.zeros(dim)] * pad
        x = x[:pad]
        x = np.array(x)
        x = torch.from_numpy(x)
        x = x.float()
        return x

In [30]:
class Model(nn.Module):
    
    def __init__(self, embed_dim, hidden_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.rnn = nn.RNN(embed_dim, hidden_dim, nonlinearity='relu', batch_first=True)
        self.hidden2y = nn.Linear(2*hidden_dim, 1)
        
    def init_hidden(self):
        return Variable(torch.zeros(1, 1, self.hidden_dim))
        
    def forward(self, s1, s2):
        
        h1 = self.init_hidden()
        _, h1 = self.rnn(s1, h1)
        
        h2 = self.init_hidden()
        _, h2 = self.rnn(s2, h2)
        
        h = torch.cat([h1, h2], 2)
        
        y = self.hidden2y(h)
        y = F.sigmoid(y)
        
        return y

In [75]:
train = Corpus('../data/train.json/*.json', 100)

In [76]:
torch.manual_seed(1)

<torch._C.Generator at 0x10f802ca8>

In [77]:
model = Model(300, 150)

In [78]:
criterion = nn.MSELoss()

In [79]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [80]:
def train_pair(model, s1, s2, y):

    s1 = Variable(s1.unsqueeze(0))
    s2 = Variable(s2.unsqueeze(0))

    model.zero_grad()

    y = Variable(torch.FloatTensor([y]))
    y_pred = model(s1, s2)

    loss = criterion(y_pred, y)
    loss.backward()

    optimizer.step()
    
    return loss.data[0]

In [81]:
train_loss = []
for epoch in range(5):
    
    epoch_loss = 0
    for ab in train.abstracts():
        for s1, s2 in pairwise(ab.tensor()):
            epoch_loss += train_pair(model, s1, s2, 1)
            epoch_loss += train_pair(model, s2, s1, 0)

    epoch_loss /= (train.skim*2)
    train_loss.append(epoch_loss)
    print(epoch_loss)




Exception in thread Thread-20:
Traceback (most recent call last):
  File "/usr/local/Cellar/python3/3.6.2/Frameworks/Python.framework/Versions/3.6/lib/python3.6/threading.py", line 916, in _bootstrap_inner
    self.run()
  File "/Users/dclure/Projects/plot-ordering/env/lib/python3.6/site-packages/tqdm/_tqdm.py", line 144, in run
    for instance in self.tqdm_cls._instances:
  File "/usr/local/bin/../Cellar/python3/3.6.2/bin/../Frameworks/Python.framework/Versions/3.6/lib/python3.6/_weakrefset.py", line 60, in __iter__
    for itemref in self.data:
RuntimeError: Set changed size during iteration




1.1187777937948704



1.115949678234756



1.1134537059383



1.1130424741363094



1.1129283291752654
