In [None]:
import itertools
import pathlib
import urllib.request

import faiss
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel

import pandas as pd

In [None]:
MODEL_DIR = pathlib.Path().absolute().parent / "models"

In [None]:
# Define the device to use, using a CUDA GPU if available.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the pre-trained tokenizer and model
model_name = ['bert-base-uncased', 'bert-large-uncased'][1]
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=MODEL_DIR)
model = AutoModel.from_pretrained(model_name).to(device)

In [None]:
# Download the sonnets (free for non-commercial use)
url = "https://flgr.sh/txtfssSontxt"
document = [b.decode('UTF-8') for b in urllib.request.urlopen(url).readlines()]

In [None]:
without_header = list(itertools.dropwhile(lambda x: len(x.strip()) > 0, document))
cleaned = [str(line).strip() for line in without_header]

In [None]:
sonnet_number = None
sonnets = {}
in_between_sonnets = True

for line in cleaned:
    is_empty = len(line) == 0
    if in_between_sonnets:
        if is_empty:
            pass
        elif line.isnumeric():
            sonnet_number = int(line)
            sonnets[sonnet_number] = []
        elif sonnet_number is not None:
            in_between_sonnets = False
            sonnets[sonnet_number].append(line)
        else:
            # wait for sonnet number
            pass
    else:
        if is_empty:
            in_between_sonnets = True
            sonnet_number = None
        else:
            sonnets[sonnet_number].append(line)


In [None]:
def canonicalize(s):
    no_punctuation = ''.join([c for c in s if c.isalpha() or c == ' '])
    return no_punctuation.lower().strip()

In [None]:
def encode(strs):
    # The Bert paper mentions prepending a [CLS] token and adding a [SEP] token to separate sentences
    # https://arxiv.org/pdf/1810.04805.pdf
    # However, this seems to make the scores worse, so we don't do it
    with torch.no_grad():
        encoded_input = tokenizer(strs, padding=True, truncation=True, return_tensors="pt")
        encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
        model_output = model(**encoded_input)
    return model_output.last_hidden_state[:, 0, :].detach().cpu().numpy()

In [None]:
df = pd.DataFrame([{'sonnet_number': sonnet_number, 'line_number': line_index+1, 'text': text,
                    'embeddings': encode([canonicalize(text)])[0]}
                   for sonnet_number, lines in sonnets.items()
                   for line_index, text in enumerate(lines)])

In [None]:
embeddings = np.vstack(df.embeddings.values)
print(embeddings.shape)

In [None]:
d = embeddings.shape[1]
index = faiss.IndexFlatL2(d)
index.add(embeddings)

In [None]:
def search(query):
    xq = encode([canonicalize(query)])
    D, I = index.search(xq, k=10)
    result = df.iloc[I[0]][['sonnet_number', 'line_number', 'text']]
    result['distance'] = D[0]
    return result

In [None]:
search("rough winds shake the flowers of spring")

In [None]:
# Find the most similar lines
search("Rough winds do shake the darling buds of May,")