In [167]:
%load_ext cython
%timeit

import Cython

The cython extension is already loaded. To reload it, use:
  %reload_ext cython


In [1]:
import numpy as np

#### Asymmetric Distance computation

Currently the code performs

```
dists = np.sum(self.dtable[range(M), codes], axis=1)
```
which is equivalent to 

```python
dists = np.zeros((N, )).astype(np.float32)
for n in range(N):
    for m in range(M):
        dists[n] += self.dtable[m][codes[n][m]]

```

let us consider 

M = 32

In [159]:
M = 32
np.random.seed(123)

n_cluster = 256
dtable = np.array(np.random.random((M,n_cluster)), 'float32')

np.random.seed(123)
pq_codes_batch = np.array([np.random.randint([M]*M)])
N, M = pq_codes_batch.shape


In [160]:
pq_codes_batch

array([[30, 13, 30,  2, 28,  2,  6, 17, 19, 10, 27, 25, 22,  1,  0, 17,
        30, 15,  9,  0, 14,  0, 15, 25, 19, 14, 29,  4,  0, 16,  4, 17]])

In [161]:
dtable[range(n_subvectors),pq_codes_batch].sum(axis=1)

array([17.402649], dtype=float32)

In [162]:
def distances_loop_py(N,M, dtable):
    dists = np.zeros((N, )).astype(np.float32)
    for n in range(N):
        for m in range(M):
            dists[n] += dtable[m, pq_codes_batch[n,m]]
    return dists

In [163]:
distances_loop_py(1,M,dtable)

array([17.402647], dtype=float32)

In [166]:
dtable.shape

(32, 256)

In [164]:
%timeit distances_loop_py(1,M,dtable)

13.9 µs ± 307 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [168]:
%timeit dtable[range(n_subvectors),pq_codes_batch].sum(axis=1)

8.03 µs ± 213 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [177]:
%%cython -a
cimport numpy as cnp
cimport cython
             
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef distances_loop_cy(long M,float[:,:] dtable,long[:] pq_code):
    cdef float dist = 0
    for m in range(M):
        dist += dtable[m, pq_code[m]]

    return dist

In [178]:
pq_code = pq_codes_batch.flatten()

In [180]:
%timeit distances_loop_cy(M, dtable, pq_code)

532 ns ± 1.95 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)


In [181]:
%timeit dtable[range(n_subvectors),pq_codes_batch].sum(axis=1)

7.94 µs ± 385 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
