In [23]:
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
from sentence_transformers.quantization import quantize_embeddings

# 1. Specify preffered dimensions
dimensions = 512

# 2. load model
model = SentenceTransformer("all-MiniLM-L6-v2", truncate_dim=dimensions)

# For retrieval you need to pass this prompt.
query = 'Represent this sentence for searching relevant passages: A man is eating a piece of bread'

docs = [
    query,
    "A man is eating food.",
    "A man is eating pasta.",
    "The girl is carrying a baby.",
    "A man is riding a horse.",
]

# 2. Encode
embeddings = model.encode(docs)

# Optional: Quantize the embeddings
binary_embeddings = quantize_embeddings(embeddings, precision="ubinary")

similarities = cos_sim(embeddings[0], embeddings[1:])
print('similarities:', similarities)


similarities: tensor([[ 0.4510,  0.2982, -0.0929,  0.0698]])


In [30]:
binary_embeddings[0].nbytes

48

In [70]:
import numpy as np

for i in [1, 2, 3, 4]:
    print((np.unpackbits(binary_embeddings[0]) == np.unpackbits(binary_embeddings[i])).sum())

260
234
177
205


We can unpack these bits via the `unpackbits` method ... sure ... but we can also use `%` directly. Even in numba!

In [78]:
import numba

@numba.jit(fastmath=True)
def sparse_overlap(x, y):
    count = 0
    for mod in range(8):
        xi = x % 2**(mod + 1) // 2**mod
        yi = y % 2**(mod + 1) // 2**mod
        for i in range(x.shape[0]):
            count += (xi[i] == yi[i])
    return count

In [85]:
for i in [1, 2, 3, 4]:
    print(sparse_overlap(binary_embeddings[0], binary_embeddings[i]))

260
234
177
205


In [86]:
for i in [1, 2, 3, 4]:
    main = np.array([binary_embeddings[0] % 2**(m+1) // 2**m for m in range(8)])
    other = np.array([binary_embeddings[i] % 2**(m+1) // 2**m for m in range(8)])
    print((main == other).sum())

260
234
177
205


But it's a bunch faster due to numba!

In [83]:
%%timeit 

for i in [1, 2, 3, 4]:
    _ = sparse_overlap(binary_embeddings[0], binary_embeddings[i])

10.4 µs ± 325 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [63]:
%%timeit 

for i in [1, 2, 3, 4]:
    main = np.array([binary_embeddings[0] % 2**(m+1) // 2**m for m in range(8)])
    other = np.array([binary_embeddings[i] % 2**(m+1) // 2**m for m in range(8)])
    _ = (main == other).sum()

173 µs ± 2.21 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
