In [1]:
%env TOKENIZERS_PARALLELISM=false

env: TOKENIZERS_PARALLELISM=false


In [22]:
import itertools
import math
import numpy

from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans

In [23]:
EMBEDDING_SIZE = 384

In [24]:
dataset = load_dataset("imdb")
dataset = dataset["unsupervised"][:100]["text"]

len(dataset)

100

In [25]:
sentence_transformer = SentenceTransformer("all-MiniLM-L6-v2")
embedded = sentence_transformer.encode(dataset)

embedded.shape

(100, 384)

In [26]:
n_clusters = math.ceil(math.sqrt(len(embedded)))
index = KMeans(n_clusters=n_clusters, n_init="auto").fit(embedded)

centroids = index.cluster_centers_
centroids.shape

(10, 384)

In [146]:
matrix = [[] for _ in range(n_clusters)]

for cluster, embedding in zip(index.labels_.tolist(), embedded.tolist()):
    matrix[cluster].append(embedding)

filler = numpy.ones(EMBEDDING_SIZE) * -10_000
max_size = max([len(val) for val in matrix])

[len(col) for col in matrix]

[6, 8, 11, 7, 20, 14, 16, 6, 9, 3]

In [147]:
for cluster, embedding in enumerate(matrix):
    cluster_size = len(embedding)

    matrix[cluster].extend([filler] * (max_size - cluster_size))

matrix = numpy.array(matrix)
matrix.shape

(10, 20, 384)

In [148]:
query = "sci-fi"
query = sentence_transformer.encode(query)

query.shape

(384,)

In [149]:
cluster = numpy.argmax(centroids @ query)

vector = numpy.zeros(n_clusters)
vector[cluster] = 1

vector

array([0., 0., 0., 1., 0., 0., 0., 0., 0., 0.])

In [150]:
result = vector @ matrix.reshape(n_clusters, max_size * EMBEDDING_SIZE)
result = result.reshape(max_size, EMBEDDING_SIZE)

result.shape

(20, 384)

In [151]:
query @ result.T

array([ 3.07303688e-01,  2.42736120e-01,  2.07748113e-01,  2.95851877e-01,
        2.98658304e-01,  1.79303235e-01,  2.83564630e-01, -3.36159450e+03,
       -3.36159450e+03, -3.36159450e+03, -3.36159450e+03, -3.36159450e+03,
       -3.36159450e+03, -3.36159450e+03, -3.36159450e+03, -3.36159450e+03,
       -3.36159450e+03, -3.36159450e+03, -3.36159450e+03, -3.36159450e+03])