From 8389df2ecf91d32b4be34d3c62ba2ced7974cbc8 Mon Sep 17 00:00:00 2001 From: Buqian Zheng Date: Tue, 7 May 2024 18:45:14 +0800 Subject: [PATCH] remove scipy dependency for sparse while still supporting it Signed-off-by: Buqian Zheng --- examples/hello_sparse.py | 17 +++-- pymilvus/client/abstract.py | 4 +- pymilvus/client/entity_helper.py | 111 ++++++++----------------------- pymilvus/client/grpc_handler.py | 4 +- pymilvus/client/prepare.py | 5 +- pymilvus/client/utils.py | 100 +++++++++++++++++++++++++++- pymilvus/orm/collection.py | 12 ++-- pymilvus/orm/iterator.py | 4 +- pymilvus/orm/partition.py | 8 +-- pymilvus/orm/prepare.py | 4 +- pyproject.toml | 1 - requirements.txt | 1 - 12 files changed, 157 insertions(+), 114 deletions(-) diff --git a/examples/hello_sparse.py b/examples/hello_sparse.py index b6ac8f732..eab0c558e 100644 --- a/examples/hello_sparse.py +++ b/examples/hello_sparse.py @@ -10,7 +10,7 @@ import time import numpy as np -from scipy.sparse import rand +import random from pymilvus import ( connections, utility, @@ -20,7 +20,9 @@ fmt = "=== {:30} ===" search_latency_fmt = "search latency = {:.4f}s" -num_entities, dim, density = 1000, 3000, 0.005 +num_entities, dim = 1000, 3000 +# non zero count of randomly generated sparse vectors +nnz = 30 def log(msg): print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + " " + msg) @@ -54,11 +56,16 @@ def log(msg): # insert log(fmt.format("Start creating entities to insert")) rng = np.random.default_rng(seed=19530) -# this step is so damn slow -matrix_csr = rand(num_entities, dim, density=density, format='csr') + +def generate_sparse_vector(dimension: int, non_zero_count: int) -> dict: + indices = random.sample(range(dimension), non_zero_count) + values = [random.random() for _ in range(non_zero_count)] + sparse_vector = {index: value for index, value in zip(indices, values)} + return sparse_vector + entities = [ rng.random(num_entities).tolist(), - matrix_csr, + [generate_sparse_vector(dim, nnz) for _ in range(num_entities)], ] log(fmt.format("Start inserting entities")) diff --git a/pymilvus/client/abstract.py b/pymilvus/client/abstract.py index 1f0f5c2b5..6e78963fc 100644 --- a/pymilvus/client/abstract.py +++ b/pymilvus/client/abstract.py @@ -7,7 +7,7 @@ from pymilvus.grpc_gen import schema_pb2 from pymilvus.settings import Config -from . import entity_helper +from . import entity_helper, utils from .constants import DEFAULT_CONSISTENCY_LEVEL, RANKER_TYPE_RRF, RANKER_TYPE_WEIGHTED from .types import DataType @@ -327,7 +327,7 @@ def dict(self): class AnnSearchRequest: def __init__( self, - data: Union[List, entity_helper.SparseMatrixInputType], + data: Union[List, utils.SparseMatrixInputType], anns_field: str, param: Dict, limit: int, diff --git a/pymilvus/client/entity_helper.py b/pymilvus/client/entity_helper.py index 4d493e445..a946d52b6 100644 --- a/pymilvus/client/entity_helper.py +++ b/pymilvus/client/entity_helper.py @@ -1,10 +1,9 @@ import math import struct -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional import numpy as np import ujson -from scipy import sparse from pymilvus.exceptions import ( DataNotMatchException, @@ -16,67 +15,13 @@ from pymilvus.settings import Config from .types import DataType +from .utils import SciPyHelper, SparseMatrixInputType, SparseRowOutputType CHECK_STR_ARRAY = True -# in search results, if output fields includes a sparse float vector field, we -# will return a SparseRowOutputType for each entity. Using Dict for readability. -# TODO(SPARSE): to allow the user to specify output format. -SparseRowOutputType = Dict[int, float] - -# we accept the following types as input for sparse matrix in user facing APIs -# such as insert, search, etc.: -# - scipy sparse array/matrix family: csr, csc, coo, bsr, dia, dok, lil -# - iterable of iterables, each element(iterable) is a sparse vector with index -# as key and value as float. -# dict example: [{2: 0.33, 98: 0.72, ...}, {4: 0.45, 198: 0.52, ...}, ...] -# list of tuple example: [[(2, 0.33), (98, 0.72), ...], [(4, 0.45), ...], ...] -# both index/value can be str numbers: {'2': '3.1'} -SparseMatrixInputType = Union[ - Iterable[ - Union[ - SparseRowOutputType, - Iterable[Tuple[int, float]], # only type hint, we accept int/float like types - ] - ], - sparse.csc_array, - sparse.coo_array, - sparse.bsr_array, - sparse.dia_array, - sparse.dok_array, - sparse.lil_array, - sparse.csr_array, - sparse.spmatrix, -] - - -def sparse_is_scipy_matrix(data: Any): - return isinstance(data, sparse.spmatrix) - - -def sparse_is_scipy_array(data: Any): - # sparse.sparray, the common superclass of sparse.*_array, is introduced in - # scipy 1.11.0, which requires python 3.9, higher than pymilvus's current requirement. - return isinstance( - data, - ( - sparse.bsr_array, - sparse.coo_array, - sparse.csc_array, - sparse.csr_array, - sparse.dia_array, - sparse.dok_array, - sparse.lil_array, - ), - ) - - -def sparse_is_scipy_format(data: Any): - return sparse_is_scipy_matrix(data) or sparse_is_scipy_array(data) - def entity_is_sparse_matrix(entity: Any): - if sparse_is_scipy_format(entity): + if SciPyHelper.is_scipy_sparse(entity): return True try: @@ -143,34 +88,30 @@ def sparse_float_row_to_bytes(indices: Iterable[int], values: Iterable[float]): data += struct.pack("f", v) return data - def unify_sparse_input(data: SparseMatrixInputType) -> sparse.csr_array: - if isinstance(data, sparse.csr_array): - return data - if sparse_is_scipy_array(data): - return data.tocsr() - if sparse_is_scipy_matrix(data): - return sparse.csr_array(data.tocsr()) - row_indices = [] - col_indices = [] - values = [] - for row_id, row_data in enumerate(data): - row = row_data.items() if isinstance(row_data, dict) else row_data - row_indices.extend([row_id] * len(row)) - col_indices.extend( - [int(col_id) if isinstance(col_id, str) else col_id for col_id, _ in row] - ) - values.extend([float(value) if isinstance(value, str) else value for _, value in row]) - return sparse.csr_array((values, (row_indices, col_indices))) - if not entity_is_sparse_matrix(data): raise ParamError(message="input must be a sparse matrix in supported format") - csr = unify_sparse_input(data) + result = schema_types.SparseFloatArray() - result.dim = csr.shape[1] - for start, end in zip(csr.indptr[:-1], csr.indptr[1:]): - result.contents.append( - sparse_float_row_to_bytes(csr.indices[start:end], csr.data[start:end]) - ) + + if SciPyHelper.is_scipy_sparse(data): + csr = data.tocsr() + result.dim = csr.shape[1] + for start, end in zip(csr.indptr[:-1], csr.indptr[1:]): + result.contents.append( + sparse_float_row_to_bytes(csr.indices[start:end], csr.data[start:end]) + ) + else: + dim = 0 + for _, row_data in enumerate(data): + indices = [] + values = [] + row = row_data.items() if isinstance(row_data, dict) else row_data + for index, value in row: + indices.append(index) + values.append(value) + result.contents.append(sparse_float_row_to_bytes(indices, values)) + dim = max(dim, indices[-1] + 1) + result.dim = dim return result @@ -186,7 +127,7 @@ def sparse_proto_to_rows( def get_input_num_rows(entity: Any) -> int: - if sparse_is_scipy_format(entity): + if SciPyHelper.is_scipy_sparse(entity): return entity.shape[0] return len(entity) @@ -354,7 +295,7 @@ def pack_field_value_to_field_data( field_data.vectors.bfloat16_vector += v_bytes elif field_type == DataType.SPARSE_FLOAT_VECTOR: # field_value is a single row of sparse float vector in user provided format - if not sparse_is_scipy_format(field_value): + if not SciPyHelper.is_scipy_sparse(field_value): field_value = [field_value] elif field_value.shape[0] != 1: raise ParamError(message="invalid input for sparse float vector: expect 1 row") diff --git a/pymilvus/client/grpc_handler.py b/pymilvus/client/grpc_handler.py index 7d37aedde..49e86a4ef 100644 --- a/pymilvus/client/grpc_handler.py +++ b/pymilvus/client/grpc_handler.py @@ -23,7 +23,7 @@ from pymilvus.grpc_gen import milvus_pb2 as milvus_types from pymilvus.settings import Config -from . import entity_helper, interceptor, ts_utils +from . import entity_helper, interceptor, ts_utils, utils from .abstract import AnnSearchRequest, BaseRanker, CollectionSchema, MutationResult, SearchResult from .asynch import ( CreateIndexFuture, @@ -761,7 +761,7 @@ def _execute_hybrid_search( def search( self, collection_name: str, - data: Union[List[List[float]], entity_helper.SparseMatrixInputType], + data: Union[List[List[float]], utils.SparseMatrixInputType], anns_field: str, param: Dict, limit: int, diff --git a/pymilvus/client/prepare.py b/pymilvus/client/prepare.py index a756a1dc7..df14161e6 100644 --- a/pymilvus/client/prepare.py +++ b/pymilvus/client/prepare.py @@ -4,14 +4,13 @@ import numpy as np -from pymilvus.client import __version__, entity_helper from pymilvus.exceptions import DataNotMatchException, ExceptionsMessage, ParamError from pymilvus.grpc_gen import common_pb2 as common_types from pymilvus.grpc_gen import milvus_pb2 as milvus_types from pymilvus.grpc_gen import schema_pb2 as schema_types from pymilvus.orm.schema import CollectionSchema -from . import blob, ts_utils, utils +from . import __version__, blob, entity_helper, ts_utils, utils from .check import check_pass_param, is_legal_collection_properties from .constants import ( DEFAULT_CONSISTENCY_LEVEL, @@ -626,7 +625,7 @@ def _prepare_placeholder_str(cls, data: Any): def search_requests_with_expr( cls, collection_name: str, - data: Union[List, entity_helper.SparseMatrixInputType], + data: Union[List, utils.SparseMatrixInputType], anns_field: str, param: Dict, limit: int, diff --git a/pymilvus/client/utils.py b/pymilvus/client/utils.py index 7ecc35670..4168fde9c 100644 --- a/pymilvus/client/utils.py +++ b/pymilvus/client/utils.py @@ -1,6 +1,7 @@ import datetime +import importlib.util from datetime import timedelta -from typing import Any, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union import ujson @@ -270,3 +271,100 @@ def get_server_type(host: str): def dumps(v: Union[dict, str]) -> str: return ujson.dumps(v) if isinstance(v, dict) else str(v) + + +class SciPyHelper: + _checked = False + + # whether scipy.sparse.*_matrix classes exists + _matrix_available = False + # whether scipy.sparse.*_array classes exists + _array_available = False + + @classmethod + def _init(cls): + if cls._checked: + return + scipy_spec = importlib.util.find_spec("scipy") + if scipy_spec is not None: + # when scipy is not installed, find_spec("scipy.sparse") directly + # throws exception instead of returning None. + sparse_spec = importlib.util.find_spec("scipy.sparse") + if sparse_spec is not None: + scipy_sparse = importlib.util.module_from_spec(sparse_spec) + sparse_spec.loader.exec_module(scipy_sparse) + # all scipy.sparse.*_matrix classes are introduced in the same scipy + # version, so we only need to check one of them. + cls._matrix_available = hasattr(scipy_sparse, "csr_matrix") + # all scipy.sparse.*_array classes are introduced in the same scipy + # version, so we only need to check one of them. + cls._array_available = hasattr(scipy_sparse, "csr_array") + + cls._checked = True + + @classmethod + def is_spmatrix(cls, data: Any): + cls._init() + if not cls._matrix_available: + return False + from scipy.sparse import isspmatrix + + return isspmatrix(data) + + @classmethod + def is_sparray(cls, data: Any): + cls._init() + if not cls._array_available: + return False + from scipy.sparse import issparse, isspmatrix + + return issparse(data) and not isspmatrix(data) + + @classmethod + def is_scipy_sparse(cls, data: Any): + return cls.is_spmatrix(data) or cls.is_sparray(data) + + +# in search results, if output fields includes a sparse float vector field, we +# will return a SparseRowOutputType for each entity. Using Dict for readability. +# TODO(SPARSE): to allow the user to specify output format. +SparseRowOutputType = Dict[int, float] + + +# this import will be called only during static type checking +if TYPE_CHECKING: + from scipy.sparse import ( + bsr_array, + coo_array, + csc_array, + csr_array, + dia_array, + dok_array, + lil_array, + spmatrix, + ) + +# we accept the following types as input for sparse matrix in user facing APIs +# such as insert, search, etc.: +# - scipy sparse array/matrix family: csr, csc, coo, bsr, dia, dok, lil +# - iterable of iterables, each element(iterable) is a sparse vector with index +# as key and value as float. +# dict example: [{2: 0.33, 98: 0.72, ...}, {4: 0.45, 198: 0.52, ...}, ...] +# list of tuple example: [[(2, 0.33), (98, 0.72), ...], [(4, 0.45), ...], ...] +# both index/value can be str numbers: {'2': '3.1'} +SparseMatrixInputType = Union[ + Iterable[ + Union[ + SparseRowOutputType, + Iterable[Tuple[int, float]], # only type hint, we accept int/float like types + ] + ], + "csc_array", + "coo_array", + "bsr_array", + "dia_array", + "dok_array", + "lil_array", + "csr_array", + "spmatrix", +] diff --git a/pymilvus/orm/collection.py b/pymilvus/orm/collection.py index 299642c2e..6926c5d8d 100644 --- a/pymilvus/orm/collection.py +++ b/pymilvus/orm/collection.py @@ -16,7 +16,7 @@ import pandas as pd -from pymilvus.client import entity_helper +from pymilvus.client import entity_helper, utils from pymilvus.client.abstract import BaseRanker, SearchResult from pymilvus.client.constants import DEFAULT_CONSISTENCY_LEVEL from pymilvus.client.types import ( @@ -454,7 +454,7 @@ def release(self, timeout: Optional[float] = None, **kwargs): def insert( self, - data: Union[List, pd.DataFrame, Dict, entity_helper.SparseMatrixInputType], + data: Union[List, pd.DataFrame, Dict, utils.SparseMatrixInputType], partition_name: Optional[str] = None, timeout: Optional[float] = None, **kwargs, @@ -581,7 +581,7 @@ def delete( def upsert( self, - data: Union[List, pd.DataFrame, Dict, entity_helper.SparseMatrixInputType], + data: Union[List, pd.DataFrame, Dict, utils.SparseMatrixInputType], partition_name: Optional[str] = None, timeout: Optional[float] = None, **kwargs, @@ -655,7 +655,7 @@ def upsert( def search( self, - data: Union[List, entity_helper.SparseMatrixInputType], + data: Union[List, utils.SparseMatrixInputType], anns_field: str, param: Dict, limit: int, @@ -790,7 +790,7 @@ def search( if expr is not None and not isinstance(expr, str): raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(expr)) - empty_scipy_sparse = entity_helper.sparse_is_scipy_format(data) and (data.shape[0] == 0) + empty_scipy_sparse = utils.SciPyHelper.is_scipy_sparse(data) and (data.shape[0] == 0) if (isinstance(data, list) and len(data) == 0) or empty_scipy_sparse: resp = SearchResult(schema_pb2.SearchResultData()) return SearchFuture(None) if kwargs.get("_async", False) else resp @@ -957,7 +957,7 @@ def hybrid_search( def search_iterator( self, - data: Union[List, entity_helper.SparseMatrixInputType], + data: Union[List, utils.SparseMatrixInputType], anns_field: str, param: Dict, batch_size: Optional[int] = 1000, diff --git a/pymilvus/orm/iterator.py b/pymilvus/orm/iterator.py index ea25bf5f5..687601624 100644 --- a/pymilvus/orm/iterator.py +++ b/pymilvus/orm/iterator.py @@ -2,7 +2,7 @@ from copy import deepcopy from typing import Any, Dict, List, Optional, TypeVar, Union -from pymilvus.client import entity_helper +from pymilvus.client import entity_helper, utils from pymilvus.client.abstract import Hits, LoopBase from pymilvus.exceptions import ( MilvusException, @@ -283,7 +283,7 @@ def __init__( self, connection: Connections, collection_name: str, - data: Union[List, entity_helper.SparseMatrixInputType], + data: Union[List, utils.SparseMatrixInputType], ann_field: str, param: Dict, batch_size: Optional[int] = 1000, diff --git a/pymilvus/orm/partition.py b/pymilvus/orm/partition.py index 3b678c2d9..8b61f17bd 100644 --- a/pymilvus/orm/partition.py +++ b/pymilvus/orm/partition.py @@ -15,7 +15,7 @@ import pandas as pd import ujson -from pymilvus.client import entity_helper +from pymilvus.client import utils from pymilvus.client.abstract import BaseRanker, SearchResult from pymilvus.client.types import Replica from pymilvus.exceptions import MilvusException @@ -239,7 +239,7 @@ def release(self, timeout: Optional[float] = None, **kwargs): def insert( self, - data: Union[List, pd.DataFrame, entity_helper.SparseMatrixInputType], + data: Union[List, pd.DataFrame, utils.SparseMatrixInputType], timeout: Optional[float] = None, **kwargs, ) -> MutationResult: @@ -317,7 +317,7 @@ def delete(self, expr: str, timeout: Optional[float] = None, **kwargs): def upsert( self, - data: Union[List, pd.DataFrame, entity_helper.SparseMatrixInputType], + data: Union[List, pd.DataFrame, utils.SparseMatrixInputType], timeout: Optional[float] = None, **kwargs, ) -> MutationResult: @@ -357,7 +357,7 @@ def upsert( def search( self, - data: Union[List, entity_helper.SparseMatrixInputType], + data: Union[List, utils.SparseMatrixInputType], anns_field: str, param: Dict, limit: int, diff --git a/pymilvus/orm/prepare.py b/pymilvus/orm/prepare.py index 799fcad04..9145add64 100644 --- a/pymilvus/orm/prepare.py +++ b/pymilvus/orm/prepare.py @@ -16,7 +16,7 @@ import numpy as np import pandas as pd -from pymilvus.client import entity_helper +from pymilvus.client import utils from pymilvus.client.types import DataType from pymilvus.exceptions import ( DataNotMatchException, @@ -156,7 +156,7 @@ def prepare_insert_data( @classmethod def prepare_upsert_data( cls, - data: Union[List, Tuple, pd.DataFrame, entity_helper.SparseMatrixInputType], + data: Union[List, Tuple, pd.DataFrame, utils.SparseMatrixInputType], schema: CollectionSchema, ) -> List: if schema.auto_id: diff --git a/pyproject.toml b/pyproject.toml index c237b5dff..9bfaa3f16 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,6 @@ dependencies=[ "minio>=7.0.0", "pyarrow>=12.0.0", "azure-storage-blob", - "scipy", ] classifiers=[ diff --git a/requirements.txt b/requirements.txt index 84648e5fa..57ca601ef 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,7 +15,6 @@ toml==0.10.2 ujson>=2.0.0 urllib3==1.26.18 m2r==0.3.1 -scipy>=1.9.3 Sphinx==4.0.0 sphinx-copybutton sphinx-rtd-theme