In [1]:
import h5py
import numpy as np
from biotite.database import uniprot
from qdrant_client import QdrantClient, models as qdmodels
from scipy.spatial.distance import cdist

### Load dataset (ProtT5 embeddings of Homo sapiens proteins)

In [2]:
dataset = h5py.File("../data/UP000005640_9606.h5")
dataset

<HDF5 file "UP000005640_9606.h5" (mode r)>

In [3]:
len(dataset)

20591

### (Prelim. check) Which is faster to read an embedding vector out from h5 file?

In [4]:
acc_id, vector = next(iter(dataset.items()))

In [5]:
%timeit vector[:].astype(float).tolist()

13.1 µs ± 317 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [6]:
%timeit list(vector.astype(float))

22.8 ms ± 850 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


-> directly using the underlying numpy array is far faster

### Create a collection and insert points

In [7]:
client = QdrantClient()

In [8]:
collection_name = "UP000005640_9606"

client.recreate_collection(
    collection_name=collection_name,
    vectors_config=qdmodels.VectorParams(size=1024, distance=qdmodels.Distance.COSINE),
)

In [9]:
batch_size = 128

points = []
for idx, (acc_id, vector) in enumerate(dataset.items()):
    point = qdmodels.PointStruct(
        id=idx, vector=vector[:].astype(float).tolist(), payload={"accession_id": acc_id}
    )
    points.append(point)
    if len(points) == batch_size:
        op_results = client.upsert(collection_name=collection_name, points=points)
        points = []
if points:
    op_results = client.upsert(collection_name=collection_name, points=points)
del points

In [10]:
op_results

UpdateResult(operation_id=160, status=<UpdateStatus.COMPLETED: 'completed'>)

### Querying

In [11]:
query_acc_id = np.random.choice(list(dataset.keys()))
query_vector = dataset[query_acc_id][:].astype(float).tolist()

In [12]:
hits = client.search(
    collection_name=collection_name,
    query_vector=query_vector,
    limit=10,
)

In [13]:
%%timeit
client.search(
    collection_name=collection_name,
    query_vector=query_vector,
    limit=10,
)

7.93 ms ± 250 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [14]:
%%timeit
client.search(
    collection_name=collection_name,
    search_params=qdmodels.SearchParams(exact=True),
    query_vector=query_vector,
    limit=10,
)

10.8 ms ± 1.26 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [15]:
def fetch_uniprot_entry(accession_id: str) -> tuple[str, str]:
    annotation, *seq = uniprot.fetch(accession_id, format="fasta").read().split("\n")
    return annotation, "".join(seq)

In [16]:
print(query_acc_id)
fetch_uniprot_entry(query_acc_id)

Q02543


('>sp|Q02543|RL18A_HUMAN 60S ribosomal protein L18a OS=Homo sapiens OX=9606 GN=RPL18A PE=1 SV=2',
 'MKASGTLREYKVVGRCLPTPKCHTPPLYRMRIFAPNHVVAKSRFWYFVSQLKKMKKSSGEIVYCGQVFEKSPLRVKNFGIWLRYDSRSGTHNMYREYRDLTTAGAVTQCYRDMGARHRARAHSIQIMKVEEIAASKCRRPAVKQFHDSKIKFPLPHRVLRRQHKPRFTTKRPNTFF')

In [17]:
for hit in hits:
    annotation, seq = fetch_uniprot_entry(hit.payload["accession_id"])
    print(hit.score, annotation)

1.0000001 >sp|Q02543|RL18A_HUMAN 60S ribosomal protein L18a OS=Homo sapiens OX=9606 GN=RPL18A PE=1 SV=2
0.915057 >sp|P61313|RL15_HUMAN 60S ribosomal protein L15 OS=Homo sapiens OX=9606 GN=RPL15 PE=1 SV=2
0.9046874 >sp|P62910|RL32_HUMAN 60S ribosomal protein L32 OS=Homo sapiens OX=9606 GN=RPL32 PE=1 SV=2
0.9021665 >sp|Q07020|RL18_HUMAN 60S ribosomal protein L18 OS=Homo sapiens OX=9606 GN=RPL18 PE=1 SV=2
0.9017502 >sp|P40429|RL13A_HUMAN 60S ribosomal protein L13a OS=Homo sapiens OX=9606 GN=RPL13A PE=1 SV=2
0.90106285 >sp|P18077|RL35A_HUMAN 60S ribosomal protein L35a OS=Homo sapiens OX=9606 GN=RPL35A PE=1 SV=2
0.9007937 >sp|P62241|RS8_HUMAN 40S ribosomal protein S8 OS=Homo sapiens OX=9606 GN=RPS8 PE=1 SV=2
0.8986009 >sp|P62280|RS11_HUMAN 40S ribosomal protein S11 OS=Homo sapiens OX=9606 GN=RPS11 PE=1 SV=3
0.89437604 >sp|P46778|RL21_HUMAN 60S ribosomal protein L21 OS=Homo sapiens OX=9606 GN=RPL21 PE=1 SV=2
0.88958335 >sp|P18124|RL7_HUMAN 60S ribosomal protein L7 OS=Homo sapiens OX=9606 GN=

### How fast is Qdrant? Querying via direct distance calculation with scipy

In [18]:
# Create a "vector database"
all_embeddings = np.stack([vector[:].astype(float) for vector in dataset.values()])
query_embedding = dataset[query_acc_id][:].astype(float)

In [19]:
all_acc_ids = list(dataset.keys())

In [20]:
# "Querying"
distances = 1 - cdist(query_embedding.reshape(1, -1), all_embeddings, metric="cosine")
hit_ids = np.argsort(distances[0])[::-1][:10]
hit_acc_ids = [all_acc_ids[idx] for idx in hit_ids]
list(zip(distances[0, hit_ids], hit_acc_ids))

[(0.9999999999999998, 'Q02543'),
 (0.9150568741108658, 'P61313'),
 (0.9046873686191571, 'P62910'),
 (0.9021664299651811, 'Q07020'),
 (0.9017501718261806, 'P40429'),
 (0.9010627598826844, 'P18077'),
 (0.9007935517352109, 'P62241'),
 (0.8986008819611099, 'P62280'),
 (0.8943759676859638, 'P46778'),
 (0.8895832426555201, 'P18124')]

The result is identical to the one by Qdrant.

In [21]:
%%timeit
distances = 1 - cdist(query_embedding.reshape(1, -1), all_embeddings, metric="cosine")
hit_ids = np.argsort(distances[0])[::-1][:10]
hit_acc_ids = [all_acc_ids[idx] for idx in hit_ids]

39.9 ms ± 1.52 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


Qdrant search is far faster than the naive kNN calculation using SciPy (and possibly more memory-efficient) even when Qdrant search is exact.