From 842ef7a912a24b0d3ec30f1b3cc3f378d26d34d2 Mon Sep 17 00:00:00 2001 From: "shanhaikang.shk" Date: Fri, 21 Feb 2025 15:13:46 +0800 Subject: [PATCH 1/5] support FTS & vector new feature Signed-off-by: shanhaikang.shk --- pyobvector/__init__.py | 10 + pyobvector/client/__init__.py | 5 + pyobvector/client/fts_index_param.py | 44 +++ pyobvector/client/index_param.py | 108 ++++-- pyobvector/client/ob_vec_client.py | 49 +++ pyobvector/client/ob_vec_json_table_client.py | 2 +- pyobvector/schema/__init__.py | 8 + pyobvector/schema/full_text_index.py | 59 +++ pyobvector/schema/match_against_func.py | 33 ++ pyobvector/schema/reflection.py | 6 +- tests/test_fts_index.py | 99 +++++ tests/test_ob_vec_more_algorithm.py | 337 ++++++++++++++++++ tests/test_reflection.py | 4 +- 13 files changed, 730 insertions(+), 34 deletions(-) create mode 100644 pyobvector/client/fts_index_param.py create mode 100644 pyobvector/schema/full_text_index.py create mode 100644 pyobvector/schema/match_against_func.py create mode 100644 tests/test_fts_index.py create mode 100644 tests/test_ob_vec_more_algorithm.py diff --git a/pyobvector/__init__.py b/pyobvector/__init__.py index 4477273..085c4a2 100644 --- a/pyobvector/__init__.py +++ b/pyobvector/__init__.py @@ -15,6 +15,7 @@ * DataType Specify field type in collection schema for MilvusLikeClient * VECTOR An extended data type in SQLAlchemy for ObVecClient * VectorIndex An extended index type in SQLAlchemy for ObVecClient +* FtsIndex Full Text Search Index * FieldSchema Clas to define field schema in collection for MilvusLikeClient * CollectionSchema Class to define collection schema for MilvusLikeClient * PartType Specify partition type of table or collection @@ -34,6 +35,9 @@ * st_distance GIS function: calculate distance between Points * st_dwithin GIS function: check if the distance between two points * st_astext GIS function: return a Point in human-readable format +* FtsParser Text Parser Type for Full Text Search +* FtsIndexParam Full Text Search index parameter +* MatchAgainst Full Text Search clause """ from .client import * from .schema import ( @@ -50,6 +54,8 @@ st_distance, st_dwithin, st_astext, + FtsIndex, + MatchAgaint, ) from .json_table import OceanBase @@ -64,6 +70,7 @@ "VECTOR", "POINT", "VectorIndex", + "FtsIndex", "OceanBaseDialect", "AsyncOceanBaseDialect", "FieldSchema", @@ -88,4 +95,7 @@ "st_dwithin", "st_astext", "OceanBase", + "FtsParser", + "FtsIndexParam", + "MatchAgaint", ] diff --git a/pyobvector/client/__init__.py b/pyobvector/client/__init__.py index 64d5dab..cdac26a 100644 --- a/pyobvector/client/__init__.py +++ b/pyobvector/client/__init__.py @@ -27,6 +27,8 @@ * ObSubHashPartition Specify Hash subpartition info * ObKeyPartition Specify Key partition info * ObSubKeyPartition Specify Key subpartition info +* FtsParser Text Parser Type for Full Text Search +* FtsIndexParam Full Text Search index parameter """ from .ob_vec_client import ObVecClient from .milvus_like_client import MilvusLikeClient @@ -35,6 +37,7 @@ from .schema_type import DataType from .collection_schema import FieldSchema, CollectionSchema from .partitions import * +from .fts_index_param import FtsParser, FtsIndexParam __all__ = [ "ObVecClient", @@ -57,4 +60,6 @@ "ObSubHashPartition", "ObKeyPartition", "ObSubKeyPartition", + "FtsParser", + "FtsIndexParam", ] diff --git a/pyobvector/client/fts_index_param.py b/pyobvector/client/fts_index_param.py new file mode 100644 index 0000000..87882e5 --- /dev/null +++ b/pyobvector/client/fts_index_param.py @@ -0,0 +1,44 @@ +"""A module to specify fts index parameters""" +from enum import Enum +from typing import List, Optional + +class FtsParser(Enum): + IK = 0 + NGRAM = 1 + + +class FtsIndexParam: + def __init__( + self, + index_name: str, + field_names: List[str], + parser_type: Optional[FtsParser], + ): + self.index_name = index_name + self.field_names = field_names + self.parser_type = parser_type + + def param_str(self) -> str: + if self.parser_type is None: + return None + if self.parser_type == FtsParser.IK: + return "ik" + if self.parser_type == FtsParser.NGRAM: + return "ngram" + + def __iter__(self): + yield "index_name", self.index_name + yield "field_names", self.field_names + if self.parser_type: + yield "parser_type", self.parser_type + + def __str__(self): + return str(dict(self)) + + def __eq__(self, other: None): + if isinstance(other, self.__class__): + return dict(self) == dict(other) + + if isinstance(other, dict): + return dict(self) == other + return False diff --git a/pyobvector/client/index_param.py b/pyobvector/client/index_param.py index 0343a93..40ba6fb 100644 --- a/pyobvector/client/index_param.py +++ b/pyobvector/client/index_param.py @@ -5,7 +5,10 @@ class VecIndexType(Enum): """Vector index algorithm type""" HNSW = 0 - # IVFFLAT = 1 + HNSW_SQ = 1 + IVFFLAT = 2 + IVFSQ = 3 + IVFPQ = 4 class IndexParam: @@ -23,6 +26,11 @@ class IndexParam: HNSW_DEFAULT_EF_CONSTRUCTION = 200 HNSW_DEFAULT_EF_SEARCH = 40 OCEANBASE_DEFAULT_ALGO_LIB = 'vsag' + HNSW_ALGO_NAME = "hnsw" + HNSW_SQ_ALGO_NAME = "hnsw_sq" + IVFFLAT_ALGO_NAME = "ivf_flat" + IVFSQ_ALGO_NAME = "ivf_sq8" + IVFPQ_ALGO_NAME = "ivf_pq" def __init__( self, index_name: str, field_name: str, index_type: Union[VecIndexType, str], **kwargs @@ -33,47 +41,89 @@ def __init__( self.index_type = self._get_vector_index_type_str() self.kwargs = kwargs + def is_index_type_hnsw_serial(self): + return self.index_type in [ + IndexParam.HNSW_ALGO_NAME, IndexParam.HNSW_SQ_ALGO_NAME + ] + + def is_index_type_ivf_serial(self): + return self.index_type in [ + IndexParam.IVFFLAT_ALGO_NAME, + IndexParam.IVFSQ_ALGO_NAME, + IndexParam.IVFPQ_ALGO_NAME, + ] + + def is_index_type_product_quantization(self): + return self.index_type in [ + IndexParam.IVFPQ_ALGO_NAME, + ] + def _get_vector_index_type_str(self): """Parse vector index type to string.""" if isinstance(self.index_type, VecIndexType): if self.index_type == VecIndexType.HNSW: - return "hnsw" - # elif self.index_type == VecIndexType.IVFFLAT: - # return "ivfflat" + return IndexParam.HNSW_ALGO_NAME + elif self.index_type == VecIndexType.HNSW_SQ: + return IndexParam.HNSW_SQ_ALGO_NAME + elif self.index_type == VecIndexType.IVFFLAT: + return IndexParam.IVFFLAT_ALGO_NAME + elif self.index_type == VecIndexType.IVFSQ: + return IndexParam.IVFSQ_ALGO_NAME + elif self.index_type == VecIndexType.IVFPQ: + return IndexParam.IVFPQ_ALGO_NAME raise ValueError(f"unsupported vector index type: {self.index_type}") assert isinstance(self.index_type, str) - if self.index_type.lower() == "hnsw": - return "hnsw" - raise ValueError(f"unsupported vector index type: {self.index_type}") + index_type = self.index_type.lower() + if index_type not in [ + IndexParam.HNSW_ALGO_NAME, + IndexParam.HNSW_SQ_ALGO_NAME, + IndexParam.IVFFLAT_ALGO_NAME, + IndexParam.IVFSQ_ALGO_NAME, + IndexParam.IVFPQ_ALGO_NAME, + ]: + raise ValueError(f"unsupported vector index type: {self.index_type}") + return index_type def _parse_kwargs(self): ob_params = {} + # handle lib + if self.is_index_type_hnsw_serial(): + ob_params['lib'] = 'vsag' + else: + ob_params['lib'] = 'OB' # handle metric_type + ob_params['distance'] = "l2" if 'metric_type' in self.kwargs: ob_params['distance'] = self.kwargs['metric_type'] - elif self.index_type == "hnsw": - ob_params['distance'] = 'l2' - else: - raise ValueError(f"unsupported vector index type: {self.index_type}") # handle param - if 'params' in self.kwargs: - for k, v in self.kwargs['params'].items(): - if k == 'M': - ob_params['m'] = v - elif k == 'efConstruction': - ob_params['ef_construction'] = v - elif k == 'efSearch': - ob_params['ef_search'] = v - else: - ob_params[k] = v - elif self.index_type == "hnsw": - ob_params['m'] = IndexParam.HNSW_DEFAULT_M - ob_params['ef_construction'] = IndexParam.HNSW_DEFAULT_EF_CONSTRUCTION - ob_params['ef_search'] = IndexParam.HNSW_DEFAULT_EF_SEARCH - else: - raise ValueError(f"unsupported vector index type: {self.index_type}") - # Append OceanBase parameters. - ob_params['lib'] = IndexParam.OCEANBASE_DEFAULT_ALGO_LIB + if self.is_index_type_ivf_serial(): + if (self.is_index_type_product_quantization() and + 'params' not in self.kwargs): + raise ValueError('params must be configured for IVF index type') + + if 'params' not in self.kwargs: + params = {} + else: + params = self.kwargs['params'] + + if self.is_index_type_product_quantization(): + if 'm' not in params: + raise ValueError('m must be configured for IVFSQ or IVFPQ') + ob_params['m'] = params['m'] + if 'nlist' in params: + ob_params['nlist'] = params['nlist'] + if 'samples_per_nlist' in params: + ob_params['samples_per_nlist'] = params['samples_per_nlist'] + + if self.is_index_type_hnsw_serial(): + if 'params' in self.kwargs: + params = self.kwargs['params'] + if 'M' in params: + ob_params['m'] = params['M'] + if 'efConstruction' in params: + ob_params['ef_construction'] = params['efConstruction'] + if 'efSearch' in params: + ob_params['ef_search'] = params['efSearch'] return ob_params def param_str(self): diff --git a/pyobvector/client/ob_vec_client.py b/pyobvector/client/ob_vec_client.py index 8847715..c525bd7 100644 --- a/pyobvector/client/ob_vec_client.py +++ b/pyobvector/client/ob_vec_client.py @@ -21,6 +21,7 @@ import sqlalchemy.sql.functions as func_mod import numpy as np from .index_param import IndexParams, IndexParam +from .fts_index_param import FtsIndexParam from ..schema import ( ObTable, VectorIndex, @@ -33,6 +34,7 @@ st_dwithin, st_astext, ReplaceStmt, + FtsIndex, ) from ..util import ObVersion from .partitions import * @@ -158,6 +160,7 @@ def create_table_with_index_params( columns: List[Column], indexes: Optional[List[Index]] = None, vidxs: Optional[IndexParams] = None, + fts_idxs: Optional[List[FtsIndexParam]] = None, partitions: Optional[ObPartition] = None, ): """Create table with optional index_params. @@ -202,6 +205,16 @@ def create_table_with_index_params( params=vidx.param_str(), ) vidx.create(self.engine, checkfirst=True) + # create fts indexes + if fts_idxs is not None: + for fts_idx in fts_idxs: + idx_cols = [table.c[field_name] for field_name in fts_idx.field_names] + fts_idx = FtsIndex( + fts_idx.index_name, + fts_idx.param_str(), + *idx_cols, + ) + fts_idx.create(self.engine, checkfirst=True) def create_index( self, @@ -254,6 +267,28 @@ def create_vidx_with_vec_index_param( ) vidx.create(self.engine, checkfirst=True) + def create_fts_idx_with_fts_index_param( + self, + table_name: str, + fts_idx_param: FtsIndexParam, + ): + """Create fts index with fts index parameter. + + Args: + table_name (string) : table name + fts_idx_param (FtsIndexParam) : fts index parameter + """ + table = Table(table_name, self.metadata_obj, autoload_with=self.engine) + with self.engine.connect() as conn: + with conn.begin(): + idx_cols = [table.c[field_name] for field_name in fts_idx_param.field_names] + fts_idx = FtsIndex( + fts_idx_param.index_name, + fts_idx_param.param_str(), + *idx_cols, + ) + fts_idx.create(self.engine, checkfirst=True) + def drop_table_if_exist(self, table_name: str): """Drop table if exists.""" try: @@ -556,6 +591,7 @@ def ann_search( extra_output_cols: Optional[List] = None, where_clause=None, partition_names: Optional[List[str]] = None, + idx_name_hint: Optional[List[str]] = None, **kwargs, ): # pylint: disable=unused-argument """perform ann search. @@ -569,6 +605,8 @@ def ann_search( topk (int) : top K output_column_names (Optional[List[str]]) : output fields where_clause : do ann search with filter + idx_name_hint : post-filtering enabled if vector index name is specified + Or pre-filtering enabled """ table = Table(table_name, self.metadata_obj, autoload_with=self.engine) @@ -587,6 +625,13 @@ def ann_search( "[" + ",".join([str(np.float32(v)) for v in vec_data]) + "]", ) ) + # if idx_name_hint is not None: + # stmt = select(*columns).with_hint( + # table, + # f"index(%(name)s {idx_name_hint})", + # "oracle" + # ) + # else: stmt = select(*columns) if where_clause is not None: @@ -607,6 +652,10 @@ def ann_search( ) with self.engine.connect() as conn: with conn.begin(): + if idx_name_hint is not None: + idx = stmt_str.find("SELECT ") + stmt_str = f"SELECT /*+ index({table_name} {idx_name_hint}) */ " + stmt_str[idx + len("SELECT "):] + if partition_names is None: return conn.execute(text(stmt_str)) stmt_str = self._insert_partition_hint_for_query_sql( diff --git a/pyobvector/client/ob_vec_json_table_client.py b/pyobvector/client/ob_vec_json_table_client.py index 17defb2..c0a744a 100644 --- a/pyobvector/client/ob_vec_json_table_client.py +++ b/pyobvector/client/ob_vec_json_table_client.py @@ -78,7 +78,7 @@ def _parse_col_type(cls, col_type: str): if col_type == 'DECIMAL': factory = JsonTableDecimalFactory(10, 0) else: - decimal_pattern = r'DECIMAL\((\d+),\s*(\d+)\)' + decimal_pattern = r'DECIMAL\s*\((\d+),\s*(\d+)\)' decimal_matches = re.findall(decimal_pattern, col_type) x, y = decimal_matches[0] factory = JsonTableDecimalFactory(int(x), int(y)) diff --git a/pyobvector/schema/__init__.py b/pyobvector/schema/__init__.py index c2e746b..6f7096c 100644 --- a/pyobvector/schema/__init__.py +++ b/pyobvector/schema/__init__.py @@ -14,6 +14,9 @@ * st_dwithin GIS function: check if the distance between two points * st_astext GIS function: return a Point in human-readable format * ReplaceStmt Replace into statement based on the extension of SQLAlchemy.Insert +* FtsIndex Full Text Search Index +* CreateFtsIndex Full Text Search Index Creation statement clause +* MatchAgainst Full Text Search clause """ from .vector import VECTOR from .geo_srid_point import POINT @@ -23,6 +26,8 @@ from .gis_func import ST_GeomFromText, st_distance, st_dwithin, st_astext from .replace_stmt import ReplaceStmt from .dialect import OceanBaseDialect, AsyncOceanBaseDialect +from .full_text_index import FtsIndex, CreateFtsIndex +from .match_against_func import MatchAgaint __all__ = [ "VECTOR", @@ -41,4 +46,7 @@ "ReplaceStmt", "OceanBaseDialect", "AsyncOceanBaseDialect", + "FtsIndex", + "CreateFtsIndex", + "MatchAgaint", ] diff --git a/pyobvector/schema/full_text_index.py b/pyobvector/schema/full_text_index.py new file mode 100644 index 0000000..950d25f --- /dev/null +++ b/pyobvector/schema/full_text_index.py @@ -0,0 +1,59 @@ +"""FullTextIndex: full text search index type""" +from sqlalchemy import Index +from sqlalchemy.schema import DDLElement +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.sql.ddl import SchemaGenerator + + +class CreateFtsIndex(DDLElement): + """A new statement clause to create fts index. + + Attributes: + index : fts index schema + """ + def __init__(self, index): + self.index = index + + +class ObFtsSchemaGenerator(SchemaGenerator): + """A new schema generator to handle create fts index statement.""" + def visit_fts_index(self, index, create_ok=False): + """Handle create fts index statement compiling. + + Args: + index: fts index schema + create_ok: the schema is created or not + """ + if not create_ok and not self._can_create_index(index): + return + with self.with_ddl_events(index): + CreateFtsIndex(index)._invoke_with(self.connection) + +class FtsIndex(Index): + """Fts Index schema.""" + __visit_name__ = "fts_index" + + def __init__(self, name, fts_parser: str, *column_names, **kw): + self.fts_parser = fts_parser + super().__init__(name, *column_names, **kw) + + def create(self, bind, checkfirst: bool = False) -> None: + """Create fts index. + + Args: + bind: SQL engine or connection. + checkfirst: check the index exists or not. + """ + bind._run_ddl_visitor(ObFtsSchemaGenerator, self, checkfirst=checkfirst) + + +@compiles(CreateFtsIndex) +def compile_create_fts_index(element, compiler, **kw): # pylint: disable=unused-argument + """A decorator function to compile create fts index statement.""" + index = element.index + table_name = index.table.name + column_list = ", ".join([column.name for column in index.columns]) + fts_parser = index.fts_parser + if fts_parser is not None: + return f"CREATE FULLTEXT INDEX {index.name} ON {table_name} ({column_list}) WITH PARSER {fts_parser}" + return f"CREATE FULLTEXT INDEX {index.name} ON {table_name} ({column_list})" diff --git a/pyobvector/schema/match_against_func.py b/pyobvector/schema/match_against_func.py new file mode 100644 index 0000000..81d3f82 --- /dev/null +++ b/pyobvector/schema/match_against_func.py @@ -0,0 +1,33 @@ +"""match_against_func: An extend system function in FTS.""" + +import logging + +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.sql.functions import FunctionElement +from sqlalchemy import BOOLEAN + +logger = logging.getLogger(__name__) + +class MatchAgaint(FunctionElement): + """MatchAgaint: match clause for full text search. + + Attributes: + type : result type + """ + inherit_cache = True + + def __init__(self, *args): + super().__init__() + self.args = args + +@compiles(MatchAgaint) +def complie_MatchAgaint(element, compiler, **kwargs): # pylint: disable=unused-argument + """Compile MatchAgaint function.""" + args = element.args + if len(args) < 2: + raise ValueError( + f"MatchAgaints should take a string expression and " \ + f"at least one column name string as parameters." + ) + cols = ", ".join(args[1:]) + return f"MATCH ({cols}) AGAINST ('{args[0]}' IN NATURAL LANGUAGE MODE)" diff --git a/pyobvector/schema/reflection.py b/pyobvector/schema/reflection.py index ef80103..f76d7b9 100644 --- a/pyobvector/schema/reflection.py +++ b/pyobvector/schema/reflection.py @@ -33,7 +33,7 @@ def _prep_regexes(self): self._re_key = _re_compile( r" " - r"(?:(SPATIAL|VECTOR|(?P\S+)) )?KEY" + r"(?:(FULLTEXT|SPATIAL|VECTOR|(?P\S+)) )?KEY" # r"(?:(?P\S+) )?KEY" r"(?: +{iq}(?P(?:{esc_fq}|[^{fq}])+){fq})?" r"(?: +USING +(?P\S+))?" @@ -69,10 +69,12 @@ def _parse_constraints(self, line): ret = super()._parse_constraints(line) if ret: tp, spec = ret + + if tp is None: + return ret if tp == "partition": # do not handle partition return ret - # logger.info(f"{tp} {spec}") if tp == "fk_constraint": if len(spec["table"]) == 2 and spec["table"][0] == self.default_schema: spec["table"] = spec["table"][1:] diff --git a/tests/test_fts_index.py b/tests/test_fts_index.py new file mode 100644 index 0000000..0df2576 --- /dev/null +++ b/tests/test_fts_index.py @@ -0,0 +1,99 @@ +import unittest +from pyobvector import * +from sqlalchemy import Column, Integer, text +from sqlalchemy.dialects.mysql import TEXT +import logging + +logger = logging.getLogger(__name__) + +class ObFtsIndexTest(unittest.TestCase): + def setUp(self) -> None: + self.client = ObVecClient() + + def test_fts_index(self): + test_collection_name = "fts_simple_test" + self.client.drop_table_if_exist(test_collection_name) + + cols = [ + Column("id", Integer, primary_key=True, autoincrement=False), + Column("doc", TEXT), + ] + self.client.create_table( + test_collection_name, + columns=cols, + ) + fts_index_param = FtsIndexParam( + index_name="fts_idx", + field_names=["doc"], + parser_type=FtsParser.NGRAM, + ) + self.client.create_fts_idx_with_fts_index_param( + test_collection_name, + fts_idx_param=fts_index_param, + ) + + self.client.drop_table_if_exist(test_collection_name) + + cols = [ + Column("id", Integer, primary_key=True, autoincrement=False), + Column("doc", TEXT), + ] + fts_index_param = FtsIndexParam( + index_name="fts_idx", + field_names=["doc"], + parser_type=FtsParser.NGRAM, + ) + self.client.create_table_with_index_params( + table_name=test_collection_name, + columns=cols, + fts_idxs=[fts_index_param], + ) + + def test_fts_insert_and_search(self): + test_collection_name = "fts_data_test" + self.client.drop_table_if_exist(test_collection_name) + + cols = [ + Column("id", Integer, primary_key=True), + Column("doc", TEXT), + ] + fts_index_param = FtsIndexParam( + index_name="fts_idx", + field_names=["doc"], + parser_type=FtsParser.NGRAM, + ) + self.client.create_table_with_index_params( + table_name=test_collection_name, + columns=cols, + fts_idxs=[fts_index_param], + ) + + datas = [ + { "id": 1, "doc": "pLease porridge in the pot", }, + { "id": 2, "doc": "please say sorry", }, + { "id": 3, "doc": "nine years old", }, + { "id": 4, "doc": "some like it hot, some like it cold", }, + { "id": 5, "doc": "i like coding", }, + { "id": 6, "doc": "i like my company", }, + ] + self.client.insert( + test_collection_name, + data = datas + ) + + res = self.client.get( + test_collection_name, + ids=None, + where_clause=[MatchAgaint('like', 'doc'), text("id > 4")], + output_column_name=["id", "doc"], + ) + self.assertEqual( + set(res.fetchall()), + set( + [ + (5, 'i like coding'), + (6, 'i like my company'), + ] + ) + ) + \ No newline at end of file diff --git a/tests/test_ob_vec_more_algorithm.py b/tests/test_ob_vec_more_algorithm.py new file mode 100644 index 0000000..6f3cd27 --- /dev/null +++ b/tests/test_ob_vec_more_algorithm.py @@ -0,0 +1,337 @@ +import unittest +from pyobvector import * +from sqlalchemy import Column, Integer, JSON, text +from sqlalchemy.dialects.mysql import TEXT +import logging + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +class ObVecMoreAlgorithmTest(unittest.TestCase): + def setUp(self) -> None: + self.client = ObVecClient() + + def test_hnswsq(self): + test_collection_name = "hnswsq_test" + self.client.drop_table_if_exist(test_collection_name) + + cols = [ + Column("id", Integer, primary_key=True, autoincrement=False), + Column("embedding", VECTOR(3)), + Column("meta", JSON), + ] + self.client.create_table( + test_collection_name, columns=cols + ) + + vidx_param = IndexParam( + index_name="vidx", + field_name="embedding", + index_type=VecIndexType.HNSW_SQ, + ) + self.client.create_vidx_with_vec_index_param( + test_collection_name, + vidx_param, + ) + + vector_value1 = [0.748479, 0.276979, 0.555195] + vector_value2 = [0, 0, 0] + data1 = [{"id": i, "embedding": vector_value1} for i in range(10)] + data1.extend([{"id": i, "embedding": vector_value2} for i in range(10, 13)]) + data1.extend([{"id": i, "embedding": vector_value2} for i in range(111, 113)]) + self.client.insert(test_collection_name, data=data1) + + res = self.client.ann_search( + test_collection_name, + vec_data=[0, 0, 0], + vec_column_name="embedding", + distance_func=l2_distance, + with_dist=True, + topk=5, + output_column_names=["id"], + ) + self.assertEqual(set(res.fetchall()), set([(112,0.0), (111,0.0), (10,0.0), (11,0.0), (12,0.0)])) + + def test_ivfflat(self): + test_collection_name = "ivf_test" + self.client.drop_table_if_exist(test_collection_name) + + cols = [ + Column("id", Integer, primary_key=True, autoincrement=False), + Column("embedding", VECTOR(3)), + Column("meta", JSON), + ] + self.client.create_table( + test_collection_name, columns=cols + ) + + vector_value1 = [0.748479, 0.276979, 0.555195] + vector_value2 = [0, 0, 0] + data1 = [{"id": i, "embedding": vector_value1} for i in range(10)] + data1.extend([{"id": i, "embedding": vector_value2} for i in range(10, 13)]) + data1.extend([{"id": i, "embedding": vector_value2} for i in range(111, 113)]) + self.client.insert(test_collection_name, data=data1) + + vidx_param = IndexParam( + index_name="vidx", + field_name="embedding", + index_type=VecIndexType.IVFFLAT, + ) + self.client.create_vidx_with_vec_index_param( + test_collection_name, + vidx_param, + ) + + res = self.client.ann_search( + test_collection_name, + vec_data=[0, 0, 0], + vec_column_name="embedding", + distance_func=l2_distance, + with_dist=True, + topk=5, + output_column_names=["id"], + ) + self.assertEqual(set(res.fetchall()), set([(112,0.0), (111,0.0), (10,0.0), (11,0.0), (12,0.0)])) + + def test_ivfsq(self): + test_collection_name = "ivfsq_test" + self.client.drop_table_if_exist(test_collection_name) + + cols = [ + Column("id", Integer, primary_key=True, autoincrement=False), + Column("embedding", VECTOR(3)), + Column("meta", JSON), + ] + self.client.create_table( + test_collection_name, columns=cols + ) + + vector_value1 = [0.748479, 0.276979, 0.555195] + vector_value2 = [0, 0, 0] + data1 = [{"id": i, "embedding": vector_value1} for i in range(10)] + data1.extend([{"id": i, "embedding": vector_value2} for i in range(10, 13)]) + data1.extend([{"id": i, "embedding": vector_value2} for i in range(111, 113)]) + self.client.insert(test_collection_name, data=data1) + + vidx_param = IndexParam( + index_name="vidx", + field_name="embedding", + index_type=VecIndexType.IVFSQ, + ) + self.client.create_vidx_with_vec_index_param( + test_collection_name, + vidx_param, + ) + + res = self.client.ann_search( + test_collection_name, + vec_data=[0, 0, 0], + vec_column_name="embedding", + distance_func=l2_distance, + with_dist=True, + topk=5, + output_column_names=["id"], + ) + self.assertEqual(set(res.fetchall()), set([(112,0.0), (111,0.0), (10,0.0), (11,0.0), (12,0.0)])) + + def test_ivfpq(self): + test_collection_name = "ivfpq_test" + self.client.drop_table_if_exist(test_collection_name) + + cols = [ + Column("id", Integer, primary_key=True, autoincrement=False), + Column("embedding", VECTOR(4)), + Column("meta", JSON), + ] + self.client.create_table( + test_collection_name, columns=cols + ) + + vector_value1 = [0.748479, 0.276979, 0.555195, 0.13234] + vector_value2 = [0, 0, 0, 0] + data1 = [{"id": i, "embedding": vector_value1} for i in range(10)] + data1.extend([{"id": i, "embedding": vector_value2} for i in range(10, 13)]) + data1.extend([{"id": i, "embedding": vector_value2} for i in range(111, 113)]) + self.client.insert(test_collection_name, data=data1) + + vidx_param = IndexParam( + index_name="vidx", + field_name="embedding", + index_type=VecIndexType.IVFPQ, + params={ + "m": 2, + } + ) + self.client.create_vidx_with_vec_index_param( + test_collection_name, + vidx_param, + ) + + res = self.client.ann_search( + test_collection_name, + vec_data=[0, 0, 0, 0], + vec_column_name="embedding", + distance_func=l2_distance, + with_dist=True, + topk=5, + output_column_names=["id"], + ) + self.assertEqual(set(res.fetchall()), set([(112,0.0), (111,0.0), (10,0.0), (11,0.0), (12,0.0)])) + + def test_vec_fts_hybrid(self): + test_collection_name = "vec_fts_test" + self.client.drop_table_if_exist(test_collection_name) + + cols = [ + Column("id", Integer, primary_key=True, autoincrement=False), + Column("embedding", VECTOR(3)), + Column("doc", TEXT), + ] + self.client.create_table( + test_collection_name, columns=cols + ) + + vidx_param = IndexParam( + index_name="vidx", + field_name="embedding", + index_type=VecIndexType.HNSW, + ) + self.client.create_vidx_with_vec_index_param( + test_collection_name, + vidx_param, + ) + fts_param = FtsIndexParam( + index_name="fts_idx", + field_names=["doc"], + parser_type=None, + ) + self.client.create_fts_idx_with_fts_index_param( + test_collection_name, + fts_param + ) + + datas = [ + { "id": 1, "embedding":[1,2,3], "doc": "pLease porridge in the pot", }, + { "id": 2, "embedding":[0,0,0], "doc": "please say sorry", }, + { "id": 3, "embedding":[1,1,1], "doc": "nine years old", }, + { "id": 4, "embedding":[0,1,0], "doc": "some like it hot, some like it cold", }, + { "id": 5, "embedding":[100,100,100], "doc": "i like coding", }, + { "id": 6, "embedding":[0,0,0], "doc": "i like my company", }, + ] + self.client.insert( + test_collection_name, + data=datas + ) + + # res = self.client.ann_search( + # test_collection_name, + # vec_data=[0, 0, 0], + # vec_column_name="embedding", + # distance_func=l2_distance, + # with_dist=True, + # topk=5, + # output_column_names=["id", "doc"], + # where_clause=[MatchAgaint('like', 'doc')] + # ) + # for r in res.fetchall(): + # logger.info(f"{r[0]} {r[1]}") + + def test_pre_post_filtering(self): + test_collection_name = "pre_post_filtering_test" + self.client.drop_table_if_exist(test_collection_name) + + cols = [ + Column("c1", Integer, primary_key=True, autoincrement=False), + Column("c2", Integer), + Column("c3", Integer), + Column("v", VECTOR(3)), + Column("doc", TEXT), + ] + self.client.create_table( + test_collection_name, columns=cols + ) + vidx_param = IndexParam( + index_name="idx3", + field_name="v", + index_type=VecIndexType.HNSW, + ) + self.client.create_vidx_with_vec_index_param( + test_collection_name, + vidx_param, + ) + self.client.create_index( + test_collection_name, + is_vec_index=False, + index_name="idx1", + column_names=["c2"], + ) + self.client.create_index( + test_collection_name, + is_vec_index=False, + index_name="idx2", + column_names=["c3"], + ) + + datas = [ + { "c1": 1, "c2": 1, "c3": 10, "v": [0.203846,0.205289,0.880265] }, + { "c1": 2, "c2": 2, "c3": 9, "v": [0.226980,0.579658,0.933939] }, + { "c1": 3, "c2": 3, "c3": 8, "v": [0.181664,0.013905,0.628127] }, + { "c1": 4, "c2": 4, "c3": 7, "v": [0.442633,0.637534,0.633993] }, + { "c1": 5, "c2": 5, "c3": 6, "v": [0.190118,0.959676,0.796483] }, + { "c1": 6, "c2": 6, "c3": 5, "v": [0.710370,0.007130,0.710913] }, + { "c1": 7, "c2": 7, "c3": 4, "v": [0.238120,0.289662,0.970101] }, + { "c1": 8, "c2": 8, "c3": 3, "v": [0.168794,0.567442,0.062338] }, + { "c1": 9, "c2": 9, "c3": 2, "v": [0.901419,0.676738,0.122339] }, + { "c1": 10, "c2": 10, "c3": 1, "v": [0.563644,0.811224,0.175574] }, + ] + self.client.insert( + test_collection_name, + data=datas + ) + + res = self.client.ann_search( + test_collection_name, + vec_data=[0.712338,0.603321,0.133444], + vec_column_name="v", + distance_func=l2_distance, + topk=5, + output_column_names=["c1", "c2", "c3"], + where_clause=[text("c2 > 5 and c3 < 6")], + idx_name_hint="idx3" + ) + self.assertEqual( + set(res.fetchall()), + set( + [ + (9, 9, 2), + (10, 10, 1), + (8, 8, 3), + (6, 6, 5), + (7, 7, 4), + ] + ) + ) + + res = self.client.ann_search( + test_collection_name, + vec_data=[0.712338,0.603321,0.133444], + vec_column_name="v", + distance_func=l2_distance, + topk=5, + output_column_names=["c1", "c2", "c3"], + where_clause=[text("c2 > 5 and c3 < 6")], + idx_name_hint="idx1" + ) + self.assertEqual( + set(res.fetchall()), + set( + [ + (9, 9, 2), + (10, 10, 1), + (8, 8, 3), + (6, 6, 5), + (7, 7, 4), + ] + ) + ) diff --git a/tests/test_reflection.py b/tests/test_reflection.py index ddc1ebc..66fd9ff 100644 --- a/tests/test_reflection.py +++ b/tests/test_reflection.py @@ -1,6 +1,5 @@ import unittest from pyobvector import * -from pyobvector import VECTOR import logging logger = logging.getLogger(__name__) @@ -15,7 +14,8 @@ def test_reflection(self): `embeddings` VECTOR(1024) DEFAULT NULL, `metadata` json DEFAULT NULL, PRIMARY KEY (`id`), - VECTOR KEY `vidx` (`embeddings`) WITH (DISTANCE=L2,M=16,EF_CONSTRUCTION=256,LIB=VSAG,TYPE=HNSW, EF_SEARCH=64) BLOCK_SIZE 16384 + VECTOR KEY `vidx` (`embeddings`) WITH (DISTANCE=L2,M=16,EF_CONSTRUCTION=256,LIB=VSAG,TYPE=HNSW, EF_SEARCH=64) BLOCK_SIZE 16384, + FULLTEXT KEY `idx_content_fts` (`content`) WITH PARSER ik PARSER_PROPERTIES=(ik_mode="smart") BLOCK_SIZE 16384 ) DEFAULT CHARSET = utf8mb4 ROW_FORMAT = DYNAMIC COMPRESSION = 'zstd_1.3.8' REPLICA_NUM = 1 BLOCK_SIZE = 16384 USE_BLOOM_FILTER = FALSE TABLET_SIZE = 134217728 PCTFREE = 0 """ dialect._tabledef_parser.parse(ddl, "utf8") From e3e5bb13f8bb67b1e2208764d5f082abd513f994 Mon Sep 17 00:00:00 2001 From: "shanhaikang.shk" Date: Thu, 13 Mar 2025 11:39:29 +0800 Subject: [PATCH 2/5] support limit in ObVecClient::get Signed-off-by: shanhaikang.shk --- pyobvector/client/ob_vec_client.py | 4 ++++ tests/test_fts_index.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/pyobvector/client/ob_vec_client.py b/pyobvector/client/ob_vec_client.py index c525bd7..15852e1 100644 --- a/pyobvector/client/ob_vec_client.py +++ b/pyobvector/client/ob_vec_client.py @@ -516,6 +516,7 @@ def get( where_clause = None, output_column_name: Optional[List[str]] = None, partition_names: Optional[List[str]] = None, + n_limits: Optional[int] = None, ): """get records with specified primary field `ids`. @@ -549,6 +550,9 @@ def get( stmt = stmt.where(*where_clause) elif where_in_clause is not None and where_clause is not None: stmt = stmt.where(and_(where_in_clause, *where_clause)) + + if n_limits is not None: + stmt = stmt.limit(n_limits) with self.engine.connect() as conn: with conn.begin(): diff --git a/tests/test_fts_index.py b/tests/test_fts_index.py index 0df2576..d3607a5 100644 --- a/tests/test_fts_index.py +++ b/tests/test_fts_index.py @@ -86,13 +86,13 @@ def test_fts_insert_and_search(self): ids=None, where_clause=[MatchAgaint('like', 'doc'), text("id > 4")], output_column_name=["id", "doc"], + n_limits=1, ) self.assertEqual( set(res.fetchall()), set( [ (5, 'i like coding'), - (6, 'i like my company'), ] ) ) From e54498ff474209080223187474e930d811148822 Mon Sep 17 00:00:00 2001 From: "shanhaikang.shk" Date: Mon, 17 Mar 2025 10:36:43 +0800 Subject: [PATCH 3/5] add more test case Signed-off-by: shanhaikang.shk --- tests/test_json_table.py | 20 ++++++++++++++++++++ tests/test_ob_vec_client.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/tests/test_json_table.py b/tests/test_json_table.py index 21c5e9a..a2808fc 100644 --- a/tests/test_json_table.py +++ b/tests/test_json_table.py @@ -364,3 +364,23 @@ def test_col_name_conflict(self): (2, 'bob'), ] ) + + def test_timestamp_datatype(self): + self.root_client._reset() + self.client.refresh_metadata() + self.client.perform_json_table_sql( + "create table `t1` (c1 int DEFAULT NULL, c2 TIMESTAMP);" + ) + + self.client.perform_json_table_sql( + "insert into t1 values (1, CURRENT_DATE - INTERVAL '1' MONTH);" + ) + + self.client.perform_json_table_sql( + "select * from t1" + ) + + def test_online_cases(self): + self.root_client._reset() + self.client.refresh_metadata() + diff --git a/tests/test_ob_vec_client.py b/tests/test_ob_vec_client.py index fbd81bb..843ebe8 100644 --- a/tests/test_ob_vec_client.py +++ b/tests/test_ob_vec_client.py @@ -119,6 +119,36 @@ def test_delete_get(self): def test_set_variable(self): self.client.set_ob_hnsw_ef_search(100) self.assertEqual(self.client.get_ob_hnsw_ef_search(), 100) + + def test_create_index_dup(self): + test_collection_name = "ob_create_index_dup_test" + self.client.drop_table_if_exist(test_collection_name) + + cols = [ + Column("id", String(64), primary_key=True, autoincrement=False), + Column("embedding", VECTOR(3)), + Column("meta", JSON), + ] + self.client.create_table( + test_collection_name, columns=cols + ) + + # create vector index + self.client.create_index( + test_collection_name, + is_vec_index=True, + index_name="vidx", + column_names=["embedding"], + vidx_params="distance=l2, type=hnsw, lib=vsag", + ) + + self.client.create_index( + test_collection_name, + is_vec_index=True, + index_name="vidx", + column_names=["embedding"], + vidx_params="distance=l2, type=hnsw, lib=vsag", + ) if __name__ == "__main__": unittest.main() From 54e8ac635c9eb04f32b9e59578495c5f66196736 Mon Sep 17 00:00:00 2001 From: "shanhaikang.shk" Date: Sun, 23 Mar 2025 11:06:03 +0800 Subject: [PATCH 4/5] update ob docker version Signed-off-by: shanhaikang.shk --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 46cbbc8..45500f8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -31,7 +31,7 @@ jobs: - name: Run Docker container run: | - docker run --name ob433 -e MODE=slim -p 2881:2881 -d quay.io/oceanbase/oceanbase-ce:4.3.3.0-100000142024101215 + docker run --name ob435 -e MODE=slim -p 2881:2881 -d oceanbase/oceanbase-ce:4.3.5.1-101000042025031818 - name: Wait for container to be ready run: | From cae4445b807804c6d9c72fd5c148c2b5193ca2ea Mon Sep 17 00:00:00 2001 From: "shanhaikang.shk" Date: Sun, 23 Mar 2025 11:12:59 +0800 Subject: [PATCH 5/5] fix docker name Signed-off-by: shanhaikang.shk --- .github/workflows/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 45500f8..e7c83ed 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -37,7 +37,7 @@ jobs: run: | timeout=300 while [ $timeout -gt 0 ]; do - if docker logs ob433 | grep -q 'boot success!'; then + if docker logs ob435 | grep -q 'boot success!'; then echo "Container is ready." break fi @@ -56,7 +56,7 @@ jobs: OCEANBASE_USER: 'root@test' OCEANBASE_PASS: '' run: | - docker exec ob433 obclient -h $OCEANBASE_HOST -P $OCEANBASE_PORT -u $OCEANBASE_USER -p$OCEANBASE_PASS -e "ALTER SYSTEM ob_vector_memory_limit_percentage = 30; create user 'jtuser'@'%'; GRANT SELECT, INSERT, UPDATE, DELETE ON test.* TO 'jtuser'@'%'; FLUSH PRIVILEGES;" + docker exec ob435 obclient -h $OCEANBASE_HOST -P $OCEANBASE_PORT -u $OCEANBASE_USER -p$OCEANBASE_PASS -e "ALTER SYSTEM ob_vector_memory_limit_percentage = 30; create user 'jtuser'@'%'; GRANT SELECT, INSERT, UPDATE, DELETE ON test.* TO 'jtuser'@'%'; FLUSH PRIVILEGES;" - name: Run tests run: |