Skip to content

Commit

Permalink
Add pq dist table support (#158)
Browse files Browse the repository at this point in the history
* feat: add dist table compute and tests

Signed-off-by: Jianbai Ye <jianbaiye@outlook.com>

* feat: add dist table parameter

Signed-off-by: Jianbai Ye <jianbaiye@outlook.com>

* feat(cpp): add local state param

Signed-off-by: Jianbai Ye <jianbaiye@outlook.com>

* fix: add fool local state

Signed-off-by: Jianbai Ye <jianbaiye@outlook.com>

* fix: dtable for multi-metrics

Signed-off-by: Jianbai Ye <jianbaiye@outlook.com>

* fix: assure dtable memory layout

Signed-off-by: Jianbai Ye <jianbaiye@outlook.com>

* feat(cpp): support pq local data

Signed-off-by: Jianbai Ye <jianbaiye@outlook.com>

* feat(cpp): add pq batch dtable

Signed-off-by: Jianbai Ye <jianbaiye@outlook.com>

* fix(cpp): remove local stage param

Signed-off-by: Jianbai Ye <jianbaiye@outlook.com>

* test: add linear PQ results

Signed-off-by: Jianbai Ye <jianbaiye@outlook.com>

* test: fix path error

Signed-off-by: Jianbai Ye <jianbaiye@outlook.com>

* feat: cpython for dist mat

Signed-off-by: Jianbai Ye <jianbaiye@outlook.com>

* fix: cpython return

Signed-off-by: Jianbai Ye <jianbaiye@outlook.com>

* style: move into kwargs

Signed-off-by: Jianbai Ye <jianbaiye@outlook.com>

* fix: update executor parameter (#180)

* fix: update parameter

* fix: fix import

* fix: fix error

* fix: fix limit

* fix: fix match_args

* refactor: use rocksdb as the docs storage engine (#178)

* refactor: use rocksdb as doc store

* fix: unittest

* fix: close rocksdb in test

* fix: add warning comments

* fix: group bench scripts

* fix: revert bench

* fix(ci): tests

* fix: update include/hnswlib/space_ip.h

Co-authored-by: felix-wang <35718120+numb3r3@users.noreply.github.com>

Signed-off-by: Jianbai Ye <jianbaiye@outlook.com>
Co-authored-by: YangXiuyu <gzzyyxy@gmail.com>
Co-authored-by: felix-wang <35718120+numb3r3@users.noreply.github.com>
  • Loading branch information
3 people committed Sep 27, 2022
1 parent e8c5990 commit 02857ec
Show file tree
Hide file tree
Showing 11 changed files with 758 additions and 398 deletions.
86 changes: 85 additions & 1 deletion annlite/core/codec/pq.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from argparse import ArgumentError

import numpy as np
from scipy.cluster.vq import vq

from annlite import pq_bind

from ...enums import Metric
from ...math import l2_normalize
from ...profile import time_profile
from .base import BaseCodec

# from pqlite.pq_bind import precompute_adc_table, dist_pqcodes_to_codebooks
Expand Down Expand Up @@ -231,7 +234,7 @@ def get_codebook(self) -> 'np.ndarray':
Expect a 3-dimensional matrix is returned,
with shape (`n_subvectors`, `n_clusters`, `d_subvector`) and dtype float32
"""
return self.codebooks
return np.ascontiguousarray(self.codebooks, dtype='float32')

def get_subspace_splitting(self):
"""Return subspace splitting setting
Expand All @@ -240,6 +243,87 @@ def get_subspace_splitting(self):
"""
return (self.n_subvectors, self.n_clusters, self.d_subvector)

# def get_dist_mat(self, x: np.ndarray):
# """Return the distance tables in form of matrix for multiple queries

# :param query: shape('N', 'D'),

# :return: ndarray with shape('N', `n_subvectors`, `n_clusters`)

# .. note::
# _description_
# """
# assert x.dtype == np.float32
# assert x.ndim == 2
# N, D = x.shape
# assert (
# D == self.d_subvector * self.n_subvectors
# ), 'input dimension must be Ds * M'
# if self.normalize_input:
# x = l2_normalize(x)

# x = x.reshape(
# N,
# self.n_subvectors,
# 1,
# self.d_subvector,
# )
# if self.metric == Metric.EUCLIDEAN:
# # (1, n_subvectors, n_clusters, d_subvector)
# codebook = self.codebooks[np.newaxis, ...]

# # broadcast to (N, n_subvectors, n_clusters, d_subvector)
# dist_vector = (x - codebook) ** 2

# # reduce to (N, n_subvectors, n_clusters)
# dist_mat = np.sum(dist_vector, axis=3)
# elif self.metric in [Metric.INNER_PRODUCT, Metric.COSINE]:
# # (1, n_subvectors, n_clusters, d_subvector)
# codebook = self.codebooks[np.newaxis, ...]

# # broadcast to (N, n_subvectors, n_clusters, d_subvector)
# dist_vector = x * codebook

# # reduce to (N, n_subvectors, n_clusters)
# dist_mat = 1 / self.n_clusters - np.sum(dist_vector, axis=3)
# else:
# raise ArgumentError(f'Unable support metrics {self.metric}')
# return np.ascontiguousarray(dist_mat, dtype='float32')

def get_dist_mat(self, x: np.ndarray):
"""Return the distance tables in form of matrix for multiple queries
:param query: shape('N', 'D'),
:return: ndarray with shape('N', `n_subvectors`, `n_clusters`)
.. note::
_description_
"""
assert x.dtype == np.float32
assert x.ndim == 2
N, D = x.shape
assert (
D == self.d_subvector * self.n_subvectors
), 'input dimension must be Ds * M'
if self.normalize_input:
x = l2_normalize(x)

if self.metric == Metric.EUCLIDEAN:
dist_mat = pq_bind.batch_precompute_adc_table(
x, self.d_subvector, self.n_clusters, self.codebooks
)
elif self.metric in [Metric.INNER_PRODUCT, Metric.COSINE]:
dist_mat = 1 / self.n_clusters - np.array(
pq_bind.batch_precompute_adc_table_ip(
x, self.d_subvector, self.n_clusters, self.codebooks
),
dtype='float32',
)
else:
raise ArgumentError(f'Unable support metrics {self.metric}')
return np.ascontiguousarray(dist_mat, dtype='float32')

# -------------------------------------


Expand Down
26 changes: 21 additions & 5 deletions annlite/core/index/hnsw/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,14 @@ def pre_processed(self: 'HnswIndex', x: np.ndarray, *args, **kwargs):
elif not self._set_backend_pq:
self._index.loadPQ(self.pq_codec)
self._set_backend_pq = True
kwargs['pre_process_dtables'] = self.pq_codec.get_dist_mat(x)
x = self.pq_codec.encode(x)
return f(self, x, *args, **kwargs)

assert kwargs['pre_process_dtables'].dtype == 'float32'
assert kwargs['pre_process_dtables'].flags['C_CONTIGUOUS']
return f(self, x, *args, **kwargs)
else:
return f(self, x, *args, **kwargs)

return pre_processed

Expand Down Expand Up @@ -116,20 +122,28 @@ def dump(self, index_file: Union[str, Path]):
self._index.save_index(str(index_file))

@pre_process
def add_with_ids(self, x: 'np.ndarray', ids: List[int]):
def add_with_ids(
self,
x: 'np.ndarray',
ids: List[int],
# kwargs maybe used by pre_process
pre_process_dtables=None,
):
max_id = max(ids) + 1
if max_id > self.capacity:
expand_steps = math.ceil(max_id / self.expand_step_size)
self._expand_capacity(expand_steps * self.expand_step_size)

self._index.add_items(x, ids=ids)
self._index.add_items(x, ids=ids, dtables=pre_process_dtables)

@pre_process
def search(
self,
query: 'np.ndarray',
limit: int = 10,
indices: Optional['np.ndarray'] = None,
# kwargs maybe used by pre_process
pre_process_dtables=None,
):
ef_search = max(self.ef_search, limit)
self._index.set_ef(ef_search)
Expand All @@ -139,10 +153,12 @@ def search(
if len(indices) < limit:
limit = len(indices)
ids, dists = self._index.knn_query_with_filter(
query, filters=indices, k=limit
query, filters=indices, k=limit, dtables=pre_process_dtables
)
else:
ids, dists = self._index.knn_query(query, k=limit)
ids, dists = self._index.knn_query(
query, k=limit, dtables=pre_process_dtables
)

# convert squared l2 into euclidean distance
if self.metric == Metric.EUCLIDEAN:
Expand Down
Loading

0 comments on commit 02857ec

Please sign in to comment.