In [1]:
import numpy as np
import faiss

## 场景
1. 256维向量
2. 10w条数据库
3. 1000条待检索数据，检索top10结果

In [2]:
dimension = 256
data_size = 100_000
query_size = 1000
np.random.seed(1234)
retrieval_k = 10

In [3]:
data_vec = np.random.random((data_size, dimension)).astype(np.float32)
query_vec = np.random.random((query_size, dimension)).astype(np.float32)

# Faiss  无优化
100M 内存

In [4]:
%%time
index = faiss.IndexFlatL2(dimension)
index.add(data_vec)

CPU times: user 40.2 ms, sys: 40.2 ms, total: 80.5 ms
Wall time: 77.7 ms


In [22]:
%%time
distance, coresponding_index = index.search(query_vec, retrieval_k)
coresponding_index[0]

CPU times: user 5.87 s, sys: 481 ms, total: 6.35 s
Wall time: 325 ms


array([89650, 58430, 21260, 32710, 26650, 11909, 94663, 52705, 25054,
       23948])

In [23]:
distance[0]

array([30.510437, 30.934174, 31.149734, 31.348007, 31.397034, 31.417542,
       31.426514, 31.442612, 31.478806, 31.484848], dtype=float32)

## Faiss  IVFL2

内存占用  约180MB

In [4]:
%%time
index = faiss.IndexFlatL2(dimension)
index = faiss.IndexIVFFlat(index,dimension, 100) # 100 簇
index.train(data_vec)
index.add(data_vec)

CPU times: user 3.68 s, sys: 230 ms, total: 3.91 s
Wall time: 256 ms


In [52]:
%%time
distance, coresponding_index = index.search(query_vec, retrieval_k)
coresponding_index[0]

CPU times: user 412 ms, sys: 6.41 ms, total: 419 ms
Wall time: 24.6 ms


array([75151, 63477, 47593, 14561, 32060, 78235, 30596, 63458, 82924,
        7903])

In [53]:
distance[0]

array([31.762985, 32.859394, 33.745956, 33.82788 , 34.095734, 34.126045,
       34.21199 , 34.340767, 34.517437, 34.640167], dtype=float32)

In [54]:
%%time
index.nprobe=10
distance, coresponding_index = index.search(query_vec, retrieval_k)
coresponding_index[0]

CPU times: user 3.47 s, sys: 509 µs, total: 3.47 s
Wall time: 180 ms


array([58430, 26650, 11909, 75151,  3693, 54551, 34531, 47054, 27491,
       10108])

In [56]:
%%time
index.nprobe=30
distance, coresponding_index = index.search(query_vec, retrieval_k)
coresponding_index[0]

CPU times: user 8.26 s, sys: 0 ns, total: 8.26 s
Wall time: 415 ms


array([58430, 21260, 32710, 26650, 11909, 94663, 23948, 75151, 62514,
       93035])

In [57]:
%%time
index.nprobe=100
distance, coresponding_index = index.search(query_vec, retrieval_k)
coresponding_index[0]

CPU times: user 26.2 s, sys: 23 ms, total: 26.3 s
Wall time: 1.32 s


array([89650, 58430, 21260, 32710, 26650, 11909, 94663, 52705, 25054,
       23948])

## Faiss  IndexIVFPQ

- 方案1 8个子量化器 约20M 内存
- 方案2 128个子量化器 约40M 内存

In [5]:
%%time
m = 8                            # number of subquantizers
quantizer = faiss.IndexFlatL2(dimension)  # this remains the same
index = faiss.IndexIVFPQ(quantizer, dimension, 100, m, 8)
                                    
index.train(data_vec)
index.add(data_vec)

CPU times: user 1min 8s, sys: 1.79 s, total: 1min 9s
Wall time: 3.55 s


In [6]:
%%time
index.nprobe=10
distance, coresponding_index = index.search(query_vec, retrieval_k)
coresponding_index[0]

CPU times: user 496 ms, sys: 5.73 ms, total: 502 ms
Wall time: 27.6 ms


array([10108, 29174, 39208, 82176, 86363, 30938, 32425, 82501, 53797,
        4353])

In [4]:
%%time
m = 128                          # number of subquantizers
quantizer = faiss.IndexFlatL2(dimension)  # this remains the same
index = faiss.IndexIVFPQ(quantizer, dimension, 100, m, 8)
                                    
index.train(data_vec)
index.add(data_vec)

CPU times: user 11min 17s, sys: 6.24 s, total: 11min 23s
Wall time: 34.3 s


In [5]:
%%time
index.nprobe=10
distance, coresponding_index = index.search(query_vec, retrieval_k)
coresponding_index[0]

CPU times: user 3.11 s, sys: 4.14 ms, total: 3.11 s
Wall time: 159 ms


array([58430, 11909, 75151, 26650,  3693, 88897, 34531, 10108, 54551,
       27491])

## numpy
这里只做100条数据的查询，十分之一的计算量

numpy（这个方法）无法用到多线程，速度是faiss的100+倍
即使不考虑多线程，Faiss速度也是numpy的5倍以上（numpy代码里有循环）

In [8]:
%%time
first = True
for query in query_vec[:100]:
    ls_distances = np.sum(np.square(query-data_vec),axis=1)
    search_index = np.argpartition(ls_distances,retrieval_k)
    if first:
        first =False
        result = search_index
        dis = ls_distances[result]
result[:10]

CPU times: user 7.32 s, sys: 3.07 s, total: 10.4 s
Wall time: 7.32 s


array([89650, 58430, 21260, 32710, 94663, 26650, 11909, 52705, 25054,
       23948])

In [9]:
dis[:10]

array([30.510485, 30.934149, 31.14967 , 31.347977, 31.426548, 31.39702 ,
       31.417484, 31.442694, 31.47879 , 31.484783], dtype=float32)

## numba

numba 也是一种常用的python科学计算加速方式

使用numba 加速

In [10]:
from numba import njit as numba_njit

In [11]:
@numba_njit(parallel=True)
def numba_func(q,d):
    first = True
    for query in q[:100]:
        residual = query-d
        normd_array = np.empty(residual.shape[0])
        for i in range(residual.shape[0]):
            nrm = np.linalg.norm(residual[i])
            normd_array[i] = nrm
        
        search_index = np.argsort(normd_array)
        if first:
            first =False
            result = search_index
    return result[:10]

In [12]:
%%time
numba_func(query_vec,data_vec)

OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.


CPU times: user 2min 38s, sys: 12.7 s, total: 2min 51s
Wall time: 10.5 s


array([89650, 58430, 21260, 32710, 26650, 11909, 94663, 52705, 25054,
       23948])

In [13]:
@numba_njit(parallel=True)
def numba_func2(q,d):
    first = True
    for query in q[:100]:
        sum_of_square = np.sum(np.square(query-d),axis=1)
        search_index = np.argsort(sum_of_square)
        if first:
            first =False
            result = search_index
            dis = sum_of_square[search_index]
    return result[:10] , dis[:10]

In [14]:
%%time
numba_func2(query_vec,data_vec)

CPU times: user 6min 25s, sys: 18.9 s, total: 6min 44s
Wall time: 30 s


(array([89650, 58430, 21260, 32710, 26650, 11909, 94663, 52705, 25054,
        23948]),
 array([30.510483, 30.934147, 31.149668, 31.347973, 31.397024, 31.417475,
        31.426542, 31.442688, 31.478794, 31.484797], dtype=float32))

## Jax CPU

In [4]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
import jax
import jax.numpy as jnp
import jax.lax as lax
from jax import jit,vmap
jax.config.update('jax_platform_name', 'cpu')

In [10]:
%%time
jnp.stack([jax.block_until_ready(lax.top_k(-jnp.sum((query-data_vec)**2,axis=1),retrieval_k))[1] for query in query_vec])[0]

2022-07-08 16:50:23.651458: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected


CPU times: user 42.1 s, sys: 16.3 s, total: 58.4 s
Wall time: 48 s


DeviceArray([89650, 58430, 21260, 32710, 26650, 11909, 94663, 52705,
             25054, 23948], dtype=int32)

In [5]:
@jit
def jax_func(query_vec,data_vec):
    search_index = jnp.stack([jax.block_until_ready(lax.top_k(-jnp.linalg.norm(query-data_vec,axis=1),retrieval_k))[1] for query in query_vec])
    return search_index

In [None]:
%%time
jax_func(data_vec,query_vec)[0,:]

In [10]:
def get_topk(query):
    return jax.block_until_ready(lax.top_k(-jnp.linalg.norm(query-data_vec,axis=1),retrieval_k))[1]

@jit
def vmap_get_topk(query_vec):
    return vmap(get_topk)(query_vec)

In [18]:
%%time
vmap_get_topk(query_vec[:500])

CPU times: user 16.6 s, sys: 22.9 s, total: 39.6 s
Wall time: 9.67 s


DeviceArray([[89650, 58430, 21260, ..., 52705, 25054, 23948],
             [84304, 62062, 82023, ..., 62775, 50838, 51148],
             [38158, 55219, 64005, ..., 98387, 48624, 40577],
             ...,
             [63914, 29266, 76601, ..., 61901, 95488, 28443],
             [74164,  1430, 25750, ..., 10639, 64621, 55519],
             [66406, 70162, 15525, ..., 20121, 53211,   991]],            dtype=int32)