In [1]:
import matplotlib as mpl
import matplotlib.pyplot as plt

%matplotlib inline
mpl.style.use('bmh')

In [83]:
import torch
import attr
import pandas as pd

from textblob import TextBlob
from cached_property import cached_property
from annoy import AnnoyIndex
from itertools import combinations
from tqdm import tqdm_notebook
from scipy.spatial.distance import cosine

from sent_order.models.kt_regression import Sentence, SentenceEncoder

In [3]:
sent_encoder = torch.load(
    '../../data/models/new/kt-reg/sent_encoder.366.bin',
    map_location={'cuda:0': 'cpu'},
)

In [7]:
@attr.s
class Text:

    raw = attr.ib()
    
    @classmethod
    def from_path(cls, path):
        with open(path) as fh:
            return cls(fh.read())
    
    @cached_property
    def blob(self):
        return TextBlob(self.raw)
    
    def sentence_variables(self):
        for sent in self.blob.sentences:
            sent = Sentence(list(sent.tokens))
            yield sent.variable()

In [8]:
gg = Text.from_path('../../data/novels/great-gatsby.txt')

In [14]:
gg_sents = sent_encoder(gg.sentence_variables())

In [15]:
sar = Text.from_path('../../data/novels/sun-also-rises.txt')

In [16]:
sar_sents = sent_encoder(sar.sentence_variables())

In [53]:
gg_idx = AnnoyIndex(1000)

In [54]:
for i in range(len(gg_sents)):
    gg_idx.add_item(i, gg_sents[i].data.tolist())

In [142]:
gg_idx.build(10)

True

In [143]:
gg_idx.get_nns_by_vector(gg_sents[160].data.tolist(), 10, include_distances=True)

([160, 2540, 915, 2322, 2548, 1342, 1613, 1213, 378, 1507],
 [0.0,
  1.1178419589996338,
  1.1187474727630615,
  1.118809700012207,
  1.1280927658081055,
  1.1366569995880127,
  1.1381558179855347,
  1.138733983039856,
  1.143604040145874,
  1.1484378576278687])

In [77]:
gg.blob.sentences[160]

Sentence("As if his absence quickened something within her, Daisy leaned forward again, her voice glowing and singing.")

In [69]:
gg.blob.sentences[610]

Sentence("I was in the ninth machine-gun battalion.”

“I was in the Seventh Infantry until June nineteen-eighteen.")

In [89]:
matches = []
for sar_id in tqdm_notebook(range(len(sar_sents))):
    gg_ids, ds = gg_idx.get_nns_by_vector(sar_sents[sar_id].data.tolist(), 10, include_distances=True)
    for gg_id, d in zip(gg_ids, ds):
        if d > 0:
            matches.append((sar_id, gg_id, d))




In [90]:
df = pd.DataFrame(matches, columns=('sar_id', 'gg_id', 'd'))

In [136]:
df.sort_values('d').head(1000)

Unnamed: 0,sar_id,gg_id,d
30732,3102,1857,0.504626
43876,4417,318,0.590269
54133,5446,1685,0.601002
14204,1443,1956,0.602994
41486,4178,2571,0.608421
55383,5571,1991,0.617864
67023,6742,318,0.633026
31062,3135,1109,0.635736
64833,6520,2459,0.640016
74733,7513,198,0.642554


In [137]:
for r in df.sort_values('d').head(100).itertuples():
    print(sar.blob.sentences[r.sar_id], '|', gg.blob.sentences[r.gg_id])
    print('---')

“Come on, Michael. | “Come on, Tom.
---
What do you think of that?”

“I don’t know.”

“That’s it. | What kind do you want, lady?”

“I’d like to get one of those police dogs; I don’t suppose you got that kind?”

The man peered doubtfully into the basket, plunged in his hand and drew one up, wriggling, by the back of the neck.
---
That was it all right. | That was it.
---
“Really?”

“No. | “What?”

“Want any?”

“No .
---
Do you care?”

“No,” Edna said. | Do you object to shaking hands with me?”

“Yes.
---
“It’s good. | “That’s good.
---
What do you think of that?”

“I don’t know.”

“That’s the way. | What kind do you want, lady?”

“I’d like to get one of those police dogs; I don’t suppose you got that kind?”

The man peered doubtfully into the basket, plunged in his hand and drew one up, wriggling, by the back of the neck.
---
They’re all right.”

“How did your friends like them?”

“Fine.”

“Good,” Montoya said. | They’re fine!” and he added hollowly, “.
---
He was fairly happy, except t