In [35]:
from biotite.database import uniprot
from qdrant_client import QdrantClient, models as qdmodels
import h5py
import numpy as np

### Load dataset (ProtT5 embeddings of Homo sapiens proteins)

In [2]:
dataset = h5py.File("/home/hideki/data/UP000005640_9606.h5")
dataset

<HDF5 file "UP000005640_9606.h5" (mode r)>

In [3]:
len(dataset)

20591

### Which is faster to read out an embedding from h5 file?

In [14]:
acc_id, vector = next(iter(dataset.items()))

In [16]:
%timeit vector[:].astype(float).tolist()

23.3 µs ± 34.3 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [17]:
%timeit list(vector.astype(float))

73.3 ms ± 403 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


-> directly using the underlying numpy array is far faster

### Create a collection and insert points

In [28]:
client = QdrantClient()

In [31]:
collection_name = "UP000005640_9606"

client.recreate_collection(
    collection_name=collection_name,
    vectors_config=qdmodels.VectorParams(size=1024, distance=qdmodels.Distance.COSINE),
)

In [33]:
batch_size = 128

points = []
for idx, (acc_id, vector) in enumerate(dataset.items()):
    point = qdmodels.PointStruct(
        id=idx, vector=vector[:].astype(float).tolist(), payload={"accession_id": acc_id}
    )
    points.append(point)
    if len(points) == batch_size:
        op_results = client.upsert(collection_name=collection_name, points=points)
        points = []
if points:
    op_results = client.upsert(collection_name=collection_name, points=points)
del points

In [34]:
op_results

UpdateResult(operation_id=160, status=<UpdateStatus.COMPLETED: 'completed'>)

### Querying

In [38]:
query_acc_id = np.random.choice(list(dataset.keys()))
hits = client.search(
    collection_name=collection_name,
    query_vector=dataset[query_acc_id][:].astype(float).tolist(),
    limit=10,
)

In [40]:
query_acc_id

'P40259'

In [53]:
def fetch_uniprot_entry(accession_id: str) -> tuple[str, str]:
    annotation, *seq = uniprot.fetch(accession_id, format="fasta").read().split("\n")
    return annotation, "".join(seq)

In [45]:
fetch_uniprot_entry(query_acc_id)

('>sp|P40259|CD79B_HUMAN B-cell antigen receptor complex-associated protein beta chain OS=Homo sapiens OX=9606 GN=CD79B PE=1 SV=1',
 'MARLALSPVPSHWMVALLLLLSAEPVPAARSEDRYRNPKGSACSRIWQSPRFIARKRGFTVKMHCYMNSASGNVSWLWKQEMDENPQQLKLEKGRMEESQNESLATLTIQGIRFEDNGIYFCQQKCNNTSEVYQGCGTELRVMGFSTLAQLKQRNTLKDGIIMIQTLLIILFIIVPIFLLLDKDDSKAGMEEDHTYEGLDIDQTATYEDIVTLRTGEVKWSVGEHPGQE')

In [55]:
for hit in hits:
    annotation, seq = fetch_uniprot_entry(hit.payload["accession_id"])
    print(hit.score, annotation)

1.0000001 >sp|P40259|CD79B_HUMAN B-cell antigen receptor complex-associated protein beta chain OS=Homo sapiens OX=9606 GN=CD79B PE=1 SV=1
0.9043885 >sp|P11912|CD79A_HUMAN B-cell antigen receptor complex-associated protein alpha chain OS=Homo sapiens OX=9606 GN=CD79A PE=1 SV=2
0.89464164 >sp|Q8NET5|NFAM1_HUMAN NFAT activation molecule 1 OS=Homo sapiens OX=9606 GN=NFAM1 PE=1 SV=1
0.8810506 >sp|Q9UIB8|SLAF5_HUMAN SLAM family member 5 OS=Homo sapiens OX=9606 GN=CD84 PE=1 SV=1
0.8793238 >sp|Q13291|SLAF1_HUMAN Signaling lymphocytic activation molecule OS=Homo sapiens OX=9606 GN=SLAMF1 PE=1 SV=1
0.8782977 >sp|Q9NQ25|SLAF7_HUMAN SLAM family member 7 OS=Homo sapiens OX=9606 GN=SLAMF7 PE=1 SV=1
0.8766905 >sp|Q15116|PDCD1_HUMAN Programmed cell death protein 1 OS=Homo sapiens OX=9606 GN=PDCD1 PE=1 SV=3
0.875193 >sp|Q96A28|SLAF9_HUMAN SLAM family member 9 OS=Homo sapiens OX=9606 GN=SLAMF9 PE=2 SV=2
0.873362 >sp|Q9BZW8|CD244_HUMAN Natural killer cell receptor 2B4 OS=Homo sapiens OX=9606 GN=CD244 PE=