In [1]:
import torch

from sent_order.conll import Corpus, Document

from torch.nn import functional as F
from itertools import islice
from tqdm import tqdm_notebook

In [2]:
model = torch.load(
    '../../data/coref.4.bin',
    map_location={'cpu': 'cpu'},
)

In [3]:
model

Coref(
  (encode_doc): DocEncoder(
    (embeddings): WordEmbedding(39414, 300)
    (lstm): LSTM(300, 200, batch_first=True, bidirectional=True)
    (dropout): Dropout(p=0.5)
  )
  (score_spans): SpanScorer(
    (attention): Scorer(
      (score): Sequential(
        (0): Linear(in_features=400, out_features=150, bias=True)
        (1): ReLU()
        (2): Linear(in_features=150, out_features=150, bias=True)
        (3): ReLU()
        (4): Linear(in_features=150, out_features=1, bias=True)
      )
    )
    (width_embeddings): DistanceEmbedding(9, 20)
    (sm): Scorer(
      (score): Sequential(
        (0): Linear(in_features=1120, out_features=150, bias=True)
        (1): ReLU()
        (2): Linear(in_features=150, out_features=150, bias=True)
        (3): ReLU()
        (4): Linear(in_features=150, out_features=1, bias=True)
      )
    )
  )
  (score_pairs): PairScorer(
    (dist_embeddings): DistanceEmbedding(9, 20)
    (sa): Scorer(
      (score): Sequential(
        (0): Linear(

In [4]:
c = Corpus.from_combined_file('../../data/test.conll')

In [5]:
def score(spans):
    return sum([F.softmax(span.sij, 0)[-1] for span in spans])

In [7]:
correct, total = 0, 0
for tokens1, tokens2 in tqdm_notebook(islice(c.sent_pair_tokens(), 200)):
    
    d1 = Document(tokens1+tokens2)
    d2 = Document(tokens2+tokens1)
    
    pred1 = model.predict(d1)
    pred2 = model.predict(d2)
    
    print(pred1, pred2)

[{(16, 16), (3, 3)}] []
[] []
[] []
[] []
[] []
[] []
[] []
[] []
[] [{(0, 0), (25, 25)}]
[] []
[] [{(3, 4), (38, 39)}]
[] []
[] []
[{(15, 16), (19, 20)}] [{(1, 2), (30, 31)}]
[] []
[] []
[{(9, 10), (26, 27)}] [{(7, 8), (35, 36)}]
[] []
[] []
[] []
[] []
[] []
[{(36, 36), (45, 45), (2, 2)}] [{(25, 25), (19, 19), (10, 10)}]
[{(5, 5), (19, 19), (10, 10)}] [{(12, 12), (7, 7), (21, 21)}]
[] []
[] []
[{(36, 36), (25, 25), (3, 3)}, {(21, 22), (32, 32)}] [{(27, 27), (7, 7), (18, 18)}]
[{(3, 4), (14, 14)}, {(27, 27), (7, 7), (18, 18)}] [{(28, 28), (17, 17), (3, 3)}, {(13, 14), (24, 24)}]
[] []
[] []
[] []
[] []
[] []
[{(5, 5), (17, 17)}, {(22, 23), (2, 3)}] [{(15, 15), (3, 3)}, {(20, 21), (0, 1)}]
[{(15, 15), (3, 3)}, {(20, 21), (0, 1), (46, 47)}] [{(20, 21), (40, 41), (17, 18)}, {(23, 23), (35, 35)}]
[] []
[] []
[{(5, 6), (8, 8)}] [{(3, 4), (6, 6)}]
[{(30, 31), (15, 18)}, {(41, 42), (47, 48), (55, 57)}] [{(27, 28), (35, 37)}, {(42, 43), (45, 45)}]
[{(27, 28), (35, 37)}] [{(29, 30), (37, 39), 

In [8]:
correct, total = 0, 0
for tokens1, tokens2 in tqdm_notebook(islice(c.sent_pair_tokens(), 200)):
    
    d1 = Document(tokens1+tokens2)
    d2 = Document(tokens2+tokens1)
    
    try:
        
        pred1 = model.predict(d1)
        pred2 = model.predict(d2)
        
        if not pred1 or not pred2:
            continue
    
        spans1 = model(d1)
        spans2 = model(d2)
        
        if score(spans1) > score(spans2):
            correct += 1

        total += 1
        
    except:
        pass




In [10]:
correct / total

0.5526315789473685

In [12]:
d = Document.from_text('d1', 0, 'After Mr. Trump’s election, while the F.B.I. was investigating whether his campaign helped Russian efforts to put Mr. Trump in the Oval Office, Mr. Cohen visited his boss in the White House. During that February 2017 visit, Mr. Cohen left for the national security adviser, Michael Flynn, a plan to lift sanctions against Russia, which had been imposed for its attacks on Ukraine. These sanctions had squeezed the sorts of people Mr. Cohen dealt with. The plan was proposed by Mr. Sater and a Ukrainian politician with ties to Paul Manafort, a former Trump campaign chairman.')

In [21]:
model.eval()

Coref(
  (encode_doc): DocEncoder(
    (embeddings): WordEmbedding(39414, 300)
    (lstm): LSTM(300, 200, batch_first=True, bidirectional=True)
    (dropout): Dropout(p=0.5)
  )
  (score_spans): SpanScorer(
    (attention): Scorer(
      (score): Sequential(
        (0): Linear(in_features=400, out_features=150, bias=True)
        (1): ReLU()
        (2): Linear(in_features=150, out_features=150, bias=True)
        (3): ReLU()
        (4): Linear(in_features=150, out_features=1, bias=True)
      )
    )
    (width_embeddings): DistanceEmbedding(9, 20)
    (sm): Scorer(
      (score): Sequential(
        (0): Linear(in_features=1120, out_features=150, bias=True)
        (1): ReLU()
        (2): Linear(in_features=150, out_features=150, bias=True)
        (3): ReLU()
        (4): Linear(in_features=150, out_features=1, bias=True)
      )
    )
  )
  (score_pairs): PairScorer(
    (dist_embeddings): DistanceEmbedding(9, 20)
    (sa): Scorer(
      (score): Sequential(
        (0): Linear(

In [25]:
model.predict(d)

[{(1, 4), (14, 14), (28, 29), (31, 31), (44, 45), (82, 83), (92, 93)},
 {(1, 2), (21, 22)}]

In [27]:
for c in model.predict(d):
    for i1, i2 in c:
        print([t.text for t in d.tokens[i1:i2+1]])
    print('---')

['Mr.', 'Cohen']
['Mr.', 'Cohen']
['Mr.', 'Sater']
['Mr.', 'Trump', '’', 's']
['Mr.', 'Cohen']
['his']
['his']
---
['Mr.', 'Trump']
['Mr.', 'Trump']
---
