In [5]:
from qdrant_client import QdrantClient
from qdrant_client.http import models
from qdrant_client.http.models import Distance, VectorParams
from qdrant_client.http.models import PointStruct
import numpy as np
import random
import pandas as pd
import time
import pickle

import utils

In [6]:
qdrantClient = QdrantClient(host='localhost', port=6333, timeout=10000000)


In [7]:
import os
from dotenv import load_dotenv
load_dotenv()


def read_dataset():
    with open('BASEV_WITH_ATTRIBUTES.pkl', 'rb') as f:
        baseV = pickle.load(f)
    with open('QUERYV_WITH_ATTRIBUTES.pkl', 'rb') as f:
        queryV = pickle.load(f)
    with open('AF_SIFT_COSINE_GT.pkl', 'rb') as f:
        groundTruth = pickle.load(f)
    return baseV, queryV, groundTruth

In [8]:
vector_size = 128
collection_name = "testSIFTT"

qdrantClient.delete_collection(collection_name=collection_name)

qdrantClient.recreate_collection(
    collection_name=collection_name,
    vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
)

True

In [11]:
baseV, queryV, _ = read_dataset()
# baseV = baseV[:10000]
# queryV = queryV[:100]
# baseV = pd.DataFrame({'vector': baseV[:10000]})
# queryV = pd.DataFrame({'vector': queryV[:100]})
with open('ANN_SIFT_COSINE_GT.pkl', 'rb') as f:
        groundTruth = pickle.load(f)

Loading file: siftsmall_base.fvecs
    The dimension of the vectors in the file is: 128
    The final shape of the loaded dataset siftsmall_base.fvecs is (10000, 128).
Loading file: siftsmall_query.fvecs
    The dimension of the vectors in the file is: 128
    The final shape of the loaded dataset siftsmall_query.fvecs is (100, 128).
 Loading file: siftsmall_groundtruth.ivecs
    The dimension of the vectors in the file is: 100
    The final shape of the loaded dataset is (100, 100).


In [9]:
baseV, queryV, groundTruth = read_dataset()
print(type(baseV.iloc[0]["vector"][0]))


<class 'float'>


In [10]:
batch_points = [PointStruct(id=i, vector=elem["vector"], payload= {"attr1": elem["attr1"], "attr2": elem["attr2"], "attr3": elem["attr3"]}) for i, elem in baseV.iterrows()]


In [11]:
batch_size = 50000
n = len(batch_points)
num_batches = n // batch_size + int(n % batch_size > 0)

for batch_idx in range(num_batches):
    start_idx = batch_idx * batch_size
    end_idx = min((batch_idx + 1) * batch_size, n)

    batch_points_i = batch_points[start_idx:end_idx]

    operation_info = qdrantClient.upsert(
        collection_name=collection_name,
        wait=True,
        points=batch_points_i
    )

In [12]:
print(f'Search function starting')
result_ids = []
for i,elem in queryV.iterrows():
    print(f'Progress: {i}/{len(queryV)}', end='\r')
    # print(elem)
    vec = elem["vector"]
    attr1 = elem["attr1"]
    attr2 = elem["attr2"]
    attr3 = elem["attr3"]
    # print(attr1, attr2, attr3)
    search_result = qdrantClient.search(
        collection_name=collection_name,
        query_vector=vec, 
        query_filter=models.Filter(
            # must = AND
            must=[
                models.FieldCondition(
                    key="attr1",
                    match=models.MatchValue(
                        value=attr1,
                    ),
                ),
                models.FieldCondition(
                    key="attr2",
                    match=models.MatchValue(
                        value=attr2,
                    ),
                ),
                models.FieldCondition(
                    key="attr3",
                    match=models.MatchValue(
                        value=attr3,
                    ),
                )
            ]
        ),
        limit=100
    )
    result_ids.append([elem.id for elem in search_result])

Search function starting
Progress: 99/100

In [13]:
true_positives = 0
n_classified = 0
for i,elem in enumerate(result_ids):
    true_positives_iter = len(np.intersect1d(groundTruth[i], result_ids[i]))
    true_positives += true_positives_iter
    n_classified += len(elem)
print(true_positives)
print(n_classified)
print(f'Average recall: {true_positives/n_classified}')

10000
10000
Average recall: 1.0
