In [1]:
from qdrant_client import QdrantClient
from qdrant_client.http import models
from qdrant_client.http.models import Distance, VectorParams
from qdrant_client.http.models import PointStruct
import numpy as np
import random
import pandas as pd
import time
import pickle

import utils

In [2]:
qdrantClient = QdrantClient(host='localhost', port=6333, timeout=10000000)


In [3]:
import os
from dotenv import load_dotenv
load_dotenv()


def read_dataset():
    with open('BASEV_WITH_ATTRIBUTES.pkl', 'rb') as f:
        baseV = pickle.load(f)
    with open('QUERYV_WITH_ATTRIBUTES.pkl', 'rb') as f:
        queryV = pickle.load(f)
    with open('AF_gloVe_COSINE_GT.pkl', 'rb') as f:
        groundTruth = pickle.load(f)
    return baseV, queryV, groundTruth

In [5]:
vector_size = 25
collection_name = "aftest"

qdrantClient.delete_collection(collection_name=collection_name)

qdrantClient.recreate_collection(
    collection_name=collection_name,
    vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
)

True

In [6]:
baseV, queryV, groundTruth = read_dataset()
print(type(baseV.iloc[0]["vector"][0]))
# baseV = pd.DataFrame({'vector': baseV[:10000]})
# queryV = pd.DataFrame({'vector': queryV[:100]})


<class 'numpy.float32'>


In [7]:
baseV['vector'] = baseV['vector'].apply(lambda x: [float(i) for i in x])
queryV['vector'] = queryV['vector'].apply(lambda x: [float(i) for i in x])


In [8]:
print(type(baseV.iloc[0]["vector"][0]))


<class 'float'>


In [9]:
batch_points = [PointStruct(id=i, vector=elem["vector"], payload= {"attr1": elem["attr1"], "attr2": elem["attr2"], "attr3": elem["attr3"]}) for i, elem in baseV.iterrows()]


In [10]:
batch_points

[PointStruct(id=0, vector=[-0.28571999073028564, 1.6030000448226929, -0.23368999361991882, 0.42476001381874084, 0.071834996342659, -1.6633000373840332, -0.6774700284004211, -0.20066000521183014, 0.7255899906158447, -0.722599983215332, 0.09668300300836563, 1.0442999601364136, 1.1964000463485718, -0.2735399901866913, 1.44159996509552, 0.06502100080251694, 0.9345399737358093, -0.40575000643730164, 0.9226999878883362, -0.29600998759269714, -0.5180299878120422, 0.8512099981307983, -1.0339000225067139, 0.050655998289585114, 0.13964000344276428], payload={'attr1': True, 'attr2': True, 'attr3': True}),
 PointStruct(id=1, vector=[-2.3120999336242676, -1.069100022315979, 0.3303000032901764, -0.8492599725723267, -0.450980007648468, -1.1102999448776245, -2.791800022125244, -0.34088999032974243, 1.4888999462127686, 0.06055299937725067, -1.1956000328063965, -0.3486599922180176, 0.2411700040102005, 1.770300030708313, -1.250100016593933, -0.5211899876594543, 0.07655300199985504, -1.1669000387191772, 0

In [11]:
batch_size = 50000
n = len(batch_points)
num_batches = n // batch_size + int(n % batch_size > 0)

for batch_idx in range(num_batches):
    start_idx = batch_idx * batch_size
    end_idx = min((batch_idx + 1) * batch_size, n)

    batch_points_i = batch_points[start_idx:end_idx]

    operation_info = qdrantClient.upsert(
        collection_name=collection_name,
        wait=True,
        points=batch_points_i
    )

In [15]:
print(f'Search function starting')
result_ids = []
for i,elem in queryV.iterrows():
    print(f'Progress: {i}/{len(queryV)}', end='\r')
    # print(elem)
    vec = elem["vector"]
    attr1 = elem["attr1"]
    attr2 = elem["attr2"]
    attr3 = elem["attr3"]
    # print(attr1, attr2, attr3)
    search_result = qdrantClient.search(
        collection_name=collection_name,
        query_vector=vec, 
        query_filter=models.Filter(
            # must = AND
            must=[
                models.FieldCondition(
                    key="attr1",
                    match=models.MatchValue(
                        value=attr1,
                    ),
                ),
                models.FieldCondition(
                    key="attr2",
                    match=models.MatchValue(
                        value=attr2,
                    ),
                ),
                models.FieldCondition(
                    key="attr3",
                    match=models.MatchValue(
                        value=attr3,
                    ),
                )
            ]
        ),
        limit=100
    )
    result_ids.append([elem.id for elem in search_result])

Search function starting
Progress: 99/100

In [17]:
true_positives = 0
n_classified = 0
for i,elem in enumerate(result_ids):
    true_positives_iter = len(np.intersect1d(groundTruth[i], result_ids[i]))
    true_positives += true_positives_iter
    n_classified += len(elem)
print(true_positives)
print(n_classified)
print(f'Average recall: {true_positives/n_classified}')

10000
10000
Average recall: 1.0
