In [44]:
import time
import faiss
import numpy as np
import sys


class ProductQuantizer:
    def __init__(self, vector_size, subquantizer_size):
        self.vector_size = vector_size
        self.num_centroids = subquantizer_size
        self.num_subquantizers = int(vector_size / subquantizer_size)
        self.pq = faiss.ProductQuantizer(
            vector_size, self.num_subquantizers, int(np.log2(subquantizer_size))
        )

    def train(self, vectors):
        """_summary_

        Args:
            vectors (_type_): input vectors with shape [batch_size, vector_size]
        """
        start = time.time()
        self.pq.train(vectors)
        print("Completed in {} secs".format(time.time() - start))

    def encode(self, vectors):
        return self.pq.compute_codes(vectors)

    def decode(self, encoded_vectors):
        return self.pq.decode(encoded_vectors)

In [45]:
# dataset with shape [batch, sentence_length, vector_length] -> [batch, sentence_length * vector_length]
batch = 4372
sentence_length = 65
vector_length = 768

dataset = np.random.random((batch, sentence_length * vector_length))
sys.getsizeof(dataset)

1746002048

In [46]:
pq = ProductQuantizer(
    vector_size=sentence_length * vector_length, subquantizer_size=128
)

In [47]:
pq.train(dataset)

Completed in 26.18165874481201 secs


In [48]:
encoded = pq.encode(dataset)
sys.getsizeof(encoded)

1495352

In [49]:
decoded = pq.decode(encoded)

In [50]:
np.linalg.norm(dataset[0] - decoded[0])

61.3063492475266