In [1]:
%env TOKENIZERS_PARALLELISM=false

env: TOKENIZERS_PARALLELISM=false


In [2]:
import itertools
import math
import numpy
import tenseal

from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans

In [3]:
EMBEDDING_SIZE = 384

In [4]:
context = tenseal.context(
    tenseal.SCHEME_TYPE.CKKS,
    poly_modulus_degree=8192,
    coeff_mod_bit_sizes=[60, 40, 40, 60]
)

context.generate_galois_keys()
context.global_scale = 2**40

In [5]:
dataset = load_dataset("imdb")
dataset = dataset["unsupervised"][:100]["text"]

len(dataset)

100

In [6]:
sentence_transformer = SentenceTransformer("all-MiniLM-L6-v2")
embedded = sentence_transformer.encode(dataset)

embedded.shape

(100, 384)

In [7]:
n_clusters = math.ceil(math.sqrt(len(embedded)))
index = KMeans(n_clusters=n_clusters, n_init="auto").fit(embedded)

centroids = index.cluster_centers_
centroids.shape

(10, 384)

In [8]:
matrix = [[] for _ in range(n_clusters)]

for cluster, embedding in zip(index.labels_.tolist(), embedded.tolist()):
    matrix[cluster].append(embedding)

filler = numpy.ones(EMBEDDING_SIZE) * 0
max_size = max([len(val) for val in matrix])

[len(col) for col in matrix]

[6, 19, 11, 14, 14, 9, 4, 3, 3, 17]

In [9]:
for cluster, embedding in enumerate(matrix):
    cluster_size = len(embedding)

    matrix[cluster].extend([filler] * (max_size - cluster_size))

matrix = numpy.array(matrix)
matrix.shape

(10, 19, 384)

In [10]:
query = "sci-fi"
query = sentence_transformer.encode(query)

query.shape

(384,)

In [11]:
cluster = numpy.argmax(centroids @ query)

vector = numpy.zeros(n_clusters)
vector[cluster] = 1

secure = tenseal.ckks_vector(context, vector.tolist())
secure

<tenseal.tensors.ckksvector.CKKSVector at 0x10622fcd0>

In [12]:
result = secure.matmul(
    matrix.reshape(n_clusters, max_size * EMBEDDING_SIZE).tolist()
)

result

<tenseal.tensors.ckksvector.CKKSVector at 0x1061d8850>

In [13]:
result = result.decrypt()
result = numpy.array(result).reshape(max_size, EMBEDDING_SIZE)

result.shape

(19, 384)

In [14]:
query @ result.T

array([ 2.38827589e-01,  2.06322503e-01,  2.33940425e-01,  2.47272666e-01,
        1.44581811e-01,  2.88038064e-01, -1.01220650e-09,  1.66964467e-09,
       -6.86774740e-10,  1.08107844e-09,  1.01582327e-10,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00])