# AQ Recall

The following is not additive quantization, only the codebooks and codes have the same structure as additive quantization

In [1]:
import numpy as np
from scipy.cluster.vq import kmeans2

n, nq, D = 10000, 2000, 128
np.random.seed(15)
X = np.random.randn(n, D).astype(np.float32)  
queries = np.random.randn(nq,D).astype(np.float32)
M,K = 8,256

centroid, code = kmeans2(X, K, minit='points')
centroid.shape  # shape = (256,128)

codebooks = centroid
codes = code 
RX = X
for i in range(1,M):
    RX = RX - centroid[code]

    centroid , code = kmeans2(RX, K)

    codebooks = np.r_[codebooks,centroid]
    codes = np.c_[codes,code]
print(codebooks.shape)
print(codes.shape)



(2048, 128)
(10000, 8)


## compute recall

In [2]:
from evaluationRecall import SearchNeighbors_AQ, recall_atN

# M (int): The number of codebooks  
# K (int): The number of codewords for each codebook  
# D (int): The dim of each vector  
# aq_codebooks (np.ndarray): shape=(M*K, D) with dtype=np.float32.  
#     aq_codebooks[0:K,:] represents the K codewords in the first codebook  
#     aq_codebooks[(m-1)*K:mK,:] represents the K codewords in the m-th codebook  
# aq_codes (np.ndarray): AQ codes with shape=(n, M) and dtype=np.int, where n is the number of encoded datapoints.  
    # aq_codes[i,j] is in {0,1,...,K-1} for all i,j
# metric (str): dot_product or l2_distance 

raq = SearchNeighbors_AQ(M = M, K = K, D = D, aq_codebooks = codebooks, aq_codes = codes, metric="dot_product")

# This will get the true nearest neighbor of the queries by brute force search.
ground_truth = raq.brute_force_search(X,queries,metric="dot_product")

In [3]:
# This will get topk neighbors(raq.neighbors_matrix) of queries and compute the recall
neighbors_matrix = raq.par_neighbors(queries=queries, topk=512, njobs=4)
recall_atN(neighbors_matrix,ground_truth)

par_neighbors took 3.4909019470214844 seconds
recall 1@1 = 0.09
recall 1@2 = 0.136
recall 1@4 = 0.2045
recall 1@8 = 0.2885
recall 1@10 = 0.318
recall 1@16 = 0.385
recall 1@20 = 0.4235
recall 1@32 = 0.5185
recall 1@64 = 0.662
recall 1@100 = 0.7315
recall 1@128 = 0.7735
recall 1@256 = 0.8785
recall 1@512 = 0.949


N=[1, 2, 4, 8, 10, 16, 20, 32, 64, 100, 128, 256, 512]
recall1@N:[0.09, 0.136, 0.2045, 0.2885, 0.318, 0.385, 0.4235, 0.5185, 0.662, 0.7315, 0.7735, 0.8785, 0.949]
