diff --git a/app/config/settings.py b/app/config/settings.py index 856f5ff..b5417fe 100644 --- a/app/config/settings.py +++ b/app/config/settings.py @@ -33,6 +33,13 @@ class ElasticsearchSettings(BaseModel): """Elasticsearch 相关配置""" url: str + number_of_shards: int + number_of_replicas: int + index_max_result_window: int + index_refresh_interval: str + index_option_type: str + index_option_m: int + index_option_ef_construction: int metadata_index_suffix: str chunk_index_suffix: str request_timeout: int = 15 @@ -44,7 +51,6 @@ class EmbedderSettings(BaseModel): model_name: str dimensions: int similarity_metric: str - index_type: str class RerankerSettings(BaseModel): @@ -82,9 +88,16 @@ class RetrievalSettings(BaseModel): multiplier: int = Field(5, description="召回倍数配置") vector_weight: float = Field(2.0, description="向量搜索权重") + vector_similarity: float = Field(0.7, description="相似度") text_weight: float = Field(1.0, description="文本搜索权重") +class SearchSettings(BaseModel): + """搜索相关配置""" + + max_top_k: int = Field(50, description="最大top_k值限制") + + class TencentOssSettings(BaseModel): """ 腾讯云对象存储相关配置。 @@ -114,6 +127,7 @@ class Settings(BaseSettings): storage: StorageSettings upload: UploadSettings retrieval: RetrievalSettings + search: SearchSettings @property def cos_config(self) -> CosConfig: diff --git a/app/service/elasticsearch.py b/app/service/elasticsearch.py index 8bbd3e9..5fc4e2f 100644 --- a/app/service/elasticsearch.py +++ b/app/service/elasticsearch.py @@ -97,11 +97,11 @@ def _ensure_metadata_index_exists(self, metadata_index: str) -> None: if not self._client.indices.exists(index=metadata_index): body = { "settings": { - "number_of_shards": 1, - "number_of_replicas": 0, + "number_of_shards": self._settings.elasticsearch.number_of_shards, + "number_of_replicas": self._settings.elasticsearch.number_of_replicas, "index": { - "max_result_window": 10000, - "refresh_interval": "1s", + "max_result_window": self._settings.elasticsearch.index_max_result_window, + "refresh_interval": self._settings.elasticsearch.index_refresh_interval, }, }, "mappings": { @@ -138,11 +138,11 @@ def _ensure_chunk_index_exists(self, chunk_index: str) -> None: if not self._client.indices.exists(index=chunk_index): body = { "settings": { - "number_of_shards": 1, - "number_of_replicas": 0, + "number_of_shards": self._settings.elasticsearch.number_of_shards, + "number_of_replicas": self._settings.elasticsearch.number_of_replicas, "index": { - "max_result_window": 10000, - "refresh_interval": "1s", + "max_result_window": self._settings.elasticsearch.index_max_result_window, + "refresh_interval": self._settings.elasticsearch.index_refresh_interval, }, }, "mappings": { @@ -159,9 +159,9 @@ def _ensure_chunk_index_exists(self, chunk_index: str) -> None: "similarity": self._embedder.similarity_metric, "index": True, "index_options": { - "type": self._settings.embedder.index_type, - "m": 32, - "ef_construction": 100, + "type": self._settings.elasticsearch.index_option_type, + "m": self._settings.elasticsearch.index_option_m, + "ef_construction": self._settings.elasticsearch.index_option_ef_construction, }, }, "chunk_index": {"type": "integer"}, @@ -207,7 +207,9 @@ def store_for_vector_hybrid_search(self, document: Document) -> str: metadata_index, chunk_index = self._ensure_indexes_exist( document.index_prefix ) - + logger.info( + f"向量混合搜索: 元数据索引名={metadata_index} 分片索引名={chunk_index}" + ) metadata_id = self._create_metadata(metadata_index, document) document.id = metadata_id # 确保 document 对象持有 ID logger.info(f"元数据占位符创建成功,ID: {metadata_id}") @@ -359,10 +361,11 @@ def search(self, parameters: SearchParameters) -> SearchResult: ) # 执行ES搜索 + logger.info(f"在 {parameters.index_name} 上执行查询: {search_body}") response = self._client.search( index=parameters.index_name, body=search_body ) - + logger.info(f"查询结果: {response}") # 计算搜索耗时 search_time_ms = int((time.time() - start_time) * 1000) @@ -416,57 +419,57 @@ def _build_hybrid_search_body( ES查询体 """ # 获取文本查询进行向量化 - text_query: str | None = None - for condition in search_conditions["vector"]: - if isinstance(condition.value, str): - text_query = condition.value - if not text_query: - raise ValueError("向量混合搜索需要文本查询内容") + text_query = cast("str", search_conditions["vector"][0].value) # 生成查询向量 query_vector = self._embedder.embed_documents([text_query])[0] # 计算召回数量(用于后续重排序) - retrieval_size = parameters.limit * self._settings.retrieval.multiplier + k = parameters.limit * self._settings.retrieval.multiplier + vector_similarity = self._settings.retrieval.vector_similarity # 获取权重配置 vector_weight = self._settings.retrieval.vector_weight text_weight = self._settings.retrieval.text_weight + # # 确保 num_candidates 至少为 k 的 2 倍或 100,取较大值 + num_candidates = max(k * 2, 100) + # 构建混合搜索查询体 search_body: dict[str, Any] = { - "size": retrieval_size, + "size": parameters.limit, "_source": ["content", "file_metadata_id"], # 只返回需要的字段 "knn": { "field": "content_vector", # 固定向量字段 "query_vector": query_vector, - "k": retrieval_size, - "num_candidates": 100, + "k": k, + "num_candidates": num_candidates, "boost": vector_weight, + "similarity": vector_similarity, }, "query": { "bool": { - "should": [ - # 普通匹配 + "must": [ { "match": { "content": { "query": text_query, - "boost": text_weight * 0.5, + "boost": text_weight * 0.7, # 基础匹配权重 } } - }, - # 短语匹配 + } + ], + "should": [ { "match_phrase": { "content": { "query": text_query, - "boost": text_weight * 0.3, + "boost": text_weight * 0.3, # 短语匹配加分 } } - }, + } ], - "minimum_should_match": 0, + "minimum_should_match": 0, # should是纯加分项 } }, } @@ -554,7 +557,7 @@ def _convert_to_search_result( # 根据搜索类型处理结果 if is_hybrid_search: documents = self._process_hybrid_search_results( - cast("str", search_conditions["vector"][0].value), hits, limit + cast("str", search_conditions["vector"][0].value), hits ) else: documents = self._process_structured_search_results(hits) @@ -569,7 +572,6 @@ def _process_hybrid_search_results( self, text_query: str, hits: list[dict[str, Any]], - limit: int, ) -> list[DocumentResult]: """ 处理混合搜索结果:去重 + 重排序 @@ -603,7 +605,7 @@ def _process_hybrid_search_results( unique_chunks.append(chunk) # 重排 - return self._reranker.rerank(text_query, unique_chunks)[:limit] + return self._reranker.rerank(text_query, unique_chunks) @staticmethod def _process_structured_search_results( diff --git a/app/utils/converters/search.py b/app/utils/converters/search.py index 94ffb77..3f82719 100644 --- a/app/utils/converters/search.py +++ b/app/utils/converters/search.py @@ -58,9 +58,11 @@ def request_vo_to_domain(request: SearchRequest) -> SearchParameters: conditions = [ SearchCondition( field_name=cond.field, - mode=SearchMode.TERM - if cond.op == ConditionOperator.TERM - else SearchMode.MATCH, + mode=( + SearchMode.TERM + if cond.op == ConditionOperator.TERM + else SearchMode.MATCH + ), value=cond.value, ) for cond in request.query.conditions @@ -83,7 +85,7 @@ def result_domain_to_vo( if search_type == SearchType.VECTOR_HYBRID: results = [ VectorHybridSearchResult( - text=doc.content.get("text", ""), + text=doc.content.get("content", ""), file_metadata_id=doc.content.get("file_metadata_id", ""), score=doc.score, ) @@ -100,4 +102,4 @@ def result_domain_to_vo( if doc.id ] - return SearchResponse(type=search_type, results=results) + return SearchResponse(results=results) diff --git a/app/web/document.py b/app/web/document.py index 391e629..dd8f559 100644 --- a/app/web/document.py +++ b/app/web/document.py @@ -19,6 +19,7 @@ from pathlib import Path from urllib.parse import urlparse +from elasticsearch import NotFoundError from fastapi import ( APIRouter, BackgroundTasks, @@ -426,6 +427,10 @@ async def search(self, request: SearchRequest) -> SearchResponse: f"✅ 搜索完成, 返回{len(domain_response.documents)}条结果" ) return resp + except NotFoundError as e: + raise HTTPException( + status_code=404, detail=f"索引 {request.query.index} 不存在" + ) from e except Exception as e: logger.error(f"❌ 搜索失败: {e}", exc_info=True) raise HTTPException(status_code=500, detail="搜索处理失败") from e diff --git a/app/web/vo.py b/app/web/vo.py index 8071bff..0754e78 100644 --- a/app/web/vo.py +++ b/app/web/vo.py @@ -20,6 +20,8 @@ from pydantic import BaseModel, Field, HttpUrl, Json, field_validator from pydantic_core.core_schema import ValidationInfo +from app.config.settings import settings + class FileUploadResponse(BaseModel): """文件上传后的标准响应模型""" @@ -75,6 +77,16 @@ class Condition(BaseModel): ..., description="字段值,支持多种类型" ) + @field_validator("value") + @classmethod + def validate_value_not_empty_string( + cls, v: str | int | float | bool + ) -> str | int | float | bool: + """验证字符串值不能为空""" + if isinstance(v, str) and v.strip() == "": + raise ValueError("字符串类型的查询值不能为空") + return v + class Query(BaseModel): """查询对象""" @@ -93,13 +105,16 @@ class SearchRequest(BaseModel): type: SearchType = Field(..., description="搜索类型") query: Query = Field(..., description="查询条件") - top_k: int = Field(..., ge=1, description="返回结果数量,至少为1") + top_k: int = Field( + ..., + ge=1, + le=settings.search.max_top_k, + description="返回结果数量 1 <= top_k <= 配置文件中的max_top_k", + ) @field_validator("query") @classmethod - def validate_query_for_search_type( - cls, v: Query, info: ValidationInfo - ) -> Query: + def validate_query(cls, v: Query, info: ValidationInfo) -> Query: """根据搜索类型验证查询条件""" search_type = info.data.get("type") @@ -134,7 +149,6 @@ class StructuredSearchResult(BaseModel): class SearchResponse(BaseModel): """搜索响应""" - type: SearchType = Field(..., description="搜索类型") # 保持一致性 results: list[VectorHybridSearchResult | StructuredSearchResult] = Field( default_factory=list, description="搜索结果" ) diff --git a/config.yaml b/config.yaml index d6ddca4..12d1e57 100644 --- a/config.yaml +++ b/config.yaml @@ -1,5 +1,12 @@ elasticsearch: url: "http://localhost:9200" + number_of_shards: 1 + number_of_replicas: 0 + index_max_result_window: 10000 + index_refresh_interval: 1s + index_option_type: "int8_hnsw" + index_option_m: 32 # 控制HNSW图中每个节点可以连接的最大邻居节点数量 + index_option_ef_construction: 100 # 索引构建时每个节点考虑的候选邻居数量,影响索引质量。 metadata_index_suffix: "_metadatas" chunk_index_suffix: "_chunks" request_timeout: 60 @@ -26,7 +33,12 @@ upload: - ".pdf" - ".md" - ".txt" + retrieval: multiplier: 5 # 召回倍数配置 vector_weight: 2.0 # 向量搜索权重 - text_weight: 1.0 # 文本搜索权重 \ No newline at end of file + vector_similarity: 0.1 # 向量搜索相似度阈值 + text_weight: 1.0 # 文本搜索权重 + +search: + max_top_k: 50 # 最大top_k值限制 \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 35f3c8b..9430a46 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,9 +17,9 @@ """ import logging -from collections.abc import Generator +from collections.abc import Callable, Generator from pathlib import Path -from typing import Any +from typing import Any, overload import pytest from elasticsearch import Elasticsearch @@ -48,6 +48,41 @@ def cos_client() -> CosS3Client: return client +@pytest.fixture(scope="class") +def user_upload_dir() -> Path: + """提供 '用户准备上传' 的文件目录路径。""" + path = Path(__file__).parent / "fixtures" / "files" / "user" + path.mkdir(exist_ok=True, parents=True) + return path + + +@pytest.fixture(scope="class") +def get_user_upload_file( + user_upload_dir: Path, +) -> Callable[[str | list[str]], Path | list[Path]]: + """从 'user' 目录轻松获取文件路径。""" + + @overload + def _builder(file_names: str) -> Path: ... + + @overload # noqa: F811 + def _builder(file_names: list[str]) -> list[Path]: ... + + def _builder(file_names: str | list[str]) -> Path | list[Path]: # noqa: F811 + if isinstance(file_names, str): + path = user_upload_dir / file_names + if not path.exists(): + raise FileNotFoundError(f"源文件不存在: {path}") + return path + + paths = [user_upload_dir / name for name in file_names] + if not all(p.exists() for p in paths): + raise FileNotFoundError("一个或多个源文件不存在。") + return paths + + return _builder # type: ignore[return-value] + + @pytest.fixture(scope="module") def client() -> Generator[TestClient, Any, None]: """V2 API 测试客户端""" diff --git a/tests/fixtures/config.yaml b/tests/fixtures/config.yaml index f9df47c..14db2ab 100644 --- a/tests/fixtures/config.yaml +++ b/tests/fixtures/config.yaml @@ -1,5 +1,12 @@ elasticsearch: url: "http://localhost:9200" + number_of_shards: 1 + number_of_replicas: 0 + index_max_result_window: 10000 + index_refresh_interval: 1s + index_option_type: "int8_hnsw" + index_option_m: 32 # 控制HNSW图中每个节点可以连接的最大邻居节点数量 + index_option_ef_construction: 100 # 索引构建时每个节点考虑的候选邻居数量,影响索引质量。 metadata_index_suffix: "_metadatas" chunk_index_suffix: "_chunks" request_timeout: 15 @@ -8,7 +15,7 @@ embedder: model_name: "shibing624/text2vec-base-chinese" dimensions: 768 similarity_metric: "cosine" - index_type: "int8_hnsw" + reranker: model_name: "BAAI/bge-reranker-base" @@ -30,4 +37,8 @@ upload: retrieval: multiplier: 5 vector_weight: 2.0 # 向量搜索权重 + vector_similarity: 0.1 # 向量搜索相似度阈值 text_weight: 1.0 # 文本搜索权重 + +search: + max_top_k: 50 # 最大top_k值限制 \ No newline at end of file diff --git a/tests/web/document/structured_search_test.py b/tests/web/document/structured_search_test.py index 925237c..342e10e 100644 --- a/tests/web/document/structured_search_test.py +++ b/tests/web/document/structured_search_test.py @@ -78,7 +78,7 @@ def _prepare_test_data( "data": { "role": "后端", "level": "高级", - "content": "MySQL 性能优化是后端开发的重要技能,包括索引优化、查询优化等", + "content": "MySQL设计及性能优化是后端开发的重要技能,包括索引优化、查询优化等", "department": "技术部", "status": "active", "tags": "database", @@ -114,7 +114,7 @@ def _prepare_test_data( "data": { "role": "后端", "level": "高级", - "content": "分布式系统设计与微服务架构实践指南", + "content": "分布式系统设计微服务架构实践指南", "department": "技术部", "status": "active", "tags": "architecture", @@ -174,7 +174,6 @@ def test_single_exact_match(self, client: TestClient) -> None: data = response.json() # 验证响应结构和数据 - assert data["type"] == "structured" assert len(data["results"]) == 3 # 3个后端开发者 # 验证具体结果 @@ -194,7 +193,7 @@ def test_single_exact_match(self, client: TestClient) -> None: assert result["document"]["role"] == "后端" def test_multiple_exact_match(self, client: TestClient) -> None: - """测试多个精确匹配 - 角色+级别(AND关系)""" + """测试多个精确匹配""" response = client.post( "/api/v1/search", json={ @@ -213,7 +212,6 @@ def test_multiple_exact_match(self, client: TestClient) -> None: assert response.status_code == 200 data = response.json() - assert data["type"] == "structured" assert len(data["results"]) == 2 # 2个高级后端开发者 # 验证具体结果 @@ -247,9 +245,7 @@ def test_single_full_text_match(self, client: TestClient) -> None: assert response.status_code == 200 data = response.json() - - assert data["type"] == "structured" - assert len(data["results"]) >= 1 + assert len(data["results"]) == 2 # 验证结果包含相关内容 found_mysql_doc = False @@ -260,7 +256,7 @@ def test_single_full_text_match(self, client: TestClient) -> None: assert found_mysql_doc, "应该找到包含MySQL的文档" def test_multiple_full_text_match(self, client: TestClient) -> None: - """测试多个全文搜索 - 系统+架构(AND关系)""" + """测试多个全文搜索""" response = client.post( "/api/v1/search", json={ @@ -269,7 +265,7 @@ def test_multiple_full_text_match(self, client: TestClient) -> None: "index": self.TEST_INDEX, "conditions": [ {"field": "content", "op": "match", "value": "系统"}, - {"field": "content", "op": "match", "value": "架构"}, + {"field": "content", "op": "match", "value": "模式"}, ], }, "top_k": 3, @@ -279,16 +275,14 @@ def test_multiple_full_text_match(self, client: TestClient) -> None: assert response.status_code == 200 data = response.json() - assert data["type"] == "structured" assert len(data["results"]) == 1 - assert data["results"][0]["id"] == "backend_senior_2" # 验证AND关系:必须同时包含两个关键词 for result in data["results"]: content = result["document"]["content"] - assert "系统" in content and "架构" in content, ( - f"文档内容应同时包含'系统'和'架构': {content}" - ) + assert ("系" in content or "统" in content) and ( + "模" in content or "式" in content + ), f"文档内容应同时包含'系''统''模''式': {content}" def test_mixed_exact_and_full_text_match(self, client: TestClient) -> None: """测试混合搜索 - 精确匹配+全文搜索""" @@ -301,7 +295,7 @@ def test_mixed_exact_and_full_text_match(self, client: TestClient) -> None: "conditions": [ {"field": "role", "op": "term", "value": "后端"}, {"field": "status", "op": "term", "value": "active"}, - {"field": "content", "op": "match", "value": "优化"}, + {"field": "content", "op": "match", "value": "设计"}, ], }, "top_k": 3, @@ -310,16 +304,13 @@ def test_mixed_exact_and_full_text_match(self, client: TestClient) -> None: assert response.status_code == 200 data = response.json() - - assert data["type"] == "structured" - assert len(data["results"]) == 1 - assert data["results"][0]["id"] == "backend_senior_1" + assert len(data["results"]) == 2 for result in data["results"]: doc = result["document"] assert doc["role"] == "后端" assert doc["status"] == "active" - assert "优化" in doc["content"] + assert "设计" in doc["content"] def test_complex_mixed_conditions(self, client: TestClient) -> None: """测试复杂混合条件 - 部门+级别+内容""" @@ -354,8 +345,7 @@ def test_complex_mixed_conditions(self, client: TestClient) -> None: assert response.status_code == 200 data = response.json() - assert data["type"] == "structured" - assert len(data["results"]) >= 1 + assert len(data["results"]) == 3 for result in data["results"]: doc = result["document"] @@ -376,8 +366,8 @@ def test_with_range_filters(self, client: TestClient) -> None: ], "filters": { "range": { - "salary": {"gte": 20000} # 薪资>=20000的过滤条件 - } + "salary": {"gte": 20000} + } # 薪资>=20000的过滤条件 }, }, "top_k": 5, @@ -387,8 +377,7 @@ def test_with_range_filters(self, client: TestClient) -> None: assert response.status_code == 200 data = response.json() - assert data["type"] == "structured" - assert len(data["results"]) >= 1 + assert len(data["results"]) == 2 # 验证具体地期望结果 expected_ids = {"backend_senior_1", "backend_senior_2"} @@ -421,7 +410,6 @@ def test_top_k_limit(self, client: TestClient) -> None: assert response.status_code == 200 data = response.json() - assert data["type"] == "structured" assert len(data["results"]) <= 2 # 最多返回2个结果 assert ( len(data["results"]) == 2 @@ -450,14 +438,13 @@ def test_no_results_found(self, client: TestClient) -> None: assert response.status_code == 200 data = response.json() - assert data["type"] == "structured" assert len(data["results"]) == 0 # ===== 参数验证测试 ===== # 注:这些测试复用结构化搜索的测试环境 # 如果向量混合搜索需要不同的校验规则,在vector_hybrid_search_test.py中单独添加 - def test_invalid_search_type(self, client: TestClient) -> None: + def test_invalid_type(self, client: TestClient) -> None: """测试无效搜索类型""" response = client.post( "/api/v1/search", @@ -493,7 +480,7 @@ def test_invalid_operator(self, client: TestClient) -> None: assert response.status_code == 422 - def test_missing_required_fields(self, client: TestClient) -> None: + def test_missing_field_value(self, client: TestClient) -> None: """测试缺少必需字段""" response = client.post( "/api/v1/search", @@ -501,7 +488,7 @@ def test_missing_required_fields(self, client: TestClient) -> None: "query": { "index": self.TEST_INDEX, "conditions": [ - {"field": "role", "op": "term", "value": "后端"} + {"field": "role", "op": "term", "value": ""} ], }, "top_k": 3, diff --git a/tests/web/document/upload_endpoint_test.py b/tests/web/document/upload_endpoint_test.py index 1815899..07c79cd 100644 --- a/tests/web/document/upload_endpoint_test.py +++ b/tests/web/document/upload_endpoint_test.py @@ -15,7 +15,6 @@ import time from collections.abc import Callable from pathlib import Path -from typing import overload import pytest from elasticsearch import Elasticsearch @@ -34,41 +33,6 @@ class TestUploadEndpoint: FILE_UPLOAD_INDEX_PREFIX = "test_file_upload" URL_UPLOAD_INDEX_PREFIX = "test_url_upload" - @pytest.fixture(scope="class") - def user_upload_dir(self) -> Path: - """提供 '用户准备上传' 的文件目录路径。""" - path = ( - Path(__file__).parent.parent.parent / "fixtures" / "files" / "user" - ) - path.mkdir(exist_ok=True, parents=True) - return path - - @pytest.fixture(scope="class") - def get_user_upload_file( - self, user_upload_dir: Path - ) -> Callable[[str | list[str]], Path | list[Path]]: - """从 'user' 目录轻松获取文件路径。""" - - @overload - def _builder(file_names: str) -> Path: ... - - @overload # noqa: F811 - def _builder(file_names: list[str]) -> list[Path]: ... - - def _builder(file_names: str | list[str]) -> Path | list[Path]: # noqa: F811 - if isinstance(file_names, str): - path = user_upload_dir / file_names - if not path.exists(): - raise FileNotFoundError(f"源文件不存在: {path}") - return path - - paths = [user_upload_dir / name for name in file_names] - if not all(p.exists() for p in paths): - raise FileNotFoundError("一个或多个源文件不存在。") - return paths - - return _builder # type: ignore[return-value] - @staticmethod def _get_metadata_index_name(index_prefix: str) -> str: """获取metadata索引名""" diff --git a/tests/web/document/vector_hybrid_search_test.py b/tests/web/document/vector_hybrid_search_test.py new file mode 100644 index 0000000..cec9c1e --- /dev/null +++ b/tests/web/document/vector_hybrid_search_test.py @@ -0,0 +1,414 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http:#www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from collections.abc import Generator +from pprint import pprint +from typing import Any + +import pytest +from elasticsearch import Elasticsearch +from fastapi.testclient import TestClient +from httpx import Response + +from app.config.settings import settings + + +class TestVectorHybridSearch: + """向量混合搜索测试 + + 包含: + 1. 向量混合搜索功能测试 (基于upload接口上传的数据) + 2. 混合搜索特有参数验证测试 + 3. 错误处理和边界条件测试 + + 使用场景: + - 通过upload接口上传的文档,使用search接口查询时,要用type=vector_hybrid + """ + + INDEX_PREFIX = "test_vector_hybrid_url" + + @pytest.fixture(scope="class", autouse=True) + def setup_environment( + self, + client: TestClient, + es_client: Elasticsearch, + ) -> Generator[None, Any, None]: + """准备测试环境(索引+数据)""" + + # 1. 清理已存在的索引 + self._cleanup_indexes(es_client, self.INDEX_PREFIX) + + # 2. 准备测试数据 + self._prepare_test_data(client, es_client, self.INDEX_PREFIX) + + # 3. 执行所有测试 + yield + + # 4. 清理测试索引 + self._cleanup_indexes(es_client, self.INDEX_PREFIX) + + def _cleanup_indexes( + self, es_client: Elasticsearch, index_prefix: str + ) -> None: + """清理测试索引""" + metadata_index = self._get_metadata_index_name(index_prefix) + chunk_index = self._get_chunk_index_name(index_prefix) + + try: + if es_client.indices.exists(index=metadata_index): + es_client.indices.delete(index=metadata_index) + if es_client.indices.exists(index=chunk_index): + es_client.indices.delete(index=chunk_index) + except Exception as e: + print(f"⚠️ 清理索引时出错: {e}") + + @staticmethod + def _get_metadata_index_name(index_prefix: str) -> str: + """获取metadata索引名""" + return index_prefix + settings.elasticsearch.metadata_index_suffix + + @staticmethod + def _get_chunk_index_name(index_prefix: str) -> str: + """获取chunk索引名""" + return index_prefix + settings.elasticsearch.chunk_index_suffix + + def _prepare_test_data( + self, + client: TestClient, + es_client: Elasticsearch, + index_prefix: str, + ) -> None: + """准备向量混合搜索测试数据""" + + print("\n🚀 开始准备向量混合搜索测试环境") + print(f"🔗 URL上传索引前缀: {self.INDEX_PREFIX}") + + # 通过URL上传准备数据 + self._upload_test_url(client) + + # 等待数据处理完成 + self._wait_for_test_data_ready(index_prefix, es_client) + + print("✅ 向量搜索测试数据准备完成") + + def _upload_test_url(self, client: TestClient) -> None: + """通过URL上传准备测试数据""" + bucket_name = settings.tencent_oss.bucket + cos_url = f"https://{bucket_name}.cos.{settings.tencent_oss.region}.myqcloud.com/kbase-temp/02_test.pdf" + + response = client.post( + "/api/v1/documents/upload-from-url", + json={ + "url": cos_url, + "index_prefix": self.INDEX_PREFIX, + }, + ) + + assert response.status_code == 200, f"URL上传失败: {response.json()}" + task_id = response.json()["task_id"] + print(f"🔗 URL上传任务创建成功: {task_id}") + self._wait_for_task_completion(client, task_id) + + @staticmethod + def _wait_for_task_completion( + client: TestClient, task_id: str, max_wait: int = 30 + ) -> None: + """等待后台任务完成""" + for _ in range(max_wait): + response = client.get(f"/api/v1/tasks/{task_id}") + if response.status_code == 200: + status = response.json()["status"] + if status == "completed": + return + elif status.startswith("failed"): + pytest.fail(f"任务处理失败: {status}") + time.sleep(1) + pytest.fail(f"任务处理超时: {task_id}") + + def _wait_for_test_data_ready( + self, index_prefix: str, es_client: Elasticsearch + ) -> None: + """等待测试数据就绪""" + print("⏳ 等待上传任务完成和数据索引...") + + max_wait_time = 30 + start_time = time.time() + + while time.time() - start_time < max_wait_time: + # 检查URL上传索引 + url_metadata_count = self._get_index_doc_count( + es_client, self._get_metadata_index_name(index_prefix) + ) + url_chunk_count = self._get_index_doc_count( + es_client, self._get_chunk_index_name(index_prefix) + ) + + print( + f"📊 当前数据统计: URL上传(metadata: {url_metadata_count}, chunks: {url_chunk_count}) " + ) + + # 检查是否都有数据了 + if url_metadata_count > 0 and url_chunk_count > 0: + print("✅ 所有上传任务完成") + return + + time.sleep(2) + + # 超时但尽量继续测试 + print("⚠️ 等待上传超时,但继续进行测试") + + @staticmethod + def _get_index_doc_count(es_client: Elasticsearch, index_name: str) -> int: + """获取索引中的文档数量""" + if not es_client.indices.exists(index=index_name): + return 0 + + # 刷新索引确保数据可见 + es_client.indices.refresh(index=index_name) + + try: + response = es_client.count(index=index_name) + return int(response["count"]) + except Exception as e: + print(f"获取索引中文档总数失败:{e}") + return 0 + + # ===== 向量混合搜索功能测试 ===== + + def test_hybrid_search(self, client: TestClient) -> None: + """测试基础向量混合搜索 - 基于URL上传的数据""" + value = "统计型数据" + response = client.post( + "/api/v1/search", + json={ + "type": "vector_hybrid", + "query": { + "index": self._get_chunk_index_name(self.INDEX_PREFIX), + "conditions": [ + { + "field": "content", + "op": "match", + "value": value, + } + ], + }, + "top_k": 3, + }, + ) + pprint(f"查询字段:{value}") + self._assert_response(response) + + @staticmethod + def _assert_response(response: Response) -> None: + assert response.status_code == 200 + data = response.json() + + # 验证响应结构 + assert isinstance(data["results"], list) + + # 验证结果格式 + for result in data["results"]: + assert "text" in result # VectorHybridSearchResult格式 + assert "file_metadata_id" in result + assert "score" in result + assert isinstance(result["score"], int | float) + # 不应该包含StructuredSearchResult的字段 + assert "id" not in result + assert "document" not in result + pprint(result) + + def test_semantic_similarity(self, client: TestClient) -> None: + """测试语义相似性搜索""" + value = "用户中心并发过高" + response = client.post( + "/api/v1/search", + json={ + "type": "vector_hybrid", + "query": { + "index": self._get_chunk_index_name(self.INDEX_PREFIX), + "conditions": [ + { + "field": "content", + "op": "match", + "value": value, + } + ], + }, + "top_k": 2, + }, + ) + pprint(f"查询字段:{value}") + self._assert_response(response) + + # ===== 参数验证测试 ===== + + def test_invalid_multiple_conditions(self, client: TestClient) -> None: + """测试向量混合搜索不允许多个条件""" + response = client.post( + "/api/v1/search", + json={ + "type": "vector_hybrid", + "query": { + "index": self._get_chunk_index_name(self.INDEX_PREFIX), + "conditions": [ + {"field": "content", "op": "match", "value": "Python"}, + { + "field": "content", + "op": "match", + "value": "机器学习", + }, + ], + }, + "top_k": 3, + }, + ) + + assert response.status_code == 422 # Validation error + + def test_invalid_term_condition(self, client: TestClient) -> None: + """测试向量混合搜索不允许term操作符""" + response = client.post( + "/api/v1/search", + json={ + "type": "vector_hybrid", + "query": { + "index": self._get_chunk_index_name(self.INDEX_PREFIX), + "conditions": [ + {"field": "content", "op": "term", "value": "Python"} + ], + }, + "top_k": 3, + }, + ) + + assert response.status_code == 422 # Validation error + + def test_empty_condition_value(self, client: TestClient) -> None: + """测试空查询处理""" + response = client.post( + "/api/v1/search", + json={ + "type": "vector_hybrid", + "query": { + "index": self._get_chunk_index_name(self.INDEX_PREFIX), + "conditions": [ + {"field": "content", "op": "match", "value": ""} + ], + }, + "top_k": 5, + }, + ) + + # 空查询应该返回 422 验证错误 + assert response.status_code == 422 + + def test_nonexistent_index(self, client: TestClient) -> None: + """测试不存在的索引""" + response = client.post( + "/api/v1/search", + json={ + "type": "vector_hybrid", + "query": { + "index": "不存在的索引_chunks", + "conditions": [ + {"field": "content", "op": "match", "value": "测试"} + ], + }, + "top_k": 3, + }, + ) + + assert response.status_code == 404 + + def test_search_with_filters( + self, es_client: Elasticsearch, client: TestClient + ) -> None: + """测试带过滤条件的向量混合搜索""" + value = "缓存" + chunk_index_number = self._get_index_doc_count( + es_client, self._get_chunk_index_name(self.INDEX_PREFIX) + ) + response = client.post( + "/api/v1/search", + json={ + "type": "vector_hybrid", + "query": { + "index": self._get_chunk_index_name(self.INDEX_PREFIX), + "conditions": [ + {"field": "content", "op": "match", "value": value} + ], + "filters": { + "range": { + "chunk_index": {"gte": chunk_index_number - 1} + } + }, # 包含所有chunk + }, + "top_k": 5, + }, + ) + + pprint(f"查询字段:{value}") + assert response.status_code == 200 + data = response.json() + assert isinstance(data["results"], list) + assert len(data["results"]) == 1 + pprint(f"{data['results']}") + + def test_top_k_limit(self, client: TestClient) -> None: + """测试top_k参数限制""" + for top_k in [1, 2, 3]: + response = client.post( + "/api/v1/search", + json={ + "type": "vector_hybrid", + "query": { + "index": self._get_chunk_index_name(self.INDEX_PREFIX), + "conditions": [ + {"field": "content", "op": "match", "value": "中心"} + ], + }, + "top_k": top_k, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert isinstance(data["results"], list) + assert len(data["results"]) <= top_k + print( + f"📊 Top-K={top_k} 限制测试通过: 返回 {len(data['results'])} 个结果" + ) + + def test_score_ordering(self, client: TestClient) -> None: + """测试结果按分数排序""" + response = client.post( + "/api/v1/search", + json={ + "type": "vector_hybrid", + "query": { + "index": self._get_chunk_index_name(self.INDEX_PREFIX), + "conditions": [ + {"field": "content", "op": "match", "value": "缓存"} + ], + }, + "top_k": 3, + }, + ) + + assert response.status_code == 200 + data = response.json() + + # 验证分数降序排列 + scores = [result["score"] for result in data["results"]] + assert scores == sorted(scores, reverse=True), "结果应该按分数降序排列" + print(f"📊 分数排序验证通过: {scores}")