In [None]:
import itertools
import pathlib
import urllib.request

import faiss
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel

import pandas as pd

import networkx as nx
import matplotlib.pyplot as plt
from sklearn.neighbors import NearestNeighbors
from sklearn.manifold import TSNE

In [None]:
MODEL_DIR = pathlib.Path().absolute().parent / "models"

In [None]:
# Define the device to use, using a CUDA GPU if available.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the pre-trained tokenizer and model
model_name = ['bert-base-uncased',
              'bert-large-uncased',
              'facebook/bart-large-mnli'
              ][1]
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=MODEL_DIR)
model = AutoModel.from_pretrained(model_name).to(device)

In [None]:
# Download the sonnets (free for non-commercial use)
url = "https://flgr.sh/txtfssSontxt"
document = [b.decode('UTF-8') for b in urllib.request.urlopen(url).readlines()]

In [None]:
without_header = list(itertools.dropwhile(lambda x: len(x.strip()) > 0, document))
cleaned = [str(line).strip() for line in without_header]

In [None]:
sonnet_number = None
sonnets = {}
in_between_sonnets = True

for line in cleaned:
    is_empty = len(line) == 0
    if in_between_sonnets:
        if is_empty:
            pass
        elif line.isnumeric():
            sonnet_number = int(line)
            sonnets[sonnet_number] = []
        elif sonnet_number is not None:
            in_between_sonnets = False
            sonnets[sonnet_number].append(line)
        else:
            # wait for sonnet number
            pass
    else:
        if is_empty:
            in_between_sonnets = True
            sonnet_number = None
        else:
            sonnets[sonnet_number].append(line)


In [None]:
def canonicalize(s):
    no_punctuation = ''.join([c for c in s if c.isalpha() or c == ' '])
    return no_punctuation.lower().strip()

In [None]:
def encode(strs):
    # The Bert paper mentions prepending a [CLS] token and adding a [SEP] token to separate sentences
    # https://arxiv.org/pdf/1810.04805.pdf
    # However, this seems to make the scores worse, so we don't do it
    with torch.no_grad():
        encoded_input = tokenizer(strs, padding=True, truncation=True, return_tensors="pt")
        encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
        model_output = model(**encoded_input)
    return model_output.last_hidden_state[:, 0, :].detach().cpu().numpy()

In [None]:
df = pd.DataFrame([{'sonnet_number': sonnet_number, 'line_number': line_index+1, 'text': text,
                    'embeddings': encode([canonicalize(text)])[0]}
                   for sonnet_number, lines in sonnets.items()
                   for line_index, text in enumerate(lines)])

In [None]:
embeddings = np.vstack(df.embeddings.values)
print(embeddings.shape)

In [None]:
d = embeddings.shape[1]
index = faiss.IndexFlatL2(d)
index.add(embeddings)

In [None]:
def search(query):
    xq = encode([canonicalize(query)])
    D, I = index.search(xq, k=10)
    result = df.iloc[I[0]][['sonnet_number', 'line_number', 'text']]
    result['distance'] = D[0]
    return result

In [None]:
search("rough winds shake the flowers of spring")

In [None]:
# Find the most similar lines
search("Rough winds do shake the darling buds of May,")

# Clustering
Let's try to cluster the lines into topics.

In [None]:
n_centroids = 10
n_iter = 100
verbose = True

k_means = faiss.Kmeans(d, n_centroids, niter=n_iter, verbose=verbose)
k_means.train(embeddings)
assignments = k_means.assign(embeddings)


In [None]:
df_clusters = df.copy()
df_clusters['cluster'] = assignments[1]

In [None]:
df_clusters[df_clusters['cluster']==0]

In [None]:
n_sonnets = len(set(df_clusters['sonnet_number']))

# See docs: https://networkx.org/documentation/stable/auto_examples/drawing/plot_multipartite_graph.html#sphx-glr-auto-examples-drawing-plot-multipartite-graph-py
G = nx.Graph()

def cluster_node(i):
    return f'C{i}'

def sonnet_node(sn):
    return f'S{sn}'

for sn in set(df_clusters.sonnet_number.values):
    G.add_node(sonnet_node(sn), layer=1, type='sonnet', label=f'Sonnet {sn}')

for c in range(n_centroids):
    G.add_node(cluster_node(c), layer=0, type='cluster', label=f'Cluster {c}')

for i, r in df_clusters.iterrows():
    u = cluster_node(r.cluster)
    v = sonnet_node(r.sonnet_number)
    G.add_edge(u,v, weight=10) # weight is the attractive force

In [None]:
pos=nx.spring_layout(G,weight='weight')
type_to_col = {'sonnet':'red', 'cluster':'blue'}
cols = [type_to_col[d['type']] for n,d in G.nodes(data=True)]
nx.draw_networkx(G,pos=pos, node_color=cols,  with_labels=True)

In [None]:
df_clusters[df_clusters['cluster']==0]

# Graphing by mean embedding

In [None]:
df_means = pd.DataFrame()
df_means['sonnet_number'] = df.sonnet_number.unique()
df_means['mean_embedding'] = [np.mean(df[df['sonnet_number']==sn]['embeddings'].values)
                              for sn in df_means.sonnet_number.values]

df_means

In [None]:
n_sonnets = df_means.shape[0]

mean_embeddings = np.vstack(df_means.mean_embedding)
mean_embeddings.shape


In [None]:
X_tsne = TSNE(n_components=2).fit_transform(mean_embeddings)

plt.scatter(X_tsne[:, 0], X_tsne[:, 1])

for i,sn in enumerate(df_means.sonnet_number.values):
    plt.annotate(f'S {sn}', (X_tsne[i, 0], X_tsne[i, 1]))

plt.show()