In [None]:
import glob
import json
import os

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from sentence_transformers import SentenceTransformer
from sklearn.manifold import TSNE

from redbox.models.file import Chunk, File

model = SentenceTransformer("bert-base-nli-mean-tokens")

In [None]:
file_names = os.listdir("../data/dev/file/")
file_paths = [os.path.join("../data/dev/file/", file_name) for file_name in file_names]

chunk_names = os.listdir("../data/dev/chunks/")
chunk_paths = [
    os.path.join("../data/dev/chunks/", chunk_name) for chunk_name in chunk_names
]

In [None]:
index = 2
with open(file_paths[index], "r") as f:
    file = File(**json.load(f))

file_basename = os.path.basename(file.name)

child_chunks = glob.glob(f"../data/dev/chunks/{file_basename}.*.json")
chunks = []

for child_chunk in child_chunks:
    with open(child_chunk) as f:
        chunk = Chunk(**json.load(f))
        chunks.append(chunk)

In [None]:
chunk_texts = [chunk.text for chunk in chunks]

pool = model.start_multi_process_pool()

chunk_vectors = model.encode_multi_process(chunk_texts, pool=pool)

In [None]:
chunk_vectors.shape

In [None]:
tsne = TSNE(n_components=2, verbose=1, n_iter=300)

X = tsne.fit_transform(chunk_vectors)

In [None]:
df = pd.DataFrame(X, columns=["x", "y"])
df["chunk_index"] = [chunk.index for chunk in chunks]

In [None]:
%matplotlib widget

fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")
ax.scatter(df["x"], df["y"], df["chunk_index"], c=df["chunk_index"], cmap="inferno")

# add lines to show sequence of chunks
for i in range(len(df) - 1):
    color_mapped = sns.color_palette("inferno", len(df["chunk_index"]))[
        int(df["chunk_index"][i])
    ]
    ax.plot(
        [df["x"][i], df["x"][i + 1]],
        [df["y"][i], df["y"][i + 1]],
        [df["chunk_index"][i], df["chunk_index"][i + 1]],
        c=color_mapped,
    )

plt.show()