In [None]:
%env TOKENIZERS_PARALLELISM=false

In [None]:
import math
import numpy
import tenseal

from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans

In [None]:
EMBEDDING_SIZE = 384

### SETUP ENCRYPTION CONTEXT

In [None]:
context = tenseal.context(
    tenseal.SCHEME_TYPE.CKKS,
    poly_modulus_degree=8192,
    coeff_mod_bit_sizes=[60, 40, 40, 60],
)

context.generate_galois_keys()
context.global_scale = 2**40

### EMBEDD DATASET

In [None]:
dataset = load_dataset("imdb")
dataset = dataset["unsupervised"][:100]["text"]

len(dataset)

In [None]:
sentence_transformer = SentenceTransformer("all-MiniLM-L6-v2")
embedded = sentence_transformer.encode(dataset)

embedded.shape

### CLUSTERIZE

In [None]:
# SERVER
n_clusters = math.ceil(math.sqrt(len(embedded)))
index = KMeans(n_clusters=n_clusters, n_init="auto").fit(embedded)

In [None]:
# SEND TO CLIENT
centroids = index.cluster_centers_
centroids.shape

### CREATE MATRIX

In [None]:
# SERVER
matrix = [[] for _ in range(n_clusters)]

for cluster, embedding in zip(index.labels_.tolist(), embedded.tolist()):
    matrix[cluster].append(embedding)

filler = numpy.ones(EMBEDDING_SIZE) * 0
max_size = max([len(val) for val in matrix])

[len(col) for col in matrix]

In [None]:
# SERVER
for cluster, embedding in enumerate(matrix):
    cluster_size = len(embedding)

    matrix[cluster].extend([filler] * (max_size - cluster_size))

matrix = numpy.array(matrix)
matrix.shape

### QUERY

In [None]:
query = "sci-fi"
query = sentence_transformer.encode(query)

query.shape

In [None]:
# CLIENT
cluster = numpy.argmax(centroids @ query)

vector = numpy.zeros(n_clusters)
vector[cluster] = 1

secure = tenseal.ckks_vector(context, vector.tolist())
secure

In [None]:
# SERVER
result = secure.matmul(matrix.reshape(n_clusters, max_size * EMBEDDING_SIZE).tolist())

result

In [None]:
# CLIENT
result = result.decrypt()
result = numpy.array(result).reshape(max_size, EMBEDDING_SIZE)

result.shape

In [None]:
query @ result.T