# Product Quantization

In [1]:
import numpy as np

In [2]:
def ivecs_read(fname):
    a = np.fromfile(fname, dtype='int32')
    d = a[0]
    return a.reshape(-1, d + 1)[:, 1:].copy()

def fvecs_read(fname):
    return ivecs_read(fname).view('float32')

def load_siftsmall():
    xt = fvecs_read("siftsmall/siftsmall_learn.fvecs")
    xb = fvecs_read("siftsmall/siftsmall_base.fvecs")
    xq = fvecs_read("siftsmall/siftsmall_query.fvecs")
    gt = ivecs_read("siftsmall/siftsmall_groundtruth.ivecs")

    return xb, xq, xt, gt

In [3]:
xb, xq, xt, gt = load_siftsmall()

In [4]:
print("Base vectors shape: ", xb.shape)
print(f"Base vectors range: [{xb.min()}, {xb.max()}]") # are degrees
print("Query vectors shape: ", xq.shape)
print("Ground truth shape: ", gt.shape)
print("Learn vectors shape: ", xt.shape)
print("Query example:\n", xq[0])

Base vectors shape:  (10000, 128)
Base vectors range: [0.0, 180.0]
Query vectors shape:  (100, 128)
Ground truth shape:  (100, 100)
Learn vectors shape:  (25000, 128)
Query example:
 [  1.   3.  11. 110.  62.  22.   4.   0.  43.  21.  22.  18.   6.  28.
  64.   9.  11.   1.   0.   0.   1.  40. 101.  21.  20.   2.   4.   2.
   2.   9.  18.  35.   1.   1.   7.  25. 108. 116.  63.   2.   0.   0.
  11.  74.  40. 101. 116.   3.  33.   1.   1.  11.  14.  18. 116. 116.
  68.  12.   5.   4.   2.   2.   9. 102.  17.   3.  10.  18.   8.  15.
  67.  63.  15.   0.  14. 116.  80.   0.   2.  22.  96.  37.  28.  88.
  43.   1.   4.  18. 116.  51.   5.  11.  32.  14.   8.  23.  44.  17.
  12.   9.   0.   0.  19.  37.  85.  18.  16. 104.  22.   6.   2.  26.
  12.  58.  67.  82.  25.  12.   2.   2.  25.  18.   8.   2.  19.  42.
  48.  11.]


In [5]:
class ExactSearch:
    def __init__(self, data):
        self.data = data

    def search(self, query, sort=True):
        dist = np.sum((self.data - query)**2, axis=1)   

        if sort:
            return dist, np.argsort(dist)
        
        return dist, None

In [6]:
data = xb
query = xq[0]
es = ExactSearch(data)
exact_dists, exact_ranking = es.search(query)

In [7]:
exact_ranking[0:1000]

array([2176, 3752,  882, 4009, 2837,  190, 3615,  816, 1045, 1884,  224,
       3013,  292, 1272, 5307, 4938, 1295,  492, 9211, 3625, 1254, 1292,
       1625, 3553, 1156,  146,  107, 5231, 1995, 9541, 3543, 9758, 9806,
       1064, 9701, 4064, 2456, 2763, 3237, 1317, 3530,  641, 1710, 8887,
       4263, 1756,  598,  370, 2776,  121, 4058, 7245, 1895,  124, 8731,
        696, 4320, 4527, 4050, 2648, 1682, 2154, 1689, 2436, 2005, 3210,
       4002, 2774,  924, 6630, 3449, 9814, 3515, 5375,  287, 1038, 4096,
       4094,  942, 4321,  123, 3814,   97, 4293,  420, 9734, 1916, 2791,
        149, 6139, 9576, 6837, 2952, 3138, 2890, 3066, 2852,  348, 3043,
       3687, 9825, 4125,  106, 1370, 3056, 3591, 2301,  286, 8271, 4063,
       4241, 3249, 9582,  396, 8799,  613, 1705, 3998, 1848, 1543,  145,
        953, 3079, 3137, 2908,  174,  421, 3997, 9805, 4246, 9885, 9712,
       4801, 4447, 1454, 2877, 3246,  402, 3683, 8581, 9594,  579, 2306,
       1619, 2969, 3829, 9843, 2330, 3919, 9651, 20

In [8]:
gt.reshape(-1)[0:1000]

array([2176, 3752,  882, 4009, 2837,  190, 3615,  816, 1045, 1884,  224,
       3013,  292, 1272, 5307, 4938, 1295,  492, 9211, 3625, 1254, 1292,
       1625, 3553, 1156,  146,  107, 5231, 1995, 9541, 3543, 9758, 9806,
       1064, 9701, 4064, 2456, 2763, 3237, 1317, 3530,  641, 1710, 8887,
       4263, 1756,  598,  370, 2776,  121, 4058, 7245, 1895,  124, 8731,
        696, 4320, 4527, 4050, 2648, 1682, 2154, 1689, 2436, 2005, 3210,
       4002, 2774,  924, 6630, 3449, 9814, 3515, 5375,  287, 1038, 4096,
       4094,  942, 4321,  123, 3814,   97, 4293,  420, 9734, 1916, 2791,
        149, 6139, 9576, 6837, 2952, 3138, 2890, 3066, 2852,  348, 3043,
       3687, 2781, 9574, 2492, 1322, 3136, 1038, 9564,  925, 3998, 2183,
       1533,  145, 1150, 4097, 9814, 9520, 9576, 3013, 1467,  909, 3568,
       3683,  833, 9536, 3530, 2388, 9936, 8643, 3408, 3676, 2078, 3138,
         97, 1543, 2755, 3210, 2111, 2908, 3567, 1116, 9807,  800,  462,
       9824, 9842,  280, 9715, 3229, 1993,  349,  5

In [9]:
gt.reshape(-1)[:20]

array([2176, 3752,  882, 4009, 2837,  190, 3615,  816, 1045, 1884,  224,
       3013,  292, 1272, 5307, 4938, 1295,  492, 9211, 3625], dtype=int32)

In [10]:
exact_ranking[:20]

array([2176, 3752,  882, 4009, 2837,  190, 3615,  816, 1045, 1884,  224,
       3013,  292, 1272, 5307, 4938, 1295,  492, 9211, 3625])

In [11]:
gt.reshape(-1)[-20:]

array([8028, 2692, 2016, 1477,  711, 4334, 9494, 1088, 8852, 1365, 7344,
       1076, 5776, 5096, 3434, 5327, 3730,   11, 2482, 3631], dtype=int32)

In [12]:
exact_ranking[-20:]

array([7222, 4603, 4716, 5861, 8062, 5131, 4690, 7192, 5085, 4492, 4853,
       7695, 8492, 6804, 6488, 5452, 6041, 6400, 4683, 5802])

In [13]:
exact_dists[exact_ranking[-20:]]

array([450294., 450330., 450380., 450407., 450449., 450601., 450727.,
       450950., 451212., 451284., 451312., 451543., 451937., 452259.,
       452420., 452688., 453152., 454336., 454779., 456471.],
      dtype=float32)

In [14]:
# check if there are repeated values in exact_dists
len(exact_dists) == len(np.unique(exact_dists))

False

In [15]:
len(np.unique(gt.reshape(-1))) # should be 10000

5115

In [16]:
len(np.unique(exact_ranking))

10000

In [17]:
# find the first element that is different between gt and exact_ranking
diff = np.where(gt.reshape(-1) != exact_ranking)[0]
diff

array([ 100,  101,  102, ..., 9997, 9998, 9999])

In [18]:
gt.reshape(-1)[100]

2781

In [19]:
exact_ranking[100]

9825

In [20]:
exact_dists[exact_ranking[98:102]]

array([121870., 121964., 122029., 122065.], dtype=float32)