In [24]:
import torch
import random

from itertools import islice
from tqdm import tqdm_notebook

from sent_order.conll import Document

In [4]:
model = torch.load(
    '../../data/coref.9.bin',
    map_location={'cpu': 'cpu'},
)



In [6]:
model.eval()

Coref(
  (encode_doc): DocEncoder(
    (embeddings): WordEmbedding(39414, 300)
    (lstm): LSTM(300, 200, batch_first=True, bidirectional=True)
    (dropout): Dropout(p=0.5)
  )
  (score_spans): SpanScorer(
    (attention): Scorer(
      (score): Sequential(
        (0): Linear(in_features=400, out_features=150, bias=True)
        (1): ReLU()
        (2): Linear(in_features=150, out_features=150, bias=True)
        (3): ReLU()
        (4): Linear(in_features=150, out_features=1, bias=True)
      )
    )
    (width_embeddings): DistanceEmbedding(9, 20)
    (sm): Scorer(
      (score): Sequential(
        (0): Linear(in_features=1120, out_features=150, bias=True)
        (1): ReLU()
        (2): Linear(in_features=150, out_features=150, bias=True)
        (3): ReLU()
        (4): Linear(in_features=150, out_features=1, bias=True)
      )
    )
  )
  (score_pairs): PairScorer(
    (dist_embeddings): DistanceEmbedding(9, 20)
    (sa): Scorer(
      (score): Sequential(
        (0): Linear(

In [18]:
def score_doc(text):
    doc = Document.from_text(text)
    spans = model(doc)
    return sum([sum(F.softmax(s.sij, 0)[:-1]) for s in spans]).item()

In [14]:
class Corpus:

    def __init__(self, path):
        self.path = path

    def lines(self):
        with open(self.path) as fh:
            for line in fh:
                yield line.strip()

    def abstract_lines(self):
        """Generate abstract line groups.
        """
        lines = []
        for line in self.lines():
            if line:
                lines.append(line)
            else:
                yield lines[2:]
                lines = []

In [26]:
c = Corpus('../../data/abstracts/test.txt')

In [29]:
correct, total = 0, 0
for sents in tqdm_notebook(islice(c.abstract_lines(), 100)):
    
    shuffled_sents = sorted(sents, key=lambda x: random.random())
    
    s1 = score_doc(' '.join(sents))
    s2 = score_doc(' '.join(shuffled_sents))
    
    if s1 > s2:
        correct += 1
        
    total += 1




In [30]:
correct / total

0.47