In [None]:
import pathlib
import urllib.request
import itertools

import faiss
import numpy as np
from transformers import AutoTokenizer, AutoModel
import torch

import matplotlib_inline
import matplotlib.pyplot as plt

from sklearn.neighbors import NearestNeighbors
from sklearn.manifold import TSNE
import networkx as nx

In [None]:
MODEL_DIR = pathlib.Path().absolute().parent / "models"

In [None]:
# Define the device to use, using a CUDA GPU if available.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the pre-trained tokenizer and model
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', cache_dir=MODEL_DIR)
model = AutoModel.from_pretrained('bert-base-uncased').to(device)

In [None]:
# Download the sonnets (free for non-commercial use)
url = "https://flgr.sh/txtfssSontxt"
document = [b.decode('UTF-8') for b in urllib.request.urlopen(url).readlines()]

In [None]:
without_header = list(itertools.dropwhile(lambda x: len(x.strip()) > 0, document))
cleaned = [str(line).strip() for line in without_header]

In [None]:
sonnet_number = None
sonnets = {}
in_between_sonnets = True

for line in cleaned:
    is_empty = len(line) == 0
    if in_between_sonnets:
        if is_empty:
            pass
        elif line.isnumeric():
            sonnet_number = int(line)
            sonnets[sonnet_number] = []
        elif sonnet_number is not None:
            in_between_sonnets = False
            sonnets[sonnet_number].append(line)
        else:
            # wait for sonnet number
            pass
    else:
        if is_empty:
            in_between_sonnets = True
            sonnet_number = None
        else:
            sonnets[sonnet_number].append(line)


In [None]:
sentences = [(f"Sonnet {sonnet_number}\r\n" + "\r\n".join(sonnets[sonnet_number])).lower()
             for sonnet_number in sorted(sonnets.keys())]
print(sentences[17])

In [None]:
def encode(strs):
    encoded_input = tokenizer(strs, padding=True, truncation=True, return_tensors="pt")
    encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
    model_output = model(**encoded_input)
    return model_output.last_hidden_state[:, 0, :].detach().cpu().numpy()

In [None]:
# encode one at a time to avoid memory issues
sentence_embeddings = np.vstack([encode(sentence) for sentence in sentences])
sentence_embeddings.shape

In [None]:
d = sentence_embeddings.shape[1]
index = faiss.IndexFlatL2(d)

In [None]:
index.add(sentence_embeddings)

In [None]:
k = 3
xq = encode(["profitless usurer why dost thou use so great a sum of sums yet canst not live"])
D, I = index.search(xq, k)
print(D, I)
for i in I[0]:
    print(f"** SENTENCE={i}:", sentences[i], "\r\n")

# Plot the embedding space
The embedding space is 768-dimensional, so we need to reduce it to 2 dimensions to plot it.
We can use tSNE for this.

In [None]:
labels = [
    #"youth", "old age", "death", "decay", "time", "poetry", "arts", "children", "parenthood", "man", "woman", "anger", "jealousy"
    "time passing, youth, old age, death and decay",
]
label_embeddings = encode(labels)

all_embeddings = np.vstack([sentence_embeddings, label_embeddings])
X_tsne = TSNE(n_components=2).fit_transform(all_embeddings)
S_tsne = X_tsne[:len(sentences)]
L_tsne = X_tsne[len(sentences):]

plt.scatter(S_tsne[:, 0], S_tsne[:, 1])
plt.scatter(L_tsne[:, 0], L_tsne[:, 1], c="red")

for i, label in enumerate(labels):
    plt.annotate(label, (L_tsne[i, 0], L_tsne[i, 1]))

plt.show()

In [None]:
# Compute the nearest neighbors on the raw embeddings
nbrs = NearestNeighbors(n_neighbors=10, algorithm='ball_tree').fit(sentence_embeddings)
distances, indices = nbrs.kneighbors(sentence_embeddings)

print(distances[:3, :3])
print(indices[:3, :3])

In [None]:
G = nx.Graph()
for i, sentence in enumerate(sentences):
    G.add_node(i, label=sentence[:20])

for i in range(len(sentences)):
    for j in range(1, nbrs.n_neighbors):
        p1, p2 = indices[i][0], indices[i][j]
        dist = distances[i][j]
        w = dist * dist
        G.add_edge(p1, p2, weight=1.0 / w, length=dist)


In [None]:

plt.subplot(121)
pos = nx.spring_layout(G, weight='weight', k=0.1, iterations=50)
nx.draw_networkx(G, pos, node_size=10, font_size=10, width=0.1, alpha=0.5, with_labels=True)

plt.subplot(122)
pos = nx.kamada_kawai_layout(G)
nx.draw_networkx(G, pos, node_size=10, font_size=10, width=0.1, alpha=0.5, with_labels=True)

plt.show()

#edge_labels = nx.get_edge_attributes(G, 'weight')
#nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)



In [None]:
print(indices[:3, :3])

In [None]:
G = nx.Graph()
G.add_nodes_from(indices[:, 0])
for i in indices[:, 0]:
    for j in range(1, nbrs.n_neighbors):
        u = i
        v = indices[i][j]
        dist = distances[i][j]
        G.add_edge(u, v, dist=dist)

In [None]:
dx = {}
for u, v, d in G.edges(data=True):
    if dx.get(u) is None:
        dx[u] = {}
    dx[u][v] = d['dist']

pos = nx.kamada_kawai_layout(G, dist=dx)
nx.draw_networkx(G, pos, node_size=10, font_size=10, width=0.1, alpha=0.5, with_labels=True)