diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index 51227d7..0000000 Binary files a/.DS_Store and /dev/null differ diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..a1cf0c0 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,18 @@ +[run] +source = app +omit = + */__init__.py + */tests/* + */__pycache__/* + app/main.py + app/config/settings.py + +[report] +exclude_lines = + pragma: no cover + def __repr__ + if self.debug: + raise AssertionError + raise NotImplementedError + if __name__ == .__main__.: + if TYPE_CHECKING: \ No newline at end of file diff --git a/.gitignore b/.gitignore index 9108296..c703d8d 100644 --- a/.gitignore +++ b/.gitignore @@ -205,4 +205,11 @@ cython_debug/ marimo/_static/ marimo/_lsp/ __marimo__/ +# macOS .DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2322890..0d36bde 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,20 +33,19 @@ repos: types: [python] # 4. 重要检查(仅在push时) - # - repo: local - # hooks: - # - id: pytest - # name: pytest - # language: system - # entry: uv - # args: ["run", "pytest", "-q"] - # pass_filenames: false - # stages: [pre-push] - - # - id: pip-audit - # name: pip-audit - # language: system - # entry: uv - # args: ["run", "pip-audit", "--strict"] - # pass_filenames: false - # stages: [pre-push] \ No newline at end of file +# - repo: local +# hooks: +# - id: pytest +# name: pytest +# language: system +# entry: uv +# args: ["run", "pytest", "-q"] +# pass_filenames: false +# stages: [pre-push] +# - id: pip-audit +# name: pip-audit +# language: system +# entry: uv +# args: ["run", "pip-audit", "--strict"] +# pass_filenames: false +# stages: [pre-push] \ No newline at end of file diff --git a/Makefile b/Makefile index abb2919..d79419a 100644 --- a/Makefile +++ b/Makefile @@ -16,6 +16,7 @@ help: @echo " lint ✨ 检查代码并自动修复问题" @echo " type 🔍 类型检查" @echo " test 🧪 运行测试并生成覆盖率报告" + @echo " cov 🧪 运行测试并打开覆盖率报告" @echo " run ▶️ 启动开发服务器" @echo " pre-commit 🔄 运行预提交检查" @echo " audit 🛡️ 扫描依赖中的安全漏洞" @@ -65,7 +66,12 @@ check: fmt lint type .PHONY: test test: - @uv run pytest -q --cov=app --cov-report=term-missing --cov-report=xml + @uv run pytest -v --cov=app --cov-report=term-missing --cov-report=xml + +.PHONY: cov +cov: test + @echo "🌐 打开覆盖率报告..." + @open htmlcov/index.html || xdg-open htmlcov/index.html || echo "请打开: htmlcov/index.html" .PHONY: audit audit: diff --git a/README.md b/README.md index 0d2d797..c4d09c4 100644 --- a/README.md +++ b/README.md @@ -80,17 +80,18 @@ make run ## 📋 开发命令 -| 命令 | 描述 | -|------|------| -| `make setup` | 🚀 一键设置完整开发环境 | -| `make check` | ✅ 运行所有代码质量检查 | -| `make test` | 🧪 运行测试并生成覆盖率报告 | -| `make run` | ▶️ 启动开发服务器 | -| `make fmt` | 🎨 格式化代码 | -| `make lint` | ✨ 检查代码并自动修复 | -| `make type` | 🔍 类型检查 | -| `make audit` | 🛡️ 扫描安全漏洞 | -| `make clean` | 🧹 清理临时文件 | +| 命令 | 描述 | +|--------------|-----------------| +| `make setup` | 🚀 一键设置完整开发环境 | +| `make check` | ✅ 运行所有代码质量检查 | +| `make test` | 🧪 运行测试并生成覆盖率报告 | +| `make cov` | 🧪 运行测试并打开覆盖率报告 | +| `make run` | ▶️ 启动开发服务器 | +| `make fmt` | 🎨 格式化代码 | +| `make lint` | ✨ 检查代码并自动修复 | +| `make type` | 🔍 类型检查 | +| `make audit` | 🛡️ 扫描安全漏洞 | +| `make clean` | 🧹 清理临时文件 | ## 🔧 API 接口 @@ -103,14 +104,15 @@ make run ### 主要端点 -| 端点 | 方法 | 描述 | -|------|------|------| -| `/` | GET | API 根路径和信息 | -| `/api/v1/health` | GET | 健康检查 | -| `/api/v1/documents/upload-file` | POST | 本地文件上传 | -| `/api/v1/documents/upload-from-url` | POST | 从COS URL上传 | -| `/api/v1/search` | POST | 文档搜索 | -| `/api/v1/tasks/{task_id}` | GET | 查询任务状态 | +| 端点 | 方法 | 描述 | +|-------------------------------------|------|----------------| +| `/` | GET | API 根路径和信息 | +| `/api/v1/health` | GET | 健康检查 | +| `/api/v1/documents/upload-file` | POST | 本地文件上传 | +| `/api/v1/documents/upload-from-url` | POST | 从COS URL上传 | +| `/api/v1/documents/save` | POST | 以JSON格式字符串上传文档 | +| `/api/v1/search` | POST | 文档搜索 | +| `/api/v1/tasks/{task_id}` | GET | 查询任务状态 | ### 健康检查 ```bash diff --git a/app/config/settings.py b/app/config/settings.py index fcbfd30..856f5ff 100644 --- a/app/config/settings.py +++ b/app/config/settings.py @@ -33,8 +33,8 @@ class ElasticsearchSettings(BaseModel): """Elasticsearch 相关配置""" url: str - metadata_index: str = "file_metadatas" - chunk_index: str = "file_chunks" + metadata_index_suffix: str + chunk_index_suffix: str request_timeout: int = 15 diff --git a/app/domain/document.py b/app/domain/document.py index d43a1ff..520afb0 100644 --- a/app/domain/document.py +++ b/app/domain/document.py @@ -1,9 +1,22 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http:#www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from dataclasses import dataclass, field from typing import Any @dataclass class Document: + index_prefix: str path: str size: int category: str | None = None diff --git a/app/domain/search.py b/app/domain/search.py index 3390d9d..bf57047 100644 --- a/app/domain/search.py +++ b/app/domain/search.py @@ -12,30 +12,54 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass, field +from dataclasses import dataclass +from enum import Enum from typing import Any -@dataclass -class SearchRequest: - """封装搜索请求,新增 mode 和 filters。""" +class SearchMode(str, Enum): + """基础查询模式""" + + VECTOR = "vector" # 向量搜索 + TERM = "term" # 精确匹配 + MATCH = "match" # 模糊匹配 + + +@dataclass(frozen=True) +class SearchCondition: + """搜索条件 - 值对象""" - query: str - top_k: int = 5 - filters: dict[str, Any] | None = field(default_factory=dict) + field_name: str + mode: SearchMode + value: str | int | float | bool + + +@dataclass(frozen=True) +class SearchParameters: + """搜索参数 - 值对象""" + + index_name: str + conditions: list[SearchCondition] + limit: int = 10 + filters: dict[str, Any] | None = None @dataclass -class ContextChunk: - """定义一个上下文块,用于最终返回结果。""" +class DocumentResult: + """文档结果 - 值对象""" - text: str - file_metadata_id: str + content: dict[str, Any] score: float + id: str | None = None @dataclass -class SearchResponse: - """定义最终的搜索响应格式。""" +class SearchResult: + """搜索结果 - 聚合根""" + + documents: list[DocumentResult] + total_count: int + search_time_ms: int - context: list[ContextChunk] + def is_empty(self) -> bool: + return len(self.documents) == 0 diff --git a/app/main.py b/app/main.py index 3aa3c94..f2ae188 100644 --- a/app/main.py +++ b/app/main.py @@ -30,7 +30,7 @@ from app.utils.loaders.dispatcher import DispatcherLoader from app.utils.rerankers.bge import BgeReranker from app.utils.splitters import RecursiveCharacterTextSplitter -from app.web.handler import DocumentHandler +from app.web.document import DocumentHandler # 配置标准日志 logging.basicConfig( @@ -62,8 +62,6 @@ splitter=splitter, embedder=embedder, reranker=reranker, - metadata_index=settings.elasticsearch.metadata_index, - chunk_index=settings.elasticsearch.chunk_index, settings=settings, ) logger.info("✅ 核心服务组件初始化成功。") diff --git a/app/service/elasticsearch.py b/app/service/elasticsearch.py index daeb75d..8bbd3e9 100644 --- a/app/service/elasticsearch.py +++ b/app/service/elasticsearch.py @@ -14,16 +14,24 @@ import logging import os +import time from datetime import UTC, datetime -from typing import Any, Protocol +from typing import Any, Protocol, cast +from elastic_transport import ObjectApiResponse from elasticsearch import Elasticsearch from elasticsearch.helpers import bulk from langchain_core.documents import Document as LangChainDocument from app.config.settings import Settings from app.domain.document import Document -from app.domain.search import ContextChunk, SearchRequest, SearchResponse +from app.domain.search import ( + DocumentResult, + SearchCondition, + SearchMode, + SearchParameters, + SearchResult, +) logger = logging.getLogger(__name__) @@ -57,8 +65,8 @@ class Reranker(Protocol): """重排器接口,负责对初步检索结果进行精排。""" def rerank( - self, query: str, results: list[ContextChunk] - ) -> list[ContextChunk]: ... + self, query: str, results: list[DocumentResult] + ) -> list[DocumentResult]: ... class ElasticsearchService: @@ -69,24 +77,24 @@ def __init__( splitter: Splitter, embedder: Embedder, reranker: Reranker, - metadata_index: str, - chunk_index: str, settings: Settings, ) -> None: - self.client = client - self.loader = loader - self.splitter = splitter - self.embedder = embedder - self.reranker = reranker - self.metadata_index = metadata_index - self.chunk_index = chunk_index - self.settings = settings - self._ensure_metadata_index_exists() - self._ensure_chunk_index_exists() - - def _ensure_metadata_index_exists(self) -> None: + self._client = client + self._loader = loader + self._splitter = splitter + self._embedder = embedder + self._reranker = reranker + self._settings = settings + + def _metadata_index_name(self, index_prefix: str) -> str: + return index_prefix + self._settings.elasticsearch.metadata_index_suffix + + def _chunk_index_name(self, index_prefix: str) -> str: + return index_prefix + self._settings.elasticsearch.chunk_index_suffix + + def _ensure_metadata_index_exists(self, metadata_index: str) -> None: """确保索引 metadata_index 存在""" - if not self.client.indices.exists(index=self.metadata_index): + if not self._client.indices.exists(index=metadata_index): body = { "settings": { "number_of_shards": 1, @@ -121,13 +129,13 @@ def _ensure_metadata_index_exists(self) -> None: } try: - self.client.indices.create(index=self.metadata_index, body=body) + self._client.indices.create(index=metadata_index, body=body) except Exception as e: raise e - def _ensure_chunk_index_exists(self) -> None: + def _ensure_chunk_index_exists(self, chunk_index: str) -> None: """确保索引chunk_index存在""" - if not self.client.indices.exists(index=self.chunk_index): + if not self._client.indices.exists(index=chunk_index): body = { "settings": { "number_of_shards": 1, @@ -147,11 +155,11 @@ def _ensure_chunk_index_exists(self) -> None: }, "content_vector": { "type": "dense_vector", - "dims": self.embedder.dimensions, - "similarity": self.embedder.similarity_metric, + "dims": self._embedder.dimensions, + "similarity": self._embedder.similarity_metric, "index": True, "index_options": { - "type": self.settings.embedder.index_type, + "type": self._settings.embedder.index_type, "m": 32, "ef_construction": 100, }, @@ -167,30 +175,52 @@ def _ensure_chunk_index_exists(self) -> None: }, } try: - self.client.indices.create(index=self.chunk_index, body=body) + self._client.indices.create(index=chunk_index, body=body) except Exception as e: raise e - def store(self, document: Document) -> str: + def _ensure_indexes_exist(self, index_prefix: str) -> tuple[str, str]: + """确保索引存在并返回索引名 + + Args: + index_prefix: 索引前缀 + + Returns: + tuple: (metadata_index, chunk_index) + """ + metadata_index = self._metadata_index_name(index_prefix) + chunk_index = self._chunk_index_name(index_prefix) + + self._ensure_metadata_index_exists(metadata_index) + self._ensure_chunk_index_exists(chunk_index) + + return metadata_index, chunk_index + + def store_for_vector_hybrid_search(self, document: Document) -> str: """ 将文档存入双索引系统。 1. 存储文件元数据到 file_metadatas。 2. 切分文件为chunks,并将 chunks相关信息 及对应的源文件的 file_metadata_id 存入 file_chunks。 :return: 在 file_metadatas 中生成的文档 ID。 """ - metadata_id = self._create_metadata(document) + + metadata_index, chunk_index = self._ensure_indexes_exist( + document.index_prefix + ) + + metadata_id = self._create_metadata(metadata_index, document) document.id = metadata_id # 确保 document 对象持有 ID logger.info(f"元数据占位符创建成功,ID: {metadata_id}") try: # 尝试创建和存储 chunks。失败会抛出异常。 - created_chunks_count = self._create_chunks(document) + created_chunks_count = self._create_chunks(chunk_index, document) logger.info(f"成功存储 {created_chunks_count} 个文档块。") # Chunks 存储成功后,才更新元数据中的 total_chunks now_millis = int(datetime.now(UTC).timestamp() * 1000) - self.client.update( - index=self.metadata_index, + self._client.update( + index=metadata_index, id=metadata_id, body={ "doc": { @@ -203,20 +233,20 @@ def store(self, document: Document) -> str: logger.info( f"元数据更新成功,total_chunks 已写入: {created_chunks_count}。" ) - self.client.indices.refresh(index=self.chunk_index) + self._client.indices.refresh(index=chunk_index) return metadata_id except Exception as e: # 如果上述 try 块中任何一步失败,执行回滚操作 logger.error(f"文档处理失败,错误: {e}。正在回滚元数据...") - self.client.delete( - index=self.metadata_index, id=metadata_id, refresh=True + self._client.delete( + index=metadata_index, id=metadata_id, refresh=True ) logger.info(f"元数据 {metadata_id} 已被成功删除。") # 重新抛出异常,让上层调用者知道操作失败 raise RuntimeError("文档存储失败,已回滚。") from e - def _create_metadata(self, document: Document) -> str: + def _create_metadata(self, metadata_index: str, document: Document) -> str: """ 根据文档创建并存储元数据索引数据 :return: 在 file_metadatas 中生成的文档 ID。 @@ -232,21 +262,21 @@ def _create_metadata(self, document: Document) -> str: "created_at": now_millis, "updated_at": now_millis, } - meta_response = self.client.index( - index=self.metadata_index, document=doc, refresh="wait_for" + meta_response = self._client.index( + index=metadata_index, document=doc, refresh="wait_for" ) return str(meta_response["_id"]) - def _create_chunks(self, document: Document) -> int: + def _create_chunks(self, chunk_index: str, document: Document) -> int: """增强错误处理,支持部分回滚""" if not document.id: raise ValueError("文档ID未设置") - chunks = self.splitter.split_documents(self.loader.load(document)) + chunks = self._splitter.split_documents(self._loader.load(document)) if not chunks: raise RuntimeError("未提取出任何文本块") - content_vectors = self.embedder.embed_documents( + content_vectors = self._embedder.embed_documents( [chunk.page_content for chunk in chunks] ) @@ -258,7 +288,7 @@ def _create_chunks(self, document: Document) -> int: zip(chunks, content_vectors, strict=True) ): doc = { - "_index": self.chunk_index, + "_index": chunk_index, "_id": f"{document.id}_{i}", "file_metadata_id": document.id, "content": chunk.page_content, @@ -273,7 +303,7 @@ def _create_chunks(self, document: Document) -> int: chunk_ids.append(str(doc["_id"])) success, failed = bulk( - client=self.client, + client=self._client, actions=chunk_docs, stats_only=False, raise_on_error=False, @@ -281,7 +311,7 @@ def _create_chunks(self, document: Document) -> int: if failed: # 清理已成功写入的chunks - self._cleanup_chunks(chunk_ids[:success]) + self._cleanup_chunks(chunk_index, chunk_ids[:success]) raise RuntimeError(f"批量写入失败: {failed}") return success @@ -289,62 +319,126 @@ def _create_chunks(self, document: Document) -> int: except Exception: # 确保清理所有可能已写入的chunks if chunk_ids: - self._cleanup_chunks(chunk_ids) + self._cleanup_chunks(chunk_index, chunk_ids) raise - def _cleanup_chunks(self, chunk_ids: list[str]) -> None: + def _cleanup_chunks(self, chunk_index: str, chunk_ids: list[str]) -> None: """清理指定的chunks""" for chunk_id in chunk_ids: try: - self.client.delete(index=self.chunk_index, id=chunk_id) + self._client.delete(index=chunk_index, id=chunk_id) except Exception as e: logger.error(f"删除文档分失败,错误: {e}。") pass - # 在 ElasticsearchService 中添加 - def delete_document(self, metadata_id: str) -> bool: - """删除文档及其所有chunks""" - try: - # 先删除所有相关chunks - self.client.delete_by_query( - index=self.chunk_index, - body={"query": {"term": {"file_metadata_id": metadata_id}}}, - refresh=True, + def search(self, parameters: SearchParameters) -> SearchResult: + """ + 执行搜索 - 支持多种搜索模式 + + Args: + parameters: 搜索参数,包含索引、条件、限制等 + + Returns: + SearchResult: 统一的搜索结果 + """ + start_time = time.time() + + # 按搜索模式分类条件 + search_conditions = self._classify_conditions(parameters.conditions) + + # 根据条件类型构建查询 + if search_conditions["vector"] and search_conditions["match"]: + # 向量+全文混合搜索(兼容旧版本) + search_body = self._build_hybrid_search_body( + parameters, search_conditions ) - # 删除元数据 - self.client.delete( - index=self.metadata_index, id=metadata_id, refresh=True + else: + # 纯结构化搜索(新版本) + search_body = self._build_structured_search_body( + parameters, search_conditions ) - return True - except Exception as e: - logger.error(f"删除文档失败: {e}") - return False - def search(self, request: SearchRequest) -> SearchResponse: - """在 file_chunks 索引中执行搜索。""" - if not request.query: - # 如果查询为空,可以考虑返回最近的文件等,这里暂时返回空 - return SearchResponse(context=[]) + # 执行ES搜索 + response = self._client.search( + index=parameters.index_name, body=search_body + ) - standard_query, filter_clause = self._build_filtered_queries( - request.query, request.filters + # 计算搜索耗时 + search_time_ms = int((time.time() - start_time) * 1000) + + # 转换为Domain对象并返回 + return self._convert_to_search_result( + response, search_time_ms, parameters.limit, search_conditions ) - # 定义召回阶段要获取的文档数量,应大于最终的 top_k - # 这是一个超参数,可以根据需求调整 - retrieval_size = request.top_k * self.settings.retrieval.multiplier - query_vector = self.embedder.embed_documents([request.query])[0] + @staticmethod + def _classify_conditions( + conditions: list[SearchCondition], + ) -> dict[str, list[SearchCondition]]: + """ + 按搜索模式分类条件 + + Args: + conditions: 搜索条件列表 + + Returns: + 分类后的条件字典 + """ + classified: dict[str, list[SearchCondition]] = { + "vector": [], + "match": [], + "term": [], + } + + for condition in conditions: + if condition.mode == SearchMode.VECTOR: + classified["vector"].append(condition) + elif condition.mode == SearchMode.MATCH: + classified["match"].append(condition) + elif condition.mode == SearchMode.TERM: + classified["term"].append(condition) + + return classified + + def _build_hybrid_search_body( + self, + parameters: SearchParameters, + search_conditions: dict[str, list[SearchCondition]], + ) -> dict[str, Any]: + """ + 构建向量+全文混合搜索查询体(兼容旧版本) + + Args: + parameters: 搜索参数 + search_conditions: 分类后的搜索条件 - # 使用实用的混合搜索语法:knn + query 组合(兼容 ES 9.x) - vector_weight = self.settings.retrieval.vector_weight - text_weight = self.settings.retrieval.text_weight + Returns: + ES查询体 + """ + # 获取文本查询进行向量化 + text_query: str | None = None + for condition in search_conditions["vector"]: + if isinstance(condition.value, str): + text_query = condition.value + if not text_query: + raise ValueError("向量混合搜索需要文本查询内容") + + # 生成查询向量 + query_vector = self._embedder.embed_documents([text_query])[0] + + # 计算召回数量(用于后续重排序) + retrieval_size = parameters.limit * self._settings.retrieval.multiplier - # 构建搜索体 + # 获取权重配置 + vector_weight = self._settings.retrieval.vector_weight + text_weight = self._settings.retrieval.text_weight + + # 构建混合搜索查询体 search_body: dict[str, Any] = { "size": retrieval_size, - "_source": ["content", "file_metadata_id"], + "_source": ["content", "file_metadata_id"], # 只返回需要的字段 "knn": { - "field": "content_vector", + "field": "content_vector", # 固定向量字段 "query_vector": query_vector, "k": retrieval_size, "num_candidates": 100, @@ -353,18 +447,20 @@ def search(self, request: SearchRequest) -> SearchResponse: "query": { "bool": { "should": [ + # 普通匹配 { "match": { "content": { - "query": request.query, + "query": text_query, "boost": text_weight * 0.5, } } }, + # 短语匹配 { "match_phrase": { "content": { - "query": request.query, + "query": text_query, "boost": text_weight * 0.3, } } @@ -375,78 +471,193 @@ def search(self, request: SearchRequest) -> SearchResponse: }, } - # 如果有过滤条件,添加到搜索体中 - if filter_clause: - # 为 knn 添加过滤器 - search_body["knn"]["filter"] = filter_clause - # 为 query 添加过滤器 - search_body["query"]["bool"]["filter"] = filter_clause - - # 执行混合搜索 - response = self.client.search(index=self.chunk_index, body=search_body) - - # 格式化召回结果 - retrieved_chunks = [ - ContextChunk( - text=hit["_source"]["content"], - file_metadata_id=hit["_source"]["file_metadata_id"], + # 添加过滤条件 + if parameters.filters: + # 为knn查询添加过滤器 + search_body["knn"]["filter"] = parameters.filters + # 为全文查询添加过滤器 + search_body["query"]["bool"]["filter"] = parameters.filters + + return search_body + + @staticmethod + def _build_structured_search_body( + parameters: SearchParameters, + search_conditions: dict[str, list[SearchCondition]], + ) -> dict[str, Any]: + """ + 构建结构化搜索查询体(新版本) + + Args: + parameters: 搜索参数 + search_conditions: 分类后的搜索条件 + + Returns: + ES查询体 + """ + bool_query: dict[str, Any] = {"bool": {"must": []}} + + # 添加MATCH查询条件 + for condition in search_conditions["match"]: + bool_query["bool"]["must"].append( + {"match": {condition.field_name: {"query": condition.value}}} + ) + + # 添加TERM查询条件 + for condition in search_conditions["term"]: + bool_query["bool"]["must"].append( + {"term": {condition.field_name: condition.value}} + ) + + search_body: dict[str, Any] = { + "size": parameters.limit, + "query": bool_query, + } + + # 添加过滤条件 + if parameters.filters: + bool_query["bool"]["filter"] = parameters.filters + + return search_body + + def _convert_to_search_result( + self, + response: ObjectApiResponse[Any], + search_time_ms: int, + limit: int, + search_conditions: dict[str, list[SearchCondition]], + ) -> SearchResult: + """ + 将ES响应转换为Domain搜索结果 + + Args: + response: ES查询响应 + search_time_ms: 搜索耗时(毫秒) + limit: 限制返回结果的个数 + search_conditions: 分类后的搜索条件 + + Returns: + SearchResult: Domain层搜索结果 + """ + hits = response["hits"]["hits"] + + # 获取总数 + total_count: int = 0 + if isinstance(response["hits"]["total"], dict): + total_count = response["hits"]["total"]["value"] + + # 判断是否为混合搜索 + is_hybrid_search = bool( + search_conditions["vector"] and search_conditions["match"] + ) + + # 根据搜索类型处理结果 + if is_hybrid_search: + documents = self._process_hybrid_search_results( + cast("str", search_conditions["vector"][0].value), hits, limit + ) + else: + documents = self._process_structured_search_results(hits) + + return SearchResult( + documents=documents, + total_count=total_count, + search_time_ms=search_time_ms, + ) + + def _process_hybrid_search_results( + self, + text_query: str, + hits: list[dict[str, Any]], + limit: int, + ) -> list[DocumentResult]: + """ + 处理混合搜索结果:去重 + 重排序 + + Args: + hits: ES查询命中结果 + + Returns: + 处理后的文档结果列表 + """ + + chunks = [ + DocumentResult( + content=hit["_source"], score=hit["_score"] if hit["_score"] is not None else 0.0, ) - for hit in response["hits"]["hits"] + for hit in hits ] - # 去重 + # 去重处理 seen = set() unique_chunks = [] - for chunk in retrieved_chunks: - identifier = (chunk.text, chunk.file_metadata_id) + + for chunk in chunks: + identifier = ( + chunk.content["content"], + chunk.content["file_metadata_id"], + ) if identifier not in seen: seen.add(identifier) unique_chunks.append(chunk) # 重排 - reranked_chunks = self.reranker.rerank(request.query, unique_chunks) + return self._reranker.rerank(text_query, unique_chunks)[:limit] - # 截取最终的 top_k - final_context = reranked_chunks[: request.top_k] + @staticmethod + def _process_structured_search_results( + hits: list[dict[str, Any]], + ) -> list[DocumentResult]: + """ + 处理结构化搜索结果:直接转换 - return SearchResponse(context=final_context) + Args: + hits: ES查询命中结果 - @staticmethod - def _build_filtered_queries( - query: str, filters: dict[str, Any] | None - ) -> tuple[dict[str, Any], list[dict[str, Any]]]: + Returns: + 文档结果列表 """ - 根据用户查询和过滤器,动态构建用于 standard 和 knn 检索器的查询。 + documents = [] + for hit in hits: + documents.append( + DocumentResult( + id=hit["_id"], # 使用ES文档ID + content=hit["_source"], # 完整文档内容 + score=hit["_score"] if hit["_score"] is not None else 0.0, + ) + ) + return documents - :param query: 用户的查询字符串。 - :param filters: 一个包含字段和期望值的字典,用于过滤。 - :return: 一个元组,包含: - - standard_query (dict): 用于 standard retriever 的完整查询体。 - - knn_filter (list): 用于 knn retriever 的过滤器列表。 + def save_for_structured_search( + self, index_name: str, doc_id: str, doc_dict: dict[str, Any] + ) -> None: """ - # 准备 standard retriever 的核心 query 部分 (match query) - standard_query_part = { - "match": {"content": {"query": query, "boost": 0.5}} - } + 保存文档到Elasticsearch索引,如果文档已存在则整体覆盖 - # 根据 filters 构建 filter_clause - filter_clause: list[dict[str, Any]] = [] - if filters: - for field, value in filters.items(): - if isinstance(value, list): - filter_clause.append({"terms": {field: value}}) - else: - filter_clause.append({"term": {field: value}}) - - # 简单情况:如果没有过滤器,直接返回最简单的查询 - if not filter_clause: - # 此处 standard_query_part 就是最终的查询 - return standard_query_part, filter_clause - - # 构建复杂的 bool 查询并返回 - standard_query = { - "bool": {"must": [standard_query_part], "filter": filter_clause} - } + Args: + index_name: ES索引名称 + doc_id: 文档ID + doc_dict: 文档内容字典 - return standard_query, filter_clause + Raises: + RuntimeError: 文档存储失败时抛出 + """ + try: + # 插入或完整覆盖 + response = self._client.index( + index=index_name, + id=doc_id, + document=doc_dict, + refresh="wait_for", + ) + # 记录操作结果 + operation = ( + "创建" if response.get("result") == "created" else "覆盖" + ) + logger.info(f"文档 {doc_id} 在索引 {index_name} 中{operation}成功") + + except Exception as e: + error_msg = f"文档存储失败 - 索引: {index_name}, 文档ID: {doc_id}" + logger.error(f"{error_msg},错误: {e}") + raise RuntimeError(f"{error_msg}: {str(e)}") from e diff --git a/app/utils/converters/__init__.py b/app/utils/converters/__init__.py new file mode 100644 index 0000000..7986741 --- /dev/null +++ b/app/utils/converters/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2021 ecodeclub +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http:#www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +嵌入器包 +""" + +from .search import SearchConverter + +__all__ = ["SearchConverter"] diff --git a/app/utils/converters/search.py b/app/utils/converters/search.py new file mode 100644 index 0000000..94ffb77 --- /dev/null +++ b/app/utils/converters/search.py @@ -0,0 +1,103 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http:#www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from app.domain.search import ( + SearchCondition, + SearchMode, + SearchParameters, + SearchResult, +) +from app.web.vo import ( + ConditionOperator, + SearchRequest, + SearchResponse, + SearchType, + StructuredSearchResult, + VectorHybridSearchResult, +) + + +class SearchConverter: + """搜索数据转换器""" + + @staticmethod + def request_vo_to_domain(request: SearchRequest) -> SearchParameters: + """VO转Domain""" + + if request.type == SearchType.VECTOR_HYBRID: + # 向量混合:生成双条件 + conditions: list[SearchCondition] = [] + for cond in request.query.conditions: + # 文本搜索条件 + conditions.append( + SearchCondition( + field_name=cond.field, + mode=SearchMode.MATCH, + value=cond.value, + ) + ) + # 向量搜索条件(value还是文本) + conditions.append( + SearchCondition( + field_name=f"{cond.field}_vector", + mode=SearchMode.VECTOR, + value=cond.value, + ) + ) + else: + # 结构化搜索:直接映射 + conditions = [ + SearchCondition( + field_name=cond.field, + mode=SearchMode.TERM + if cond.op == ConditionOperator.TERM + else SearchMode.MATCH, + value=cond.value, + ) + for cond in request.query.conditions + ] + + return SearchParameters( + index_name=request.query.index, + conditions=conditions, + limit=request.top_k, + filters=request.query.filters, + ) + + @staticmethod + def result_domain_to_vo( + search_result: SearchResult, search_type: SearchType + ) -> SearchResponse: + """Domain转VO""" + results: list[VectorHybridSearchResult | StructuredSearchResult] + + if search_type == SearchType.VECTOR_HYBRID: + results = [ + VectorHybridSearchResult( + text=doc.content.get("text", ""), + file_metadata_id=doc.content.get("file_metadata_id", ""), + score=doc.score, + ) + for doc in search_result.documents + ] + else: + results = [ + StructuredSearchResult( + id=doc.id, + document=doc.content, + score=doc.score, + ) + for doc in search_result.documents + if doc.id + ] + + return SearchResponse(type=search_type, results=results) diff --git a/app/utils/rerankers/bge.py b/app/utils/rerankers/bge.py index b97b57a..c05f296 100644 --- a/app/utils/rerankers/bge.py +++ b/app/utils/rerankers/bge.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from app.domain.search import ContextChunk +from app.domain.search import DocumentResult class BgeReranker: @@ -32,31 +32,32 @@ def __init__(self, model_name: str = "BAAI/bge-reranker-base") -> None: print(f"BGE Reranker loaded with model: {model_name}") def rerank( - self, query: str, results: list[ContextChunk] - ) -> list[ContextChunk]: + self, query: str, results: list[DocumentResult] + ) -> list[DocumentResult]: if not query or not results: return results # 创建副本,避免修改原始对象 results_copy = [ - ContextChunk( - text=chunk.text, - file_metadata_id=chunk.file_metadata_id, - score=chunk.score, + DocumentResult( + id=doc.id, + content=doc.content, + score=doc.score, ) - for chunk in results + for doc in results ] # 使用模型计算得分 - # show_progress_bar=False 在生产环境中是很好的实践 - sentence_pairs = [[query, chunk.text] for chunk in results_copy] + sentence_pairs = [ + [query, chunk.content.get("content", "")] + for chunk in results_copy # 安全获取字段 + ] scores = self.model.predict(sentence_pairs, show_progress_bar=False) - # 将新的 rerank 分数附加到每个 chunk 上 - for chunk, score in zip(results, scores, strict=True): - chunk.score = float(score) - - # 根据新的 rerank 分数降序排序 - results.sort(key=lambda x: x.score, reverse=True) + # 将新的rerank分数赋给副本 + for doc, score in zip(results_copy, scores, strict=True): + doc.score = float(score) - return results + # 根据新的rerank分数降序排序并返回副本 + results_copy.sort(key=lambda x: x.score, reverse=True) + return results_copy diff --git a/app/web/handler.py b/app/web/document.py similarity index 81% rename from app/web/handler.py rename to app/web/document.py index 2b7b729..391e629 100644 --- a/app/web/handler.py +++ b/app/web/document.py @@ -34,10 +34,14 @@ from app.config.settings import Settings from app.domain.document import Document -from app.domain.search import SearchRequest, SearchResponse from app.service.elasticsearch import ElasticsearchService +from app.utils.converters import SearchConverter from app.web.vo import ( FileUploadResponse, + SaveRequest, + SaveResponse, + SearchRequest, + SearchResponse, UrlUploadRequest, UrlUploadResponse, ) @@ -82,16 +86,24 @@ def __init__( def register_routes(self) -> None: """将本处理器中的所有API端点注册到构造时传入的路由器上。""" + + self._router.get("/health", summary="健康检查")(DocumentHandler.health) + + self._router.get( + "/tasks/{task_id}", + summary="查询任务状态", + )(self.get_task_status) + self._router.post( "/documents/upload-file", response_model=FileUploadResponse, - summary="通过文件上传进行索引", + summary="通过文件上传进行索引,可以假定索引已提前建好,只需要用前后缀拼接得到完整索引名称即可", )(self.upload_file) self._router.post( "/documents/upload-from-url", response_model=UrlUploadResponse, - summary="通过腾讯云COS URL下载并进行索引", + summary="通过腾讯云COS URL下载并进行索引,可以假定索引已提前建好,只需要用前后缀拼接得到完整索引名称即可", )(self.upload_from_url) self._router.post( @@ -100,12 +112,21 @@ def register_routes(self) -> None: summary="在知识库中进行搜索", )(self.search) - self._router.get("/health", summary="健康检查")(DocumentHandler.health) + self._router.post( + "/documents/save", + response_model=SaveResponse, + summary="保存JSON格式文档到指定的Elasticsearch索引", + )(self.save) - self._router.get( - "/tasks/{task_id}", - summary="查询任务状态", - )(self.get_task_status) + @staticmethod + async def health() -> dict[str, str]: + """健康检查接口。""" + return {"status": "healthy"} + + async def get_task_status(self, task_id: str) -> dict[str, str]: + """查询任务状态""" + status = self._task_status.get(task_id, "not_found") + return {"task_id": task_id, "status": status} def _process_and_cleanup( self, task_id: str, temp_dir: Path, document: Document @@ -114,7 +135,7 @@ def _process_and_cleanup( self._task_status[task_id] = "processing" try: logger.info(f"后台任务开始处理: {document.path}") - self._service.store(document) + self._service.store_for_vector_hybrid_search(document) logger.info(f"✅ 后台任务成功处理文件: {document.path}") self._task_status[task_id] = "completed" except Exception as e: @@ -142,9 +163,12 @@ async def _cleanup_task_status( async def upload_file( self, background_tasks: BackgroundTasks, - file: UploadFile = File(...), - category: str | None = Form(None), - tags: str | None = Form(None), + index_prefix: str = Form( + ..., min_length=1, description="索引完整名称前缀" + ), + file: UploadFile = File(..., description="上传的文件"), + category: str | None = Form(None, description="分类"), + tags: str | None = Form(None, description="标签"), ) -> FileUploadResponse: """从用户上传的文件创建并索引文档""" if not file.filename: @@ -179,6 +203,11 @@ async def upload_file( try: content = await file.read() + + if len(content) == 0: + shutil.rmtree(temp_dir) + raise HTTPException(status_code=400, detail="不能上传空文件") + # 双重检查(防止file.size不准确的情况) if len(content) > self._max_file_size_bytes: shutil.rmtree(temp_dir) @@ -199,6 +228,7 @@ async def upload_file( tag_list = [tag.strip() for tag in tags.split(",")] if tags else [] document = Document( + index_prefix=index_prefix, path=str(file_path.resolve()), size=file_size, category=category, @@ -293,32 +323,11 @@ async def _get_cos_object_metadata( except HTTPException: raise except Exception as e: - error_str = str(e) logger.error(f"获取COS对象元数据失败: {e}", exc_info=True) - - # 如果是权限问题,优雅降级:返回默认值,稍后从下载的文件获取实际大小 - if "AccessDenied" in error_str: - logger.warning( - f"COS权限不足,无法获取对象元数据: {cos_key},将在下载后获取文件信息" - ) - return 0, None, [] # 返回占位符:大小=0, 无category, 无tags - - # 其他错误照常处理 - elif "NoSuchKey" in error_str or "NoSuchBucket" in error_str: - raise HTTPException( - status_code=404, - detail=f"COS中未找到指定对象: {cos_key}", - ) from e - elif "timeout" in error_str.lower(): - raise HTTPException( - status_code=502, - detail=f"连接COS超时: {cos_key}", - ) from e - else: - raise HTTPException( - status_code=500, - detail=f"获取COS对象元数据失败: {cos_key}", - ) from e + raise HTTPException( + status_code=500, + detail=f"获取COS对象元数据失败: {cos_key}", + ) from e async def _download_cos_file( self, cos_key: str, file_path: Path, temp_dir: Path @@ -340,29 +349,10 @@ async def _download_cos_file( ) except Exception as e: shutil.rmtree(temp_dir) - error_str = str(e) logger.error(f"从COS下载文件失败: {e}", exc_info=True) - - # 根据具体错误类型返回不同的HTTP状态码 - if "AccessDenied" in error_str: - raise HTTPException( - status_code=403, - detail=f"无权限下载COS对象,请检查访问密钥权限: {cos_key}", - ) from e - elif "NoSuchKey" in error_str: - raise HTTPException( - status_code=404, - detail=f"COS中未找到指定对象: {cos_key}", - ) from e - elif "timeout" in error_str.lower(): - raise HTTPException( - status_code=502, - detail=f"下载COS对象超时: {cos_key}", - ) from e - else: - raise HTTPException( - status_code=500, detail=f"从COS下载文件失败: {cos_key}" - ) from e + raise HTTPException( + status_code=500, detail=f"从COS下载文件失败:{cos_key}" + ) from e async def upload_from_url( self, request: UrlUploadRequest, background_tasks: BackgroundTasks @@ -407,6 +397,7 @@ async def upload_from_url( # 7. 创建Document并添加后台任务 document = Document( + index_prefix=request.index_prefix, path=str(file_path.resolve()), size=file_size, category=category, @@ -421,23 +412,49 @@ async def search(self, request: SearchRequest) -> SearchResponse: """文档搜索接口""" try: logger.info( - f"🔍 收到搜索请求: query='{request.query}', top_k={request.top_k}, filters={request.filters}" + f"🔍 收到搜索请求: type='{request.type}', query='{request.query}', top_k={request.top_k}" + ) + + domain_response = self._service.search( + SearchConverter.request_vo_to_domain(request) + ) + + resp = SearchConverter.result_domain_to_vo( + domain_response, request.type ) - domain_response = self._service.search(request) logger.info( - f"✅ 搜索完成, 返回{len(domain_response.context)}条结果" + f"✅ 搜索完成, 返回{len(domain_response.documents)}条结果" ) - return domain_response + return resp except Exception as e: logger.error(f"❌ 搜索失败: {e}", exc_info=True) raise HTTPException(status_code=500, detail="搜索处理失败") from e - @staticmethod - async def health() -> dict[str, str]: - """健康检查接口。""" - return {"status": "healthy"} + async def save(self, request: SaveRequest) -> SaveResponse: + """保存JSON格式文档到指定的Elasticsearch索引""" + try: + self._service.save_for_structured_search( + index_name=request.index, + doc_id=request.key, + doc_dict=request.doc_json, + ) + return SaveResponse(message="ok") - async def get_task_status(self, task_id: str) -> dict[str, str]: - """查询任务状态""" - status = self._task_status.get(task_id, "not_found") - return {"task_id": task_id, "status": status} + except ValueError as e: + # JSON格式验证错误(由Pydantic自动处理) + logger.error(f"JSON格式验证失败: {e}") + raise HTTPException( + status_code=400, detail=f"JSON格式错误: {str(e)}" + ) from e + + except RuntimeError as e: + # service层抛出的存储错误 + logger.error(f"文档存储失败: {e}") + raise HTTPException( + status_code=500, detail="文档存储失败,请稍后重试" + ) from e + + except Exception as e: + # 其他未预期的异常 + logger.error(f"保存文档时发生未知错误: {e}") + raise HTTPException(status_code=500, detail="服务内部错误") from e diff --git a/app/web/vo.py b/app/web/vo.py index f702d91..8071bff 100644 --- a/app/web/vo.py +++ b/app/web/vo.py @@ -14,13 +14,11 @@ """Web层VO模型定义""" -from pydantic import BaseModel, Field, HttpUrl +from enum import Enum +from typing import Any - -class UrlUploadRequest(BaseModel): - """从URL上传的请求体模型""" - - url: HttpUrl = Field(..., description="要下载和索引的文件的完整URL") +from pydantic import BaseModel, Field, HttpUrl, Json, field_validator +from pydantic_core.core_schema import ValidationInfo class FileUploadResponse(BaseModel): @@ -34,6 +32,13 @@ class FileUploadResponse(BaseModel): ) +class UrlUploadRequest(BaseModel): + """从URL上传的请求体模型""" + + url: HttpUrl = Field(..., description="要下载和索引的文件的完整URL") + index_prefix: str = Field(..., min_length=1, description="索引完整名称前缀") + + class UrlUploadResponse(BaseModel): """从URL上传后的标准响应模型""" @@ -43,3 +48,119 @@ class UrlUploadResponse(BaseModel): message: str = Field( "URL已接收,正在后台下载和处理中...", description="操作结果信息" ) + + +class SearchType(str, Enum): + """搜索类型""" + + VECTOR_HYBRID = "vector_hybrid" # 向量+全文混合搜索 + STRUCTURED = "structured" # 结构化条件搜索 + + +class ConditionOperator(str, Enum): + """搜索条件操作符""" + + TERM = "term" # 精确匹配 + MATCH = "match" # 全文搜索匹配 + + +class Condition(BaseModel): + """搜索条件""" + + field: str = Field(..., min_length=1, description="文档中的字段名称") + op: ConditionOperator = Field( + ..., description="操作符:term(精确匹配) 或 match(全文搜索)" + ) + value: str | int | float | bool = Field( + ..., description="字段值,支持多种类型" + ) + + +class Query(BaseModel): + """查询对象""" + + index: str = Field(..., min_length=1, description="ES索引名称") + conditions: list[Condition] = Field( + ..., min_length=1, description="搜索条件列表,至少需要一个条件" + ) + filters: dict[str, Any] | None = Field( + None, description="过滤条件,用于精确过滤不参与计分" + ) + + +class SearchRequest(BaseModel): + """搜索请求""" + + type: SearchType = Field(..., description="搜索类型") + query: Query = Field(..., description="查询条件") + top_k: int = Field(..., ge=1, description="返回结果数量,至少为1") + + @field_validator("query") + @classmethod + def validate_query_for_search_type( + cls, v: Query, info: ValidationInfo + ) -> Query: + """根据搜索类型验证查询条件""" + search_type = info.data.get("type") + + if search_type == SearchType.VECTOR_HYBRID: + if ( + len(v.conditions) != 1 + or v.conditions[0].op != ConditionOperator.MATCH + ): + raise ValueError( + "vector_hybrid 模式只能有一个 match 类型的搜索条件" + ) + + return v + + +class VectorHybridSearchResult(BaseModel): + """向量+全文混合搜索结果""" + + text: str = Field(..., description="文档内容") + file_metadata_id: str = Field(..., description="文件元数据ID") + score: float = Field(..., description="相关度分数") + + +class StructuredSearchResult(BaseModel): + """结构化搜索结果""" + + id: str = Field(..., description="文档唯一标识符") + document: dict[str, Any] = Field(..., description="文档数据") + score: float = Field(..., description="相关度分数") + + +class SearchResponse(BaseModel): + """搜索响应""" + + type: SearchType = Field(..., description="搜索类型") # 保持一致性 + results: list[VectorHybridSearchResult | StructuredSearchResult] = Field( + default_factory=list, description="搜索结果" + ) + + +class SaveRequest(BaseModel): + """ + Elasticsearch文档保存请求模型 + + 以key为_id向名为index的索引中插入doc_json。 + 如果key不存在则直接插入,如果key已存在则完整覆盖。 + + Attributes: + index: ES中索引的完整名称,假定mappings已建立好 + key: 文档的唯一标识,将作为ES中的_id使用 + doc_json: 满足JSON格式的字符串文档内容,会自动解析为字典 + """ + + index: str = Field(..., min_length=1, description="ES索引名称") + key: str = Field(..., min_length=1, description="文档唯一标识") + doc_json: Json[dict[str, Any]] = Field( + ..., description="JSON格式的文档内容" + ) + + +class SaveResponse(BaseModel): + """文档保存操作的响应模型""" + + message: str = Field(default="操作成功", description="操作结果信息") diff --git a/config.yaml b/config.yaml index 92bc11f..d6ddca4 100644 --- a/config.yaml +++ b/config.yaml @@ -1,7 +1,7 @@ elasticsearch: url: "http://localhost:9200" - metadata_index: "knowledge_base_metadatas" - chunk_index: "knowledge_base_chunks" + metadata_index_suffix: "_metadatas" + chunk_index_suffix: "_chunks" request_timeout: 60 embedder: diff --git a/pyproject.toml b/pyproject.toml index 22d65ba..d7654a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,12 +89,18 @@ module = [ ignore_missing_imports = true [tool.pytest.ini_options] -addopts = "-q --cov=app --cov-report=term-missing --cov-report=xml --cov-report=html" +addopts = [ + "-q", + "--cov-report=term-missing", + "--cov-report=xml:coverage.xml", # 指定XML报告位置 + "--cov-report=html:htmlcov", # 明确指定HTML报告位置 + "--cov-branch", # 启用分支覆盖率 +] testpaths = ["tests"] -minversion = "8.0" # ← 添加最小版本要求 -python_files = "test_*.py" # ← 明确测试文件模式 -python_classes = "Test*" # ← 明确测试类模式 -python_functions = "test_*" # ← 明确测试函数模式 +minversion = "8.0" +python_files = "*_test.py" +python_classes = "Test*" +python_functions = "test_*" [tool.uv] required-environments = [ diff --git a/tests/.DS_Store b/tests/.DS_Store deleted file mode 100644 index 3ebc335..0000000 Binary files a/tests/.DS_Store and /dev/null differ diff --git a/tests/conftest.py b/tests/conftest.py index 112803a..35f3c8b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,6 +22,7 @@ from typing import Any import pytest +from elasticsearch import Elasticsearch from fastapi.testclient import TestClient from qcloud_cos import CosS3Client # type: ignore[import-untyped] @@ -52,3 +53,28 @@ def client() -> Generator[TestClient, Any, None]: """V2 API 测试客户端""" with TestClient(app) as test_client: yield test_client + + +@pytest.fixture(scope="session") +def es_client() -> Generator[Elasticsearch, Any, None]: + """ + Elasticsearch客户端 - 用于测试ES相关操作 + """ + client = Elasticsearch( + hosts=[{"host": "localhost", "port": 9200, "scheme": "http"}], + request_timeout=100, + retry_on_timeout=True, + # 如果有认证信息也要添加 + ) + # 测试连接 + try: + info = client.info() + logger.info( + f"✅ 成功连接到Elasticsearch服务: {info['version']['number']}" + ) + yield client + except Exception as e: + logger.error(f"❌ 无法连接到Elasticsearch: {e}") + pytest.fail("Elasticsearch服务不可用") + finally: + client.close() diff --git a/tests/fixtures/config.yaml b/tests/fixtures/config.yaml index 74f3ee4..f9df47c 100644 --- a/tests/fixtures/config.yaml +++ b/tests/fixtures/config.yaml @@ -1,7 +1,7 @@ elasticsearch: url: "http://localhost:9200" - metadata_index: "test_knowledge_base_metadatas" - chunk_index: "test_knowledge_base_chunks" + metadata_index_suffix: "_metadatas" + chunk_index_suffix: "_chunks" request_timeout: 15 embedder: diff --git a/tests/test_api.py b/tests/test_api.py deleted file mode 100644 index f5d29f5..0000000 --- a/tests/test_api.py +++ /dev/null @@ -1,257 +0,0 @@ -# Copyright 2021 ecodeclub -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http:#www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time -from collections.abc import Callable -from pathlib import Path -from pprint import pprint -from typing import overload - -import pytest -from fastapi.testclient import TestClient - -from app.config.settings import settings - - -class TestAPI: - """ - 端到端测试类,测试路径为 /api/v1/。 - """ - - @pytest.fixture(scope="class") - def user_upload_dir(self) -> Path: - """提供 '用户准备上传' 的文件目录路径。""" - path = Path(__file__).parent / "fixtures" / "files" / "user" - path.mkdir(exist_ok=True, parents=True) - return path - - @pytest.fixture(scope="class") - def backend_uploaded_dir(self) -> Path: - """提供 '后端已接收' 的文件根目录路径。""" - path = Path(__file__).parent / "fixtures" / "files" / "uploaded" - path.mkdir(exist_ok=True, parents=True) - return path - - @pytest.fixture(scope="class") - def get_user_upload_file( - self, user_upload_dir: Path - ) -> Callable[[str | list[str]], Path | list[Path]]: - """从 'user' 目录轻松获取文件路径。""" - - # 为 mypy 准备的类型重载,用于精确类型推断 - @overload - def _builder(file_names: str) -> Path: ... - - @overload # noqa: F811 - def _builder(file_names: list[str]) -> list[Path]: ... - - def _builder(file_names: str | list[str]) -> Path | list[Path]: # noqa: F811 - if isinstance(file_names, str): - path = user_upload_dir / file_names - if not path.exists(): - raise FileNotFoundError(f"源文件不存在: {path}") - return path - - paths = [user_upload_dir / name for name in file_names] - if not all(p.exists() for p in paths): - raise FileNotFoundError("一个或多个源文件不存在。") - return paths - - return _builder # type: ignore[return-value] - - def test_upload_file_and_search_flow( - self, - client: TestClient, - get_user_upload_file: Callable[[str], Path], - backend_uploaded_dir: Path, - ) -> None: - """测试本地上传 -> 验证文件在UUID子目录中 -> 搜索。""" - test_file_name = "03_test.pdf" - print(f"\n▶️ 测试本地上传,使用文件: {test_file_name}") - - test_file = get_user_upload_file(test_file_name) - - with test_file.open("rb") as f: - response = client.post( - "/api/v1/documents/upload-file", - files={"file": (test_file.name, f, "application/pdf")}, - ) - assert response.status_code == 200 - - # 验证上传响应 - upload_result = response.json() - assert "task_id" in upload_result - print(f"✅ 文件上传成功,任务ID: {upload_result['task_id']}") - - # 文件上传后会被异步处理,然后正确清理 - # 我们通过上传响应和后续的索引结果来验证处理成功 - print("✅ 文件上传成功,后台正在异步处理(处理完成后会自动清理)") - - import time - - print("⏳ 等待索引...") - time.sleep(5) - - query = {"query": "数据库", "top_k": 3} - print(f"查询参数:{query}\n") - - response = client.post( - "/api/v1/search", json={"query": "数据库并发", "top_k": 3} - ) - assert response.status_code == 200 - search_results = response.json() - assert len(search_results["context"]) > 0 - print("✅ 搜索成功!结果详情:") - pprint(search_results) - - def test_health_check_endpoint(self, client: TestClient) -> None: - """测试API健康检查端点""" - print("\n🏥 测试健康检查端点...") - response = client.get("/api/v1/health") - assert response.status_code == 200 - - health_data = response.json() - assert health_data["status"] == "healthy" - print("✅ 健康检查端点正常") - - def test_upload_performance_monitoring( - self, - client: TestClient, - get_user_upload_file: Callable[[str], Path], - ) -> None: - """测试文件上传的性能表现""" - test_file_name = "01_test.pdf" - print(f"\n⏱️ 测试上传性能: {test_file_name}") - - test_file = get_user_upload_file(test_file_name) - file_size = test_file.stat().st_size - print(f" 文件大小: {file_size / 1024:.1f} KB") - - with test_file.open("rb") as f: - start_time = time.time() - response = client.post( - "/api/v1/documents/upload-file", - files={"file": (test_file.name, f, "application/pdf")}, - ) - upload_duration = time.time() - start_time - - assert response.status_code == 200 - upload_result = response.json() - assert "task_id" in upload_result - - print("✅ 性能测试完成") - print(f" - 上传耗时: {upload_duration:.2f}秒") - print(f" - 上传速度: {file_size / 1024 / upload_duration:.1f} KB/s") - - def test_search_robustness(self, client: TestClient) -> None: - """测试搜索功能的鲁棒性""" - print("\n🛡️ 测试搜索鲁棒性...") - - # 定义各种边界情况 - robustness_cases = [ - { - "query": "不存在的专业术语xyz123", - "description": "不存在内容", - "expect_results": False, - }, - {"query": "数", "description": "极短查询", "expect_results": True}, - { - "query": "数据库性能优化并发处理" * 10, - "description": "超长查询", - "expect_results": True, - }, - {"query": "", "description": "空字符串", "expect_results": False}, - {"query": " ", "description": "纯空格", "expect_results": False}, - { - "query": "!@#$%", - "description": "特殊字符", - "expect_results": False, - }, - ] - - for case in robustness_cases: - print(f"🔍 测试场景: {case['description']}") - search_data = {"query": case["query"], "top_k": 3} - response = client.post("/api/v1/search", json=search_data) - - # 所有情况都应该返回 200,不应该崩溃 - assert response.status_code == 200 - search_result = response.json() - assert "context" in search_result - - result_count = len(search_result["context"]) - query_text = str(case["query"]) - display_query = ( - query_text[:20] + "..." if len(query_text) > 20 else query_text - ) - print(f" 查询: '{display_query}' -> {result_count} 条结果") - - # 验证预期行为 - if case["expect_results"]: - # 对于有意义的查询,应该尽量有结果(但不强制) - print(" ✅ 查询处理正常") - else: - # 对于无意义查询,无结果是正常的 - print(" ✅ 边界情况处理正常") - - def test_upload_url_and_search_flow( - self, - client: TestClient, - ) -> None: - """测试从COS URL上传 -> 搜索流程(使用预上传的文件)""" - - object_key = "kbase-temp/02_test.pdf" - bucket_name = settings.tencent_oss.bucket - cos_url = f"https://{bucket_name}.cos.{settings.tencent_oss.region}.myqcloud.com/{object_key}" - - print(f"🔗 测试COS URL: {cos_url}") - print(f" 存储桶: {bucket_name}") - print(f" 对象键: {object_key}") - - try: - # 步骤1: 通过API从COS URL上传 - print("📥 步骤1: 通过API从COS URL上传...") - response = client.post( - "/api/v1/documents/upload-from-url", json={"url": cos_url} - ) - assert response.status_code == 200 - upload_result = response.json() - assert "task_id" in upload_result - print(f"✅ URL上传API调用成功,任务ID: {upload_result['task_id']}") - - # 步骤2: 等待索引完成 - print("⏳ 步骤2: 等待索引完成...") - time.sleep(5) - - # 步骤3: 测试搜索 - print("🔍 步骤3: 测试搜索...") - response = client.post( - "/api/v1/search", json={"query": "读多写少", "top_k": 3} - ) - assert response.status_code == 200 - search_results = response.json() - - print(f"搜索结果数量: {len(search_results['context'])}") - if len(search_results["context"]) > 0: - print("✅ COS URL上传内容的搜索成功!结果详情:") - pprint(search_results) - else: - print("⚠️ 搜索结果为空,可能文件还在处理中或内容不匹配") - # 不强制要求搜索结果,因为索引可能需要更多时间 - - print("🎉 COS URL上传和搜索测试完成!") - - except Exception as e: - print(f"❌ 测试过程中出现异常: {e}") - raise diff --git a/tests/web/__init__.py b/tests/web/__init__.py new file mode 100644 index 0000000..eb8dd5f --- /dev/null +++ b/tests/web/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2021 ecodeclub +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http:#www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""web 集成测试""" diff --git a/tests/web/document/__init__.py b/tests/web/document/__init__.py new file mode 100644 index 0000000..2a2902e --- /dev/null +++ b/tests/web/document/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2021 ecodeclub +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http:#www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Document 相关端点测试""" diff --git a/tests/web/document/health_endpoint_test.py b/tests/web/document/health_endpoint_test.py new file mode 100644 index 0000000..0011e0e --- /dev/null +++ b/tests/web/document/health_endpoint_test.py @@ -0,0 +1,25 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http:#www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from fastapi.testclient import TestClient + + +class TestHealthEndpoint: + def test_health_check_endpoint(self, client: TestClient) -> None: + """测试API健康检查端点""" + print("\n🏥 测试健康检查端点...") + response = client.get("/api/v1/health") + assert response.status_code == 200 + + health_data = response.json() + assert health_data["status"] == "healthy" + print("✅ 健康检查端点正常") diff --git a/tests/web/document/save_endpoint_test.py b/tests/web/document/save_endpoint_test.py new file mode 100644 index 0000000..bd5e200 --- /dev/null +++ b/tests/web/document/save_endpoint_test.py @@ -0,0 +1,256 @@ +# Copyright 2021 ecodeclub +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http:#www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from collections.abc import Generator +from typing import Any, cast + +import pytest +from elasticsearch import Elasticsearch +from elasticsearch.exceptions import NotFoundError, TransportError +from fastapi.testclient import TestClient + + +class TestSaveEndpoint: + """测试文档保存接口""" + + TEST_INDEX = "test_save_index" + + @pytest.fixture(scope="class", autouse=True) + def setup_test_index( + self, es_client: Elasticsearch + ) -> Generator[None, Any, None]: + """设置测试索引""" + # 如果索引存在则删除(清理之前的测试) + if es_client.indices.exists(index=self.TEST_INDEX): + es_client.indices.delete(index=self.TEST_INDEX) + + # 创建测试索引 + es_client.indices.create( + index=self.TEST_INDEX, + body={ + "mappings": { + "properties": { + "role": { + "type": "keyword" + }, # 精确匹配:后端、前端、测试等 + "level": { + "type": "keyword" + }, # 精确匹配:初级、中级、高级等 + "content": {"type": "text"}, # 模糊匹配:详细描述内容 + } + } + }, + ) + + yield + + # 清理:删除测试索引 + if es_client.indices.exists(index=self.TEST_INDEX): + es_client.indices.delete(index=self.TEST_INDEX) + + def _get_document_from_es( + self, es_client: Elasticsearch, doc_id: str + ) -> dict[str, Any] | None: + """从ES获取文档""" + try: + response = es_client.get(index=self.TEST_INDEX, id=doc_id) + return cast("dict[str, Any]", response["_source"]) + except NotFoundError: + return None + except TransportError as e: + print( + f"ES获取文档失败 - 索引: {self.TEST_INDEX}, 文档ID: {doc_id}, 错误: {e}" + ) + raise + + def _document_exists_in_es( + self, es_client: Elasticsearch, doc_id: str + ) -> bool: + """检查文档是否存在于ES中""" + try: + es_client.get(index=self.TEST_INDEX, id=doc_id) + return True + except NotFoundError: + return False + except TransportError as e: + print( + f"ES检查文档存在性失败 - 索引: {self.TEST_INDEX}, 文档ID: {doc_id}, 错误: {e}" + ) + raise + + def test_save_new_document( + self, client: TestClient, es_client: Elasticsearch + ) -> None: + """测试保存新文档""" + + doc_id = "backend_dev_1" + doc_data = { + "role": "后端", + "level": "高级", + "content": "负责用户管理系统的设计和开发,熟练掌握Python、Django等技术栈,具备丰富的微服务架构经验", + } + + # 验证文档不存在 + assert not self._document_exists_in_es(es_client, doc_id), ( + "文档不应该预先存在" + ) + + # 调用保存接口并验证响应 + response = client.post( + "/api/v1/documents/save", + json={ + "index": self.TEST_INDEX, + "key": doc_id, + "doc_json": json.dumps(doc_data), + }, + ) + assert response.status_code == 200 + response_data = response.json() + assert "ok" in response_data["message"] + + # 验证文档已插入ES + assert self._document_exists_in_es(es_client, doc_id), ( + "文档应该已经插入" + ) + saved_doc = self._get_document_from_es(es_client, doc_id) + assert saved_doc == doc_data + + def test_save_update_existing_document( + self, client: TestClient, es_client: Elasticsearch + ) -> None: + """测试覆盖已存在的文档""" + + doc_id = "frontend_dev_1" + original_doc = { + "role": "前端", + "level": "中级", + "content": "负责企业级Web应用的前端开发,熟悉React和Vue框架", + } + updated_doc = { + "role": "架构师", + "level": "高级", + "content": "负责企业级Web应用和移动端H5的前端开发,精通React、Vue、TypeScript,具备跨平台开发经验", + } + + # 验证文档不存在 + assert not self._document_exists_in_es(es_client, doc_id), ( + "文档不应该预先存在" + ) + + # 先插入一个文档并验证原文档存在 + es_client.index( + index=self.TEST_INDEX, + id=doc_id, + document=original_doc, + refresh="wait_for", + ) + assert self._document_exists_in_es(es_client, doc_id), "原文档应该存在" + original_saved = self._get_document_from_es(es_client, doc_id) + assert original_saved == original_doc + + # 调用保存接口进行覆盖 + response = client.post( + "/api/v1/documents/save", + json={ + "index": self.TEST_INDEX, + "key": doc_id, + "doc_json": json.dumps(updated_doc), + }, + ) + + # 验证响应 + assert response.status_code == 200 + response_data = response.json() + assert "ok" in response_data["message"] + + # 验证文档已完全覆盖 + updated_saved = self._get_document_from_es(es_client, doc_id) + assert updated_saved == updated_doc + + def test_save_invalid_json_format(self, client: TestClient) -> None: + """测试无效的JSON格式""" + + response = client.post( + "/api/v1/documents/save", + json={ + "index": self.TEST_INDEX, + "key": "test_invalid", + "doc_json": '{"role": "后端", "level": "高级"', # 无效JSON,缺少闭合括号 + }, + ) + + assert response.status_code == 422 # Pydantic验证错误 + error_data = response.json() + assert "detail" in error_data + + def test_save_empty_fields(self, client: TestClient) -> None: + """测试空字段""" + + test_cases = [ + { + "index": "", + "key": "test", + "doc_json": '{"role": "后端", "level": "高级", "content": "测试内容"}', + }, # 空索引 + { + "index": self.TEST_INDEX, + "key": "", + "doc_json": '{"role": "后端", "level": "高级", "content": "测试内容"}', + }, # 空key + ] + + for case in test_cases: + response = client.post("/api/v1/documents/save", json=case) + assert response.status_code == 422 # 字段验证错误 + error_data = response.json() + assert "detail" in error_data + + def test_save_missing_fields(self, client: TestClient) -> None: + """测试缺失字段""" + + test_cases = [ + { + "key": "test", + "doc_json": '{"role": "后端", "level": "高级", "content": "测试内容"}', + }, # 缺失index + { + "index": self.TEST_INDEX, + "doc_json": '{"role": "后端", "level": "高级", "content": "测试内容"}', + }, # 缺失key + {"index": self.TEST_INDEX, "key": "test"}, # 缺失doc_json + ] + + for case in test_cases: + response = client.post("/api/v1/documents/save", json=case) + assert response.status_code == 422 # 缺失字段错误 + error_data = response.json() + assert "detail" in error_data + + def test_save_invalid_index_name(self, client: TestClient) -> None: + """测试无效索引名(模拟ES错误)""" + + response = client.post( + "/api/v1/documents/save", + json={ + "index": "INVALID_INDEX_NAME_WITH_UPPERCASE", # ES不允许大写索引名 + "key": "test", + "doc_json": '{"role": "后端", "level": "高级", "content": "测试ES错误处理"}', + }, + ) + + # 这应该触发service层的异常,返回500错误 + assert response.status_code == 500 + error_data = response.json() + assert "文档存储失败" in error_data["detail"] diff --git a/tests/web/document/structured_search_test.py b/tests/web/document/structured_search_test.py new file mode 100644 index 0000000..925237c --- /dev/null +++ b/tests/web/document/structured_search_test.py @@ -0,0 +1,545 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http:#www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from collections.abc import Generator +from typing import Any + +import pytest +from elasticsearch import Elasticsearch +from fastapi.testclient import TestClient + + +class TestStructuredSearch: + """结构化搜索测试 + + 包含: + 1. 结构化搜索功能测试 + 2. 通用参数验证测试(复用测试环境,后续可能调整) + """ + + TEST_INDEX = "test_search_structured" + + @pytest.fixture(scope="class", autouse=True) + def setup_environment( + self, client: TestClient, es_client: Elasticsearch + ) -> Generator[None, Any, None]: + """准备测试环境(索引+数据)""" + + # 1. 清理已存在的索引 + if es_client.indices.exists(index=self.TEST_INDEX): + es_client.indices.delete(index=self.TEST_INDEX) + + # 2. 创建结构化搜索测试索引 + es_client.indices.create( + index=self.TEST_INDEX, + body={ + "mappings": { + "properties": { + "role": {"type": "keyword"}, # 精确匹配 + "level": {"type": "keyword"}, # 精确匹配 + "content": {"type": "text"}, # 全文搜索 + "department": {"type": "keyword"}, # 精确匹配 + "status": {"type": "keyword"}, # 精确匹配 + "tags": {"type": "keyword"}, # 精确匹配 + "salary": {"type": "integer"}, # 数值类型 + } + } + }, + ) + + # 3. 准备测试数据 + self._prepare_test_data(client, es_client) + + # 4. 执行所有测试 + yield + + # 5. 清理测试索引 + if es_client.indices.exists(index=self.TEST_INDEX): + es_client.indices.delete(index=self.TEST_INDEX) + + def _prepare_test_data( + self, client: TestClient, es_client: Elasticsearch + ) -> None: + """准备结构化搜索测试数据(在fixture中调用,只执行一次)""" + + test_documents = [ + { + "id": "backend_senior_1", + "data": { + "role": "后端", + "level": "高级", + "content": "MySQL 性能优化是后端开发的重要技能,包括索引优化、查询优化等", + "department": "技术部", + "status": "active", + "tags": "database", + "salary": 25000, + }, + }, + { + "id": "backend_junior_1", + "data": { + "role": "后端", + "level": "初级", + "content": "Python 基础语法学习,包括变量、函数、类等概念", + "department": "技术部", + "status": "active", + "tags": "programming", + "salary": 15000, + }, + }, + { + "id": "frontend_senior_1", + "data": { + "role": "前端", + "level": "高级", + "content": "React 组件设计模式,状态管理最佳实践", + "department": "技术部", + "status": "inactive", + "tags": "frontend", + "salary": 22000, + }, + }, + { + "id": "backend_senior_2", + "data": { + "role": "后端", + "level": "高级", + "content": "分布式系统设计与微服务架构实践指南", + "department": "技术部", + "status": "active", + "tags": "architecture", + "salary": 30000, + }, + }, + { + "id": "qa_middle_1", + "data": { + "role": "测试", + "level": "中级", + "content": "自动化测试框架设计,性能测试和优化经验", + "department": "质量部", + "status": "active", + "tags": "testing", + "salary": 18000, + }, + }, + ] + + # 使用save接口插入所有测试数据 + for doc in test_documents: + response = client.post( + "/api/v1/documents/save", + json={ + "index": self.TEST_INDEX, + "key": doc["id"], + "doc_json": json.dumps(doc["data"]), + }, + ) + assert response.status_code == 200, ( + f"保存文档失败: {response.json()}" + ) + + # 刷新索引确保数据可搜索 + es_client.indices.refresh(index=self.TEST_INDEX) + + # ===== 结构化搜索功能测试 ===== + + def test_single_exact_match(self, client: TestClient) -> None: + """测试单个精确匹配 - 按角色搜索后端开发者""" + response = client.post( + "/api/v1/search", + json={ + "type": "structured", + "query": { + "index": self.TEST_INDEX, + "conditions": [ + {"field": "role", "op": "term", "value": "后端"} + ], + }, + "top_k": 5, + }, + ) + + assert response.status_code == 200 + data = response.json() + + # 验证响应结构和数据 + assert data["type"] == "structured" + assert len(data["results"]) == 3 # 3个后端开发者 + + # 验证具体结果 + backend_ids = { + "backend_senior_1", + "backend_junior_1", + "backend_senior_2", + } + actual_ids = {result["id"] for result in data["results"]} + assert actual_ids == backend_ids + + # 验证结果格式 + for result in data["results"]: + assert "id" in result + assert "document" in result + assert "score" in result + assert result["document"]["role"] == "后端" + + def test_multiple_exact_match(self, client: TestClient) -> None: + """测试多个精确匹配 - 角色+级别(AND关系)""" + response = client.post( + "/api/v1/search", + json={ + "type": "structured", + "query": { + "index": self.TEST_INDEX, + "conditions": [ + {"field": "role", "op": "term", "value": "后端"}, + {"field": "level", "op": "term", "value": "高级"}, + ], + }, + "top_k": 3, + }, + ) + + assert response.status_code == 200 + data = response.json() + + assert data["type"] == "structured" + assert len(data["results"]) == 2 # 2个高级后端开发者 + + # 验证具体结果 + expected_ids = {"backend_senior_1", "backend_senior_2"} + actual_ids = {result["id"] for result in data["results"]} + assert actual_ids == expected_ids + + for result in data["results"]: + assert result["document"]["role"] == "后端" + assert result["document"]["level"] == "高级" + + def test_single_full_text_match(self, client: TestClient) -> None: + """测试单个全文搜索 - 内容匹配""" + response = client.post( + "/api/v1/search", + json={ + "type": "structured", + "query": { + "index": self.TEST_INDEX, + "conditions": [ + { + "field": "content", + "op": "match", + "value": "MySQL 性能优化", + } + ], + }, + "top_k": 3, + }, + ) + + assert response.status_code == 200 + data = response.json() + + assert data["type"] == "structured" + assert len(data["results"]) >= 1 + + # 验证结果包含相关内容 + found_mysql_doc = False + for result in data["results"]: + if "MySQL" in result["document"]["content"]: + found_mysql_doc = True + break + assert found_mysql_doc, "应该找到包含MySQL的文档" + + def test_multiple_full_text_match(self, client: TestClient) -> None: + """测试多个全文搜索 - 系统+架构(AND关系)""" + response = client.post( + "/api/v1/search", + json={ + "type": "structured", + "query": { + "index": self.TEST_INDEX, + "conditions": [ + {"field": "content", "op": "match", "value": "系统"}, + {"field": "content", "op": "match", "value": "架构"}, + ], + }, + "top_k": 3, + }, + ) + + assert response.status_code == 200 + data = response.json() + + assert data["type"] == "structured" + assert len(data["results"]) == 1 + assert data["results"][0]["id"] == "backend_senior_2" + + # 验证AND关系:必须同时包含两个关键词 + for result in data["results"]: + content = result["document"]["content"] + assert "系统" in content and "架构" in content, ( + f"文档内容应同时包含'系统'和'架构': {content}" + ) + + def test_mixed_exact_and_full_text_match(self, client: TestClient) -> None: + """测试混合搜索 - 精确匹配+全文搜索""" + response = client.post( + "/api/v1/search", + json={ + "type": "structured", + "query": { + "index": self.TEST_INDEX, + "conditions": [ + {"field": "role", "op": "term", "value": "后端"}, + {"field": "status", "op": "term", "value": "active"}, + {"field": "content", "op": "match", "value": "优化"}, + ], + }, + "top_k": 3, + }, + ) + + assert response.status_code == 200 + data = response.json() + + assert data["type"] == "structured" + assert len(data["results"]) == 1 + assert data["results"][0]["id"] == "backend_senior_1" + + for result in data["results"]: + doc = result["document"] + assert doc["role"] == "后端" + assert doc["status"] == "active" + assert "优化" in doc["content"] + + def test_complex_mixed_conditions(self, client: TestClient) -> None: + """测试复杂混合条件 - 部门+级别+内容""" + response = client.post( + "/api/v1/search", + json={ + "type": "structured", + "query": { + "index": self.TEST_INDEX, + "conditions": [ + { + "field": "department", + "op": "term", + "value": "技术部", + }, # 精确匹配 + { + "field": "level", + "op": "term", + "value": "高级", + }, # 精确匹配 + { + "field": "content", + "op": "match", + "value": "设计", + }, # 全文匹配 + ], + }, + "top_k": 5, + }, + ) + + assert response.status_code == 200 + data = response.json() + + assert data["type"] == "structured" + assert len(data["results"]) >= 1 + + for result in data["results"]: + doc = result["document"] + assert doc["department"] == "技术部" + assert doc["level"] == "高级" + assert "设计" in doc["content"] + + def test_with_range_filters(self, client: TestClient) -> None: + """测试范围过滤 - 薪资范围""" + response = client.post( + "/api/v1/search", + json={ + "type": "structured", + "query": { + "index": self.TEST_INDEX, + "conditions": [ + {"field": "role", "op": "term", "value": "后端"} + ], + "filters": { + "range": { + "salary": {"gte": 20000} # 薪资>=20000的过滤条件 + } + }, + }, + "top_k": 5, + }, + ) + + assert response.status_code == 200 + data = response.json() + + assert data["type"] == "structured" + assert len(data["results"]) >= 1 + + # 验证具体地期望结果 + expected_ids = {"backend_senior_1", "backend_senior_2"} + actual_ids = {result["id"] for result in data["results"]} + assert actual_ids == expected_ids, ( + f"期望 {expected_ids},实际 {actual_ids}" + ) + + # 验证过滤条件生效 + for result in data["results"]: + assert result["document"]["role"] == "后端" + assert result["document"]["salary"] >= 20000 + + def test_top_k_limit(self, client: TestClient) -> None: + """测试结果数量限制""" + response = client.post( + "/api/v1/search", + json={ + "type": "structured", + "query": { + "index": self.TEST_INDEX, + "conditions": [ + {"field": "role", "op": "term", "value": "后端"} + ], + }, + "top_k": 2, + }, + ) + + assert response.status_code == 200 + data = response.json() + + assert data["type"] == "structured" + assert len(data["results"]) <= 2 # 最多返回2个结果 + assert ( + len(data["results"]) == 2 + ) # 应该正好有2个结果(因为有3个后端,限制为2) + + # 验证返回的都是后端 + for result in data["results"]: + assert result["document"]["role"] == "后端" + + def test_no_results_found(self, client: TestClient) -> None: + """测试无匹配结果""" + response = client.post( + "/api/v1/search", + json={ + "type": "structured", + "query": { + "index": self.TEST_INDEX, + "conditions": [ + {"field": "role", "op": "term", "value": "不存在的角色"} + ], + }, + "top_k": 3, + }, + ) + + assert response.status_code == 200 + data = response.json() + + assert data["type"] == "structured" + assert len(data["results"]) == 0 + + # ===== 参数验证测试 ===== + # 注:这些测试复用结构化搜索的测试环境 + # 如果向量混合搜索需要不同的校验规则,在vector_hybrid_search_test.py中单独添加 + + def test_invalid_search_type(self, client: TestClient) -> None: + """测试无效搜索类型""" + response = client.post( + "/api/v1/search", + json={ + "type": "invalid_type", + "query": { + "index": self.TEST_INDEX, + "conditions": [ + {"field": "role", "op": "term", "value": "后端"} + ], + }, + "top_k": 3, + }, + ) + + assert response.status_code == 422 + + def test_invalid_operator(self, client: TestClient) -> None: + """测试无效操作符""" + response = client.post( + "/api/v1/search", + json={ + "type": "structured", + "query": { + "index": self.TEST_INDEX, + "conditions": [ + {"field": "role", "op": "invalid_op", "value": "后端"} + ], + }, + "top_k": 3, + }, + ) + + assert response.status_code == 422 + + def test_missing_required_fields(self, client: TestClient) -> None: + """测试缺少必需字段""" + response = client.post( + "/api/v1/search", + json={ + "query": { + "index": self.TEST_INDEX, + "conditions": [ + {"field": "role", "op": "term", "value": "后端"} + ], + }, + "top_k": 3, + }, + ) + + assert response.status_code == 422 + + def test_empty_conditions(self, client: TestClient) -> None: + """测试空条件列表""" + response = client.post( + "/api/v1/search", + json={ + "type": "structured", + "query": { + "index": self.TEST_INDEX, + "conditions": [], + }, + "top_k": 3, + }, + ) + + assert response.status_code == 422 + + def test_nonexistent_index(self, client: TestClient) -> None: + """测试不存在的索引""" + response = client.post( + "/api/v1/search", + json={ + "type": "structured", + "query": { + "index": "不存在的索引", + "conditions": [ + {"field": "role", "op": "term", "value": "后端"} + ], + }, + "top_k": 3, + }, + ) + + assert response.status_code >= 400 diff --git a/tests/web/document/upload_endpoint_test.py b/tests/web/document/upload_endpoint_test.py new file mode 100644 index 0000000..1815899 --- /dev/null +++ b/tests/web/document/upload_endpoint_test.py @@ -0,0 +1,439 @@ +# Copyright 2021 ecodeclub +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http:#www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from collections.abc import Callable +from pathlib import Path +from typing import overload + +import pytest +from elasticsearch import Elasticsearch +from fastapi.testclient import TestClient + +from app.config.settings import settings + + +class TestUploadEndpoint: + """ + 上传接口端到端测试类,测试路径为 /api/v1/documents/。 + 使用不同的index_prefix实现数据隔离,用ES Client直接验证ES存储结果。 + """ + + # 数据隔离:两个接口使用不同的索引前缀 + FILE_UPLOAD_INDEX_PREFIX = "test_file_upload" + URL_UPLOAD_INDEX_PREFIX = "test_url_upload" + + @pytest.fixture(scope="class") + def user_upload_dir(self) -> Path: + """提供 '用户准备上传' 的文件目录路径。""" + path = ( + Path(__file__).parent.parent.parent / "fixtures" / "files" / "user" + ) + path.mkdir(exist_ok=True, parents=True) + return path + + @pytest.fixture(scope="class") + def get_user_upload_file( + self, user_upload_dir: Path + ) -> Callable[[str | list[str]], Path | list[Path]]: + """从 'user' 目录轻松获取文件路径。""" + + @overload + def _builder(file_names: str) -> Path: ... + + @overload # noqa: F811 + def _builder(file_names: list[str]) -> list[Path]: ... + + def _builder(file_names: str | list[str]) -> Path | list[Path]: # noqa: F811 + if isinstance(file_names, str): + path = user_upload_dir / file_names + if not path.exists(): + raise FileNotFoundError(f"源文件不存在: {path}") + return path + + paths = [user_upload_dir / name for name in file_names] + if not all(p.exists() for p in paths): + raise FileNotFoundError("一个或多个源文件不存在。") + return paths + + return _builder # type: ignore[return-value] + + @staticmethod + def _get_metadata_index_name(index_prefix: str) -> str: + """获取metadata索引名""" + return index_prefix + settings.elasticsearch.metadata_index_suffix + + @staticmethod + def _get_chunk_index_name(index_prefix: str) -> str: + """获取chunk索引名""" + return index_prefix + settings.elasticsearch.chunk_index_suffix + + def _verify_es_data_exists( + self, + es_client: Elasticsearch, + index_prefix: str, + expected_filename: str, + ) -> tuple[int, int]: + """ + 验证ES中的数据存在 + + Args: + es_client: ES客户端 + index_prefix: 索引前缀 + expected_filename: 期望的文件名 + + Returns: + tuple: (metadata_count, chunk_count) + """ + metadata_index = self._get_metadata_index_name(index_prefix) + chunk_index = self._get_chunk_index_name(index_prefix) + + # 刷新索引确保数据可见 + es_client.indices.refresh(index=[metadata_index, chunk_index]) + + # 检查metadata索引 + metadata_response = es_client.search( + index=metadata_index, + body={"query": {"match": {"name": expected_filename}}, "size": 10}, + ) + metadata_count = len(metadata_response["hits"]["hits"]) + + # 检查chunk索引 + chunk_response = es_client.search( + index=chunk_index, body={"query": {"match_all": {}}, "size": 100} + ) + chunk_count = len(chunk_response["hits"]["hits"]) + + return metadata_count, chunk_count + + def _cleanup_es_indexes( + self, es_client: Elasticsearch, index_prefix: str + ) -> None: + """ + 清理测试索引 + + Args: + es_client: ES客户端 + index_prefix: 索引前缀 + """ + metadata_index = self._get_metadata_index_name(index_prefix) + chunk_index = self._get_chunk_index_name(index_prefix) + + try: + if es_client.indices.exists(index=metadata_index): + es_client.indices.delete(index=metadata_index) + if es_client.indices.exists(index=chunk_index): + es_client.indices.delete(index=chunk_index) + print(f"✅ 已清理索引: {metadata_index}, {chunk_index}") + except Exception as e: + print(f"⚠️ 清理索引时出错: {e}") + + def _wait_for_es_data( + self, + es_client: Elasticsearch, + index_prefix: str, + expected_filename: str, + max_wait_time: int = 15, + wait_interval: int = 2, + ) -> tuple[int, int]: + """ + 等待ES数据就绪 + + Args: + es_client: ES客户端 + index_prefix: 索引前缀 + expected_filename: 期望的文件名 + max_wait_time: 最大等待时间(秒) + wait_interval: 检查间隔(秒) + + Returns: + tuple: (metadata_count, chunk_count) + """ + print(f"⏳ 等待数据处理完成... (最多等待{max_wait_time}秒)") + + for attempt in range(max_wait_time // wait_interval): + time.sleep(wait_interval) + + try: + metadata_count, chunk_count = self._verify_es_data_exists( + es_client, index_prefix, expected_filename + ) + + if metadata_count > 0 and chunk_count > 0: + print("✅ ES验证成功!") + print(f" 📝 Metadata记录: {metadata_count}") + print(f" 📄 Chunk记录: {chunk_count}") + return metadata_count, chunk_count + + except Exception as e: + print(f"⏳ 第{attempt + 1}次检查: 数据还未就绪 ({e})") + + # 最后一次验证 + metadata_count, chunk_count = self._verify_es_data_exists( + es_client, index_prefix, expected_filename + ) + + return metadata_count, chunk_count + + def test_upload_file( + self, + client: TestClient, + es_client: Elasticsearch, + get_user_upload_file: Callable[[str], Path], + ) -> None: + """测试文件上传功能并验证ES存储结果""" + + test_file_name = "03_test.pdf" + index_prefix = self.FILE_UPLOAD_INDEX_PREFIX + + print(f"\n📂 测试文件上传: {test_file_name}") + print(f"📋 索引前缀: {index_prefix}") + + try: + test_file_path = get_user_upload_file(test_file_name) + if test_file_path.stat().st_size == 0: + pytest.skip(f"请确保 {test_file_name} 是一个非空pdf文件\n") + except FileNotFoundError: + pytest.skip( + f"测试文件 {test_file_name} 不存在,请放入 tests/fixtures/files/user/ 目录" + ) + + # 清理可能存在的旧数据 + self._cleanup_es_indexes(es_client, index_prefix) + + try: + # 步骤1: 上传文件 + test_file = get_user_upload_file(test_file_name) + + with test_file.open("rb") as f: + response = client.post( + "/api/v1/documents/upload-file", + files={"file": (test_file.name, f, "application/pdf")}, + data={ + "index_prefix": index_prefix, + "category": "test_document", + "tags": "pdf,test", + }, + ) + + # 验证上传响应 + assert response.status_code == 200 + upload_result = response.json() + assert "task_id" in upload_result + assert "message" in upload_result + print(f"✅ 文件上传成功,任务ID: {upload_result['task_id']}") + + # 步骤2: 等待异步处理完成并验证ES数据 + metadata_count, chunk_count = self._wait_for_es_data( + es_client, index_prefix, test_file_name, max_wait_time=15 + ) + + # 断言数据存在 + assert metadata_count > 0, ( + f"未找到metadata记录,索引: {self._get_metadata_index_name(index_prefix)}" + ) + assert chunk_count > 0, ( + f"未找到chunk记录,索引: {self._get_chunk_index_name(index_prefix)}" + ) + + self._assert_task_endpoint(client) + print("🎉 文件上传测试完成!") + + finally: + # 清理测试数据 + self._cleanup_es_indexes(es_client, index_prefix) + + @staticmethod + def _assert_task_endpoint(client: TestClient) -> None: + resp = client.get("/api/v1/tasks/{upload_result['task_id']}") + assert resp.status_code == 200 + task_result = resp.json() + assert "task_id" in task_result + assert "status" in task_result + + def test_upload_file_missing_index_prefix(self, client: TestClient) -> None: + """测试文件上传缺少index_prefix参数的验证""" + + print("\n🛡️ 测试文件上传缺少index_prefix参数...") + + response = client.post( + "/api/v1/documents/upload-file", + files={"file": ("test.txt", "test content", "text/plain")}, + data={"category": "test"}, # 缺少index_prefix + ) + + # 验证返回:422错误(参数验证失败) + assert response.status_code == 422 + error_detail = response.json() + assert "detail" in error_detail + + print("✅ 文件上传缺少index_prefix参数验证通过!") + + def test_upload_file_empty_index_prefix(self, client: TestClient) -> None: + """测试文件上传空index_prefix参数的验证""" + + print("\n🛡️ 测试文件上传空index_prefix参数...") + + response = client.post( + "/api/v1/documents/upload-file", + files={"file": ("test.txt", "test content", "text/plain")}, + data={"index_prefix": ""}, # 空字符串 + ) + + # 验证返回422错误(参数验证失败) + assert response.status_code == 422 + error_detail = response.json() + assert "detail" in error_detail + + print("✅ 文件上传空index_prefix参数验证通过!") + + def test_upload_file_invalid_file(self, client: TestClient) -> None: + """测试文件上传无效文件的处理""" + + print("\n🛡️ 测试文件上传无效文件...") + + # 发送空文件 + response = client.post( + "/api/v1/documents/upload-file", + files={"file": ("empty.txt", "", "text/plain")}, + data={"index_prefix": "test_invalid"}, + ) + + # 验证返回400(立即拒绝) + assert response.status_code == 400 + error_detail = response.json() + assert "detail" in error_detail + + print("✅ 无效文件上传处理验证通过!") + + def test_upload_from_url( + self, + client: TestClient, + es_client: Elasticsearch, + ) -> None: + """测试URL上传功能并验证ES存储结果""" + + index_prefix = self.URL_UPLOAD_INDEX_PREFIX + object_key = "kbase-temp/02_test.pdf" + bucket_name = settings.tencent_oss.bucket + cos_url = f"https://{bucket_name}.cos.{settings.tencent_oss.region}.myqcloud.com/{object_key}" + expected_filename = object_key.split("/")[-1] # 从URL提取文件名 + + print(f"\n🔗 测试URL上传: {cos_url}") + print(f"📋 索引前缀: {index_prefix}") + + # 清理可能存在的旧数据 + self._cleanup_es_indexes(es_client, index_prefix) + + try: + # 步骤1: 通过URL上传 + response = client.post( + "/api/v1/documents/upload-from-url", + json={ + "url": cos_url, + "index_prefix": index_prefix, + "category": "test_url_document", + "tags": "pdf,url_test", + }, + ) + + # 验证上传响应 + assert response.status_code == 200 + upload_result = response.json() + assert "task_id" in upload_result + assert "message" in upload_result + print(f"✅ URL上传成功,任务ID: {upload_result['task_id']}") + + # 步骤2: 等待异步处理完成并验证ES数据 (URL下载需要更长时间) + metadata_count, chunk_count = self._wait_for_es_data( + es_client, + index_prefix, + expected_filename, + max_wait_time=20, + wait_interval=3, + ) + + # 断言数据存在 + assert metadata_count > 0, ( + f"未找到metadata记录,索引: {self._get_metadata_index_name(index_prefix)}" + ) + assert chunk_count > 0, ( + f"未找到chunk记录,索引: {self._get_chunk_index_name(index_prefix)}" + ) + + self._assert_task_endpoint(client) + + print("🎉 URL上传测试完成!") + + finally: + # 清理测试数据 + self._cleanup_es_indexes(es_client, index_prefix) + + def test_upload_from_url_missing_index_prefix( + self, client: TestClient + ) -> None: + """测试URL上传缺少index_prefix参数的验证""" + + print("\n🛡️ 测试URL上传缺少index_prefix参数...") + + response = client.post( + "/api/v1/documents/upload-from-url", + json={"url": "https://example.com/test.pdf"}, # 缺少index_prefix + ) + + # 验证返回:422错误(参数验证失败) + assert response.status_code == 422 + error_detail = response.json() + assert "detail" in error_detail + + print("✅ URL上传缺少index_prefix参数验证通过!") + + def test_upload_from_url_empty_index_prefix( + self, client: TestClient + ) -> None: + """测试URL上传空index_prefix参数的验证""" + + print("\n🛡️ 测试URL上传空index_prefix参数...") + + response = client.post( + "/api/v1/documents/upload-from-url", + json={ + "url": "https://example.com/test.pdf", + "index_prefix": "", # 空字符串 + }, + ) + + # 验证返回422错误(参数验证失败) + assert response.status_code == 422 + error_detail = response.json() + assert "detail" in error_detail + + print("✅ URL上传空index_prefix参数验证通过!") + + def test_upload_from_url_invalid_url(self, client: TestClient) -> None: + """测试URL上传无效URL的验证""" + + print("\n🛡️ 测试URL上传无效URL...") + + response = client.post( + "/api/v1/documents/upload-from-url", + json={"url": "not-a-valid-url", "index_prefix": "test_invalid"}, + ) + + # 验证返回422错误(URL格式验证失败) + assert response.status_code == 422 + error_detail = response.json() + assert "detail" in error_detail + + print("✅ 无效URL验证通过!")