diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 40cf60f3dd14..140eb420ce9e 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -7,6 +7,8 @@ assignees: '' --- +**Please state you issue using the following template and, most importantly, in English.** + **Describe the bug** A clear and concise description of what the bug is. diff --git a/.github/ISSUE_TEMPLATE/documentation-request.md b/.github/ISSUE_TEMPLATE/documentation-request.md index 133fb9e1e9b3..503ecd3e64d7 100644 --- a/.github/ISSUE_TEMPLATE/documentation-request.md +++ b/.github/ISSUE_TEMPLATE/documentation-request.md @@ -7,7 +7,7 @@ assignees: '' --- -## Report incorrect documentation +**Please state you issue using the following template and, most importantly, in English.** **Location of incorrect documentation** Provide links and line numbers if applicable. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md index 01bceb3321d4..33c748bebcef 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -7,6 +7,8 @@ assignees: '' --- +**Please state you issue using the following template and, most importantly, in English.** + **Is your feature request related to a problem? Please describe.** A clear and concise description of what the problem is. E.g. I wish I could use Milvus to do [...] diff --git a/.github/ISSUE_TEMPLATE/general-question.md b/.github/ISSUE_TEMPLATE/general-question.md index d49fce181742..44d288abf835 100644 --- a/.github/ISSUE_TEMPLATE/general-question.md +++ b/.github/ISSUE_TEMPLATE/general-question.md @@ -7,4 +7,6 @@ assignees: '' --- +**Please state you issue using the following template and, most importantly, in English.** + **What is your question?** diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 2976ba9824bd..2f29cafc346b 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,11 +1,10 @@ **What type of PR is this?** -api-change / bug / improvement / documentation / feature +API-change / BUG / Improvement / Documentation / Feature -**Need this PR be picked to master branch?** - -Yes / No +**Which branch you want to cherry-pick to?** +Not Available **Which issue(s) this PR fixes:** @@ -13,9 +12,12 @@ Fixes # **What this PR does / why we need it:** +Not Available **Special notes for your reviewer:** +Not Available **Additional documentation (e.g. design docs, usage docs, etc.):** +Not Available \ No newline at end of file diff --git a/.gitignore b/.gitignore index b6c4247481e8..efc7bc816f7f 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,8 @@ core/thirdparty/knowhere_build .idea/ .ycm_extra_conf.py +core/compile_commands.json +*/.clangd/* __pycache__ # vscode generated files diff --git a/CHANGELOG.md b/CHANGELOG.md index ff4e041d144a..4714b34dd2f7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,10 +11,14 @@ Please mark all changes in change log and use the issue from GitHub - \#2557 Fix random crash of INSERT_DUPLICATE_ID case - \#2578 Result count doesn't match target vectors count - \#2582 CreateHybridIndex.cpp compile error +- \#2585 Support IVF_PQ IP on GPU - \#2598 Fix Milvus docker image report illegal instruction - \#2617 Fix HNSW and RNSG index files size - \#2637 Suit the range of HNSW parameters +- \#2642 Create index failed and server crashed - \#2649 Search parameter of annoy has conflict with document +- \#2690 Remove body parser in show-partitions endpoints +- \#2692 Milvus hangs during multi-thread concurrent search - \#2693 Collection create success if no dimension value provided - \#2694 Collection create success if an invalid field name provided - \#2695 The number of fields should be limited @@ -25,21 +29,36 @@ Please mark all changes in change log and use the issue from GitHub - \#2731 No entity returned with `get_entity_by_id` - \#2732 Server destroyed after `delete by id` - \#2733 The max value of top-k should be limited +- \#2739 Fix mishards start failed +- \#2752 Milvus formats vectors data to double-precision and return to http client - \#2763 Unexpected error when insert binary entities - \#2765 Server crashed when calling `get_entity_by_id` +- \#2767 Fix a bug of getting wrong nprobe limitation in knowhere on GPU version +- \#2768 After building the index, the number of vectors increases +- \#2776 Fix too many data copies during creating IVF index - \#2783 Wrong result returned if searching with tags - \#2790 Distances returned by calling `search` is inaccurate +- \#2813 To implemente RNSG IP - \#2818 Wrong result returned by `get_entity_by_id` - \#2823 Server crashed during inserting, and can not restart - \#2845 Server crashed after calling `delete_entity_by_id` +- \#2852 Fix Prometheus rebuild problem. - \#2869 Create index failed with binary vectors +- \#2890 Fix wrong index size - \#2893 Insert binary data failed +- \#2952 Fix the result merging of IVF_PQ IP +- \#2957 There is no exisitence check of annoy search parameter ## Feature - \#2319 Redo metadata to support MVCC - \#2509 Count up query statistics for debug ease - \#2572 Support structured data index - \#2585 Support IVF_PQ on GPU with using metric_type IP +- \#2689 Construct Knowhere Index Without Data +- \#2798 hnsw-sq8 support +- \#2802 Add new index: IVFSQ8NR +- \#2834 Add C++ sdk support 4 hnsw_sq8nr +- \#2940 Add option to build.sh for cuda arch ## Improvement - \#2543 Remove secondary_path related code @@ -47,6 +66,10 @@ Please mark all changes in change log and use the issue from GitHub - \#2561 Clean util dependencies with other modules - \#2612 Move all APIs in utils into namespace milvus - \#2675 Print out system memory size when report invalid cpu cache size +- \#2828 Let Faiss not compile half float by default +- \#2841 Replace IndexType/EngineType/MetricType +- \#2858 Unify index name in db +- \#2884 Using BlockingQueue in JobMgr ## Task diff --git a/INSTALL.md b/INSTALL.md index 1e3d8b06a2d3..835b209dbdab 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -52,7 +52,7 @@ For GPU-enabled version, you will also need: -- CUDA 10.0 or higher +- CUDA 10.x (10.0, 10.1, 10.2) - NVIDIA driver 418 or higher diff --git a/README.md b/README.md index 418fcd903c58..6cd10cbde494 100644 --- a/README.md +++ b/README.md @@ -12,24 +12,24 @@ ## What is Milvus -As an open source vector similarity search engine, Milvus is easy-to-use, highly reliable, scalable, robust, and blazing fast. Adopted by over 100 organizations and institutions worldwide, Milvus empowers applications in a variety of fields, including image processing, computer vision, natural language processing, voice recognition, recommender systems, drug discovery, etc. +As an open source vector similarity search engine, Milvus is easy-to-use, highly reliable, scalable, robust, and blazing fast. Adopted by over 100 organizations and institutions worldwide, Milvus empowers applications in a variety of fields, including image processing, computer vision, natural language processing, voice recognition, recommender systems, drug discovery, and more. -Milvus has the following architecture: +The following is Milvus architecture: ![arch](https://github.com/milvus-io/docs/blob/v0.9.1/assets/milvus_arch.png) -For more detailed introduction of Milvus and its architecture, see [Milvus overview](https://www.milvus.io/docs/about_milvus/overview.md). Keep up-to-date with newest releases and latest updates by reading Milvus [release notes](https://www.milvus.io/docs/releases/release_notes.md). +For more detailed introduction of Milvus and its architecture, see [Milvus overview](https://www.milvus.io/docs/overview.md). See Milvus [release notes](https://www.milvus.io/docs/release_notes.md) to keep up-to-date with its releases and updates. -Milvus is an [LF AI Foundation](https://lfai.foundation/) incubation project. Learn more at [lfai.foundation](https://lfai.foundation/). +Milvus is an [LF AI Foundation](https://lfai.foundation/) incubation project. ## Get started ### Install Milvus -See the [Milvus install guide](https://www.milvus.io/docs/guides/get_started/install_milvus/install_milvus.md) to install Milvus using Docker. To install Milvus from source code, see [build from source](INSTALL.md). +See [Milvus install guide](https://www.milvus.io/docs/install_milvus.md) to install Milvus using Docker. To install Milvus from source code, see [build from source](INSTALL.md). ### Try example programs -Try an example program with Milvus using [Python](https://www.milvus.io/docs/guides/get_started/example_code.md), [Java](https://github.com/milvus-io/milvus-sdk-java/tree/master/examples), [Go](https://github.com/milvus-io/milvus-sdk-go/tree/master/examples), or [C++ example code](https://github.com/milvus-io/milvus/tree/master/sdk/examples). +Try an example program with Milvus using [Python](https://www.milvus.io/docs/example_code.md), [Java](https://github.com/milvus-io/milvus-sdk-java/tree/master/examples), [Go](https://github.com/milvus-io/milvus-sdk-go/tree/master/examples), or [C++ example code](https://github.com/milvus-io/milvus/tree/master/sdk/examples). ## Supported clients @@ -38,11 +38,11 @@ Try an example program with Milvus using [Python](https://www.milvus.io/docs/gui - [Java](https://github.com/milvus-io/milvus-sdk-java) - [C++](https://github.com/milvus-io/milvus/tree/master/sdk) - [RESTful API](https://github.com/milvus-io/milvus/tree/master/core/src/server/web_impl) -- [Node.js](https://www.npmjs.com/package/@arkie-ai/milvus-client) (Provided by [arkie](https://www.arkie.cn/)) +- [Node.js](https://www.npmjs.com/package/@arkie-ai/milvus-client) (Contributed by [arkie](https://www.arkie.cn/)) ## Application scenarios -You can use Milvus to build intelligent systems in a variety of AI application scenarios. Refer to [Milvus Scenarios](https://milvus.io/scenarios) for live demos. You can also refer to [Milvus Bootcamp](https://github.com/milvus-io/bootcamp) for detailed solutions and application scenarios. +You can use Milvus to build intelligent systems in a variety of AI application scenarios. See [Milvus Scenarios](https://milvus.io/scenarios) for live demos. You can also see [Milvus Bootcamp](https://github.com/milvus-io/bootcamp) for detailed solutions and application scenarios. ## Benchmark @@ -52,11 +52,11 @@ See our [test reports](https://github.com/milvus-io/milvus/tree/master/docs) for To learn what's coming up soon in Milvus, read our [Roadmap](https://github.com/milvus-io/milvus/milestones). -It is a Work in Progress, and is subject to reasonable adjustments when necessary. And we greatly welcome any comments/requirements/suggestions regarding Milvus roadmap.:clap: +It is a Work in Progress, and is subject to reasonable adjustments when necessary. And we greatly appreciate any comments/requirements/suggestions regarding Milvus' roadmap.:clap: ## Contribution guidelines -Contributions are welcomed and greatly appreciated. Please read our [contribution guidelines](CONTRIBUTING.md) for detailed contribution workflow. This project adheres to the [code of conduct](CODE_OF_CONDUCT.md) of Milvus. By participating, you are expected to uphold this code. +Contributions are welcomed and greatly appreciated. Please read our [contribution guidelines](CONTRIBUTING.md) for detailed contribution workflow. This project adheres to the [code of conduct](CODE_OF_CONDUCT.md) of Milvus. You must abide by this code in order to participate. We use [GitHub issues](https://github.com/milvus-io/milvus/issues) to track issues and bugs. For general questions and public discussions, please join our community. @@ -64,7 +64,7 @@ We use [GitHub issues](https://github.com/milvus-io/milvus/issues) to track issu :heart:To connect with other users and contributors, welcome to join our [Slack channel](https://join.slack.com/t/milvusio/shared_invite/zt-e0u4qu3k-bI2GDNys3ZqX1YCJ9OM~GQ). -See our [community](https://github.com/milvus-io/community) repository to learn about our governance and access more community resources. +See our [community](https://github.com/milvus-io/community) repository to learn more about our governance and access more community resources. ## Resources diff --git a/README_CN.md b/README_CN.md index 933c55ffc311..38dcb3365e7e 100644 --- a/README_CN.md +++ b/README_CN.md @@ -23,7 +23,7 @@ Milvus 的架构如下: ![arch](https://github.com/milvus-io/docs/blob/v0.9.1/assets/milvus_arch.png) -若要了解 Milvus 详细介绍和整体架构,请访问 [Milvus 简介](https://www.milvus.io/cn/docs/about_milvus/overview.md)。您可以通过 [版本发布说明](https://www.milvus.io/cn/docs/releases/release_notes.md) 获取最新版本的功能和更新。 +若要了解 Milvus 详细介绍和整体架构,请访问 [Milvus 简介](https://www.milvus.io/docs/overview.md)。你可以通过 [版本发布说明](https://www.milvus.io/docs/release_notes.md) 获取最新版本的功能和更新。 Milvus是一个[LF AI基金会](https://lfai.foundation/)的孵化项目。获取更多,请访问[lfai.foundation](https://lfai.foundation/)。 @@ -31,11 +31,11 @@ Milvus是一个[LF AI基金会](https://lfai.foundation/)的孵化项目。获 ### 安装 Milvus -请参阅 [Milvus 安装指南](https://www.milvus.io/cn/docs/guides/get_started/install_milvus/install_milvus.md) 使用 Docker 容器安装 Milvus。若要基于源码编译,请访问 [源码安装](INSTALL.md)。 +请参阅 [Milvus 安装指南](https://www.milvus.io/docs/install_milvus.md) 使用 Docker 容器安装 Milvus。若要基于源码编译,请访问 [源码安装](INSTALL.md)。 ### 尝试示例代码 -您可以尝试用 [Python](https://www.milvus.io/cn/docs/guides/get_started/example_code.md),[Java](https://github.com/milvus-io/milvus-sdk-java/tree/master/examples),[Go](https://github.com/milvus-io/milvus-sdk-go/tree/master/examples),或者 [C++](https://github.com/milvus-io/milvus/tree/master/sdk/examples) 运行 Milvus 示例代码。 +你可以尝试用 [Python](https://www.milvus.io/docs/example_code.md),[Java](https://github.com/milvus-io/milvus-sdk-java/tree/master/examples),[Go](https://github.com/milvus-io/milvus-sdk-go/tree/master/examples),或者 [C++](https://github.com/milvus-io/milvus/tree/master/sdk/examples) 运行 Milvus 示例代码。 ## 支持的客户端 @@ -48,7 +48,7 @@ Milvus是一个[LF AI基金会](https://lfai.foundation/)的孵化项目。获 ## 应用场景 -Milvus 可以应用于多种 AI 场景。您可以访问 [Milvus 应用场景](https://milvus.io/scenarios) 体验在线场景展示。您也可以访问 [Milvus 训练营](https://github.com/milvus-io/bootcamp) 了解更详细的应用场景和解决方案。 +Milvus 可以应用于多种 AI 场景。你可以访问 [Milvus 应用场景](https://milvus.io/scenarios) 体验在线场景展示。你也可以访问 [Milvus 训练营](https://github.com/milvus-io/bootcamp) 了解更详细的应用场景和解决方案。 ## 性能基准测试 @@ -56,15 +56,15 @@ Milvus 可以应用于多种 AI 场景。您可以访问 [Milvus 应用场景](h ## 路线图 -您可以参考我们的[路线图](https://github.com/milvus-io/milvus/milestones),了解 Milvus 即将实现的新特性。 +你可以参考我们的[路线图](https://github.com/milvus-io/milvus/milestones),了解 Milvus 即将实现的新特性。 路线图尚未完成,并且可能会存在合理改动。我们欢迎各种针对路线图的意见、需求和建议。 ## 贡献者指南 -我们由衷欢迎您推送贡献。关于贡献流程的详细信息,请参阅[贡献者指南](https://github.com/milvus-io/milvus/blob/master/CONTRIBUTING.md)。本项目遵循 Milvus [行为准则](https://github.com/milvus-io/milvus/blob/master/CODE_OF_CONDUCT.md)。如果您希望参与本项目,请遵守该准则的内容。 +我们由衷欢迎你推送贡献。关于贡献流程的详细信息,请参阅[贡献者指南](https://github.com/milvus-io/milvus/blob/master/CONTRIBUTING.md)。本项目遵循 Milvus [行为准则](https://github.com/milvus-io/milvus/blob/master/CODE_OF_CONDUCT.md)。如果你希望参与本项目,请遵守该准则的内容。 -我们使用 [GitHub issues](https://github.com/milvus-io/milvus/issues) 追踪问题和补丁。若您希望提出问题或进行讨论,请加入我们的社区。 +我们使用 [GitHub issues](https://github.com/milvus-io/milvus/issues) 追踪问题和补丁。若你希望提出问题或进行讨论,请加入我们的社区。 ## 加入 Milvus 社区 diff --git a/ci/jenkins/Jenkinsfile b/ci/jenkins/Jenkinsfile index 3464d10262cd..d793c1eb7577 100644 --- a/ci/jenkins/Jenkinsfile +++ b/ci/jenkins/Jenkinsfile @@ -31,7 +31,7 @@ pipeline { LOWER_BUILD_TYPE = params.BUILD_TYPE.toLowerCase() SEMVER = "${BRANCH_NAME.contains('/') ? BRANCH_NAME.substring(BRANCH_NAME.lastIndexOf('/') + 1) : BRANCH_NAME}" PIPELINE_NAME = "milvus-ci" - HELM_BRANCH = "0.10.0" + HELM_BRANCH = "0.10.1" } stages { @@ -84,7 +84,17 @@ pipeline { steps { container("milvus-${BINARY_VERSION}-build-env") { script { - load "${env.WORKSPACE}/ci/jenkins/step/build.groovy" + try{ + boolean isNightlyTest = isTimeTriggeredBuild() + if (isNightlyTest || "${params.IS_MANUAL_TRIGGER_TYPE}" == "True") { + load "${env.WORKSPACE}/ci/jenkins/step/nightlyBuild.groovy" + } else { + load "${env.WORKSPACE}/ci/jenkins/step/build.groovy" + } + } catch (Exception e) { + containerLog "milvus-${BINARY_VERSION}-build-env" + throw e + } } } } diff --git a/ci/jenkins/step/build.groovy b/ci/jenkins/step/build.groovy index 0bdbed342991..a1b76ab64e77 100644 --- a/ci/jenkins/step/build.groovy +++ b/ci/jenkins/step/build.groovy @@ -4,7 +4,7 @@ timeout(time: 120, unit: 'MINUTES') { def checkResult = sh(script: "./check_ccache.sh -l ${params.JFROG_ARTFACTORY_URL}/ccache", returnStatus: true) if ("${BINARY_VERSION}" == "gpu") { - sh "/bin/bash --login -c \". ./before-install.sh && ./build.sh -t ${params.BUILD_TYPE} -j4 -i ${env.MILVUS_INSTALL_PREFIX} --with_fiu --coverage -l -g -u\"" + sh "/bin/bash --login -c \". ./before-install.sh && ./build.sh -t ${params.BUILD_TYPE} -j4 -i ${env.MILVUS_INSTALL_PREFIX} --with_fiu --coverage -l -g -u -s '-gencode=arch=compute_61,code=sm_61;-gencode=arch=compute_75,code=sm_75' \"" } else { sh "/bin/bash --login -c \". ./before-install.sh && ./build.sh -t ${params.BUILD_TYPE} -j4 -i ${env.MILVUS_INSTALL_PREFIX} --with_fiu --coverage -l -u\"" } diff --git a/ci/jenkins/step/nightlyBuild.groovy b/ci/jenkins/step/nightlyBuild.groovy new file mode 100644 index 000000000000..0bdbed342991 --- /dev/null +++ b/ci/jenkins/step/nightlyBuild.groovy @@ -0,0 +1,14 @@ +timeout(time: 120, unit: 'MINUTES') { + dir ("ci/scripts") { + withCredentials([usernamePassword(credentialsId: "${params.JFROG_CREDENTIALS_ID}", usernameVariable: 'USERNAME', passwordVariable: 'PASSWORD')]) { + def checkResult = sh(script: "./check_ccache.sh -l ${params.JFROG_ARTFACTORY_URL}/ccache", returnStatus: true) + + if ("${BINARY_VERSION}" == "gpu") { + sh "/bin/bash --login -c \". ./before-install.sh && ./build.sh -t ${params.BUILD_TYPE} -j4 -i ${env.MILVUS_INSTALL_PREFIX} --with_fiu --coverage -l -g -u\"" + } else { + sh "/bin/bash --login -c \". ./before-install.sh && ./build.sh -t ${params.BUILD_TYPE} -j4 -i ${env.MILVUS_INSTALL_PREFIX} --with_fiu --coverage -l -u\"" + } + sh "./update_ccache.sh -l ${params.JFROG_ARTFACTORY_URL}/ccache -u ${USERNAME} -p ${PASSWORD}" + } + } +} diff --git a/ci/jenkins/step/shardsDevNightlyTest.groovy b/ci/jenkins/step/shardsDevNightlyTest.groovy index 88b5cbdcd25c..e2a86cdc8e3f 100644 --- a/ci/jenkins/step/shardsDevNightlyTest.groovy +++ b/ci/jenkins/step/shardsDevNightlyTest.groovy @@ -8,7 +8,13 @@ timeout(time: 180, unit: 'MINUTES') { retry(3) { try { - sh "helm install --wait --timeout 300s --set cluster.enabled=true --set persistence.enabled=true --set image.repository=registry.zilliz.com/milvus/engine --set mishards.image.tag=test --set image.tag=${DOCKER_VERSION} --set image.pullPolicy=Always --set service.type=ClusterIP -f ci/db_backend/mysql_${BINARY_VERSION}_values.yaml --namespace milvus ${env.SHARDS_HELM_RELEASE_NAME} ." + dir ('charts/milvus') { + if ("${BINARY_VERSION}" == "CPU") { + sh "helm install --wait --timeout 300s --set cluster.enabled=true --set persistence.enabled=true --set image.repository=registry.zilliz.com/milvus/engine --set mishards.image.tag=test --set mishards.image.pullPolicy=Always --set image.tag=${DOCKER_VERSION} --set image.pullPolicy=Always --set service.type=ClusterIP --set image.resources.requests.memory=8Gi --set image.resources.requests.cpu=2.0 --set image.resources.limits.memory=12Gi --set image.resources.limits.cpu=4.0 -f ci/db_backend/mysql_${BINARY_VERSION}_values.yaml --namespace milvus ${env.SHARDS_HELM_RELEASE_NAME} ." + } else { + sh "helm install --wait --timeout 300s --set cluster.enabled=true --set persistence.enabled=true --set image.repository=registry.zilliz.com/milvus/engine --set mishards.image.tag=test --set mishards.image.pullPolicy=Always --set gpu.enabled=true --set image.tag=${DOCKER_VERSION} --set image.pullPolicy=Always --set service.type=ClusterIP -f ci/db_backend/mysql_${BINARY_VERSION}_values.yaml --namespace milvus ${env.SHARDS_HELM_RELEASE_NAME} ." + } + } } catch (exc) { def helmStatusCMD = "helm get manifest --namespace milvus ${env.SHARDS_HELM_RELEASE_NAME} | kubectl describe -n milvus -f - && \ kubectl logs --namespace milvus -l \"app.kubernetes.io/name=milvus,app.kubernetes.io/instance=${env.SHARDS_HELM_RELEASE_NAME},component=writable\" -c milvus && \ @@ -24,6 +30,6 @@ timeout(time: 180, unit: 'MINUTES') { dir ("tests/milvus_python_test") { sh 'python3 -m pip install -r requirements.txt' - sh "pytest . --level=2 --alluredir=\"test_out/dev/shards/\" --ip ${env.SHARDS_HELM_RELEASE_NAME}.milvus.svc.cluster.local >> ${WORKSPACE}/${env.DEV_TEST_ARTIFACTS}/milvus_shards_dev_test.log" + sh "pytest . --level=2 --alluredir=\"test_out/dev/shards/\" --ip ${env.SHARDS_HELM_RELEASE_NAME}.milvus.svc.cluster.local >> ${WORKSPACE}/${env.DEV_TEST_ARTIFACTS}/milvus_${BINARY_VERSION}_shards_dev_test.log" } } diff --git a/ci/jenkins/step/singleDevNightlyTest.groovy b/ci/jenkins/step/singleDevNightlyTest.groovy index 5d8a3a06bcb0..923f9307dc74 100644 --- a/ci/jenkins/step/singleDevNightlyTest.groovy +++ b/ci/jenkins/step/singleDevNightlyTest.groovy @@ -8,7 +8,9 @@ timeout(time: 180, unit: 'MINUTES') { retry(3) { try { - sh "helm install --wait --timeout 300s --set image.repository=registry.zilliz.com/milvus/engine --set image.tag=${DOCKER_VERSION} --set image.pullPolicy=Always --set service.type=ClusterIP -f ci/db_backend/mysql_${BINARY_VERSION}_values.yaml -f ci/filebeat/values.yaml --namespace milvus ${env.HELM_RELEASE_NAME} ." + dir ('charts/milvus') { + sh "helm install --wait --timeout 300s --set image.repository=registry.zilliz.com/milvus/engine --set image.tag=${DOCKER_VERSION} --set image.pullPolicy=Always --set service.type=ClusterIP --set image.resources.requests.memory=8Gi --set image.resources.requests.cpu=2.0 --set image.resources.limits.memory=12Gi --set image.resources.limits.cpu=4.0 -f ci/db_backend/mysql_${BINARY_VERSION}_values.yaml -f ci/filebeat/values.yaml --namespace milvus ${env.HELM_RELEASE_NAME} ." + } } catch (exc) { def helmStatusCMD = "helm get manifest --namespace milvus ${env.HELM_RELEASE_NAME} | kubectl describe -n milvus -f - && \ kubectl logs --namespace milvus -l \"app.kubernetes.io/name=milvus,app.kubernetes.io/instance=${env.HELM_RELEASE_NAME}\" -c milvus && \ @@ -23,32 +25,34 @@ timeout(time: 180, unit: 'MINUTES') { dir ("tests/milvus_python_test") { // sh 'python3 -m pip install -r requirements.txt -i http://pypi.douban.com/simple --trusted-host pypi.douban.com' sh 'python3 -m pip install -r requirements.txt' - sh "pytest . --level=2 --alluredir=\"test_out/dev/single/mysql\" --ip ${env.HELM_RELEASE_NAME}.milvus.svc.cluster.local >> ${WORKSPACE}/${env.DEV_TEST_ARTIFACTS}/milvus_mysql_dev_test.log" + sh "pytest . --level=2 --alluredir=\"test_out/dev/single/mysql\" --ip ${env.HELM_RELEASE_NAME}.milvus.svc.cluster.local >> ${WORKSPACE}/${env.DEV_TEST_ARTIFACTS}/milvus_${BINARY_VERSION}_mysql_dev_test.log" } // sqlite database backend test load "ci/jenkins/step/cleanupSingleDev.groovy" - if (!fileExists('milvus-helm')) { + if (!fileExists('milvus-helm/charts/milvus')) { dir ("milvus-helm") { checkout([$class: 'GitSCM', branches: [[name:"${env.HELM_BRANCH}"]], userRemoteConfigs: [[url: "https://github.com/milvus-io/milvus-helm.git", name: 'origin', refspec: "+refs/heads/${env.HELM_BRANCH}:refs/remotes/origin/${env.HELM_BRANCH}"]]]) } } - dir ("milvus-helm") { - retry(3) { - try { - sh "helm install --wait --timeout 300s --set image.repository=registry.zilliz.com/milvus/engine --set image.tag=${DOCKER_VERSION} --set image.pullPolicy=Always --set service.type=ClusterIP -f ci/db_backend/sqlite_${BINARY_VERSION}_values.yaml -f ci/filebeat/values.yaml --namespace milvus ${env.HELM_RELEASE_NAME} ." - } catch (exc) { - def helmStatusCMD = "helm get manifest --namespace milvus ${env.HELM_RELEASE_NAME} | kubectl describe -n milvus -f - && \ - kubectl logs --namespace milvus -l \"app.kubernetes.io/name=milvus,app.kubernetes.io/instance=${env.HELM_RELEASE_NAME}\" -c milvus && \ - helm status -n milvus ${env.HELM_RELEASE_NAME}" - sh script: helmStatusCMD, returnStatus: true - sh script: "helm uninstall -n milvus ${env.HELM_RELEASE_NAME} && sleep 1m", returnStatus: true - throw exc + retry(3) { + try { + dir ("milvus-helm/charts/milvus") { + sh "helm install --wait --timeout 300s --set image.repository=registry.zilliz.com/milvus/engine --set image.tag=${DOCKER_VERSION} --set image.pullPolicy=Always --set service.type=ClusterIP --set image.resources.requests.memory=8Gi --set image.resources.requests.cpu=2.0 --set image.resources.limits.memory=12Gi --set image.resources.limits.cpu=4.0 -f ci/db_backend/sqlite_${BINARY_VERSION}_values.yaml -f ci/filebeat/values.yaml --namespace milvus ${env.HELM_RELEASE_NAME} ." + } + } catch (exc) { + def helmStatusCMD = "helm get manifest --namespace milvus ${env.HELM_RELEASE_NAME} | kubectl describe -n milvus -f - && \ + kubectl logs --namespace milvus -l \"app=milvus,release=${env.HELM_RELEASE_NAME}\" -c milvus && \ + helm status -n milvus ${env.HELM_RELEASE_NAME}" + def helmResult = sh script: helmStatusCMD, returnStatus: true + if (!helmResult) { + sh "helm uninstall -n milvus ${env.HELM_RELEASE_NAME} && sleep 1m" } + throw exc } } dir ("tests/milvus_python_test") { - sh "pytest . --level=2 --alluredir=\"test_out/dev/single/sqlite\" --ip ${env.HELM_RELEASE_NAME}.milvus.svc.cluster.local >> ${WORKSPACE}/${env.DEV_TEST_ARTIFACTS}/milvus_sqlite_dev_test.log" - sh "pytest . --level=1 --ip ${env.HELM_RELEASE_NAME}.milvus.svc.cluster.local --port=19121 --handler=HTTP" + sh "pytest . --level=2 --alluredir=\"test_out/dev/single/sqlite\" --ip ${env.HELM_RELEASE_NAME}.milvus.svc.cluster.local >> ${WORKSPACE}/${env.DEV_TEST_ARTIFACTS}/milvus_${BINARY_VERSION}_sqlite_dev_test.log" + sh "pytest . --level=1 --ip ${env.HELM_RELEASE_NAME}.milvus.svc.cluster.local --port=19121 --handler=HTTP >> ${WORKSPACE}/${env.DEV_TEST_ARTIFACTS}/milvus_${BINARY_VERSION}_sqlite_http_dev_test.log" } } diff --git a/ci/jenkins/step/singleDevTest.groovy b/ci/jenkins/step/singleDevTest.groovy index bfd73500e90a..b4dbfc44fc80 100644 --- a/ci/jenkins/step/singleDevTest.groovy +++ b/ci/jenkins/step/singleDevTest.groovy @@ -8,7 +8,9 @@ timeout(time: 120, unit: 'MINUTES') { retry(3) { try { - sh "helm install --wait --timeout 300s --set image.repository=registry.zilliz.com/milvus/engine --set persistence.enabled=true --set image.tag=${DOCKER_VERSION} --set image.pullPolicy=Always --set service.type=ClusterIP -f ci/db_backend/mysql_${BINARY_VERSION}_values.yaml -f ci/filebeat/values.yaml --namespace milvus ${env.HELM_RELEASE_NAME} ." + dir ('charts/milvus') { + sh "helm install --wait --timeout 300s --set image.repository=registry.zilliz.com/milvus/engine --set persistence.enabled=true --set image.tag=${DOCKER_VERSION} --set image.pullPolicy=Always --set service.type=ClusterIP -f ci/db_backend/mysql_${BINARY_VERSION}_values.yaml -f ci/filebeat/values.yaml --namespace milvus ${env.HELM_RELEASE_NAME} ." + } } catch (exc) { def helmStatusCMD = "helm get manifest --namespace milvus ${env.HELM_RELEASE_NAME} | kubectl describe -n milvus -f - && \ kubectl logs --namespace milvus -l \"app.kubernetes.io/name=milvus,app.kubernetes.io/instance=${env.HELM_RELEASE_NAME}\" -c milvus && \ @@ -23,7 +25,7 @@ timeout(time: 120, unit: 'MINUTES') { // sh 'python3 -m pip install -r requirements.txt -i http://pypi.douban.com/simple --trusted-host pypi.douban.com' sh 'python3 -m pip install -r requirements_no_pymilvus.txt' sh 'python3 -m pip install git+https://github.com/BossZou/pymilvus.git@api-update' - sh "pytest . --alluredir=\"test_out/dev/single/mysql\" --level=1 --ip ${env.HELM_RELEASE_NAME}.milvus.svc.cluster.local --service ${env.HELM_RELEASE_NAME} >> ${WORKSPACE}/${env.DEV_TEST_ARTIFACTS}/milvus_mysql_dev_test.log" + sh "pytest . --alluredir=\"test_out/dev/single/mysql\" --level=1 --ip ${env.HELM_RELEASE_NAME}.milvus.svc.cluster.local --service ${env.HELM_RELEASE_NAME} >> ${WORKSPACE}/${env.DEV_TEST_ARTIFACTS}/milvus_${BINARY_VERSION}_mysql_dev_test.log" // sh "pytest test_restart.py --alluredir=\"test_out/dev/single/mysql\" --level=3 --ip ${env.HELM_RELEASE_NAME}.milvus.svc.cluster.local --service ${env.HELM_RELEASE_NAME}" } } diff --git a/ci/jenkins/step/unittest.groovy b/ci/jenkins/step/unittest.groovy index 9bd85a764ac6..5e59378a8151 100644 --- a/ci/jenkins/step/unittest.groovy +++ b/ci/jenkins/step/unittest.groovy @@ -1,5 +1,5 @@ timeout(time: 30, unit: 'MINUTES') { dir ("ci/scripts") { - sh "./run_unittest.sh -i ${env.MILVUS_INSTALL_PREFIX} --mysql_user=root --mysql_password=123456 --mysql_host=\$POD_IP" + sh "./run_unittest.sh -i ${env.MILVUS_INSTALL_PREFIX} --mysql_user=root --mysql_password=123456 --mysql_host=\"127.0.0.1\"" } } diff --git a/ci/scripts/build.sh b/ci/scripts/build.sh index 3260fb13d98a..fafd5284f556 100755 --- a/ci/scripts/build.sh +++ b/ci/scripts/build.sh @@ -22,6 +22,7 @@ Usage: Install directory used by install. -t [BUILD_TYPE] or --build_type=[BUILD_TYPE] Build type (default: Release) + -s [CUDA_ARCH] Building for the cuda architecture -j[N] or --jobs=[N] Allow N jobs at once; infinite jobs with no arg. -l Run cpplint & check clang-format -n No make and make install step @@ -38,7 +39,7 @@ Usage: Use \"$0 --help\" for more information about a given command. " -ARGS=`getopt -o "i:t:j::lngcupvh" -l "install_prefix::,build_type::,jobs::,with_mkl,with_fiu,coverage,tests,privileges,help" -n "$0" -- "$@"` +ARGS=`getopt -o "i:t:s:j::lngcupvh" -l "install_prefix::,build_type::,jobs::,with_mkl,with_fiu,coverage,tests,privileges,help" -n "$0" -- "$@"` eval set -- "${ARGS}" @@ -72,6 +73,11 @@ while true ; do -p|--privileges) PRIVILEGES="ON" ; shift ;; -v|--verbose) VERBOSE="1" ; shift ;; -h|--help) echo -e "${HELP}" ; exit 0 ;; + -s) + case "$2" in + "") CUDA_ARCH="DEFAULT"; shift 2 ;; + *) CUDA_ARCH=$2 ; shift 2 ;; + esac ;; --) shift ; break ;; *) echo "Internal error!" ; exit 1 ;; esac @@ -86,6 +92,7 @@ BUILD_UNITTEST=${BUILD_UNITTEST:="OFF"} BUILD_COVERAGE=${BUILD_COVERAGE:="OFF"} COMPILE_BUILD=${COMPILE_BUILD:="ON"} GPU_VERSION=${GPU_VERSION:="OFF"} +CUDA_ARCH=${CUDA_ARCH:="DEFAULT"} RUN_CPPLINT=${RUN_CPPLINT:="OFF"} WITH_MKL=${WITH_MKL:="OFF"} FIU_ENABLE=${FIU_ENABLE:="OFF"} @@ -125,6 +132,7 @@ CMAKE_CMD="cmake \ -DFAISS_SOURCE=AUTO \ -DOpenBLAS_SOURCE=AUTO \ -DMILVUS_WITH_FIU=${FIU_ENABLE} \ +-DMILVUS_CUDA_ARCH=${CUDA_ARCH} \ ${MILVUS_CORE_DIR}" echo ${CMAKE_CMD} ${CMAKE_CMD} diff --git a/core/CMakeLists.txt b/core/CMakeLists.txt index b98170fc0512..906fda6da05c 100644 --- a/core/CMakeLists.txt +++ b/core/CMakeLists.txt @@ -11,147 +11,51 @@ # or implied. See the License for the specific language governing permissions and limitations under the License. #------------------------------------------------------------------------------- - cmake_minimum_required(VERSION 3.12) message(STATUS "Building using CMake version: ${CMAKE_VERSION}") set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake") +include(Utils) -# get build time -MACRO(GET_CURRENT_TIME CURRENT_TIME) - execute_process(COMMAND "date" +"%Y-%m-%d %H:%M.%S" OUTPUT_VARIABLE ${CURRENT_TIME}) -ENDMACRO(GET_CURRENT_TIME) - -GET_CURRENT_TIME(BUILD_TIME) -string(REGEX REPLACE "\n" "" BUILD_TIME ${BUILD_TIME}) +# **************************** Build time, type and code version **************************** +get_current_time(BUILD_TIME) message(STATUS "Build time = ${BUILD_TIME}") - -if (NOT DEFINED CMAKE_BUILD_TYPE) - set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build.") -endif () - -# get Milvus version via branch name -set(GIT_BRANCH_NAME_REGEX "[0-9]+\\.[0-9]+\\.[0-9]") - -MACRO(GET_GIT_BRANCH_NAME GIT_BRANCH_NAME) - execute_process(COMMAND sh "-c" "git log --decorate | head -n 1 | sed 's/.*(\\(.*\\))/\\1/' | sed 's/.*, //' | sed 's=[a-zA-Z]*\/==g'" - OUTPUT_VARIABLE ${GIT_BRANCH_NAME}) - if (NOT GIT_BRANCH_NAME MATCHES "${GIT_BRANCH_NAME_REGEX}") - execute_process(COMMAND "git" rev-parse --abbrev-ref HEAD OUTPUT_VARIABLE ${GIT_BRANCH_NAME}) - endif () - if (NOT GIT_BRANCH_NAME MATCHES "${GIT_BRANCH_NAME_REGEX}") - execute_process(COMMAND "git" symbolic-ref --short -q HEAD HEAD OUTPUT_VARIABLE ${GIT_BRANCH_NAME}) - endif () -ENDMACRO(GET_GIT_BRANCH_NAME) - -GET_GIT_BRANCH_NAME(GIT_BRANCH_NAME) -message(STATUS "GIT_BRANCH_NAME = ${GIT_BRANCH_NAME}") -if (NOT GIT_BRANCH_NAME STREQUAL "") - string(REGEX REPLACE "\n" "" GIT_BRANCH_NAME ${GIT_BRANCH_NAME}) -endif () - -set(MILVUS_VERSION "${GIT_BRANCH_NAME}") -string(REGEX MATCH "${GIT_BRANCH_NAME_REGEX}" MILVUS_VERSION "${MILVUS_VERSION}") - -# get last commit id -MACRO(GET_LAST_COMMIT_ID LAST_COMMIT_ID) - execute_process(COMMAND sh "-c" "git log --decorate | head -n 1 | awk '{print $2}'" - OUTPUT_VARIABLE ${LAST_COMMIT_ID}) -ENDMACRO(GET_LAST_COMMIT_ID) - -GET_LAST_COMMIT_ID(LAST_COMMIT_ID) -message(STATUS "LAST_COMMIT_ID = ${LAST_COMMIT_ID}") -if (NOT LAST_COMMIT_ID STREQUAL "") - string(REGEX REPLACE "\n" "" LAST_COMMIT_ID ${LAST_COMMIT_ID}) - set(LAST_COMMIT_ID "${LAST_COMMIT_ID}") -else () - set(LAST_COMMIT_ID "Unknown") -endif () - -# set build type -if (CMAKE_BUILD_TYPE STREQUAL "Release") - set(BUILD_TYPE "Release") -else () - set(BUILD_TYPE "Debug") -endif () +get_build_type(TARGET BUILD_TYPE + DEFAULT "Release") message(STATUS "Build type = ${BUILD_TYPE}") +get_milvus_version(TARGET MILVUS_VERSION + DEFAULT "0.10.0") +message(STATUS "Build version = ${MILVUS_VERSION}") +get_last_commit_id(LAST_COMMIT_ID) +message(STATUS "LAST_COMMIT_ID = ${LAST_COMMIT_ID}") -project(milvus VERSION "${MILVUS_VERSION}") -project(milvus_engine LANGUAGES CXX) +configure_file(${CMAKE_CURRENT_SOURCE_DIR}/src/version.h.in ${CMAKE_CURRENT_SOURCE_DIR}/src/version.h @ONLY) -unset(CMAKE_EXPORT_COMPILE_COMMANDS CACHE) +# unset(CMAKE_EXPORT_COMPILE_COMMANDS CACHE) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -set(MILVUS_VERSION_MAJOR "${milvus_VERSION_MAJOR}") -set(MILVUS_VERSION_MINOR "${milvus_VERSION_MINOR}") -set(MILVUS_VERSION_PATCH "${milvus_VERSION_PATCH}") - -if (MILVUS_VERSION_MAJOR STREQUAL "" - OR MILVUS_VERSION_MINOR STREQUAL "" - OR MILVUS_VERSION_PATCH STREQUAL "") - message(WARNING "Failed to determine Milvus version from git branch name") - set(MILVUS_VERSION "0.10.0") -endif () - -message(STATUS "Build version = ${MILVUS_VERSION}") -configure_file(${CMAKE_CURRENT_SOURCE_DIR}/src/version.h.in ${CMAKE_CURRENT_SOURCE_DIR}/src/version.h @ONLY) - -message(STATUS "Milvus version: " - "${MILVUS_VERSION_MAJOR}.${MILVUS_VERSION_MINOR}.${MILVUS_VERSION_PATCH} " - "(full: '${MILVUS_VERSION}')") +# **************************** Project **************************** +project(milvus VERSION "${MILVUS_VERSION}") set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED on) -if (CMAKE_SYSTEM_PROCESSOR MATCHES "(x86)|(X86)|(amd64)|(AMD64)") - message(STATUS "Building milvus_engine on x86 architecture") - set(MILVUS_BUILD_ARCH x86_64) -elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "(ppc)") - message(STATUS "Building milvus_engine on ppc architecture") - set(MILVUS_BUILD_ARCH ppc64le) -else () - message(WARNING "Unknown processor type") - message(WARNING "CMAKE_SYSTEM_PROCESSOR=${CMAKE_SYSTEM_PROCESSOR}") - set(MILVUS_BUILD_ARCH unknown) -endif () - -# Ensure that a default make is set -if ("${MAKE}" STREQUAL "") - if (NOT MSVC) - find_program(MAKE make) - endif () -endif () - -find_path(MYSQL_INCLUDE_DIR - NAMES "mysql.h" - PATH_SUFFIXES "mysql") -if (${MYSQL_INCLUDE_DIR} STREQUAL "MYSQL_INCLUDE_DIR-NOTFOUND") - message(FATAL_ERROR "Could not found MySQL include directory") -else () - include_directories(${MYSQL_INCLUDE_DIR}) -endif () - set(MILVUS_SOURCE_DIR ${PROJECT_SOURCE_DIR}) set(MILVUS_BINARY_DIR ${PROJECT_BINARY_DIR}) set(MILVUS_ENGINE_SRC ${PROJECT_SOURCE_DIR}/src) set(MILVUS_THIRDPARTY_SRC ${PROJECT_SOURCE_DIR}/thirdparty) +# **************************** Dependencies **************************** +include(BuildUtils) +import_mysql_inc() + include(ExternalProject) include(DefineOptions) -include(BuildUtils) include(ThirdPartyPackages) -if (MILVUS_USE_CCACHE) - find_program(CCACHE_FOUND ccache) - if (CCACHE_FOUND) - message(STATUS "Using ccache: ${CCACHE_FOUND}") - set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE ${CCACHE_FOUND}) - set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK ${CCACHE_FOUND}) - # let ccache preserve C++ comments, because some of them may be - # meaningful to the compiler - set(ENV{CCACHE_COMMENTS} "1") - endif (CCACHE_FOUND) -endif () +using_ccache_if_defined(MILVUS_USE_CCACHE) + +# **************************** Compiler arguments **************************** if (MILVUS_GPU_VERSION) message(STATUS "Building Milvus GPU version") @@ -163,32 +67,23 @@ else () message(STATUS "Building Milvus CPU version") endif () -if (MILVUS_WITH_PROMETHEUS) - add_compile_definitions("MILVUS_WITH_PROMETHEUS") -endif () +set_milvus_definition(MILVUS_WITH_PROMETHEUS "MILVUS_WITH_PROMETHEUS") +set_milvus_definition(ENABLE_CPU_PROFILING "ENABLE_CPU_PROFILING") +set_milvus_definition(MILVUS_WITH_FIU "FIU_ENABLE") -message("ENABLE_CPU_PROFILING = ${ENABLE_CPU_PROFILING}") -if (ENABLE_CPU_PROFILING STREQUAL "ON") - ADD_DEFINITIONS(-DENABLE_CPU_PROFILING) -endif() - -if (MILVUS_WITH_FIU) - add_compile_definitions("FIU_ENABLE") -endif () +config_summary() if (CMAKE_BUILD_TYPE STREQUAL "Release") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC -DELPP_THREAD_SAFE -fopenmp") - if (MILVUS_GPU_VERSION) - set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -O3") - endif () -else () + set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -O3") +elseif (CMAKE_BUILD_TYPE STREQUAL "Debug") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g -fPIC -DELPP_THREAD_SAFE -fopenmp") - if (MILVUS_GPU_VERSION) - set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -O0 -g") - endif () + set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -O0 -g") +else () + message(FATAL_ERROR "Unknown CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}") endif () -config_summary() +# **************************** Source files **************************** add_subdirectory(src) if (BUILD_UNIT_TEST STREQUAL "ON") @@ -201,21 +96,20 @@ endif () add_custom_target(Clean-All COMMAND ${CMAKE_BUILD_TOOL} clean) +# **************************** Install **************************** + if ("${MILVUS_DB_PATH}" STREQUAL "") set(MILVUS_DB_PATH "${CMAKE_INSTALL_PREFIX}") endif () if (MILVUS_GPU_VERSION) set(GPU_ENABLE "true") - configure_file(${CMAKE_CURRENT_SOURCE_DIR}/conf/server_config.template - ${CMAKE_CURRENT_SOURCE_DIR}/conf/server_config.yaml - @ONLY) else () set(GPU_ENABLE "false") - configure_file(${CMAKE_CURRENT_SOURCE_DIR}/conf/server_config.template - ${CMAKE_CURRENT_SOURCE_DIR}/conf/server_config.yaml - @ONLY) endif () +configure_file(${CMAKE_CURRENT_SOURCE_DIR}/conf/server_config.template + ${CMAKE_CURRENT_SOURCE_DIR}/conf/server_config.yaml + @ONLY) install(DIRECTORY scripts/ DESTINATION scripts @@ -233,6 +127,8 @@ install(FILES DESTINATION conf) + +# **************************** Coding style check tools **************************** find_package(Python COMPONENTS Interpreter Development) find_package(ClangTools) set(BUILD_SUPPORT_DIR "${CMAKE_SOURCE_DIR}/build-support") diff --git a/core/build.sh b/core/build.sh index 28ad29cc8b99..ca68e5d7709e 100755 --- a/core/build.sh +++ b/core/build.sh @@ -1,5 +1,5 @@ #!/bin/bash - +# BUILD_OUTPUT_DIR="cmake_build" BUILD_TYPE="Debug" BUILD_UNITTEST="OFF" @@ -16,9 +16,10 @@ FAISS_ROOT="" #FAISS root path FAISS_SOURCE="BUNDLED" WITH_PROMETHEUS="ON" FIU_ENABLE="OFF" +CUDA_ARCH="DEFAULT" # BUILD_OPENBLAS="ON" # not used any more -while getopts "p:d:t:f:ulrcghzmei" arg; do +while getopts "p:d:t:f:s:ulrcghzmei" arg; do case $arg in p) INSTALL_PREFIX=$OPTARG @@ -64,6 +65,9 @@ while getopts "p:d:t:f:ulrcghzmei" arg; do i) FIU_ENABLE="ON" ;; + s) + CUDA_ARCH=$OPTARG + ;; h) # help echo " @@ -83,10 +87,11 @@ parameter: -m: build with MKL(default: OFF) -e: build without prometheus(default: OFF) -i: build FIU_ENABLE(default: OFF) +-s: build with CUDA arch(default:DEFAULT), for example '-gencode=compute_61,code=sm_61;-gencode=compute_75,code=sm_75' -h: help usage: -./build.sh -p \${INSTALL_PREFIX} -t \${BUILD_TYPE} -f \${FAISS_ROOT} [-u] [-l] [-r] [-c] [-z] [-g] [-m] [-e] [-h] +./build.sh -p \${INSTALL_PREFIX} -t \${BUILD_TYPE} -f \${FAISS_ROOT} -s \${CUDA_ARCH}[-u] [-l] [-r] [-c] [-z] [-g] [-m] [-e] [-h] " exit 0 ;; @@ -122,6 +127,7 @@ CMAKE_CMD="cmake \ -DFAISS_WITH_MKL=${WITH_MKL} \ -DMILVUS_WITH_PROMETHEUS=${WITH_PROMETHEUS} \ -DMILVUS_WITH_FIU=${FIU_ENABLE} \ +-DMILVUS_CUDA_ARCH=${CUDA_ARCH} \ ../" echo ${CMAKE_CMD} ${CMAKE_CMD} @@ -139,7 +145,7 @@ if [[ ${RUN_CPPLINT} == "ON" ]]; then fi echo "cpplint check passed!" - # clang-format check + clang-format check make check-clang-format if [ $? -ne 0 ]; then echo "ERROR! clang-format check failed" @@ -157,5 +163,5 @@ if [[ ${RUN_CPPLINT} == "ON" ]]; then else # compile and build - make -j 8 install || exit 1 + make -j 8 install || exit 1 fi diff --git a/core/cmake/BuildUtils.cmake b/core/cmake/BuildUtils.cmake index 6332d29d747b..a739ce243db4 100644 --- a/core/cmake/BuildUtils.cmake +++ b/core/cmake/BuildUtils.cmake @@ -202,3 +202,30 @@ function(ADD_THIRDPARTY_LIB LIB_NAME) message(FATAL_ERROR "No static or shared library provided for ${LIB_NAME}") endif() endfunction() + +MACRO (import_mysql_inc) + find_path (MYSQL_INCLUDE_DIR + NAMES "mysql.h" + PATH_SUFFIXES "mysql") + + if (${MYSQL_INCLUDE_DIR} STREQUAL "MYSQL_INCLUDE_DIR-NOTFOUND") + message(FATAL_ERROR "Could not found MySQL include directory") + else () + include_directories(${MYSQL_INCLUDE_DIR}) + endif () +ENDMACRO (import_mysql_inc) + +MACRO(using_ccache_if_defined MILVUS_USE_CCACHE) + if (MILVUS_USE_CCACHE) + find_program(CCACHE_FOUND ccache) + if (CCACHE_FOUND) + message(STATUS "Using ccache: ${CCACHE_FOUND}") + set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE ${CCACHE_FOUND}) + set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK ${CCACHE_FOUND}) + # let ccache preserve C++ comments, because some of them may be + # meaningful to the compiler + set(ENV{CCACHE_COMMENTS} "1") + endif (CCACHE_FOUND) + endif () +ENDMACRO(using_ccache_if_defined) + diff --git a/core/cmake/ThirdPartyPackages.cmake b/core/cmake/ThirdPartyPackages.cmake index 09b8fafd11ca..330e70951d8c 100644 --- a/core/cmake/ThirdPartyPackages.cmake +++ b/core/cmake/ThirdPartyPackages.cmake @@ -251,7 +251,7 @@ if (DEFINED ENV{MILVUS_PROMETHEUS_URL}) set(PROMETHEUS_SOURCE_URL "$ENV{PROMETHEUS_OPENBLAS_URL}") else () set(PROMETHEUS_SOURCE_URL - https://github.com/jupp0r/prometheus-cpp.git) + "https://github.com/milvus-io/prometheus-cpp/archive/${PROMETHEUS_VERSION}.zip") endif () if (DEFINED ENV{MILVUS_SQLITE_URL}) @@ -494,18 +494,17 @@ macro(build_prometheus) -DCMAKE_BUILD_TYPE=Release) externalproject_add(prometheus_ep - GIT_REPOSITORY + URL ${PROMETHEUS_SOURCE_URL} - GIT_TAG - ${PROMETHEUS_VERSION} - GIT_SHALLOW - TRUE ${EP_LOG_OPTIONS} CMAKE_ARGS ${PROMETHEUS_CMAKE_ARGS} BUILD_COMMAND ${MAKE} ${MAKE_BUILD_ARGS} + BUILD_COMMAND + ${MAKE} + ${MAKE_BUILD_ARGS} BUILD_IN_SOURCE 1 INSTALL_COMMAND diff --git a/core/cmake/Utils.cmake b/core/cmake/Utils.cmake new file mode 100644 index 000000000000..98deaa207605 --- /dev/null +++ b/core/cmake/Utils.cmake @@ -0,0 +1,80 @@ +# get build time +MACRO(get_current_time CURRENT_TIME) + execute_process(COMMAND "date" "+%Y-%m-%d %H:%M.%S" OUTPUT_VARIABLE ${CURRENT_TIME}) + string(REGEX REPLACE "\n" "" ${CURRENT_TIME} ${${CURRENT_TIME}}) +ENDMACRO(get_current_time) + +# get build type +MACRO(get_build_type) + cmake_parse_arguments(BUILD_TYPE "" "TARGET;DEFAULT" "" ${ARGN}) + if (NOT DEFINED CMAKE_BUILD_TYPE) + set(${BUILD_TYPE_TARGET} ${BUILD_TYPE_DEFAULT}) + elseif (CMAKE_BUILD_TYPE STREQUAL "Release") + set(${BUILD_TYPE_TARGET} "Release") + elseif (CMAKE_BUILD_TYPE STREQUAL "Debug") + set(${BUILD_TYPE_TARGET} "Debug") + else () + set(${BUILD_TYPE_TARGET} ${BUILD_TYPE_DEFAULT}) + endif () +ENDMACRO(get_build_type) + +# get git branch name +MACRO(get_git_branch_name GIT_BRANCH_NAME) + set(GIT_BRANCH_NAME_REGEX "[0-9]+\\.[0-9]+\\.[0-9]") + + execute_process(COMMAND sh "-c" "git log --decorate | head -n 1 | sed 's/.*(\\(.*\\))/\\1/' | sed 's/.*, //' | sed 's=[a-zA-Z]*\/==g'" + OUTPUT_VARIABLE ${GIT_BRANCH_NAME}) + + if (NOT GIT_BRANCH_NAME MATCHES "${GIT_BRANCH_NAME_REGEX}") + execute_process(COMMAND "git" rev-parse --abbrev-ref HEAD OUTPUT_VARIABLE ${GIT_BRANCH_NAME}) + endif () + + if (NOT GIT_BRANCH_NAME MATCHES "${GIT_BRANCH_NAME_REGEX}") + execute_process(COMMAND "git" symbolic-ref -q --short HEAD OUTPUT_VARIABLE ${GIT_BRANCH_NAME}) + endif () + + message(DEBUG "GIT_BRANCH_NAME = ${GIT_BRANCH_NAME}") + + # Some unexpected case + if (NOT GIT_BRANCH_NAME STREQUAL "") + string(REGEX REPLACE "\n" "" GIT_BRANCH_NAME ${GIT_BRANCH_NAME}) + else () + set(GIT_BRANCH_NAME "#") + endif () +ENDMACRO(get_git_branch_name) + +# get last commit id +MACRO(get_last_commit_id LAST_COMMIT_ID) + execute_process(COMMAND sh "-c" "git log --decorate | head -n 1 | awk '{print $2}'" + OUTPUT_VARIABLE ${LAST_COMMIT_ID}) + + message(DEBUG "LAST_COMMIT_ID = ${${LAST_COMMIT_ID}}") + + if (NOT LAST_COMMIT_ID STREQUAL "") + string(REGEX REPLACE "\n" "" ${LAST_COMMIT_ID} ${${LAST_COMMIT_ID}}) + else () + set(LAST_COMMIT_ID "Unknown") + endif () +ENDMACRO(get_last_commit_id) + +# get milvus version +MACRO(get_milvus_version) + cmake_parse_arguments(VER "" "TARGET;DEFAULT" "" ${ARGN}) + + # Step 1: get branch name + get_git_branch_name(GIT_BRANCH_NAME) + message(DEBUG ${GIT_BRANCH_NAME}) + + # Step 2: match MAJOR.MINOR.PATCH format or set DEFAULT value + string(REGEX MATCH "([0-9]+)\\.([0-9]+)\\.([0-9]+)" ${VER_TARGET} ${GIT_BRANCH_NAME}) + if (NOT ${VER_TARGET}) + set(${VER_TARGET} ${VER_DEFAULT}) + endif() +ENDMACRO(get_milvus_version) + +# set definition +MACRO(set_milvus_definition DEF_PASS_CMAKE MILVUS_DEF) + if (${${DEF_PASS_CMAKE}}) + add_compile_definitions(${MILVUS_DEF}) + endif() +ENDMACRO(set_milvus_definition) diff --git a/core/scripts/migration/sqlite_6_to_4.sql b/core/scripts/migration/sqlite_6_to_4.sql index 686d276f461c..19202bdbc6a7 100644 --- a/core/scripts/migration/sqlite_6_to_4.sql +++ b/core/scripts/migration/sqlite_6_to_4.sql @@ -1,6 +1,4 @@ -CREATE TABLE 'TempTables' ( 'id' INTEGER PRIMARY KEY NOT NULL , 'table_id' TEXT UNIQUE NOT NULL , 'state' INTEGER NOT NULL , 'dimension' INTEGER NOT NULL , 'created_on' INTEGER NOT NULL , 'flag' INTEGER DEFAULT 0 NOT NULL , 'index_file_size' INTEGER NOT NULL , 'engine_type' INTEGER NOT NULL , 'nlist' INTEGER NOT NULL , 'metric_type' INTEGER NOT NULL); - -INSERT INTO TempTables SELECT id, table_id, state, dimension, created_on, flag, index_file_size, engine_type, nlist, metric_type FROM Tables; +CREATE TABLE 'TempTables' AS SELECT id, table_id, state, dimension, created_on, flag, index_file_size, engine_type, nlist, metric_type FROM Tables; DROP TABLE Tables; diff --git a/core/src/CMakeLists.txt b/core/src/CMakeLists.txt index 3d87f50764cc..c02e3f047599 100644 --- a/core/src/CMakeLists.txt +++ b/core/src/CMakeLists.txt @@ -38,11 +38,17 @@ aux_source_directory(${MILVUS_ENGINE_SRC}/db db_main_files) aux_source_directory(${MILVUS_ENGINE_SRC}/db/attr db_attr_files) aux_source_directory(${MILVUS_ENGINE_SRC}/db/engine db_engine_files) aux_source_directory(${MILVUS_ENGINE_SRC}/db/insert db_insert_files) -aux_source_directory(${MILVUS_ENGINE_SRC}/db/meta db_meta_files) aux_source_directory(${MILVUS_ENGINE_SRC}/db/merge db_merge_files) aux_source_directory(${MILVUS_ENGINE_SRC}/db/wal db_wal_files) aux_source_directory(${MILVUS_ENGINE_SRC}/db/snapshot db_snapshot_files) +aux_source_directory(${MILVUS_ENGINE_SRC}/db/meta db_meta_main_files) +aux_source_directory(${MILVUS_ENGINE_SRC}/db/meta/backend db_meta_backend_files) +set(db_meta_files + ${db_meta_main_files} + ${db_meta_backend_files} + ) + set(grpc_service_files ${MILVUS_ENGINE_SRC}/grpc/gen-milvus/milvus.grpc.pb.cc ${MILVUS_ENGINE_SRC}/grpc/gen-milvus/milvus.pb.cc @@ -137,6 +143,7 @@ aux_source_directory(${MILVUS_ENGINE_SRC}/tracing tracing_files) aux_source_directory(${MILVUS_ENGINE_SRC}/codecs codecs_files) aux_source_directory(${MILVUS_ENGINE_SRC}/codecs/default codecs_default_files) +aux_source_directory(${MILVUS_ENGINE_SRC}/codecs/snapshot codecs_snapshot_files) aux_source_directory(${MILVUS_ENGINE_SRC}/segment segment_files) @@ -158,6 +165,7 @@ set(engine_files ${wrapper_files} ${codecs_files} ${codecs_default_files} + ${codecs_snapshot_files} ${segment_files} ) diff --git a/core/src/codecs/Codec.h b/core/src/codecs/Codec.h index 972815e0e625..6b61b929792e 100644 --- a/core/src/codecs/Codec.h +++ b/core/src/codecs/Codec.h @@ -22,8 +22,10 @@ #include "DeletedDocsFormat.h" #include "IdBloomFilterFormat.h" #include "IdIndexFormat.h" +#include "VectorCompressFormat.h" #include "VectorIndexFormat.h" #include "VectorsFormat.h" +#include "utils/Exception.h" namespace milvus { namespace codec { @@ -31,35 +33,39 @@ namespace codec { class Codec { public: virtual VectorsFormatPtr - GetVectorsFormat() = 0; + GetVectorsFormat() { + throw Exception(SERVER_UNSUPPORTED_ERROR, "vectors not supported"); + } virtual AttrsFormatPtr - GetAttrsFormat() = 0; + GetAttrsFormat() { + throw Exception(SERVER_UNSUPPORTED_ERROR, "attr not supported"); + } virtual VectorIndexFormatPtr - GetVectorIndexFormat() = 0; + GetVectorIndexFormat() { + throw Exception(SERVER_UNSUPPORTED_ERROR, "vectors index not supported"); + } virtual AttrsIndexFormatPtr - GetAttrsIndexFormat() = 0; + GetAttrsIndexFormat() { + throw Exception(SERVER_UNSUPPORTED_ERROR, "attr index not supported"); + } virtual DeletedDocsFormatPtr - GetDeletedDocsFormat() = 0; + GetDeletedDocsFormat() { + throw Exception(SERVER_UNSUPPORTED_ERROR, "delete doc index not supported"); + } virtual IdBloomFilterFormatPtr - GetIdBloomFilterFormat() = 0; + GetIdBloomFilterFormat() { + throw Exception(SERVER_UNSUPPORTED_ERROR, "id bloom filter not supported"); + } - // TODO(zhiru) - /* - virtual AttrsFormat - GetAttrsFormat() = 0; - - virtual AttrsIndexFormat - GetAttrsIndexFormat() = 0; - - virtual IdIndexFormat - GetIdIndexFormat() = 0; - - */ + virtual VectorCompressFormatPtr + GetVectorCompressFormat() { + throw Exception(SERVER_UNSUPPORTED_ERROR, "vector compress not supported"); + } }; } // namespace codec diff --git a/core/src/codecs/VectorCompressFormat.h b/core/src/codecs/VectorCompressFormat.h new file mode 100644 index 000000000000..6d5edb24baef --- /dev/null +++ b/core/src/codecs/VectorCompressFormat.h @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +#include "knowhere/common/BinarySet.h" +#include "storage/FSHandler.h" + +namespace milvus { +namespace codec { + +class VectorCompressFormat { + public: + virtual void + read(const storage::FSHandlerPtr& fs_ptr, const std::string& location, knowhere::BinaryPtr& compress) = 0; + + virtual void + write(const storage::FSHandlerPtr& fs_ptr, const std::string& location, const knowhere::BinaryPtr& compress) = 0; +}; + +using VectorCompressFormatPtr = std::shared_ptr; + +} // namespace codec +} // namespace milvus diff --git a/core/src/codecs/VectorIndexFormat.h b/core/src/codecs/VectorIndexFormat.h index d25fcd712cce..6bc4e648778c 100644 --- a/core/src/codecs/VectorIndexFormat.h +++ b/core/src/codecs/VectorIndexFormat.h @@ -26,10 +26,13 @@ namespace milvus { namespace codec { +enum ExternalData { ExternalData_None, ExternalData_RawData, ExternalData_SQ8 }; + class VectorIndexFormat { public: virtual void - read(const storage::FSHandlerPtr& fs_ptr, const std::string& location, segment::VectorIndexPtr& vector_index) = 0; + read(const storage::FSHandlerPtr& fs_ptr, const std::string& location, ExternalData external_data, + segment::VectorIndexPtr& vector_index) = 0; virtual void write(const storage::FSHandlerPtr& fs_ptr, const std::string& location, diff --git a/core/src/codecs/VectorsFormat.h b/core/src/codecs/VectorsFormat.h index 5227f9a6a00a..150e9309d61b 100644 --- a/core/src/codecs/VectorsFormat.h +++ b/core/src/codecs/VectorsFormat.h @@ -20,6 +20,7 @@ #include #include +#include "index/knowhere/knowhere/common/BinarySet.h" #include "segment/Vectors.h" #include "storage/FSHandler.h" @@ -37,6 +38,9 @@ class VectorsFormat { virtual void read_uids(const storage::FSHandlerPtr& fs_ptr, std::vector& uids) = 0; + virtual void + read_vectors(const storage::FSHandlerPtr& fs_ptr, knowhere::BinaryPtr& raw_vectors) = 0; + virtual void read_vectors(const storage::FSHandlerPtr& fs_ptr, off_t offset, size_t num_bytes, std::vector& raw_vectors) = 0; diff --git a/core/src/codecs/default/DefaultAttrsFormat.cpp b/core/src/codecs/default/DefaultAttrsFormat.cpp index e4493cfb2aca..9c673e712ed1 100644 --- a/core/src/codecs/default/DefaultAttrsFormat.cpp +++ b/core/src/codecs/default/DefaultAttrsFormat.cpp @@ -78,8 +78,6 @@ DefaultAttrsFormat::read_uids_internal(const storage::FSHandlerPtr& fs_ptr, cons void DefaultAttrsFormat::read(const milvus::storage::FSHandlerPtr& fs_ptr, milvus::segment::AttrsPtr& attrs_read) { - const std::lock_guard lock(mutex_); - std::string dir_path = fs_ptr->operation_ptr_->GetDirectory(); auto is_directory = boost::filesystem::is_directory(dir_path); fiu_do_on("read_id_directory_false", is_directory = false); @@ -120,8 +118,6 @@ DefaultAttrsFormat::read(const milvus::storage::FSHandlerPtr& fs_ptr, milvus::se void DefaultAttrsFormat::write(const milvus::storage::FSHandlerPtr& fs_ptr, const milvus::segment::AttrsPtr& attrs_ptr) { - const std::lock_guard lock(mutex_); - TimeRecorder rc("write attributes"); std::string dir_path = fs_ptr->operation_ptr_->GetDirectory(); @@ -195,8 +191,6 @@ DefaultAttrsFormat::write(const milvus::storage::FSHandlerPtr& fs_ptr, const mil void DefaultAttrsFormat::read_attrs(const milvus::storage::FSHandlerPtr& fs_ptr, const std::string& field_name, off_t offset, size_t num_bytes, std::vector& raw_attrs) { - const std::lock_guard lock(mutex_); - std::string dir_path = fs_ptr->operation_ptr_->GetDirectory(); if (!boost::filesystem::is_directory(dir_path)) { std::string err_msg = "Directory: " + dir_path + "does not exist"; @@ -222,8 +216,6 @@ DefaultAttrsFormat::read_attrs(const milvus::storage::FSHandlerPtr& fs_ptr, cons void DefaultAttrsFormat::read_uids(const milvus::storage::FSHandlerPtr& fs_ptr, std::vector& uids) { - const std::lock_guard lock(mutex_); - std::string dir_path = fs_ptr->operation_ptr_->GetDirectory(); auto is_directory = boost::filesystem::is_directory(dir_path); fiu_do_on("is_directory_false", is_directory = false); diff --git a/core/src/codecs/default/DefaultAttrsFormat.h b/core/src/codecs/default/DefaultAttrsFormat.h index ed00531b6f77..c078d29dc68c 100644 --- a/core/src/codecs/default/DefaultAttrsFormat.h +++ b/core/src/codecs/default/DefaultAttrsFormat.h @@ -17,7 +17,6 @@ #pragma once -#include #include #include @@ -62,8 +61,6 @@ class DefaultAttrsFormat : public AttrsFormat { read_uids_internal(const storage::FSHandlerPtr& fs_ptr, const std::string&, std::vector&); private: - std::mutex mutex_; - const std::string raw_attr_extension_ = ".ra"; const std::string user_id_extension_ = ".uid"; }; diff --git a/core/src/codecs/default/DefaultAttrsIndexFormat.cpp b/core/src/codecs/default/DefaultAttrsIndexFormat.cpp index 875408fa15e5..faa9c6d3c445 100644 --- a/core/src/codecs/default/DefaultAttrsIndexFormat.cpp +++ b/core/src/codecs/default/DefaultAttrsIndexFormat.cpp @@ -139,8 +139,6 @@ DefaultAttrsIndexFormat::read_internal(const milvus::storage::FSHandlerPtr& fs_p void DefaultAttrsIndexFormat::read(const milvus::storage::FSHandlerPtr& fs_ptr, milvus::segment::AttrsIndexPtr& attrs_index) { - const std::lock_guard lock(mutex_); - std::string dir_path = fs_ptr->operation_ptr_->GetDirectory(); if (!boost::filesystem::is_directory(dir_path)) { std::string err_msg = "Directory: " + dir_path + "does not exist"; @@ -170,8 +168,6 @@ DefaultAttrsIndexFormat::read(const milvus::storage::FSHandlerPtr& fs_ptr, void DefaultAttrsIndexFormat::write(const milvus::storage::FSHandlerPtr& fs_ptr, const milvus::segment::AttrsIndexPtr& attrs_index) { - const std::lock_guard lock(mutex_); - milvus::TimeRecorder recorder("write_index"); recorder.RecordSection("Start"); diff --git a/core/src/codecs/default/DefaultAttrsIndexFormat.h b/core/src/codecs/default/DefaultAttrsIndexFormat.h index ca36f5c8f219..5d62f6fee571 100644 --- a/core/src/codecs/default/DefaultAttrsIndexFormat.h +++ b/core/src/codecs/default/DefaultAttrsIndexFormat.h @@ -18,7 +18,6 @@ #pragma once #include -#include #include #include @@ -56,8 +55,6 @@ class DefaultAttrsIndexFormat : public AttrsIndexFormat { create_structured_index(const engine::meta::hybrid::DataType data_type); private: - std::mutex mutex_; - const std::string attr_index_extension_ = ".idx"; }; diff --git a/core/src/codecs/default/DefaultCodec.cpp b/core/src/codecs/default/DefaultCodec.cpp index b57b4273423d..421ffd71f0c2 100644 --- a/core/src/codecs/default/DefaultCodec.cpp +++ b/core/src/codecs/default/DefaultCodec.cpp @@ -23,12 +23,19 @@ #include "DefaultAttrsIndexFormat.h" #include "DefaultDeletedDocsFormat.h" #include "DefaultIdBloomFilterFormat.h" +#include "DefaultVectorCompressFormat.h" #include "DefaultVectorIndexFormat.h" #include "DefaultVectorsFormat.h" namespace milvus { namespace codec { +DefaultCodec& +DefaultCodec::instance() { + static DefaultCodec s_instance; + return s_instance; +} + DefaultCodec::DefaultCodec() { vectors_format_ptr_ = std::make_shared(); attrs_format_ptr_ = std::make_shared(); @@ -36,6 +43,7 @@ DefaultCodec::DefaultCodec() { attrs_index_format_ptr_ = std::make_shared(); deleted_docs_format_ptr_ = std::make_shared(); id_bloom_filter_format_ptr_ = std::make_shared(); + vector_compress_format_ptr_ = std::make_shared(); } VectorsFormatPtr @@ -68,5 +76,10 @@ DefaultCodec::GetIdBloomFilterFormat() { return id_bloom_filter_format_ptr_; } +VectorCompressFormatPtr +DefaultCodec::GetVectorCompressFormat() { + return vector_compress_format_ptr_; +} + } // namespace codec } // namespace milvus diff --git a/core/src/codecs/default/DefaultCodec.h b/core/src/codecs/default/DefaultCodec.h index 3bf54e36d9e7..6f25de6a371f 100644 --- a/core/src/codecs/default/DefaultCodec.h +++ b/core/src/codecs/default/DefaultCodec.h @@ -24,7 +24,8 @@ namespace codec { class DefaultCodec : public Codec { public: - DefaultCodec(); + static DefaultCodec& + instance(); VectorsFormatPtr GetVectorsFormat() override; @@ -44,6 +45,12 @@ class DefaultCodec : public Codec { IdBloomFilterFormatPtr GetIdBloomFilterFormat() override; + VectorCompressFormatPtr + GetVectorCompressFormat() override; + + private: + DefaultCodec(); + private: VectorsFormatPtr vectors_format_ptr_; AttrsFormatPtr attrs_format_ptr_; @@ -51,6 +58,7 @@ class DefaultCodec : public Codec { AttrsIndexFormatPtr attrs_index_format_ptr_; DeletedDocsFormatPtr deleted_docs_format_ptr_; IdBloomFilterFormatPtr id_bloom_filter_format_ptr_; + VectorCompressFormatPtr vector_compress_format_ptr_; }; } // namespace codec diff --git a/core/src/codecs/default/DefaultDeletedDocsFormat.cpp b/core/src/codecs/default/DefaultDeletedDocsFormat.cpp index 8651a05be79b..c172e00806e3 100644 --- a/core/src/codecs/default/DefaultDeletedDocsFormat.cpp +++ b/core/src/codecs/default/DefaultDeletedDocsFormat.cpp @@ -36,8 +36,6 @@ namespace codec { void DefaultDeletedDocsFormat::read(const storage::FSHandlerPtr& fs_ptr, segment::DeletedDocsPtr& deleted_docs) { - const std::lock_guard lock(mutex_); - std::string dir_path = fs_ptr->operation_ptr_->GetDirectory(); const std::string del_file_path = dir_path + "/" + deleted_docs_filename_; @@ -76,8 +74,6 @@ DefaultDeletedDocsFormat::read(const storage::FSHandlerPtr& fs_ptr, segment::Del void DefaultDeletedDocsFormat::write(const storage::FSHandlerPtr& fs_ptr, const segment::DeletedDocsPtr& deleted_docs) { - const std::lock_guard lock(mutex_); - std::string dir_path = fs_ptr->operation_ptr_->GetDirectory(); const std::string del_file_path = dir_path + "/" + deleted_docs_filename_; @@ -148,8 +144,6 @@ DefaultDeletedDocsFormat::write(const storage::FSHandlerPtr& fs_ptr, const segme void DefaultDeletedDocsFormat::readSize(const storage::FSHandlerPtr& fs_ptr, size_t& size) { - const std::lock_guard lock(mutex_); - std::string dir_path = fs_ptr->operation_ptr_->GetDirectory(); const std::string del_file_path = dir_path + "/" + deleted_docs_filename_; diff --git a/core/src/codecs/default/DefaultDeletedDocsFormat.h b/core/src/codecs/default/DefaultDeletedDocsFormat.h index 06aff4c563fe..d8bc8a9ec8ed 100644 --- a/core/src/codecs/default/DefaultDeletedDocsFormat.h +++ b/core/src/codecs/default/DefaultDeletedDocsFormat.h @@ -17,7 +17,6 @@ #pragma once -#include #include #include "codecs/DeletedDocsFormat.h" @@ -48,8 +47,6 @@ class DefaultDeletedDocsFormat : public DeletedDocsFormat { operator=(DefaultDeletedDocsFormat&&) = delete; private: - std::mutex mutex_; - const std::string deleted_docs_filename_ = "deleted_docs"; }; diff --git a/core/src/codecs/default/DefaultIdBloomFilterFormat.cpp b/core/src/codecs/default/DefaultIdBloomFilterFormat.cpp index fb2eb9e5356f..ad39c905c3fd 100644 --- a/core/src/codecs/default/DefaultIdBloomFilterFormat.cpp +++ b/core/src/codecs/default/DefaultIdBloomFilterFormat.cpp @@ -32,8 +32,6 @@ constexpr double bloom_filter_error_rate = 0.01; void DefaultIdBloomFilterFormat::read(const storage::FSHandlerPtr& fs_ptr, segment::IdBloomFilterPtr& id_bloom_filter_ptr) { - const std::lock_guard lock(mutex_); - std::string dir_path = fs_ptr->operation_ptr_->GetDirectory(); const std::string bloom_filter_file_path = dir_path + "/" + bloom_filter_filename_; scaling_bloom_t* bloom_filter = @@ -51,8 +49,6 @@ DefaultIdBloomFilterFormat::read(const storage::FSHandlerPtr& fs_ptr, segment::I void DefaultIdBloomFilterFormat::write(const storage::FSHandlerPtr& fs_ptr, const segment::IdBloomFilterPtr& id_bloom_filter_ptr) { - const std::lock_guard lock(mutex_); - std::string dir_path = fs_ptr->operation_ptr_->GetDirectory(); const std::string bloom_filter_file_path = dir_path + "/" + bloom_filter_filename_; if (scaling_bloom_flush(id_bloom_filter_ptr->GetBloomFilter()) == -1) { diff --git a/core/src/codecs/default/DefaultIdBloomFilterFormat.h b/core/src/codecs/default/DefaultIdBloomFilterFormat.h index e35daad9ef8f..b6d66e77ca9c 100644 --- a/core/src/codecs/default/DefaultIdBloomFilterFormat.h +++ b/core/src/codecs/default/DefaultIdBloomFilterFormat.h @@ -17,7 +17,6 @@ #pragma once -#include #include #include "codecs/IdBloomFilterFormat.h" @@ -50,8 +49,6 @@ class DefaultIdBloomFilterFormat : public IdBloomFilterFormat { operator=(DefaultIdBloomFilterFormat&&) = delete; private: - std::mutex mutex_; - const std::string bloom_filter_filename_ = "bloom_filter"; }; diff --git a/core/src/codecs/default/DefaultVectorCompressFormat.cpp b/core/src/codecs/default/DefaultVectorCompressFormat.cpp new file mode 100644 index 000000000000..74d8d703b5d3 --- /dev/null +++ b/core/src/codecs/default/DefaultVectorCompressFormat.cpp @@ -0,0 +1,84 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include "codecs/default/DefaultVectorCompressFormat.h" +#include "knowhere/common/BinarySet.h" +#include "utils/Exception.h" +#include "utils/Log.h" +#include "utils/TimeRecorder.h" + +namespace milvus { +namespace codec { + +void +DefaultVectorCompressFormat::read(const storage::FSHandlerPtr& fs_ptr, const std::string& location, + knowhere::BinaryPtr& compress) { + const std::string compress_file_path = location + sq8_vector_extension_; + + milvus::TimeRecorder recorder("read_index"); + + recorder.RecordSection("Start"); + if (!fs_ptr->reader_ptr_->open(compress_file_path)) { + LOG_ENGINE_ERROR_ << "Fail to open vector index: " << compress_file_path; + return; + } + + int64_t length = fs_ptr->reader_ptr_->length(); + if (length <= 0) { + LOG_ENGINE_ERROR_ << "Invalid vector index length: " << compress_file_path; + return; + } + + compress = std::make_shared(); + compress->data = std::shared_ptr(new uint8_t[length]); + compress->size = length; + + fs_ptr->reader_ptr_->seekg(0); + fs_ptr->reader_ptr_->read(compress->data.get(), length); + fs_ptr->reader_ptr_->close(); + + double span = recorder.RecordSection("End"); + double rate = length * 1000000.0 / span / 1024 / 1024; + LOG_ENGINE_DEBUG_ << "read_compress(" << compress_file_path << ") rate " << rate << "MB/s"; +} + +void +DefaultVectorCompressFormat::write(const storage::FSHandlerPtr& fs_ptr, const std::string& location, + const knowhere::BinaryPtr& compress) { + const std::string compress_file_path = location + sq8_vector_extension_; + + milvus::TimeRecorder recorder("write_index"); + + recorder.RecordSection("Start"); + if (!fs_ptr->writer_ptr_->open(compress_file_path)) { + LOG_ENGINE_ERROR_ << "Fail to open vector compress: " << compress_file_path; + return; + } + + fs_ptr->writer_ptr_->write(compress->data.get(), compress->size); + fs_ptr->writer_ptr_->close(); + + double span = recorder.RecordSection("End"); + double rate = compress->size * 1000000.0 / span / 1024 / 1024; + LOG_ENGINE_DEBUG_ << "write_compress(" << compress_file_path << ") rate " << rate << "MB/s"; +} + +} // namespace codec +} // namespace milvus diff --git a/core/src/codecs/default/DefaultVectorCompressFormat.h b/core/src/codecs/default/DefaultVectorCompressFormat.h new file mode 100644 index 000000000000..010ece9a7373 --- /dev/null +++ b/core/src/codecs/default/DefaultVectorCompressFormat.h @@ -0,0 +1,52 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "codecs/VectorCompressFormat.h" + +namespace milvus { +namespace codec { + +class DefaultVectorCompressFormat : public VectorCompressFormat { + public: + DefaultVectorCompressFormat() = default; + + void + read(const storage::FSHandlerPtr& fs_ptr, const std::string& location, knowhere::BinaryPtr& compress) override; + + void + write(const storage::FSHandlerPtr& fs_ptr, const std::string& location, + const knowhere::BinaryPtr& compress) override; + + // No copy and move + DefaultVectorCompressFormat(const DefaultVectorCompressFormat&) = delete; + DefaultVectorCompressFormat(DefaultVectorCompressFormat&&) = delete; + + DefaultVectorCompressFormat& + operator=(const DefaultVectorCompressFormat&) = delete; + DefaultVectorCompressFormat& + operator=(DefaultVectorCompressFormat&&) = delete; + + private: + const std::string sq8_vector_extension_ = ".sq8"; +}; + +} // namespace codec +} // namespace milvus diff --git a/core/src/codecs/default/DefaultVectorIndexFormat.cpp b/core/src/codecs/default/DefaultVectorIndexFormat.cpp index d543192dc429..33b1cc8cdd5e 100644 --- a/core/src/codecs/default/DefaultVectorIndexFormat.cpp +++ b/core/src/codecs/default/DefaultVectorIndexFormat.cpp @@ -18,6 +18,7 @@ #include #include +#include "codecs/default/DefaultCodec.h" #include "codecs/default/DefaultVectorIndexFormat.h" #include "knowhere/common/BinarySet.h" #include "knowhere/index/vector_index/VecIndex.h" @@ -31,7 +32,8 @@ namespace milvus { namespace codec { knowhere::VecIndexPtr -DefaultVectorIndexFormat::read_internal(const storage::FSHandlerPtr& fs_ptr, const std::string& path) { +DefaultVectorIndexFormat::read_internal(const storage::FSHandlerPtr& fs_ptr, const std::string& path, + const std::string& extern_key, const knowhere::BinaryPtr& extern_data) { milvus::TimeRecorder recorder("read_index"); knowhere::BinarySet load_data_list; @@ -91,8 +93,15 @@ DefaultVectorIndexFormat::read_internal(const storage::FSHandlerPtr& fs_ptr, con auto index = vec_index_factory.CreateVecIndex(knowhere::OldIndexTypeToStr(current_type), knowhere::IndexMode::MODE_CPU); if (index != nullptr) { + if (extern_data != nullptr) { + LOG_ENGINE_DEBUG_ << "load index with " << extern_key << " " << extern_data->size; + load_data_list.Append(extern_key, extern_data); + length += extern_data->size; + } + index->Load(load_data_list); - index->SetIndexSize(length); + index->UpdateIndexSize(); + LOG_ENGINE_DEBUG_ << "index file size " << length << " index size " << index->IndexSize(); } else { LOG_ENGINE_ERROR_ << "Fail to create vector index: " << path; } @@ -102,9 +111,7 @@ DefaultVectorIndexFormat::read_internal(const storage::FSHandlerPtr& fs_ptr, con void DefaultVectorIndexFormat::read(const storage::FSHandlerPtr& fs_ptr, const std::string& location, - segment::VectorIndexPtr& vector_index) { - const std::lock_guard lock(mutex_); - + ExternalData externalData, segment::VectorIndexPtr& vector_index) { std::string dir_path = fs_ptr->operation_ptr_->GetDirectory(); if (!boost::filesystem::is_directory(dir_path)) { std::string err_msg = "Directory: " + dir_path + "does not exist"; @@ -112,15 +119,36 @@ DefaultVectorIndexFormat::read(const storage::FSHandlerPtr& fs_ptr, const std::s throw Exception(SERVER_INVALID_ARGUMENT, err_msg); } - knowhere::VecIndexPtr index = read_internal(fs_ptr, location); + knowhere::VecIndexPtr index = nullptr; + switch (externalData) { + case ExternalData_None: { + index = read_internal(fs_ptr, location); + break; + } + case ExternalData_RawData: { + auto& default_codec = codec::DefaultCodec::instance(); + knowhere::BinaryPtr raw_data = nullptr; + default_codec.GetVectorsFormat()->read_vectors(fs_ptr, raw_data); + + index = read_internal(fs_ptr, location, RAW_DATA, raw_data); + break; + } + case ExternalData_SQ8: { + auto& default_codec = codec::DefaultCodec::instance(); + knowhere::BinaryPtr sq8_data = nullptr; + default_codec.GetVectorCompressFormat()->read(fs_ptr, location, sq8_data); + + index = read_internal(fs_ptr, location, SQ8_DATA, sq8_data); + break; + } + } + vector_index->SetVectorIndex(index); } void DefaultVectorIndexFormat::write(const storage::FSHandlerPtr& fs_ptr, const std::string& location, const segment::VectorIndexPtr& vector_index) { - const std::lock_guard lock(mutex_); - milvus::TimeRecorder recorder("write_index"); knowhere::VecIndexPtr index = vector_index->GetVectorIndex(); @@ -128,6 +156,12 @@ DefaultVectorIndexFormat::write(const storage::FSHandlerPtr& fs_ptr, const std:: auto binaryset = index->Serialize(knowhere::Config()); int32_t index_type = knowhere::StrToOldIndexType(index->index_type()); + auto sq8_data = binaryset.Erase(SQ8_DATA); + if (sq8_data != nullptr) { + auto& default_codec = codec::DefaultCodec::instance(); + default_codec.GetVectorCompressFormat()->write(fs_ptr, location, sq8_data); + } + recorder.RecordSection("Start"); if (!fs_ptr->writer_ptr_->open(location)) { LOG_ENGINE_ERROR_ << "Fail to open vector index: " << location; diff --git a/core/src/codecs/default/DefaultVectorIndexFormat.h b/core/src/codecs/default/DefaultVectorIndexFormat.h index 945ff31f4580..fae2a99bf6af 100644 --- a/core/src/codecs/default/DefaultVectorIndexFormat.h +++ b/core/src/codecs/default/DefaultVectorIndexFormat.h @@ -17,7 +17,6 @@ #pragma once -#include #include #include "codecs/VectorIndexFormat.h" @@ -30,7 +29,7 @@ class DefaultVectorIndexFormat : public VectorIndexFormat { DefaultVectorIndexFormat() = default; void - read(const storage::FSHandlerPtr& fs_ptr, const std::string& location, + read(const storage::FSHandlerPtr& fs_ptr, const std::string& location, ExternalData externalData, segment::VectorIndexPtr& vector_index) override; void @@ -48,12 +47,8 @@ class DefaultVectorIndexFormat : public VectorIndexFormat { private: knowhere::VecIndexPtr - read_internal(const storage::FSHandlerPtr& fs_ptr, const std::string& path); - - private: - std::mutex mutex_; - - const std::string vector_index_extension_ = ""; + read_internal(const storage::FSHandlerPtr& fs_ptr, const std::string& path, const std::string& extern_key = "", + const knowhere::BinaryPtr& extern_data = nullptr); }; } // namespace codec diff --git a/core/src/codecs/default/DefaultVectorsFormat.cpp b/core/src/codecs/default/DefaultVectorsFormat.cpp index 9e6e50d2f00d..2b6cba9b2d7e 100644 --- a/core/src/codecs/default/DefaultVectorsFormat.cpp +++ b/core/src/codecs/default/DefaultVectorsFormat.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include @@ -53,6 +54,30 @@ DefaultVectorsFormat::read_vectors_internal(const storage::FSHandlerPtr& fs_ptr, fs_ptr->reader_ptr_->close(); } +void +DefaultVectorsFormat::read_vectors_internal(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, + knowhere::BinaryPtr& raw_vectors) { + if (!fs_ptr->reader_ptr_->open(file_path.c_str())) { + std::string err_msg = "Failed to open file: " + file_path + ", error: " + std::strerror(errno); + LOG_ENGINE_ERROR_ << err_msg; + throw Exception(SERVER_CANNOT_OPEN_FILE, err_msg); + } + + size_t num_bytes; + fs_ptr->reader_ptr_->read(&num_bytes, sizeof(size_t)); + + raw_vectors = std::make_shared(); + raw_vectors->size = num_bytes; + raw_vectors->data = std::shared_ptr(new uint8_t[num_bytes]); + + // Beginning of file is num_bytes + fs_ptr->reader_ptr_->seekg(sizeof(size_t)); + + fs_ptr->reader_ptr_->read(raw_vectors->data.get(), num_bytes); + + fs_ptr->reader_ptr_->close(); +} + void DefaultVectorsFormat::read_uids_internal(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, std::vector& uids) { @@ -73,8 +98,6 @@ DefaultVectorsFormat::read_uids_internal(const storage::FSHandlerPtr& fs_ptr, co void DefaultVectorsFormat::read(const storage::FSHandlerPtr& fs_ptr, segment::VectorsPtr& vectors_read) { - const std::lock_guard lock(mutex_); - std::string dir_path = fs_ptr->operation_ptr_->GetDirectory(); if (!boost::filesystem::is_directory(dir_path)) { std::string err_msg = "Directory: " + dir_path + "does not exist"; @@ -102,8 +125,6 @@ DefaultVectorsFormat::read(const storage::FSHandlerPtr& fs_ptr, segment::Vectors void DefaultVectorsFormat::write(const storage::FSHandlerPtr& fs_ptr, const segment::VectorsPtr& vectors) { - const std::lock_guard lock(mutex_); - std::string dir_path = fs_ptr->operation_ptr_->GetDirectory(); const std::string rv_file_path = dir_path + "/" + vectors->GetName() + raw_vector_extension_; @@ -139,8 +160,6 @@ DefaultVectorsFormat::write(const storage::FSHandlerPtr& fs_ptr, const segment:: void DefaultVectorsFormat::read_uids(const storage::FSHandlerPtr& fs_ptr, std::vector& uids) { - const std::lock_guard lock(mutex_); - std::string dir_path = fs_ptr->operation_ptr_->GetDirectory(); if (!boost::filesystem::is_directory(dir_path)) { std::string err_msg = "Directory: " + dir_path + "does not exist"; @@ -157,6 +176,30 @@ DefaultVectorsFormat::read_uids(const storage::FSHandlerPtr& fs_ptr, std::vector const auto& path = it->path(); if (path.extension().string() == user_id_extension_) { read_uids_internal(fs_ptr, path.string(), uids); + break; + } + } +} + +void +DefaultVectorsFormat::read_vectors(const storage::FSHandlerPtr& fs_ptr, knowhere::BinaryPtr& raw_vectors) { + std::string dir_path = fs_ptr->operation_ptr_->GetDirectory(); + if (!boost::filesystem::is_directory(dir_path)) { + std::string err_msg = "Directory: " + dir_path + "does not exist"; + LOG_ENGINE_ERROR_ << err_msg; + throw Exception(SERVER_INVALID_ARGUMENT, err_msg); + } + + boost::filesystem::path target_path(dir_path); + typedef boost::filesystem::directory_iterator d_it; + d_it it_end; + d_it it(target_path); + // for (auto& it : boost::filesystem::directory_iterator(dir_path)) { + for (; it != it_end; ++it) { + const auto& path = it->path(); + if (path.extension().string() == raw_vector_extension_) { + read_vectors_internal(fs_ptr, path.string(), raw_vectors); + break; } } } @@ -164,8 +207,6 @@ DefaultVectorsFormat::read_uids(const storage::FSHandlerPtr& fs_ptr, std::vector void DefaultVectorsFormat::read_vectors(const storage::FSHandlerPtr& fs_ptr, off_t offset, size_t num_bytes, std::vector& raw_vectors) { - const std::lock_guard lock(mutex_); - std::string dir_path = fs_ptr->operation_ptr_->GetDirectory(); if (!boost::filesystem::is_directory(dir_path)) { std::string err_msg = "Directory: " + dir_path + "does not exist"; @@ -182,6 +223,7 @@ DefaultVectorsFormat::read_vectors(const storage::FSHandlerPtr& fs_ptr, off_t of const auto& path = it->path(); if (path.extension().string() == raw_vector_extension_) { read_vectors_internal(fs_ptr, path.string(), offset, num_bytes, raw_vectors); + break; } } } diff --git a/core/src/codecs/default/DefaultVectorsFormat.h b/core/src/codecs/default/DefaultVectorsFormat.h index ac5fc89a5a25..c331b3da0152 100644 --- a/core/src/codecs/default/DefaultVectorsFormat.h +++ b/core/src/codecs/default/DefaultVectorsFormat.h @@ -17,7 +17,6 @@ #pragma once -#include #include #include @@ -40,6 +39,9 @@ class DefaultVectorsFormat : public VectorsFormat { void read_uids(const storage::FSHandlerPtr& fs_ptr, std::vector& uids) override; + void + read_vectors(const storage::FSHandlerPtr& fs_ptr, knowhere::BinaryPtr& raw_vectors) override; + void read_vectors(const storage::FSHandlerPtr& fs_ptr, off_t offset, size_t num_bytes, std::vector& raw_vectors) override; @@ -58,13 +60,15 @@ class DefaultVectorsFormat : public VectorsFormat { read_vectors_internal(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, off_t offset, size_t num, std::vector& raw_vectors); + void + read_vectors_internal(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, + knowhere::BinaryPtr& raw_vectors); + void read_uids_internal(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, std::vector& uids); private: - std::mutex mutex_; - const std::string raw_vector_extension_ = ".rv"; const std::string user_id_extension_ = ".uid"; }; diff --git a/core/src/codecs/snapshot/SSBlockFormat.cpp b/core/src/codecs/snapshot/SSBlockFormat.cpp new file mode 100644 index 000000000000..44726e0bbceb --- /dev/null +++ b/core/src/codecs/snapshot/SSBlockFormat.cpp @@ -0,0 +1,139 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "codecs/snapshot/SSBlockFormat.h" + +#include +#include +#include +#include + +#include + +#include "utils/Exception.h" +#include "utils/Log.h" +#include "utils/TimeRecorder.h" + +namespace milvus { +namespace codec { + +void +SSBlockFormat::Read(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, std::vector& raw) { + if (!fs_ptr->reader_ptr_->open(file_path.c_str())) { + std::string err_msg = "Failed to open file: " + file_path + ", error: " + std::strerror(errno); + LOG_ENGINE_ERROR_ << err_msg; + throw Exception(SERVER_CANNOT_OPEN_FILE, err_msg); + } + + size_t num_bytes; + fs_ptr->reader_ptr_->read(&num_bytes, sizeof(size_t)); + + raw.resize(num_bytes); + fs_ptr->reader_ptr_->read(raw.data(), num_bytes); + + fs_ptr->reader_ptr_->close(); +} + +void +SSBlockFormat::Read(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, int64_t offset, + int64_t num_bytes, std::vector& raw) { + if (offset < 0 || num_bytes <= 0) { + std::string err_msg = "Invalid input to read: " + file_path; + LOG_ENGINE_ERROR_ << err_msg; + throw Exception(SERVER_INVALID_ARGUMENT, err_msg); + } + + if (!fs_ptr->reader_ptr_->open(file_path.c_str())) { + std::string err_msg = "Failed to open file: " + file_path + ", error: " + std::strerror(errno); + LOG_ENGINE_ERROR_ << err_msg; + throw Exception(SERVER_CANNOT_OPEN_FILE, err_msg); + } + + size_t total_num_bytes; + fs_ptr->reader_ptr_->read(&total_num_bytes, sizeof(size_t)); + + offset += sizeof(size_t); // Beginning of file is num_bytes + if (offset + num_bytes > total_num_bytes) { + std::string err_msg = "Invalid input to read: " + file_path; + LOG_ENGINE_ERROR_ << err_msg; + throw Exception(SERVER_INVALID_ARGUMENT, err_msg); + } + + raw.resize(num_bytes); + fs_ptr->reader_ptr_->seekg(offset); + fs_ptr->reader_ptr_->read(raw.data(), num_bytes); + fs_ptr->reader_ptr_->close(); +} + +void +SSBlockFormat::Read(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, const ReadRanges& read_ranges, + std::vector& raw) { + if (read_ranges.empty()) { + return; + } + + if (!fs_ptr->reader_ptr_->open(file_path.c_str())) { + std::string err_msg = "Failed to open file: " + file_path + ", error: " + std::strerror(errno); + LOG_ENGINE_ERROR_ << err_msg; + throw Exception(SERVER_CANNOT_OPEN_FILE, err_msg); + } + + size_t total_num_bytes; + fs_ptr->reader_ptr_->read(&total_num_bytes, sizeof(size_t)); + + int64_t total_bytes = 0; + for (auto& range : read_ranges) { + int64_t offset = range.offset_ + sizeof(size_t); + if (offset + range.num_bytes_ > total_num_bytes) { + std::string err_msg = "Invalid input to read: " + file_path; + LOG_ENGINE_ERROR_ << err_msg; + throw Exception(SERVER_INVALID_ARGUMENT, err_msg); + } + + total_bytes += range.num_bytes_; + } + + raw.clear(); + raw.resize(total_bytes); + int64_t poz = 0; + for (auto& range : read_ranges) { + int64_t offset = range.offset_ + sizeof(size_t); + fs_ptr->reader_ptr_->seekg(offset); + fs_ptr->reader_ptr_->read(raw.data() + poz, range.num_bytes_); + poz += range.num_bytes_; + } + + fs_ptr->reader_ptr_->close(); +} + +void +SSBlockFormat::Write(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, + const std::vector& raw) { + if (!fs_ptr->writer_ptr_->open(file_path.c_str())) { + std::string err_msg = "Failed to open file: " + file_path + ", error: " + std::strerror(errno); + LOG_ENGINE_ERROR_ << err_msg; + throw Exception(SERVER_CANNOT_CREATE_FILE, err_msg); + } + + size_t num_bytes = raw.size(); + fs_ptr->writer_ptr_->write(&num_bytes, sizeof(size_t)); + fs_ptr->writer_ptr_->write((void*)raw.data(), num_bytes); + fs_ptr->writer_ptr_->close(); +} + +} // namespace codec +} // namespace milvus diff --git a/core/src/codecs/snapshot/SSBlockFormat.h b/core/src/codecs/snapshot/SSBlockFormat.h new file mode 100644 index 000000000000..4abb38caa2c9 --- /dev/null +++ b/core/src/codecs/snapshot/SSBlockFormat.h @@ -0,0 +1,70 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include "knowhere/common/BinarySet.h" +#include "storage/FSHandler.h" + +namespace milvus { +namespace codec { + +struct ReadRange { + ReadRange(int64_t offset, int64_t num_bytes) : offset_(offset), num_bytes_(num_bytes) { + } + int64_t offset_; + int64_t num_bytes_; +}; + +using ReadRanges = std::vector; + +class SSBlockFormat { + public: + SSBlockFormat() = default; + + void + Read(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, std::vector& raw); + + void + Read(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, int64_t offset, int64_t num_bytes, + std::vector& raw); + + void + Read(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, const ReadRanges& read_ranges, + std::vector& raw); + + void + Write(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, const std::vector& raw); + + // No copy and move + SSBlockFormat(const SSBlockFormat&) = delete; + SSBlockFormat(SSBlockFormat&&) = delete; + + SSBlockFormat& + operator=(const SSBlockFormat&) = delete; + SSBlockFormat& + operator=(SSBlockFormat&&) = delete; +}; + +using SSBlockFormatPtr = std::shared_ptr; + +} // namespace codec +} // namespace milvus diff --git a/core/src/codecs/snapshot/SSCodec.cpp b/core/src/codecs/snapshot/SSCodec.cpp new file mode 100644 index 000000000000..a0cf7ba69e8e --- /dev/null +++ b/core/src/codecs/snapshot/SSCodec.cpp @@ -0,0 +1,75 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "codecs/snapshot/SSCodec.h" + +#include + +#include "SSDeletedDocsFormat.h" +#include "SSIdBloomFilterFormat.h" +#include "SSStructuredIndexFormat.h" +#include "SSVectorIndexFormat.h" + +namespace milvus { +namespace codec { + +SSCodec& +SSCodec::instance() { + static SSCodec s_instance; + return s_instance; +} + +SSCodec::SSCodec() { + block_format_ptr_ = std::make_shared(); + structured_index_format_ptr_ = std::make_shared(); + vector_index_format_ptr_ = std::make_shared(); + deleted_docs_format_ptr_ = std::make_shared(); + id_bloom_filter_format_ptr_ = std::make_shared(); + vector_compress_format_ptr_ = std::make_shared(); +} + +SSBlockFormatPtr +SSCodec::GetBlockFormat() { + return block_format_ptr_; +} + +SSVectorIndexFormatPtr +SSCodec::GetVectorIndexFormat() { + return vector_index_format_ptr_; +} + +SSStructuredIndexFormatPtr +SSCodec::GetStructuredIndexFormat() { + return structured_index_format_ptr_; +} + +SSDeletedDocsFormatPtr +SSCodec::GetDeletedDocsFormat() { + return deleted_docs_format_ptr_; +} + +SSIdBloomFilterFormatPtr +SSCodec::GetIdBloomFilterFormat() { + return id_bloom_filter_format_ptr_; +} + +SSVectorCompressFormatPtr +SSCodec::GetVectorCompressFormat() { + return vector_compress_format_ptr_; +} +} // namespace codec +} // namespace milvus diff --git a/core/src/codecs/snapshot/SSCodec.h b/core/src/codecs/snapshot/SSCodec.h new file mode 100644 index 000000000000..b5ee9cab729c --- /dev/null +++ b/core/src/codecs/snapshot/SSCodec.h @@ -0,0 +1,66 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "codecs/snapshot/SSBlockFormat.h" +#include "codecs/snapshot/SSDeletedDocsFormat.h" +#include "codecs/snapshot/SSIdBloomFilterFormat.h" +#include "codecs/snapshot/SSStructuredIndexFormat.h" +#include "codecs/snapshot/SSVectorCompressFormat.h" +#include "codecs/snapshot/SSVectorIndexFormat.h" + +namespace milvus { +namespace codec { + +class SSCodec { + public: + static SSCodec& + instance(); + + SSBlockFormatPtr + GetBlockFormat(); + + SSVectorIndexFormatPtr + GetVectorIndexFormat(); + + SSStructuredIndexFormatPtr + GetStructuredIndexFormat(); + + SSDeletedDocsFormatPtr + GetDeletedDocsFormat(); + + SSIdBloomFilterFormatPtr + GetIdBloomFilterFormat(); + + SSVectorCompressFormatPtr + GetVectorCompressFormat(); + + private: + SSCodec(); + + private: + SSBlockFormatPtr block_format_ptr_; + SSStructuredIndexFormatPtr structured_index_format_ptr_; + SSVectorIndexFormatPtr vector_index_format_ptr_; + SSDeletedDocsFormatPtr deleted_docs_format_ptr_; + SSIdBloomFilterFormatPtr id_bloom_filter_format_ptr_; + SSVectorCompressFormatPtr vector_compress_format_ptr_; +}; + +} // namespace codec +} // namespace milvus diff --git a/core/src/codecs/snapshot/SSDeletedDocsFormat.cpp b/core/src/codecs/snapshot/SSDeletedDocsFormat.cpp new file mode 100644 index 000000000000..bc998a4f70a5 --- /dev/null +++ b/core/src/codecs/snapshot/SSDeletedDocsFormat.cpp @@ -0,0 +1,183 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "codecs/snapshot/SSDeletedDocsFormat.h" + +#include +#include + +#define BOOST_NO_CXX11_SCOPED_ENUMS + +#include + +#undef BOOST_NO_CXX11_SCOPED_ENUMS + +#include +#include +#include + +#include "segment/Types.h" +#include "utils/Exception.h" +#include "utils/Log.h" + +namespace milvus { +namespace codec { + +const char* DELETED_DOCS_POSTFIX = ".del"; + +std::string +SSDeletedDocsFormat::FilePostfix() { + std::string str = DELETED_DOCS_POSTFIX; + return str; +} + +void +SSDeletedDocsFormat::Read(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, + segment::DeletedDocsPtr& deleted_docs) { + const std::string full_file_path = file_path + DELETED_DOCS_POSTFIX; + + int del_fd = open(full_file_path.c_str(), O_RDONLY, 00664); + if (del_fd == -1) { + std::string err_msg = "Failed to open file: " + full_file_path + ", error: " + std::strerror(errno); + LOG_ENGINE_ERROR_ << err_msg; + throw Exception(SERVER_CANNOT_CREATE_FILE, err_msg); + } + + size_t num_bytes; + if (::read(del_fd, &num_bytes, sizeof(size_t)) == -1) { + std::string err_msg = "Failed to read from file: " + full_file_path + ", error: " + std::strerror(errno); + LOG_ENGINE_ERROR_ << err_msg; + throw Exception(SERVER_WRITE_ERROR, err_msg); + } + + auto deleted_docs_size = num_bytes / sizeof(segment::offset_t); + std::vector deleted_docs_list; + deleted_docs_list.resize(deleted_docs_size); + + if (::read(del_fd, deleted_docs_list.data(), num_bytes) == -1) { + std::string err_msg = "Failed to read from file: " + full_file_path + ", error: " + std::strerror(errno); + LOG_ENGINE_ERROR_ << err_msg; + throw Exception(SERVER_WRITE_ERROR, err_msg); + } + + deleted_docs = std::make_shared(deleted_docs_list); + + if (::close(del_fd) == -1) { + std::string err_msg = "Failed to close file: " + full_file_path + ", error: " + std::strerror(errno); + LOG_ENGINE_ERROR_ << err_msg; + throw Exception(SERVER_WRITE_ERROR, err_msg); + } +} + +void +SSDeletedDocsFormat::Write(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, + const segment::DeletedDocsPtr& deleted_docs) { + const std::string full_file_path = file_path + DELETED_DOCS_POSTFIX; + + // Create a temporary file from the existing file + const std::string temp_path = file_path + ".temp_del"; + bool exists = boost::filesystem::exists(full_file_path); + if (exists) { + boost::filesystem::copy_file(full_file_path, temp_path, boost::filesystem::copy_option::fail_if_exists); + } + + // Write to the temp file, in order to avoid possible race condition with search (concurrent read and write) + int del_fd = open(temp_path.c_str(), O_RDWR | O_CREAT, 00664); + if (del_fd == -1) { + std::string err_msg = "Failed to open file: " + temp_path + ", error: " + std::strerror(errno); + LOG_ENGINE_ERROR_ << err_msg; + throw Exception(SERVER_CANNOT_CREATE_FILE, err_msg); + } + + size_t old_num_bytes; + if (exists) { + if (::read(del_fd, &old_num_bytes, sizeof(size_t)) == -1) { + std::string err_msg = "Failed to read from file: " + temp_path + ", error: " + std::strerror(errno); + LOG_ENGINE_ERROR_ << err_msg; + throw Exception(SERVER_WRITE_ERROR, err_msg); + } + } else { + old_num_bytes = 0; + } + + auto deleted_docs_list = deleted_docs->GetDeletedDocs(); + size_t new_num_bytes = old_num_bytes + sizeof(segment::offset_t) * deleted_docs->GetSize(); + + // rewind and overwrite with the new_num_bytes + int off = lseek(del_fd, 0, SEEK_SET); + if (off == -1) { + std::string err_msg = "Failed to seek file: " + temp_path + ", error: " + std::strerror(errno); + LOG_ENGINE_ERROR_ << err_msg; + throw Exception(SERVER_WRITE_ERROR, err_msg); + } + if (::write(del_fd, &new_num_bytes, sizeof(size_t)) == -1) { + std::string err_msg = "Failed to write to file" + temp_path + ", error: " + std::strerror(errno); + LOG_ENGINE_ERROR_ << err_msg; + throw Exception(SERVER_WRITE_ERROR, err_msg); + } + + // Move to the end of file and append + off = lseek(del_fd, 0, SEEK_END); + if (off == -1) { + std::string err_msg = "Failed to seek file: " + temp_path + ", error: " + std::strerror(errno); + LOG_ENGINE_ERROR_ << err_msg; + throw Exception(SERVER_WRITE_ERROR, err_msg); + } + if (::write(del_fd, deleted_docs_list.data(), sizeof(segment::offset_t) * deleted_docs->GetSize()) == -1) { + std::string err_msg = "Failed to write to file" + temp_path + ", error: " + std::strerror(errno); + LOG_ENGINE_ERROR_ << err_msg; + throw Exception(SERVER_WRITE_ERROR, err_msg); + } + + if (::close(del_fd) == -1) { + std::string err_msg = "Failed to close file: " + temp_path + ", error: " + std::strerror(errno); + LOG_ENGINE_ERROR_ << err_msg; + throw Exception(SERVER_WRITE_ERROR, err_msg); + } + + // Move temp file to delete file + boost::filesystem::rename(temp_path, full_file_path); +} + +void +SSDeletedDocsFormat::ReadSize(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, size_t& size) { + const std::string full_file_path = file_path + DELETED_DOCS_POSTFIX; + int del_fd = open(full_file_path.c_str(), O_RDONLY, 00664); + if (del_fd == -1) { + std::string err_msg = "Failed to open file: " + full_file_path + ", error: " + std::strerror(errno); + LOG_ENGINE_ERROR_ << err_msg; + throw Exception(SERVER_CANNOT_CREATE_FILE, err_msg); + } + + size_t num_bytes; + if (::read(del_fd, &num_bytes, sizeof(size_t)) == -1) { + std::string err_msg = "Failed to read from file: " + full_file_path + ", error: " + std::strerror(errno); + LOG_ENGINE_ERROR_ << err_msg; + throw Exception(SERVER_WRITE_ERROR, err_msg); + } + + size = num_bytes / sizeof(segment::offset_t); + + if (::close(del_fd) == -1) { + std::string err_msg = "Failed to close file: " + full_file_path + ", error: " + std::strerror(errno); + LOG_ENGINE_ERROR_ << err_msg; + throw Exception(SERVER_WRITE_ERROR, err_msg); + } +} + +} // namespace codec +} // namespace milvus diff --git a/core/src/codecs/snapshot/SSDeletedDocsFormat.h b/core/src/codecs/snapshot/SSDeletedDocsFormat.h new file mode 100644 index 000000000000..24300ef37a96 --- /dev/null +++ b/core/src/codecs/snapshot/SSDeletedDocsFormat.h @@ -0,0 +1,59 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +#include "segment/DeletedDocs.h" +#include "storage/FSHandler.h" + +namespace milvus { +namespace codec { + +class SSDeletedDocsFormat { + public: + SSDeletedDocsFormat() = default; + + std::string + FilePostfix(); + + void + Read(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, segment::DeletedDocsPtr& deleted_docs); + + void + Write(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, + const segment::DeletedDocsPtr& deleted_docs); + + void + ReadSize(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, size_t& size); + + // No copy and move + SSDeletedDocsFormat(const SSDeletedDocsFormat&) = delete; + SSDeletedDocsFormat(SSDeletedDocsFormat&&) = delete; + + SSDeletedDocsFormat& + operator=(const SSDeletedDocsFormat&) = delete; + SSDeletedDocsFormat& + operator=(SSDeletedDocsFormat&&) = delete; +}; + +using SSDeletedDocsFormatPtr = std::shared_ptr; + +} // namespace codec +} // namespace milvus diff --git a/core/src/codecs/snapshot/SSIdBloomFilterFormat.cpp b/core/src/codecs/snapshot/SSIdBloomFilterFormat.cpp new file mode 100644 index 000000000000..cdc6c2612b61 --- /dev/null +++ b/core/src/codecs/snapshot/SSIdBloomFilterFormat.cpp @@ -0,0 +1,82 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "codecs/snapshot/SSIdBloomFilterFormat.h" + +#include +#include +#include + +#include "utils/Exception.h" +#include "utils/Log.h" + +namespace milvus { +namespace codec { + +const char* BLOOM_FILTER_POSTFIX = ".bf"; + +constexpr unsigned int BLOOM_FILTER_CAPACITY = 500000; +constexpr double BLOOM_FILTER_ERROR_RATE = 0.01; + +std::string +SSIdBloomFilterFormat::FilePostfix() { + std::string str = BLOOM_FILTER_POSTFIX; + return str; +} + +void +SSIdBloomFilterFormat::Read(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, + segment::IdBloomFilterPtr& id_bloom_filter_ptr) { + const std::string full_file_path = file_path + BLOOM_FILTER_POSTFIX; + scaling_bloom_t* bloom_filter = + new_scaling_bloom_from_file(BLOOM_FILTER_CAPACITY, BLOOM_FILTER_ERROR_RATE, full_file_path.c_str()); + fiu_do_on("bloom_filter_nullptr", bloom_filter = nullptr); + if (bloom_filter == nullptr) { + std::string err_msg = "Failed to read bloom filter from file: " + full_file_path + ". " + std::strerror(errno); + LOG_ENGINE_ERROR_ << err_msg; + throw Exception(SERVER_UNEXPECTED_ERROR, err_msg); + } + id_bloom_filter_ptr = std::make_shared(bloom_filter); +} + +void +SSIdBloomFilterFormat::Write(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, + const segment::IdBloomFilterPtr& id_bloom_filter_ptr) { + const std::string full_file_path = file_path + BLOOM_FILTER_POSTFIX; + if (scaling_bloom_flush(id_bloom_filter_ptr->GetBloomFilter()) == -1) { + std::string err_msg = "Failed to write bloom filter to file: " + full_file_path + ". " + std::strerror(errno); + LOG_ENGINE_ERROR_ << err_msg; + throw Exception(SERVER_UNEXPECTED_ERROR, err_msg); + } +} + +void +SSIdBloomFilterFormat::Create(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, + segment::IdBloomFilterPtr& id_bloom_filter_ptr) { + const std::string full_file_path = file_path + BLOOM_FILTER_POSTFIX; + scaling_bloom_t* bloom_filter = + new_scaling_bloom(BLOOM_FILTER_CAPACITY, BLOOM_FILTER_ERROR_RATE, full_file_path.c_str()); + if (bloom_filter == nullptr) { + std::string err_msg = "Failed to read bloom filter from file: " + full_file_path + ". " + std::strerror(errno); + LOG_ENGINE_ERROR_ << err_msg; + throw Exception(SERVER_UNEXPECTED_ERROR, err_msg); + } + id_bloom_filter_ptr = std::make_shared(bloom_filter); +} + +} // namespace codec +} // namespace milvus diff --git a/core/src/codecs/snapshot/SSIdBloomFilterFormat.h b/core/src/codecs/snapshot/SSIdBloomFilterFormat.h new file mode 100644 index 000000000000..69f95c057448 --- /dev/null +++ b/core/src/codecs/snapshot/SSIdBloomFilterFormat.h @@ -0,0 +1,61 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +#include "segment/IdBloomFilter.h" +#include "storage/FSHandler.h" + +namespace milvus { +namespace codec { + +class SSIdBloomFilterFormat { + public: + SSIdBloomFilterFormat() = default; + + std::string + FilePostfix(); + + void + Read(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, + segment::IdBloomFilterPtr& id_bloom_filter_ptr); + + void + Write(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, + const segment::IdBloomFilterPtr& id_bloom_filter_ptr); + + void + Create(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, + segment::IdBloomFilterPtr& id_bloom_filter_ptr); + + // No copy and move + SSIdBloomFilterFormat(const SSIdBloomFilterFormat&) = delete; + SSIdBloomFilterFormat(SSIdBloomFilterFormat&&) = delete; + + SSIdBloomFilterFormat& + operator=(const SSIdBloomFilterFormat&) = delete; + SSIdBloomFilterFormat& + operator=(SSIdBloomFilterFormat&&) = delete; +}; + +using SSIdBloomFilterFormatPtr = std::shared_ptr; + +} // namespace codec +} // namespace milvus diff --git a/core/src/codecs/snapshot/SSStructuredIndexFormat.cpp b/core/src/codecs/snapshot/SSStructuredIndexFormat.cpp new file mode 100644 index 000000000000..bf664ef993e1 --- /dev/null +++ b/core/src/codecs/snapshot/SSStructuredIndexFormat.cpp @@ -0,0 +1,177 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "codecs/snapshot/SSStructuredIndexFormat.h" + +#include +#include +#include +#include +#include +#include + +#include "db/meta/MetaTypes.h" +#include "knowhere/index/structured_index/StructuredIndexSort.h" + +#include "utils/Exception.h" +#include "utils/Log.h" +#include "utils/TimeRecorder.h" + +namespace milvus { +namespace codec { + +const char* STRUCTURED_INDEX_POSTFIX = ".ind"; + +std::string +SSStructuredIndexFormat::FilePostfix() { + std::string str = STRUCTURED_INDEX_POSTFIX; + return str; +} + +knowhere::IndexPtr +SSStructuredIndexFormat::CreateStructuredIndex(const milvus::engine::meta::hybrid::DataType data_type) { + knowhere::IndexPtr index = nullptr; + switch (data_type) { + case engine::meta::hybrid::DataType::INT8: { + index = std::make_shared>(); + break; + } + case engine::meta::hybrid::DataType::INT16: { + index = std::make_shared>(); + break; + } + case engine::meta::hybrid::DataType::INT32: { + index = std::make_shared>(); + break; + } + case engine::meta::hybrid::DataType::INT64: { + index = std::make_shared>(); + break; + } + case engine::meta::hybrid::DataType::FLOAT: { + index = std::make_shared>(); + break; + } + case engine::meta::hybrid::DataType::DOUBLE: { + index = std::make_shared>(); + break; + } + default: { + LOG_ENGINE_ERROR_ << "Invalid field type"; + return nullptr; + } + } + return index; +} + +void +SSStructuredIndexFormat::Read(const milvus::storage::FSHandlerPtr& fs_ptr, const std::string& file_path, + knowhere::IndexPtr& index) { + milvus::TimeRecorder recorder("SSStructuredIndexFormat::Read"); + knowhere::BinarySet load_data_list; + + std::string full_file_path = file_path + STRUCTURED_INDEX_POSTFIX; + if (!fs_ptr->reader_ptr_->open(full_file_path)) { + LOG_ENGINE_ERROR_ << "Fail to open structured index: " << full_file_path; + return; + } + int64_t length = fs_ptr->reader_ptr_->length(); + if (length <= 0) { + LOG_ENGINE_ERROR_ << "Invalid structured index length: " << full_file_path; + return; + } + + size_t rp = 0; + fs_ptr->reader_ptr_->seekg(0); + + int32_t data_type = 0; + fs_ptr->reader_ptr_->read(&data_type, sizeof(data_type)); + rp += sizeof(data_type); + fs_ptr->reader_ptr_->seekg(rp); + + LOG_ENGINE_DEBUG_ << "Start to read_index(" << full_file_path << ") length: " << length << " bytes"; + while (rp < length) { + size_t meta_length; + fs_ptr->reader_ptr_->read(&meta_length, sizeof(meta_length)); + rp += sizeof(meta_length); + fs_ptr->reader_ptr_->seekg(rp); + + auto meta = new char[meta_length]; + fs_ptr->reader_ptr_->read(meta, meta_length); + rp += meta_length; + fs_ptr->reader_ptr_->seekg(rp); + + size_t bin_length; + fs_ptr->reader_ptr_->read(&bin_length, sizeof(bin_length)); + rp += sizeof(bin_length); + fs_ptr->reader_ptr_->seekg(rp); + + auto bin = new uint8_t[bin_length]; + fs_ptr->reader_ptr_->read(bin, bin_length); + rp += bin_length; + fs_ptr->reader_ptr_->seekg(rp); + + std::shared_ptr binptr(bin); + load_data_list.Append(std::string(meta, meta_length), binptr, bin_length); + delete[] meta; + } + fs_ptr->reader_ptr_->close(); + + double span = recorder.RecordSection("End"); + double rate = length * 1000000.0 / span / 1024 / 1024; + LOG_ENGINE_DEBUG_ << "SSStructuredIndexFormat::read(" << full_file_path << ") rate " << rate << "MB/s"; + + auto attr_type = static_cast(data_type); + index = CreateStructuredIndex(attr_type); + index->Load(load_data_list); +} + +void +SSStructuredIndexFormat::Write(const milvus::storage::FSHandlerPtr& fs_ptr, const std::string& file_path, + engine::meta::hybrid::DataType data_type, const knowhere::IndexPtr& index) { + milvus::TimeRecorder recorder("SSStructuredIndexFormat::Write"); + + std::string full_file_path = file_path + STRUCTURED_INDEX_POSTFIX; + auto binaryset = index->Serialize(knowhere::Config()); + + if (!fs_ptr->writer_ptr_->open(full_file_path)) { + LOG_ENGINE_ERROR_ << "Fail to open structured index: " << full_file_path; + return; + } + fs_ptr->writer_ptr_->write(&data_type, sizeof(data_type)); + + for (auto& iter : binaryset.binary_map_) { + auto meta = iter.first.c_str(); + size_t meta_length = iter.first.length(); + fs_ptr->writer_ptr_->write(&meta_length, sizeof(meta_length)); + fs_ptr->writer_ptr_->write((void*)meta, meta_length); + + auto binary = iter.second; + int64_t binary_length = binary->size; + fs_ptr->writer_ptr_->write(&binary_length, sizeof(binary_length)); + fs_ptr->writer_ptr_->write((void*)binary->data.get(), binary_length); + } + + fs_ptr->writer_ptr_->close(); + + double span = recorder.RecordSection("End"); + double rate = fs_ptr->writer_ptr_->length() * 1000000.0 / span / 1024 / 1024; + LOG_ENGINE_DEBUG_ << "SSStructuredIndexFormat::write(" << full_file_path << ") rate " << rate << "MB/s"; +} + +} // namespace codec +} // namespace milvus diff --git a/core/src/codecs/snapshot/SSStructuredIndexFormat.h b/core/src/codecs/snapshot/SSStructuredIndexFormat.h new file mode 100644 index 000000000000..447f24cccead --- /dev/null +++ b/core/src/codecs/snapshot/SSStructuredIndexFormat.h @@ -0,0 +1,62 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include "db/meta/MetaTypes.h" +#include "knowhere/index/Index.h" +#include "storage/FSHandler.h" + +namespace milvus { +namespace codec { + +class SSStructuredIndexFormat { + public: + SSStructuredIndexFormat() = default; + + std::string + FilePostfix(); + + void + Read(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, knowhere::IndexPtr& index); + + void + Write(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, engine::meta::hybrid::DataType data_type, + const knowhere::IndexPtr& index); + + // No copy and move + SSStructuredIndexFormat(const SSStructuredIndexFormat&) = delete; + SSStructuredIndexFormat(SSStructuredIndexFormat&&) = delete; + + SSStructuredIndexFormat& + operator=(const SSStructuredIndexFormat&) = delete; + SSStructuredIndexFormat& + operator=(SSStructuredIndexFormat&&) = delete; + + private: + knowhere::IndexPtr + CreateStructuredIndex(const engine::meta::hybrid::DataType data_type); +}; + +using SSStructuredIndexFormatPtr = std::shared_ptr; + +} // namespace codec +} // namespace milvus diff --git a/core/src/codecs/snapshot/SSVectorCompressFormat.cpp b/core/src/codecs/snapshot/SSVectorCompressFormat.cpp new file mode 100644 index 000000000000..452677eae749 --- /dev/null +++ b/core/src/codecs/snapshot/SSVectorCompressFormat.cpp @@ -0,0 +1,87 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include "codecs/snapshot/SSVectorCompressFormat.h" +#include "knowhere/common/BinarySet.h" +#include "utils/Exception.h" +#include "utils/Log.h" +#include "utils/TimeRecorder.h" + +namespace milvus { +namespace codec { + +const char* VECTOR_COMPRESS_POSTFIX = ".cmp"; + +std::string +SSVectorCompressFormat::FilePostfix() { + std::string str = VECTOR_COMPRESS_POSTFIX; + return str; +} + +void +SSVectorCompressFormat::Read(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, + knowhere::BinaryPtr& compress) { + milvus::TimeRecorder recorder("SSVectorCompressFormat::Read"); + + const std::string full_file_path = file_path + VECTOR_COMPRESS_POSTFIX; + if (!fs_ptr->reader_ptr_->open(full_file_path)) { + LOG_ENGINE_ERROR_ << "Fail to open vector compress: " << full_file_path; + return; + } + + int64_t length = fs_ptr->reader_ptr_->length(); + if (length <= 0) { + LOG_ENGINE_ERROR_ << "Invalid vector compress length: " << full_file_path; + return; + } + + compress->data = std::shared_ptr(new uint8_t[length]); + compress->size = length; + + fs_ptr->reader_ptr_->seekg(0); + fs_ptr->reader_ptr_->read(compress->data.get(), length); + fs_ptr->reader_ptr_->close(); + + double span = recorder.RecordSection("End"); + double rate = length * 1000000.0 / span / 1024 / 1024; + LOG_ENGINE_DEBUG_ << "SSVectorCompressFormat::Read(" << full_file_path << ") rate " << rate << "MB/s"; +} + +void +SSVectorCompressFormat::Write(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, + const knowhere::BinaryPtr& compress) { + milvus::TimeRecorder recorder("SSVectorCompressFormat::Write"); + + const std::string full_file_path = file_path + VECTOR_COMPRESS_POSTFIX; + if (!fs_ptr->writer_ptr_->open(full_file_path)) { + LOG_ENGINE_ERROR_ << "Fail to open vector compress: " << full_file_path; + return; + } + + fs_ptr->writer_ptr_->write(compress->data.get(), compress->size); + fs_ptr->writer_ptr_->close(); + + double span = recorder.RecordSection("End"); + double rate = compress->size * 1000000.0 / span / 1024 / 1024; + LOG_ENGINE_DEBUG_ << "SSVectorCompressFormat::Write(" << full_file_path << ") rate " << rate << "MB/s"; +} + +} // namespace codec +} // namespace milvus diff --git a/core/src/codecs/snapshot/SSVectorCompressFormat.h b/core/src/codecs/snapshot/SSVectorCompressFormat.h new file mode 100644 index 000000000000..6bbe26f36202 --- /dev/null +++ b/core/src/codecs/snapshot/SSVectorCompressFormat.h @@ -0,0 +1,55 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +#include "knowhere/common/BinarySet.h" +#include "storage/FSHandler.h" + +namespace milvus { +namespace codec { + +class SSVectorCompressFormat { + public: + SSVectorCompressFormat() = default; + + std::string + FilePostfix(); + + void + Read(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, knowhere::BinaryPtr& compress); + + void + Write(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, const knowhere::BinaryPtr& compress); + + // No copy and move + SSVectorCompressFormat(const SSVectorCompressFormat&) = delete; + SSVectorCompressFormat(SSVectorCompressFormat&&) = delete; + + SSVectorCompressFormat& + operator=(const SSVectorCompressFormat&) = delete; + SSVectorCompressFormat& + operator=(SSVectorCompressFormat&&) = delete; +}; + +using SSVectorCompressFormatPtr = std::shared_ptr; + +} // namespace codec +} // namespace milvus diff --git a/core/src/codecs/snapshot/SSVectorIndexFormat.cpp b/core/src/codecs/snapshot/SSVectorIndexFormat.cpp new file mode 100644 index 000000000000..257bceffc313 --- /dev/null +++ b/core/src/codecs/snapshot/SSVectorIndexFormat.cpp @@ -0,0 +1,225 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include "codecs/snapshot/SSCodec.h" +#include "codecs/snapshot/SSVectorIndexFormat.h" +#include "knowhere/common/BinarySet.h" +#include "knowhere/index/vector_index/VecIndex.h" +#include "knowhere/index/vector_index/VecIndexFactory.h" +#include "utils/Exception.h" +#include "utils/Log.h" +#include "utils/TimeRecorder.h" + +namespace milvus { +namespace codec { + +const char* VECTOR_INDEX_POSTFIX = ".idx"; + +std::string +SSVectorIndexFormat::FilePostfix() { + std::string str = VECTOR_INDEX_POSTFIX; + return str; +} + +void +SSVectorIndexFormat::ReadRaw(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, + knowhere::BinaryPtr& data) { + milvus::TimeRecorder recorder("SSVectorIndexFormat::ReadRaw"); + + if (!fs_ptr->reader_ptr_->open(file_path.c_str())) { + std::string err_msg = "Failed to open raw file: " + file_path + ", error: " + std::strerror(errno); + LOG_ENGINE_ERROR_ << err_msg; + throw Exception(SERVER_CANNOT_OPEN_FILE, err_msg); + } + + size_t num_bytes; + fs_ptr->reader_ptr_->read(&num_bytes, sizeof(size_t)); + + data = std::make_shared(); + data->size = num_bytes; + data->data = std::shared_ptr(new uint8_t[num_bytes]); + + // Beginning of file is num_bytes + fs_ptr->reader_ptr_->seekg(sizeof(size_t)); + fs_ptr->reader_ptr_->read(data->data.get(), num_bytes); + fs_ptr->reader_ptr_->close(); + + double span = recorder.RecordSection("End"); + double rate = num_bytes * 1000000.0 / span / 1024 / 1024; + LOG_ENGINE_DEBUG_ << "SSVectorIndexFormat::ReadIndex(" << file_path << ") rate " << rate << "MB/s"; +} + +void +SSVectorIndexFormat::ReadIndex(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, + knowhere::BinarySet& data) { + milvus::TimeRecorder recorder("SSVectorIndexFormat::ReadIndex"); + + std::string full_file_path = file_path + VECTOR_INDEX_POSTFIX; + if (!fs_ptr->reader_ptr_->open(full_file_path)) { + std::string err_msg = "Failed to open vector index: " + full_file_path + ", error: " + std::strerror(errno); + LOG_ENGINE_ERROR_ << err_msg; + throw Exception(SERVER_CANNOT_OPEN_FILE, err_msg); + } + + int64_t length = fs_ptr->reader_ptr_->length(); + if (length <= 0) { + LOG_ENGINE_ERROR_ << "Invalid vector index length: " << full_file_path; + return; + } + + int64_t rp = 0; + fs_ptr->reader_ptr_->seekg(0); + + int32_t current_type = 0; + fs_ptr->reader_ptr_->read(¤t_type, sizeof(current_type)); + rp += sizeof(current_type); + fs_ptr->reader_ptr_->seekg(rp); + + LOG_ENGINE_DEBUG_ << "Start to ReadIndex(" << full_file_path << ") length: " << length << " bytes"; + while (rp < length) { + size_t meta_length; + fs_ptr->reader_ptr_->read(&meta_length, sizeof(meta_length)); + rp += sizeof(meta_length); + fs_ptr->reader_ptr_->seekg(rp); + + auto meta = new char[meta_length]; + fs_ptr->reader_ptr_->read(meta, meta_length); + rp += meta_length; + fs_ptr->reader_ptr_->seekg(rp); + + size_t bin_length; + fs_ptr->reader_ptr_->read(&bin_length, sizeof(bin_length)); + rp += sizeof(bin_length); + fs_ptr->reader_ptr_->seekg(rp); + + auto bin = new uint8_t[bin_length]; + fs_ptr->reader_ptr_->read(bin, bin_length); + rp += bin_length; + fs_ptr->reader_ptr_->seekg(rp); + + std::shared_ptr binptr(bin); + data.Append(std::string(meta, meta_length), binptr, bin_length); + delete[] meta; + } + fs_ptr->reader_ptr_->close(); + + double span = recorder.RecordSection("End"); + double rate = length * 1000000.0 / span / 1024 / 1024; + LOG_ENGINE_DEBUG_ << "SSVectorIndexFormat::ReadIndex(" << full_file_path << ") rate " << rate << "MB/s"; +} + +void +SSVectorIndexFormat::ReadCompress(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, + knowhere::BinaryPtr& data) { + auto& ss_codec = codec::SSCodec::instance(); + ss_codec.GetVectorCompressFormat()->Read(fs_ptr, file_path, data); +} + +void +SSVectorIndexFormat::ConvertRaw(const std::vector& raw, knowhere::BinaryPtr& data) { + data = std::make_shared(); + data->size = raw.size(); + data->data = std::shared_ptr(new uint8_t[data->size]); +} + +void +SSVectorIndexFormat::ConstructIndex(const std::string& index_name, knowhere::BinarySet& index_data, + knowhere::BinaryPtr& raw_data, knowhere::BinaryPtr& compress_data, + knowhere::VecIndexPtr& index) { + knowhere::VecIndexFactory& vec_index_factory = knowhere::VecIndexFactory::GetInstance(); + index = vec_index_factory.CreateVecIndex(index_name, knowhere::IndexMode::MODE_CPU); + if (index != nullptr) { + int64_t length = 0; + for (auto& pair : index_data.binary_map_) { + length += pair.second->size; + } + + if (raw_data != nullptr) { + LOG_ENGINE_DEBUG_ << "load index with " << RAW_DATA << " " << raw_data->size; + index_data.Append(RAW_DATA, raw_data); + length += raw_data->size; + } + + if (compress_data != nullptr) { + LOG_ENGINE_DEBUG_ << "load index with " << SQ8_DATA << " " << compress_data->size; + index_data.Append(SQ8_DATA, compress_data); + length += compress_data->size; + } + + index->Load(index_data); + index->UpdateIndexSize(); + LOG_ENGINE_DEBUG_ << "index file size " << length << " index size " << index->IndexSize(); + } else { + std::string err_msg = "Fail to create vector index"; + LOG_ENGINE_ERROR_ << err_msg; + throw Exception(SERVER_UNEXPECTED_ERROR, err_msg); + } +} + +void +SSVectorIndexFormat::WriteIndex(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, + const knowhere::VecIndexPtr& index) { + milvus::TimeRecorder recorder("SVectorIndexFormat::WriteIndex"); + + std::string full_file_path = file_path + VECTOR_INDEX_POSTFIX; + auto binaryset = index->Serialize(knowhere::Config()); + int32_t index_type = knowhere::StrToOldIndexType(index->index_type()); + + if (!fs_ptr->writer_ptr_->open(full_file_path)) { + LOG_ENGINE_ERROR_ << "Fail to open vector index: " << full_file_path; + return; + } + + fs_ptr->writer_ptr_->write(&index_type, sizeof(index_type)); + + for (auto& iter : binaryset.binary_map_) { + auto meta = iter.first.c_str(); + size_t meta_length = iter.first.length(); + fs_ptr->writer_ptr_->write(&meta_length, sizeof(meta_length)); + fs_ptr->writer_ptr_->write((void*)meta, meta_length); + + auto binary = iter.second; + int64_t binary_length = binary->size; + fs_ptr->writer_ptr_->write(&binary_length, sizeof(binary_length)); + fs_ptr->writer_ptr_->write((void*)binary->data.get(), binary_length); + } + fs_ptr->writer_ptr_->close(); + + double span = recorder.RecordSection("End"); + double rate = fs_ptr->writer_ptr_->length() * 1000000.0 / span / 1024 / 1024; + LOG_ENGINE_DEBUG_ << "SSVectorIndexFormat::WriteIndex(" << full_file_path << ") rate " << rate << "MB/s"; +} + +void +SSVectorIndexFormat::WriteCompress(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, + const knowhere::VecIndexPtr& index) { + milvus::TimeRecorder recorder("SSVectorIndexFormat::WriteCompress"); + + auto binaryset = index->Serialize(knowhere::Config()); + + auto sq8_data = binaryset.Erase(SQ8_DATA); + if (sq8_data != nullptr) { + auto& ss_codec = codec::SSCodec::instance(); + ss_codec.GetVectorCompressFormat()->Write(fs_ptr, file_path, sq8_data); + } +} + +} // namespace codec +} // namespace milvus diff --git a/core/src/codecs/snapshot/SSVectorIndexFormat.h b/core/src/codecs/snapshot/SSVectorIndexFormat.h new file mode 100644 index 000000000000..ecec6e5eb7d3 --- /dev/null +++ b/core/src/codecs/snapshot/SSVectorIndexFormat.h @@ -0,0 +1,73 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include "knowhere/index/vector_index/VecIndex.h" +#include "storage/FSHandler.h" + +namespace milvus { +namespace codec { + +class SSVectorIndexFormat { + public: + SSVectorIndexFormat() = default; + + std::string + FilePostfix(); + + void + ReadRaw(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, knowhere::BinaryPtr& data); + + void + ReadIndex(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, knowhere::BinarySet& data); + + void + ReadCompress(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, knowhere::BinaryPtr& data); + + void + ConvertRaw(const std::vector& raw, knowhere::BinaryPtr& data); + + void + ConstructIndex(const std::string& index_name, knowhere::BinarySet& index_data, knowhere::BinaryPtr& raw_data, + knowhere::BinaryPtr& compress_data, knowhere::VecIndexPtr& index); + + void + WriteIndex(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, const knowhere::VecIndexPtr& index); + + void + WriteCompress(const storage::FSHandlerPtr& fs_ptr, const std::string& file_path, + const knowhere::VecIndexPtr& index); + + // No copy and move + SSVectorIndexFormat(const SSVectorIndexFormat&) = delete; + SSVectorIndexFormat(SSVectorIndexFormat&&) = delete; + + SSVectorIndexFormat& + operator=(const SSVectorIndexFormat&) = delete; + SSVectorIndexFormat& + operator=(SSVectorIndexFormat&&) = delete; +}; + +using SSVectorIndexFormatPtr = std::shared_ptr; + +} // namespace codec +} // namespace milvus diff --git a/core/src/config/Config.cpp b/core/src/config/Config.cpp index 63688ae4932b..4da833c1fb62 100644 --- a/core/src/config/Config.cpp +++ b/core/src/config/Config.cpp @@ -113,6 +113,8 @@ const char* CONFIG_ENGINE_OMP_THREAD_NUM = "omp_thread_num"; const char* CONFIG_ENGINE_OMP_THREAD_NUM_DEFAULT = "0"; const char* CONFIG_ENGINE_SIMD_TYPE = "simd_type"; const char* CONFIG_ENGINE_SIMD_TYPE_DEFAULT = "auto"; +const char* CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ = "search_combine_nq"; +const char* CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ_DEFAULT = "64"; /* gpu resource config */ const char* CONFIG_GPU_RESOURCE = "gpu"; @@ -198,6 +200,9 @@ Config::Config() { std::string node_blas_threshold = std::string(CONFIG_ENGINE) + "." + CONFIG_ENGINE_USE_BLAS_THRESHOLD; config_callback_[node_blas_threshold] = empty_map; + std::string node_search_combine = std::string(CONFIG_ENGINE) + "." + CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ; + config_callback_[node_search_combine] = empty_map; + // gpu resources config std::string node_gpu_enable = std::string(CONFIG_GPU_RESOURCE) + "." + CONFIG_GPU_RESOURCE_ENABLE; config_callback_[node_gpu_enable] = empty_map; @@ -451,6 +456,7 @@ Config::ResetDefaultConfig() { STATUS_CHECK(SetEngineConfigUseBlasThreshold(CONFIG_ENGINE_USE_BLAS_THRESHOLD_DEFAULT)); STATUS_CHECK(SetEngineConfigOmpThreadNum(CONFIG_ENGINE_OMP_THREAD_NUM_DEFAULT)); STATUS_CHECK(SetEngineConfigSimdType(CONFIG_ENGINE_SIMD_TYPE_DEFAULT)); + STATUS_CHECK(SetEngineSearchCombineMaxNq(CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ_DEFAULT)); /* gpu resource config */ #ifdef MILVUS_GPU_VERSION @@ -578,6 +584,8 @@ Config::SetConfigCli(const std::string& parent_key, const std::string& child_key status = SetEngineConfigOmpThreadNum(value); } else if (child_key == CONFIG_ENGINE_SIMD_TYPE) { status = SetEngineConfigSimdType(value); + } else if (child_key == CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ) { + status = SetEngineSearchCombineMaxNq(value); } else { status = Status(SERVER_UNEXPECTED_ERROR, invalid_node_str); } @@ -1344,6 +1352,18 @@ Config::CheckEngineConfigSimdType(const std::string& value) { return Status::OK(); } +Status +Config::CheckEngineSearchCombineMaxNq(const std::string& value) { + fiu_return_on("check_config_search_combine_nq_fail", Status(SERVER_INVALID_ARGUMENT, "")); + + if (!ValidateStringIsNumber(value).ok()) { + std::string msg = "Invalid omp thread num: " + value + + ". Possible reason: engine_config.omp_thread_num is not a positive integer."; + return Status(SERVER_INVALID_ARGUMENT, msg); + } + return Status::OK(); +} + /* gpu resource config */ #ifdef MILVUS_GPU_VERSION Status @@ -1967,6 +1987,15 @@ Config::GetEngineConfigSimdType(std::string& value) { return CheckEngineConfigSimdType(value); } +Status +Config::GetEngineSearchCombineMaxNq(int64_t& value) { + std::string str = + GetConfigStr(CONFIG_ENGINE, CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ, CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ_DEFAULT); + // STATUS_CHECK(CheckEngineSearchCombineMaxNq(str)); + value = std::stoll(str); + return Status::OK(); +} + /* gpu resource config */ #ifdef MILVUS_GPU_VERSION Status @@ -2361,8 +2390,16 @@ Config::SetEngineConfigSimdType(const std::string& value) { return SetConfigValueInMem(CONFIG_ENGINE, CONFIG_ENGINE_SIMD_TYPE, value); } +Status +Config::SetEngineSearchCombineMaxNq(const std::string& value) { + STATUS_CHECK(CheckEngineSearchCombineMaxNq(value)); + STATUS_CHECK(SetConfigValueInMem(CONFIG_ENGINE, CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ, value)); + return ExecCallBacks(CONFIG_ENGINE, CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ, value); +} + /* gpu resource config */ #ifdef MILVUS_GPU_VERSION + Status Config::SetGpuResourceConfigEnable(const std::string& value) { STATUS_CHECK(CheckGpuResourceConfigEnable(value)); @@ -2407,6 +2444,7 @@ Config::SetGpuResourceConfigBuildIndexResources(const std::string& value) { STATUS_CHECK(SetConfigValueInMem(CONFIG_GPU_RESOURCE, CONFIG_GPU_RESOURCE_BUILD_INDEX_RESOURCES, value)); return ExecCallBacks(CONFIG_GPU_RESOURCE, CONFIG_GPU_RESOURCE_BUILD_INDEX_RESOURCES, value); } + #endif /* tracing config */ diff --git a/core/src/config/Config.h b/core/src/config/Config.h index 1819c0c1e1ab..563b2eb48782 100644 --- a/core/src/config/Config.h +++ b/core/src/config/Config.h @@ -100,6 +100,8 @@ extern const char* CONFIG_ENGINE_OMP_THREAD_NUM; extern const char* CONFIG_ENGINE_OMP_THREAD_NUM_DEFAULT; extern const char* CONFIG_ENGINE_SIMD_TYPE; extern const char* CONFIG_ENGINE_SIMD_TYPE_DEFAULT; +extern const char* CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ; +extern const char* CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ_DEFAULT; /* gpu resource config */ extern const char* CONFIG_GPU_RESOURCE; @@ -264,6 +266,8 @@ class Config { CheckEngineConfigOmpThreadNum(const std::string& value); Status CheckEngineConfigSimdType(const std::string& value); + Status + CheckEngineSearchCombineMaxNq(const std::string& value); /* gpu resource config */ #ifdef MILVUS_GPU_VERSION @@ -382,6 +386,8 @@ class Config { GetEngineConfigOmpThreadNum(int64_t& value); Status GetEngineConfigSimdType(std::string& value); + Status + GetEngineSearchCombineMaxNq(int64_t& value); /* gpu resource config */ #ifdef MILVUS_GPU_VERSION @@ -492,6 +498,8 @@ class Config { SetEngineConfigOmpThreadNum(const std::string& value); Status SetEngineConfigSimdType(const std::string& value); + Status + SetEngineSearchCombineMaxNq(const std::string& value); /* gpu resource config */ #ifdef MILVUS_GPU_VERSION diff --git a/core/src/config/Utils.cpp b/core/src/config/Utils.cpp index 93885744f65e..bad3b1f4cb73 100644 --- a/core/src/config/Utils.cpp +++ b/core/src/config/Utils.cpp @@ -251,7 +251,8 @@ ValidateDbURI(const std::string& uri) { if (std::regex_match(uri, pieces_match, uriRegex)) { std::string dialect = pieces_match[1].str(); std::transform(dialect.begin(), dialect.end(), dialect.begin(), ::tolower); - if (dialect.find("mysql") == std::string::npos && dialect.find("sqlite") == std::string::npos) { + if (dialect.find("mysql") == std::string::npos && dialect.find("sqlite") == std::string::npos && + dialect.find("mock") == std::string::npos) { LOG_SERVER_ERROR_ << "Invalid dialect in URI: dialect = " << dialect; okay = false; } diff --git a/core/src/config/handler/EngineConfigHandler.cpp b/core/src/config/handler/EngineConfigHandler.cpp index 51b08e17ba35..e838bc077397 100644 --- a/core/src/config/handler/EngineConfigHandler.cpp +++ b/core/src/config/handler/EngineConfigHandler.cpp @@ -19,10 +19,12 @@ namespace server { EngineConfigHandler::EngineConfigHandler() { auto& config = Config::GetInstance(); config.GetEngineConfigUseBlasThreshold(use_blas_threshold_); + config.GetEngineSearchCombineMaxNq(search_combine_nq_); } EngineConfigHandler::~EngineConfigHandler() { RemoveUseBlasThresholdListener(); + RemoveSearchCombineMaxNqListener(); } //////////////////////////// Listener methods ////////////////////////////////// @@ -48,5 +50,27 @@ EngineConfigHandler::RemoveUseBlasThresholdListener() { config.CancelCallBack(CONFIG_ENGINE, CONFIG_ENGINE_USE_BLAS_THRESHOLD, identity_); } +void +EngineConfigHandler::AddSearchCombineMaxNqListener() { + ConfigCallBackF lambda = [this](const std::string& value) -> Status { + auto& config = server::Config::GetInstance(); + auto status = config.GetEngineSearchCombineMaxNq(search_combine_nq_); + if (status.ok()) { + OnSearchCombineMaxNqChanged(search_combine_nq_); + } + + return status; + }; + + auto& config = Config::GetInstance(); + config.RegisterCallBack(CONFIG_ENGINE, CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ, identity_, lambda); +} + +void +EngineConfigHandler::RemoveSearchCombineMaxNqListener() { + auto& config = Config::GetInstance(); + config.CancelCallBack(CONFIG_ENGINE, CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ, identity_); +} + } // namespace server } // namespace milvus diff --git a/core/src/config/handler/EngineConfigHandler.h b/core/src/config/handler/EngineConfigHandler.h index ebc055c7c8e5..3fed6c9847bf 100644 --- a/core/src/config/handler/EngineConfigHandler.h +++ b/core/src/config/handler/EngineConfigHandler.h @@ -28,16 +28,27 @@ class EngineConfigHandler : virtual public ConfigHandler { OnUseBlasThresholdChanged(int64_t threshold) { } + virtual void + OnSearchCombineMaxNqChanged(int64_t nq) { + search_combine_nq_ = nq; + } + protected: void AddUseBlasThresholdListener(); - protected: void RemoveUseBlasThresholdListener(); + void + AddSearchCombineMaxNqListener(); + + void + RemoveSearchCombineMaxNqListener(); + protected: int64_t use_blas_threshold_ = std::stoll(CONFIG_ENGINE_USE_BLAS_THRESHOLD_DEFAULT); + int64_t search_combine_nq_ = std::stoll(CONFIG_ENGINE_SEARCH_COMBINE_MAX_NQ_DEFAULT); }; } // namespace server diff --git a/core/src/db/DB.h b/core/src/db/DB.h index adb07ff87ac6..5a346966110a 100644 --- a/core/src/db/DB.h +++ b/core/src/db/DB.h @@ -11,7 +11,6 @@ #pragma once -#include #include #include #include @@ -28,8 +27,6 @@ namespace milvus { namespace engine { -class Env; - class DB { public: DB() = default; diff --git a/core/src/db/DBImpl.cpp b/core/src/db/DBImpl.cpp index 0a6b4a0a3017..2ee9e31072ec 100644 --- a/core/src/db/DBImpl.cpp +++ b/core/src/db/DBImpl.cpp @@ -1353,8 +1353,8 @@ DBImpl::GetEntitiesByID(const std::string& collection_id, const milvus::engine:: } std::unordered_map attr_type; for (const auto& schema : fields_schema.fields_schema_) { - if (schema.field_type_ == (int32_t)engine::meta::hybrid::DataType::FLOAT_VECTOR || - schema.field_type_ == (int32_t)engine::meta::hybrid::DataType::BINARY_VECTOR) { + if (schema.field_type_ == (int32_t)engine::meta::hybrid::DataType::VECTOR_FLOAT || + schema.field_type_ == (int32_t)engine::meta::hybrid::DataType::VECTOR_BINARY) { continue; } for (const auto& name : field_names) { @@ -1865,7 +1865,7 @@ DBImpl::FlushAttrsIndex(const std::string& collection_id) { } for (auto& field_schema : fields_schema.fields_schema_) { - if (field_schema.field_type_ != (int32_t)meta::hybrid::DataType::FLOAT_VECTOR) { + if (field_schema.field_type_ != (int32_t)meta::hybrid::DataType::VECTOR_FLOAT) { attr_types.insert( std::make_pair(field_schema.field_name_, (meta::hybrid::DataType)field_schema.field_type_)); field_names.emplace_back(field_schema.field_name_); @@ -2895,7 +2895,7 @@ DBImpl::GetPartitionsByTags(const std::string& collection_id, const std::vector< Status DBImpl::UpdateCollectionIndexRecursively(const std::string& collection_id, const CollectionIndex& index) { DropIndex(collection_id); - + WaitMergeFileFinish(); // DropIndex called StartMergeTask, need to wait merge thread finish auto status = meta_ptr_->UpdateCollectionIndex(collection_id, index); fiu_do_on("DBImpl.UpdateCollectionIndexRecursively.fail_update_collection_index", status = Status(DB_META_TRANSACTION_FAILED, "")); diff --git a/core/src/db/DBImpl.h b/core/src/db/DBImpl.h index 44e329cc98e0..0158bf8bacb0 100644 --- a/core/src/db/DBImpl.h +++ b/core/src/db/DBImpl.h @@ -12,7 +12,6 @@ #pragma once #include -#include #include #include #include @@ -27,6 +26,7 @@ #include "config/handler/EngineConfigHandler.h" #include "db/DB.h" #include "db/IndexFailedChecker.h" +#include "db/SimpleWaitNotify.h" #include "db/Types.h" #include "db/insert/MemManager.h" #include "db/merge/MergeManager.h" @@ -324,47 +324,6 @@ class DBImpl : public DB, public server::CacheConfigHandler, public server::Engi std::thread bg_metric_thread_; std::thread bg_index_thread_; - struct SimpleWaitNotify { - bool notified_ = false; - std::mutex mutex_; - std::condition_variable cv_; - - void - Wait() { - std::unique_lock lck(mutex_); - if (!notified_) { - cv_.wait(lck); - } - notified_ = false; - } - - void - Wait_Until(const std::chrono::system_clock::time_point& tm_pint) { - std::unique_lock lck(mutex_); - if (!notified_) { - cv_.wait_until(lck, tm_pint); - } - notified_ = false; - } - - void - Wait_For(const std::chrono::system_clock::duration& tm_dur) { - std::unique_lock lck(mutex_); - if (!notified_) { - cv_.wait_for(lck, tm_dur); - } - notified_ = false; - } - - void - Notify() { - std::unique_lock lck(mutex_); - notified_ = true; - lck.unlock(); - cv_.notify_one(); - } - }; - SimpleWaitNotify swn_wal_; SimpleWaitNotify swn_flush_; SimpleWaitNotify swn_metric_; diff --git a/core/src/db/SSDBImpl.cpp b/core/src/db/SSDBImpl.cpp index 5b2be9302bd9..ba4929494c2f 100644 --- a/core/src/db/SSDBImpl.cpp +++ b/core/src/db/SSDBImpl.cpp @@ -10,10 +10,36 @@ // or implied. See the License for the specific language governing permissions and limitations under the License. #include "db/SSDBImpl.h" +#include "cache/CpuCacheMgr.h" +#include "config/Config.h" +#include "db/IDGenerator.h" +#include "db/SnapshotUtils.h" +#include "db/SnapshotVisitor.h" +#include "db/merge/MergeManagerFactory.h" +#include "db/merge/SSMergeTask.h" #include "db/snapshot/CompoundOperations.h" +#include "db/snapshot/EventExecutor.h" +#include "db/snapshot/IterateHandler.h" +#include "db/snapshot/OperationExecutor.h" +#include "db/snapshot/ResourceHelper.h" +#include "db/snapshot/ResourceTypes.h" #include "db/snapshot/Snapshots.h" +#include "insert/MemManagerFactory.h" +#include "knowhere/index/vector_index/helpers/BuilderSuspend.h" +#include "metrics/Metrics.h" +#include "metrics/SystemInfo.h" +#include "scheduler/Definition.h" +#include "scheduler/SchedInst.h" +#include "scheduler/job/SSSearchJob.h" +#include "segment/SSSegmentReader.h" +#include "segment/SSSegmentWriter.h" +#include "utils/Exception.h" +#include "utils/StringHelpFunctions.h" +#include "utils/TimeRecorder.h" #include "wal/WalDefinations.h" +#include +#include #include #include @@ -21,6 +47,10 @@ namespace milvus { namespace engine { namespace { +constexpr uint64_t BACKGROUND_METRIC_INTERVAL = 1; +constexpr uint64_t BACKGROUND_INDEX_INTERVAL = 1; +constexpr uint64_t WAIT_BUILD_INDEX_INTERVAL = 5; + static const Status SHUTDOWN_ERROR = Status(DB_ERROR, "Milvus server is shutdown!"); } // namespace @@ -29,7 +59,11 @@ static const Status SHUTDOWN_ERROR = Status(DB_ERROR, "Milvus server is shutdown return SHUTDOWN_ERROR; \ } -SSDBImpl::SSDBImpl(const DBOptions& options) : options_(options), initialized_(false) { +SSDBImpl::SSDBImpl(const DBOptions& options) + : options_(options), initialized_(false), merge_thread_pool_(1, 1), index_thread_pool_(1, 1) { + mem_mgr_ = MemManagerFactory::SSBuild(options_); + merge_mgr_ptr_ = MergeManagerFactory::SSBuild(options_); + if (options_.wal_enable_) { wal::MXLogConfiguration mxlog_config; mxlog_config.recovery_error_ignore = options_.recovery_error_ignore_; @@ -45,18 +79,85 @@ SSDBImpl::~SSDBImpl() { Stop(); } -/////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -// external api -/////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// +// External APIs +//////////////////////////////////////////////////////////////////////////////// Status SSDBImpl::Start() { if (initialized_.load(std::memory_order_acquire)) { return Status::OK(); } + // TODO(yhz): Get storage url + auto& config = server::Config::GetInstance(); + std::string path; + STATUS_CHECK(config.GetStorageConfigPath(path)); + + std::string url; + STATUS_CHECK(config.GetGeneralConfigMetaURI(url)); + + // snapshot + auto store = snapshot::Store::Build(url, path); + snapshot::OperationExecutor::Init(store); + snapshot::OperationExecutor::GetInstance().Start(); + snapshot::EventExecutor::Init(store); + snapshot::EventExecutor::GetInstance().Start(); + snapshot::Snapshots::GetInstance().Init(store); + // LOG_ENGINE_TRACE_ << "DB service start"; initialized_.store(true, std::memory_order_release); + // TODO: merge files + + // wal + if (options_.wal_enable_) { + return Status(SERVER_NOT_IMPLEMENT, "Wal not implemented"); + // auto error_code = DB_ERROR; + // if (wal_mgr_ != nullptr) { + // error_code = wal_mgr_->Init(); + // } + // if (error_code != WAL_SUCCESS) { + // throw Exception(error_code, "Wal init error!"); + // } + // + // // recovery + // while (true) { + // wal::MXLogRecord record; + // auto error_code = wal_mgr_->GetNextRecovery(record); + // if (error_code != WAL_SUCCESS) { + // throw Exception(error_code, "Wal recovery error!"); + // } + // if (record.type == wal::MXLogType::None) { + // break; + // } + // ExecWalRecord(record); + // } + // + // // for distribute version, some nodes are read only + // if (options_.mode_ != DBOptions::MODE::CLUSTER_READONLY) { + // // background wal thread + // bg_wal_thread_ = std::thread(&SSDBImpl::TimingWalThread, this); + // } + } else { + // for distribute version, some nodes are read only + if (options_.mode_ != DBOptions::MODE::CLUSTER_READONLY) { + // background flush thread + bg_flush_thread_ = std::thread(&SSDBImpl::TimingFlushThread, this); + } + } + + // for distribute version, some nodes are read only + if (options_.mode_ != DBOptions::MODE::CLUSTER_READONLY) { + // background build index thread + bg_index_thread_ = std::thread(&SSDBImpl::TimingIndexThread, this); + } + + // background metric thread + fiu_do_on("options_metric_enable", options_.metric_enable_ = true); + if (options_.metric_enable_) { + bg_metric_thread_ = std::thread(&SSDBImpl::TimingMetricThread, this); + } + return Status::OK(); } @@ -68,6 +169,37 @@ SSDBImpl::Stop() { initialized_.store(false, std::memory_order_release); + if (options_.mode_ != DBOptions::MODE::CLUSTER_READONLY) { + if (options_.wal_enable_) { + // // wait wal thread finish + // swn_wal_.Notify(); + // bg_wal_thread_.join(); + } else { + // flush all without merge + wal::MXLogRecord record; + record.type = wal::MXLogType::Flush; + ExecWalRecord(record); + + // wait flush thread finish + swn_flush_.Notify(); + bg_flush_thread_.join(); + } + + WaitMergeFileFinish(); + + swn_index_.Notify(); + bg_index_thread_.join(); + } + + // wait metric thread exit + if (options_.metric_enable_) { + swn_metric_.Notify(); + bg_metric_thread_.join(); + } + + snapshot::EventExecutor::GetInstance().Stop(); + snapshot::OperationExecutor::GetInstance().Stop(); + // LOG_ENGINE_TRACE_ << "DB service stop"; return Status::OK(); } @@ -77,15 +209,31 @@ SSDBImpl::CreateCollection(const snapshot::CreateCollectionContext& context) { CHECK_INITIALIZED; auto ctx = context; + // check uid existence/validation + bool has_uid = false; + for (auto& pair : ctx.fields_schema) { + if (pair.first->GetFtype() == meta::hybrid::DataType::UID) { + has_uid = true; + break; + } + } - if (options_.wal_enable_) { - ctx.lsn = wal_mgr_->CreateCollection(context.collection->GetName()); + // add uid field if not specified + if (!has_uid) { + auto uid_field = std::make_shared(DEFAULT_UID_NAME, 0, milvus::engine::FieldType::UID); + auto bloom_filter_element = std::make_shared( + 0, 0, DEFAULT_BLOOM_FILTER_NAME, milvus::engine::FieldElementType::FET_BLOOM_FILTER); + auto delete_doc_element = std::make_shared( + 0, 0, DEFAULT_DELETED_DOCS_NAME, milvus::engine::FieldElementType::FET_DELETED_DOCS); + + ctx.fields_schema[uid_field] = {bloom_filter_element, delete_doc_element}; } + if (options_.wal_enable_) { + // ctx.lsn = wal_mgr_->CreateCollection(context.collection->GetName()); + } auto op = std::make_shared(ctx); - auto status = op->Push(); - - return status; + return op->Push(); } Status @@ -94,41 +242,34 @@ SSDBImpl::DescribeCollection(const std::string& collection_name, snapshot::Colle CHECK_INITIALIZED; snapshot::ScopedSnapshotT ss; - auto status = snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_name); - if (!status.ok()) { - return status; - } + STATUS_CHECK(snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_name)); collection = ss->GetCollection(); - auto& fields = ss->GetResources(); for (auto& kv : fields) { fields_schema[kv.second.Get()] = ss->GetFieldElementsByField(kv.second->GetName()); } - return status; + return Status::OK(); } Status SSDBImpl::DropCollection(const std::string& name) { CHECK_INITIALIZED; - // dates partly delete files of the collection but currently we don't support LOG_ENGINE_DEBUG_ << "Prepare to delete collection " << name; snapshot::ScopedSnapshotT ss; auto& snapshots = snapshot::Snapshots::GetInstance(); - auto status = snapshots.GetSnapshot(ss, name); - if (!status.ok()) { - return status; - } + STATUS_CHECK(snapshots.GetSnapshot(ss, name)); if (options_.wal_enable_) { // SS TODO /* wal_mgr_->DropCollection(ss->GetCollectionId()); */ } - status = snapshots.DropCollection(ss->GetCollectionId(), std::numeric_limits::max()); - return status; + auto status = mem_mgr_->EraseMemVector(ss->GetCollectionId()); // not allow insert + + return snapshots.DropCollection(ss->GetCollectionId(), std::numeric_limits::max()); } Status @@ -150,22 +291,42 @@ SSDBImpl::AllCollections(std::vector& names) { return snapshot::Snapshots::GetInstance().GetCollectionNames(names); } +Status +SSDBImpl::GetCollectionRowCount(const std::string& collection_name, uint64_t& row_count) { + CHECK_INITIALIZED; + + snapshot::ScopedSnapshotT ss; + STATUS_CHECK(snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_name)); + + row_count = ss->GetCollectionCommit()->GetRowCount(); + return Status::OK(); +} + +Status +SSDBImpl::LoadCollection(const server::ContextPtr& context, const std::string& collection_name, + const std::vector& field_names, bool force) { + CHECK_INITIALIZED; + + snapshot::ScopedSnapshotT ss; + STATUS_CHECK(snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_name)); + + auto handler = std::make_shared(context, ss); + handler->Iterate(); + + return handler->GetStatus(); +} + Status SSDBImpl::CreatePartition(const std::string& collection_name, const std::string& partition_name) { CHECK_INITIALIZED; - uint64_t lsn = 0; snapshot::ScopedSnapshotT ss; - auto status = snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_name); - if (!status.ok()) { - return status; - } + STATUS_CHECK(snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_name)); + snapshot::LSN_TYPE lsn = 0; if (options_.wal_enable_) { // SS TODO - /* lsn = wal_mgr_->CreatePartition(collection_id, partition_tag); */ - } else { - lsn = ss->GetCollection()->GetLsn(); + /* lsn = wal_mgr_->CreatePartition(collection_name, partition_tag); */ } snapshot::OperationContext context; @@ -175,13 +336,8 @@ SSDBImpl::CreatePartition(const std::string& collection_name, const std::string& snapshot::PartitionContext p_ctx; p_ctx.name = partition_name; snapshot::PartitionPtr partition; - status = op->CommitNewPartition(p_ctx, partition); - if (!status.ok()) { - return status; - } - - status = op->Push(); - return status; + STATUS_CHECK(op->CommitNewPartition(p_ctx, partition)); + return op->Push(); } Status @@ -189,10 +345,7 @@ SSDBImpl::DropPartition(const std::string& collection_name, const std::string& p CHECK_INITIALIZED; snapshot::ScopedSnapshotT ss; - auto status = snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_name); - if (!status.ok()) { - return status; - } + STATUS_CHECK(snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_name)); // SS TODO: Is below step needed? Or How to implement it? /* mem_mgr_->EraseMemVector(partition_name); */ @@ -200,9 +353,7 @@ SSDBImpl::DropPartition(const std::string& collection_name, const std::string& p snapshot::PartitionContext context; context.name = partition_name; auto op = std::make_shared(context, ss); - status = op->Push(); - - return status; + return op->Push(); } Status @@ -210,30 +361,840 @@ SSDBImpl::ShowPartitions(const std::string& collection_name, std::vectorGetPartitionNames()); + return Status::OK(); +} + +Status +SSDBImpl::InsertEntities(const std::string& collection_name, const std::string& partition_name, + DataChunkPtr& data_chunk) { + CHECK_INITIALIZED; + + if (data_chunk == nullptr) { + return Status(DB_ERROR, "Null pointer"); + } + + snapshot::ScopedSnapshotT ss; + STATUS_CHECK(snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_name)); + + auto partition_ptr = ss->GetPartition(partition_name); + if (partition_ptr == nullptr) { + return Status(DB_NOT_FOUND, "Fail to get partition " + partition_name); + } + + /* Generate id */ + if (data_chunk->fixed_fields_.find(engine::DEFAULT_UID_NAME) == data_chunk->fixed_fields_.end()) { + SafeIDGenerator& id_generator = SafeIDGenerator::GetInstance(); + IDNumbers ids; + STATUS_CHECK(id_generator.GetNextIDNumbers(data_chunk->count_, ids)); + FIXED_FIELD_DATA& id_data = data_chunk->fixed_fields_[engine::DEFAULT_UID_NAME]; + id_data.resize(ids.size() * sizeof(int64_t)); + memcpy(id_data.data(), ids.data(), ids.size() * sizeof(int64_t)); + } + + if (options_.wal_enable_) { + return Status(SERVER_NOT_IMPLEMENT, "Wal not implemented"); + // auto vector_it = entity.vector_data_.begin(); + // if (!vector_it->second.binary_data_.empty()) { + // wal_mgr_->InsertEntities(collection_name, partition_name, entity.id_array_, + // vector_it->second.binary_data_, + // attr_nbytes, attr_data); + // } else if (!vector_it->second.float_data_.empty()) { + // wal_mgr_->InsertEntities(collection_name, partition_name, entity.id_array_, + // vector_it->second.float_data_, + // attr_nbytes, attr_data); + // } + // swn_wal_.Notify(); + } else { + // insert entities: collection_name is field id + wal::MXLogRecord record; + record.lsn = 0; + record.collection_id = collection_name; + record.partition_tag = partition_name; + record.data_chunk = data_chunk; + record.length = data_chunk->count_; + record.type = wal::MXLogType::Entity; + + STATUS_CHECK(ExecWalRecord(record)); + } + + return Status::OK(); +} + +Status +SSDBImpl::DeleteEntities(const std::string& collection_name, engine::IDNumbers entity_ids) { + CHECK_INITIALIZED; + + Status status; + if (options_.wal_enable_) { + return Status(SERVER_NOT_IMPLEMENT, "Wal not implemented"); + // wal_mgr_->DeleteById(collection_name, entity_ids); + // swn_wal_.Notify(); + } else { + wal::MXLogRecord record; + record.lsn = 0; // need to get from meta ? + record.type = wal::MXLogType::Delete; + record.collection_id = collection_name; + record.ids = entity_ids.data(); + record.length = entity_ids.size(); + + status = ExecWalRecord(record); + } + + return status; +} + +Status +SSDBImpl::Flush(const std::string& collection_name) { + if (!initialized_.load(std::memory_order_acquire)) { + return SHUTDOWN_ERROR; + } + + Status status; + bool has_collection; + status = HasCollection(collection_name, has_collection); if (!status.ok()) { return status; } + if (!has_collection) { + LOG_ENGINE_ERROR_ << "Collection to flush does not exist: " << collection_name; + return Status(DB_NOT_FOUND, "Collection to flush does not exist"); + } + + LOG_ENGINE_DEBUG_ << "Begin flush collection: " << collection_name; + + if (options_.wal_enable_) { + return Status(SERVER_NOT_IMPLEMENT, "Wal not implemented"); + // LOG_ENGINE_DEBUG_ << "WAL flush"; + // auto lsn = wal_mgr_->Flush(collection_name); + // if (lsn != 0) { + // swn_wal_.Notify(); + // flush_req_swn_.Wait(); + // } else { + // // no collection flushed, call merge task to cleanup files + // std::set merge_collection_names; + // StartMergeTask(merge_collection_names); + // } + } else { + LOG_ENGINE_DEBUG_ << "MemTable flush"; + InternalFlush(collection_name); + } + + LOG_ENGINE_DEBUG_ << "End flush collection: " << collection_name; - partition_names = std::move(ss->GetPartitionNames()); return status; } Status -SSDBImpl::PreloadCollection(const std::shared_ptr& context, const std::string& collection_name, - bool force) { +SSDBImpl::Flush() { + if (!initialized_.load(std::memory_order_acquire)) { + return SHUTDOWN_ERROR; + } + + LOG_ENGINE_DEBUG_ << "Begin flush all collections"; + + Status status; + fiu_do_on("options_wal_enable_false", options_.wal_enable_ = false); + if (options_.wal_enable_) { + return Status(SERVER_NOT_IMPLEMENT, "Wal not implemented"); + // LOG_ENGINE_DEBUG_ << "WAL flush"; + // auto lsn = wal_mgr_->Flush(); + // if (lsn != 0) { + // swn_wal_.Notify(); + // flush_req_swn_.Wait(); + // } else { + // // no collection flushed, call merge task to cleanup files + // std::set merge_collection_names; + // StartMergeTask(merge_collection_names); + // } + } else { + LOG_ENGINE_DEBUG_ << "MemTable flush"; + InternalFlush(); + } + + LOG_ENGINE_DEBUG_ << "End flush all collections"; + + return status; +} + +Status +SSDBImpl::Compact(const server::ContextPtr& context, const std::string& collection_name, double threshold) { + if (!initialized_.load(std::memory_order_acquire)) { + return SHUTDOWN_ERROR; + } + + LOG_ENGINE_DEBUG_ << "Before compacting, wait for build index thread to finish..."; + const std::lock_guard index_lock(build_index_mutex_); + const std::lock_guard merge_lock(flush_merge_compact_mutex_); + + Status status; + bool has_collection; + status = HasCollection(collection_name, has_collection); + if (!status.ok()) { + return status; + } + if (!has_collection) { + LOG_ENGINE_ERROR_ << "Collection to compact does not exist: " << collection_name; + return Status(DB_NOT_FOUND, "Collection to compact does not exist"); + } + + snapshot::ScopedSnapshotT latest_ss; + status = snapshot::Snapshots::GetInstance().GetSnapshot(latest_ss, collection_name); + if (!status.ok()) { + return status; + } + + auto& segments = latest_ss->GetResources(); + for (auto& kv : segments) { + // client break the connection, no need to continue + if (context && context->IsConnectionBroken()) { + LOG_ENGINE_DEBUG_ << "Client connection broken, stop compact operation"; + break; + } + + snapshot::ID_TYPE segment_id = kv.first; + auto read_visitor = engine::SegmentVisitor::Build(latest_ss, segment_id); + segment::SSSegmentReaderPtr segment_reader = + std::make_shared(options_.meta_.path_, read_visitor); + + segment::DeletedDocsPtr deleted_docs; + status = segment_reader->LoadDeletedDocs(deleted_docs); + if (!status.ok() || deleted_docs == nullptr) { + continue; // no deleted docs, no need to compact + } + + auto segment_commit = latest_ss->GetSegmentCommitBySegmentId(segment_id); + auto row_count = segment_commit->GetRowCount(); + if (row_count == 0) { + continue; + } + + auto deleted_count = deleted_docs->GetSize(); + if (deleted_count / (row_count + deleted_count) < threshold) { + continue; // no need to compact + } + + snapshot::IDS_TYPE ids = {segment_id}; + SSMergeTask merge_task(options_, latest_ss, ids); + status = merge_task.Execute(); + if (!status.ok()) { + LOG_ENGINE_ERROR_ << "Compact failed for segment " << segment_reader->GetSegmentPath() << ": " + << status.message(); + continue; // skip this file and try compact next one + } + } + + return status; +} + +Status +SSDBImpl::GetEntityByID(const std::string& collection_name, const IDNumbers& id_array, + const std::vector& field_names, DataChunkPtr& data_chunk) { CHECK_INITIALIZED; snapshot::ScopedSnapshotT ss; - auto status = snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_name); + STATUS_CHECK(snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_name)); + + std::string dir_root = options_.meta_.path_; + auto handler = std::make_shared(nullptr, ss, dir_root, id_array, field_names); + handler->Iterate(); + STATUS_CHECK(handler->GetStatus()); + + data_chunk = handler->data_chunk_; + return Status::OK(); +} + +Status +SSDBImpl::GetEntityIDs(const std::string& collection_name, int64_t segment_id, IDNumbers& entity_ids) { + CHECK_INITIALIZED; + + snapshot::ScopedSnapshotT ss; + STATUS_CHECK(snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_name)); + + auto read_visitor = engine::SegmentVisitor::Build(ss, segment_id); + segment::SSSegmentReaderPtr segment_reader = + std::make_shared(options_.meta_.path_, read_visitor); + + STATUS_CHECK(segment_reader->LoadUids(entity_ids)); + + return Status::OK(); +} + +Status +SSDBImpl::CreateIndex(const std::shared_ptr& context, const std::string& collection_name, + const std::string& field_name, const CollectionIndex& index) { + CHECK_INITIALIZED; + + // step 1: wait merge file thread finished to avoid duplicate data bug + auto status = Flush(); + WaitMergeFileFinish(); // let merge file thread finish + + // step 2: compare old index and new index + CollectionIndex new_index = index; + CollectionIndex old_index; + status = DescribeIndex(collection_name, field_name, old_index); if (!status.ok()) { return status; } - auto handler = std::make_shared(context, ss); - handler->Iterate(); + if (old_index.metric_type_ != (int32_t)MetricType::INVALID) { + new_index.metric_type_ = old_index.metric_type_; // dont change metric type, it was defined by CreateCollection + } + if (utils::IsSameIndex(old_index, new_index)) { + return Status::OK(); // same index + } - return handler->GetStatus(); + // step 3: drop old index + DropIndex(collection_name); + WaitMergeFileFinish(); // let merge file thread finish since DropIndex start a merge task + + // step 4: create field element for index + status = SetSnapshotIndex(collection_name, field_name, new_index); + if (!status.ok()) { + return status; + } + + // step 5: start background build index thread + std::vector collection_names = {collection_name}; + WaitBuildIndexFinish(); + StartBuildIndexTask(collection_names); + + // step 6: iterate segments need to be build index, wait until all segments are built + while (true) { + SnapshotVisitor ss_visitor(collection_name); + snapshot::IDS_TYPE segment_ids; + ss_visitor.SegmentsToIndex(field_name, segment_ids); + if (segment_ids.empty()) { + break; + } + + index_req_swn_.Wait_For(std::chrono::seconds(1)); + + // client break the connection, no need to block, check every 1 second + if (context && context->IsConnectionBroken()) { + LOG_ENGINE_DEBUG_ << "Client connection broken, build index in background"; + break; // just break, not return, continue to update partitions files to to_index + } + } + + return Status::OK(); +} + +Status +SSDBImpl::DescribeIndex(const std::string& collection_name, const std::string& field_name, CollectionIndex& index) { + CHECK_INITIALIZED; + + return GetSnapshotIndex(collection_name, field_name, index); +} + +Status +SSDBImpl::DropIndex(const std::string& collection_name, const std::string& field_name) { + CHECK_INITIALIZED; + + LOG_ENGINE_DEBUG_ << "Drop index for collection: " << collection_name; + + STATUS_CHECK(DeleteSnapshotIndex(collection_name, field_name)); + + std::set merge_collection_names = {collection_name}; + StartMergeTask(merge_collection_names, true); + return Status::OK(); +} + +Status +SSDBImpl::DropIndex(const std::string& collection_name) { + CHECK_INITIALIZED; + + LOG_ENGINE_DEBUG_ << "Drop index for collection: " << collection_name; + + std::vector field_names; + { + snapshot::ScopedSnapshotT ss; + STATUS_CHECK(snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_name)); + field_names = ss->GetFieldNames(); + } + + snapshot::OperationContext context; + for (auto& field_name : field_names) { + STATUS_CHECK(DeleteSnapshotIndex(collection_name, field_name)); + } + + std::set merge_collection_names = {collection_name}; + StartMergeTask(merge_collection_names, true); + return Status::OK(); +} + +Status +SSDBImpl::Query(const server::ContextPtr& context, const std::string& collection_name, const query::QueryPtr& query_ptr, + engine::QueryResultPtr& result) { + CHECK_INITIALIZED; + + TimeRecorder rc("SSDBImpl::Query"); + + scheduler::SSSearchJobPtr job = std::make_shared(nullptr, options_, query_ptr); + + /* put search job to scheduler and wait job finish */ + scheduler::JobMgrInst::GetInstance()->Put(job); + job->WaitFinish(); + + if (!job->status().ok()) { + return job->status(); + } + + result = job->query_result(); + + rc.ElapseFromBegin("Engine query totally cost"); + + return job->status(); +} + +//////////////////////////////////////////////////////////////////////////////// +// Internal APIs +//////////////////////////////////////////////////////////////////////////////// +void +SSDBImpl::InternalFlush(const std::string& collection_name) { + wal::MXLogRecord record; + record.type = wal::MXLogType::Flush; + record.collection_id = collection_name; + ExecWalRecord(record); +} + +void +SSDBImpl::TimingFlushThread() { + SetThreadName("flush_thread"); + server::SystemInfo::GetInstance().Init(); + while (true) { + if (!initialized_.load(std::memory_order_acquire)) { + LOG_ENGINE_DEBUG_ << "DB background flush thread exit"; + break; + } + + InternalFlush(); + if (options_.auto_flush_interval_ > 0) { + swn_flush_.Wait_For(std::chrono::seconds(options_.auto_flush_interval_)); + } else { + swn_flush_.Wait(); + } + } +} + +void +SSDBImpl::StartMetricTask() { + server::Metrics::GetInstance().KeepingAliveCounterIncrement(BACKGROUND_METRIC_INTERVAL); + int64_t cache_usage = cache::CpuCacheMgr::GetInstance()->CacheUsage(); + int64_t cache_total = cache::CpuCacheMgr::GetInstance()->CacheCapacity(); + fiu_do_on("DBImpl.StartMetricTask.InvalidTotalCache", cache_total = 0); + + if (cache_total > 0) { + double cache_usage_double = cache_usage; + server::Metrics::GetInstance().CpuCacheUsageGaugeSet(cache_usage_double * 100 / cache_total); + } else { + server::Metrics::GetInstance().CpuCacheUsageGaugeSet(0); + } + + server::Metrics::GetInstance().GpuCacheUsageGaugeSet(); + /* SS TODO */ + // uint64_t size; + // Size(size); + // server::Metrics::GetInstance().DataFileSizeGaugeSet(size); + server::Metrics::GetInstance().CPUUsagePercentSet(); + server::Metrics::GetInstance().RAMUsagePercentSet(); + server::Metrics::GetInstance().GPUPercentGaugeSet(); + server::Metrics::GetInstance().GPUMemoryUsageGaugeSet(); + server::Metrics::GetInstance().OctetsSet(); + + server::Metrics::GetInstance().CPUCoreUsagePercentSet(); + server::Metrics::GetInstance().GPUTemperature(); + server::Metrics::GetInstance().CPUTemperature(); + server::Metrics::GetInstance().PushToGateway(); +} + +void +SSDBImpl::TimingMetricThread() { + SetThreadName("metric_thread"); + server::SystemInfo::GetInstance().Init(); + while (true) { + if (!initialized_.load(std::memory_order_acquire)) { + LOG_ENGINE_DEBUG_ << "DB background metric thread exit"; + break; + } + + swn_metric_.Wait_For(std::chrono::seconds(BACKGROUND_METRIC_INTERVAL)); + StartMetricTask(); + meta::FilesHolder::PrintInfo(); + } +} + +void +SSDBImpl::StartBuildIndexTask(const std::vector& collection_names) { + // build index has been finished? + { + std::lock_guard lck(index_result_mutex_); + if (!index_thread_results_.empty()) { + std::chrono::milliseconds span(10); + if (index_thread_results_.back().wait_for(span) == std::future_status::ready) { + index_thread_results_.pop_back(); + } + } + } + + // add new build index task + { + std::lock_guard lck(index_result_mutex_); + if (index_thread_results_.empty()) { + index_thread_results_.push_back( + index_thread_pool_.enqueue(&SSDBImpl::BackgroundBuildIndexTask, this, collection_names)); + } + } +} + +void +SSDBImpl::BackgroundBuildIndexTask(std::vector collection_names) { + std::unique_lock lock(build_index_mutex_); + + for (auto collection_name : collection_names) { + SnapshotVisitor ss_visitor(collection_name); + + snapshot::IDS_TYPE segment_ids; + ss_visitor.SegmentsToIndex("", segment_ids); + + scheduler::SSBuildIndexJobPtr job = + std::make_shared(options_, collection_name, segment_ids); + + scheduler::JobMgrInst::GetInstance()->Put(job); + job->WaitFinish(); + + if (!job->status().ok()) { + LOG_ENGINE_ERROR_ << job->status().message(); + break; + } + } +} + +void +SSDBImpl::TimingIndexThread() { + SetThreadName("index_thread"); + server::SystemInfo::GetInstance().Init(); + while (true) { + if (!initialized_.load(std::memory_order_acquire)) { + WaitMergeFileFinish(); + WaitBuildIndexFinish(); + + LOG_ENGINE_DEBUG_ << "DB background thread exit"; + break; + } + + swn_index_.Wait_For(std::chrono::seconds(BACKGROUND_INDEX_INTERVAL)); + + std::vector collection_names; + snapshot::Snapshots::GetInstance().GetCollectionNames(collection_names); + WaitMergeFileFinish(); + StartBuildIndexTask(collection_names); + } +} + +void +SSDBImpl::WaitBuildIndexFinish() { + // LOG_ENGINE_DEBUG_ << "Begin WaitBuildIndexFinish"; + std::lock_guard lck(index_result_mutex_); + for (auto& iter : index_thread_results_) { + iter.wait(); + } + // LOG_ENGINE_DEBUG_ << "End WaitBuildIndexFinish"; +} + +void +SSDBImpl::TimingWalThread() { + SetThreadName("wal_thread"); + server::SystemInfo::GetInstance().Init(); + + std::chrono::system_clock::time_point next_auto_flush_time; + auto get_next_auto_flush_time = [&]() { + return std::chrono::system_clock::now() + std::chrono::seconds(options_.auto_flush_interval_); + }; + if (options_.auto_flush_interval_ > 0) { + next_auto_flush_time = get_next_auto_flush_time(); + } + + InternalFlush(); + while (true) { + if (options_.auto_flush_interval_ > 0) { + if (std::chrono::system_clock::now() >= next_auto_flush_time) { + InternalFlush(); + next_auto_flush_time = get_next_auto_flush_time(); + } + } + + wal::MXLogRecord record; + auto error_code = wal_mgr_->GetNextRecord(record); + if (error_code != WAL_SUCCESS) { + LOG_ENGINE_ERROR_ << "WAL background GetNextRecord error"; + break; + } + + if (record.type != wal::MXLogType::None) { + ExecWalRecord(record); + if (record.type == wal::MXLogType::Flush) { + // notify flush request to return + flush_req_swn_.Notify(); + + // if user flush all manually, update auto flush also + if (record.collection_id.empty() && options_.auto_flush_interval_ > 0) { + next_auto_flush_time = get_next_auto_flush_time(); + } + } + + } else { + if (!initialized_.load(std::memory_order_acquire)) { + InternalFlush(); + flush_req_swn_.Notify(); + // SS TODO + // WaitMergeFileFinish(); + // WaitBuildIndexFinish(); + LOG_ENGINE_DEBUG_ << "WAL background thread exit"; + break; + } + + if (options_.auto_flush_interval_ > 0) { + swn_wal_.Wait_Until(next_auto_flush_time); + } else { + swn_wal_.Wait(); + } + } + } +} + +Status +SSDBImpl::ExecWalRecord(const wal::MXLogRecord& record) { + auto collections_flushed = [&](const std::string& collection_name, + const std::set& target_collection_names) -> uint64_t { + uint64_t max_lsn = 0; + if (options_.wal_enable_ && !target_collection_names.empty()) { + // uint64_t lsn = 0; + // for (auto& collection_name : target_collection_names) { + // snapshot::ScopedSnapshotT ss; + // snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_name); + // lsn = ss->GetMaxLsn(); + // if (lsn > max_lsn) { + // max_lsn = lsn; + // } + // } + // wal_mgr_->CollectionFlushed(collection_name, lsn); + } + + std::set merge_collection_names; + for (auto& collection : target_collection_names) { + merge_collection_names.insert(collection); + } + StartMergeTask(merge_collection_names); + return max_lsn; + }; + + auto force_flush_if_mem_full = [&]() -> uint64_t { + if (mem_mgr_->GetCurrentMem() > options_.insert_buffer_size_) { + LOG_ENGINE_DEBUG_ << LogOut("[%s][%ld] ", "insert", 0) << "Insert buffer size exceeds limit. Force flush"; + InternalFlush(); + } + }; + + auto get_collection_partition_id = [&](const wal::MXLogRecord& record, int64_t& col_id, + int64_t& part_id) -> Status { + snapshot::ScopedSnapshotT ss; + auto status = snapshot::Snapshots::GetInstance().GetSnapshot(ss, record.collection_id); + if (!status.ok()) { + LOG_ENGINE_ERROR_ << LogOut("[%s][%ld] ", "insert", 0) << "Get snapshot fail: " << status.message(); + return status; + } + col_id = ss->GetCollectionId(); + snapshot::PartitionPtr part = ss->GetPartition(record.partition_tag); + if (part == nullptr) { + LOG_ENGINE_ERROR_ << LogOut("[%s][%ld] ", "insert", 0) << "Get partition fail: " << status.message(); + return status; + } + part_id = part->GetID(); + + return Status::OK(); + }; + + Status status; + + switch (record.type) { + case wal::MXLogType::Entity: { + int64_t collection_name = 0, partition_id = 0; + auto status = get_collection_partition_id(record, collection_name, partition_id); + if (!status.ok()) { + LOG_WAL_ERROR_ << LogOut("[%s][%ld] ", "insert", 0) << status.message(); + return status; + } + + status = mem_mgr_->InsertEntities(collection_name, partition_id, record.data_chunk, record.lsn); + force_flush_if_mem_full(); + + // metrics + milvus::server::CollectInsertMetrics metrics(record.length, status); + break; + } + + case wal::MXLogType::Delete: { + snapshot::ScopedSnapshotT ss; + auto status = snapshot::Snapshots::GetInstance().GetSnapshot(ss, record.collection_id); + if (!status.ok()) { + LOG_WAL_ERROR_ << LogOut("[%s][%ld] ", "delete", 0) << "Get snapshot fail: " << status.message(); + return status; + } + + if (record.length == 1) { + status = mem_mgr_->DeleteEntity(ss->GetCollectionId(), *record.ids, record.lsn); + if (!status.ok()) { + return status; + } + } else { + status = mem_mgr_->DeleteEntities(ss->GetCollectionId(), record.length, record.ids, record.lsn); + if (!status.ok()) { + return status; + } + } + break; + } + + case wal::MXLogType::Flush: { + if (!record.collection_id.empty()) { + // flush one collection + snapshot::ScopedSnapshotT ss; + auto status = snapshot::Snapshots::GetInstance().GetSnapshot(ss, record.collection_id); + if (!status.ok()) { + LOG_WAL_ERROR_ << LogOut("[%s][%ld] ", "flush", 0) << "Get snapshot fail: " << status.message(); + return status; + } + + const std::lock_guard lock(flush_merge_compact_mutex_); + int64_t collection_name = ss->GetCollectionId(); + status = mem_mgr_->Flush(collection_name); + if (!status.ok()) { + return status; + } + + std::set flushed_collections; + collections_flushed(record.collection_id, flushed_collections); + + } else { + // flush all collections + std::set collection_names; + { + const std::lock_guard lock(flush_merge_compact_mutex_); + status = mem_mgr_->Flush(collection_names); + } + + std::set flushed_collections; + for (auto id : collection_names) { + snapshot::ScopedSnapshotT ss; + auto status = snapshot::Snapshots::GetInstance().GetSnapshot(ss, id); + if (!status.ok()) { + LOG_WAL_ERROR_ << LogOut("[%s][%ld] ", "flush", 0) << "Get snapshot fail: " << status.message(); + return status; + } + + flushed_collections.insert(ss->GetName()); + } + + uint64_t lsn = collections_flushed("", flushed_collections); + if (options_.wal_enable_) { + // wal_mgr_->RemoveOldFiles(lsn); + } + } + break; + } + + default: + break; + } + + return status; +} + +void +SSDBImpl::StartMergeTask(const std::set& collection_names, bool force_merge_all) { + // LOG_ENGINE_DEBUG_ << "Begin StartMergeTask"; + // merge task has been finished? + { + std::lock_guard lck(merge_result_mutex_); + if (!merge_thread_results_.empty()) { + std::chrono::milliseconds span(10); + if (merge_thread_results_.back().wait_for(span) == std::future_status::ready) { + merge_thread_results_.pop_back(); + } + } + } + + // add new merge task + { + std::lock_guard lck(merge_result_mutex_); + if (merge_thread_results_.empty()) { + // start merge file thread + merge_thread_results_.push_back( + merge_thread_pool_.enqueue(&SSDBImpl::BackgroundMerge, this, collection_names, force_merge_all)); + } + } + + // LOG_ENGINE_DEBUG_ << "End StartMergeTask"; +} + +void +SSDBImpl::BackgroundMerge(std::set collection_names, bool force_merge_all) { + // LOG_ENGINE_TRACE_ << " Background merge thread start"; + + Status status; + for (auto& collection_name : collection_names) { + const std::lock_guard lock(flush_merge_compact_mutex_); + + auto old_strategy = merge_mgr_ptr_->Strategy(); + if (force_merge_all) { + merge_mgr_ptr_->UseStrategy(MergeStrategyType::ADAPTIVE); + } + + auto status = merge_mgr_ptr_->MergeFiles(collection_name); + merge_mgr_ptr_->UseStrategy(old_strategy); + if (!status.ok()) { + LOG_ENGINE_ERROR_ << "Failed to get merge files for collection: " << collection_name + << " reason:" << status.message(); + } + + if (!initialized_.load(std::memory_order_acquire)) { + LOG_ENGINE_DEBUG_ << "Server will shutdown, skip merge action for collection: " << collection_name; + break; + } + } + + // TODO: cleanup with ttl +} + +void +SSDBImpl::WaitMergeFileFinish() { + // LOG_ENGINE_DEBUG_ << "Begin WaitMergeFileFinish"; + std::lock_guard lck(merge_result_mutex_); + for (auto& iter : merge_thread_results_) { + iter.wait(); + } + // LOG_ENGINE_DEBUG_ << "End WaitMergeFileFinish"; +} + +void +SSDBImpl::SuspendIfFirst() { + std::lock_guard lock(suspend_build_mutex_); + if (++live_search_num_ == 1) { + LOG_ENGINE_TRACE_ << "live_search_num_: " << live_search_num_; + knowhere::BuilderSuspend(); + } +} + +void +SSDBImpl::ResumeIfLast() { + std::lock_guard lock(suspend_build_mutex_); + if (--live_search_num_ == 0) { + LOG_ENGINE_TRACE_ << "live_search_num_: " << live_search_num_; + knowhere::BuildResume(); + } } } // namespace engine diff --git a/core/src/db/SSDBImpl.h b/core/src/db/SSDBImpl.h index 1cc53387c557..6209772ade4f 100644 --- a/core/src/db/SSDBImpl.h +++ b/core/src/db/SSDBImpl.h @@ -12,17 +12,26 @@ #pragma once #include +#include #include #include +#include +#include #include +#include +#include #include #include "db/Options.h" +#include "db/SimpleWaitNotify.h" #include "db/SnapshotHandlers.h" +#include "db/insert/SSMemManager.h" +#include "db/merge/MergeManager.h" #include "db/snapshot/Context.h" #include "db/snapshot/ResourceTypes.h" #include "db/snapshot/Resources.h" #include "utils/Status.h" +#include "utils/ThreadPool.h" #include "wal/WalManager.h" namespace milvus { @@ -32,6 +41,14 @@ class SSDBImpl { public: explicit SSDBImpl(const DBOptions& options); + ~SSDBImpl(); + + Status + Start(); + + Status + Stop(); + Status CreateCollection(const snapshot::CreateCollectionContext& context); @@ -49,8 +66,11 @@ class SSDBImpl { AllCollections(std::vector& names); Status - PreloadCollection(const std::shared_ptr& context, const std::string& collection_name, - bool force = false); + GetCollectionRowCount(const std::string& collection_name, uint64_t& row_count); + + Status + LoadCollection(const server::ContextPtr& context, const std::string& collection_name, + const std::vector& field_names, bool force = false); Status CreatePartition(const std::string& collection_name, const std::string& partition_name); @@ -61,19 +81,130 @@ class SSDBImpl { Status ShowPartitions(const std::string& collection_name, std::vector& partition_names); - ~SSDBImpl(); + Status + InsertEntities(const std::string& collection_name, const std::string& partition_name, DataChunkPtr& data_chunk); Status - Start(); + DeleteEntities(const std::string& collection_name, engine::IDNumbers entity_ids); Status - Stop(); + Flush(const std::string& collection_name); + + Status + Flush(); + + Status + Compact(const server::ContextPtr& context, const std::string& collection_name, double threshold = 0.0); + + Status + GetEntityByID(const std::string& collection_name, const IDNumbers& id_array, + const std::vector& field_names, DataChunkPtr& data_chunk); + + Status + GetEntityIDs(const std::string& collection_name, int64_t segment_id, IDNumbers& entity_ids); + + Status + CreateIndex(const std::shared_ptr& context, const std::string& collection_name, + const std::string& field_name, const CollectionIndex& index); + + Status + DescribeIndex(const std::string& collection_name, const std::string& field_name, CollectionIndex& index); + + Status + DropIndex(const std::string& collection_name, const std::string& field_name); + + Status + DropIndex(const std::string& collection_name); + + Status + Query(const server::ContextPtr& context, const std::string& collection_name, const query::QueryPtr& query_ptr, + engine::QueryResultPtr& result); + + private: + void + InternalFlush(const std::string& collection_name = ""); + + void + TimingFlushThread(); + + void + StartMetricTask(); + + void + TimingMetricThread(); + + void + StartBuildIndexTask(const std::vector& collection_names); + + void + BackgroundBuildIndexTask(std::vector collection_names); + + void + TimingIndexThread(); + + void + WaitBuildIndexFinish(); + + void + TimingWalThread(); + + Status + ExecWalRecord(const wal::MXLogRecord& record); + + void + StartMergeTask(const std::set& collection_names, bool force_merge_all = false); + + void + BackgroundMerge(std::set collection_names, bool force_merge_all); + + void + WaitMergeFileFinish(); + + void + SuspendIfFirst(); + + void + ResumeIfLast(); private: DBOptions options_; std::atomic initialized_; + + SSMemManagerPtr mem_mgr_; + MergeManagerPtr merge_mgr_ptr_; + std::shared_ptr wal_mgr_; + std::thread bg_wal_thread_; + + std::thread bg_flush_thread_; + std::thread bg_metric_thread_; + std::thread bg_index_thread_; + + SimpleWaitNotify swn_wal_; + SimpleWaitNotify swn_flush_; + SimpleWaitNotify swn_metric_; + SimpleWaitNotify swn_index_; + + SimpleWaitNotify flush_req_swn_; + SimpleWaitNotify index_req_swn_; + + ThreadPool merge_thread_pool_; + std::mutex merge_result_mutex_; + std::list> merge_thread_results_; + + ThreadPool index_thread_pool_; + std::mutex index_result_mutex_; + std::list> index_thread_results_; + + std::mutex build_index_mutex_; + + std::mutex flush_merge_compact_mutex_; + + int64_t live_search_num_ = 0; + std::mutex suspend_build_mutex_; }; // SSDBImpl +using SSDBImplPtr = std::shared_ptr; + } // namespace engine } // namespace milvus diff --git a/core/src/db/SimpleWaitNotify.h b/core/src/db/SimpleWaitNotify.h new file mode 100644 index 000000000000..a92a8ccee768 --- /dev/null +++ b/core/src/db/SimpleWaitNotify.h @@ -0,0 +1,62 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include + +namespace milvus { +namespace engine { + +struct SimpleWaitNotify { + bool notified_ = false; + std::mutex mutex_; + std::condition_variable cv_; + + void + Wait() { + std::unique_lock lck(mutex_); + if (!notified_) { + cv_.wait(lck); + } + notified_ = false; + } + + void + Wait_Until(const std::chrono::system_clock::time_point& tm_pint) { + std::unique_lock lck(mutex_); + if (!notified_) { + cv_.wait_until(lck, tm_pint); + } + notified_ = false; + } + + void + Wait_For(const std::chrono::system_clock::duration& tm_dur) { + std::unique_lock lck(mutex_); + if (!notified_) { + cv_.wait_for(lck, tm_dur); + } + notified_ = false; + } + + void + Notify() { + std::unique_lock lck(mutex_); + notified_ = true; + lck.unlock(); + cv_.notify_one(); + } +}; + +} // namespace engine +} // namespace milvus diff --git a/core/src/db/SnapshotHandlers.cpp b/core/src/db/SnapshotHandlers.cpp index f16476a4d0bd..e20e025ed3f2 100644 --- a/core/src/db/SnapshotHandlers.cpp +++ b/core/src/db/SnapshotHandlers.cpp @@ -10,7 +10,18 @@ // or implied. See the License for the specific language governing permissions and limitations under the License. #include "db/SnapshotHandlers.h" +#include "db/SnapshotVisitor.h" +#include "db/Types.h" +#include "db/meta/MetaConsts.h" #include "db/meta/MetaTypes.h" +#include "db/snapshot/ResourceHelper.h" +#include "db/snapshot/Resources.h" +#include "db/snapshot/Snapshot.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" +#include "segment/SSSegmentReader.h" + +#include +#include namespace milvus { namespace engine { @@ -23,7 +34,8 @@ LoadVectorFieldElementHandler::LoadVectorFieldElementHandler(const std::shared_p Status LoadVectorFieldElementHandler::Handle(const snapshot::FieldElementPtr& field_element) { - if (field_->GetFtype() != snapshot::FieldType::VECTOR) { + if (field_->GetFtype() != engine::FieldType::VECTOR_FLOAT && + field_->GetFtype() != engine::FieldType::VECTOR_BINARY) { return Status(DB_ERROR, "Should be VECTOR field"); } if (field_->GetID() != field_element->GetFieldId()) { @@ -40,7 +52,7 @@ LoadVectorFieldHandler::LoadVectorFieldHandler(const std::shared_ptrGetFtype() != snapshot::FieldType::VECTOR) { + if (field->GetFtype() != engine::FieldType::VECTOR_FLOAT && field->GetFtype() != engine::FieldType::VECTOR_BINARY) { return Status::OK(); } if (context_ && context_->IsConnectionBroken()) { @@ -62,36 +74,113 @@ LoadVectorFieldHandler::Handle(const snapshot::FieldPtr& field) { return status; } -SegmentsToSearchCollector::SegmentsToSearchCollector(snapshot::ScopedSnapshotT ss, meta::FilesHolder& holder) - : BaseT(ss), holder_(holder) { +/////////////////////////////////////////////////////////////////////////////// +SegmentsToSearchCollector::SegmentsToSearchCollector(snapshot::ScopedSnapshotT ss, snapshot::IDS_TYPE& segment_ids) + : BaseT(ss), segment_ids_(segment_ids) { } Status SegmentsToSearchCollector::Handle(const snapshot::SegmentCommitPtr& segment_commit) { - // SS TODO - meta::SegmentSchema schema; - /* schema.id_ = segment_commit->GetSegmentId(); */ - /* schema.file_type_ = resRow["file_type"]; */ - /* schema.file_size_ = resRow["file_size"]; */ - /* schema.row_count_ = resRow["row_count"]; */ - /* schema.date_ = resRow["date"]; */ - /* schema.engine_type_ = resRow["engine_type"]; */ - /* schema.created_on_ = resRow["created_on"]; */ - /* schema.updated_time_ = resRow["updated_time"]; */ - - /* schema.dimension_ = collection_schema.dimension_; */ - /* schema.index_file_size_ = collection_schema.index_file_size_; */ - /* schema.index_params_ = collection_schema.index_params_; */ - /* schema.metric_type_ = collection_schema.metric_type_; */ - - /* auto status = utils::GetCollectionFilePath(options_, schema); */ - /* if (!status.ok()) { */ - /* ret = status; */ - /* continue; */ - /* } */ - - holder_.MarkFile(schema); + segment_ids_.push_back(segment_commit->GetSegmentId()); +} + +/////////////////////////////////////////////////////////////////////////////// +SegmentsToIndexCollector::SegmentsToIndexCollector(snapshot::ScopedSnapshotT ss, const std::string& field_name, + snapshot::IDS_TYPE& segment_ids) + : BaseT(ss), field_name_(field_name), segment_ids_(segment_ids) { +} + +Status +SegmentsToIndexCollector::Handle(const snapshot::SegmentCommitPtr& segment_commit) { + if (segment_commit->GetRowCount() < meta::BUILD_INDEX_THRESHOLD) { + return Status::OK(); + } + + auto segment_visitor = engine::SegmentVisitor::Build(ss_, segment_commit->GetSegmentId()); + if (field_name_.empty()) { + auto field_visitors = segment_visitor->GetFieldVisitors(); + for (auto& pair : field_visitors) { + auto& field_visitor = pair.second; + auto element_visitor = field_visitor->GetElementVisitor(engine::FieldElementType::FET_INDEX); + if (element_visitor != nullptr && element_visitor->GetFile() == nullptr) { + segment_ids_.push_back(segment_commit->GetSegmentId()); + break; + } + } + } else { + auto field_visitor = segment_visitor->GetFieldVisitor(field_name_); + auto element_visitor = field_visitor->GetElementVisitor(engine::FieldElementType::FET_INDEX); + if (element_visitor != nullptr && element_visitor->GetFile() == nullptr) { + segment_ids_.push_back(segment_commit->GetSegmentId()); + } + } + + return Status::OK(); } +/////////////////////////////////////////////////////////////////////////////// +GetEntityByIdSegmentHandler::GetEntityByIdSegmentHandler(const std::shared_ptr& context, + engine::snapshot::ScopedSnapshotT ss, + const std::string& dir_root, const IDNumbers& ids, + const std::vector& field_names) + : BaseT(ss), context_(context), dir_root_(dir_root), ids_(ids), field_names_(field_names) { +} + +Status +GetEntityByIdSegmentHandler::Handle(const snapshot::SegmentPtr& segment) { + LOG_ENGINE_DEBUG_ << "Get entity by id in segment " << segment->GetID(); + + auto segment_visitor = SegmentVisitor::Build(ss_, segment->GetID()); + if (segment_visitor == nullptr) { + return Status(DB_ERROR, "Fail to build segment visitor with id " + std::to_string(segment->GetID())); + } + segment::SSSegmentReader segment_reader(dir_root_, segment_visitor); + + auto uid_field_visitor = segment_visitor->GetFieldVisitor(DEFAULT_UID_NAME); + + // load UID's bloom filter file + segment::IdBloomFilterPtr id_bloom_filter_ptr; + STATUS_CHECK(segment_reader.LoadBloomFilter(id_bloom_filter_ptr)); + + std::vector uids; + segment::DeletedDocsPtr deleted_docs_ptr; + std::vector offsets; + for (auto id : ids_) { + // fast check using bloom filter + if (!id_bloom_filter_ptr->Check(id)) { + continue; + } + + // check if id really exists in uids + if (uids.empty()) { + STATUS_CHECK(segment_reader.LoadUids(uids)); // lazy load + } + auto found = std::find(uids.begin(), uids.end(), id); + if (found == uids.end()) { + continue; + } + + // check if this id is deleted + auto offset = std::distance(uids.begin(), found); + if (deleted_docs_ptr == nullptr) { + STATUS_CHECK(segment_reader.LoadDeletedDocs(deleted_docs_ptr)); // lazy load + } + if (deleted_docs_ptr) { + auto& deleted_docs = deleted_docs_ptr->GetDeletedDocs(); + auto deleted = std::find(deleted_docs.begin(), deleted_docs.end(), offset); + if (deleted != deleted_docs.end()) { + continue; + } + } + offsets.push_back(offset); + } + + STATUS_CHECK(segment_reader.LoadFieldsEntities(field_names_, offsets, data_chunk_)); + + return Status::OK(); +} + +/////////////////////////////////////////////////////////////////////////////// + } // namespace engine } // namespace milvus diff --git a/core/src/db/SnapshotHandlers.h b/core/src/db/SnapshotHandlers.h index c695954e5a20..58ca17e8207b 100644 --- a/core/src/db/SnapshotHandlers.h +++ b/core/src/db/SnapshotHandlers.h @@ -11,12 +11,18 @@ #pragma once +#include "db/Types.h" #include "db/meta/FilesHolder.h" +#include "db/snapshot/IterateHandler.h" #include "db/snapshot/Snapshot.h" +#include "segment/Segment.h" +#include "segment/Types.h" #include "server/context/Context.h" #include "utils/Log.h" #include +#include +#include namespace milvus { namespace engine { @@ -24,37 +30,70 @@ namespace engine { struct LoadVectorFieldElementHandler : public snapshot::IterateHandler { using ResourceT = snapshot::FieldElement; using BaseT = snapshot::IterateHandler; - LoadVectorFieldElementHandler(const std::shared_ptr& context, snapshot::ScopedSnapshotT ss, + LoadVectorFieldElementHandler(const server::ContextPtr& context, snapshot::ScopedSnapshotT ss, const snapshot::FieldPtr& field); Status Handle(const typename ResourceT::Ptr&) override; - const std::shared_ptr& context_; - const snapshot::FieldPtr& field_; + const server::ContextPtr context_; + const snapshot::FieldPtr field_; }; struct LoadVectorFieldHandler : public snapshot::IterateHandler { using ResourceT = snapshot::Field; using BaseT = snapshot::IterateHandler; - LoadVectorFieldHandler(const std::shared_ptr& context, snapshot::ScopedSnapshotT ss); + LoadVectorFieldHandler(const server::ContextPtr& context, snapshot::ScopedSnapshotT ss); Status Handle(const typename ResourceT::Ptr&) override; - const std::shared_ptr& context_; + const server::ContextPtr context_; }; struct SegmentsToSearchCollector : public snapshot::IterateHandler { using ResourceT = snapshot::SegmentCommit; using BaseT = snapshot::IterateHandler; - SegmentsToSearchCollector(snapshot::ScopedSnapshotT ss, meta::FilesHolder& holder); + SegmentsToSearchCollector(snapshot::ScopedSnapshotT ss, snapshot::IDS_TYPE& segment_ids); Status Handle(const typename ResourceT::Ptr&) override; - meta::FilesHolder& holder_; + snapshot::IDS_TYPE& segment_ids_; }; +struct SegmentsToIndexCollector : public snapshot::IterateHandler { + using ResourceT = snapshot::SegmentCommit; + using BaseT = snapshot::IterateHandler; + SegmentsToIndexCollector(snapshot::ScopedSnapshotT ss, const std::string& field_name, + snapshot::IDS_TYPE& segment_ids); + + Status + Handle(const typename ResourceT::Ptr&) override; + + std::string field_name_; + snapshot::IDS_TYPE& segment_ids_; +}; + +/////////////////////////////////////////////////////////////////////////////// +struct GetEntityByIdSegmentHandler : public snapshot::IterateHandler { + using ResourceT = snapshot::Segment; + using BaseT = snapshot::IterateHandler; + GetEntityByIdSegmentHandler(const server::ContextPtr& context, snapshot::ScopedSnapshotT ss, + const std::string& dir_root, const IDNumbers& ids, + const std::vector& field_names); + + Status + Handle(const typename ResourceT::Ptr&) override; + + const server::ContextPtr context_; + const std::string dir_root_; + const engine::IDNumbers ids_; + const std::vector field_names_; + engine::DataChunkPtr data_chunk_; +}; + +/////////////////////////////////////////////////////////////////////////////// + } // namespace engine } // namespace milvus diff --git a/core/src/db/SnapshotUtils.cpp b/core/src/db/SnapshotUtils.cpp new file mode 100644 index 000000000000..a34eff0ea249 --- /dev/null +++ b/core/src/db/SnapshotUtils.cpp @@ -0,0 +1,126 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "db/SnapshotUtils.h" +#include "db/snapshot/CompoundOperations.h" +#include "db/snapshot/Resources.h" +#include "db/snapshot/Snapshots.h" +#include "segment/Segment.h" + +#include +#include +#include + +namespace milvus { +namespace engine { + +Status +SetSnapshotIndex(const std::string& collection_name, const std::string& field_name, + engine::CollectionIndex& index_info) { + snapshot::ScopedSnapshotT ss; + STATUS_CHECK(snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_name)); + auto field = ss->GetField(field_name); + if (field == nullptr) { + return Status(DB_ERROR, "Invalid field name"); + } + + snapshot::OperationContext ss_context; + auto ftype = field->GetFtype(); + if (ftype == engine::FIELD_TYPE::VECTOR || ftype == engine::FIELD_TYPE::VECTOR_FLOAT || + ftype == engine::FIELD_TYPE::VECTOR_BINARY) { + std::string index_name = knowhere::OldIndexTypeToStr(index_info.engine_type_); + auto new_element = std::make_shared(ss->GetCollectionId(), field->GetID(), index_name, + milvus::engine::FieldElementType::FET_INDEX); + nlohmann::json json; + json[engine::PARAM_INDEX_METRIC_TYPE] = index_info.metric_type_; + json[engine::PARAM_INDEX_EXTRA_PARAMS] = index_info.extra_params_; + new_element->SetParams(json); + ss_context.new_field_elements.push_back(new_element); + } else { + auto new_element = std::make_shared( + ss->GetCollectionId(), field->GetID(), "structured_index", milvus::engine::FieldElementType::FET_INDEX); + ss_context.new_field_elements.push_back(new_element); + } + + auto op = std::make_shared(ss_context, ss); + auto status = op->Push(); + if (!status.ok()) { + return status; + } + + return Status::OK(); +} + +Status +GetSnapshotIndex(const std::string& collection_name, const std::string& field_name, + engine::CollectionIndex& index_info) { + index_info.engine_type_ = 0; + index_info.metric_type_ = 0; + + snapshot::ScopedSnapshotT ss; + STATUS_CHECK(snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_name)); + + auto field = ss->GetField(field_name); + if (field == nullptr) { + return Status(DB_ERROR, "Invalid field name"); + } + + auto field_elements = ss->GetFieldElementsByField(field_name); + auto ftype = field->GetFtype(); + if (ftype == engine::FIELD_TYPE::VECTOR || ftype == engine::FIELD_TYPE::VECTOR_FLOAT || + ftype == engine::FIELD_TYPE::VECTOR_BINARY) { + for (auto& field_element : field_elements) { + if (field_element->GetFtype() == (int64_t)milvus::engine::FieldElementType::FET_INDEX) { + std::string index_name = field_element->GetName(); + index_info.engine_type_ = knowhere::StrToOldIndexType(index_name); + auto json = field_element->GetParams(); + if (json.find(engine::PARAM_INDEX_METRIC_TYPE) != json.end()) { + index_info.metric_type_ = json[engine::PARAM_INDEX_METRIC_TYPE]; + } + if (json.find(engine::PARAM_INDEX_EXTRA_PARAMS) != json.end()) { + index_info.extra_params_ = json[engine::PARAM_INDEX_EXTRA_PARAMS]; + } + break; + } + } + } else { + for (auto& field_element : field_elements) { + if (field_element->GetFtype() == (int64_t)milvus::engine::FieldElementType::FET_INDEX) { + index_info.engine_type_ = (int32_t)engine::StructuredIndexType::SORTED; + } + } + } + + return Status::OK(); +} + +Status +DeleteSnapshotIndex(const std::string& collection_name, const std::string& field_name) { + snapshot::ScopedSnapshotT ss; + STATUS_CHECK(snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_name)); + + snapshot::OperationContext context; + std::vector elements = ss->GetFieldElementsByField(field_name); + for (auto& element : elements) { + if (element->GetFtype() == engine::FieldElementType::FET_INDEX || + element->GetFtype() == engine::FieldElementType::FET_COMPRESS_SQ8) { + context.stale_field_elements.push_back(element); + } + } + + auto op = std::make_shared(context, ss); + STATUS_CHECK(op->Push()); + + return Status::OK(); +} + +} // namespace engine +} // namespace milvus diff --git a/core/src/db/SnapshotUtils.h b/core/src/db/SnapshotUtils.h new file mode 100644 index 000000000000..379c8bd50106 --- /dev/null +++ b/core/src/db/SnapshotUtils.h @@ -0,0 +1,33 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include "db/Types.h" + +#include + +namespace milvus { +namespace engine { + +Status +SetSnapshotIndex(const std::string& collection_name, const std::string& field_name, + engine::CollectionIndex& index_info); + +Status +GetSnapshotIndex(const std::string& collection_name, const std::string& field_name, + engine::CollectionIndex& index_info); + +Status +DeleteSnapshotIndex(const std::string& collection_name, const std::string& field_name); + +} // namespace engine +} // namespace milvus diff --git a/core/src/db/SnapshotVisitor.cpp b/core/src/db/SnapshotVisitor.cpp index b0300f38f044..ff3262890991 100644 --- a/core/src/db/SnapshotVisitor.cpp +++ b/core/src/db/SnapshotVisitor.cpp @@ -10,6 +10,7 @@ // or implied. See the License for the specific language governing permissions and limitations under the License. #include "db/SnapshotVisitor.h" +#include #include "db/SnapshotHandlers.h" #include "db/meta/MetaTypes.h" #include "db/snapshot/Snapshots.h" @@ -29,14 +30,235 @@ SnapshotVisitor::SnapshotVisitor(snapshot::ID_TYPE collection_id) { } Status -SnapshotVisitor::SegmentsToSearch(meta::FilesHolder& files_holder) { +SnapshotVisitor::SegmentsToSearch(snapshot::IDS_TYPE& segment_ids) { STATUS_CHECK(status_); - auto handler = std::make_shared(ss_, files_holder); + auto handler = std::make_shared(ss_, segment_ids); handler->Iterate(); return handler->GetStatus(); } +Status +SnapshotVisitor::SegmentsToIndex(const std::string& field_name, snapshot::IDS_TYPE& segment_ids) { + STATUS_CHECK(status_); + + auto handler = std::make_shared(ss_, field_name, segment_ids); + handler->Iterate(); + + return handler->GetStatus(); +} + +SegmentFieldElementVisitor::Ptr +SegmentFieldElementVisitor::Build(snapshot::ScopedSnapshotT ss, const snapshot::FieldElementPtr& field_element, + const snapshot::SegmentPtr& segment, const snapshot::SegmentFilePtr& segment_file) { + if (!ss || !segment || !field_element) { + return nullptr; + } + + if (segment_file) { + if (segment_file->GetFieldElementId() != field_element->GetID()) { + std::cout << "FieldElement " << segment_file->GetFieldElementId() << " is expected for SegmentFile "; + std::cout << segment_file->GetID() << " while actual is " << field_element->GetID() << std::endl; + return nullptr; + } + if (segment_file->GetSegmentId() != segment->GetID()) { + std::cout << "Segment " << segment_file->GetSegmentId() << " is expected for SegmentFile "; + std::cout << segment_file->GetID() << " while actual is " << segment->GetID() << std::endl; + return nullptr; + } + } + + auto visitor = std::make_shared(); + visitor->SetFieldElement(field_element); + if (segment_file) { + visitor->SetFile(segment_file); + } + + return visitor; +} + +SegmentFieldElementVisitor::Ptr +SegmentFieldElementVisitor::Build(snapshot::ScopedSnapshotT ss, snapshot::ID_TYPE segment_id, + snapshot::ID_TYPE field_element_id) { + if (!ss) { + return nullptr; + } + + auto element = ss->GetResource(field_element_id); + if (!element) { + return nullptr; + } + + auto visitor = std::make_shared(); + visitor->SetFieldElement(element); + auto segment = ss->GetResource(segment_id); + if (!segment) { + return nullptr; + } + + auto file = ss->GetSegmentFile(segment_id, field_element_id); + if (file) { + visitor->SetFile(file); + } + + return visitor; +} + +SegmentFieldVisitor::Ptr +SegmentFieldVisitor::Build(snapshot::ScopedSnapshotT ss, const snapshot::FieldPtr& field, + const snapshot::SegmentPtr& segment, const snapshot::SegmentFile::VecT& segment_files) { + if (!ss || !segment || !field) { + return nullptr; + } + if (ss->GetResource(field->GetID()) != field) { + return nullptr; + } + + auto visitor = std::make_shared(); + visitor->SetField(field); + + std::map files; + for (auto& f : segment_files) { + files[f->GetFieldElementId()] = f; + } + + auto executor = [&](const snapshot::FieldElement::Ptr& field_element, + snapshot::FieldElementIterator* itr) -> Status { + if (field_element->GetFieldId() != field->GetID()) { + return Status::OK(); + } + snapshot::SegmentFilePtr file; + auto it = files.find(field_element->GetID()); + if (it != files.end()) { + file = it->second; + } + auto element_visitor = SegmentFieldElementVisitor::Build(ss, field_element, segment, file); + if (!element_visitor) { + return Status::OK(); + } + visitor->InsertElement(element_visitor); + return Status::OK(); + }; + + auto iterator = std::make_shared(ss, executor); + iterator->Iterate(); + + return visitor; +} + +SegmentFieldVisitor::Ptr +SegmentFieldVisitor::Build(snapshot::ScopedSnapshotT ss, snapshot::ID_TYPE segment_id, snapshot::ID_TYPE field_id) { + if (!ss) { + return nullptr; + } + + auto field = ss->GetResource(field_id); + if (!field) { + return nullptr; + } + + auto visitor = std::make_shared(); + visitor->SetField(field); + + auto executor = [&](const snapshot::FieldElement::Ptr& field_element, + snapshot::FieldElementIterator* itr) -> Status { + if (field_element->GetFieldId() != field_id) { + return Status::OK(); + } + auto element_visitor = SegmentFieldElementVisitor::Build(ss, segment_id, field_element->GetID()); + if (!element_visitor) { + return Status::OK(); + } + visitor->InsertElement(element_visitor); + return Status::OK(); + }; + + auto iterator = std::make_shared(ss, executor); + iterator->Iterate(); + + return visitor; +} + +SegmentVisitor::Ptr +SegmentVisitor::Build(snapshot::ScopedSnapshotT ss, const snapshot::SegmentPtr& segment, + const snapshot::SegmentFile::VecT& segment_files) { + if (!ss || !segment) { + return nullptr; + } + if (!ss->GetResource(segment->GetPartitionId())) { + return nullptr; + } + + auto visitor = std::make_shared(); + visitor->SetSegment(segment); + + auto executor = [&](const snapshot::Field::Ptr& field, snapshot::FieldIterator* itr) -> Status { + auto field_visitor = SegmentFieldVisitor::Build(ss, field, segment, segment_files); + if (!field_visitor) { + return Status::OK(); + } + visitor->InsertField(field_visitor); + + return Status::OK(); + }; + + auto iterator = std::make_shared(ss, executor); + iterator->Iterate(); + + return visitor; +} + +SegmentVisitor::Ptr +SegmentVisitor::Build(snapshot::ScopedSnapshotT ss, snapshot::ID_TYPE segment_id) { + if (!ss) { + return nullptr; + } + auto segment = ss->GetResource(segment_id); + if (!segment) { + return nullptr; + } + + auto visitor = std::make_shared(); + visitor->SetSegment(segment); + + auto executor = [&](const snapshot::Field::Ptr& field, snapshot::FieldIterator* itr) -> Status { + auto field_visitor = SegmentFieldVisitor::Build(ss, segment_id, field->GetID()); + if (!field_visitor) { + return Status::OK(); + } + visitor->InsertField(field_visitor); + + return Status::OK(); + }; + + auto iterator = std::make_shared(ss, executor); + iterator->Iterate(); + + return visitor; +} + +std::string +SegmentVisitor::ToString() const { + std::stringstream ss; + ss << "SegmentVisitor[" << GetSegment()->GetID() << "]: " << (GetSegment()->IsActive() ? "" : "*") << "\n"; + auto& field_visitors = GetFieldVisitors(); + for (auto& fkv : field_visitors) { + ss << " Field[" << fkv.first << "]\n"; + auto& fe_visitors = fkv.second->GetElementVistors(); + for (auto& fekv : fe_visitors) { + ss << " FieldElement[" << fekv.first << "] "; + auto file = fekv.second->GetFile(); + if (file) { + ss << "SegmentFile [" << file->GetID() << "]: " << (file->IsActive() ? "" : "*") << "\n"; + } else { + ss << "No SegmentFile!\n"; + } + } + } + + return ss.str(); +} + } // namespace engine } // namespace milvus diff --git a/core/src/db/SnapshotVisitor.h b/core/src/db/SnapshotVisitor.h index 3cd018896843..8f0aae012646 100644 --- a/core/src/db/SnapshotVisitor.h +++ b/core/src/db/SnapshotVisitor.h @@ -14,7 +14,9 @@ #include "db/meta/FilesHolder.h" #include "db/snapshot/Snapshot.h" +#include #include +#include #include namespace milvus { @@ -27,12 +29,166 @@ class SnapshotVisitor { explicit SnapshotVisitor(snapshot::ID_TYPE collection_id); Status - SegmentsToSearch(meta::FilesHolder& files_holder); + SegmentsToSearch(snapshot::IDS_TYPE& segment_ids); + + Status + SegmentsToIndex(const std::string& field_name, snapshot::IDS_TYPE& segment_ids); protected: snapshot::ScopedSnapshotT ss_; Status status_; }; +class SegmentFieldElementVisitor { + public: + using Ptr = std::shared_ptr; + + static Ptr + Build(snapshot::ScopedSnapshotT ss, snapshot::ID_TYPE segment_id, snapshot::ID_TYPE field_element_id); + static Ptr + Build(snapshot::ScopedSnapshotT ss, const snapshot::FieldElementPtr& field_element, + const snapshot::SegmentPtr& segment, const snapshot::SegmentFilePtr& segment_file); + + SegmentFieldElementVisitor() = default; + + void + SetFieldElement(snapshot::FieldElementPtr field_element) { + field_element_ = field_element; + } + + void + SetFile(snapshot::SegmentFilePtr file) { + file_ = file; + } + + const snapshot::FieldElementPtr + GetElement() const { + return field_element_; + } + + const snapshot::SegmentFilePtr + GetFile() const { + return file_; + } + + protected: + snapshot::FieldElementPtr field_element_; + snapshot::SegmentFilePtr file_; +}; +using SegmentFieldElementVisitorPtr = std::shared_ptr; + +class SegmentFieldVisitor { + public: + using Ptr = std::shared_ptr; + using ElementT = typename SegmentFieldElementVisitor::Ptr; + using ElementsMapT = std::map; + + static Ptr + Build(snapshot::ScopedSnapshotT ss, snapshot::ID_TYPE segment_id, snapshot::ID_TYPE field_id); + static Ptr + Build(snapshot::ScopedSnapshotT ss, const snapshot::FieldPtr& field, const snapshot::SegmentPtr& segment, + const snapshot::SegmentFile::VecT& segment_files); + + SegmentFieldVisitor() = default; + + const ElementsMapT& + GetElementVistors() const { + return elements_map_; + } + const snapshot::FieldPtr& + GetField() const { + return field_; + } + + void + SetField(snapshot::FieldPtr field) { + field_ = field; + } + + void + InsertElement(ElementT element) { + elements_map_[element->GetElement()->GetID()] = element; + } + + const ElementT + GetElementVisitor(const FieldElementType elem_type) const { + for (auto& kv : elements_map_) { + auto& ev = kv.second; + if (ev->GetElement()->GetFtype() == elem_type) { + return ev; + } + } + return nullptr; + } + + protected: + ElementsMapT elements_map_; + snapshot::FieldPtr field_; +}; +using SegmentFieldVisitorPtr = SegmentFieldVisitor::Ptr; + +class SegmentVisitor { + public: + using Ptr = std::shared_ptr; + using FieldVisitorT = typename SegmentFieldVisitor::Ptr; + using IdMapT = std::map; + using NameMapT = std::map; + + static Ptr + Build(snapshot::ScopedSnapshotT ss, snapshot::ID_TYPE segment_id); + static Ptr + Build(snapshot::ScopedSnapshotT ss, const snapshot::SegmentPtr& segment, + const snapshot::SegmentFile::VecT& segment_files); + + SegmentVisitor() = default; + + const IdMapT& + GetFieldVisitors() const { + return id_map_; + } + + FieldVisitorT + GetFieldVisitor(snapshot::ID_TYPE field_id) const { + auto it = id_map_.find(field_id); + if (it == id_map_.end()) { + return nullptr; + } + return it->second; + } + + FieldVisitorT + GetFieldVisitor(const std::string& field_name) const { + auto it = name_map_.find(field_name); + if (it == name_map_.end()) { + return nullptr; + } + return it->second; + } + + const snapshot::SegmentPtr& + GetSegment() const { + return segment_; + } + + void + SetSegment(snapshot::SegmentPtr segment) { + segment_ = segment; + } + void + InsertField(FieldVisitorT field_visitor) { + id_map_[field_visitor->GetField()->GetID()] = field_visitor; + name_map_[field_visitor->GetField()->GetName()] = field_visitor; + } + + std::string + ToString() const; + + protected: + snapshot::SegmentPtr segment_; + IdMapT id_map_; + NameMapT name_map_; +}; +using SegmentVisitorPtr = SegmentVisitor::Ptr; + } // namespace engine } // namespace milvus diff --git a/core/src/db/Types.cpp b/core/src/db/Types.cpp new file mode 100644 index 000000000000..054ac654fd40 --- /dev/null +++ b/core/src/db/Types.cpp @@ -0,0 +1,29 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "db/Types.h" + +namespace milvus { +namespace engine { + +const char* DEFAULT_UID_NAME = "_uid"; + +const char* DEFAULT_RAW_DATA_NAME = "_raw"; +const char* DEFAULT_BLOOM_FILTER_NAME = "_blf"; +const char* DEFAULT_DELETED_DOCS_NAME = "_del"; +const char* DEFAULT_INDEX_NAME = "_idx"; + +const char* PARAM_COLLECTION_DIMENSION = "dimension"; +const char* PARAM_INDEX_METRIC_TYPE = "metric_type"; +const char* PARAM_INDEX_EXTRA_PARAMS = "extra_params"; + +} // namespace engine +} // namespace milvus diff --git a/core/src/db/Types.h b/core/src/db/Types.h index 1a4058f6320b..53c6bb73f3c4 100644 --- a/core/src/db/Types.h +++ b/core/src/db/Types.h @@ -15,6 +15,7 @@ #include #include +#include #include #include #include @@ -33,6 +34,9 @@ typedef segment::doc_id_t IDNumber; typedef IDNumber* IDNumberPtr; typedef std::vector IDNumbers; +typedef faiss::Index::distance_t VectorDistance; +typedef std::vector VectorDistances; + typedef std::vector ResultIds; typedef std::vector ResultDistances; @@ -72,9 +76,32 @@ struct QueryResult { std::vector vectors_; std::vector attrs_; }; +using QueryResultPtr = std::shared_ptr; using File2ErrArray = std::map>; using Table2FileErr = std::map; +extern const char* DEFAULT_UID_NAME; + +extern const char* DEFAULT_RAW_DATA_NAME; +extern const char* DEFAULT_BLOOM_FILTER_NAME; +extern const char* DEFAULT_DELETED_DOCS_NAME; +extern const char* DEFAULT_INDEX_NAME; + +extern const char* PARAM_COLLECTION_DIMENSION; +extern const char* PARAM_INDEX_METRIC_TYPE; +extern const char* PARAM_INDEX_EXTRA_PARAMS; + +using FieldType = meta::hybrid::DataType; + +enum FieldElementType { + FET_NONE = 0, + FET_RAW = 1, + FET_BLOOM_FILTER = 2, + FET_DELETED_DOCS = 3, + FET_INDEX = 4, + FET_COMPRESS_SQ8 = 5, +}; + } // namespace engine } // namespace milvus diff --git a/core/src/db/Utils.cpp b/core/src/db/Utils.cpp index 5a4abcfb6b8e..5d0dc66b17e2 100644 --- a/core/src/db/Utils.cpp +++ b/core/src/db/Utils.cpp @@ -16,14 +16,18 @@ #include #include #include +#include #include #include #include #include "cache/CpuCacheMgr.h" +#include "db/Types.h" + #ifdef MILVUS_GPU_VERSION #include "cache/GpuCacheMgr.h" #endif + #include "config/Config.h" //#include "storage/s3/S3ClientWrapper.h" #include "utils/CommonUtil.h" @@ -252,8 +256,10 @@ GetIndexName(int32_t index_type) { {(int32_t)engine::EngineType::FAISS_IVFSQ8, "IVF_SQ8"}, {(int32_t)engine::EngineType::FAISS_IVFSQ8H, "IVF_SQ8_HYBRID"}, {(int32_t)engine::EngineType::FAISS_PQ, "IVF_PQ"}, +#ifdef MILVUS_SUPPORT_SPTAG {(int32_t)engine::EngineType::SPTAG_KDT, "SPTAG_KDT_RNT"}, {(int32_t)engine::EngineType::SPTAG_BKT, "SPTAG_BKT_RNT"}, +#endif {(int32_t)engine::EngineType::FAISS_BIN_IDMAP, "BIN_FLAT"}, {(int32_t)engine::EngineType::FAISS_BIN_IVFFLAT, "BIN_IVF_FLAT"}, {(int32_t)engine::EngineType::HNSW, "HNSW"}, @@ -299,7 +305,6 @@ EraseFromCache(const std::string& item_key) { } #endif } - } // namespace utils } // namespace engine } // namespace milvus diff --git a/core/src/db/Utils.h b/core/src/db/Utils.h index 193c07e7a46f..7b65fe026505 100644 --- a/core/src/db/Utils.h +++ b/core/src/db/Utils.h @@ -20,6 +20,11 @@ namespace milvus { namespace engine { +namespace snapshot { +class Segment; +class Partition; +class Collection; +} // namespace snapshot namespace utils { int64_t @@ -81,7 +86,6 @@ ExitOnWriteError(Status& status); void EraseFromCache(const std::string& item_key); - } // namespace utils } // namespace engine } // namespace milvus diff --git a/core/src/db/attr/InstanceStructuredIndex.cpp b/core/src/db/attr/InstanceStructuredIndex.cpp index 430cc89907f9..3e6066826b7f 100644 --- a/core/src/db/attr/InstanceStructuredIndex.cpp +++ b/core/src/db/attr/InstanceStructuredIndex.cpp @@ -64,7 +64,7 @@ InstanceStructuredIndex::CreateStructuredIndex(const std::string& collection_id, std::vector field_names; for (auto& field_schema : fields_schema.fields_schema_) { - if (field_schema.field_type_ != (int32_t)engine::meta::hybrid::DataType::FLOAT_VECTOR) { + if (field_schema.field_type_ != (int32_t)engine::meta::hybrid::DataType::VECTOR_FLOAT) { attr_types.insert( std::make_pair(field_schema.field_name_, (engine::meta::hybrid::DataType)field_schema.field_type_)); field_names.emplace_back(field_schema.field_name_); diff --git a/core/src/db/engine/EngineFactory.cpp b/core/src/db/engine/EngineFactory.cpp index f275439d551b..f81527176eef 100644 --- a/core/src/db/engine/EngineFactory.cpp +++ b/core/src/db/engine/EngineFactory.cpp @@ -11,6 +11,8 @@ #include "db/engine/EngineFactory.h" #include "db/engine/ExecutionEngineImpl.h" +#include "db/engine/SSExecutionEngineImpl.h" +#include "db/snapshot/Snapshots.h" #include "utils/Log.h" #include @@ -34,26 +36,16 @@ EngineFactory::Build(uint16_t dimension, const std::string& location, EngineType return execution_engine_ptr; } -// ExecutionEnginePtr -// EngineFactory::Build(uint16_t dimension, -// const std::string& location, -// EngineType index_type, -// MetricType metric_type, -// std::unordered_map& attr_type, -// const milvus::json& index_params) { -// -// if (index_type == EngineType::INVALID) { -// ENGINE_LOG_ERROR << "Unsupported engine type"; -// return nullptr; -// } -// -// ENGINE_LOG_DEBUG << "EngineFactory index type: " << (int)index_type; -// ExecutionEnginePtr execution_engine_ptr = -// std::make_shared(dimension, location, index_type, metric_type, attr_type, index_params); -// -// execution_engine_ptr->Init(); -// return execution_engine_ptr; -//} +SSExecutionEnginePtr +EngineFactory::Build(const std::string& dir_root, const std::string& collection_name, int64_t segment_id) { + snapshot::ScopedSnapshotT ss; + snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_name); + auto seg_visitor = engine::SegmentVisitor::Build(ss, segment_id); + + SSExecutionEnginePtr execution_engine_ptr = std::make_shared(dir_root, seg_visitor); + + return execution_engine_ptr; +} } // namespace engine } // namespace milvus diff --git a/core/src/db/engine/EngineFactory.h b/core/src/db/engine/EngineFactory.h index 14334c1522a2..6c4ae604c67d 100644 --- a/core/src/db/engine/EngineFactory.h +++ b/core/src/db/engine/EngineFactory.h @@ -12,6 +12,7 @@ #pragma once #include "ExecutionEngine.h" +#include "SSExecutionEngine.h" #include "utils/Json.h" #include "utils/Status.h" @@ -26,13 +27,8 @@ class EngineFactory { Build(uint16_t dimension, const std::string& location, EngineType index_type, MetricType metric_type, const milvus::json& index_params); - // static ExecutionEnginePtr - // Build(uint16_t dimension, - // const std::string& location, - // EngineType index_type, - // MetricType metric_type, - // std::unordered_map& attr_type, - // const milvus::json& index_params); + static SSExecutionEnginePtr + Build(const std::string& dir_root, const std::string& collection_name, int64_t segment_id); }; } // namespace engine diff --git a/core/src/db/engine/ExecutionEngine.h b/core/src/db/engine/ExecutionEngine.h index f12bdb0b4a5e..1f91a951d0d1 100644 --- a/core/src/db/engine/ExecutionEngine.h +++ b/core/src/db/engine/ExecutionEngine.h @@ -19,6 +19,7 @@ #include +#include "db/meta/MetaTypes.h" #include "query/GeneralQuery.h" #include "utils/Json.h" #include "utils/Status.h" @@ -32,77 +33,6 @@ using SearchJobPtr = std::shared_ptr; namespace engine { -// TODO(linxj): replace with VecIndex::IndexType -enum class EngineType { - INVALID = 0, - FAISS_IDMAP = 1, - FAISS_IVFFLAT, - FAISS_IVFSQ8, - NSG_MIX, - FAISS_IVFSQ8H, - FAISS_PQ, - SPTAG_KDT, - SPTAG_BKT, - FAISS_BIN_IDMAP, - FAISS_BIN_IVFFLAT, - HNSW, - ANNOY, - MAX_VALUE = ANNOY, -}; - -static std::map s_map_engine_type = { - {"FLAT", EngineType::FAISS_IDMAP}, - {"IVF_FLAT", EngineType::FAISS_IVFFLAT}, - {"IVF_SQ8", EngineType::FAISS_IVFSQ8}, - {"NSG", EngineType::NSG_MIX}, - {"IVF_SQ8_HYBRID", EngineType::FAISS_IVFSQ8H}, - {"IVF_PQ", EngineType::FAISS_PQ}, - {"SPTAG_KDT_RNT", EngineType::SPTAG_KDT}, - {"SPTAG_BKT_RNT", EngineType::SPTAG_BKT}, - {"BIN_FLAT", EngineType::FAISS_BIN_IDMAP}, - {"BIN_IVF_FLAT", EngineType::FAISS_BIN_IVFFLAT}, - {"HNSW", EngineType::HNSW}, - {"ANNOY", EngineType::ANNOY}, -}; - -enum class MetricType { - L2 = 1, // Euclidean Distance - IP = 2, // Cosine Similarity - HAMMING = 3, // Hamming Distance - JACCARD = 4, // Jaccard Distance - TANIMOTO = 5, // Tanimoto Distance - SUBSTRUCTURE = 6, // Substructure Distance - SUPERSTRUCTURE = 7, // Superstructure Distance - MAX_VALUE = SUPERSTRUCTURE -}; - -static std::map s_map_metric_type = { - {"L2", MetricType::L2}, - {"IP", MetricType::IP}, - {"HAMMING", MetricType::HAMMING}, - {"JACCARD", MetricType::JACCARD}, - {"TANIMOTO", MetricType::TANIMOTO}, - {"SUBSTRUCTURE", MetricType::SUBSTRUCTURE}, - {"SUPERSTRUCTURE", MetricType::SUPERSTRUCTURE}, -}; - -enum class DataType { - INT8 = 1, - INT16 = 2, - INT32 = 3, - INT64 = 4, - - STRING = 20, - - BOOL = 30, - - FLOAT = 40, - DOUBLE = 41, - - VECTOR = 100, - UNKNOWN = 9999, -}; - class ExecutionEngine { public: virtual Status @@ -154,10 +84,11 @@ class ExecutionEngine { virtual Status ExecBinaryQuery(query::GeneralQueryPtr general_query, faiss::ConcurrentBitsetPtr& bitset, - std::unordered_map& attr_type, std::string& vector_placeholder) = 0; + std::unordered_map& attr_type, + std::string& vector_placeholder) = 0; virtual Status - HybridSearch(scheduler::SearchJobPtr job, std::unordered_map& attr_type, + HybridSearch(scheduler::SearchJobPtr job, std::unordered_map& attr_type, std::vector& distances, std::vector& search_ids, bool hybrid) = 0; virtual Status diff --git a/core/src/db/engine/ExecutionEngineImpl.cpp b/core/src/db/engine/ExecutionEngineImpl.cpp index d14aec7a5e25..4e0dba2d14c4 100644 --- a/core/src/db/engine/ExecutionEngineImpl.cpp +++ b/core/src/db/engine/ExecutionEngineImpl.cpp @@ -92,6 +92,23 @@ IsBinaryIndexType(knowhere::IndexType type) { return type == knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP || type == knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT; } +codec::ExternalData +GetIndexDataType(EngineType type) { + switch (type) { + case EngineType::FAISS_IVFFLAT: + case EngineType::HNSW: + case EngineType::NSG_MIX: + return codec::ExternalData::ExternalData_RawData; + + case EngineType::HNSW_SQ8NM: + case EngineType::FAISS_IVFSQ8NR: + return codec::ExternalData::ExternalData_SQ8; + + default: + return codec::ExternalData::ExternalData_None; + } +} + } // namespace #ifdef MILVUS_GPU_VERSION @@ -191,6 +208,10 @@ ExecutionEngineImpl::CreatetVecIndex(EngineType type) { index = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, mode); break; } + case EngineType::FAISS_IVFSQ8NR: { + index = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_FAISS_IVFSQ8NR, mode); + break; + } #ifdef MILVUS_GPU_VERSION case EngineType::FAISS_IVFSQ8H: { index = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_FAISS_IVFSQ8H, mode); @@ -223,6 +244,10 @@ ExecutionEngineImpl::CreatetVecIndex(EngineType type) { index = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_HNSW, mode); break; } + case EngineType::HNSW_SQ8NM: { + index = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_HNSW_SQ8NM, mode); + break; + } case EngineType::ANNOY: { index = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_ANNOY, mode); break; @@ -371,7 +396,7 @@ ExecutionEngineImpl::Serialize() { // here we reset index size by file size, // since some index type(such as SQ8) data size become smaller after serialized - index_->SetIndexSize(CommonUtil::GetFileSize(location_)); + index_->UpdateIndexSize(); LOG_ENGINE_DEBUG_ << "Finish serialize index file: " << location_ << " size: " << index_->Size(); if (index_->Size() == 0) { @@ -448,7 +473,10 @@ ExecutionEngineImpl::Load(bool to_cache) { try { segment::SegmentPtr segment_ptr; segment_reader_ptr->GetSegment(segment_ptr); - auto status = segment_reader_ptr->LoadVectorIndex(location_, segment_ptr->vector_index_ptr_); + + auto external_data = GetIndexDataType(index_type_); + auto status = + segment_reader_ptr->LoadVectorIndex(location_, external_data, segment_ptr->vector_index_ptr_); index_ = segment_ptr->vector_index_ptr_->GetVectorIndex(); if (index_ == nullptr) { @@ -802,29 +830,29 @@ ProcessIndexedTermQuery(faiss::ConcurrentBitsetPtr& bitset, knowhere::IndexPtr& Status ExecutionEngineImpl::IndexedTermQuery(faiss::ConcurrentBitsetPtr& bitset, const std::string& field_name, - const DataType& data_type, milvus::json& term_values_json) { + const meta::hybrid::DataType& data_type, milvus::json& term_values_json) { switch (data_type) { - case DataType::INT8: { + case meta::hybrid::DataType::INT8: { ProcessIndexedTermQuery(bitset, attr_index_->attr_index_data().at(field_name), term_values_json); break; } - case DataType::INT16: { + case meta::hybrid::DataType::INT16: { ProcessIndexedTermQuery(bitset, attr_index_->attr_index_data().at(field_name), term_values_json); break; } - case DataType::INT32: { + case meta::hybrid::DataType::INT32: { ProcessIndexedTermQuery(bitset, attr_index_->attr_index_data().at(field_name), term_values_json); break; } - case DataType::INT64: { + case meta::hybrid::DataType::INT64: { ProcessIndexedTermQuery(bitset, attr_index_->attr_index_data().at(field_name), term_values_json); break; } - case DataType::FLOAT: { + case meta::hybrid::DataType::FLOAT: { ProcessIndexedTermQuery(bitset, attr_index_->attr_index_data().at(field_name), term_values_json); break; } - case DataType::DOUBLE: { + case meta::hybrid::DataType::DOUBLE: { ProcessIndexedTermQuery(bitset, attr_index_->attr_index_data().at(field_name), term_values_json); break; } @@ -835,7 +863,7 @@ ExecutionEngineImpl::IndexedTermQuery(faiss::ConcurrentBitsetPtr& bitset, const Status ExecutionEngineImpl::ProcessTermQuery(faiss::ConcurrentBitsetPtr& bitset, query::TermQueryPtr term_query, - std::unordered_map& attr_type) { + std::unordered_map& attr_type) { auto status = Status::OK(); auto term_query_json = term_query->json_obj; auto term_it = term_query_json.begin(); @@ -876,31 +904,31 @@ ProcessIndexedRangeQuery(faiss::ConcurrentBitsetPtr& bitset, knowhere::IndexPtr& } Status -ExecutionEngineImpl::IndexedRangeQuery(faiss::ConcurrentBitsetPtr& bitset, const DataType& data_type, +ExecutionEngineImpl::IndexedRangeQuery(faiss::ConcurrentBitsetPtr& bitset, const meta::hybrid::DataType& data_type, knowhere::IndexPtr& index_ptr, milvus::json& range_values_json) { auto status = Status::OK(); switch (data_type) { - case DataType::INT8: { + case meta::hybrid::DataType::INT8: { ProcessIndexedRangeQuery(bitset, index_ptr, range_values_json); break; } - case DataType::INT16: { + case meta::hybrid::DataType::INT16: { ProcessIndexedRangeQuery(bitset, index_ptr, range_values_json); break; } - case DataType::INT32: { + case meta::hybrid::DataType::INT32: { ProcessIndexedRangeQuery(bitset, index_ptr, range_values_json); break; } - case DataType::INT64: { + case meta::hybrid::DataType::INT64: { ProcessIndexedRangeQuery(bitset, index_ptr, range_values_json); break; } - case DataType::FLOAT: { + case meta::hybrid::DataType::FLOAT: { ProcessIndexedRangeQuery(bitset, index_ptr, range_values_json); break; } - case DataType::DOUBLE: { + case meta::hybrid::DataType::DOUBLE: { ProcessIndexedRangeQuery(bitset, index_ptr, range_values_json); break; } @@ -911,7 +939,7 @@ ExecutionEngineImpl::IndexedRangeQuery(faiss::ConcurrentBitsetPtr& bitset, const } Status -ExecutionEngineImpl::ProcessRangeQuery(const std::unordered_map& attr_type, +ExecutionEngineImpl::ProcessRangeQuery(const std::unordered_map& attr_type, faiss::ConcurrentBitsetPtr& bitset, query::RangeQueryPtr range_query) { auto status = Status::OK(); auto range_query_json = range_query->json_obj; @@ -926,8 +954,8 @@ ExecutionEngineImpl::ProcessRangeQuery(const std::unordered_map& attr_type, std::vector& distances, - std::vector& search_ids, bool hybrid) { + std::unordered_map& attr_type, + std::vector& distances, std::vector& search_ids, bool hybrid) { try { faiss::ConcurrentBitsetPtr bitset; std::string vector_placeholder; @@ -979,7 +1007,7 @@ ExecutionEngineImpl::HybridSearch(scheduler::SearchJobPtr search_job, Status ExecutionEngineImpl::ExecBinaryQuery(milvus::query::GeneralQueryPtr general_query, faiss::ConcurrentBitsetPtr& bitset, - std::unordered_map& attr_type, + std::unordered_map& attr_type, std::string& vector_placeholder) { Status status = Status::OK(); if (general_query->leaf == nullptr) { @@ -1065,7 +1093,7 @@ ExecutionEngineImpl::Search(std::vector& ids, std::vector& dista uint64_t nq = job->nq(); uint64_t topk = job->topk(); - const engine::VectorsData& vectors = job->vectors(); + const VectorsData& vectors = job->vectors(); ids.resize(topk * nq); distances.resize(topk * nq); diff --git a/core/src/db/engine/ExecutionEngineImpl.h b/core/src/db/engine/ExecutionEngineImpl.h index c7b634cf3884..8ed70d39f0a5 100644 --- a/core/src/db/engine/ExecutionEngineImpl.h +++ b/core/src/db/engine/ExecutionEngineImpl.h @@ -78,10 +78,11 @@ class ExecutionEngineImpl : public ExecutionEngine { Status ExecBinaryQuery(query::GeneralQueryPtr general_query, faiss::ConcurrentBitsetPtr& bitset, - std::unordered_map& attr_type, std::string& vector_placeholder) override; + std::unordered_map& attr_type, + std::string& vector_placeholder) override; Status - HybridSearch(scheduler::SearchJobPtr job, std::unordered_map& attr_type, + HybridSearch(scheduler::SearchJobPtr job, std::unordered_map& attr_type, std::vector& distances, std::vector& search_ids, bool hybrid) override; Status @@ -128,19 +129,19 @@ class ExecutionEngineImpl : public ExecutionEngine { Status ProcessTermQuery(faiss::ConcurrentBitsetPtr& bitset, query::TermQueryPtr term_query, - std::unordered_map& attr_type); + std::unordered_map& attr_type); Status - IndexedTermQuery(faiss::ConcurrentBitsetPtr& bitset, const std::string& field_name, const DataType& data_type, - milvus::json& term_values_json); + IndexedTermQuery(faiss::ConcurrentBitsetPtr& bitset, const std::string& field_name, + const meta::hybrid::DataType& data_type, milvus::json& term_values_json); Status - ProcessRangeQuery(const std::unordered_map& attr_type, faiss::ConcurrentBitsetPtr& bitset, - query::RangeQueryPtr range_query); + ProcessRangeQuery(const std::unordered_map& attr_type, + faiss::ConcurrentBitsetPtr& bitset, query::RangeQueryPtr range_query); Status - IndexedRangeQuery(faiss::ConcurrentBitsetPtr& bitset, const DataType& data_type, knowhere::IndexPtr& index_ptr, - milvus::json& range_values_json); + IndexedRangeQuery(faiss::ConcurrentBitsetPtr& bitset, const meta::hybrid::DataType& data_type, + knowhere::IndexPtr& index_ptr, milvus::json& range_values_json); void HybridLoad() const; diff --git a/core/src/db/engine/SSExecutionEngine.h b/core/src/db/engine/SSExecutionEngine.h new file mode 100644 index 000000000000..fb737505f482 --- /dev/null +++ b/core/src/db/engine/SSExecutionEngine.h @@ -0,0 +1,51 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "db/Types.h" +#include "db/meta/MetaTypes.h" +#include "query/GeneralQuery.h" +#include "utils/Status.h" + +namespace milvus { +namespace engine { + +struct ExecutionEngineContext { + query::QueryPtr query_ptr_; + QueryResultPtr query_result_; +}; + +class SSExecutionEngine { + public: + virtual Status + Load(ExecutionEngineContext& context) = 0; + + virtual Status + CopyToGpu(uint64_t device_id) = 0; + + virtual Status + Search(ExecutionEngineContext& context) = 0; + + virtual Status + BuildIndex() = 0; +}; + +using SSExecutionEnginePtr = std::shared_ptr; + +} // namespace engine +} // namespace milvus diff --git a/core/src/db/engine/SSExecutionEngineImpl.cpp b/core/src/db/engine/SSExecutionEngineImpl.cpp new file mode 100644 index 000000000000..daf232f77007 --- /dev/null +++ b/core/src/db/engine/SSExecutionEngineImpl.cpp @@ -0,0 +1,349 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "db/engine/SSExecutionEngineImpl.h" + +#include +#include +#include +#include + +#include "config/Config.h" +#include "db/Utils.h" +#include "segment/SSSegmentReader.h" +#include "segment/SSSegmentWriter.h" +#include "utils/CommonUtil.h" +#include "utils/Error.h" +#include "utils/Exception.h" +#include "utils/Log.h" +#include "utils/Status.h" +#include "utils/TimeRecorder.h" + +#include "knowhere/common/Config.h" +#include "knowhere/index/structured_index/StructuredIndexSort.h" +#include "knowhere/index/vector_index/ConfAdapter.h" +#include "knowhere/index/vector_index/ConfAdapterMgr.h" +#include "knowhere/index/vector_index/IndexBinaryIDMAP.h" +#include "knowhere/index/vector_index/IndexIDMAP.h" +#include "knowhere/index/vector_index/VecIndex.h" +#include "knowhere/index/vector_index/VecIndexFactory.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" + +#ifdef MILVUS_GPU_VERSION + +#include "knowhere/index/vector_index/gpu/GPUIndex.h" +#include "knowhere/index/vector_index/gpu/IndexIVFSQHybrid.h" +#include "knowhere/index/vector_index/gpu/Quantizer.h" +#include "knowhere/index/vector_index/helpers/Cloner.h" + +#endif + +namespace milvus { +namespace engine { + +namespace { +Status +GetRequiredIndexFields(const query::QueryPtr& query_ptr, std::vector& field_names) { + return Status::OK(); +} + +Status +MappingMetricType(MetricType metric_type, milvus::json& conf) { + switch (metric_type) { + case MetricType::IP: + conf[knowhere::Metric::TYPE] = knowhere::Metric::IP; + break; + case MetricType::L2: + conf[knowhere::Metric::TYPE] = knowhere::Metric::L2; + break; + case MetricType::HAMMING: + conf[knowhere::Metric::TYPE] = knowhere::Metric::HAMMING; + break; + case MetricType::JACCARD: + conf[knowhere::Metric::TYPE] = knowhere::Metric::JACCARD; + break; + case MetricType::TANIMOTO: + conf[knowhere::Metric::TYPE] = knowhere::Metric::TANIMOTO; + break; + case MetricType::SUBSTRUCTURE: + conf[knowhere::Metric::TYPE] = knowhere::Metric::SUBSTRUCTURE; + break; + case MetricType::SUPERSTRUCTURE: + conf[knowhere::Metric::TYPE] = knowhere::Metric::SUPERSTRUCTURE; + break; + default: + return Status(DB_ERROR, "Unsupported metric type"); + } + + return Status::OK(); +} + +} // namespace + +SSExecutionEngineImpl::SSExecutionEngineImpl(const std::string& dir_root, const SegmentVisitorPtr& segment_visitor) + : segment_visitor_(segment_visitor) { + segment_reader_ = std::make_shared(dir_root, segment_visitor); +} + +knowhere::VecIndexPtr +SSExecutionEngineImpl::CreatetVecIndex(EngineType type) { + knowhere::VecIndexFactory& vec_index_factory = knowhere::VecIndexFactory::GetInstance(); + knowhere::IndexMode mode = knowhere::IndexMode::MODE_CPU; +#ifdef MILVUS_GPU_VERSION + server::Config& config = server::Config::GetInstance(); + bool gpu_resource_enable = true; + config.GetGpuResourceConfigEnable(gpu_resource_enable); + if (gpu_resource_enable) { + mode = knowhere::IndexMode::MODE_GPU; + } +#endif + + knowhere::VecIndexPtr index = nullptr; + switch (type) { + case EngineType::FAISS_IDMAP: { + index = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_FAISS_IDMAP, mode); + break; + } + case EngineType::FAISS_IVFFLAT: { + index = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, mode); + break; + } + case EngineType::FAISS_PQ: { + index = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_FAISS_IVFPQ, mode); + break; + } + case EngineType::FAISS_IVFSQ8: { + index = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, mode); + break; + } + case EngineType::FAISS_IVFSQ8NR: { + index = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_FAISS_IVFSQ8NR, mode); + break; + } +#ifdef MILVUS_GPU_VERSION + case EngineType::FAISS_IVFSQ8H: { + index = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_FAISS_IVFSQ8H, mode); + break; + } +#endif + case EngineType::FAISS_BIN_IDMAP: { + index = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP, mode); + break; + } + case EngineType::FAISS_BIN_IVFFLAT: { + index = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, mode); + break; + } + case EngineType::NSG_MIX: { + index = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_NSG, mode); + break; + } + case EngineType::HNSW: { + index = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_HNSW, mode); + break; + } + case EngineType::HNSW_SQ8NM: { + index = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_HNSW_SQ8NM, mode); + break; + } + case EngineType::ANNOY: { + index = vec_index_factory.CreateVecIndex(knowhere::IndexEnum::INDEX_ANNOY, mode); + break; + } + default: { + LOG_ENGINE_ERROR_ << "Unsupported index type " << (int)type; + return nullptr; + } + } + if (index == nullptr) { + std::string err_msg = "Invalid index type " + std::to_string((int)type) + " mod " + std::to_string((int)mode); + LOG_ENGINE_ERROR_ << err_msg; + } + return index; +} + +Status +SSExecutionEngineImpl::Load(ExecutionEngineContext& context) { + if (context.query_ptr_ != nullptr) { + return LoadForSearch(context.query_ptr_); + } else { + return LoadForIndex(); + } +} + +Status +SSExecutionEngineImpl::LoadForSearch(const query::QueryPtr& query_ptr) { + SegmentPtr segment_ptr; + segment_reader_->GetSegment(segment_ptr); + + std::vector field_names; + GetRequiredIndexFields(query_ptr, field_names); + + return Load(field_names); +} + +Status +SSExecutionEngineImpl::LoadForIndex() { + std::vector field_names; + + auto field_visitors = segment_visitor_->GetFieldVisitors(); + for (auto& pair : field_visitors) { + auto& field_visitor = pair.second; + auto element_visitor = field_visitor->GetElementVisitor(engine::FieldElementType::FET_INDEX); + if (element_visitor != nullptr && element_visitor->GetFile() == nullptr) { + field_names.push_back(field_visitor->GetField()->GetName()); + break; + } + } + + return Load(field_names); +} + +Status +SSExecutionEngineImpl::Load(const std::vector& field_names) { + SegmentPtr segment_ptr; + segment_reader_->GetSegment(segment_ptr); + + for (auto& name : field_names) { + FIELD_TYPE field_type = FIELD_TYPE::NONE; + segment_ptr->GetFieldType(name, field_type); + + bool index_exist = false; + if (field_type == FIELD_TYPE::VECTOR || field_type == FIELD_TYPE::VECTOR_FLOAT || + field_type == FIELD_TYPE::VECTOR_BINARY) { + knowhere::VecIndexPtr index_ptr; + segment_reader_->LoadVectorIndex(name, index_ptr); + index_exist = (index_ptr != nullptr); + } else { + knowhere::IndexPtr index_ptr; + segment_reader_->LoadStructuredIndex(name, index_ptr); + index_exist = (index_ptr != nullptr); + } + + // index not yet build, load raw data + if (!index_exist) { + std::vector raw; + segment_reader_->LoadField(name, raw); + } + } + + return Status::OK(); +} + +Status +SSExecutionEngineImpl::CopyToGpu(uint64_t device_id) { +#ifdef MILVUS_GPU_VERSION + SegmentPtr segment_ptr; + segment_reader_->GetSegment(segment_ptr); + + engine::VECTOR_INDEX_MAP new_map; + engine::VECTOR_INDEX_MAP& indice = segment_ptr->GetVectorIndice(); + for (auto& pair : indice) { + auto gpu_index = knowhere::cloner::CopyCpuToGpu(pair.second, device_id, knowhere::Config()); + new_map.insert(std::make_pair(pair.first, gpu_index)); + } + + indice.swap(new_map); +#endif + return Status::OK(); +} + +Status +SSExecutionEngineImpl::Search(ExecutionEngineContext& context) { + return Status::OK(); +} + +Status +SSExecutionEngineImpl::BuildIndex() { + SegmentPtr segment_ptr; + segment_reader_->GetSegment(segment_ptr); + + auto field_visitors = segment_visitor_->GetFieldVisitors(); + for (auto& pair : field_visitors) { + auto& field_visitor = pair.second; + auto element_visitor = field_visitor->GetElementVisitor(engine::FieldElementType::FET_INDEX); + if (element_visitor != nullptr && element_visitor->GetFile() == nullptr) { + break; + } + } + + // knowhere::VecIndexPtr field_raw; + // segment_ptr->GetVectorIndex(field_name, field_raw); + // if (field_raw == nullptr) { + // return Status(DB_ERROR, "Field raw not available"); + // } + // + // auto from_index = std::dynamic_pointer_cast(field_raw); + // auto bin_from_index = std::dynamic_pointer_cast(field_raw); + // if (from_index == nullptr && bin_from_index == nullptr) { + // LOG_ENGINE_ERROR_ << "ExecutionEngineImpl: from_index is null, failed to build index"; + // return Status(DB_ERROR, "Field to build index"); + // } + // + // EngineType engine_type = static_cast(index.engine_type_); + // new_index = CreatetVecIndex(engine_type); + // if (!new_index) { + // return Status(DB_ERROR, "Unsupported index type"); + // } + + // milvus::json conf = index.extra_params_; + // conf[knowhere::meta::DIM] = Dimension(); + // conf[knowhere::meta::ROWS] = Count(); + // conf[knowhere::meta::DEVICEID] = gpu_num_; + // MappingMetricType(metric_type_, conf); + // LOG_ENGINE_DEBUG_ << "Index params: " << conf.dump(); + // auto adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(to_index->index_type()); + // if (!adapter->CheckTrain(conf, to_index->index_mode())) { + // throw Exception(DB_ERROR, "Illegal index params"); + // } + // LOG_ENGINE_DEBUG_ << "Index config: " << conf.dump(); + // + // std::vector uids; + // faiss::ConcurrentBitsetPtr blacklist; + // if (from_index) { + // auto dataset = + // knowhere::GenDatasetWithIds(Count(), Dimension(), from_index->GetRawVectors(), + // from_index->GetRawIds()); + // to_index->BuildAll(dataset, conf); + // uids = from_index->GetUids(); + // blacklist = from_index->GetBlacklist(); + // } else if (bin_from_index) { + // auto dataset = knowhere::GenDatasetWithIds(Count(), Dimension(), bin_from_index->GetRawVectors(), + // bin_from_index->GetRawIds()); + // to_index->BuildAll(dataset, conf); + // uids = bin_from_index->GetUids(); + // blacklist = bin_from_index->GetBlacklist(); + // } + // + //#ifdef MILVUS_GPU_VERSION + // /* for GPU index, need copy back to CPU */ + // if (to_index->index_mode() == knowhere::IndexMode::MODE_GPU) { + // auto device_index = std::dynamic_pointer_cast(to_index); + // to_index = device_index->CopyGpuToCpu(conf); + // } + //#endif + // + // to_index->SetUids(uids); + // LOG_ENGINE_DEBUG_ << "Set " << to_index->GetUids().size() << "uids for " << location; + // if (blacklist != nullptr) { + // to_index->SetBlacklist(blacklist); + // LOG_ENGINE_DEBUG_ << "Set blacklist for index " << location; + // } + // + // LOG_ENGINE_DEBUG_ << "Finish build index: " << location; + // return std::make_shared(to_index, location, engine_type, metric_type_, index_params_); + + return Status::OK(); +} + +} // namespace engine +} // namespace milvus diff --git a/core/src/db/engine/SSExecutionEngineImpl.h b/core/src/db/engine/SSExecutionEngineImpl.h new file mode 100644 index 000000000000..9237d66f1795 --- /dev/null +++ b/core/src/db/engine/SSExecutionEngineImpl.h @@ -0,0 +1,61 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "SSExecutionEngine.h" +#include "db/SnapshotVisitor.h" +#include "segment/SSSegmentReader.h" + +namespace milvus { +namespace engine { + +class SSExecutionEngineImpl : public SSExecutionEngine { + public: + SSExecutionEngineImpl(const std::string& dir_root, const SegmentVisitorPtr& segment_visitor); + + Status + Load(ExecutionEngineContext& context) override; + + Status + CopyToGpu(uint64_t device_id) override; + + Status + Search(ExecutionEngineContext& context) override; + + Status + BuildIndex() override; + + private: + knowhere::VecIndexPtr + CreatetVecIndex(EngineType type); + + Status + LoadForSearch(const query::QueryPtr& query_ptr); + + Status + LoadForIndex(); + + Status + Load(const std::vector& field_names); + + private: + SegmentVisitorPtr segment_visitor_; + segment::SSSegmentReaderPtr segment_reader_; +}; + +} // namespace engine +} // namespace milvus diff --git a/core/src/db/insert/MemManagerFactory.cpp b/core/src/db/insert/MemManagerFactory.cpp index b75f3605a5ad..fc58fc8711a6 100644 --- a/core/src/db/insert/MemManagerFactory.cpp +++ b/core/src/db/insert/MemManagerFactory.cpp @@ -11,6 +11,7 @@ #include "db/insert/MemManagerFactory.h" #include "MemManagerImpl.h" +#include "SSMemManagerImpl.h" #include "utils/Exception.h" #include "utils/Log.h" @@ -30,5 +31,10 @@ MemManagerFactory::Build(const std::shared_ptr& meta, const DBOption return std::make_shared(meta, options); } +SSMemManagerPtr +MemManagerFactory::SSBuild(const DBOptions& options) { + return std::make_shared(options); +} + } // namespace engine } // namespace milvus diff --git a/core/src/db/insert/MemManagerFactory.h b/core/src/db/insert/MemManagerFactory.h index f77dc6e575af..a3dfe70e011c 100644 --- a/core/src/db/insert/MemManagerFactory.h +++ b/core/src/db/insert/MemManagerFactory.h @@ -12,6 +12,7 @@ #pragma once #include "MemManager.h" +#include "SSMemManager.h" #include "db/meta/Meta.h" #include @@ -23,6 +24,9 @@ class MemManagerFactory { public: static MemManagerPtr Build(const std::shared_ptr& meta, const DBOptions& options); + + static SSMemManagerPtr + SSBuild(const DBOptions& options); }; } // namespace engine diff --git a/core/src/db/insert/SSMemCollection.cpp b/core/src/db/insert/SSMemCollection.cpp new file mode 100644 index 000000000000..b4a410c0b11c --- /dev/null +++ b/core/src/db/insert/SSMemCollection.cpp @@ -0,0 +1,379 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include +#include +#include +#include +#include + +#include "cache/CpuCacheMgr.h" +#include "db/Utils.h" +#include "db/insert/SSMemCollection.h" +#include "knowhere/index/vector_index/VecIndex.h" +#include "segment/SegmentReader.h" +#include "utils/Log.h" +#include "utils/TimeRecorder.h" + +namespace milvus { +namespace engine { + +SSMemCollection::SSMemCollection(int64_t collection_id, int64_t partition_id, const DBOptions& options) + : collection_id_(collection_id), partition_id_(partition_id), options_(options) { + SetIdentity("SSMemCollection"); + AddCacheInsertDataListener(); +} + +Status +SSMemCollection::Add(const milvus::engine::SSVectorSourcePtr& source) { + while (!source->AllAdded()) { + SSMemSegmentPtr current_mem_segment; + if (!mem_segment_list_.empty()) { + current_mem_segment = mem_segment_list_.back(); + } + + Status status; + if (mem_segment_list_.empty() || current_mem_segment->IsFull()) { + SSMemSegmentPtr new_mem_segment = std::make_shared(collection_id_, partition_id_, options_); + status = new_mem_segment->Add(source); + if (status.ok()) { + mem_segment_list_.emplace_back(new_mem_segment); + } else { + return status; + } + } else { + status = current_mem_segment->Add(source); + } + + if (!status.ok()) { + std::string err_msg = "Insert failed: " + status.ToString(); + LOG_ENGINE_ERROR_ << LogOut("[%s][%ld] ", "insert", 0) << err_msg; + return Status(DB_ERROR, err_msg); + } + } + return Status::OK(); +} + +Status +SSMemCollection::Delete(segment::doc_id_t doc_id) { + // Locate which collection file the doc id lands in + for (auto& mem_segment : mem_segment_list_) { + mem_segment->Delete(doc_id); + } + // Add the id to delete list so it can be applied to other segments on disk during the next flush + doc_ids_to_delete_.insert(doc_id); + + return Status::OK(); +} + +Status +SSMemCollection::Delete(const std::vector& doc_ids) { + // Locate which collection file the doc id lands in + for (auto& mem_segment : mem_segment_list_) { + mem_segment->Delete(doc_ids); + } + // Add the id to delete list so it can be applied to other segments on disk during the next flush + for (auto& id : doc_ids) { + doc_ids_to_delete_.insert(id); + } + + return Status::OK(); +} + +void +SSMemCollection::GetCurrentMemSegment(SSMemSegmentPtr& mem_segment) { + mem_segment = mem_segment_list_.back(); +} + +size_t +SSMemCollection::GetTableFileCount() { + return mem_segment_list_.size(); +} + +Status +SSMemCollection::Serialize(uint64_t wal_lsn) { + TimeRecorder recorder("SSMemCollection::Serialize collection " + collection_id_); + + if (!doc_ids_to_delete_.empty()) { + auto status = ApplyDeletes(); + if (!status.ok()) { + return Status(DB_ERROR, status.message()); + } + } + + for (auto mem_segment = mem_segment_list_.begin(); mem_segment != mem_segment_list_.end();) { + auto status = (*mem_segment)->Serialize(wal_lsn); + if (!status.ok()) { + return status; + } + + LOG_ENGINE_DEBUG_ << "Flushed segment " << (*mem_segment)->GetSegmentId(); + + { + std::lock_guard lock(mutex_); + mem_segment = mem_segment_list_.erase(mem_segment); + } + } + + recorder.RecordSection("Finished flushing"); + + return Status::OK(); +} + +bool +SSMemCollection::Empty() { + return mem_segment_list_.empty() && doc_ids_to_delete_.empty(); +} + +int64_t +SSMemCollection::GetCollectionId() const { + return collection_id_; +} + +int64_t +SSMemCollection::GetPartitionId() const { + return partition_id_; +} + +size_t +SSMemCollection::GetCurrentMem() { + std::lock_guard lock(mutex_); + size_t total_mem = 0; + for (auto& mem_table_file : mem_segment_list_) { + total_mem += mem_table_file->GetCurrentMem(); + } + return total_mem; +} + +Status +SSMemCollection::ApplyDeletes() { + // Applying deletes to other segments on disk and their corresponding cache: + // For each segment in collection: + // Load its bloom filter + // For each id in delete list: + // If present, add the uid to segment's uid list + // For each segment + // Get its cache if exists + // Load its uids file. + // Scan the uids, if any uid in segment's uid list exists: + // add its offset to deletedDoc + // remove the id from bloom filter + // set black list in cache + // Serialize segment's deletedDoc TODO(zhiru): append directly to previous file for now, may have duplicates + // Serialize bloom filter + + // LOG_ENGINE_DEBUG_ << "Applying " << doc_ids_to_delete_.size() << " deletes in collection: " << collection_id_; + // + // TimeRecorder recorder("SSMemCollection::ApplyDeletes for collection " + collection_id_); + // + // std::vector file_types{meta::SegmentSchema::FILE_TYPE::RAW, meta::SegmentSchema::FILE_TYPE::TO_INDEX, + // meta::SegmentSchema::FILE_TYPE::BACKUP}; + // meta::FilesHolder files_holder; + // auto status = meta_->FilesByType(collection_id_, file_types, files_holder); + // if (!status.ok()) { + // std::string err_msg = "Failed to apply deletes: " + status.ToString(); + // LOG_ENGINE_ERROR_ << err_msg; + // return Status(DB_ERROR, err_msg); + // } + // + // // attention: here is a copy, not reference, since files_holder.UnmarkFile will change the array internal + // milvus::engine::meta::SegmentsSchema files = files_holder.HoldFiles(); + // + // // which file need to be apply delete + // std::unordered_map> ids_to_check_map; // file id mapping to delete ids + // for (auto& file : files) { + // std::string segment_dir; + // utils::GetParentPath(file.location_, segment_dir); + // + // segment::SegmentReader segment_reader(segment_dir); + // segment::IdBloomFilterPtr id_bloom_filter_ptr; + // segment_reader.LoadBloomFilter(id_bloom_filter_ptr); + // + // for (auto& id : doc_ids_to_delete_) { + // if (id_bloom_filter_ptr->Check(id)) { + // ids_to_check_map[file.id_].emplace_back(id); + // } + // } + // } + // + // // release unused files + // for (auto& file : files) { + // if (ids_to_check_map.find(file.id_) == ids_to_check_map.end()) { + // files_holder.UnmarkFile(file); + // } + // } + // + // // attention: here is a copy, not reference, since files_holder.UnmarkFile will change the array internal + // milvus::engine::meta::SegmentsSchema hold_files = files_holder.HoldFiles(); + // recorder.RecordSection("Found " + std::to_string(hold_files.size()) + " segment to apply deletes"); + // + // meta::SegmentsSchema files_to_update; + // for (auto& file : hold_files) { + // LOG_ENGINE_DEBUG_ << "Applying deletes in segment: " << file.segment_id_; + // + // TimeRecorder rec("handle segment " + file.segment_id_); + // + // std::string segment_dir; + // utils::GetParentPath(file.location_, segment_dir); + // segment::SegmentReader segment_reader(segment_dir); + // + // auto& segment_id = file.segment_id_; + // meta::FilesHolder segment_holder; + // status = meta_->GetCollectionFilesBySegmentId(segment_id, segment_holder); + // if (!status.ok()) { + // break; + // } + // + // // Get all index that contains blacklist in cache + // std::vector indexes; + // std::vector blacklists; + // milvus::engine::meta::SegmentsSchema& segment_files = segment_holder.HoldFiles(); + // for (auto& segment_file : segment_files) { + // auto data_obj_ptr = cache::CpuCacheMgr::GetInstance()->GetIndex(segment_file.location_); + // auto index = std::static_pointer_cast(data_obj_ptr); + // if (index != nullptr) { + // faiss::ConcurrentBitsetPtr blacklist = index->GetBlacklist(); + // if (blacklist != nullptr) { + // indexes.emplace_back(index); + // blacklists.emplace_back(blacklist); + // } + // } + // } + // + // std::vector uids; + // status = segment_reader.LoadUids(uids); + // if (!status.ok()) { + // break; + // } + // segment::IdBloomFilterPtr id_bloom_filter_ptr; + // status = segment_reader.LoadBloomFilter(id_bloom_filter_ptr); + // if (!status.ok()) { + // break; + // } + // + // auto& ids_to_check = ids_to_check_map[file.id_]; + // + // segment::DeletedDocsPtr deleted_docs = std::make_shared(); + // + // rec.RecordSection("Loading uids and deleted docs"); + // + // std::sort(ids_to_check.begin(), ids_to_check.end()); + // + // rec.RecordSection("Sorting " + std::to_string(ids_to_check.size()) + " ids"); + // + // size_t delete_count = 0; + // auto find_diff = std::chrono::duration::zero(); + // auto set_diff = std::chrono::duration::zero(); + // + // for (size_t i = 0; i < uids.size(); ++i) { + // auto find_start = std::chrono::high_resolution_clock::now(); + // + // auto found = std::binary_search(ids_to_check.begin(), ids_to_check.end(), uids[i]); + // + // auto find_end = std::chrono::high_resolution_clock::now(); + // find_diff += (find_end - find_start); + // + // if (found) { + // auto set_start = std::chrono::high_resolution_clock::now(); + // + // delete_count++; + // + // deleted_docs->AddDeletedDoc(i); + // + // if (id_bloom_filter_ptr->Check(uids[i])) { + // id_bloom_filter_ptr->Remove(uids[i]); + // } + // + // for (auto& blacklist : blacklists) { + // if (!blacklist->test(i)) { + // blacklist->set(i); + // } + // } + // + // auto set_end = std::chrono::high_resolution_clock::now(); + // set_diff += (set_end - set_start); + // } + // } + // + // LOG_ENGINE_DEBUG_ << "Finding " << ids_to_check.size() << " uids in " << uids.size() << " uids took " + // << find_diff.count() << " s in total"; + // LOG_ENGINE_DEBUG_ << "Setting deleted docs and bloom filter took " << set_diff.count() << " s in total"; + // + // rec.RecordSection("Find uids and set deleted docs and bloom filter"); + // + // for (size_t i = 0; i < indexes.size(); ++i) { + // indexes[i]->SetBlacklist(blacklists[i]); + // } + // + // segment::Segment tmp_segment; + // segment::SegmentWriter segment_writer(segment_dir); + // status = segment_writer.WriteDeletedDocs(deleted_docs); + // if (!status.ok()) { + // break; + // } + // + // rec.RecordSection("Appended " + std::to_string(deleted_docs->GetSize()) + " offsets to deleted docs"); + // + // status = segment_writer.WriteBloomFilter(id_bloom_filter_ptr); + // if (!status.ok()) { + // break; + // } + // + // rec.RecordSection("Updated bloom filter"); + // + // // Update collection file row count + // for (auto& segment_file : segment_files) { + // if (segment_file.file_type_ == meta::SegmentSchema::RAW || + // segment_file.file_type_ == meta::SegmentSchema::TO_INDEX || + // segment_file.file_type_ == meta::SegmentSchema::INDEX || + // segment_file.file_type_ == meta::SegmentSchema::BACKUP) { + // segment_file.row_count_ -= delete_count; + // files_to_update.emplace_back(segment_file); + // } + // } + // rec.RecordSection("Update collection file row count in vector"); + // } + // + // recorder.RecordSection("Finished " + std::to_string(ids_to_check_map.size()) + " segment to apply deletes"); + // + // status = meta_->UpdateCollectionFilesRowCount(files_to_update); + // + // if (!status.ok()) { + // std::string err_msg = "Failed to apply deletes: " + status.ToString(); + // LOG_ENGINE_ERROR_ << err_msg; + // return Status(DB_ERROR, err_msg); + // } + // + // doc_ids_to_delete_.clear(); + // + // recorder.RecordSection("Update deletes to meta"); + // recorder.ElapseFromBegin("Finished deletes"); + + return Status::OK(); +} + +uint64_t +SSMemCollection::GetLSN() { + return lsn_; +} + +void +SSMemCollection::SetLSN(uint64_t lsn) { + lsn_ = lsn; +} + +void +SSMemCollection::OnCacheInsertDataChanged(bool value) { + options_.insert_cache_immediately_ = value; +} + +} // namespace engine +} // namespace milvus diff --git a/core/src/db/insert/SSMemCollection.h b/core/src/db/insert/SSMemCollection.h new file mode 100644 index 000000000000..20594ba43f1a --- /dev/null +++ b/core/src/db/insert/SSMemCollection.h @@ -0,0 +1,97 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "config/handler/CacheConfigHandler.h" +#include "db/insert/SSMemSegment.h" +#include "db/insert/SSVectorSource.h" +#include "utils/Status.h" + +namespace milvus { +namespace engine { + +class SSMemCollection : public server::CacheConfigHandler { + public: + using SSMemCollectionFileList = std::vector; + + SSMemCollection(int64_t collection_id, int64_t partition_id, const DBOptions& options); + + Status + Add(const SSVectorSourcePtr& source); + + Status + Delete(segment::doc_id_t doc_id); + + Status + Delete(const std::vector& doc_ids); + + void + GetCurrentMemSegment(SSMemSegmentPtr& mem_segment); + + size_t + GetTableFileCount(); + + Status + Serialize(uint64_t wal_lsn); + + bool + Empty(); + + int64_t + GetCollectionId() const; + + int64_t + GetPartitionId() const; + + size_t + GetCurrentMem(); + + uint64_t + GetLSN(); + + void + SetLSN(uint64_t lsn); + + protected: + void + OnCacheInsertDataChanged(bool value) override; + + private: + Status + ApplyDeletes(); + + private: + int64_t collection_id_; + int64_t partition_id_; + + SSMemCollectionFileList mem_segment_list_; + + DBOptions options_; + + std::mutex mutex_; + + std::set doc_ids_to_delete_; + + std::atomic lsn_; +}; // SSMemCollection + +using SSMemCollectionPtr = std::shared_ptr; + +} // namespace engine +} // namespace milvus diff --git a/core/src/db/insert/SSMemManager.h b/core/src/db/insert/SSMemManager.h new file mode 100644 index 000000000000..836347cc3f47 --- /dev/null +++ b/core/src/db/insert/SSMemManager.h @@ -0,0 +1,68 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "db/Types.h" +#include "segment/Segment.h" +#include "utils/Status.h" + +namespace milvus { +namespace engine { + +extern const char* VECTOR_FIELD; + +class SSMemManager { + public: + virtual Status + InsertEntities(int64_t collection_id, int64_t partition_id, const DataChunkPtr& chunk, uint64_t lsn) = 0; + + virtual Status + DeleteEntity(int64_t collection_id, IDNumber vector_id, uint64_t lsn) = 0; + + virtual Status + DeleteEntities(int64_t collection_id, int64_t length, const IDNumber* vector_ids, uint64_t lsn) = 0; + + virtual Status + Flush(int64_t collection_id) = 0; + + virtual Status + Flush(std::set& collection_ids) = 0; + + // virtual Status + // Serialize(std::set& table_ids) = 0; + + virtual Status + EraseMemVector(int64_t collection_id) = 0; + + virtual Status + EraseMemVector(int64_t collection_id, int64_t partition_id) = 0; + + virtual size_t + GetCurrentMutableMem() = 0; + + virtual size_t + GetCurrentImmutableMem() = 0; + + virtual size_t + GetCurrentMem() = 0; +}; // MemManagerAbstract + +using SSMemManagerPtr = std::shared_ptr; + +} // namespace engine +} // namespace milvus diff --git a/core/src/db/insert/SSMemManagerImpl.cpp b/core/src/db/insert/SSMemManagerImpl.cpp new file mode 100644 index 000000000000..ea3869604dec --- /dev/null +++ b/core/src/db/insert/SSMemManagerImpl.cpp @@ -0,0 +1,395 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "db/insert/SSMemManagerImpl.h" + +#include +#include + +#include "SSVectorSource.h" +#include "db/Constants.h" +#include "db/snapshot/Snapshots.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" +#include "utils/Log.h" + +namespace milvus { +namespace engine { + +const char* VECTOR_FIELD = "vector"; // hard code + +SSMemCollectionPtr +SSMemManagerImpl::GetMemByTable(int64_t collection_id, int64_t partition_id) { + auto mem_collection = mem_map_.find(collection_id); + if (mem_collection != mem_map_.end()) { + auto mem_partition = mem_collection->second.find(partition_id); + if (mem_partition != mem_collection->second.end()) { + return mem_partition->second; + } + } + + auto mem = std::make_shared(collection_id, partition_id, options_); + mem_map_[collection_id][partition_id] = mem; + return mem; +} + +std::vector +SSMemManagerImpl::GetMemByTable(int64_t collection_id) { + std::vector result; + auto mem_collection = mem_map_.find(collection_id); + if (mem_collection != mem_map_.end()) { + for (auto& pair : mem_collection->second) { + result.push_back(pair.second); + } + } + return result; +} + +Status +SSMemManagerImpl::InsertEntities(int64_t collection_id, int64_t partition_id, const DataChunkPtr& chunk, uint64_t lsn) { + auto status = ValidateChunk(collection_id, partition_id, chunk); + if (!status.ok()) { + return status; + } + + SSVectorSourcePtr source = std::make_shared(chunk); + std::unique_lock lock(mutex_); + return InsertEntitiesNoLock(collection_id, partition_id, source, lsn); +} + +Status +SSMemManagerImpl::ValidateChunk(int64_t collection_id, int64_t partition_id, const DataChunkPtr& chunk) { + if (chunk == nullptr) { + return Status(DB_ERROR, "Null chunk pointer"); + } + + snapshot::ScopedSnapshotT ss; + auto status = snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_id); + if (!status.ok()) { + std::string err_msg = "Could not get snapshot: " + status.ToString(); + LOG_ENGINE_ERROR_ << err_msg; + return status; + } + + std::vector field_names = ss->GetFieldNames(); + for (auto& name : field_names) { + auto iter = chunk->fixed_fields_.find(name); + if (iter == chunk->fixed_fields_.end()) { + std::string err_msg = "Missed chunk field: " + name; + LOG_ENGINE_ERROR_ << err_msg; + return Status(DB_ERROR, err_msg); + } + + size_t data_size = iter->second.size(); + + snapshot::FieldPtr field = ss->GetField(name); + meta::hybrid::DataType ftype = static_cast(field->GetFtype()); + std::string err_msg = "Illegal data size for chunk field: "; + switch (ftype) { + case meta::hybrid::DataType::BOOL: + if (data_size != chunk->count_ * sizeof(bool)) { + return Status(DB_ERROR, err_msg + name); + } + break; + case meta::hybrid::DataType::DOUBLE: + if (data_size != chunk->count_ * sizeof(double)) { + return Status(DB_ERROR, err_msg + name); + } + break; + case meta::hybrid::DataType::FLOAT: + if (data_size != chunk->count_ * sizeof(float)) { + return Status(DB_ERROR, err_msg + name); + } + break; + case meta::hybrid::DataType::INT8: + if (data_size != chunk->count_ * sizeof(uint8_t)) { + return Status(DB_ERROR, err_msg + name); + } + break; + case meta::hybrid::DataType::INT16: + if (data_size != chunk->count_ * sizeof(uint16_t)) { + return Status(DB_ERROR, err_msg + name); + } + break; + case meta::hybrid::DataType::INT32: + if (data_size != chunk->count_ * sizeof(uint32_t)) { + return Status(DB_ERROR, err_msg + name); + } + break; + case meta::hybrid::DataType::UID: + case meta::hybrid::DataType::INT64: + if (data_size != chunk->count_ * sizeof(uint64_t)) { + return Status(DB_ERROR, err_msg + name); + } + break; + case meta::hybrid::DataType::VECTOR_FLOAT: + case meta::hybrid::DataType::VECTOR_BINARY: { + json params = field->GetParams(); + if (params.find(knowhere::meta::DIM) == params.end()) { + std::string msg = "Vector field params must contain: dimension"; + LOG_SERVER_ERROR_ << msg; + return Status(DB_ERROR, msg); + } + + int64_t dimension = params[knowhere::meta::DIM]; + int64_t row_size = + (ftype == meta::hybrid::DataType::VECTOR_BINARY) ? dimension / 8 : dimension * sizeof(float); + if (data_size != chunk->count_ * row_size) { + return Status(DB_ERROR, err_msg + name); + } + + break; + } + } + } + + return Status::OK(); +} + +Status +SSMemManagerImpl::InsertEntitiesNoLock(int64_t collection_id, int64_t partition_id, + const milvus::engine::SSVectorSourcePtr& source, uint64_t lsn) { + SSMemCollectionPtr mem = GetMemByTable(collection_id, partition_id); + mem->SetLSN(lsn); + + auto status = mem->Add(source); + return status; +} + +Status +SSMemManagerImpl::DeleteEntity(int64_t collection_id, IDNumber vector_id, uint64_t lsn) { + std::unique_lock lock(mutex_); + std::vector mems = GetMemByTable(collection_id); + + for (auto& mem : mems) { + mem->SetLSN(lsn); + auto status = mem->Delete(vector_id); + if (status.ok()) { + return status; + } + } + + return Status::OK(); +} + +Status +SSMemManagerImpl::DeleteEntities(int64_t collection_id, int64_t length, const IDNumber* vector_ids, uint64_t lsn) { + std::unique_lock lock(mutex_); + std::vector mems = GetMemByTable(collection_id); + + for (auto& mem : mems) { + mem->SetLSN(lsn); + + IDNumbers ids; + ids.resize(length); + memcpy(ids.data(), vector_ids, length * sizeof(IDNumber)); + + auto status = mem->Delete(ids); + if (!status.ok()) { + return status; + } + } + + return Status::OK(); +} + +Status +SSMemManagerImpl::Flush(int64_t collection_id) { + ToImmutable(collection_id); + // TODO: There is actually only one memTable in the immutable list + MemList temp_immutable_list; + { + std::unique_lock lock(mutex_); + immu_mem_list_.swap(temp_immutable_list); + } + + std::unique_lock lock(serialization_mtx_); + auto max_lsn = GetMaxLSN(temp_immutable_list); + for (auto& mem : temp_immutable_list) { + LOG_ENGINE_DEBUG_ << "Flushing collection: " << mem->GetCollectionId(); + auto status = mem->Serialize(max_lsn); + if (!status.ok()) { + LOG_ENGINE_ERROR_ << "Flush collection " << mem->GetCollectionId() << " failed"; + return status; + } + LOG_ENGINE_DEBUG_ << "Flushed collection: " << mem->GetCollectionId(); + } + + return Status::OK(); +} + +Status +SSMemManagerImpl::Flush(std::set& collection_ids) { + ToImmutable(); + + MemList temp_immutable_list; + { + std::unique_lock lock(mutex_); + immu_mem_list_.swap(temp_immutable_list); + } + + std::unique_lock lock(serialization_mtx_); + collection_ids.clear(); + auto max_lsn = GetMaxLSN(temp_immutable_list); + for (auto& mem : temp_immutable_list) { + LOG_ENGINE_DEBUG_ << "Flushing collection: " << mem->GetCollectionId(); + auto status = mem->Serialize(max_lsn); + if (!status.ok()) { + LOG_ENGINE_ERROR_ << "Flush collection " << mem->GetCollectionId() << " failed"; + return status; + } + collection_ids.insert(mem->GetCollectionId()); + LOG_ENGINE_DEBUG_ << "Flushed collection: " << mem->GetCollectionId(); + } + + // TODO: global lsn? + // meta_->SetGlobalLastLSN(max_lsn); + + return Status::OK(); +} + +Status +SSMemManagerImpl::ToImmutable(int64_t collection_id) { + std::unique_lock lock(mutex_); + + auto mem_collection = mem_map_.find(collection_id); + if (mem_collection != mem_map_.end()) { + MemPartitionMap temp_map; + for (auto& mem : mem_collection->second) { + if (mem.second->Empty()) { + temp_map.insert(mem); + } else { + immu_mem_list_.push_back(mem.second); + } + } + + mem_collection->second.swap(temp_map); + if (temp_map.empty()) { + mem_map_.erase(mem_collection); + } + } + + return Status::OK(); +} + +Status +SSMemManagerImpl::ToImmutable() { + std::unique_lock lock(mutex_); + + for (auto& mem_collection : mem_map_) { + MemPartitionMap temp_map; + for (auto& mem : mem_collection.second) { + if (mem.second->Empty()) { + temp_map.insert(mem); + } else { + immu_mem_list_.push_back(mem.second); + } + } + + mem_collection.second.swap(temp_map); + } + + return Status::OK(); +} + +Status +SSMemManagerImpl::EraseMemVector(int64_t collection_id) { + { // erase MemVector from rapid-insert cache + std::unique_lock lock(mutex_); + mem_map_.erase(collection_id); + } + + { // erase MemVector from serialize cache + std::unique_lock lock(serialization_mtx_); + MemList temp_list; + for (auto& mem : immu_mem_list_) { + if (mem->GetCollectionId() != collection_id) { + temp_list.push_back(mem); + } + } + immu_mem_list_.swap(temp_list); + } + + return Status::OK(); +} + +Status +SSMemManagerImpl::EraseMemVector(int64_t collection_id, int64_t partition_id) { + { // erase MemVector from rapid-insert cache + std::unique_lock lock(mutex_); + auto mem_collection = mem_map_.find(collection_id); + if (mem_collection != mem_map_.end()) { + mem_collection->second.erase(partition_id); + if (mem_collection->second.empty()) { + mem_map_.erase(collection_id); + } + } + } + + { // erase MemVector from serialize cache + std::unique_lock lock(serialization_mtx_); + MemList temp_list; + for (auto& mem : immu_mem_list_) { + if (mem->GetCollectionId() != collection_id && mem->GetPartitionId() != partition_id) { + temp_list.push_back(mem); + } + } + immu_mem_list_.swap(temp_list); + } + + return Status::OK(); +} + +size_t +SSMemManagerImpl::GetCurrentMutableMem() { + size_t total_mem = 0; + std::unique_lock lock(mutex_); + for (auto& mem_collection : mem_map_) { + for (auto& mem : mem_collection.second) { + total_mem += mem.second->GetCurrentMem(); + } + } + return total_mem; +} + +size_t +SSMemManagerImpl::GetCurrentImmutableMem() { + size_t total_mem = 0; + std::unique_lock lock(serialization_mtx_); + for (auto& mem_table : immu_mem_list_) { + total_mem += mem_table->GetCurrentMem(); + } + return total_mem; +} + +size_t +SSMemManagerImpl::GetCurrentMem() { + return GetCurrentMutableMem() + GetCurrentImmutableMem(); +} + +uint64_t +SSMemManagerImpl::GetMaxLSN(const MemList& tables) { + uint64_t max_lsn = 0; + for (auto& collection : tables) { + auto cur_lsn = collection->GetLSN(); + if (collection->GetLSN() > max_lsn) { + max_lsn = cur_lsn; + } + } + return max_lsn; +} + +void +SSMemManagerImpl::OnInsertBufferSizeChanged(int64_t value) { + options_.insert_buffer_size_ = value * GB; +} + +} // namespace engine +} // namespace milvus diff --git a/core/src/db/insert/SSMemManagerImpl.h b/core/src/db/insert/SSMemManagerImpl.h new file mode 100644 index 000000000000..c2a67b306bea --- /dev/null +++ b/core/src/db/insert/SSMemManagerImpl.h @@ -0,0 +1,109 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "config/Config.h" +#include "config/handler/CacheConfigHandler.h" +#include "db/insert/SSMemCollection.h" +#include "db/insert/SSMemManager.h" +#include "utils/Status.h" + +namespace milvus { +namespace engine { + +class SSMemManagerImpl : public SSMemManager, public server::CacheConfigHandler { + public: + using Ptr = std::shared_ptr; + using MemPartitionMap = std::map; + using MemCollectionMap = std::map; + using MemList = std::vector; + + explicit SSMemManagerImpl(const DBOptions& options) : options_(options) { + SetIdentity("SSMemManagerImpl"); + AddInsertBufferSizeListener(); + } + + Status + InsertEntities(int64_t collection_id, int64_t partition_id, const DataChunkPtr& chunk, uint64_t lsn) override; + + Status + DeleteEntity(int64_t collection_id, IDNumber vector_id, uint64_t lsn) override; + + Status + DeleteEntities(int64_t collection_id, int64_t length, const IDNumber* vector_ids, uint64_t lsn) override; + + Status + Flush(int64_t collection_id) override; + + Status + Flush(std::set& collection_ids) override; + + Status + EraseMemVector(int64_t collection_id) override; + + Status + EraseMemVector(int64_t collection_id, int64_t partition_id) override; + + size_t + GetCurrentMutableMem() override; + + size_t + GetCurrentImmutableMem() override; + + size_t + GetCurrentMem() override; + + protected: + void + OnInsertBufferSizeChanged(int64_t value) override; + + private: + SSMemCollectionPtr + GetMemByTable(int64_t collection_id, int64_t partition_id); + + std::vector + GetMemByTable(int64_t collection_id); + + Status + ValidateChunk(int64_t collection_id, int64_t partition_id, const DataChunkPtr& chunk); + + Status + InsertEntitiesNoLock(int64_t collection_id, int64_t partition_id, const SSVectorSourcePtr& source, uint64_t lsn); + + Status + ToImmutable(); + + Status + ToImmutable(int64_t collection_id); + + uint64_t + GetMaxLSN(const MemList& tables); + + MemCollectionMap mem_map_; + MemList immu_mem_list_; + + DBOptions options_; + std::mutex mutex_; + std::mutex serialization_mtx_; +}; // NewMemManager + +} // namespace engine +} // namespace milvus diff --git a/core/src/db/insert/SSMemSegment.cpp b/core/src/db/insert/SSMemSegment.cpp new file mode 100644 index 000000000000..3f4cc967c02a --- /dev/null +++ b/core/src/db/insert/SSMemSegment.cpp @@ -0,0 +1,307 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "db/insert/SSMemSegment.h" + +#include +#include +#include +#include +#include + +#include "db/Constants.h" +#include "db/Utils.h" +#include "db/engine/EngineFactory.h" +#include "db/meta/MetaTypes.h" +#include "db/snapshot/Operations.h" +#include "db/snapshot/Snapshots.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" +#include "metrics/Metrics.h" +#include "segment/SegmentReader.h" +#include "utils/Log.h" + +namespace milvus { +namespace engine { + +SSMemSegment::SSMemSegment(int64_t collection_id, int64_t partition_id, const DBOptions& options) + : collection_id_(collection_id), partition_id_(partition_id), options_(options) { + current_mem_ = 0; + CreateSegment(); + + SetIdentity("SSMemSegment"); + AddCacheInsertDataListener(); +} + +Status +SSMemSegment::CreateSegment() { + snapshot::ScopedSnapshotT ss; + auto status = snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_id_); + if (!status.ok()) { + std::string err_msg = "SSMemSegment::CreateSegment failed: " + status.ToString(); + LOG_ENGINE_ERROR_ << err_msg; + return status; + } + + // create segment + snapshot::OperationContext context; + context.prev_partition = ss->GetResource(partition_id_); + operation_ = std::make_shared(context, ss); + status = operation_->CommitNewSegment(segment_); + if (!status.ok()) { + std::string err_msg = "SSMemSegment::CreateSegment failed: " + status.ToString(); + LOG_ENGINE_ERROR_ << err_msg; + return status; + } + + // create segment raw files (placeholder) + auto names = ss->GetFieldNames(); + for (auto& name : names) { + snapshot::SegmentFileContext sf_context; + sf_context.collection_id = collection_id_; + sf_context.partition_id = partition_id_; + sf_context.segment_id = segment_->GetID(); + sf_context.field_name = name; + sf_context.field_element_name = engine::DEFAULT_RAW_DATA_NAME; + + snapshot::SegmentFilePtr seg_file; + status = operation_->CommitNewSegmentFile(sf_context, seg_file); + if (!status.ok()) { + std::string err_msg = "SSMemSegment::CreateSegment failed: " + status.ToString(); + LOG_ENGINE_ERROR_ << err_msg; + return status; + } + } + + // create deleted_doc and bloom_filter files (placeholder) + { + snapshot::SegmentFileContext sf_context; + sf_context.collection_id = collection_id_; + sf_context.partition_id = partition_id_; + sf_context.segment_id = segment_->GetID(); + sf_context.field_name = engine::DEFAULT_UID_NAME; + sf_context.field_element_name = engine::DEFAULT_DELETED_DOCS_NAME; + + snapshot::SegmentFilePtr delete_doc_file, bloom_filter_file; + status = operation_->CommitNewSegmentFile(sf_context, delete_doc_file); + if (!status.ok()) { + std::string err_msg = "SSMemSegment::CreateSegment failed: " + status.ToString(); + LOG_ENGINE_ERROR_ << err_msg; + return status; + } + + sf_context.field_element_name = engine::DEFAULT_BLOOM_FILTER_NAME; + status = operation_->CommitNewSegmentFile(sf_context, bloom_filter_file); + if (!status.ok()) { + std::string err_msg = "SSMemSegment::CreateSegment failed: " + status.ToString(); + LOG_ENGINE_ERROR_ << err_msg; + return status; + } + } + + auto ctx = operation_->GetContext(); + auto visitor = SegmentVisitor::Build(ss, ctx.new_segment, ctx.new_segment_files); + + // create segment writer + segment_writer_ptr_ = std::make_shared(options_.meta_.path_, visitor); + + return Status::OK(); +} + +Status +SSMemSegment::GetSingleEntitySize(int64_t& single_size) { + snapshot::ScopedSnapshotT ss; + auto status = snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_id_); + if (!status.ok()) { + std::string err_msg = "SSMemSegment::SingleEntitySize failed: " + status.ToString(); + LOG_ENGINE_ERROR_ << err_msg; + return status; + } + + single_size = 0; + std::vector field_names = ss->GetFieldNames(); + for (auto& name : field_names) { + snapshot::FieldPtr field = ss->GetField(name); + meta::hybrid::DataType ftype = static_cast(field->GetFtype()); + switch (ftype) { + case meta::hybrid::DataType::BOOL: + single_size += sizeof(bool); + break; + case meta::hybrid::DataType::DOUBLE: + single_size += sizeof(double); + break; + case meta::hybrid::DataType::FLOAT: + single_size += sizeof(float); + break; + case meta::hybrid::DataType::INT8: + single_size += sizeof(uint8_t); + break; + case meta::hybrid::DataType::INT16: + single_size += sizeof(uint16_t); + break; + case meta::hybrid::DataType::INT32: + single_size += sizeof(uint32_t); + break; + case meta::hybrid::DataType::UID: + case meta::hybrid::DataType::INT64: + single_size += sizeof(uint64_t); + break; + case meta::hybrid::DataType::VECTOR: + case meta::hybrid::DataType::VECTOR_FLOAT: + case meta::hybrid::DataType::VECTOR_BINARY: { + json params = field->GetParams(); + if (params.find(knowhere::meta::DIM) == params.end()) { + std::string msg = "Vector field params must contain: dimension"; + LOG_SERVER_ERROR_ << msg; + return Status(DB_ERROR, msg); + } + + int64_t dimension = params[knowhere::meta::DIM]; + if (ftype == meta::hybrid::DataType::VECTOR_BINARY) { + single_size += (dimension / 8); + } else { + single_size += (dimension * sizeof(float)); + } + + break; + } + } + } + + return Status::OK(); +} + +Status +SSMemSegment::Add(const SSVectorSourcePtr& source) { + int64_t single_entity_mem_size = 0; + auto status = GetSingleEntitySize(single_entity_mem_size); + if (!status.ok()) { + return status; + } + + size_t mem_left = GetMemLeft(); + if (mem_left >= single_entity_mem_size) { + int64_t num_entities_to_add = std::ceil(mem_left / single_entity_mem_size); + int64_t num_entities_added; + + auto status = source->Add(segment_writer_ptr_, num_entities_to_add, num_entities_added); + + if (status.ok()) { + current_mem_ += (num_entities_added * single_entity_mem_size); + } + return status; + } + return Status::OK(); +} + +Status +SSMemSegment::Delete(segment::doc_id_t doc_id) { + engine::SegmentPtr segment_ptr; + segment_writer_ptr_->GetSegment(segment_ptr); + + // Check wither the doc_id is present, if yes, delete it's corresponding buffer + engine::FIXED_FIELD_DATA raw_data; + auto status = segment_ptr->GetFixedFieldData(engine::DEFAULT_UID_NAME, raw_data); + if (!status.ok()) { + return Status::OK(); + } + + int64_t* uids = reinterpret_cast(raw_data.data()); + int64_t row_count = segment_ptr->GetRowCount(); + for (int64_t i = 0; i < row_count; i++) { + if (doc_id == uids[i]) { + segment_ptr->DeleteEntity(i); + } + } + + return Status::OK(); +} + +Status +SSMemSegment::Delete(const std::vector& doc_ids) { + engine::SegmentPtr segment_ptr; + segment_writer_ptr_->GetSegment(segment_ptr); + + // Check wither the doc_id is present, if yes, delete it's corresponding buffer + std::vector temp; + temp.resize(doc_ids.size()); + memcpy(temp.data(), doc_ids.data(), doc_ids.size() * sizeof(segment::doc_id_t)); + + std::sort(temp.begin(), temp.end()); + + engine::FIXED_FIELD_DATA raw_data; + auto status = segment_ptr->GetFixedFieldData(engine::DEFAULT_UID_NAME, raw_data); + if (!status.ok()) { + return Status::OK(); + } + + int64_t* uids = reinterpret_cast(raw_data.data()); + int64_t row_count = segment_ptr->GetRowCount(); + size_t deleted = 0; + for (int64_t i = 0; i < row_count; ++i) { + if (std::binary_search(temp.begin(), temp.end(), uids[i])) { + segment_ptr->DeleteEntity(i - deleted); + ++deleted; + } + } + + return Status::OK(); +} + +int64_t +SSMemSegment::GetCurrentMem() { + return current_mem_; +} + +int64_t +SSMemSegment::GetMemLeft() { + return (MAX_TABLE_FILE_MEM - current_mem_); +} + +bool +SSMemSegment::IsFull() { + int64_t single_entity_mem_size = 0; + auto status = GetSingleEntitySize(single_entity_mem_size); + if (!status.ok()) { + return true; + } + + return (GetMemLeft() < single_entity_mem_size); +} + +Status +SSMemSegment::Serialize(uint64_t wal_lsn) { + int64_t size = GetCurrentMem(); + server::CollectSerializeMetrics metrics(size); + + auto status = segment_writer_ptr_->Serialize(); + if (!status.ok()) { + LOG_ENGINE_ERROR_ << "Failed to serialize segment: " << segment_->GetID(); + return status; + } + + status = operation_->CommitRowCount(segment_writer_ptr_->RowCount()); + status = operation_->Push(); + LOG_ENGINE_DEBUG_ << "New segment " << segment_->GetID() << " serialized, lsn = " << wal_lsn; + return status; +} + +int64_t +SSMemSegment::GetSegmentId() const { + return segment_->GetID(); +} + +void +SSMemSegment::OnCacheInsertDataChanged(bool value) { + options_.insert_cache_immediately_ = value; +} + +} // namespace engine +} // namespace milvus diff --git a/core/src/db/insert/SSMemSegment.h b/core/src/db/insert/SSMemSegment.h new file mode 100644 index 000000000000..2f72d7144ff0 --- /dev/null +++ b/core/src/db/insert/SSMemSegment.h @@ -0,0 +1,87 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include + +#include "config/handler/CacheConfigHandler.h" +#include "db/engine/ExecutionEngine.h" +#include "db/insert/SSVectorSource.h" +#include "db/snapshot/CompoundOperations.h" +#include "db/snapshot/Resources.h" +#include "segment/SSSegmentWriter.h" +#include "utils/Status.h" + +namespace milvus { +namespace engine { + +class SSMemSegment : public server::CacheConfigHandler { + public: + SSMemSegment(int64_t collection_id, int64_t partition_id, const DBOptions& options); + + ~SSMemSegment() = default; + + public: + Status + Add(const SSVectorSourcePtr& source); + + Status + Delete(segment::doc_id_t doc_id); + + Status + Delete(const std::vector& doc_ids); + + int64_t + GetCurrentMem(); + + int64_t + GetMemLeft(); + + bool + IsFull(); + + Status + Serialize(uint64_t wal_lsn); + + int64_t + GetSegmentId() const; + + protected: + void + OnCacheInsertDataChanged(bool value) override; + + private: + Status + CreateSegment(); + + Status + GetSingleEntitySize(int64_t& single_size); + + private: + int64_t collection_id_; + int64_t partition_id_; + + std::shared_ptr operation_; + snapshot::SegmentPtr segment_; + DBOptions options_; + int64_t current_mem_; + + // ExecutionEnginePtr execution_engine_; + segment::SSSegmentWriterPtr segment_writer_ptr_; +}; // SSMemTableFile + +using SSMemSegmentPtr = std::shared_ptr; + +} // namespace engine +} // namespace milvus diff --git a/core/src/db/insert/SSVectorSource.cpp b/core/src/db/insert/SSVectorSource.cpp new file mode 100644 index 000000000000..527c34429935 --- /dev/null +++ b/core/src/db/insert/SSVectorSource.cpp @@ -0,0 +1,51 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "db/insert/SSVectorSource.h" + +#include +#include + +#include "db/engine/EngineFactory.h" +#include "db/engine/ExecutionEngine.h" +#include "metrics/Metrics.h" +#include "utils/Log.h" +#include "utils/TimeRecorder.h" + +namespace milvus { +namespace engine { + +SSVectorSource::SSVectorSource(const DataChunkPtr& chunk) : chunk_(chunk) { +} + +Status +SSVectorSource::Add(const segment::SSSegmentWriterPtr& segment_writer_ptr, const int64_t& num_entities_to_add, + int64_t& num_entities_added) { + // TODO: n = vectors_.vector_count_;??? + int64_t n = chunk_->count_; + num_entities_added = current_num_added_ + num_entities_to_add <= n ? num_entities_to_add : n - current_num_added_; + + auto status = segment_writer_ptr->AddChunk(chunk_, current_num_added_, num_entities_added); + if (!status.ok()) { + return status; + } + + current_num_added_ += num_entities_added; + return status; +} + +bool +SSVectorSource::AllAdded() { + return (current_num_added_ >= chunk_->count_); +} + +} // namespace engine +} // namespace milvus diff --git a/core/src/db/insert/SSVectorSource.h b/core/src/db/insert/SSVectorSource.h new file mode 100644 index 000000000000..8f6189b79569 --- /dev/null +++ b/core/src/db/insert/SSVectorSource.h @@ -0,0 +1,51 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "db/IDGenerator.h" +#include "db/engine/ExecutionEngine.h" +#include "db/insert/SSMemManager.h" +#include "segment/SSSegmentWriter.h" +#include "segment/Segment.h" +#include "utils/Status.h" + +namespace milvus { +namespace engine { + +// TODO(zhiru): this class needs to be refactored once attributes are added + +class SSVectorSource { + public: + explicit SSVectorSource(const DataChunkPtr& chunk); + + Status + Add(const segment::SSSegmentWriterPtr& segment_writer_ptr, const int64_t& num_attrs_to_add, + int64_t& num_attrs_added); + + bool + AllAdded(); + + private: + DataChunkPtr chunk_; + + int64_t current_num_added_ = 0; +}; // SSVectorSource + +using SSVectorSourcePtr = std::shared_ptr; + +} // namespace engine +} // namespace milvus diff --git a/core/src/db/merge/MergeManagerFactory.cpp b/core/src/db/merge/MergeManagerFactory.cpp index 4f15281e163b..aa3186c99986 100644 --- a/core/src/db/merge/MergeManagerFactory.cpp +++ b/core/src/db/merge/MergeManagerFactory.cpp @@ -11,6 +11,7 @@ #include "db/merge/MergeManagerFactory.h" #include "db/merge/MergeManagerImpl.h" +#include "db/merge/SSMergeManagerImpl.h" #include "utils/Exception.h" #include "utils/Log.h" @@ -22,5 +23,10 @@ MergeManagerFactory::Build(const meta::MetaPtr& meta_ptr, const DBOptions& optio return std::make_shared(meta_ptr, options, MergeStrategyType::LAYERED); } +MergeManagerPtr +MergeManagerFactory::SSBuild(const DBOptions& options) { + return std::make_shared(options, MergeStrategyType::SIMPLE); +} + } // namespace engine } // namespace milvus diff --git a/core/src/db/merge/MergeManagerFactory.h b/core/src/db/merge/MergeManagerFactory.h index 533a3211618d..b7a072aa14d5 100644 --- a/core/src/db/merge/MergeManagerFactory.h +++ b/core/src/db/merge/MergeManagerFactory.h @@ -23,6 +23,9 @@ class MergeManagerFactory { public: static MergeManagerPtr Build(const meta::MetaPtr& meta_ptr, const DBOptions& options); + + static MergeManagerPtr + SSBuild(const DBOptions& options); }; } // namespace engine diff --git a/core/src/db/merge/SSMergeManagerImpl.cpp b/core/src/db/merge/SSMergeManagerImpl.cpp new file mode 100644 index 000000000000..8d117e0477ff --- /dev/null +++ b/core/src/db/merge/SSMergeManagerImpl.cpp @@ -0,0 +1,100 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "db/merge/SSMergeManagerImpl.h" +#include "db/merge/SSMergeSimpleStrategy.h" +#include "db/merge/SSMergeTask.h" +#include "db/snapshot/Snapshots.h" +#include "utils/Exception.h" +#include "utils/Log.h" + +#include + +namespace milvus { +namespace engine { + +SSMergeManagerImpl::SSMergeManagerImpl(const DBOptions& options, MergeStrategyType type) + : options_(options), strategy_type_(type) { + UseStrategy(type); +} + +Status +SSMergeManagerImpl::UseStrategy(MergeStrategyType type) { + switch (type) { + case MergeStrategyType::SIMPLE: { + strategy_ = std::make_shared(); + break; + } + case MergeStrategyType::LAYERED: + case MergeStrategyType::ADAPTIVE: + default: { + std::string msg = "Unsupported merge strategy type: " + std::to_string((int32_t)type); + LOG_ENGINE_ERROR_ << msg; + throw Exception(DB_ERROR, msg); + } + } + strategy_type_ = type; + + return Status::OK(); +} + +Status +SSMergeManagerImpl::MergeFiles(const std::string& collection_name) { + if (strategy_ == nullptr) { + std::string msg = "No merge strategy specified"; + LOG_ENGINE_ERROR_ << msg; + return Status(DB_ERROR, msg); + } + + int64_t row_count_per_segment = DEFAULT_ROW_COUNT_PER_SEGMENT; + while (true) { + snapshot::ScopedSnapshotT latest_ss; + STATUS_CHECK(snapshot::Snapshots::GetInstance().GetSnapshot(latest_ss, collection_name)); + + Partition2SegmentsMap part2seg; + auto& segments = latest_ss->GetResources(); + for (auto& kv : segments) { + auto segment_commit = latest_ss->GetSegmentCommitBySegmentId(kv.second->GetID()); + part2seg[kv.second->GetPartitionId()].push_back(kv.second->GetID()); + } + + Partition2SegmentsMap::iterator it; + for (it = part2seg.begin(); it != part2seg.end();) { + if (it->second.size() <= 1) { + part2seg.erase(it++); + } else { + it++; + } + } + + if (part2seg.empty()) { + break; + } + + SegmentGroups segment_groups; + auto status = strategy_->RegroupSegments(latest_ss, part2seg, segment_groups); + if (!status.ok()) { + LOG_ENGINE_ERROR_ << "Failed to regroup segments for: " << collection_name + << ", continue to merge all files into one"; + return status; + } + + for (auto& segments : segment_groups) { + SSMergeTask task(options_, latest_ss, segments); + task.Execute(); + } + } + + return Status::OK(); +} + +} // namespace engine +} // namespace milvus diff --git a/core/src/db/merge/SSMergeManagerImpl.h b/core/src/db/merge/SSMergeManagerImpl.h new file mode 100644 index 000000000000..bd900b044fb3 --- /dev/null +++ b/core/src/db/merge/SSMergeManagerImpl.h @@ -0,0 +1,53 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "db/merge/MergeManager.h" +#include "db/merge/SSMergeStrategy.h" +#include "utils/Status.h" + +namespace milvus { +namespace engine { + +class SSMergeManagerImpl : public MergeManager { + public: + SSMergeManagerImpl(const DBOptions& options, MergeStrategyType type); + + MergeStrategyType + Strategy() const override { + return strategy_type_; + } + + Status + UseStrategy(MergeStrategyType type) override; + + Status + MergeFiles(const std::string& collection_name) override; + + private: + DBOptions options_; + + MergeStrategyType strategy_type_ = MergeStrategyType::SIMPLE; + SSMergeStrategyPtr strategy_; +}; // MergeManagerImpl + +} // namespace engine +} // namespace milvus diff --git a/core/src/db/merge/SSMergeSimpleStrategy.cpp b/core/src/db/merge/SSMergeSimpleStrategy.cpp new file mode 100644 index 000000000000..83547c169c5b --- /dev/null +++ b/core/src/db/merge/SSMergeSimpleStrategy.cpp @@ -0,0 +1,62 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "db/merge/SSMergeSimpleStrategy.h" +#include "db/snapshot/Snapshots.h" +#include "utils/Log.h" + +namespace milvus { +namespace engine { + +const char* ROW_COUNT_PER_SEGMENT = "row_count_per_segment"; + +Status +SSMergeSimpleStrategy::RegroupSegments(const snapshot::ScopedSnapshotT& ss, const Partition2SegmentsMap& part2segment, + SegmentGroups& groups) { + auto collection = ss->GetCollection(); + + int64_t row_count_per_segment = DEFAULT_ROW_COUNT_PER_SEGMENT; + const json params = collection->GetParams(); + if (params.find(ROW_COUNT_PER_SEGMENT) != params.end()) { + row_count_per_segment = params[ROW_COUNT_PER_SEGMENT]; + } + + for (auto& kv : part2segment) { + snapshot::IDS_TYPE ids; + int64_t row_count_sum = 0; + for (auto& id : kv.second) { + auto segment_commit = ss->GetSegmentCommitBySegmentId(id); + if (segment_commit == nullptr) { + continue; // maybe stale + } + + ids.push_back(id); + row_count_sum += segment_commit->GetRowCount(); + if (row_count_sum >= row_count_per_segment) { + if (ids.size() >= 2) { + groups.push_back(ids); + } + ids.clear(); + row_count_sum = 0; + continue; + } + } + + if (!ids.empty()) { + groups.push_back(ids); + } + } + + return Status::OK(); +} + +} // namespace engine +} // namespace milvus diff --git a/core/src/db/merge/SSMergeSimpleStrategy.h b/core/src/db/merge/SSMergeSimpleStrategy.h new file mode 100644 index 000000000000..217e46e44faf --- /dev/null +++ b/core/src/db/merge/SSMergeSimpleStrategy.h @@ -0,0 +1,31 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include + +#include "db/merge/SSMergeStrategy.h" +#include "utils/Status.h" + +namespace milvus { +namespace engine { + +class SSMergeSimpleStrategy : public SSMergeStrategy { + public: + Status + RegroupSegments(const snapshot::ScopedSnapshotT& ss, const Partition2SegmentsMap& part2segment, + SegmentGroups& groups) override; +}; // MergeSimpleStrategy + +} // namespace engine +} // namespace milvus diff --git a/core/src/db/merge/SSMergeStrategy.h b/core/src/db/merge/SSMergeStrategy.h new file mode 100644 index 000000000000..badadf7b7339 --- /dev/null +++ b/core/src/db/merge/SSMergeStrategy.h @@ -0,0 +1,43 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "db/Types.h" +#include "db/snapshot/ResourceTypes.h" +#include "db/snapshot/Snapshot.h" +#include "utils/Status.h" + +namespace milvus { +namespace engine { + +const int64_t DEFAULT_ROW_COUNT_PER_SEGMENT = 500000; + +using Partition2SegmentsMap = std::map; +using SegmentGroups = std::vector; + +class SSMergeStrategy { + public: + virtual Status + RegroupSegments(const snapshot::ScopedSnapshotT& ss, const Partition2SegmentsMap& part2segment, + SegmentGroups& groups) = 0; +}; // MergeStrategy + +using SSMergeStrategyPtr = std::shared_ptr; + +} // namespace engine +} // namespace milvus diff --git a/core/src/db/merge/SSMergeTask.cpp b/core/src/db/merge/SSMergeTask.cpp new file mode 100644 index 000000000000..7607b5c6f053 --- /dev/null +++ b/core/src/db/merge/SSMergeTask.cpp @@ -0,0 +1,139 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "db/merge/SSMergeTask.h" +#include "db/Utils.h" +#include "db/snapshot/CompoundOperations.h" +#include "db/snapshot/Operations.h" +#include "db/snapshot/Snapshots.h" +#include "metrics/Metrics.h" +#include "segment/SSSegmentReader.h" +#include "segment/SSSegmentWriter.h" +#include "utils/Log.h" + +#include +#include + +namespace milvus { +namespace engine { + +SSMergeTask::SSMergeTask(const DBOptions& options, const snapshot::ScopedSnapshotT& ss, + const snapshot::IDS_TYPE& segments) + : options_(options), snapshot_(ss), segments_(segments) { +} + +Status +SSMergeTask::Execute() { + if (segments_.size() <= 1) { + return Status::OK(); + } + + snapshot::OperationContext context; + for (auto& id : segments_) { + auto seg = snapshot_->GetResource(id); + if (!seg) { + return Status(DB_ERROR, "Snapshot segment is null"); + } + + context.stale_segments.push_back(seg); + if (!context.prev_partition) { + snapshot::PartitionPtr partition = snapshot_->GetResource(seg->GetPartitionId()); + context.prev_partition = partition; + } + } + + auto op = std::make_shared(context, snapshot_); + snapshot::SegmentPtr new_seg; + auto status = op->CommitNewSegment(new_seg); + if (!status.ok()) { + return status; + } + + // create segment raw files (placeholder) + auto names = snapshot_->GetFieldNames(); + for (auto& name : names) { + snapshot::SegmentFileContext sf_context; + sf_context.collection_id = new_seg->GetCollectionId(); + sf_context.partition_id = new_seg->GetPartitionId(); + sf_context.segment_id = new_seg->GetID(); + sf_context.field_name = name; + sf_context.field_element_name = engine::DEFAULT_RAW_DATA_NAME; + + snapshot::SegmentFilePtr seg_file; + status = op->CommitNewSegmentFile(sf_context, seg_file); + if (!status.ok()) { + std::string err_msg = "SSMergeTask create segment failed: " + status.ToString(); + LOG_ENGINE_ERROR_ << err_msg; + return status; + } + } + + // create deleted_doc and bloom_filter files (placeholder) + { + snapshot::SegmentFileContext sf_context; + sf_context.collection_id = new_seg->GetCollectionId(); + sf_context.partition_id = new_seg->GetPartitionId(); + sf_context.segment_id = new_seg->GetID(); + sf_context.field_name = engine::DEFAULT_UID_NAME; + sf_context.field_element_name = engine::DEFAULT_DELETED_DOCS_NAME; + + snapshot::SegmentFilePtr delete_doc_file, bloom_filter_file; + status = op->CommitNewSegmentFile(sf_context, delete_doc_file); + if (!status.ok()) { + std::string err_msg = "SSMergeTask create bloom filter segment file failed: " + status.ToString(); + LOG_ENGINE_ERROR_ << err_msg; + return status; + } + + sf_context.field_element_name = engine::DEFAULT_BLOOM_FILTER_NAME; + status = op->CommitNewSegmentFile(sf_context, bloom_filter_file); + if (!status.ok()) { + std::string err_msg = "SSMergeTask create deleted-doc segment file failed: " + status.ToString(); + LOG_ENGINE_ERROR_ << err_msg; + return status; + } + } + + auto ctx = op->GetContext(); + auto visitor = SegmentVisitor::Build(snapshot_, ctx.new_segment, ctx.new_segment_files); + + // create segment writer + segment::SSSegmentWriterPtr segment_writer = + std::make_shared(options_.meta_.path_, visitor); + + // merge + for (auto& id : segments_) { + auto seg = snapshot_->GetResource(id); + + auto read_visitor = SegmentVisitor::Build(snapshot_, id); + segment::SSSegmentReaderPtr segment_reader = + std::make_shared(options_.meta_.path_, read_visitor); + status = segment_writer->Merge(segment_reader); + if (!status.ok()) { + std::string err_msg = "SSMergeTask merge failed: " + status.ToString(); + LOG_ENGINE_ERROR_ << err_msg; + return status; + } + } + + status = segment_writer->Serialize(); + if (!status.ok()) { + LOG_ENGINE_ERROR_ << "Failed to serialize segment: " << new_seg->GetID(); + return status; + } + + status = op->Push(); + + return status; +} + +} // namespace engine +} // namespace milvus diff --git a/core/src/db/merge/SSMergeTask.h b/core/src/db/merge/SSMergeTask.h new file mode 100644 index 000000000000..bfb22214ccc0 --- /dev/null +++ b/core/src/db/merge/SSMergeTask.h @@ -0,0 +1,39 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include + +#include "db/merge/MergeManager.h" +#include "db/meta/MetaTypes.h" +#include "db/snapshot/ResourceTypes.h" +#include "db/snapshot/Snapshot.h" +#include "utils/Status.h" + +namespace milvus { +namespace engine { + +class SSMergeTask { + public: + SSMergeTask(const DBOptions& options, const snapshot::ScopedSnapshotT& ss, const snapshot::IDS_TYPE& segments); + + Status + Execute(); + + private: + DBOptions options_; + snapshot::ScopedSnapshotT snapshot_; + snapshot::IDS_TYPE segments_; +}; // SSMergeTask + +} // namespace engine +} // namespace milvus diff --git a/core/src/db/meta/MetaAdapter.h b/core/src/db/meta/MetaAdapter.h new file mode 100644 index 000000000000..538ede1aadce --- /dev/null +++ b/core/src/db/meta/MetaAdapter.h @@ -0,0 +1,110 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "db/meta/MetaSession.h" +#include "db/meta/backend/MockMetaEngine.h" +#include "db/meta/backend/MySqlEngine.h" +#include "db/snapshot/Resources.h" +#include "utils/Exception.h" + +namespace milvus::engine::meta { + +class MetaAdapter { + public: + explicit MetaAdapter(MetaEnginePtr engine) : engine_(engine) { + } + + SessionPtr + CreateSession() { + return std::make_shared(engine_); + } + + template + Status + Select(int64_t id, typename T::Ptr& resource) { + // TODO move select logic to here + auto session = CreateSession(); + std::vector resources; + auto status = session->Select(snapshot::IdField::Name, {id}, {}, resources); + if (status.ok() && !resources.empty()) { + // TODO: may need to check num of resources + resource = resources.at(0); + } + + return status; + } + + template + Status + SelectBy(const std::string& field, const std::vector& values, std::vector& resources) { + auto session = CreateSession(); + return session->Select(field, values, {}, resources); + } + + template + Status + SelectResourceIDs(std::vector& ids, const std::string& filter_field, const std::vector& filter_values) { + std::vector resources; + auto session = CreateSession(); + auto status = session->Select(filter_field, filter_values, {F_ID}, resources); + if (!status.ok()) { + return status; + } + + for (auto& res : resources) { + ids.push_back(res->GetID()); + } + + return Status::OK(); + } + + template + Status + Apply(snapshot::ResourceContextPtr resp, int64_t& result_id) { + result_id = 0; + + auto session = CreateSession(); + session->Apply(resp); + + std::vector result_ids; + auto status = session->Commit(result_ids); + + if (!status.ok()) { + throw Exception(status.code(), status.message()); + } + + if (result_ids.size() != 1) { + throw Exception(1, "Result id is not equal one ... "); + } + + result_id = result_ids.at(0); + return Status::OK(); + } + + Status + TruncateAll() { + return engine_->TruncateAll(); + } + + private: + MetaEnginePtr engine_; +}; + +using MetaAdapterPtr = std::shared_ptr; + +} // namespace milvus::engine::meta diff --git a/core/src/db/meta/MetaFields.h b/core/src/db/meta/MetaFields.h new file mode 100644 index 000000000000..ac56d1fa5ec1 --- /dev/null +++ b/core/src/db/meta/MetaFields.h @@ -0,0 +1,35 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "db/snapshot/Resources.h" + +namespace milvus::engine::meta { + +extern const char* F_MAPPINGS; +extern const char* F_STATE; +extern const char* F_LSN; +extern const char* F_CREATED_ON; +extern const char* F_UPDATED_ON; +extern const char* F_ID; +extern const char* F_COLLECTON_ID; +extern const char* F_SCHEMA_ID; +extern const char* F_NUM; +extern const char* F_FTYPE; +extern const char* F_FIELD_ID; +extern const char* F_FIELD_ELEMENT_ID; +extern const char* F_PARTITION_ID; +extern const char* F_SEGMENT_ID; +extern const char* F_NAME; +extern const char* F_PARAMS; +extern const char* F_SIZE; +extern const char* F_ROW_COUNT; + +} // namespace milvus::engine::meta diff --git a/core/src/db/meta/MetaResourceAttrs.cpp b/core/src/db/meta/MetaResourceAttrs.cpp new file mode 100644 index 000000000000..51574e72d1ff --- /dev/null +++ b/core/src/db/meta/MetaResourceAttrs.cpp @@ -0,0 +1,75 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "db/meta/MetaResourceAttrs.h" + +#include "db/meta/MetaFields.h" +#include "utils/Status.h" + +namespace milvus::engine::meta { +///////////////////////////////////////////////////////////////// +const char* F_MAPPINGS = snapshot::MappingsField::Name; +const char* F_STATE = snapshot::StateField::Name; +const char* F_LSN = snapshot::LsnField::Name; +const char* F_CREATED_ON = snapshot::CreatedOnField::Name; +const char* F_UPDATED_ON = snapshot::UpdatedOnField::Name; +const char* F_ID = snapshot::IdField::Name; +const char* F_COLLECTON_ID = snapshot::CollectionIdField::Name; +const char* F_SCHEMA_ID = snapshot::SchemaIdField::Name; +const char* F_NUM = snapshot::NumField::Name; +const char* F_FTYPE = snapshot::FtypeField::Name; +const char* F_FIELD_ID = snapshot::FieldIdField::Name; +const char* F_FIELD_ELEMENT_ID = snapshot::FieldElementIdField::Name; +const char* F_PARTITION_ID = snapshot::PartitionIdField::Name; +const char* F_SEGMENT_ID = snapshot::SegmentIdField::Name; +const char* F_NAME = snapshot::NameField::Name; +const char* F_PARAMS = snapshot::ParamsField::Name; +const char* F_SIZE = snapshot::SizeField::Name; +const char* F_ROW_COUNT = snapshot::RowCountField::Name; + +/////////////////////////////////////////////////////////////// +Status +ResourceAttrMapOf(const std::string& table, std::vector& attrs) { + static const std::unordered_map> ResourceAttrMap = { + {snapshot::Collection::Name, {F_NAME, F_PARAMS, F_ID, F_LSN, F_STATE, F_CREATED_ON, F_UPDATED_ON}}, + {snapshot::CollectionCommit::Name, + {F_COLLECTON_ID, F_SCHEMA_ID, F_MAPPINGS, F_ROW_COUNT, F_SIZE, F_ID, F_LSN, F_STATE, F_CREATED_ON, + F_UPDATED_ON}}, + {snapshot::Partition::Name, {F_NAME, F_COLLECTON_ID, F_ID, F_LSN, F_STATE, F_CREATED_ON, F_UPDATED_ON}}, + {snapshot::PartitionCommit::Name, + {F_COLLECTON_ID, F_PARTITION_ID, F_MAPPINGS, F_ROW_COUNT, F_SIZE, F_ID, F_LSN, F_STATE, F_CREATED_ON, + F_UPDATED_ON}}, + {snapshot::Segment::Name, + {F_COLLECTON_ID, F_PARTITION_ID, F_NUM, F_ID, F_LSN, F_STATE, F_CREATED_ON, F_UPDATED_ON}}, + {snapshot::SegmentCommit::Name, + {F_SCHEMA_ID, F_PARTITION_ID, F_SEGMENT_ID, F_MAPPINGS, F_ROW_COUNT, F_SIZE, F_ID, F_LSN, F_STATE, + F_CREATED_ON, F_UPDATED_ON}}, + {snapshot::SegmentFile::Name, + {F_COLLECTON_ID, F_PARTITION_ID, F_SEGMENT_ID, F_FIELD_ELEMENT_ID, F_ROW_COUNT, F_SIZE, F_ID, F_LSN, F_STATE, + F_CREATED_ON, F_UPDATED_ON}}, + {snapshot::SchemaCommit::Name, {F_COLLECTON_ID, F_MAPPINGS, F_ID, F_LSN, F_STATE, F_CREATED_ON, F_UPDATED_ON}}, + {snapshot::Field::Name, {F_NAME, F_NUM, F_FTYPE, F_PARAMS, F_ID, F_LSN, F_STATE, F_CREATED_ON, F_UPDATED_ON}}, + {snapshot::FieldCommit::Name, + {F_COLLECTON_ID, F_FIELD_ID, F_MAPPINGS, F_ID, F_LSN, F_STATE, F_CREATED_ON, F_UPDATED_ON}}, + {snapshot::FieldElement::Name, + {F_COLLECTON_ID, F_FIELD_ID, F_NAME, F_FTYPE, F_PARAMS, F_ID, F_LSN, F_STATE, F_CREATED_ON, F_UPDATED_ON}}, + }; + + if (ResourceAttrMap.find(table) == ResourceAttrMap.end()) { + return Status(SERVER_UNEXPECTED_ERROR, "Cannot not found table " + table + " in ResourceAttrMap"); + } + + attrs = ResourceAttrMap.at(table); + + return Status::OK(); +} + +} // namespace milvus::engine::meta diff --git a/core/src/db/meta/MetaResourceAttrs.h b/core/src/db/meta/MetaResourceAttrs.h new file mode 100644 index 000000000000..8a5129a0e7ff --- /dev/null +++ b/core/src/db/meta/MetaResourceAttrs.h @@ -0,0 +1,245 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "db/meta/MetaFields.h" +#include "db/snapshot/ResourceContext.h" +#include "db/snapshot/Resources.h" +#include "utils/Json.h" +#include "utils/StringHelpFunctions.h" + +namespace milvus::engine::meta { + +// using namespace snapshot; +using snapshot::MappingT; +using snapshot::ResourceContext; + +/////////////////////////// Macros /////////////////////////////// +#define NULLPTR_CHECK(ptr) \ + if (ptr == nullptr) { \ + return Status(SERVER_UNSUPPORTED_ERROR, "Convert pointer failed."); \ + } + +////////////////////////////////////////////////////////////////// +Status +ResourceAttrMapOf(const std::string& table, std::vector& attrs); + +////////////////////////////////////////////////////////////////// +inline void +int2str(const int64_t& ival, std::string& val) { + val = std::to_string(ival); +} + +inline void +uint2str(const uint64_t& uival, std::string& val) { + val = std::to_string(uival); +} + +inline void +state2str(const snapshot::State& sval, std::string& val) { + val = std::to_string(sval); +} + +inline void +mappings2str(const MappingT& mval, std::string& val) { + auto value_json = json::array(); + for (auto& m : mval) { + value_json.emplace_back(m); + } + + val = "\'" + value_json.dump() + "\'"; +} + +inline void +str2str(const std::string& sval, std::string& val) { + val = "\'" + sval + "\'"; +} + +inline void +json2str(const json& jval, std::string& val) { + val = "\'" + jval.dump() + "\'"; +} + +template +inline Status +AttrValue2Str(typename ResourceContext::ResPtr src, const std::string& attr, std::string& value) { + int64_t int_value; + uint64_t uint_value; + snapshot::State state_value; + MappingT mapping_value; + std::string str_value; + json json_value; + + if (attr == F_ID) { + auto id_field = std::dynamic_pointer_cast(src); + int_value = id_field->GetID(); + int2str(int_value, value); + } else if (F_COLLECTON_ID == attr) { + auto collection_id_field = std::dynamic_pointer_cast(src); + int_value = collection_id_field->GetCollectionId(); + int2str(int_value, value); + } else if (F_CREATED_ON == attr) { + auto created_field = std::dynamic_pointer_cast(src); + int_value = created_field->GetCreatedTime(); + int2str(int_value, value); + } else if (F_UPDATED_ON == attr) { + auto updated_field = std::dynamic_pointer_cast(src); + int_value = updated_field->GetUpdatedTime(); + int2str(int_value, value); + } else if (F_SCHEMA_ID == attr) { + auto schema_id_field = std::dynamic_pointer_cast(src); + int_value = schema_id_field->GetSchemaId(); + int2str(int_value, value); + } else if (F_NUM == attr) { + auto num_field = std::dynamic_pointer_cast(src); + int_value = num_field->GetNum(); + int2str(int_value, value); + } else if (F_FTYPE == attr) { + auto ftype_field = std::dynamic_pointer_cast(src); + int_value = ftype_field->GetFtype(); + int2str(int_value, value); + } else if (F_FIELD_ID == attr) { + auto field_id_field = std::dynamic_pointer_cast(src); + int_value = field_id_field->GetFieldId(); + int2str(int_value, value); + } else if (F_FIELD_ELEMENT_ID == attr) { + auto element_id_field = std::dynamic_pointer_cast(src); + int_value = element_id_field->GetFieldElementId(); + int2str(int_value, value); + } else if (F_PARTITION_ID == attr) { + auto partition_id_field = std::dynamic_pointer_cast(src); + int_value = partition_id_field->GetPartitionId(); + int2str(int_value, value); + } else if (F_SEGMENT_ID == attr) { + auto segment_id_field = std::dynamic_pointer_cast(src); + int_value = segment_id_field->GetSegmentId(); + int2str(int_value, value); + } /* Uint field */ else if (F_LSN == attr) { + auto lsn_field = std::dynamic_pointer_cast(src); + uint_value = lsn_field->GetLsn(); + uint2str(uint_value, value); + } else if (F_SIZE == attr) { + auto size_field = std::dynamic_pointer_cast(src); + uint_value = size_field->GetSize(); + uint2str(uint_value, value); + } else if (F_ROW_COUNT == attr) { + auto row_count_field = std::dynamic_pointer_cast(src); + uint_value = row_count_field->GetRowCount(); + uint2str(uint_value, value); + } else if (F_STATE == attr) { + auto state_field = std::dynamic_pointer_cast(src); + state_value = state_field->GetState(); + state2str(state_value, value); + } else if (F_MAPPINGS == attr) { + auto mappings_field = std::dynamic_pointer_cast(src); + mapping_value = mappings_field->GetMappings(); + mappings2str(mapping_value, value); + } else if (F_NAME == attr) { + auto name_field = std::dynamic_pointer_cast(src); + str_value = name_field->GetName(); + str2str(str_value, value); + } else if (F_PARAMS == attr) { + auto params_field = std::dynamic_pointer_cast(src); + json_value = params_field->GetParams(); + json2str(json_value, value); + } else { + return Status(SERVER_UNSUPPORTED_ERROR, "Unknown field attr: " + attr); + } + + return Status::OK(); +} + +template +inline Status +ResourceContextAddAttrMap(snapshot::ResourceContextPtr src, + std::unordered_map& attr_map) { + std::vector attrs; + auto status = ResourceAttrMapOf(ResourceT::Name, attrs); + if (!status.ok()) { + return status; + } + + for (auto& attr : attrs) { + if (attr == F_ID) { + continue; + } + + std::string value; + AttrValue2Str(src->Resource(), attr, value); + attr_map[attr] = value; + } + + return Status::OK(); +} + +template +inline Status +ResourceContextUpdateAttrMap(snapshot::ResourceContextPtr res, + std::unordered_map& attr_map) { + std::string value; + for (auto& attr : res->Attrs()) { + AttrValue2Str(res->Resource(), attr, value); + attr_map[attr] = value; + } + + return Status::OK(); +} + +///////////////////////////////////////////////////////////////////// +template +inline void +ResourceFieldToSqlStr(const T& t, std::string& val) { + val = ""; +} + +template <> +inline void +ResourceFieldToSqlStr(const int64_t& ival, std::string& val) { + int2str(ival, val); +} + +template <> +inline void +ResourceFieldToSqlStr(const uint64_t& uival, std::string& val) { + uint2str(uival, val); +} + +template <> +inline void +ResourceFieldToSqlStr(const snapshot::State& sval, std::string& val) { + state2str(sval, val); +} + +template <> +inline void +ResourceFieldToSqlStr(const MappingT& mval, std::string& val) { + mappings2str(mval, val); +} + +template <> +inline void +ResourceFieldToSqlStr(const std::string& sval, std::string& val) { + str2str(sval, val); +} + +template <> +inline void +ResourceFieldToSqlStr(const json& jval, std::string& val) { + json2str(jval, val); +} + +} // namespace milvus::engine::meta diff --git a/core/src/db/meta/MetaSession.h b/core/src/db/meta/MetaSession.h new file mode 100644 index 000000000000..af8e57cf78db --- /dev/null +++ b/core/src/db/meta/MetaSession.h @@ -0,0 +1,345 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "db/meta/MetaResourceAttrs.h" +#include "db/meta/backend/MetaEngine.h" +#include "db/snapshot/BaseResource.h" +#include "db/snapshot/ResourceContext.h" +#include "db/snapshot/ResourceHelper.h" +#include "db/snapshot/Resources.h" +#include "db/snapshot/Utils.h" +#include "utils/Exception.h" +#include "utils/Json.h" +#include "utils/Status.h" + +namespace milvus::engine::meta { + +class MetaSession { + public: + explicit MetaSession(MetaEnginePtr engine) : db_engine_(engine), pos_(-1) { + } + + ~MetaSession() = default; + + public: + template + Status + Select(const std::string& field, const std::vector& value, const std::vector& target_attrs, + std::vector& resources); + + template + Status + Apply(snapshot::ResourceContextPtr resp); + + Status + ResultPos() { + if (apply_context_.empty()) { + return Status(SERVER_UNEXPECTED_ERROR, "Session is empty"); + } + pos_ = apply_context_.size() - 1; + + return Status::OK(); + } + + Status + Commit(std::vector& result_ids) { + return db_engine_->ExecuteTransaction(apply_context_, result_ids); + } + + Status + Commit(int64_t& result_id) { + if (apply_context_.empty()) { + return Status::OK(); + } + + if (pos_ < 0) { + throw Exception(1, "Result pos is small than 0"); + // return Status(SERVER_UNEXPECTED_ERROR, "Result pos is small than 0"); + } + std::vector result_ids; + auto status = db_engine_->ExecuteTransaction(apply_context_, result_ids); + if (!status.ok()) { + return status; + } + + result_id = result_ids.at(pos_); + return Status::OK(); + } + + private: + std::vector apply_context_; + int64_t pos_; + MetaEnginePtr db_engine_; +}; + +template +Status +MetaSession::Select(const std::string& field, const std::vector& values, + const std::vector& target_attrs, std::vector& resources) { + MetaQueryContext context; + context.table_ = T::Name; + + if (!field.empty()) { + std::vector field_values; + for (auto& v : values) { + std::string field_value; + ResourceFieldToSqlStr(v, field_value); + field_values.push_back(field_value); + } + context.filter_attrs_ = {{field, field_values}}; + } + + if (!target_attrs.empty()) { + context.all_required_ = false; + context.query_fields_ = target_attrs; + } + + AttrsMapList attrs; + auto status = db_engine_->Query(context, attrs); + if (!status.ok()) { + return status; + } + + if (attrs.empty()) { + return Status::OK(); + } + + for (auto raw : attrs) { + auto resource = snapshot::CreateResPtr(); + std::unordered_map::iterator iter; + auto mf_p = std::dynamic_pointer_cast(resource); + if (mf_p != nullptr) { + iter = raw.find(F_MAPPINGS); + if (iter != raw.end()) { + auto mapping_json = nlohmann::json::parse(iter->second); + std::set mappings; + for (auto& ele : mapping_json) { + mappings.insert(ele.get()); + } + mf_p->GetMappings() = mappings; + } + } + + auto sf_p = std::dynamic_pointer_cast(resource); + if (sf_p != nullptr) { + iter = raw.find(F_STATE); + if (iter != raw.end()) { + auto status_int = std::stol(iter->second); + sf_p->ResetStatus(); + switch (static_cast(status_int)) { + case snapshot::PENDING: { + break; + } + case snapshot::ACTIVE: { + sf_p->Activate(); + break; + } + case snapshot::DEACTIVE: { + sf_p->Deactivate(); + break; + } + default: { return Status(SERVER_UNSUPPORTED_ERROR, "Invalid state value"); } + } + } + } + + auto lsn_f = std::dynamic_pointer_cast(resource); + if (lsn_f != nullptr) { + iter = raw.find(F_LSN); + if (iter != raw.end()) { + auto lsn = std::stoul(iter->second); + lsn_f->SetLsn(lsn); + } + } + + auto created_on_f = std::dynamic_pointer_cast(resource); + if (created_on_f != nullptr) { + iter = raw.find(F_CREATED_ON); + if (iter != raw.end()) { + auto created_on = std::stol(iter->second); + created_on_f->SetCreatedTime(created_on); + } + } + + auto update_on_p = std::dynamic_pointer_cast(resource); + if (update_on_p != nullptr) { + iter = raw.find(F_UPDATED_ON); + if (iter != raw.end()) { + auto update_on = std::stol(iter->second); + update_on_p->SetUpdatedTime(update_on); + } + } + + auto id_p = std::dynamic_pointer_cast(resource); + if (id_p != nullptr) { + iter = raw.find(F_ID); + if (iter != raw.end()) { + auto t_id = std::stol(iter->second); + id_p->SetID(t_id); + } + } + + auto cid_p = std::dynamic_pointer_cast(resource); + if (cid_p != nullptr) { + iter = raw.find(F_COLLECTON_ID); + if (iter != raw.end()) { + auto cid = std::stol(iter->second); + cid_p->SetCollectionId(cid); + } + } + + auto sid_p = std::dynamic_pointer_cast(resource); + if (sid_p != nullptr) { + iter = raw.find(F_SCHEMA_ID); + if (iter != raw.end()) { + auto sid = std::stol(iter->second); + sid_p->SetSchemaId(sid); + } + } + + auto num_p = std::dynamic_pointer_cast(resource); + if (num_p != nullptr) { + iter = raw.find(F_NUM); + if (iter != raw.end()) { + auto num = std::stol(iter->second); + num_p->SetNum(num); + } + } + + auto ftype_p = std::dynamic_pointer_cast(resource); + if (ftype_p != nullptr) { + iter = raw.find(F_FTYPE); + if (iter != raw.end()) { + auto ftype = std::stol(iter->second); + ftype_p->SetFtype(ftype); + } + } + + auto fid_p = std::dynamic_pointer_cast(resource); + if (fid_p != nullptr) { + iter = raw.find(F_FIELD_ID); + if (iter != raw.end()) { + auto fid = std::stol(iter->second); + fid_p->SetFieldId(fid); + } + } + + auto feid_p = std::dynamic_pointer_cast(resource); + if (feid_p != nullptr) { + iter = raw.find(F_FIELD_ELEMENT_ID); + if (iter != raw.end()) { + auto feid = std::stol(iter->second); + feid_p->SetFieldElementId(feid); + } + } + + auto pid_p = std::dynamic_pointer_cast(resource); + if (pid_p != nullptr) { + iter = raw.find(F_PARTITION_ID); + if (iter != raw.end()) { + auto p_id = std::stol(iter->second); + pid_p->SetPartitionId(p_id); + } + } + + auto sgid_p = std::dynamic_pointer_cast(resource); + if (sgid_p != nullptr) { + iter = raw.find(F_SEGMENT_ID); + if (iter != raw.end()) { + auto sg_id = std::stol(iter->second); + sgid_p->SetSegmentId(sg_id); + } + } + + auto name_p = std::dynamic_pointer_cast(resource); + if (name_p != nullptr) { + iter = raw.find(F_NAME); + if (iter != raw.end()) { + name_p->SetName(iter->second); + } + } + + auto pf_p = std::dynamic_pointer_cast(resource); + if (pf_p != nullptr) { + iter = raw.find(F_PARAMS); + if (iter != raw.end()) { + auto params = nlohmann::json::parse(iter->second); + pf_p->SetParams(params); + } + } + + auto size_p = std::dynamic_pointer_cast(resource); + if (size_p != nullptr) { + iter = raw.find(F_SIZE); + if (iter != raw.end()) { + auto size = std::stol(iter->second); + size_p->SetSize(size); + } + } + + auto rc_p = std::dynamic_pointer_cast(resource); + if (rc_p != nullptr) { + iter = raw.find(F_ROW_COUNT); + if (iter != raw.end()) { + auto rc = std::stol(iter->second); + rc_p->SetRowCount(rc); + } + } + + resources.push_back(std::move(resource)); + } + + return Status::OK(); +} + +template +Status +MetaSession::Apply(snapshot::ResourceContextPtr resp) { + // TODO: may here not need to store resp + auto status = Status::OK(); + std::string sql; + + MetaApplyContext context; + context.op_ = resp->Op(); + if (context.op_ == oAdd) { + status = ResourceContextAddAttrMap(resp, context.attrs_); + } else if (context.op_ == oUpdate) { + status = ResourceContextUpdateAttrMap(resp, context.attrs_); + context.id_ = resp->Resource()->GetID(); + } else if (context.op_ == oDelete) { + context.id_ = resp->ID(); + } else { + return Status(SERVER_UNEXPECTED_ERROR, "Unknown resource context operation"); + } + + if (!status.ok()) { + return status; + } + + context.table_ = resp->Table(); + apply_context_.push_back(context); + + return Status::OK(); +} + +using SessionPtr = std::shared_ptr; + +} // namespace milvus::engine::meta diff --git a/core/src/db/meta/MetaTypes.h b/core/src/db/meta/MetaTypes.h index 029ca40a066a..4766715fa6ba 100644 --- a/core/src/db/meta/MetaTypes.h +++ b/core/src/db/meta/MetaTypes.h @@ -17,16 +17,82 @@ #include #include "db/Constants.h" -#include "db/engine/ExecutionEngine.h" +#include "knowhere/index/IndexType.h" #include "src/version.h" namespace milvus { namespace engine { + +// TODO(linxj): replace with VecIndex::IndexType +enum class EngineType { + INVALID = 0, + FAISS_IDMAP = 1, + FAISS_IVFFLAT = 2, + FAISS_IVFSQ8 = 3, + NSG_MIX = 4, + FAISS_IVFSQ8H = 5, + FAISS_PQ = 6, +#ifdef MILVUS_SUPPORT_SPTAG + SPTAG_KDT = 7, + SPTAG_BKT = 8, +#endif + FAISS_BIN_IDMAP = 9, + FAISS_BIN_IVFFLAT = 10, + HNSW = 11, + ANNOY = 12, + FAISS_IVFSQ8NR = 13, + HNSW_SQ8NM = 14, + MAX_VALUE = HNSW_SQ8NM, +}; + +static std::map s_map_engine_type = { + {knowhere::IndexEnum::INDEX_FAISS_IDMAP, EngineType::FAISS_IDMAP}, + {knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, EngineType::FAISS_IVFFLAT}, + {knowhere::IndexEnum::INDEX_FAISS_IVFPQ, EngineType::FAISS_PQ}, + {knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, EngineType::FAISS_IVFSQ8}, + {knowhere::IndexEnum::INDEX_FAISS_IVFSQ8NR, EngineType::FAISS_IVFSQ8NR}, + {knowhere::IndexEnum::INDEX_FAISS_IVFSQ8H, EngineType::FAISS_IVFSQ8H}, + {knowhere::IndexEnum::INDEX_NSG, EngineType::NSG_MIX}, +#ifdef MILVUS_SUPPORT_SPTAG + {knowhere::IndexEnum::INDEX_SPTAG_KDT_RNT, EngineType::SPTAG_KDT}, + {knowhere::IndexEnum::INDEX_SPTAG_BKT_RNT, EngineType::SPTAG_BKT}, +#endif + {knowhere::IndexEnum::INDEX_HNSW, EngineType::HNSW}, + {knowhere::IndexEnum::INDEX_HNSW_SQ8NM, EngineType::HNSW_SQ8NM}, + {knowhere::IndexEnum::INDEX_ANNOY, EngineType::ANNOY}}; + +enum class MetricType { + INVALID = 0, + L2 = 1, // Euclidean Distance + IP = 2, // Cosine Similarity + HAMMING = 3, // Hamming Distance + JACCARD = 4, // Jaccard Distance + TANIMOTO = 5, // Tanimoto Distance + SUBSTRUCTURE = 6, // Substructure Distance + SUPERSTRUCTURE = 7, // Superstructure Distance + MAX_VALUE = SUPERSTRUCTURE +}; + +static std::map s_map_metric_type = { + {"L2", MetricType::L2}, + {"IP", MetricType::IP}, + {"HAMMING", MetricType::HAMMING}, + {"JACCARD", MetricType::JACCARD}, + {"TANIMOTO", MetricType::TANIMOTO}, + {"SUBSTRUCTURE", MetricType::SUBSTRUCTURE}, + {"SUPERSTRUCTURE", MetricType::SUPERSTRUCTURE}, +}; + +enum class StructuredIndexType { + INVALID = 0, + SORTED = 1, +}; + namespace meta { constexpr int32_t DEFAULT_ENGINE_TYPE = (int)EngineType::FAISS_IDMAP; constexpr int32_t DEFAULT_METRIC_TYPE = (int)MetricType::L2; -constexpr int32_t DEFAULT_INDEX_FILE_SIZE = GB; +constexpr int32_t DEFAULT_INDEX_FILE_SIZE = 1024; constexpr char CURRENT_VERSION[] = MILVUS_VERSION; constexpr int64_t FLAG_MASK_NO_USERID = 0x1; @@ -101,22 +167,24 @@ using Table2FileRef = std::map; namespace hybrid { -enum class DataType { - INT8 = 1, - INT16 = 2, - INT32 = 3, - INT64 = 4, +enum DataType { + NONE = 0, + BOOL = 1, + INT8 = 2, + INT16 = 3, + INT32 = 4, + INT64 = 5, - STRING = 20, + FLOAT = 10, + DOUBLE = 11, - BOOL = 30, + STRING = 20, - FLOAT = 40, - DOUBLE = 41, + UID = 30, - FLOAT_VECTOR = 100, - BINARY_VECTOR = 101, - UNKNOWN = 9999, + VECTOR_BINARY = 100, + VECTOR_FLOAT = 101, + VECTOR = 200, }; struct VectorFieldSchema { @@ -134,23 +202,6 @@ struct VectorFieldsSchema { using VectorFieldSchemaPtr = std::shared_ptr; struct FieldSchema { - typedef enum { - INT8 = 1, - INT16 = 2, - INT32 = 3, - INT64 = 4, - - STRING = 20, - - BOOL = 30, - - FLOAT = 40, - DOUBLE = 41, - - VECTOR = 100, - UNKNOWN = 9999, - } FIELD_TYPE; - // TODO(yukun): need field_id? std::string collection_id_; std::string field_name_; diff --git a/core/src/db/meta/MySQLConnectionPool.h b/core/src/db/meta/MySQLConnectionPool.h index 523ffe00eb0e..2252b2002d80 100644 --- a/core/src/db/meta/MySQLConnectionPool.h +++ b/core/src/db/meta/MySQLConnectionPool.h @@ -9,12 +9,15 @@ // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express // or implied. See the License for the specific language governing permissions and limitations under the License. -#include +#pragma once #include + #include #include +#include + #include "utils/Log.h" namespace milvus { diff --git a/core/src/db/meta/backend/MetaContext.h b/core/src/db/meta/backend/MetaContext.h new file mode 100644 index 000000000000..2c7dd9d563d9 --- /dev/null +++ b/core/src/db/meta/backend/MetaContext.h @@ -0,0 +1,39 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include +#include + +namespace milvus::engine::meta { + +enum MetaContextOp { oAdd = 1, oUpdate, oDelete }; + +struct MetaQueryContext { + std::string table_; + bool all_required_ = true; + std::vector query_fields_; + std::unordered_map> filter_attrs_; +}; + +struct MetaApplyContext { + std::string table_; + MetaContextOp op_; + int64_t id_; + std::unordered_map attrs_; + std::unordered_map filter_attrs_; + std::string sql_; +}; + +} // namespace milvus::engine::meta diff --git a/core/src/db/meta/backend/MetaEngine.h b/core/src/db/meta/backend/MetaEngine.h new file mode 100644 index 000000000000..e0b52a81da83 --- /dev/null +++ b/core/src/db/meta/backend/MetaEngine.h @@ -0,0 +1,42 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "db/meta/backend/MetaContext.h" +#include "db/snapshot/ResourceTypes.h" +#include "utils/Status.h" + +namespace milvus::engine::meta { + +using AttrsMap = std::unordered_map; +using AttrsMapList = std::vector; + +class MetaEngine { + public: + virtual Status + Query(const MetaQueryContext& context, AttrsMapList& attrs) = 0; + + virtual Status + ExecuteTransaction(const std::vector& sql_contexts, std::vector& result_ids) = 0; + + virtual Status + TruncateAll() = 0; +}; + +using MetaEnginePtr = std::shared_ptr; + +} // namespace milvus::engine::meta diff --git a/core/src/db/meta/backend/MetaHelper.cpp b/core/src/db/meta/backend/MetaHelper.cpp new file mode 100644 index 000000000000..0ef1cf05d9ee --- /dev/null +++ b/core/src/db/meta/backend/MetaHelper.cpp @@ -0,0 +1,98 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "db/meta/backend/MetaHelper.h" + +#include + +#include "db/meta/backend/MetaContext.h" +#include "utils/StringHelpFunctions.h" + +namespace milvus::engine::meta { + +Status +MetaHelper::MetaQueryContextToSql(const MetaQueryContext& context, std::string& sql) { + sql.clear(); + if (context.all_required_) { + sql = "SELECT * FROM "; + } else { + std::string query_fields; + StringHelpFunctions::MergeStringWithDelimeter(context.query_fields_, ",", query_fields); + sql = "SELECT " + query_fields + " FROM "; + } + sql += context.table_; + + std::vector filter_conditions; + for (auto& attr : context.filter_attrs_) { + std::string filter_str; + if (attr.second.size() < 1) { + return Status(SERVER_UNEXPECTED_ERROR, "Invalid filter attrs. "); + } else if (attr.second.size() == 1) { + filter_conditions.emplace_back(attr.first + "=" + attr.second[0]); + } else { + std::string in_condition; + StringHelpFunctions::MergeStringWithDelimeter(attr.second, ",", in_condition); + in_condition = attr.first + " IN (" + in_condition + ")"; + filter_conditions.emplace_back(in_condition); + } + + StringHelpFunctions::MergeStringWithDelimeter(filter_conditions, " AND ", filter_str); + sql += " WHERE " + filter_str; + } + + sql += ";"; + + return Status::OK(); +} + +Status +MetaHelper::MetaApplyContextToSql(const MetaApplyContext& context, std::string& sql) { + if (!context.sql_.empty()) { + sql = context.sql_; + return Status::OK(); + } + + switch (context.op_) { + case oAdd: { + std::string field_names, values; + std::vector field_list, value_list; + for (auto& kv : context.attrs_) { + field_list.push_back(kv.first); + value_list.push_back(kv.second); + } + StringHelpFunctions::MergeStringWithDelimeter(field_list, ",", field_names); + StringHelpFunctions::MergeStringWithDelimeter(value_list, ",", values); + sql = "INSERT INTO " + context.table_ + "(" + field_names + ") " + "VALUES(" + values + ")"; + break; + } + case oUpdate: { + std::string field_pairs; + std::vector updated_attrs; + for (auto& attr_kv : context.attrs_) { + updated_attrs.emplace_back(attr_kv.first + "=" + attr_kv.second); + } + + StringHelpFunctions::MergeStringWithDelimeter(updated_attrs, ",", field_pairs); + sql = "UPDATE " + context.table_ + " SET " + field_pairs + " WHERE id = " + std::to_string(context.id_); + break; + } + case oDelete: { + sql = "DELETE FROM " + context.table_ + " WHERE id = " + std::to_string(context.id_); + break; + } + default: + return Status(SERVER_UNEXPECTED_ERROR, "Unknown context operation"); + } + + return Status::OK(); +} + +} // namespace milvus::engine::meta diff --git a/core/src/db/meta/backend/MetaHelper.h b/core/src/db/meta/backend/MetaHelper.h new file mode 100644 index 000000000000..6d91ddf61af2 --- /dev/null +++ b/core/src/db/meta/backend/MetaHelper.h @@ -0,0 +1,33 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include + +#include "db/meta/backend/MetaContext.h" +#include "utils/Status.h" + +namespace milvus::engine::meta { + +class MetaHelper { + private: + MetaHelper() = default; + + public: + static Status + MetaQueryContextToSql(const MetaQueryContext& context, std::string& sql); + + static Status + MetaApplyContextToSql(const MetaApplyContext& context, std::string& sql); +}; + +} // namespace milvus::engine::meta diff --git a/core/src/db/meta/backend/MockMetaEngine.cpp b/core/src/db/meta/backend/MockMetaEngine.cpp new file mode 100644 index 000000000000..bcaf5a095ec4 --- /dev/null +++ b/core/src/db/meta/backend/MockMetaEngine.cpp @@ -0,0 +1,255 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "db/meta/backend/MockMetaEngine.h" + +#include + +#include "db/meta/MetaFields.h" +#include "utils/StringHelpFunctions.h" + +namespace milvus::engine::meta { + +void +MockMetaEngine::Init() { + max_ip_map_.clear(); + resources_.clear(); +} + +Status +MockMetaEngine::QueryNoLock(const MetaQueryContext& context, AttrsMapList& attrs) { + if (resources_.find(context.table_) == resources_.end()) { + return Status(0, "Empty"); + } + + auto select_target_attrs = [](const TableRaw& raw, AttrsMapList& des, + const std::vector& target_attrs) { + if (target_attrs.empty()) { + return; + } + + auto m = std::unordered_map(); + for (auto& attr : target_attrs) { + auto iter = raw.find(attr); + if (iter != raw.end()) { + m.insert(std::make_pair(iter->first, iter->second)); + } + } + if (!m.empty()) { + des.push_back(m); + } + }; + + auto term = [](const std::string& attr, const std::vector& attrs) -> bool { + for (auto& t : attrs) { + if (attr == t) { + return true; + } + } + + return false; + }; + + auto& candidate_raws = resources_[context.table_]; + + bool selected = true; + if (!context.filter_attrs_.empty()) { + for (auto& raw : candidate_raws) { + for (auto& filter_attr : context.filter_attrs_) { + auto iter = raw.find(filter_attr.first); + if (iter == raw.end()) { + selected = false; + break; + } + + if (!term(iter->second, filter_attr.second)) { + selected = false; + break; + } + } + if (selected) { + if (context.all_required_) { + attrs.push_back(raw); + } else { + select_target_attrs(raw, attrs, context.query_fields_); + } + } + selected = true; + } + } else { + if (context.all_required_) { + attrs = candidate_raws; + } else { + for (auto& attr : candidate_raws) { + select_target_attrs(attr, attrs, context.query_fields_); + } + } + } + + for (auto& result_raw : attrs) { + for (auto& kv : result_raw) { + if (*kv.second.begin() == '\'' && *kv.second.rbegin() == '\'') { + std::string v = kv.second; + StringHelpFunctions::TrimStringQuote(v, "\'"); + kv.second = v; + } + } + } + + return Status::OK(); +} + +Status +MockMetaEngine::AddNoLock(const MetaApplyContext& add_context, int64_t& result_id, TableRaw& pre_raw) { + if (max_ip_map_.find(add_context.table_) == max_ip_map_.end() || + resources_.find(add_context.table_) == resources_.end()) { + max_ip_map_[add_context.table_] = 0; + resources_[add_context.table_] = std::vector(); + } + + auto max_id = max_ip_map_[add_context.table_]; + max_ip_map_[add_context.table_] = max_id + 1; + + TableRaw new_raw; + for (auto& attr : add_context.attrs_) { + new_raw.insert(attr); + } + + new_raw[F_ID] = std::to_string(max_id + 1); + resources_[add_context.table_].push_back(new_raw); + pre_raw = new_raw; + result_id = max_id + 1; + + return Status::OK(); +} + +Status +MockMetaEngine::UpdateNoLock(const MetaApplyContext& update_context, int64_t& result_id, TableRaw& pre_raw) { + const std::string id_str = std::to_string(update_context.id_); + + auto& target_collection = resources_[update_context.table_]; + for (auto& attrs : target_collection) { + if (attrs[F_ID] == id_str) { + pre_raw = attrs; + for (auto& kv : update_context.attrs_) { + attrs[kv.first] = kv.second; + } + result_id = update_context.id_; + return Status::OK(); + } + } + + std::string err = "Cannot found resource in " + update_context.table_ + " where id = " + id_str; + return Status(SERVER_UNEXPECTED_ERROR, err); +} + +Status +MockMetaEngine::DeleteNoLock(const MetaApplyContext& delete_context, int64_t& result_id, TableRaw& pre_raw) { + const std::string id_str = std::to_string(delete_context.id_); + auto& target_collection = resources_[delete_context.table_]; + + for (auto iter = target_collection.begin(); iter != target_collection.end(); iter++) { + if ((*iter)[F_ID] == id_str) { + pre_raw = *iter; + result_id = std::stol(iter->at(F_ID)); + target_collection.erase(iter); + return Status::OK(); + } + } + + std::string err = "Cannot found resource in " + delete_context.table_ + " where id = " + id_str; + return Status(SERVER_UNEXPECTED_ERROR, err); +} + +Status +MockMetaEngine::Query(const MetaQueryContext& context, AttrsMapList& attrs) { + std::lock_guard lock(mutex_); + return QueryNoLock(context, attrs); +} + +Status +MockMetaEngine::ExecuteTransaction(const std::vector& sql_contexts, + std::vector& result_ids) { + std::unique_lock lock(mutex_); + + auto status = Status::OK(); + std::vector> pair_entities; + TableRaw raw; + for (auto& context : sql_contexts) { + int64_t id; + if (context.op_ == oAdd) { + status = AddNoLock(context, id, raw); + } else if (context.op_ == oUpdate) { + status = UpdateNoLock(context, id, raw); + } else if (context.op_ == oDelete) { + status = DeleteNoLock(context, id, raw); + } else { + status = Status(SERVER_UNEXPECTED_ERROR, "Unknown resource context"); + } + + if (!status.ok()) { + break; + } + result_ids.push_back(id); + pair_entities.emplace_back(context.op_, TableEntity(context.table_, raw)); + } + + if (!status.ok()) { + RollBackNoLock(pair_entities); + } + + return status; +} + +Status +MockMetaEngine::RollBackNoLock(const std::vector>& pre_entities) { + for (auto& o_e : pre_entities) { + auto table = o_e.second.first; + if (o_e.first == oAdd) { + auto id = std::stol(o_e.second.second.at(F_ID)); + max_ip_map_[table] = id - 1; + auto& table_res = resources_[table]; + for (size_t i = 0; i < table_res.size(); i++) { + auto store_id = std::stol(table_res[i].at(F_ID)); + if (store_id == id) { + table_res.erase(table_res.begin() + i, table_res.begin() + i + 1); + break; + } + } + } else if (o_e.first == oUpdate) { + auto id = std::stol(o_e.second.second.at(F_ID)); + auto& table_res = resources_[table]; + for (size_t j = 0; j < table_res.size(); j++) { + auto store_id = std::stol(table_res[j].at(F_ID)); + if (store_id == id) { + table_res.erase(table_res.begin() + j, table_res.begin() + j + 1); + table_res.push_back(o_e.second.second); + break; + } + } + } else if (o_e.first == oDelete) { + resources_[o_e.second.first].push_back(o_e.second.second); + } else { + continue; + } + } + + return Status::OK(); +} + +Status +MockMetaEngine::TruncateAll() { + max_ip_map_.clear(); + resources_.clear(); + return Status::OK(); +} + +} // namespace milvus::engine::meta diff --git a/core/src/db/meta/backend/MockMetaEngine.h b/core/src/db/meta/backend/MockMetaEngine.h new file mode 100644 index 000000000000..af2a400188ea --- /dev/null +++ b/core/src/db/meta/backend/MockMetaEngine.h @@ -0,0 +1,71 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "db/meta/backend/MetaEngine.h" +#include "utils/Status.h" + +namespace milvus::engine::meta { + +class MockMetaEngine : public MetaEngine { + private: + using TableRaw = std::unordered_map; + using TableEntity = std::pair; + + public: + MockMetaEngine() { + Init(); + } + + ~MockMetaEngine() = default; + + Status + Query(const MetaQueryContext& context, AttrsMapList& attrs) override; + + Status + ExecuteTransaction(const std::vector& sql_contexts, std::vector& result_ids) override; + + Status + TruncateAll() override; + + private: + void + Init(); + + Status + QueryNoLock(const MetaQueryContext& context, AttrsMapList& attrs); + + Status + AddNoLock(const MetaApplyContext& add_context, int64_t& result_id, TableRaw& pre_raw); + + Status + UpdateNoLock(const MetaApplyContext& add_context, int64_t& result_id, TableRaw& pre_raw); + + Status + DeleteNoLock(const MetaApplyContext& add_context, int64_t& result_id, TableRaw& pre_raw); + + Status + RollBackNoLock(const std::vector>& pre_raws); + + private: + std::mutex mutex_; + std::unordered_map max_ip_map_; + std::unordered_map> resources_; +}; + +} // namespace milvus::engine::meta diff --git a/core/src/db/meta/backend/MySqlEngine.cpp b/core/src/db/meta/backend/MySqlEngine.cpp new file mode 100644 index 000000000000..cb6dea8e4685 --- /dev/null +++ b/core/src/db/meta/backend/MySqlEngine.cpp @@ -0,0 +1,454 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "db/meta/backend/MySqlEngine.h" + +#include + +#include +#include +#include +#include + +#include + +#include "db/Utils.h" +#include "db/meta/MetaFields.h" +#include "db/meta/backend/MetaHelper.h" +#include "utils/Exception.h" +#include "utils/StringHelpFunctions.h" + +namespace milvus::engine::meta { + +////////// private namespace ////////// +namespace { +class MetaField { + public: + MetaField(const std::string& name, const std::string& type, const std::string& setting) + : name_(name), type_(type), setting_(setting) { + } + + std::string + name() const { + return name_; + } + + std::string + ToString() const { + return name_ + " " + type_ + " " + setting_; + } + + // mysql field type has additional information. for instance, a filed type is defined as 'BIGINT' + // we get the type from sql is 'bigint(20)', so we need to ignore the '(20)' + bool + IsEqual(const MetaField& field) const { + size_t name_len_min = field.name_.length() > name_.length() ? name_.length() : field.name_.length(); + size_t type_len_min = field.type_.length() > type_.length() ? type_.length() : field.type_.length(); + + // only check field type, don't check field width, for example: VARCHAR(255) and VARCHAR(100) is equal + std::vector type_split; + milvus::StringHelpFunctions::SplitStringByDelimeter(type_, "(", type_split); + if (!type_split.empty()) { + type_len_min = type_split[0].length() > type_len_min ? type_len_min : type_split[0].length(); + } + + // field name must be equal, ignore type width + return strncasecmp(field.name_.c_str(), name_.c_str(), name_len_min) == 0 && + strncasecmp(field.type_.c_str(), type_.c_str(), type_len_min) == 0; + } + + private: + std::string name_; + std::string type_; + std::string setting_; +}; + +using MetaFields = std::vector; + +class MetaSchema { + public: + MetaSchema(const std::string& name, const MetaFields& fields) : name_(name), fields_(fields), constraint_fields_() { + } + + MetaSchema(const std::string& name, const MetaFields& fields, const MetaFields& constraints) + : name_(name), fields_(fields), constraint_fields_(constraints) { + } + + std::string + name() const { + return name_; + } + + std::string + ToString() const { + std::string result; + for (auto& field : fields_) { + if (!result.empty()) { + result += ","; + } + result += field.ToString(); + } + + std::string constraints; + for (auto& constraint : constraint_fields_) { + if (!constraints.empty()) { + constraints += ","; + } + constraints += constraint.name(); + } + + if (!constraints.empty()) { + result += ",constraint uq unique(" + constraints + ")"; + } + + return result; + } + + // if the outer fields contains all this MetaSchema fields, return true + // otherwise return false + bool + IsEqual(const MetaFields& fields) const { + std::vector found_field; + for (const auto& this_field : fields_) { + for (const auto& outer_field : fields) { + if (this_field.IsEqual(outer_field)) { + found_field.push_back(this_field.name()); + break; + } + } + } + + return found_field.size() == fields_.size(); + } + + private: + std::string name_; + MetaFields fields_; + MetaFields constraint_fields_; +}; + +static const MetaField MetaIdField = MetaField(F_ID, "BIGINT", "PRIMARY KEY AUTO_INCREMENT"); +static const MetaField MetaCollectionIdField = MetaField(F_COLLECTON_ID, "BIGINT", "NOT NULL"); +static const MetaField MetaPartitionIdField = MetaField(F_PARTITION_ID, "BIGINT", "NOT NULL"); +static const MetaField MetaSchemaIdField = MetaField(F_SCHEMA_ID, "BIGINT", "NOT NULL"); +static const MetaField MetaSegmentIdField = MetaField(F_SEGMENT_ID, "BIGINT", "NOT NULL"); +static const MetaField MetaFieldElementIdField = MetaField(F_FIELD_ELEMENT_ID, "BIGINT", "NOT NULL"); +static const MetaField MetaFieldIdField = MetaField(F_FIELD_ID, "BIGINT", "NOT NULL"); +static const MetaField MetaNameField = MetaField(F_NAME, "VARCHAR(255)", "NOT NULL"); +static const MetaField MetaMappingsField = MetaField(F_MAPPINGS, "JSON", "NOT NULL"); +static const MetaField MetaNumField = MetaField(F_NUM, "BIGINT", "NOT NULL"); +static const MetaField MetaLSNField = MetaField(F_LSN, "BIGINT", "NOT NULL"); +static const MetaField MetaFtypeField = MetaField(F_FTYPE, "BIGINT", "NOT NULL"); +static const MetaField MetaStateField = MetaField(F_STATE, "TINYINT", "NOT NULL"); +static const MetaField MetaCreatedOnField = MetaField(F_CREATED_ON, "BIGINT", "NOT NULL"); +static const MetaField MetaUpdatedOnField = MetaField(F_UPDATED_ON, "BIGINT", "NOT NULL"); +static const MetaField MetaParamsField = MetaField(F_PARAMS, "JSON", "NOT NULL"); +static const MetaField MetaSizeField = MetaField(F_SIZE, "BIGINT", "NOT NULL"); +static const MetaField MetaRowCountField = MetaField(F_ROW_COUNT, "BIGINT", "NOT NULL"); + +// Environment schema +static const MetaSchema COLLECTION_SCHEMA(snapshot::Collection::Name, + {MetaIdField, MetaNameField, MetaLSNField, MetaParamsField, MetaStateField, + MetaCreatedOnField, MetaUpdatedOnField}); + +// Tables schema +static const MetaSchema COLLECTIONCOMMIT_SCHEMA(snapshot::CollectionCommit::Name, + {MetaIdField, MetaCollectionIdField, MetaSchemaIdField, + MetaMappingsField, MetaRowCountField, MetaSizeField, MetaLSNField, + MetaStateField, MetaCreatedOnField, MetaUpdatedOnField}); + +// TableFiles schema +static const MetaSchema PARTITION_SCHEMA(snapshot::Partition::Name, + {MetaIdField, MetaNameField, MetaCollectionIdField, MetaLSNField, + MetaStateField, MetaCreatedOnField, MetaUpdatedOnField}); + +// Fields schema +static const MetaSchema PARTITIONCOMMIT_SCHEMA(snapshot::PartitionCommit::Name, + {MetaIdField, MetaCollectionIdField, MetaPartitionIdField, + MetaMappingsField, MetaRowCountField, MetaSizeField, MetaStateField, + MetaLSNField, MetaCreatedOnField, MetaUpdatedOnField}); + +static const MetaSchema SEGMENT_SCHEMA(snapshot::Segment::Name, { + MetaIdField, + MetaCollectionIdField, + MetaPartitionIdField, + MetaNumField, + MetaLSNField, + MetaStateField, + MetaCreatedOnField, + MetaUpdatedOnField, + }); + +static const MetaSchema SEGMENTCOMMIT_SCHEMA(snapshot::SegmentCommit::Name, { + MetaIdField, + MetaSchemaIdField, + MetaPartitionIdField, + MetaSegmentIdField, + MetaMappingsField, + MetaRowCountField, + MetaSizeField, + MetaLSNField, + MetaStateField, + MetaCreatedOnField, + MetaUpdatedOnField, + }); + +static const MetaSchema SEGMENTFILE_SCHEMA(snapshot::SegmentFile::Name, + {MetaIdField, MetaCollectionIdField, MetaPartitionIdField, + MetaSegmentIdField, MetaFieldElementIdField, MetaRowCountField, + MetaSizeField, MetaLSNField, MetaStateField, MetaCreatedOnField, + MetaUpdatedOnField}); + +static const MetaSchema SCHEMACOMMIT_SCHEMA(snapshot::SchemaCommit::Name, { + MetaIdField, + MetaCollectionIdField, + MetaMappingsField, + MetaLSNField, + MetaStateField, + MetaCreatedOnField, + MetaUpdatedOnField, + }); + +static const MetaSchema FIELD_SCHEMA(snapshot::Field::Name, + {MetaIdField, MetaNameField, MetaNumField, MetaFtypeField, MetaParamsField, + MetaLSNField, MetaStateField, MetaCreatedOnField, MetaUpdatedOnField}); + +static const MetaSchema FIELDCOMMIT_SCHEMA(snapshot::FieldCommit::Name, + {MetaIdField, MetaCollectionIdField, MetaFieldIdField, MetaMappingsField, + MetaLSNField, MetaStateField, MetaCreatedOnField, MetaUpdatedOnField}); + +static const MetaSchema FIELDELEMENT_SCHEMA(snapshot::FieldElement::Name, + {MetaIdField, MetaCollectionIdField, MetaFieldIdField, MetaNameField, + MetaFtypeField, MetaParamsField, MetaLSNField, MetaStateField, + MetaCreatedOnField, MetaUpdatedOnField}); + +} // namespace + +/////////////// MySqlEngine /////////////// +Status +MySqlEngine::Initialize() { + // step 1: create db root path + // if (!boost::filesystem::is_directory(options_.path_)) { + // auto ret = boost::filesystem::create_directory(options_.path_); + // fiu_do_on("MySQLMetaImpl.Initialize.fail_create_directory", ret = false); + // if (!ret) { + // std::string msg = "Failed to create db directory " + options_.path_; + // LOG_ENGINE_ERROR_ << msg; + // throw Exception(DB_META_TRANSACTION_FAILED, msg); + // } + // } + std::string uri = options_.backend_uri_; + + // step 2: parse and check meta uri + utils::MetaUriInfo uri_info; + auto status = utils::ParseMetaUri(uri, uri_info); + if (!status.ok()) { + std::string msg = "Wrong URI format: " + uri; + LOG_ENGINE_ERROR_ << msg; + throw Exception(DB_INVALID_META_URI, msg); + } + + if (strcasecmp(uri_info.dialect_.c_str(), "mysql") != 0) { + std::string msg = "URI's dialect is not MySQL"; + LOG_ENGINE_ERROR_ << msg; + throw Exception(DB_INVALID_META_URI, msg); + } + + // step 3: connect mysql + unsigned int thread_hint = std::thread::hardware_concurrency(); + int max_pool_size = (thread_hint > 8) ? static_cast(thread_hint) : 8; + int port = 0; + if (!uri_info.port_.empty()) { + port = std::stoi(uri_info.port_); + } + + mysql_connection_pool_ = std::make_shared( + uri_info.db_name_, uri_info.username_, uri_info.password_, uri_info.host_, port, max_pool_size); + LOG_ENGINE_DEBUG_ << "MySQL connection pool: maximum pool size = " << std::to_string(max_pool_size); + + // step 4: validate to avoid open old version schema + // ValidateMetaSchema(); + + // step 5: clean shadow files + // if (mode_ != DBOptions::MODE::CLUSTER_READONLY) { + // CleanUpShadowFiles(); + // } + + // step 6: try connect mysql server + mysqlpp::ScopedConnection connectionPtr(*mysql_connection_pool_, safe_grab_); + + if (connectionPtr == nullptr) { + std::string msg = "Failed to connect MySQL meta server: " + uri; + LOG_ENGINE_ERROR_ << msg; + throw Exception(DB_INVALID_META_URI, msg); + } + + bool is_thread_aware = connectionPtr->thread_aware(); + fiu_do_on("MySQLMetaImpl.Initialize.is_thread_aware", is_thread_aware = false); + if (!is_thread_aware) { + std::string msg = + "Failed to initialize MySQL meta backend: MySQL client component wasn't built with thread awareness"; + LOG_ENGINE_ERROR_ << msg; + throw Exception(DB_INVALID_META_URI, msg); + } + + mysqlpp::Query InitializeQuery = connectionPtr->query(); + + auto create_schema = [&](const MetaSchema& schema) { + std::string create_table_str = "CREATE TABLE IF NOT EXISTS " + schema.name() + "(" + schema.ToString() + ");"; + InitializeQuery << create_table_str; + // LOG_ENGINE_DEBUG_ << "Initialize: " << InitializeQuery.str(); + + bool initialize_query_exec = InitializeQuery.exec(); + // fiu_do_on("MySQLMetaImpl.Initialize.fail_create_collection_files", initialize_query_exec = false); + if (!initialize_query_exec) { + std::string msg = "Failed to create meta collection '" + schema.name() + "' in MySQL"; + LOG_ENGINE_ERROR_ << msg; + throw Exception(DB_META_TRANSACTION_FAILED, msg); + } + }; + + create_schema(COLLECTION_SCHEMA); + create_schema(COLLECTIONCOMMIT_SCHEMA); + create_schema(PARTITION_SCHEMA); + create_schema(PARTITIONCOMMIT_SCHEMA); + create_schema(SEGMENT_SCHEMA); + create_schema(SEGMENTCOMMIT_SCHEMA); + create_schema(SEGMENTFILE_SCHEMA); + create_schema(SCHEMACOMMIT_SCHEMA); + create_schema(FIELD_SCHEMA); + create_schema(FIELDCOMMIT_SCHEMA); + create_schema(FIELDELEMENT_SCHEMA); + + return Status::OK(); +} + +Status +MySqlEngine::Query(const MetaQueryContext& context, AttrsMapList& attrs) { + try { + mysqlpp::ScopedConnection connectionPtr(*mysql_connection_pool_, safe_grab_); + + std::string sql; + auto status = MetaHelper::MetaQueryContextToSql(context, sql); + if (!status.ok()) { + return status; + } + + std::lock_guard lock(meta_mutex_); + + mysqlpp::Query query = connectionPtr->query(sql); + auto res = query.store(); + if (!res) { + // TODO: change error behavior + throw Exception(1, "Query res is false"); + } + + auto names = res.field_names(); + for (auto& row : res) { + AttrsMap attrs_map; + for (auto& name : *names) { + attrs_map.insert(std::make_pair(name, row[name.c_str()])); + } + attrs.push_back(attrs_map); + } + } catch (const mysqlpp::BadQuery& er) { + // Handle any query errors + // cerr << "Query error: " << er.what() << endl; + return Status(1, er.what()); + } catch (const mysqlpp::BadConversion& er) { + // Handle bad conversions + // cerr << "Conversion error: " << er.what() << endl << + // "\tretrieved data size: " << er.retrieved << + // ", actual size: " << er.actual_size << endl; + return Status(1, er.what()); + } catch (const mysqlpp::Exception& er) { + // Catch-all for any other MySQL++ exceptions + // cerr << "Error: " << er.what() << endl; + return Status(1, er.what()); + } + + return Status::OK(); +} + +Status +MySqlEngine::ExecuteTransaction(const std::vector& sql_contexts, std::vector& result_ids) { + try { + mysqlpp::ScopedConnection connectionPtr(*mysql_connection_pool_, safe_grab_); + mysqlpp::Transaction trans(*connectionPtr, mysqlpp::Transaction::serializable, mysqlpp::Transaction::session); + + std::lock_guard lock(meta_mutex_); + for (auto& context : sql_contexts) { + std::string sql; + auto status = MetaHelper::MetaApplyContextToSql(context, sql); + if (!status.ok()) { + return status; + } + + auto query = connectionPtr->query(sql); + auto res = query.execute(); + if (context.op_ == oAdd) { + auto id = res.insert_id(); + result_ids.push_back(id); + } else { + result_ids.push_back(context.id_); + } + } + + trans.commit(); + // std::cout << "[DB] Transaction commit " << std::endl; + } catch (const mysqlpp::BadQuery& er) { + // Handle any query errors + // cerr << "Query error: " << er.what() << endl; + // return -1; + // std::cout << "[DB] Error: " << er.what() << std::endl; + return Status(SERVER_UNSUPPORTED_ERROR, er.what()); + } catch (const mysqlpp::BadConversion& er) { + // Handle bad conversions + // cerr << "Conversion error: " << er.what() << endl << + // "\tretrieved data size: " << er.retrieved << + // ", actual size: " << er.actual_size << endl; + // return -1; + // std::cout << "[DB] Error: " << er.what() << std::endl; + return Status(SERVER_UNSUPPORTED_ERROR, er.what()); + } catch (const mysqlpp::Exception& er) { + // Catch-all for any other MySQL++ exceptions + // cerr << "Error: " << er.what() << endl; + // return -1; + // std::cout << "[DB] Error: " << er.what() << std::endl; + return Status(SERVER_UNSUPPORTED_ERROR, er.what()); + } + + return Status::OK(); +} + +Status +MySqlEngine::TruncateAll() { + static std::vector collecton_names = { + COLLECTION_SCHEMA.name(), COLLECTIONCOMMIT_SCHEMA.name(), PARTITION_SCHEMA.name(), + PARTITIONCOMMIT_SCHEMA.name(), SEGMENT_SCHEMA.name(), SEGMENTCOMMIT_SCHEMA.name(), + SEGMENTFILE_SCHEMA.name(), SCHEMACOMMIT_SCHEMA.name(), FIELD_SCHEMA.name(), + FIELDCOMMIT_SCHEMA.name(), FIELDELEMENT_SCHEMA.name(), + }; + + std::vector contexts; + for (auto& name : collecton_names) { + MetaApplyContext context; + context.sql_ = "TRUNCATE " + name + ";"; + context.id_ = 0; + + contexts.push_back(context); + } + + std::vector ids; + return ExecuteTransaction(contexts, ids); +} + +} // namespace milvus::engine::meta diff --git a/core/src/db/meta/backend/MySqlEngine.h b/core/src/db/meta/backend/MySqlEngine.h new file mode 100644 index 000000000000..5cc991e8b62a --- /dev/null +++ b/core/src/db/meta/backend/MySqlEngine.h @@ -0,0 +1,57 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include + +#include + +#include "db/Options.h" +#include "db/meta/MySQLConnectionPool.h" +#include "db/meta/backend/MetaEngine.h" + +namespace milvus::engine::meta { + +class MySqlEngine : public MetaEngine { + public: + explicit MySqlEngine(const DBMetaOptions& options) : options_(options) { + Initialize(); + } + + ~MySqlEngine() = default; + + Status + Query(const MetaQueryContext& context, AttrsMapList& attrs) override; + + Status + ExecuteTransaction(const std::vector& sql_contexts, std::vector& result_ids) override; + + Status + TruncateAll() override; + + private: + Status + Initialize(); + + private: + const DBMetaOptions options_; + // const int mode_; + + std::shared_ptr mysql_connection_pool_; + bool safe_grab_ = false; // Safely graps a connection from mysql pool + + std::mutex meta_mutex_; +}; + +} // namespace milvus::engine::meta diff --git a/core/src/db/snapshot/BaseResource.h b/core/src/db/snapshot/BaseResource.h index 8931714e7b7e..5e7da89d2806 100644 --- a/core/src/db/snapshot/BaseResource.h +++ b/core/src/db/snapshot/BaseResource.h @@ -12,20 +12,23 @@ #pragma once #include +#include #include #include "ReferenceProxy.h" namespace milvus::engine::snapshot { +template class BaseResource : public ReferenceProxy { public: virtual std::string ToString() const { - return std::string(); + std::stringstream ss; + const DerivedT& derived = static_cast(*this); + ss << DerivedT::Name << ": id=" << derived.GetID() << " state=" << derived.GetState(); + return ss.str(); } }; -using BaseResourcePtr = std::shared_ptr; - } // namespace milvus::engine::snapshot diff --git a/core/src/db/snapshot/CompoundOperations.cpp b/core/src/db/snapshot/CompoundOperations.cpp index e7e8727ff3bc..0dd09288aa77 100644 --- a/core/src/db/snapshot/CompoundOperations.cpp +++ b/core/src/db/snapshot/CompoundOperations.cpp @@ -10,9 +10,17 @@ // or implied. See the License for the specific language governing permissions and limitations under the License. #include "db/snapshot/CompoundOperations.h" + +#include #include +#include #include +#include + +#include "db/meta/MetaAdapter.h" +#include "db/snapshot/IterateHandler.h" #include "db/snapshot/OperationExecutor.h" +#include "db/snapshot/ResourceContext.h" #include "db/snapshot/Snapshots.h" #include "utils/Status.h" @@ -20,39 +28,74 @@ namespace milvus { namespace engine { namespace snapshot { -BuildOperation::BuildOperation(const OperationContext& context, ScopedSnapshotT prev_ss) : BaseT(context, prev_ss) { +AddSegmentFileOperation::AddSegmentFileOperation(const OperationContext& context, ScopedSnapshotT prev_ss) + : BaseT(context, prev_ss) { } Status -BuildOperation::DoExecute(Store& store) { - STATUS_CHECK(CheckStale(std::bind(&BuildOperation::CheckSegmentStale, this, std::placeholders::_1, +AddSegmentFileOperation::DoExecute(StorePtr store) { + STATUS_CHECK(CheckStale(std::bind(&AddSegmentFileOperation::CheckSegmentStale, this, std::placeholders::_1, context_.new_segment_files[0]->GetSegmentId()))); + auto update_size = [&](SegmentFilePtr& file) { + auto update_ctx = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); + update_ctx->AddAttr(SizeField::Name); + AddStepWithLsn(*file, context_.lsn, update_ctx); + }; + + for (auto& new_file : context_.new_segment_files) { + update_size(new_file); + } + SegmentCommitOperation sc_op(context_, GetAdjustedSS()); STATUS_CHECK(sc_op(store)); STATUS_CHECK(sc_op.GetResource(context_.new_segment_commit)); - AddStepWithLsn(*context_.new_segment_commit, context_.lsn); + auto seg_commit_ctx_p = ResourceContextBuilder() + .SetResource(context_.new_segment_commit) + .SetOp(meta::oUpdate) + .CreatePtr(); + if (delta_ != 0) { + auto new_row_cnt = 0; + if (sub_ && context_.new_segment_commit->GetRowCount() < delta_) { + return Status(SS_ERROR, "Invalid row count delta"); + } else if (sub_) { + new_row_cnt = context_.new_segment_commit->GetRowCount() - delta_; + } else { + new_row_cnt = context_.new_segment_commit->GetRowCount() + delta_; + } + context_.new_segment_commit->SetRowCount(new_row_cnt); + seg_commit_ctx_p->AddAttr(RowCountField::Name); + } + AddStepWithLsn(*context_.new_segment_commit, context_.lsn, seg_commit_ctx_p); PartitionCommitOperation pc_op(context_, GetAdjustedSS()); STATUS_CHECK(pc_op(store)); OperationContext cc_context; STATUS_CHECK(pc_op.GetResource(cc_context.new_partition_commit)); - AddStepWithLsn(*cc_context.new_partition_commit, context_.lsn); + auto par_commit_ctx_p = ResourceContextBuilder() + .SetResource(cc_context.new_partition_commit) + .SetOp(meta::oUpdate) + .CreatePtr(); + AddStepWithLsn(*cc_context.new_partition_commit, context_.lsn, par_commit_ctx_p); context_.new_partition_commit = cc_context.new_partition_commit; - STATUS_CHECK(pc_op.GetResource(context_.new_partition_commit)); - AddStepWithLsn(*context_.new_partition_commit, context_.lsn); + // STATUS_CHECK(pc_op.GetResource(context_.new_partition_commit)); + // AddStepWithLsn(*context_.new_partition_commit, context_.lsn); CollectionCommitOperation cc_op(cc_context, GetAdjustedSS()); STATUS_CHECK(cc_op(store)); STATUS_CHECK(cc_op.GetResource(context_.new_collection_commit)); - AddStepWithLsn(*context_.new_collection_commit, context_.lsn); + auto c_commit_ctx_p = ResourceContextBuilder() + .SetResource(context_.new_collection_commit) + .SetOp(meta::oUpdate) + .CreatePtr(); + AddStepWithLsn(*context_.new_collection_commit, context_.lsn, c_commit_ctx_p); return Status::OK(); } Status -BuildOperation::CheckSegmentStale(ScopedSnapshotT& latest_snapshot, ID_TYPE segment_id) const { +AddSegmentFileOperation::CheckSegmentStale(ScopedSnapshotT& latest_snapshot, ID_TYPE segment_id) const { auto segment = latest_snapshot->GetResource(segment_id); if (!segment) { std::stringstream emsg; @@ -63,23 +106,259 @@ BuildOperation::CheckSegmentStale(ScopedSnapshotT& latest_snapshot, ID_TYPE segm } Status -BuildOperation::CommitNewSegmentFile(const SegmentFileContext& context, SegmentFilePtr& created) { - STATUS_CHECK( - CheckStale(std::bind(&BuildOperation::CheckSegmentStale, this, std::placeholders::_1, context.segment_id))); +AddSegmentFileOperation::CommitRowCountDelta(SIZE_TYPE delta, bool sub) { + delta_ = delta; + sub_ = sub; + return Status::OK(); +} + +Status +AddSegmentFileOperation::CommitNewSegmentFile(const SegmentFileContext& context, SegmentFilePtr& created) { + STATUS_CHECK(CheckStale( + std::bind(&AddSegmentFileOperation::CheckSegmentStale, this, std::placeholders::_1, context.segment_id))); auto segment = GetStartedSS()->GetResource(context.segment_id); - if (!segment) { + if (!segment || (context_.new_segment_files.size() > 0 && + (context_.new_segment_files[0]->GetSegmentId() != context.segment_id))) { std::stringstream emsg; emsg << GetRepr() << ". Invalid segment " << context.segment_id << " in context"; return Status(SS_INVALID_CONTEX_ERROR, emsg.str()); } + auto ctx = context; ctx.partition_id = segment->GetPartitionId(); auto new_sf_op = std::make_shared(ctx, GetStartedSS()); STATUS_CHECK(new_sf_op->Push()); STATUS_CHECK(new_sf_op->GetResource(created)); context_.new_segment_files.push_back(created); - AddStepWithLsn(*created, context_.lsn); + auto sf_ctx_p = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); + AddStepWithLsn(*created, context_.lsn, sf_ctx_p); + + return Status::OK(); +} + +AddFieldElementOperation::AddFieldElementOperation(const OperationContext& context, ScopedSnapshotT prev_ss) + : BaseT(context, prev_ss) { +} + +Status +AddFieldElementOperation::PreCheck() { + if (context_.stale_field_elements.size() > 0 || context_.new_field_elements.size() == 0) { + return Status(SS_INVALID_CONTEX_ERROR, "No new field element or at least one stale field element"); + } + + return Status::OK(); +} + +Status +AddFieldElementOperation::DoExecute(StorePtr store) { + OperationContext cc_context; + { + auto context = context_; + context.new_field_elements.clear(); + for (auto& new_fe : context_.new_field_elements) { + if (new_fe->GetCollectionId() != GetAdjustedSS()->GetCollectionId()) { + std::stringstream emsg; + emsg << GetRepr() << ". Invalid collection id " << new_fe->GetCollectionId(); + return Status(SS_INVALID_CONTEX_ERROR, emsg.str()); + } + auto field = GetAdjustedSS()->GetResource(new_fe->GetFieldId()); + if (!field) { + std::stringstream emsg; + emsg << GetRepr() << ". Invalid field id " << new_fe->GetFieldId(); + return Status(SS_INVALID_CONTEX_ERROR, emsg.str()); + } + FieldElementPtr field_element; + auto status = GetAdjustedSS()->GetFieldElement(field->GetName(), new_fe->GetName(), field_element); + if (status.ok()) { + std::stringstream emsg; + emsg << GetRepr() << ". Duplicate field element name " << new_fe->GetName(); + return Status(SS_INVALID_CONTEX_ERROR, emsg.str()); + } + + STATUS_CHECK(store->CreateResource(FieldElement(*new_fe), field_element)); + auto fe_ctx_p = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); + AddStepWithLsn(*field_element, context.lsn, fe_ctx_p); + + context.new_field_elements.push_back(field_element); + } + + FieldCommitOperation fc_op(context, GetAdjustedSS()); + STATUS_CHECK(fc_op(store)); + FieldCommitPtr new_field_commit; + STATUS_CHECK(fc_op.GetResource(new_field_commit)); + auto fc_ctx_p = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); + AddStepWithLsn(*new_field_commit, context.lsn, fc_ctx_p); + context.new_field_commits.push_back(new_field_commit); + for (auto& kv : GetAdjustedSS()->GetResources()) { + if (kv.second->GetFieldId() == new_field_commit->GetFieldId()) { + context.stale_field_commits.push_back(kv.second.Get()); + } + } + + SchemaCommitOperation sc_op(context, GetAdjustedSS()); + + STATUS_CHECK(sc_op(store)); + STATUS_CHECK(sc_op.GetResource(cc_context.new_schema_commit)); + auto sc_ctx_p = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); + AddStepWithLsn(*cc_context.new_schema_commit, context.lsn, sc_ctx_p); + } + + CollectionCommitOperation cc_op(cc_context, GetAdjustedSS()); + STATUS_CHECK(cc_op(store)); + STATUS_CHECK(cc_op.GetResource(context_.new_collection_commit)); + auto cc_ctx_p = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); + AddStepWithLsn(*context_.new_collection_commit, context_.lsn, cc_ctx_p); + + return Status::OK(); +} + +DropAllIndexOperation::DropAllIndexOperation(const OperationContext& context, ScopedSnapshotT prev_ss) + : BaseT(context, prev_ss) { +} + +Status +DropAllIndexOperation::PreCheck() { + if (context_.stale_field_elements.size() == 0) { + std::stringstream emsg; + emsg << GetRepr() << ". Stale field element is empty"; + return Status(SS_INVALID_CONTEX_ERROR, emsg.str()); + } + + std::set field_ids; + + for (auto& stale_field_element : context_.stale_field_elements) { + if (!GetStartedSS()->GetResource(stale_field_element->GetID())) { + std::stringstream emsg; + emsg << GetRepr() << ". Specified field element " << stale_field_element->GetName(); + emsg << " is stale"; + return Status(SS_INVALID_CONTEX_ERROR, emsg.str()); + } + field_ids.insert(stale_field_element->GetFieldId()); + } + + if (field_ids.size() > 1) { + std::stringstream emsg; + emsg << GetRepr() << ". All stale field elements should be of same field"; + return Status(SS_INVALID_CONTEX_ERROR, emsg.str()); + } + + // TODO: Check type + return Status::OK(); +} + +Status +DropAllIndexOperation::DoExecute(StorePtr store) { + /* auto& segment_files = GetAdjustedSS()->GetResources(); */ + + OperationContext cc_context; + { + auto context = context_; + + FieldCommitOperation fc_op(context, GetAdjustedSS()); + STATUS_CHECK(fc_op(store)); + FieldCommitPtr new_field_commit; + STATUS_CHECK(fc_op.GetResource(new_field_commit)); + auto fc_ctx_p = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); + AddStepWithLsn(*new_field_commit, context.lsn, fc_ctx_p); + context.new_field_commits.push_back(new_field_commit); + for (auto& kv : GetAdjustedSS()->GetResources()) { + if (kv.second->GetFieldId() == new_field_commit->GetFieldId()) { + context.stale_field_commits.push_back(kv.second.Get()); + } + } + + SchemaCommitOperation sc_op(context, GetAdjustedSS()); + + STATUS_CHECK(sc_op(store)); + STATUS_CHECK(sc_op.GetResource(cc_context.new_schema_commit)); + auto sc_ctx_p = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); + AddStepWithLsn(*cc_context.new_schema_commit, context.lsn, sc_ctx_p); + } + + std::map> p_sc_map; + auto executor = [&](const Segment::Ptr& segment, SegmentIterator* handler) -> Status { + auto context = context_; + for (auto& stale_element : context.stale_field_elements) { + auto segment_file = handler->ss_->GetSegmentFile(segment->GetID(), stale_element->GetID()); + if (segment_file) { + context.stale_segment_files.push_back(segment_file); + } + } + if (context.stale_segment_files.size() == 0) { + return Status::OK(); + } + SegmentCommitOperation sc_op(context, GetAdjustedSS()); + STATUS_CHECK(sc_op(store)); + STATUS_CHECK(sc_op.GetResource(context.new_segment_commit)); + auto segc_ctx_p = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); + AddStepWithLsn(*context.new_segment_commit, context.lsn, segc_ctx_p); + p_sc_map[context.new_segment_commit->GetPartitionId()].push_back(context.new_segment_commit); + return Status::OK(); + }; + + auto segment_iterator = std::make_shared(GetAdjustedSS(), executor); + segment_iterator->Iterate(); + STATUS_CHECK(segment_iterator->GetStatus()); + + for (auto& kv : p_sc_map) { + auto& partition_id = kv.first; + auto context = context_; + context.new_segment_commits = kv.second; + PartitionCommitOperation pc_op(context, GetAdjustedSS()); + STATUS_CHECK(pc_op(store)); + STATUS_CHECK(pc_op.GetResource(context.new_partition_commit)); + auto pc_ctx_p = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); + AddStepWithLsn(*context.new_partition_commit, context.lsn, pc_ctx_p); + cc_context.new_partition_commits.push_back(context.new_partition_commit); + } + + CollectionCommitOperation cc_op(cc_context, GetAdjustedSS()); + STATUS_CHECK(cc_op(store)); + STATUS_CHECK(cc_op.GetResource(context_.new_collection_commit)); + auto cc_ctx_p = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); + AddStepWithLsn(*context_.new_collection_commit, context_.lsn, cc_ctx_p); + + return Status::OK(); +} + +DropIndexOperation::DropIndexOperation(const OperationContext& context, ScopedSnapshotT prev_ss) + : BaseT(context, prev_ss) { +} + +Status +DropIndexOperation::PreCheck() { + if (context_.stale_segment_files.size() == 0) { + std::stringstream emsg; + emsg << GetRepr() << ". Stale segment is requried"; + return Status(SS_INVALID_CONTEX_ERROR, emsg.str()); + } + // TODO: Check segment file type + + return Status::OK(); +} + +Status +DropIndexOperation::DoExecute(StorePtr store) { + SegmentCommitOperation sc_op(context_, GetAdjustedSS()); + STATUS_CHECK(sc_op(store)); + STATUS_CHECK(sc_op.GetResource(context_.new_segment_commit)); + auto sc_ctx_p = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); + AddStepWithLsn(*context_.new_segment_commit, context_.lsn, sc_ctx_p); + + OperationContext cc_context; + PartitionCommitOperation pc_op(context_, GetAdjustedSS()); + STATUS_CHECK(pc_op(store)); + STATUS_CHECK(pc_op.GetResource(cc_context.new_partition_commit)); + auto pc_ctx_p = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); + AddStepWithLsn(*cc_context.new_partition_commit, context_.lsn, pc_ctx_p); + context_.new_partition_commit = cc_context.new_partition_commit; + + CollectionCommitOperation cc_op(cc_context, GetAdjustedSS()); + STATUS_CHECK(cc_op(store)); + STATUS_CHECK(cc_op.GetResource(context_.new_collection_commit)); + auto cc_ctx_p = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); + AddStepWithLsn(*context_.new_collection_commit, context_.lsn, cc_ctx_p); return Status::OK(); } @@ -89,17 +368,32 @@ NewSegmentOperation::NewSegmentOperation(const OperationContext& context, Scoped } Status -NewSegmentOperation::DoExecute(Store& store) { +NewSegmentOperation::CommitRowCount(SIZE_TYPE row_cnt) { + row_cnt_ = row_cnt; + return Status::OK(); +} + +Status +NewSegmentOperation::DoExecute(StorePtr store) { // PXU TODO: // 1. Check all requried field elements have related segment files // 2. Check Stale and others /* auto status = PrevSnapshotRequried(); */ /* if (!status.ok()) return status; */ // TODO: Check Context + for (auto& new_file : context_.new_segment_files) { + auto update_ctx = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); + update_ctx->AddAttr(SizeField::Name); + AddStepWithLsn(*new_file, context_.lsn, update_ctx); + } + SegmentCommitOperation sc_op(context_, GetAdjustedSS()); STATUS_CHECK(sc_op(store)); STATUS_CHECK(sc_op.GetResource(context_.new_segment_commit)); - AddStepWithLsn(*context_.new_segment_commit, context_.lsn); + context_.new_segment_commit->SetRowCount(row_cnt_); + auto sc_ctx_p = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); + sc_ctx_p->AddAttr(RowCountField::Name); + AddStepWithLsn(*context_.new_segment_commit, context_.lsn, sc_ctx_p); /* std::cout << GetRepr() << " POST_SC_MAP=("; */ /* for (auto id : context_.new_segment_commit->GetMappings()) { */ /* std::cout << id << ","; */ @@ -110,7 +404,8 @@ NewSegmentOperation::DoExecute(Store& store) { PartitionCommitOperation pc_op(context_, GetAdjustedSS()); STATUS_CHECK(pc_op(store)); STATUS_CHECK(pc_op.GetResource(cc_context.new_partition_commit)); - AddStepWithLsn(*cc_context.new_partition_commit, context_.lsn); + auto pc_ctx_p = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); + AddStepWithLsn(*cc_context.new_partition_commit, context_.lsn, pc_ctx_p); context_.new_partition_commit = cc_context.new_partition_commit; /* std::cout << GetRepr() << " POST_PC_MAP=("; */ /* for (auto id : cc_context.new_partition_commit->GetMappings()) { */ @@ -121,7 +416,8 @@ NewSegmentOperation::DoExecute(Store& store) { CollectionCommitOperation cc_op(cc_context, GetAdjustedSS()); STATUS_CHECK(cc_op(store)); STATUS_CHECK(cc_op.GetResource(context_.new_collection_commit)); - AddStepWithLsn(*context_.new_collection_commit, context_.lsn); + auto cc_ctx_p = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); + AddStepWithLsn(*context_.new_collection_commit, context_.lsn, cc_ctx_p); return Status::OK(); } @@ -132,7 +428,8 @@ NewSegmentOperation::CommitNewSegment(SegmentPtr& created) { STATUS_CHECK(op->Push()); STATUS_CHECK(op->GetResource(context_.new_segment)); created = context_.new_segment; - AddStepWithLsn(*created, context_.lsn); + auto s_ctx_p = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); + AddStepWithLsn(*created, context_.lsn, s_ctx_p); return Status::OK(); } @@ -141,10 +438,12 @@ NewSegmentOperation::CommitNewSegmentFile(const SegmentFileContext& context, Seg auto ctx = context; ctx.segment_id = context_.new_segment->GetID(); ctx.partition_id = context_.new_segment->GetPartitionId(); + ctx.collection_id = GetStartedSS()->GetCollectionId(); auto new_sf_op = std::make_shared(ctx, GetStartedSS()); STATUS_CHECK(new_sf_op->Push()); STATUS_CHECK(new_sf_op->GetResource(created)); - AddStepWithLsn(*created, context_.lsn); + auto sf_ctx_p = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); + AddStepWithLsn(*created, context_.lsn, sf_ctx_p); context_.new_segment_files.push_back(created); return Status::OK(); } @@ -176,7 +475,8 @@ MergeOperation::CommitNewSegment(SegmentPtr& created) { STATUS_CHECK(op->Push()); STATUS_CHECK(op->GetResource(context_.new_segment)); created = context_.new_segment; - AddStepWithLsn(*created, context_.lsn); + auto seg_ctx_p = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); + AddStepWithLsn(*created, context_.lsn, seg_ctx_p); return Status::OK(); } @@ -192,19 +492,38 @@ MergeOperation::CommitNewSegmentFile(const SegmentFileContext& context, SegmentF STATUS_CHECK(new_sf_op->Push()); STATUS_CHECK(new_sf_op->GetResource(created)); context_.new_segment_files.push_back(created); - AddStepWithLsn(*created, context_.lsn); + auto sf_ctx_p = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); + AddStepWithLsn(*created, context_.lsn, sf_ctx_p); return Status::OK(); } Status -MergeOperation::DoExecute(Store& store) { +MergeOperation::DoExecute(StorePtr store) { + auto row_cnt = 0; + for (auto& stale_seg : context_.stale_segments) { + row_cnt += GetStartedSS()->GetSegmentCommitBySegmentId(stale_seg->GetID())->GetRowCount(); + } + + auto update_size = [&](SegmentFilePtr& file) { + auto update_ctx = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); + update_ctx->AddAttr(SizeField::Name); + AddStepWithLsn(*file, context_.lsn, update_ctx); + }; + + for (auto& new_file : context_.new_segment_files) { + update_size(new_file); + } + // PXU TODO: // 1. Check all required field elements have related segment files // 2. Check Stale and others SegmentCommitOperation sc_op(context_, GetAdjustedSS()); STATUS_CHECK(sc_op(store)); STATUS_CHECK(sc_op.GetResource(context_.new_segment_commit)); - AddStepWithLsn(*context_.new_segment_commit, context_.lsn); + auto sc_ctx_p = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); + context_.new_segment_commit->SetRowCount(row_cnt); + sc_ctx_p->AddAttr(RowCountField::Name); + AddStepWithLsn(*context_.new_segment_commit, context_.lsn, sc_ctx_p); /* std::cout << GetRepr() << " POST_SC_MAP=("; */ /* for (auto id : context_.new_segment_commit->GetMappings()) { */ /* std::cout << id << ","; */ @@ -215,7 +534,8 @@ MergeOperation::DoExecute(Store& store) { STATUS_CHECK(pc_op(store)); OperationContext cc_context; STATUS_CHECK(pc_op.GetResource(cc_context.new_partition_commit)); - AddStepWithLsn(*cc_context.new_partition_commit, context_.lsn); + auto pc_ctx_p = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); + AddStepWithLsn(*cc_context.new_partition_commit, context_.lsn, pc_ctx_p); context_.new_partition_commit = cc_context.new_partition_commit; /* std::cout << GetRepr() << " POST_PC_MAP=("; */ @@ -227,7 +547,8 @@ MergeOperation::DoExecute(Store& store) { CollectionCommitOperation cc_op(cc_context, GetAdjustedSS()); STATUS_CHECK(cc_op(store)); STATUS_CHECK(cc_op.GetResource(context_.new_collection_commit)); - AddStepWithLsn(*context_.new_collection_commit, context_.lsn); + auto cc_ctx_p = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); + AddStepWithLsn(*context_.new_collection_commit, context_.lsn, cc_ctx_p); return Status::OK(); } @@ -239,8 +560,8 @@ GetSnapshotIDsOperation::GetSnapshotIDsOperation(ID_TYPE collection_id, bool rev } Status -GetSnapshotIDsOperation::DoExecute(Store& store) { - ids_ = store.AllActiveCollectionCommitIds(collection_id_, reversed_); +GetSnapshotIDsOperation::DoExecute(StorePtr store) { + ids_ = store->AllActiveCollectionCommitIds(collection_id_, reversed_); return Status::OK(); } @@ -254,8 +575,8 @@ GetCollectionIDsOperation::GetCollectionIDsOperation(bool reversed) } Status -GetCollectionIDsOperation::DoExecute(Store& store) { - ids_ = store.AllActiveCollectionIds(reversed_); +GetCollectionIDsOperation::DoExecute(StorePtr store) { + ids_ = store->AllActiveCollectionIds(reversed_); return Status::OK(); } @@ -283,7 +604,7 @@ DropPartitionOperation::GetRepr() const { } Status -DropPartitionOperation::DoExecute(Store& store) { +DropPartitionOperation::DoExecute(StorePtr store) { PartitionPtr p; auto id = c_context_.id; if (id == 0) { @@ -303,7 +624,11 @@ DropPartitionOperation::DoExecute(Store& store) { auto cc_op = CollectionCommitOperation(op_ctx, GetAdjustedSS()); STATUS_CHECK(cc_op(store)); STATUS_CHECK(cc_op.GetResource(context_.new_collection_commit)); - AddStepWithLsn(*context_.new_collection_commit, c_context_.lsn); + auto cc_ctx_p = ResourceContextBuilder() + .SetResource(context_.new_collection_commit) + .SetOp(meta::oUpdate) + .CreatePtr(); + AddStepWithLsn(*context_.new_collection_commit, c_context_.lsn, cc_ctx_p); return Status::OK(); } @@ -328,24 +653,34 @@ CreatePartitionOperation::CommitNewPartition(const PartitionContext& context, Pa STATUS_CHECK(op->Push()); STATUS_CHECK(op->GetResource(partition)); context_.new_partition = partition; - AddStepWithLsn(*partition, context_.lsn); + auto par_ctx_p = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); + AddStepWithLsn(*partition, context_.lsn, par_ctx_p); return Status::OK(); } Status -CreatePartitionOperation::DoExecute(Store& store) { +CreatePartitionOperation::DoExecute(StorePtr store) { STATUS_CHECK(CheckStale()); auto collection = GetAdjustedSS()->GetCollection(); auto partition = context_.new_partition; + if (context_.new_partition) { + if (GetAdjustedSS()->GetPartition(context_.new_partition->GetName())) { + std::stringstream emsg; + emsg << GetRepr() << ". Duplicate Partition \"" << context_.new_partition->GetName() << "\""; + return Status(SS_DUPLICATED_ERROR, emsg.str()); + } + } + PartitionCommitPtr pc; OperationContext pc_context; pc_context.new_partition = partition; auto pc_op = PartitionCommitOperation(pc_context, GetAdjustedSS()); STATUS_CHECK(pc_op(store)); STATUS_CHECK(pc_op.GetResource(pc)); - AddStepWithLsn(*pc, context_.lsn); + auto pc_ctx_p = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); + AddStepWithLsn(*pc, context_.lsn, pc_ctx_p); OperationContext cc_context; cc_context.new_partition_commit = pc; @@ -354,7 +689,8 @@ CreatePartitionOperation::DoExecute(Store& store) { STATUS_CHECK(cc_op(store)); CollectionCommitPtr cc; STATUS_CHECK(cc_op.GetResource(cc)); - AddStepWithLsn(*cc, context_.lsn); + auto cc_ctx_p = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); + AddStepWithLsn(*cc, context_.lsn, cc_ctx_p); context_.new_collection_commit = cc; return Status::OK(); @@ -385,63 +721,81 @@ CreateCollectionOperation::GetRepr() const { } Status -CreateCollectionOperation::DoExecute(Store& store) { - // TODO: Do some checks +CreateCollectionOperation::DoExecute(StorePtr store) { CollectionPtr collection; - auto status = store.CreateCollection(Collection(c_context_.collection->GetName()), collection); + ScopedSnapshotT ss; + Snapshots::GetInstance().GetSnapshot(ss, c_context_.collection->GetName()); + if (ss) { + std::stringstream emsg; + emsg << GetRepr() << ". Duplicated collection " << c_context_.collection->GetName(); + return Status(SS_DUPLICATED_ERROR, emsg.str()); + } + + auto status = store->CreateResource(Collection(c_context_.collection->GetName()), collection); if (!status.ok()) { std::cerr << status.ToString() << std::endl; return status; } - AddStepWithLsn(*collection, c_context_.lsn); + auto c_ctx_p = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); + AddStepWithLsn(*collection, c_context_.lsn, c_ctx_p); context_.new_collection = collection; MappingT field_commit_ids = {}; + ID_TYPE result_id; auto field_idx = 0; for (auto& field_kv : c_context_.fields_schema) { field_idx++; auto& field_schema = field_kv.first; auto& field_elements = field_kv.second; FieldPtr field; - status = - store.CreateResource(Field(field_schema->GetName(), field_idx, field_schema->GetFtype()), field); - AddStepWithLsn(*field, c_context_.lsn); + status = store->CreateResource( + Field(field_schema->GetName(), field_idx, field_schema->GetFtype(), field_schema->GetParams()), field); + auto f_ctx_p = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); + AddStepWithLsn(*field, c_context_.lsn, f_ctx_p); MappingT element_ids = {}; FieldElementPtr raw_element; - status = store.CreateResource( - FieldElement(collection->GetID(), field->GetID(), "RAW", FieldElementType::RAW), raw_element); - AddStepWithLsn(*raw_element, c_context_.lsn); + status = store->CreateResource( + FieldElement(collection->GetID(), field->GetID(), DEFAULT_RAW_DATA_NAME, FieldElementType::FET_RAW), + raw_element); + auto fe_ctx_p = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); + AddStepWithLsn(*raw_element, c_context_.lsn, fe_ctx_p); element_ids.insert(raw_element->GetID()); for (auto& element_schema : field_elements) { FieldElementPtr element; status = - store.CreateResource(FieldElement(collection->GetID(), field->GetID(), - element_schema->GetName(), element_schema->GetFtype()), - element); - AddStepWithLsn(*element, c_context_.lsn); + store->CreateResource(FieldElement(collection->GetID(), field->GetID(), + element_schema->GetName(), element_schema->GetFtype()), + element); + auto t_fe_ctx_p = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); + AddStepWithLsn(*element, c_context_.lsn, t_fe_ctx_p); element_ids.insert(element->GetID()); } FieldCommitPtr field_commit; - status = store.CreateResource(FieldCommit(collection->GetID(), field->GetID(), element_ids), - field_commit); - AddStepWithLsn(*field_commit, c_context_.lsn); + status = store->CreateResource(FieldCommit(collection->GetID(), field->GetID(), element_ids), + field_commit); + auto fc_ctx_p = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); + AddStepWithLsn(*field_commit, c_context_.lsn, fc_ctx_p); field_commit_ids.insert(field_commit->GetID()); } SchemaCommitPtr schema_commit; - status = store.CreateResource(SchemaCommit(collection->GetID(), field_commit_ids), schema_commit); - AddStepWithLsn(*schema_commit, c_context_.lsn); + status = store->CreateResource(SchemaCommit(collection->GetID(), field_commit_ids), schema_commit); + auto sc_ctx_p = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); + AddStepWithLsn(*schema_commit, c_context_.lsn, sc_ctx_p); PartitionPtr partition; - status = store.CreateResource(Partition("_default", collection->GetID()), partition); - AddStepWithLsn(*partition, c_context_.lsn); + status = store->CreateResource(Partition("_default", collection->GetID()), partition); + auto p_ctx_p = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); + AddStepWithLsn(*partition, c_context_.lsn, p_ctx_p); context_.new_partition = partition; PartitionCommitPtr partition_commit; - status = store.CreateResource(PartitionCommit(collection->GetID(), partition->GetID()), - partition_commit); - AddStepWithLsn(*partition_commit, c_context_.lsn); + status = store->CreateResource(PartitionCommit(collection->GetID(), partition->GetID()), + partition_commit); + auto pc_ctx_p = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); + AddStepWithLsn(*partition_commit, c_context_.lsn, pc_ctx_p); context_.new_partition_commit = partition_commit; CollectionCommitPtr collection_commit; - status = store.CreateResource( + status = store->CreateResource( CollectionCommit(collection->GetID(), schema_commit->GetID(), {partition_commit->GetID()}), collection_commit); - AddStepWithLsn(*collection_commit, c_context_.lsn); + auto cc_ctx_p = ResourceContextBuilder().SetOp(meta::oUpdate).CreatePtr(); + AddStepWithLsn(*collection_commit, c_context_.lsn, cc_ctx_p); context_.new_collection_commit = collection_commit; c_context_.collection_commit = collection_commit; context_.new_collection_commit = collection_commit; @@ -463,14 +817,17 @@ CreateCollectionOperation::GetSnapshot(ScopedSnapshotT& ss) const { } Status -DropCollectionOperation::DoExecute(Store& store) { +DropCollectionOperation::DoExecute(StorePtr store) { if (!context_.collection) { std::stringstream emsg; emsg << GetRepr() << ". Collection is missing in context"; return Status(SS_INVALID_CONTEX_ERROR, emsg.str()); } context_.collection->Deactivate(); - AddStepWithLsn(*context_.collection, context_.lsn); + auto c_ctx_p = + ResourceContextBuilder().SetResource(context_.collection).SetOp(meta::oUpdate).CreatePtr(); + c_ctx_p->AddAttr(StateField::Name); + AddStepWithLsn(*context_.collection, context_.lsn, c_ctx_p, false); return Status::OK(); } diff --git a/core/src/db/snapshot/CompoundOperations.h b/core/src/db/snapshot/CompoundOperations.h index b2828239b519..bbdff4314b58 100644 --- a/core/src/db/snapshot/CompoundOperations.h +++ b/core/src/db/snapshot/CompoundOperations.h @@ -44,9 +44,11 @@ class CompoundBaseOperation : public Operations { Status PreCheck() override { // TODO - /* if (GetContextLsn() <= GetStartedSS()->GetMaxLsn()) { */ - /* return Status(SS_INVALID_CONTEX_ERROR, "Invalid LSN found in operation"); */ - /* } */ + if (GetContextLsn() == 0) { + SetContextLsn(GetStartedSS()->GetMaxLsn()); + } else if (GetContextLsn() <= GetStartedSS()->GetMaxLsn()) { + return Status(SS_INVALID_CONTEX_ERROR, "Invalid LSN found in operation"); + } return Status::OK(); } @@ -56,22 +58,66 @@ class CompoundBaseOperation : public Operations { } }; -class BuildOperation : public CompoundBaseOperation { +class AddSegmentFileOperation : public CompoundBaseOperation { public: - using BaseT = CompoundBaseOperation; + using BaseT = CompoundBaseOperation; static constexpr const char* Name = "B"; - BuildOperation(const OperationContext& context, ScopedSnapshotT prev_ss); + AddSegmentFileOperation(const OperationContext& context, ScopedSnapshotT prev_ss); - Status - DoExecute(Store&) override; + Status DoExecute(StorePtr) override; Status CommitNewSegmentFile(const SegmentFileContext& context, SegmentFilePtr& created); + Status + CommitRowCountDelta(SIZE_TYPE delta, bool sub = true); + protected: Status CheckSegmentStale(ScopedSnapshotT& latest_snapshot, ID_TYPE segment_id) const; + + SIZE_TYPE delta_ = 0; + bool sub_; +}; + +class AddFieldElementOperation : public CompoundBaseOperation { + public: + using BaseT = CompoundBaseOperation; + static constexpr const char* Name = "AFE"; + + AddFieldElementOperation(const OperationContext& context, ScopedSnapshotT prev_ss); + + Status + PreCheck() override; + + Status DoExecute(StorePtr) override; +}; + +class DropIndexOperation : public CompoundBaseOperation { + public: + using BaseT = CompoundBaseOperation; + static constexpr const char* Name = "DI"; + + DropIndexOperation(const OperationContext& context, ScopedSnapshotT prev_ss); + + Status + PreCheck() override; + + Status DoExecute(StorePtr) override; +}; + +class DropAllIndexOperation : public CompoundBaseOperation { + public: + using BaseT = CompoundBaseOperation; + static constexpr const char* Name = "DAI"; + + DropAllIndexOperation(const OperationContext& context, ScopedSnapshotT prev_ss); + + Status + PreCheck() override; + + Status DoExecute(StorePtr) override; }; class NewSegmentOperation : public CompoundBaseOperation { @@ -81,14 +127,19 @@ class NewSegmentOperation : public CompoundBaseOperation { NewSegmentOperation(const OperationContext& context, ScopedSnapshotT prev_ss); - Status - DoExecute(Store&) override; + Status DoExecute(StorePtr) override; Status CommitNewSegment(SegmentPtr& created); Status CommitNewSegmentFile(const SegmentFileContext& context, SegmentFilePtr& created); + + Status + CommitRowCount(SIZE_TYPE row_cnt); + + protected: + SIZE_TYPE row_cnt_ = 0; }; class MergeOperation : public CompoundBaseOperation { @@ -98,8 +149,7 @@ class MergeOperation : public CompoundBaseOperation { MergeOperation(const OperationContext& context, ScopedSnapshotT prev_ss); - Status - DoExecute(Store&) override; + Status DoExecute(StorePtr) override; Status CommitNewSegment(SegmentPtr&); @@ -117,8 +167,7 @@ class CreateCollectionOperation : public CompoundBaseOperationGetID(); + first = false; + } + ss << "]"; } if (new_segment_files.size() > 0) { ss << ",N_SF=["; diff --git a/core/src/db/snapshot/Context.h b/core/src/db/snapshot/Context.h index e044e3a90f02..611d2e8e1ec7 100644 --- a/core/src/db/snapshot/Context.h +++ b/core/src/db/snapshot/Context.h @@ -54,17 +54,24 @@ struct OperationContext { ScopedSnapshotT prev_ss; SegmentPtr new_segment = nullptr; SegmentCommitPtr new_segment_commit = nullptr; + std::vector new_segment_commits; PartitionPtr new_partition = nullptr; PartitionCommitPtr new_partition_commit = nullptr; + std::vector new_partition_commits; SchemaCommitPtr new_schema_commit = nullptr; CollectionCommitPtr new_collection_commit = nullptr; CollectionPtr new_collection = nullptr; - SegmentFilePtr stale_segment_file = nullptr; + SegmentFile::VecT stale_segment_files; std::vector stale_segments; FieldPtr prev_field = nullptr; FieldElementPtr prev_field_element = nullptr; + std::vector new_field_elements; + std::vector stale_field_elements; + + std::vector new_field_commits; + std::vector stale_field_commits; SegmentPtr prev_segment = nullptr; SegmentCommitPtr prev_segment_commit = nullptr; diff --git a/core/src/db/snapshot/Event.h b/core/src/db/snapshot/Event.h index b7396f75d651..3fb1d6e07e8f 100644 --- a/core/src/db/snapshot/Event.h +++ b/core/src/db/snapshot/Event.h @@ -18,20 +18,20 @@ #include "db/snapshot/Operations.h" #include "db/snapshot/ResourceHelper.h" +#include "db/snapshot/Store.h" #include "utils/Status.h" namespace milvus { namespace engine { namespace snapshot { -class Event { +class MetaEvent { public: - virtual Status - Process() = 0; + virtual Status Process(StorePtr) = 0; }; template -class ResourceGCEvent : public Event { +class ResourceGCEvent : public MetaEvent { public: using Ptr = std::shared_ptr; @@ -41,25 +41,27 @@ class ResourceGCEvent : public Event { ~ResourceGCEvent() = default; Status - Process() override { - auto& store = Store::GetInstance(); - + Process(StorePtr store) override { /* mark resource as 'deleted' in meta */ auto sd_op = std::make_shared>(res_->GetID()); STATUS_CHECK((*sd_op)(store)); /* TODO: physically clean resource */ - std::vector res_file_list; - STATUS_CHECK(GetResFiles(res_file_list, res_)); - for (auto& res_file : res_file_list) { - if (!boost::filesystem::exists(res_file)) { - continue; - } - if (boost::filesystem::is_directory(res_file)) { - boost::filesystem::remove_all(res_file); - } else { - boost::filesystem::remove(res_file); - } + auto res_prefix = store->GetRootPath(); + std::string res_path = GetResPath(res_prefix, res_); + /* if (!boost::filesystem::exists(res_path)) { */ + /* return Status::OK(); */ + /* } */ + if (res_path.empty()) { + /* std::cout << "[GC] No remove action for " << res_->ToString() << std::endl; */ + } else if (boost::filesystem::is_directory(res_path)) { + auto ok = boost::filesystem::remove_all(res_path); + /* std::cout << "[GC] Remove dir " << res_->ToString() << " " << res_path << " " << ok << std::endl; */ + } else if (boost::filesystem::is_regular_file(res_path)) { + auto ok = boost::filesystem::remove(res_path); + /* std::cout << "[GC] Remove file " << res_->ToString() << " " << res_path << " " << ok << std::endl; */ + } else { + std::cout << "[GC] Remove stale " << res_path << " for " << res_->ToString() << std::endl; } /* remove resource from meta */ diff --git a/core/src/db/snapshot/EventExecutor.h b/core/src/db/snapshot/EventExecutor.h index cb93973fb4e1..da7aa26c8bea 100644 --- a/core/src/db/snapshot/EventExecutor.h +++ b/core/src/db/snapshot/EventExecutor.h @@ -22,23 +22,33 @@ namespace milvus { namespace engine { namespace snapshot { -using EventPtr = std::shared_ptr; +using EventPtr = std::shared_ptr; using ThreadPtr = std::shared_ptr; using EventQueue = BlockingQueue; class EventExecutor { public: - EventExecutor() = default; - EventExecutor(const EventExecutor&) = delete; - ~EventExecutor() { Stop(); } + static void + Init(StorePtr store) { + auto& instance = GetInstanceImpl(); + if (instance.initialized_) { + return; + } + instance.store_ = store; + instance.initialized_ = true; + } + static EventExecutor& GetInstance() { - static EventExecutor inst; - return inst; + auto& instance = GetInstanceImpl(); + if (!instance.initialized_) { + throw std::runtime_error("OperationExecutor should be init"); + } + return instance; } Status @@ -68,15 +78,25 @@ class EventExecutor { } private: + static EventExecutor& + GetInstanceImpl() { + static EventExecutor executor; + return executor; + } + void ThreadMain() { + Status status; while (true) { EventPtr evt = queue_.Take(); if (evt == nullptr) { break; } - std::cout << std::this_thread::get_id() << " Dequeue Event " << std::endl; - evt->Process(); + /* std::cout << std::this_thread::get_id() << " Dequeue Event " << std::endl; */ + status = evt->Process(store_); + if (!status.ok()) { + std::cout << "EventExecutor Handle Event Error: " << status.ToString() << std::endl; + } } } @@ -84,13 +104,17 @@ class EventExecutor { Enqueue(const EventPtr& evt) { queue_.Put(evt); if (evt != nullptr) { - std::cout << std::this_thread::get_id() << " Enqueue Event " << std::endl; + /* std::cout << std::this_thread::get_id() << " Enqueue Event " << std::endl; */ } } - private: + EventExecutor() = default; + EventExecutor(const EventExecutor&) = delete; + ThreadPtr thread_ptr_ = nullptr; EventQueue queue_; + std::atomic_bool initialized_ = false; + StorePtr store_; }; } // namespace snapshot diff --git a/core/src/db/snapshot/HandlerFactory.h b/core/src/db/snapshot/HandlerFactory.h new file mode 100644 index 000000000000..c4e40cbecd70 --- /dev/null +++ b/core/src/db/snapshot/HandlerFactory.h @@ -0,0 +1,101 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include + +#include "db/snapshot/Event.h" + +namespace milvus { +namespace engine { +namespace snapshot { + +class IEventHandler { + public: + using Ptr = std::shared_ptr; + static constexpr const char* EventName = ""; + virtual const char* + GetEventName() const { + return EventName; + } +}; + +class IEventHandlerRegistrar { + public: + using Ptr = std::shared_ptr; + + virtual IEventHandler::Ptr + GetHandler() = 0; +}; + +template +class HandlerFactory { + public: + using ThisT = HandlerFactory; + + static ThisT& + GetInstance() { + static ThisT factory; + return factory; + } + + IEventHandler::Ptr + GetHandler(const std::string& event_name) { + auto it = registry_.find(event_name); + if (it == registry_.end()) { + return nullptr; + } + return it->second->GetHandler(); + } + + void + Register(IEventHandlerRegistrar* registrar, const std::string& event_name) { + auto it = registry_.find(event_name); + if (it == registry_.end()) { + registry_[event_name] = registrar; + } + } + + private: + std::map registry_; +}; + +template +class EventHandlerRegistrar : public IEventHandlerRegistrar { + public: + using FactoryT = HandlerFactory; + using HandlerPtr = typename HandlerT::Ptr; + explicit EventHandlerRegistrar(const std::string& event_name) : event_name_(event_name) { + auto& factory = FactoryT::GetInstance(); + factory.Register(this, event_name_); + } + + HandlerPtr + GetHandler() { + return std::make_shared(); + } + + protected: + std::string event_name_; +}; + +#define REGISTER_HANDLER(EXECUTOR, HANDLER) \ + namespace { \ + static milvus::engine::snapshot::EventHandlerRegistrar EXECUTOR##HANDLER##_registrar( \ + HANDLER ::EventName); \ + } + +} // namespace snapshot +} // namespace engine +} // namespace milvus diff --git a/core/src/db/snapshot/InActiveResourcesGCEvent.h b/core/src/db/snapshot/InActiveResourcesGCEvent.h new file mode 100644 index 000000000000..b9c67840853c --- /dev/null +++ b/core/src/db/snapshot/InActiveResourcesGCEvent.h @@ -0,0 +1,97 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include + +#include "db/snapshot/Event.h" +#include "db/snapshot/EventExecutor.h" +#include "db/snapshot/Operations.h" +#include "db/snapshot/ResourceHelper.h" +#include "db/snapshot/Store.h" +#include "utils/Status.h" + +namespace milvus { +namespace engine { +namespace snapshot { + +class InActiveResourcesGCEvent : public MetaEvent, public Operations { + public: + using Ptr = std::shared_ptr; + + InActiveResourcesGCEvent() : Operations(OperationContext(), ScopedSnapshotT(), OperationsType::O_Leaf) { + } + + ~InActiveResourcesGCEvent() = default; + + Status + Process(StorePtr store) override { + return store->Apply(*this); + } + + Status + OnExecute(StorePtr store) override { + std::cout << "Executing InActiveResourcesGCEvent" << std::endl; + + STATUS_CHECK(ClearInActiveResources(store)); + STATUS_CHECK(ClearInActiveResources(store)); + STATUS_CHECK(ClearInActiveResources(store)); + STATUS_CHECK(ClearInActiveResources(store)); + STATUS_CHECK(ClearInActiveResources(store)); + STATUS_CHECK(ClearInActiveResources(store)); + STATUS_CHECK(ClearInActiveResources(store)); + STATUS_CHECK(ClearInActiveResources(store)); + STATUS_CHECK(ClearInActiveResources(store)); + STATUS_CHECK(ClearInActiveResources(store)); + STATUS_CHECK(ClearInActiveResources(store)); + + return Status::OK(); + } + + private: + template + Status + ClearInActiveResources(StorePtr store) { + std::vector resources; + STATUS_CHECK(store->GetInActiveResources(resources)); + + for (auto& res : resources) { + std::string res_path = GetResPath(dir_root_, res); + if (res_path.empty()) { + /* std::cout << "[GC] No remove action for " << res_->ToString() << std::endl; */ + } else if (boost::filesystem::is_directory(res_path)) { + auto ok = boost::filesystem::remove_all(res_path); + /* std::cout << "[GC] Remove dir " << res_->ToString() << " " << res_path << " " << ok << std::endl; */ + } else if (boost::filesystem::is_regular_file(res_path)) { + auto ok = boost::filesystem::remove(res_path); + /* std::cout << "[GC] Remove file " << res_->ToString() << " " << res_path << " " << ok << std::endl; */ + } else { + std::cout << "[GC] Remove stale " << res_path << " for " << res->ToString() << std::endl; + } + + /* remove resource from meta */ + auto hd_op = std::make_shared>(res->GetID()); + STATUS_CHECK((*hd_op)(store)); + } + + return Status::OK(); + } + + private: + std::string dir_root_; +}; + +} // namespace snapshot +} // namespace engine +} // namespace milvus diff --git a/core/src/db/snapshot/IterateHandler.h b/core/src/db/snapshot/IterateHandler.h new file mode 100644 index 000000000000..034b3299cc10 --- /dev/null +++ b/core/src/db/snapshot/IterateHandler.h @@ -0,0 +1,82 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include + +#include "db/snapshot/Snapshot.h" +#include "utils/Status.h" + +namespace milvus { +namespace engine { +namespace snapshot { + +template +struct IterateHandler : public std::enable_shared_from_this> { + using ResourceT = T; + using ThisT = IterateHandler; + using Ptr = std::shared_ptr; + using ExecutorT = std::function; + + explicit IterateHandler(ScopedSnapshotT ss, const ExecutorT& executor = {}) : ss_(ss), executor_(executor) { + } + + virtual Status + PreIterate() { + return Status::OK(); + } + virtual Status + Handle(const typename ResourceT::Ptr& resource) { + if (executor_) { + return executor_(resource, this); + } + return Status::OK(); + } + + virtual Status + PostIterate() { + return Status::OK(); + } + + void + SetStatus(Status status) { + std::unique_lock lock(mtx_); + status_ = status; + } + Status + GetStatus() const { + std::unique_lock lock(mtx_); + return status_; + } + + virtual void + Iterate() { + ss_->IterateResources(this->shared_from_this()); + } + + ScopedSnapshotT ss_; + ExecutorT executor_; + Status status_; + mutable std::mutex mtx_; +}; + +using CollectionIterator = IterateHandler; +using PartitionIterator = IterateHandler; +using SegmentIterator = IterateHandler; +using SegmentFileIterator = IterateHandler; +using FieldIterator = IterateHandler; +using FieldElementIterator = IterateHandler; + +} // namespace snapshot +} // namespace engine +} // namespace milvus diff --git a/core/src/db/snapshot/OperationExecutor.h b/core/src/db/snapshot/OperationExecutor.h index 319d52ff90ab..8df2fa25b945 100644 --- a/core/src/db/snapshot/OperationExecutor.h +++ b/core/src/db/snapshot/OperationExecutor.h @@ -25,17 +25,27 @@ using OperationQueue = BlockingQueue; class OperationExecutor { public: - OperationExecutor() = default; - OperationExecutor(const OperationExecutor&) = delete; - ~OperationExecutor() { Stop(); } + static void + Init(StorePtr store) { + auto& instance = GetInstanceImpl(); + if (instance.initialized_) { + return; + } + instance.store_ = store; + instance.initialized_ = true; + } + static OperationExecutor& GetInstance() { - static OperationExecutor executor; - return executor; + auto& instance = GetInstanceImpl(); + if (!instance.initialized_) { + throw std::runtime_error("OperationExecutor should be init"); + } + return instance; } Status @@ -43,8 +53,6 @@ class OperationExecutor { if (!operation) { return Status(SS_INVALID_ARGUMENT_ERROR, "Invalid Operation"); } - /* Store::GetInstance().Apply(*operation); */ - /* return true; */ Enqueue(operation); if (sync) { return operation->WaitToFinish(); @@ -70,6 +78,15 @@ class OperationExecutor { } private: + OperationExecutor() = default; + OperationExecutor(const OperationExecutor&) = delete; + + static OperationExecutor& + GetInstanceImpl() { + static OperationExecutor executor; + return executor; + } + void ThreadMain() { while (true) { @@ -77,20 +94,20 @@ class OperationExecutor { if (!operation) { break; } - /* std::cout << std::this_thread::get_id() << " Dequeue Operation " << operation->GetID() << std::endl; */ - Store::GetInstance().Apply(*operation); + store_->Apply(*operation); } } void Enqueue(const OperationsPtr& operation) { - /* std::cout << std::this_thread::get_id() << " Enqueue Operation " << operation->GetID() << std::endl; */ queue_.Put(operation); } private: ThreadPtr thread_ptr_ = nullptr; OperationQueue queue_; + std::atomic_bool initialized_ = false; + StorePtr store_; }; } // namespace milvus::engine::snapshot diff --git a/core/src/db/snapshot/Operations.cpp b/core/src/db/snapshot/Operations.cpp index e47b3c0764fa..ef2767a8e778 100644 --- a/core/src/db/snapshot/Operations.cpp +++ b/core/src/db/snapshot/Operations.cpp @@ -10,8 +10,12 @@ // or implied. See the License for the specific language governing permissions and limitations under the License. #include "db/snapshot/Operations.h" + #include #include + +#include "db/snapshot/Event.h" +#include "db/snapshot/EventExecutor.h" #include "db/snapshot/OperationExecutor.h" #include "db/snapshot/Snapshots.h" @@ -73,7 +77,7 @@ Operations::GetID() const { } Status -Operations::operator()(Store& store) { +Operations::operator()(StorePtr store) { STATUS_CHECK(PreCheck()); return ApplyToStore(store); } @@ -92,7 +96,7 @@ Operations::WaitToFinish() { } void -Operations::Done(Store& store) { +Operations::Done(StorePtr store) { std::unique_lock lock(finish_mtx_); done_ = true; if (GetType() == OperationsType::W_Compound) { @@ -100,6 +104,13 @@ Operations::Done(Store& store) { Snapshots::GetInstance().LoadSnapshot(store, context_.latest_ss, context_.new_collection_commit->GetCollectionId(), ids_.back()); } + /* if (!context_.latest_ss && context_.new_collection_commit) { */ + /* auto& holder = std::get(holders_); */ + /* if (holder.size() > 0) */ + /* Snapshots::GetInstance().LoadSnapshot(store, context_.latest_ss, */ + /* context_.new_collection_commit->GetCollectionId(), holder.rbegin()->GetID()); */ + /* } */ + /* } */ std::cout << ToString() << std::endl; } finish_cond_.notify_all(); @@ -176,7 +187,7 @@ Operations::GetSnapshot(ScopedSnapshotT& ss) const { } const Status& -Operations::ApplyToStore(Store& store) { +Operations::ApplyToStore(StorePtr store) { if (GetType() == OperationsType::W_Compound) { /* std::cout << ToString() << std::endl; */ } @@ -203,7 +214,7 @@ Operations::OnSnapshotStale() { } Status -Operations::OnExecute(Store& store) { +Operations::OnExecute(StorePtr store) { STATUS_CHECK(PreExecute(store)); STATUS_CHECK(DoExecute(store)); STATUS_CHECK(PostExecute(store)); @@ -211,7 +222,7 @@ Operations::OnExecute(Store& store) { } Status -Operations::PreExecute(Store& store) { +Operations::PreExecute(StorePtr store) { if (GetStartedSS() && type_ == OperationsType::W_Compound) { STATUS_CHECK(Snapshots::GetInstance().GetSnapshot(context_.prev_ss, GetStartedSS()->GetCollectionId())); if (!context_.prev_ss) { @@ -224,30 +235,35 @@ Operations::PreExecute(Store& store) { } Status -Operations::DoExecute(Store& store) { +Operations::DoExecute(StorePtr store) { return Status::OK(); } Status -Operations::PostExecute(Store& store) { - return store.DoCommitOperation(*this); +Operations::PostExecute(StorePtr store) { + return store->ApplyOperation(*this); } -Status -Operations::RollBack() { - // TODO: Implement here - // Spwarn a rollback operation or re-use this operation - return Status::OK(); +template +void +ApplyRollBack(std::set>>& step_context_set) { + for (auto& step_context : step_context_set) { + auto res = step_context->Resource(); + auto evt_ptr = std::make_shared>(res); + EventExecutor::GetInstance().Submit(evt_ptr); + std::cout << "Rollback " << typeid(ResourceT).name() << ": " << res->GetID() << std::endl; + } } -Status -Operations::ApplyRollBack(Store& store) { - // TODO: Implement rollback to remove all resources in steps_ - return Status::OK(); +void +Operations::RollBack() { + std::apply([&](auto&... step_context_set) { ((ApplyRollBack(step_context_set)), ...); }, GetStepHolders()); } Operations::~Operations() { - // TODO: Prefer to submit a rollback operation if status is not ok + if (!status_.ok() || !done_) { + RollBack(); + } } } // namespace snapshot diff --git a/core/src/db/snapshot/Operations.h b/core/src/db/snapshot/Operations.h index 18f65a1f7163..3c245e89784f 100644 --- a/core/src/db/snapshot/Operations.h +++ b/core/src/db/snapshot/Operations.h @@ -17,10 +17,14 @@ #include #include #include +#include #include #include +#include #include -#include "Context.h" + +#include "db/snapshot/Context.h" +#include "db/snapshot/ResourceContext.h" #include "db/snapshot/Snapshot.h" #include "db/snapshot/Store.h" #include "utils/Error.h" @@ -30,8 +34,17 @@ namespace milvus { namespace engine { namespace snapshot { -using StepsT = std::vector; using CheckStaleFunc = std::function; +// using StepsHolderT = std::tuple; +template +using StepsContextSet = std::set::Ptr>; +using StepsHolderT = + std::tuple, StepsContextSet, StepsContextSet, + StepsContextSet, StepsContextSet, StepsContextSet, + StepsContextSet, StepsContextSet, StepsContextSet, + StepsContextSet, StepsContextSet>; enum OperationsType { Invalid, W_Leaf, O_Leaf, W_Compound, O_Compound }; @@ -40,6 +53,11 @@ class Operations : public std::enable_shared_from_this { Operations(const OperationContext& context, ScopedSnapshotT prev_ss, const OperationsType& type = OperationsType::Invalid); + const OperationContext& + GetContext() const { + return context_; + } + const ScopedSnapshotT& GetStartedSS() const { return prev_ss_; @@ -55,6 +73,11 @@ class Operations : public std::enable_shared_from_this { return context_.lsn; } + void + SetContextLsn(LSN_TYPE lsn) { + context_.lsn = lsn; + } + virtual Status CheckStale(const CheckStaleFunc& checker = nullptr) const; virtual Status @@ -62,18 +85,24 @@ class Operations : public std::enable_shared_from_this { template void - AddStep(const StepT& step, bool activate = true); + AddStep(const StepT& step, ResourceContextPtr step_context = nullptr, bool activate = true); template void - AddStepWithLsn(const StepT& step, const LSN_TYPE& lsn, bool activate = true); + AddStepWithLsn(const StepT& step, const LSN_TYPE& lsn, ResourceContextPtr step_context = nullptr, + bool activate = true); void SetStepResult(ID_TYPE id) { ids_.push_back(id); } - StepsT& - GetSteps() { - return steps_; + const size_t + GetPos() const { + return last_pos_; + } + + StepsHolderT& + GetStepHolders() { + return holders_; } ID_TYPE @@ -84,34 +113,29 @@ class Operations : public std::enable_shared_from_this { return type_; } - virtual Status - OnExecute(Store&); - virtual Status - PreExecute(Store&); - virtual Status - DoExecute(Store&); - virtual Status - PostExecute(Store&); + virtual Status OnExecute(StorePtr); + virtual Status PreExecute(StorePtr); + virtual Status DoExecute(StorePtr); + virtual Status PostExecute(StorePtr); virtual Status GetSnapshot(ScopedSnapshotT& ss) const; virtual Status - operator()(Store& store); + operator()(StorePtr store); virtual Status Push(bool sync = true); virtual Status PreCheck(); - virtual const Status& - ApplyToStore(Store& store); + virtual const Status& ApplyToStore(StorePtr); const Status& WaitToFinish(); void - Done(Store& store); + Done(StorePtr store); void SetStatus(const Status& status); @@ -132,9 +156,6 @@ class Operations : public std::enable_shared_from_this { virtual std::string ToString() const; - Status - RollBack(); - virtual Status OnSnapshotStale(); virtual Status @@ -158,12 +179,13 @@ class Operations : public std::enable_shared_from_this { Status CheckPrevSnapshot() const; - Status - ApplyRollBack(Store&); + void + RollBack(); OperationContext context_; ScopedSnapshotT prev_ss_; - StepsT steps_; + StepsHolderT holders_; + size_t last_pos_; std::vector ids_; bool done_ = false; Status status_; @@ -175,21 +197,43 @@ class Operations : public std::enable_shared_from_this { template void -Operations::AddStep(const StepT& step, bool activate) { +Operations::AddStep(const StepT& step, ResourceContextPtr step_context, bool activate) { + if (step_context == nullptr) { + step_context = ResourceContextBuilder().SetOp(meta::oAdd).CreatePtr(); + } + auto s = std::make_shared(step); - if (activate) + step_context->AddResource(s); + if (activate) { s->Activate(); - steps_.push_back(s); + step_context->AddAttr(StateField::Name); + } + + last_pos_ = Index, StepsHolderT>::value; + auto& holder = std::get, StepsHolderT>::value>(holders_); + holder.insert(step_context); } template void -Operations::AddStepWithLsn(const StepT& step, const LSN_TYPE& lsn, bool activate) { +Operations::AddStepWithLsn(const StepT& step, const LSN_TYPE& lsn, ResourceContextPtr step_context, + bool activate) { + if (step_context == nullptr) { + step_context = ResourceContextBuilder().SetOp(meta::oAdd).CreatePtr(); + } + auto s = std::make_shared(step); - if (activate) + step_context->AddResource(s); + if (activate) { s->Activate(); + step_context->AddAttr(StateField::Name); + } s->SetLsn(lsn); - steps_.push_back(s); + step_context->AddAttr(LsnField::Name); + + last_pos_ = Index, StepsHolderT>::value; + auto& holder = std::get, StepsHolderT>::value>(holders_); + holder.insert(step_context); } template @@ -239,12 +283,12 @@ class LoadOperation : public Operations { } const Status& - ApplyToStore(Store& store) override { + ApplyToStore(StorePtr store) override { if (done_) { Done(store); return status_; } - auto status = store.GetResource(context_.id, resource_); + auto status = store->GetResource(context_.id, resource_); SetStatus(status); Done(store); return status_; @@ -297,8 +341,8 @@ class SoftDeleteOperation : public Operations { } Status - DoExecute(Store& store) override { - auto status = store.GetResource(id_, resource_); + DoExecute(StorePtr store) override { + auto status = store->GetResource(id_, resource_); if (!status.ok()) { return status; } @@ -308,7 +352,9 @@ class SoftDeleteOperation : public Operations { return Status(SS_NOT_FOUND_ERROR, emsg.str()); } resource_->Deactivate(); - AddStep(*resource_, false); + auto r_ctx_p = ResourceContextBuilder().SetResource(resource_).SetOp(meta::oUpdate).CreatePtr(); + r_ctx_p->AddAttr(StateField::Name); + AddStep(*resource_, r_ctx_p, false); return status; } @@ -325,33 +371,10 @@ class HardDeleteOperation : public Operations { } const Status& - ApplyToStore(Store& store) override { + ApplyToStore(StorePtr store) override { if (done_) return status_; - auto status = store.RemoveResource(id_); - SetStatus(status); - Done(store); - return status_; - } - - protected: - ID_TYPE id_; -}; - -template <> -class HardDeleteOperation : public Operations { - public: - explicit HardDeleteOperation(ID_TYPE id) - : Operations(OperationContext(), ScopedSnapshotT(), OperationsType::W_Leaf), id_(id) { - } - - const Status& - ApplyToStore(Store& store) override { - if (done_) { - Done(store); - return status_; - } - auto status = store.RemoveCollection(id_); + auto status = store->RemoveResource(id_); SetStatus(status); Done(store); return status_; diff --git a/core/src/db/snapshot/ResourceContext.h b/core/src/db/snapshot/ResourceContext.h new file mode 100644 index 000000000000..4de698883842 --- /dev/null +++ b/core/src/db/snapshot/ResourceContext.h @@ -0,0 +1,157 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "db/meta/backend/MetaContext.h" +#include "db/snapshot/BaseResource.h" +#include "db/snapshot/ResourceTypes.h" + +namespace milvus::engine::snapshot { + +using ResourceContextOp = meta::MetaContextOp; + +template +class ResourceContext { + public: + using ResPtr = typename ResourceT::Ptr; + using Ptr = std::shared_ptr>; + + public: + ResourceContext(const std::string& table, ID_TYPE id, ResourceContextOp op, ResPtr res, std::set attrs) + : table_(table), id_(id), resource_(std::move(res)), op_(op), attrs_(std::move(attrs)) { + } + + ~ResourceContext() = default; + + public: + void + AddResource(ResPtr res) { + table_ = ResourceT::Name; + resource_ = std::shared_ptr(std::move(res)); + } + + void + AddAttr(const std::string& attr) { + attrs_.insert(attr); + } + + void + AddAttrs(const std::set& attrs) { + attrs_.insert(attrs.begin(), attrs.end()); + } + + void + UpdateOp(const ResourceContextOp op) { + op_ = op; + } + + ResPtr + Resource() { + return resource_; + } + + ResourceContextOp + Op() { + return op_; + } + + ID_TYPE + ID() const { + return id_; + } + + std::set& + Attrs() { + return attrs_; + } + + std::string + Table() { + return table_; + } + + private: + std::string table_; + ID_TYPE id_; + ResPtr resource_; + ResourceContextOp op_; + std::set attrs_; +}; + +template +class ResourceContextBuilder { + public: + ResourceContextBuilder() : table_(T::Name), op_(meta::oAdd) { + } + + ResourceContextBuilder& + SetResource(typename T::Ptr res) { + table_ = T::Name; + id_ = res->GetID(); + // resource_ = std::shared_ptr(std::move(res)); + resource_ = std::move(res); + return *this; + } + + ResourceContextBuilder& + SetOp(ResourceContextOp op) { + op_ = op; + return *this; + } + + ResourceContextBuilder& + SetID(ID_TYPE id) { + id_ = id; + return *this; + } + + ResourceContextBuilder& + SetTable(const std::string& table) { + table_ = table; + return *this; + } + + ResourceContextBuilder& + AddAttr(const std::string& attr) { + attrs_.insert(attr); + return *this; + } + + ResourceContextBuilder& + AddAttrs(const std::set& attrs) { + attrs_.insert(attrs.begin(), attrs.end()); + return *this; + } + + public: + typename ResourceContext::Ptr + CreatePtr() { + return std::make_shared>(table_, id_, op_, resource_, attrs_); + } + + private: + std::string table_; + typename ResourceContext::ResPtr resource_; + ID_TYPE id_{}; + ResourceContextOp op_; + std::set attrs_; +}; + +template +using ResourceContextPtr = typename ResourceContext::Ptr; + +} // namespace milvus::engine::snapshot diff --git a/core/src/db/snapshot/ResourceHelper.h b/core/src/db/snapshot/ResourceHelper.h index 995d48a222f6..e42aeb6f2208 100644 --- a/core/src/db/snapshot/ResourceHelper.h +++ b/core/src/db/snapshot/ResourceHelper.h @@ -11,64 +11,143 @@ #pragma once +#include #include -#include #include "db/snapshot/Resources.h" #include "utils/Status.h" namespace milvus::engine::snapshot { +static const char* COLLECTION_PREFIX = "C_"; +static const char* PARTITION_PREFIX = "P_"; +static const char* SEGMENT_PREFIX = "S_"; +static const char* SEGMENT_FILE_PREFIX = "F_"; + template -inline Status -GetResFiles(std::vector& file_list, typename ResourceT::Ptr& res_ptr) { - return Status::OK(); +inline std::string +GetResPath(const std::string& root, const typename ResourceT::Ptr& res_ptr) { + return std::string(); } template <> -inline Status -GetResFiles(std::vector& file_list, Collection::Ptr& res_ptr) { +inline std::string +GetResPath(const std::string& root, const Collection::Ptr& res_ptr) { std::stringstream ss; - ss << res_ptr->GetID(); + ss << root << "/"; + ss << COLLECTION_PREFIX << res_ptr->GetID(); - file_list.push_back(ss.str()); - return Status::OK(); + return ss.str(); } template <> -inline Status -GetResFiles(std::vector& file_list, Partition::Ptr& res_ptr) { +inline std::string +GetResPath(const std::string& root, const Partition::Ptr& res_ptr) { std::stringstream ss; - ss << res_ptr->GetCollectionId() << "/"; - ss << res_ptr->GetID(); + ss << root << "/"; + ss << COLLECTION_PREFIX << res_ptr->GetCollectionId() << "/"; + ss << PARTITION_PREFIX << res_ptr->GetID(); - file_list.push_back(ss.str()); - return Status::OK(); + return ss.str(); } template <> -inline Status -GetResFiles(std::vector& file_list, Segment::Ptr& res_ptr) { +inline std::string +GetResPath(const std::string& root, const Segment::Ptr& res_ptr) { std::stringstream ss; - ss << res_ptr->GetCollectionId() << "/"; - ss << res_ptr->GetPartitionId() << "/"; - ss << res_ptr->GetID(); + ss << root << "/"; + ss << COLLECTION_PREFIX << res_ptr->GetCollectionId() << "/"; + ss << PARTITION_PREFIX << res_ptr->GetPartitionId() << "/"; + ss << SEGMENT_PREFIX << res_ptr->GetID(); - file_list.push_back(ss.str()); - return Status::OK(); + return ss.str(); } template <> -inline Status -GetResFiles(std::vector& file_list, SegmentFile::Ptr& res_ptr) { +inline std::string +GetResPath(const std::string& root, const SegmentFile::Ptr& res_ptr) { std::stringstream ss; - ss << res_ptr->GetCollectionId() << "/"; - ss << res_ptr->GetPartitionId() << "/"; - ss << res_ptr->GetSegmentId() << "/"; - ss << res_ptr->GetID(); + ss << root << "/"; + ss << COLLECTION_PREFIX << res_ptr->GetCollectionId() << "/"; + ss << PARTITION_PREFIX << res_ptr->GetPartitionId() << "/"; + ss << SEGMENT_PREFIX << res_ptr->GetSegmentId() << "/"; + ss << SEGMENT_FILE_PREFIX << res_ptr->GetID(); + + return ss.str(); +} + +/////////////////////////////////////////////////////////////////////////////////////// +/// Default resource creator +template +inline typename T::Ptr +CreateResPtr() { + return nullptr; +} + +template <> +inline Collection::Ptr +CreateResPtr() { + return std::make_shared(""); +} + +template <> +inline CollectionCommit::Ptr +CreateResPtr() { + return std::make_shared(0, 0); +} + +template <> +inline Partition::Ptr +CreateResPtr() { + return std::make_shared("", 0); +} + +template <> +inline PartitionCommit::Ptr +CreateResPtr() { + return std::make_shared(0, 0); +} + +template <> +inline Segment::Ptr +CreateResPtr() { + return std::make_shared(0, 0); +} + +template <> +inline SegmentCommit::Ptr +CreateResPtr() { + return std::make_shared(0, 0, 0); +} + +template <> +inline SegmentFile::Ptr +CreateResPtr() { + return std::make_shared(0, 0, 0, 0); +} + +template <> +inline SchemaCommit::Ptr +CreateResPtr() { + return std::make_shared(0); +} + +template <> +inline Field::Ptr +CreateResPtr() { + return std::make_shared("", 0, 0); +} - file_list.push_back(ss.str()); - return Status::OK(); +template <> +inline FieldCommit::Ptr +CreateResPtr() { + return std::make_shared(0, 0); +} + +template <> +inline FieldElement::Ptr +CreateResPtr() { + return std::make_shared(0, 0, "", 0); } } // namespace milvus::engine::snapshot diff --git a/core/src/db/snapshot/ResourceHolder.h b/core/src/db/snapshot/ResourceHolder.h index 39660ff6e673..aceb63357161 100644 --- a/core/src/db/snapshot/ResourceHolder.h +++ b/core/src/db/snapshot/ResourceHolder.h @@ -17,6 +17,7 @@ #include #include #include + #include "db/snapshot/Event.h" #include "db/snapshot/EventExecutor.h" #include "db/snapshot/Operations.h" @@ -48,11 +49,12 @@ class ResourceHolder { } ScopedT - GetResource(ID_TYPE id, bool scoped = true) { - // TODO: Temp to use Load here. Will be removed when resource is loaded just post Compound - // Operations. - return Load(Store::GetInstance(), id, scoped); + GetResource(StorePtr store, ID_TYPE id, bool scoped = true) { + return Load(store, id, scoped); + } + ScopedT + GetResource(ID_TYPE id, bool scoped = true) { { std::unique_lock lock(mutex_); auto cit = id_map_.find(id); @@ -75,13 +77,15 @@ class ResourceHolder { return ReleaseNoLock(id); } - // TODO: Resource should be loaded into holder in OperationExecutor thread ScopedT - Load(Store& store, ID_TYPE id, bool scoped = true) { + Load(StorePtr store, ID_TYPE id, bool scoped = true) { { std::unique_lock lock(mutex_); auto cit = id_map_.find(id); if (cit != id_map_.end()) { + if (!cit->second->IsActive()) { + return ScopedT(); + } return ScopedT(cit->second, scoped); } } @@ -92,14 +96,6 @@ class ResourceHolder { return ScopedT(ret, scoped); } - virtual bool - HardDelete(ID_TYPE id) { - auto op = std::make_shared>(id); - // TODO: - (*op)(Store::GetInstance()); - return true; - } - virtual void Reset() { id_map_.clear(); @@ -142,20 +138,22 @@ class ResourceHolder { virtual void OnNoRefCallBack(ResourcePtr resource) { + resource->Deactivate(); + Release(resource->GetID()); auto evt_ptr = std::make_shared>(resource); EventExecutor::GetInstance().Submit(evt_ptr); - Release(resource->GetID()); } virtual ResourcePtr - DoLoad(Store& store, ID_TYPE id) { + DoLoad(StorePtr store, ID_TYPE id) { LoadOperationContext context; context.id = id; auto op = std::make_shared>(context); (*op)(store); typename ResourceT::Ptr c; auto status = op->GetResource(c); - if (status.ok()) { + if (status.ok() && c->IsActive()) { + /* if (status.ok()) { */ Add(c); return c; } diff --git a/core/src/db/snapshot/ResourceOperations.cpp b/core/src/db/snapshot/ResourceOperations.cpp index 3fa41d7299af..f23adc605b86 100644 --- a/core/src/db/snapshot/ResourceOperations.cpp +++ b/core/src/db/snapshot/ResourceOperations.cpp @@ -17,8 +17,10 @@ namespace engine { namespace snapshot { Status -CollectionCommitOperation::DoExecute(Store& store) { +CollectionCommitOperation::DoExecute(StorePtr store) { auto prev_resource = GetPrevResource(); + auto row_cnt = 0; + auto size = 0; if (!prev_resource) { std::stringstream emsg; emsg << GetRepr() << ". Cannot find prev collection commit resource"; @@ -26,19 +28,39 @@ CollectionCommitOperation::DoExecute(Store& store) { } resource_ = std::make_shared(*prev_resource); resource_->ResetStatus(); + row_cnt = resource_->GetRowCount(); + size = resource_->GetSize(); + + auto handle_new_pc = [&](PartitionCommitPtr& pc) { + auto prev_partition_commit = GetStartedSS()->GetPartitionCommitByPartitionId(pc->GetPartitionId()); + if (prev_partition_commit) { + resource_->GetMappings().erase(prev_partition_commit->GetID()); + row_cnt -= prev_partition_commit->GetRowCount(); + size -= prev_partition_commit->GetSize(); + } + resource_->GetMappings().insert(pc->GetID()); + row_cnt += pc->GetRowCount(); + size += pc->GetSize(); + }; + if (context_.stale_partition_commit) { resource_->GetMappings().erase(context_.stale_partition_commit->GetID()); + row_cnt -= context_.stale_partition_commit->GetRowCount(); + size -= context_.stale_partition_commit->GetSize(); } else if (context_.new_partition_commit) { - auto prev_partition_commit = - GetStartedSS()->GetPartitionCommitByPartitionId(context_.new_partition_commit->GetPartitionId()); - if (prev_partition_commit) - resource_->GetMappings().erase(prev_partition_commit->GetID()); - resource_->GetMappings().insert(context_.new_partition_commit->GetID()); - } else if (context_.new_schema_commit) { + handle_new_pc(context_.new_partition_commit); + } else if (context_.new_partition_commits.size() > 0) { + for (auto& pc : context_.new_partition_commits) { + handle_new_pc(pc); + } + } + if (context_.new_schema_commit) { resource_->SetSchemaId(context_.new_schema_commit->GetID()); } resource_->SetID(0); - AddStep(*BaseT::resource_, false); + resource_->SetRowCount(row_cnt); + resource_->SetSize(size); + AddStep(*BaseT::resource_, nullptr, false); return Status::OK(); } @@ -52,12 +74,12 @@ PartitionOperation::PreCheck() { } Status -PartitionOperation::DoExecute(Store& store) { +PartitionOperation::DoExecute(StorePtr store) { auto status = CheckStale(); if (!status.ok()) return status; resource_ = std::make_shared(context_.name, GetStartedSS()->GetCollection()->GetID()); - AddStep(*resource_, false); + AddStep(*resource_, nullptr, false); return status; } @@ -72,23 +94,41 @@ PartitionCommitOperation::PreCheck() { PartitionCommitPtr PartitionCommitOperation::GetPrevResource() const { - auto& segment_commit = context_.new_segment_commit; - if (!segment_commit) - return nullptr; - return GetStartedSS()->GetPartitionCommitByPartitionId(segment_commit->GetPartitionId()); + if (context_.new_segment_commit) { + return GetStartedSS()->GetPartitionCommitByPartitionId(context_.new_segment_commit->GetPartitionId()); + } else if (context_.new_segment_commits.size() > 0) { + return GetStartedSS()->GetPartitionCommitByPartitionId(context_.new_segment_commits[0]->GetPartitionId()); + } + return nullptr; } Status -PartitionCommitOperation::DoExecute(Store& store) { +PartitionCommitOperation::DoExecute(StorePtr store) { auto prev_resource = GetPrevResource(); + auto row_cnt = 0; + auto size = 0; if (prev_resource) { resource_ = std::make_shared(*prev_resource); resource_->SetID(0); resource_->ResetStatus(); - auto prev_segment_commit = - GetStartedSS()->GetSegmentCommitBySegmentId(context_.new_segment_commit->GetSegmentId()); - if (prev_segment_commit) - resource_->GetMappings().erase(prev_segment_commit->GetID()); + row_cnt = resource_->GetRowCount(); + size = resource_->GetSize(); + auto erase_sc = [&](SegmentCommitPtr& sc) { + if (!sc) + return; + auto prev_sc = GetStartedSS()->GetSegmentCommitBySegmentId(sc->GetSegmentId()); + if (prev_sc) { + resource_->GetMappings().erase(prev_sc->GetID()); + row_cnt -= prev_sc->GetRowCount(); + size -= prev_sc->GetSize(); + } + }; + + erase_sc(context_.new_segment_commit); + for (auto& sc : context_.new_segment_commits) { + erase_sc(sc); + } + if (context_.stale_segments.size() > 0) { for (auto& stale_segment : context_.stale_segments) { if (stale_segment->GetPartitionId() != prev_resource->GetPartitionId()) { @@ -99,6 +139,8 @@ PartitionCommitOperation::DoExecute(Store& store) { } auto stale_segment_commit = GetStartedSS()->GetSegmentCommitBySegmentId(stale_segment->GetID()); resource_->GetMappings().erase(stale_segment_commit->GetID()); + row_cnt -= stale_segment_commit->GetRowCount(); + size -= stale_segment_commit->GetSize(); } } } else { @@ -113,23 +155,21 @@ PartitionCommitOperation::DoExecute(Store& store) { if (context_.new_segment_commit) { resource_->GetMappings().insert(context_.new_segment_commit->GetID()); + row_cnt += context_.new_segment_commit->GetRowCount(); + size += context_.new_segment_commit->GetSize(); + } else if (context_.new_segment_commits.size() > 0) { + for (auto& sc : context_.new_segment_commits) { + resource_->GetMappings().insert(sc->GetID()); + row_cnt += sc->GetRowCount(); + size += sc->GetSize(); + } } - AddStep(*resource_, false); + resource_->SetRowCount(row_cnt); + resource_->SetSize(size); + AddStep(*resource_, nullptr, false); return Status::OK(); } -SegmentCommitOperation::SegmentCommitOperation(const OperationContext& context, ScopedSnapshotT prev_ss) - : BaseT(context, prev_ss) { -} - -SegmentCommit::Ptr -SegmentCommitOperation::GetPrevResource() const { - if (context_.new_segment_files.size() > 0) { - return GetStartedSS()->GetSegmentCommitBySegmentId(context_.new_segment_files[0]->GetSegmentId()); - } - return nullptr; -} - SegmentOperation::SegmentOperation(const OperationContext& context, ScopedSnapshotT prev_ss) : BaseT(context, prev_ss) { } @@ -144,7 +184,7 @@ SegmentOperation::PreCheck() { } Status -SegmentOperation::DoExecute(Store& store) { +SegmentOperation::DoExecute(StorePtr store) { if (!context_.prev_partition) { std::stringstream emsg; emsg << GetRepr() << ". prev_partition should be specified in context"; @@ -153,21 +193,38 @@ SegmentOperation::DoExecute(Store& store) { auto prev_num = GetStartedSS()->GetMaxSegmentNumByPartition(context_.prev_partition->GetID()); resource_ = std::make_shared(context_.prev_partition->GetCollectionId(), context_.prev_partition->GetID(), prev_num + 1); - AddStep(*resource_, false); + AddStep(*resource_, nullptr, false); return Status::OK(); } +SegmentCommitOperation::SegmentCommitOperation(const OperationContext& context, ScopedSnapshotT prev_ss) + : BaseT(context, prev_ss) { +} + +SegmentCommit::Ptr +SegmentCommitOperation::GetPrevResource() const { + if (context_.new_segment_files.size() > 0) { + return GetStartedSS()->GetSegmentCommitBySegmentId(context_.new_segment_files[0]->GetSegmentId()); + } else if (context_.stale_segment_files.size() != 0) { + return GetStartedSS()->GetSegmentCommitBySegmentId(context_.stale_segment_files[0]->GetSegmentId()); + } + return nullptr; +} + Status -SegmentCommitOperation::DoExecute(Store& store) { +SegmentCommitOperation::DoExecute(StorePtr store) { auto prev_resource = GetPrevResource(); - + auto size = 0; if (prev_resource) { resource_ = std::make_shared(*prev_resource); resource_->SetID(0); resource_->ResetStatus(); - if (context_.stale_segment_file) { - resource_->GetMappings().erase(context_.stale_segment_file->GetID()); + size = resource_->GetSize(); + for (auto& stale_file : context_.stale_segment_files) { + resource_->GetMappings().erase(stale_file->GetID()); + size -= stale_file->GetSize(); } + } else { resource_ = std::make_shared(GetStartedSS()->GetLatestSchemaCommitId(), context_.new_segment_files[0]->GetPartitionId(), @@ -175,31 +232,117 @@ SegmentCommitOperation::DoExecute(Store& store) { } for (auto& new_segment_file : context_.new_segment_files) { resource_->GetMappings().insert(new_segment_file->GetID()); + size += new_segment_file->GetSize(); } - AddStep(*resource_, false); + resource_->SetSize(size); + AddStep(*resource_, nullptr, false); return Status::OK(); } Status SegmentCommitOperation::PreCheck() { - if (context_.new_segment_files.size() == 0) { + if (context_.stale_segment_files.size() == 0 && context_.new_segment_files.size() == 0) { std::stringstream emsg; emsg << GetRepr() << ". new_segment_files should not be empty in context"; return Status(SS_INVALID_CONTEX_ERROR, emsg.str()); + } else if (context_.stale_segment_files.size() > 0 && context_.new_segment_files.size() > 0) { + std::stringstream emsg; + emsg << GetRepr() << ". new_segment_files should be empty in context"; + return Status(SS_INVALID_CONTEX_ERROR, emsg.str()); } return Status::OK(); } +FieldCommitOperation::FieldCommitOperation(const OperationContext& context, ScopedSnapshotT prev_ss) + : BaseT(context, prev_ss) { +} + +FieldCommit::Ptr +FieldCommitOperation::GetPrevResource() const { + auto get_resource = [&](FieldElementPtr fe) -> FieldCommitPtr { + auto& field_commits = GetStartedSS()->GetResources(); + for (auto& kv : field_commits) { + if (kv.second->GetFieldId() == fe->GetFieldId()) { + return kv.second.Get(); + } + } + return nullptr; + }; + + if (context_.new_field_elements.size() > 0) { + return get_resource(context_.new_field_elements[0]); + } else if (context_.stale_field_elements.size() > 0) { + return get_resource(context_.stale_field_elements[0]); + } + return nullptr; +} + +Status +FieldCommitOperation::DoExecute(StorePtr store) { + auto prev_resource = GetPrevResource(); + + if (prev_resource) { + resource_ = std::make_shared(*prev_resource); + resource_->SetID(0); + resource_->ResetStatus(); + for (auto& fe : context_.stale_field_elements) { + resource_->GetMappings().erase(fe->GetID()); + } + } else { + // TODO + } + + for (auto& fe : context_.new_field_elements) { + resource_->GetMappings().insert(fe->GetID()); + } + + AddStep(*resource_, nullptr, false); + return Status::OK(); +} + +SchemaCommitOperation::SchemaCommitOperation(const OperationContext& context, ScopedSnapshotT prev_ss) + : BaseT(context, prev_ss) { +} + +SchemaCommit::Ptr +SchemaCommitOperation::GetPrevResource() const { + return GetStartedSS()->GetSchemaCommit(); +} + +Status +SchemaCommitOperation::DoExecute(StorePtr store) { + auto prev_resource = GetPrevResource(); + if (!prev_resource) { + return Status(SS_INVALID_CONTEX_ERROR, "Cannot get schema commit"); + } + + resource_ = std::make_shared(*prev_resource); + resource_->SetID(0); + resource_->ResetStatus(); + for (auto& fc : context_.stale_field_commits) { + resource_->GetMappings().erase(fc->GetID()); + } + + for (auto& fc : context_.new_field_commits) { + resource_->GetMappings().insert(fc->GetID()); + } + + AddStep(*resource_, nullptr, false); + return Status::OK(); +} + SegmentFileOperation::SegmentFileOperation(const SegmentFileContext& sc, ScopedSnapshotT prev_ss) : BaseT(OperationContext(), prev_ss), context_(sc) { } Status -SegmentFileOperation::DoExecute(Store& store) { - auto field_element_id = GetStartedSS()->GetFieldElementId(context_.field_name, context_.field_element_name); - resource_ = std::make_shared(context_.collection_id, context_.partition_id, context_.segment_id, - field_element_id); - AddStep(*resource_, false); +SegmentFileOperation::DoExecute(StorePtr store) { + FieldElementPtr fe; + STATUS_CHECK(GetStartedSS()->GetFieldElement(context_.field_name, context_.field_element_name, fe)); + resource_ = + std::make_shared(context_.collection_id, context_.partition_id, context_.segment_id, fe->GetID()); + // auto seg_ctx_p = ResourceContextBuilder().SetResource(resource_).SetOp(oAdd).CreatePtr(); + AddStep(*resource_, nullptr, false); return Status::OK(); } diff --git a/core/src/db/snapshot/ResourceOperations.h b/core/src/db/snapshot/ResourceOperations.h index 924ab66d0e3c..d9496b15df38 100644 --- a/core/src/db/snapshot/ResourceOperations.h +++ b/core/src/db/snapshot/ResourceOperations.h @@ -28,8 +28,7 @@ class CollectionCommitOperation : public CommitOperation { return prev_ss_->GetCollectionCommit(); } - Status - DoExecute(Store&) override; + Status DoExecute(StorePtr) override; }; class PartitionCommitOperation : public CommitOperation { @@ -40,8 +39,7 @@ class PartitionCommitOperation : public CommitOperation { PartitionCommitPtr GetPrevResource() const override; - Status - DoExecute(Store&) override; + Status DoExecute(StorePtr) override; Status PreCheck() override; @@ -52,8 +50,7 @@ class PartitionOperation : public CommitOperation { using BaseT = CommitOperation; PartitionOperation(const PartitionContext& context, ScopedSnapshotT prev_ss); - Status - DoExecute(Store& store) override; + Status DoExecute(StorePtr) override; Status PreCheck() override; @@ -70,8 +67,7 @@ class SegmentCommitOperation : public CommitOperation { SegmentCommit::Ptr GetPrevResource() const override; - Status - DoExecute(Store&) override; + Status DoExecute(StorePtr) override; Status PreCheck() override; @@ -82,8 +78,7 @@ class SegmentOperation : public CommitOperation { using BaseT = CommitOperation; SegmentOperation(const OperationContext& context, ScopedSnapshotT prev_ss); - Status - DoExecute(Store& store) override; + Status DoExecute(StorePtr) override; Status PreCheck() override; @@ -94,13 +89,34 @@ class SegmentFileOperation : public CommitOperation { using BaseT = CommitOperation; SegmentFileOperation(const SegmentFileContext& sc, ScopedSnapshotT prev_ss); - Status - DoExecute(Store& store) override; + Status DoExecute(StorePtr) override; protected: SegmentFileContext context_; }; +class FieldCommitOperation : public CommitOperation { + public: + using BaseT = CommitOperation; + FieldCommitOperation(const OperationContext& context, ScopedSnapshotT prev_ss); + + FieldCommit::Ptr + GetPrevResource() const override; + + Status DoExecute(StorePtr) override; +}; + +class SchemaCommitOperation : public CommitOperation { + public: + using BaseT = CommitOperation; + SchemaCommitOperation(const OperationContext& context, ScopedSnapshotT prev_ss); + + SchemaCommit::Ptr + GetPrevResource() const override; + + Status DoExecute(StorePtr) override; +}; + template <> class LoadOperation : public Operations { public: @@ -109,16 +125,16 @@ class LoadOperation : public Operations { } const Status& - ApplyToStore(Store& store) override { + ApplyToStore(StorePtr store) override { if (done_) { Done(store); return status_; } Status status; if (context_.id == 0 && context_.name != "") { - status = store.GetCollection(context_.name, resource_); + status = store->GetCollection(context_.name, resource_); } else { - status = store.GetResource(context_.id, resource_); + status = store->GetResource(context_.id, resource_); } SetStatus(status); Done(store); diff --git a/core/src/db/snapshot/ResourceTypes.h b/core/src/db/snapshot/ResourceTypes.h index 21b3ce508689..19e2e2528903 100644 --- a/core/src/db/snapshot/ResourceTypes.h +++ b/core/src/db/snapshot/ResourceTypes.h @@ -15,6 +15,8 @@ #include #include +#include "db/meta/MetaTypes.h" + namespace milvus { namespace engine { namespace snapshot { @@ -23,13 +25,9 @@ using ID_TYPE = int64_t; using NUM_TYPE = int64_t; using FTYPE_TYPE = int64_t; using TS_TYPE = int64_t; -using LSN_TYPE = uint64_t; +using LSN_TYPE = int64_t; using SIZE_TYPE = uint64_t; using MappingT = std::set; - -enum FieldType { VECTOR, INT32 }; -enum FieldElementType { RAW, IVFSQ8 }; - using IDS_TYPE = std::vector; enum State { PENDING = 0, ACTIVE = 1, DEACTIVE = 2, INVALID = 999 }; diff --git a/core/src/db/snapshot/Resources.cpp b/core/src/db/snapshot/Resources.cpp index 2fb5e54a0c37..5867b63b0502 100644 --- a/core/src/db/snapshot/Resources.cpp +++ b/core/src/db/snapshot/Resources.cpp @@ -16,7 +16,7 @@ namespace milvus::engine::snapshot { -Collection::Collection(const std::string& name, const std::string& params, ID_TYPE id, LSN_TYPE lsn, State state, +Collection::Collection(const std::string& name, const json& params, ID_TYPE id, LSN_TYPE lsn, State state, TS_TYPE created_on, TS_TYPE updated_on) : NameField(name), ParamsField(params), @@ -27,11 +27,13 @@ Collection::Collection(const std::string& name, const std::string& params, ID_TY UpdatedOnField(updated_on) { } -CollectionCommit::CollectionCommit(ID_TYPE collection_id, ID_TYPE schema_id, const MappingT& mappings, SIZE_TYPE size, - ID_TYPE id, LSN_TYPE lsn, State state, TS_TYPE created_on, TS_TYPE updated_on) +CollectionCommit::CollectionCommit(ID_TYPE collection_id, ID_TYPE schema_id, const MappingT& mappings, + SIZE_TYPE row_cnt, SIZE_TYPE size, ID_TYPE id, LSN_TYPE lsn, State state, + TS_TYPE created_on, TS_TYPE updated_on) : CollectionIdField(collection_id), SchemaIdField(schema_id), MappingsField(mappings), + RowCountField(row_cnt), SizeField(size), IdField(id), LsnField(lsn), @@ -51,11 +53,13 @@ Partition::Partition(const std::string& name, ID_TYPE collection_id, ID_TYPE id, UpdatedOnField(updated_on) { } -PartitionCommit::PartitionCommit(ID_TYPE collection_id, ID_TYPE partition_id, const MappingT& mappings, SIZE_TYPE size, - ID_TYPE id, LSN_TYPE lsn, State state, TS_TYPE created_on, TS_TYPE updated_on) +PartitionCommit::PartitionCommit(ID_TYPE collection_id, ID_TYPE partition_id, const MappingT& mappings, + SIZE_TYPE row_cnt, SIZE_TYPE size, ID_TYPE id, LSN_TYPE lsn, State state, + TS_TYPE created_on, TS_TYPE updated_on) : CollectionIdField(collection_id), PartitionIdField(partition_id), MappingsField(mappings), + RowCountField(row_cnt), SizeField(size), IdField(id), LsnField(lsn), @@ -102,12 +106,13 @@ Segment::ToString() const { } SegmentCommit::SegmentCommit(ID_TYPE schema_id, ID_TYPE partition_id, ID_TYPE segment_id, const MappingT& mappings, - SIZE_TYPE size, ID_TYPE id, LSN_TYPE lsn, State state, TS_TYPE created_on, - TS_TYPE updated_on) + SIZE_TYPE row_cnt, SIZE_TYPE size, ID_TYPE id, LSN_TYPE lsn, State state, + TS_TYPE created_on, TS_TYPE updated_on) : SchemaIdField(schema_id), PartitionIdField(partition_id), SegmentIdField(segment_id), MappingsField(mappings), + RowCountField(row_cnt), SizeField(size), IdField(id), LsnField(lsn), @@ -128,11 +133,13 @@ SegmentCommit::ToString() const { } SegmentFile::SegmentFile(ID_TYPE collection_id, ID_TYPE partition_id, ID_TYPE segment_id, ID_TYPE field_element_id, - SIZE_TYPE size, ID_TYPE id, LSN_TYPE lsn, State state, TS_TYPE created_on, TS_TYPE updated_on) + SIZE_TYPE row_cnt, SIZE_TYPE size, ID_TYPE id, LSN_TYPE lsn, State state, TS_TYPE created_on, + TS_TYPE updated_on) : CollectionIdField(collection_id), PartitionIdField(partition_id), SegmentIdField(segment_id), FieldElementIdField(field_element_id), + RowCountField(row_cnt), SizeField(size), IdField(id), LsnField(lsn), @@ -152,8 +159,8 @@ SchemaCommit::SchemaCommit(ID_TYPE collection_id, const MappingT& mappings, ID_T UpdatedOnField(updated_on) { } -Field::Field(const std::string& name, NUM_TYPE num, FTYPE_TYPE ftype, const std::string& params, ID_TYPE id, - LSN_TYPE lsn, State state, TS_TYPE created_on, TS_TYPE updated_on) +Field::Field(const std::string& name, NUM_TYPE num, FTYPE_TYPE ftype, const json& params, ID_TYPE id, LSN_TYPE lsn, + State state, TS_TYPE created_on, TS_TYPE updated_on) : NameField(name), NumField(num), FtypeField(ftype), @@ -178,7 +185,7 @@ FieldCommit::FieldCommit(ID_TYPE collection_id, ID_TYPE field_id, const MappingT } FieldElement::FieldElement(ID_TYPE collection_id, ID_TYPE field_id, const std::string& name, FTYPE_TYPE ftype, - const std::string& params, ID_TYPE id, LSN_TYPE lsn, State state, TS_TYPE created_on, + const json& params, ID_TYPE id, LSN_TYPE lsn, State state, TS_TYPE created_on, TS_TYPE updated_on) : CollectionIdField(collection_id), FieldIdField(field_id), diff --git a/core/src/db/snapshot/Resources.h b/core/src/db/snapshot/Resources.h index 108a2dac18d1..cb9351248633 100644 --- a/core/src/db/snapshot/Resources.h +++ b/core/src/db/snapshot/Resources.h @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -29,10 +30,12 @@ using milvus::engine::utils::GetMicroSecTimeStamp; namespace milvus::engine::snapshot { -static constexpr const char* JEmpty = "{}"; +static const json JEmpty = {}; class MappingsField { public: + static constexpr const char* Name = "mappings"; + explicit MappingsField(MappingT mappings = {}) : mappings_(std::move(mappings)) { } @@ -51,6 +54,8 @@ class MappingsField { class StateField { public: + static constexpr const char* Name = "state"; + explicit StateField(State state = PENDING) : state_(state) { } @@ -93,6 +98,8 @@ class StateField { class LsnField { public: + static constexpr const char* Name = "lsn"; + explicit LsnField(LSN_TYPE lsn = 0) : lsn_(lsn) { } @@ -112,6 +119,8 @@ class LsnField { class CreatedOnField { public: + static constexpr const char* Name = "created_on"; + explicit CreatedOnField(TS_TYPE created_on = GetMicroSecTimeStamp()) : created_on_(created_on) { } @@ -120,12 +129,19 @@ class CreatedOnField { return created_on_; } + void + SetCreatedTime(const TS_TYPE& time) { + created_on_ = time; + } + protected: TS_TYPE created_on_; }; class UpdatedOnField { public: + static constexpr const char* Name = "updated_on"; + explicit UpdatedOnField(TS_TYPE updated_on = GetMicroSecTimeStamp()) : updated_on_(updated_on) { } @@ -134,12 +150,19 @@ class UpdatedOnField { return updated_on_; } + void + SetUpdatedTime(const TS_TYPE& time) { + updated_on_ = time; + } + protected: TS_TYPE updated_on_; }; class IdField { public: + static constexpr const char* Name = "id"; + explicit IdField(ID_TYPE id) : id_(id) { } @@ -162,6 +185,8 @@ class IdField { class CollectionIdField { public: + static constexpr const char* Name = "collection_id"; + explicit CollectionIdField(ID_TYPE id) : collection_id_(id) { } @@ -170,12 +195,19 @@ class CollectionIdField { return collection_id_; } + void + SetCollectionId(ID_TYPE id) { + collection_id_ = id; + } + protected: ID_TYPE collection_id_; }; class SchemaIdField { public: + static constexpr const char* Name = "schema_id"; + explicit SchemaIdField(ID_TYPE id) : schema_id_(id) { } @@ -194,6 +226,8 @@ class SchemaIdField { class NumField { public: + static constexpr const char* Name = "num"; + explicit NumField(NUM_TYPE num) : num_(num) { } @@ -212,6 +246,8 @@ class NumField { class FtypeField { public: + static constexpr const char* Name = "ftype"; + explicit FtypeField(FTYPE_TYPE type) : ftype_(type) { } @@ -220,12 +256,19 @@ class FtypeField { return ftype_; } + void + SetFtype(FTYPE_TYPE type) { + ftype_ = type; + } + protected: FTYPE_TYPE ftype_; }; class FieldIdField { public: + static constexpr const char* Name = "field_id"; + explicit FieldIdField(ID_TYPE id) : field_id_(id) { } @@ -234,12 +277,19 @@ class FieldIdField { return field_id_; } + void + SetFieldId(ID_TYPE id) { + field_id_ = id; + } + protected: ID_TYPE field_id_; }; class FieldElementIdField { public: + static constexpr const char* Name = "field_element_id"; + explicit FieldElementIdField(ID_TYPE id) : field_element_id_(id) { } @@ -248,12 +298,19 @@ class FieldElementIdField { return field_element_id_; } + void + SetFieldElementId(ID_TYPE id) { + field_element_id_ = id; + } + protected: ID_TYPE field_element_id_; }; class PartitionIdField { public: + static constexpr const char* Name = "partition_id"; + explicit PartitionIdField(ID_TYPE id) : partition_id_(id) { } @@ -262,12 +319,19 @@ class PartitionIdField { return partition_id_; } + void + SetPartitionId(ID_TYPE id) { + partition_id_ = id; + } + protected: ID_TYPE partition_id_; }; class SegmentIdField { public: + static constexpr const char* Name = "segment_id"; + explicit SegmentIdField(ID_TYPE id) : segment_id_(id) { } @@ -276,12 +340,19 @@ class SegmentIdField { return segment_id_; } + void + SetSegmentId(ID_TYPE id) { + segment_id_ = id; + } + protected: ID_TYPE segment_id_; }; class NameField { public: + static constexpr const char* Name = "name"; + explicit NameField(std::string name) : name_(std::move(name)) { } @@ -290,32 +361,40 @@ class NameField { return name_; } + void + SetName(const std::string& name) { + name_ = name; + } + protected: std::string name_; }; class ParamsField { public: - explicit ParamsField(std::string params) : params_(std::move(params)), json_params_(json::parse(params_)) { + static constexpr const char* Name = "params"; + + explicit ParamsField(const json& params) : params_(params) { } - const std::string& + const json& GetParams() const { return params_; } - const json& - GetParamsJson() const { - return json_params_; + void + SetParams(const json& params) { + params_ = params; } protected: - std::string params_; - json json_params_; + json params_; }; class SizeField { public: + static constexpr const char* Name = "size"; + explicit SizeField(SIZE_TYPE size) : size_(size) { } @@ -324,13 +403,39 @@ class SizeField { return size_; } + void + SetSize(SIZE_TYPE size) { + size_ = size; + } + + protected: + SIZE_TYPE size_; +}; + +class RowCountField { + public: + static constexpr const char* Name = "row_count"; + + explicit RowCountField(SIZE_TYPE size) : size_(size) { + } + + SIZE_TYPE + GetRowCount() const { + return size_; + } + + void + SetRowCount(SIZE_TYPE row_cnt) { + size_ = row_cnt; + } + protected: SIZE_TYPE size_; }; /////////////////////////////////////////////////////////////////////////////// -class Collection : public BaseResource, +class Collection : public BaseResource, public NameField, public ParamsField, public IdField, @@ -341,21 +446,23 @@ class Collection : public BaseResource, public: using Ptr = std::shared_ptr; using MapT = std::map; + using SetT = std::set; using ScopedMapT = std::map>; using VecT = std::vector; static constexpr const char* Name = "Collection"; - explicit Collection(const std::string& name, const std::string& params = JEmpty, ID_TYPE id = 0, LSN_TYPE lsn = 0, + explicit Collection(const std::string& name, const json& params = JEmpty, ID_TYPE id = 0, LSN_TYPE lsn = 0, State status = PENDING, TS_TYPE created_on = GetMicroSecTimeStamp(), TS_TYPE UpdatedOnField = GetMicroSecTimeStamp()); }; using CollectionPtr = Collection::Ptr; -class CollectionCommit : public BaseResource, +class CollectionCommit : public BaseResource, public CollectionIdField, public SchemaIdField, public MappingsField, + public RowCountField, public SizeField, public IdField, public LsnField, @@ -366,10 +473,11 @@ class CollectionCommit : public BaseResource, static constexpr const char* Name = "CollectionCommit"; using Ptr = std::shared_ptr; using MapT = std::map; + using SetT = std::set; using ScopedMapT = std::map>; using VecT = std::vector; - CollectionCommit(ID_TYPE collection_id, ID_TYPE schema_id, const MappingT& mappings = {}, SIZE_TYPE size = 0, - ID_TYPE id = 0, LSN_TYPE lsn = 0, State status = PENDING, + CollectionCommit(ID_TYPE collection_id, ID_TYPE schema_id, const MappingT& mappings = {}, SIZE_TYPE row_cnt = 0, + SIZE_TYPE size = 0, ID_TYPE id = 0, LSN_TYPE lsn = 0, State status = PENDING, TS_TYPE created_on = GetMicroSecTimeStamp(), TS_TYPE UpdatedOnField = GetMicroSecTimeStamp()); }; @@ -377,7 +485,7 @@ using CollectionCommitPtr = CollectionCommit::Ptr; /////////////////////////////////////////////////////////////////////////////// -class Partition : public BaseResource, +class Partition : public BaseResource, public NameField, public CollectionIdField, public IdField, @@ -388,9 +496,10 @@ class Partition : public BaseResource, public: using Ptr = std::shared_ptr; using MapT = std::map; + using SetT = std::set; using ScopedMapT = std::map>; using VecT = std::vector; - static constexpr const char* Name = "Partition"; + static constexpr const char* Name = "Partitions"; Partition(const std::string& name, ID_TYPE collection_id, ID_TYPE id = 0, LSN_TYPE lsn = 0, State status = PENDING, TS_TYPE created_on = GetMicroSecTimeStamp(), TS_TYPE UpdatedOnField = GetMicroSecTimeStamp()); @@ -398,10 +507,11 @@ class Partition : public BaseResource, using PartitionPtr = Partition::Ptr; -class PartitionCommit : public BaseResource, +class PartitionCommit : public BaseResource, public CollectionIdField, public PartitionIdField, public MappingsField, + public RowCountField, public SizeField, public IdField, public LsnField, @@ -411,12 +521,13 @@ class PartitionCommit : public BaseResource, public: using Ptr = std::shared_ptr; using MapT = std::map; + using SetT = std::set; using ScopedMapT = std::map>; using VecT = std::vector; static constexpr const char* Name = "PartitionCommit"; - PartitionCommit(ID_TYPE collection_id, ID_TYPE partition_id, const MappingT& mappings = {}, SIZE_TYPE size = 0, - ID_TYPE id = 0, LSN_TYPE lsn = 0, State status = PENDING, + PartitionCommit(ID_TYPE collection_id, ID_TYPE partition_id, const MappingT& mappings = {}, SIZE_TYPE row_cnt = 0, + SIZE_TYPE size = 0, ID_TYPE id = 0, LSN_TYPE lsn = 0, State status = PENDING, TS_TYPE created_on = GetMicroSecTimeStamp(), TS_TYPE UpdatedOnField = GetMicroSecTimeStamp()); std::string @@ -427,7 +538,7 @@ using PartitionCommitPtr = PartitionCommit::Ptr; /////////////////////////////////////////////////////////////////////////////// -class Segment : public BaseResource, +class Segment : public BaseResource, public CollectionIdField, public PartitionIdField, public NumField, @@ -439,6 +550,7 @@ class Segment : public BaseResource, public: using Ptr = std::shared_ptr; using MapT = std::map; + using SetT = std::set; using ScopedMapT = std::map>; using VecT = std::vector; static constexpr const char* Name = "Segment"; @@ -453,11 +565,12 @@ class Segment : public BaseResource, using SegmentPtr = Segment::Ptr; -class SegmentCommit : public BaseResource, +class SegmentCommit : public BaseResource, public SchemaIdField, public PartitionIdField, public SegmentIdField, public MappingsField, + public RowCountField, public SizeField, public IdField, public LsnField, @@ -467,12 +580,13 @@ class SegmentCommit : public BaseResource, public: using Ptr = std::shared_ptr; using MapT = std::map; + using SetT = std::set; using ScopedMapT = std::map>; using VecT = std::vector; static constexpr const char* Name = "SegmentCommit"; SegmentCommit(ID_TYPE schema_id, ID_TYPE partition_id, ID_TYPE segment_id, const MappingT& mappings = {}, - SIZE_TYPE size = 0, ID_TYPE id = 0, LSN_TYPE lsn = 0, State status = PENDING, + SIZE_TYPE row_cnt = 0, SIZE_TYPE size = 0, ID_TYPE id = 0, LSN_TYPE lsn = 0, State status = PENDING, TS_TYPE created_on = GetMicroSecTimeStamp(), TS_TYPE UpdatedOnField = GetMicroSecTimeStamp()); std::string @@ -483,11 +597,12 @@ using SegmentCommitPtr = SegmentCommit::Ptr; /////////////////////////////////////////////////////////////////////////////// -class SegmentFile : public BaseResource, +class SegmentFile : public BaseResource, public CollectionIdField, public PartitionIdField, public SegmentIdField, public FieldElementIdField, + public RowCountField, public SizeField, public IdField, public LsnField, @@ -497,12 +612,13 @@ class SegmentFile : public BaseResource, public: using Ptr = std::shared_ptr; using MapT = std::map; + using SetT = std::set; using ScopedMapT = std::map>; using VecT = std::vector; static constexpr const char* Name = "SegmentFile"; SegmentFile(ID_TYPE collection_id, ID_TYPE partition_id, ID_TYPE segment_id, ID_TYPE field_element_id, - SIZE_TYPE size = 0, ID_TYPE id = 0, LSN_TYPE lsn = 0, State status = PENDING, + SIZE_TYPE row_cnt = 0, SIZE_TYPE size = 0, ID_TYPE id = 0, LSN_TYPE lsn = 0, State status = PENDING, TS_TYPE created_on = GetMicroSecTimeStamp(), TS_TYPE UpdatedOnField = GetMicroSecTimeStamp()); }; @@ -510,7 +626,7 @@ using SegmentFilePtr = SegmentFile::Ptr; /////////////////////////////////////////////////////////////////////////////// -class SchemaCommit : public BaseResource, +class SchemaCommit : public BaseResource, public CollectionIdField, public MappingsField, public IdField, @@ -521,6 +637,7 @@ class SchemaCommit : public BaseResource, public: using Ptr = std::shared_ptr; using MapT = std::map; + using SetT = std::set; using ScopedMapT = std::map>; using VecT = std::vector; static constexpr const char* Name = "SchemaCommit"; @@ -534,7 +651,7 @@ using SchemaCommitPtr = SchemaCommit::Ptr; /////////////////////////////////////////////////////////////////////////////// -class Field : public BaseResource, +class Field : public BaseResource, public NameField, public NumField, public FtypeField, @@ -547,18 +664,19 @@ class Field : public BaseResource, public: using Ptr = std::shared_ptr; using MapT = std::map; + using SetT = std::set; using ScopedMapT = std::map>; using VecT = std::vector; static constexpr const char* Name = "Field"; - Field(const std::string& name, NUM_TYPE num, FTYPE_TYPE ftype, const std::string& params = JEmpty, ID_TYPE id = 0, + Field(const std::string& name, NUM_TYPE num, FTYPE_TYPE ftype, const json& params = JEmpty, ID_TYPE id = 0, LSN_TYPE lsn = 0, State status = PENDING, TS_TYPE created_on = GetMicroSecTimeStamp(), TS_TYPE UpdatedOnField = GetMicroSecTimeStamp()); }; using FieldPtr = Field::Ptr; -class FieldCommit : public BaseResource, +class FieldCommit : public BaseResource, public CollectionIdField, public FieldIdField, public MappingsField, @@ -570,6 +688,7 @@ class FieldCommit : public BaseResource, public: using Ptr = std::shared_ptr; using MapT = std::map; + using SetT = std::set; using ScopedMapT = std::map>; using VecT = std::vector; static constexpr const char* Name = "FieldCommit"; @@ -583,7 +702,7 @@ using FieldCommitPtr = FieldCommit::Ptr; /////////////////////////////////////////////////////////////////////////////// -class FieldElement : public BaseResource, +class FieldElement : public BaseResource, public CollectionIdField, public FieldIdField, public NameField, @@ -597,11 +716,12 @@ class FieldElement : public BaseResource, public: using Ptr = std::shared_ptr; using MapT = std::map; + using SetT = std::set; using ScopedMapT = std::map>; using VecT = std::vector; static constexpr const char* Name = "FieldElement"; FieldElement(ID_TYPE collection_id, ID_TYPE field_id, const std::string& name, FTYPE_TYPE ftype, - const std::string& params = JEmpty, ID_TYPE id = 0, LSN_TYPE lsn = 0, State status = PENDING, + const json& params = JEmpty, ID_TYPE id = 0, LSN_TYPE lsn = 0, State status = PENDING, TS_TYPE created_on = GetMicroSecTimeStamp(), TS_TYPE UpdatedOnField = GetMicroSecTimeStamp()); }; diff --git a/core/src/db/snapshot/Snapshot.cpp b/core/src/db/snapshot/Snapshot.cpp index 0ac5ea01d9ec..a4f3b542ad1a 100644 --- a/core/src/db/snapshot/Snapshot.cpp +++ b/core/src/db/snapshot/Snapshot.cpp @@ -29,91 +29,100 @@ Snapshot::UnRefAll() { std::apply([this](auto&... resource) { ((DoUnRef(resource)), ...); }, resources_); } -Snapshot::Snapshot(ID_TYPE id) { - auto collection_commit = CollectionCommitsHolder::GetInstance().GetResource(id, false); - AddResource(collection_commit); - max_lsn_ = collection_commit->GetLsn(); - auto& schema_holder = SchemaCommitsHolder::GetInstance(); - auto current_schema = schema_holder.GetResource(collection_commit->GetSchemaId(), false); - AddResource(current_schema); - current_schema_id_ = current_schema->GetID(); +Snapshot::Snapshot(StorePtr store, ID_TYPE ss_id) { + auto& collection_commits_holder = CollectionCommitsHolder::GetInstance(); + auto& collections_holder = CollectionsHolder::GetInstance(); + auto& schema_commits_holder = SchemaCommitsHolder::GetInstance(); auto& field_commits_holder = FieldCommitsHolder::GetInstance(); auto& fields_holder = FieldsHolder::GetInstance(); auto& field_elements_holder = FieldElementsHolder::GetInstance(); - - auto collection = CollectionsHolder::GetInstance().GetResource(collection_commit->GetCollectionId(), false); - AddResource(collection); - auto& mappings = collection_commit->GetMappings(); auto& partition_commits_holder = PartitionCommitsHolder::GetInstance(); auto& partitions_holder = PartitionsHolder::GetInstance(); - auto& segments_holder = SegmentsHolder::GetInstance(); auto& segment_commits_holder = SegmentCommitsHolder::GetInstance(); + auto& segments_holder = SegmentsHolder::GetInstance(); auto& segment_files_holder = SegmentFilesHolder::GetInstance(); - auto ssid = id; - for (auto& id : mappings) { - auto partition_commit = partition_commits_holder.GetResource(id, false); - auto partition = partitions_holder.GetResource(partition_commit->GetPartitionId(), false); + auto collection_commit = collection_commits_holder.GetResource(store, ss_id, false); + AddResource(collection_commit); + + max_lsn_ = collection_commit->GetLsn(); + auto schema_commit = schema_commits_holder.GetResource(store, collection_commit->GetSchemaId(), false); + AddResource(schema_commit); + + current_schema_id_ = schema_commit->GetID(); + auto collection = collections_holder.GetResource(store, collection_commit->GetCollectionId(), false); + AddResource(collection); + + auto& collection_commit_mappings = collection_commit->GetMappings(); + for (auto p_c_id : collection_commit_mappings) { + auto partition_commit = partition_commits_holder.GetResource(store, p_c_id, false); + auto partition_id = partition_commit->GetPartitionId(); + auto partition = partitions_holder.GetResource(store, partition_id, false); + auto partition_name = partition->GetName(); AddResource(partition_commit); - p_pc_map_[partition_commit->GetPartitionId()] = partition_commit->GetID(); + + p_pc_map_[partition_id] = partition_commit->GetID(); AddResource(partition); - partition_names_map_[partition->GetName()] = partition->GetID(); - p_max_seg_num_[partition->GetID()] = 0; - auto& s_c_mappings = partition_commit->GetMappings(); - /* std::cout << "SS-" << ssid << "PC_MAP=("; */ + partition_names_map_[partition_name] = partition_id; + p_max_seg_num_[partition_id] = 0; + /* std::cout << "SS-" << ss_id << "PC_MAP=("; */ /* for (auto id : s_c_mappings) { */ /* std::cout << id << ","; */ /* } */ /* std::cout << ")" << std::endl; */ - for (auto& s_c_id : s_c_mappings) { - auto segment_commit = segment_commits_holder.GetResource(s_c_id, false); - auto segment = segments_holder.GetResource(segment_commit->GetSegmentId(), false); - auto schema = schema_holder.GetResource(segment_commit->GetSchemaId(), false); - AddResource(schema); + auto& partition_commit_mappings = partition_commit->GetMappings(); + for (auto s_c_id : partition_commit_mappings) { + auto segment_commit = segment_commits_holder.GetResource(store, s_c_id, false); + auto segment_id = segment_commit->GetSegmentId(); + auto segment = segments_holder.GetResource(store, segment_id, false); + auto segment_schema_id = segment_commit->GetSchemaId(); + auto segment_schema = schema_commits_holder.GetResource(store, segment_schema_id, false); + auto segment_partition_id = segment->GetPartitionId(); + AddResource(segment_schema); AddResource(segment_commit); - if (segment->GetNum() > p_max_seg_num_[segment->GetPartitionId()]) { - p_max_seg_num_[segment->GetPartitionId()] = segment->GetNum(); + if (segment->GetNum() > p_max_seg_num_[segment_partition_id]) { + p_max_seg_num_[segment_partition_id] = segment->GetNum(); } AddResource(segment); - seg_segc_map_[segment->GetID()] = segment_commit->GetID(); - auto& s_f_mappings = segment_commit->GetMappings(); - for (auto& s_f_id : s_f_mappings) { - auto segment_file = segment_files_holder.GetResource(s_f_id, false); - auto field_element = field_elements_holder.GetResource(segment_file->GetFieldElementId(), false); + + seg_segc_map_[segment_id] = segment_commit->GetID(); + auto& segment_commit_mappings = segment_commit->GetMappings(); + for (auto s_f_id : segment_commit_mappings) { + auto segment_file = segment_files_holder.GetResource(store, s_f_id, false); + auto segment_file_id = segment_file->GetID(); + auto field_element_id = segment_file->GetFieldElementId(); + auto field_element = field_elements_holder.GetResource(store, field_element_id, false); AddResource(field_element); AddResource(segment_file); - auto entry = element_segfiles_map_.find(segment_file->GetFieldElementId()); - if (entry == element_segfiles_map_.end()) { - element_segfiles_map_[segment_file->GetFieldElementId()] = { - {segment_file->GetSegmentId(), segment_file->GetID()}}; - } else { - entry->second[segment_file->GetSegmentId()] = segment_file->GetID(); - } + element_segfiles_map_[field_element_id][segment_id] = segment_file_id; + seg_segfiles_map_[segment_id].insert(segment_file_id); } } } - for (auto& kv : GetResources()) { - if (kv.first > latest_schema_commit_id_) + auto& schema_commit_mappings = schema_commit->GetMappings(); + auto& schema_commits = GetResources(); + for (auto& kv : schema_commits) { + if (kv.first > latest_schema_commit_id_) { latest_schema_commit_id_ = kv.first; + } auto& schema_commit = kv.second; - auto& s_c_m = current_schema->GetMappings(); - for (auto field_commit_id : s_c_m) { - auto field_commit = field_commits_holder.GetResource(field_commit_id, false); + for (auto field_commit_id : schema_commit_mappings) { + auto field_commit = field_commits_holder.GetResource(store, field_commit_id, false); AddResource(field_commit); - auto field = fields_holder.GetResource(field_commit->GetFieldId(), false); + + auto field_id = field_commit->GetFieldId(); + auto field = fields_holder.GetResource(store, field_id, false); + auto field_name = field->GetName(); AddResource(field); - field_names_map_[field->GetName()] = field->GetID(); - auto& f_c_m = field_commit->GetMappings(); - for (auto field_element_id : f_c_m) { - auto field_element = field_elements_holder.GetResource(field_element_id, false); + + field_names_map_[field_name] = field_id; + auto& field_commit_mappings = field_commit->GetMappings(); + for (auto field_element_id : field_commit_mappings) { + auto field_element = field_elements_holder.GetResource(store, field_element_id, false); AddResource(field_element); - auto entry = field_element_names_map_.find(field->GetName()); - if (entry == field_element_names_map_.end()) { - field_element_names_map_[field->GetName()] = {{field_element->GetName(), field_element->GetID()}}; - } else { - entry->second[field_element->GetName()] = field_element->GetID(); - } + auto field_element_name = field_element->GetName(); + field_element_names_map_[field_name][field_element_name] = field_element_id; } } } @@ -121,6 +130,55 @@ Snapshot::Snapshot(ID_TYPE id) { RefAll(); } +FieldPtr +Snapshot::GetField(const std::string& name) const { + auto it = field_names_map_.find(name); + if (it == field_names_map_.end()) { + return nullptr; + } + + return GetResource(it->second); +} + +Status +Snapshot::GetFieldElement(const std::string& field_name, const std::string& field_element_name, + FieldElementPtr& field_element) const { + field_element = nullptr; + auto itf = field_element_names_map_.find(field_name); + if (itf == field_element_names_map_.end()) { + std::stringstream emsg; + emsg << "Snapshot::GetFieldElement: Specified field \"" << field_name; + emsg << "\" not found"; + return Status(SS_NOT_FOUND_ERROR, emsg.str()); + } + + auto itfe = itf->second.find(field_element_name); + if (itfe == itf->second.end()) { + std::stringstream emsg; + emsg << "Snapshot::GetFieldElement: Specified field element \"" << field_element_name; + emsg << "\" not found"; + return Status(SS_NOT_FOUND_ERROR, emsg.str()); + } + + field_element = GetResource(itfe->second); + return Status::OK(); +} + +SegmentFilePtr +Snapshot::GetSegmentFile(ID_TYPE segment_id, ID_TYPE field_element_id) const { + auto it = element_segfiles_map_.find(field_element_id); + if (it == element_segfiles_map_.end()) { + return nullptr; + } + + auto its = it->second.find(segment_id); + if (its == it->second.end()) { + return nullptr; + } + + return GetResource(its->second); +} + const std::string Snapshot::ToString() const { auto to_matrix_string = [](const MappingT& mappings, int line_length, size_t ident = 0) -> std::string { @@ -159,15 +217,35 @@ Snapshot::ToString() const { ss << "****************************** Snapshot " << GetID() << " ******************************"; ss << "\nCollection: id=" << GetCollectionId() << ",name=\"" << GetName() << "\""; ss << ", CollectionCommit: id=" << GetCollectionCommit()->GetID(); - ss << ",mappings="; + ss << ",size=" << GetCollectionCommit()->GetSize(); + ss << ",rows=" << GetCollectionCommit()->GetRowCount() << ",mappings="; auto& cc_m = GetCollectionCommit()->GetMappings(); ss << to_matrix_string(cc_m, row_element_size, 2); + + auto& schema_m = GetSchemaCommit()->GetMappings(); + ss << "\nSchemaCommit: id=" << GetSchemaCommit()->GetID() << ",mappings="; + ss << to_matrix_string(schema_m, row_element_size, 2); + for (auto& fc_id : schema_m) { + auto fc = GetResource(fc_id); + auto f = GetResource(fc->GetFieldId()); + ss << "\n Field: id=" << f->GetID() << ",name=\"" << f->GetName() << "\""; + ss << ", FieldCommit: id=" << fc->GetID(); + ss << ",mappings="; + auto& fc_m = fc->GetMappings(); + ss << to_matrix_string(fc_m, row_element_size, 2); + for (auto& fe_id : fc_m) { + auto fe = GetResource(fe_id); + ss << "\n\tFieldElement: id=" << fe_id << ",name=" << fe->GetName() << " CID=" << fe->GetCollectionId(); + } + } + for (auto& p_c_id : cc_m) { auto p_c = GetResource(p_c_id); auto p = GetResource(p_c->GetPartitionId()); ss << "\nPartition: id=" << p->GetID() << ",name=\"" << p->GetName() << "\""; ss << ", PartitionCommit: id=" << p_c->GetID(); - ss << ",mappings="; + ss << ",size=" << p_c->GetSize(); + ss << ",rows=" << p_c->GetRowCount() << ",mappings="; auto& pc_m = p_c->GetMappings(); ss << to_matrix_string(pc_m, row_element_size, 2); for (auto& sc_id : pc_m) { @@ -175,12 +253,14 @@ Snapshot::ToString() const { auto se = GetResource(sc->GetSegmentId()); ss << "\n Segment: id=" << se->GetID(); ss << ", SegmentCommit: id=" << sc->GetID(); - ss << ",mappings="; + ss << ",size=" << sc->GetSize(); + ss << ",rows=" << sc->GetRowCount() << ",mappings="; auto& sc_m = sc->GetMappings(); ss << to_matrix_string(sc_m, row_element_size, 2); for (auto& sf_id : sc_m) { auto sf = GetResource(sf_id); ss << "\n\tSegmentFile: id=" << sf_id << ",field_element_id=" << sf->GetFieldElementId(); + ss << ",size=" << sf->GetSize(); } } } diff --git a/core/src/db/snapshot/Snapshot.h b/core/src/db/snapshot/Snapshot.h index 534547ab916d..45250f6ec424 100644 --- a/core/src/db/snapshot/Snapshot.h +++ b/core/src/db/snapshot/Snapshot.h @@ -20,12 +20,14 @@ #include #include #include +#include #include #include #include #include #include #include + #include "db/snapshot/Store.h" #include "db/snapshot/Utils.h" #include "db/snapshot/WrappedTypes.h" @@ -43,14 +45,14 @@ using ScopedResourcesT = class Snapshot : public ReferenceProxy { public: using Ptr = std::shared_ptr; - explicit Snapshot(ID_TYPE id); + Snapshot(StorePtr, ID_TYPE); ID_TYPE GetID() const { return GetCollectionCommit()->GetID(); } - [[nodiscard]] ID_TYPE + ID_TYPE GetCollectionId() const { auto it = GetResources().cbegin(); return it->first; @@ -67,17 +69,17 @@ class Snapshot : public ReferenceProxy { return GetResource(id); } - [[nodiscard]] const std::string& + const std::string& GetName() const { return GetResources().cbegin()->second->GetName(); } - [[nodiscard]] size_t + size_t NumberOfPartitions() const { return GetResources().size(); } - [[nodiscard]] const LSN_TYPE& + const LSN_TYPE& GetMaxLsn() const { return max_lsn_; } @@ -94,7 +96,8 @@ class Snapshot : public ReferenceProxy { Status GetPartitionId(const std::string& name, ID_TYPE& id) const { - auto it = partition_names_map_.find(name); + std::string real_name = name.empty() ? DEFAULT_PARTITON_TAG : name; + auto it = partition_names_map_.find(real_name); if (it == partition_names_map_.end()) { return Status(SS_NOT_FOUND_ERROR, "Specified partition name not found"); } @@ -107,11 +110,27 @@ class Snapshot : public ReferenceProxy { return GetResources().cbegin()->second.Get(); } - [[nodiscard]] ID_TYPE + const std::set& + GetSegmentFileIds(ID_TYPE segment_id) const { + auto it = seg_segfiles_map_.find(segment_id); + if (it == seg_segfiles_map_.end()) { + return empty_set_; + } + return it->second; + } + + SegmentFilePtr + GetSegmentFile(ID_TYPE segment_id, ID_TYPE field_element_id) const; + + ID_TYPE GetLatestSchemaCommitId() const { return latest_schema_commit_id_; } + Status + GetFieldElement(const std::string& field_name, const std::string& field_element_name, + FieldElementPtr& field_element) const; + // PXU TODO: add const. Need to change Scopedxxxx::Get SegmentCommitPtr GetSegmentCommitBySegmentId(ID_TYPE segment_id) const { @@ -163,7 +182,7 @@ class Snapshot : public ReferenceProxy { handler->SetStatus(status); } - [[nodiscard]] std::vector + std::vector GetFieldNames() const { std::vector names; for (auto& kv : field_names_map_) { @@ -172,19 +191,22 @@ class Snapshot : public ReferenceProxy { return std::move(names); } - [[nodiscard]] bool + bool HasField(const std::string& name) const { auto it = field_names_map_.find(name); return it != field_names_map_.end(); } - [[nodiscard]] bool + FieldPtr + GetField(const std::string& name) const; + + bool HasFieldElement(const std::string& field_name, const std::string& field_element_name) const { auto id = GetFieldElementId(field_name, field_element_name); return id > 0; } - [[nodiscard]] ID_TYPE + ID_TYPE GetSegmentFileId(const std::string& field_name, const std::string& field_element_name, ID_TYPE segment_id) const { auto field_element_id = GetFieldElementId(field_name, field_element_name); auto it = element_segfiles_map_.find(field_element_id); @@ -198,17 +220,17 @@ class Snapshot : public ReferenceProxy { return its->second; } - [[nodiscard]] bool + bool HasSegmentFile(const std::string& field_name, const std::string& field_element_name, ID_TYPE segment_id) const { auto id = GetSegmentFileId(field_name, field_element_name, segment_id); return id > 0; } - [[nodiscard]] ID_TYPE + ID_TYPE GetFieldElementId(const std::string& field_name, const std::string& field_element_name) const { auto itf = field_element_names_map_.find(field_name); if (itf == field_element_names_map_.end()) - return false; + return 0; auto itfe = itf->second.find(field_element_name); if (itfe == itf->second.end()) { return 0; @@ -280,7 +302,7 @@ class Snapshot : public ReferenceProxy { } template - [[nodiscard]] const typename ResourceT::ScopedMapT& + const typename ResourceT::ScopedMapT& GetResources() const { return std::get::value>(resources_); } @@ -293,7 +315,6 @@ class Snapshot : public ReferenceProxy { if (it == resources.end()) { return nullptr; } - return it->second.Get(); } @@ -319,57 +340,18 @@ class Snapshot : public ReferenceProxy { std::map partition_names_map_; std::map> field_element_names_map_; std::map> element_segfiles_map_; + std::map> seg_segfiles_map_; std::map seg_segc_map_; std::map p_pc_map_; ID_TYPE latest_schema_commit_id_ = 0; std::map p_max_seg_num_; LSN_TYPE max_lsn_; + std::set empty_set_; }; using GCHandler = std::function; using ScopedSnapshotT = ScopedResource; -template -struct IterateHandler : public std::enable_shared_from_this> { - using ResourceT = T; - using ThisT = IterateHandler; - using Ptr = std::shared_ptr; - - explicit IterateHandler(ScopedSnapshotT ss) : ss_(ss) { - } - - virtual Status - PreIterate() { - return Status::OK(); - } - virtual Status - Handle(const typename ResourceT::Ptr& resource) = 0; - virtual Status - PostIterate() { - return Status::OK(); - } - - void - SetStatus(Status status) { - std::unique_lock lock(mtx_); - status_ = status; - } - Status - GetStatus() const { - std::unique_lock lock(mtx_); - return status_; - } - - virtual void - Iterate() { - ss_->IterateResources(this->shared_from_this()); - } - - ScopedSnapshotT ss_; - Status status_; - mutable std::mutex mtx_; -}; - } // namespace snapshot } // namespace engine } // namespace milvus diff --git a/core/src/db/snapshot/SnapshotHolder.cpp b/core/src/db/snapshot/SnapshotHolder.cpp index 230353e98428..ba07e2cae51d 100644 --- a/core/src/db/snapshot/SnapshotHolder.cpp +++ b/core/src/db/snapshot/SnapshotHolder.cpp @@ -36,14 +36,14 @@ SnapshotHolder::~SnapshotHolder() { } Status -SnapshotHolder::Load(Store& store, ScopedSnapshotT& ss, ID_TYPE id, bool scoped) { +SnapshotHolder::Load(StorePtr store, ScopedSnapshotT& ss, ID_TYPE id, bool scoped) { Status status; if (id > max_id_) { CollectionCommitPtr cc; status = LoadNoLock(id, cc, store); if (!status.ok()) return status; - status = Add(id); + status = Add(store, id); if (!status.ok()) return status; } @@ -121,7 +121,7 @@ SnapshotHolder::IsActive(Snapshot::Ptr& ss) { } Status -SnapshotHolder::Add(ID_TYPE id) { +SnapshotHolder::Add(StorePtr store, ID_TYPE id) { Status status; { std::unique_lock lock(mutex_); @@ -140,7 +140,7 @@ SnapshotHolder::Add(ID_TYPE id) { } Snapshot::Ptr oldest_ss; { - auto ss = std::make_shared(id); + auto ss = std::make_shared(store, id); std::unique_lock lock(mutex_); if (!IsActive(ss)) { @@ -173,7 +173,7 @@ SnapshotHolder::Add(ID_TYPE id) { } Status -SnapshotHolder::LoadNoLock(ID_TYPE collection_commit_id, CollectionCommitPtr& cc, Store& store) { +SnapshotHolder::LoadNoLock(ID_TYPE collection_commit_id, CollectionCommitPtr& cc, StorePtr store) { assert(collection_commit_id > max_id_); LoadOperationContext context; context.id = collection_commit_id; diff --git a/core/src/db/snapshot/SnapshotHolder.h b/core/src/db/snapshot/SnapshotHolder.h index 47dab8bd88d5..696cf8dacb8c 100644 --- a/core/src/db/snapshot/SnapshotHolder.h +++ b/core/src/db/snapshot/SnapshotHolder.h @@ -31,13 +31,12 @@ class SnapshotHolder { GetID() const { return collection_id_; } - Status - Add(ID_TYPE id); + Status Add(StorePtr, ID_TYPE); Status Get(ScopedSnapshotT& ss, ID_TYPE id = 0, bool scoped = true) const; Status - Load(Store& store, ScopedSnapshotT& ss, ID_TYPE id = 0, bool scoped = true); + Load(StorePtr store, ScopedSnapshotT& ss, ID_TYPE id = 0, bool scoped = true); Status SetGCHandler(GCHandler gc_handler) { @@ -54,7 +53,7 @@ class SnapshotHolder { /* Status */ /* LoadNoLock(ID_TYPE collection_commit_id, CollectionCommitPtr& cc); */ Status - LoadNoLock(ID_TYPE collection_commit_id, CollectionCommitPtr& cc, Store& store); + LoadNoLock(ID_TYPE collection_commit_id, CollectionCommitPtr& cc, StorePtr store); void ReadyForRelease(Snapshot::Ptr ss) { diff --git a/core/src/db/snapshot/Snapshots.cpp b/core/src/db/snapshot/Snapshots.cpp index 243e71331c2f..718af6a830f5 100644 --- a/core/src/db/snapshot/Snapshots.cpp +++ b/core/src/db/snapshot/Snapshots.cpp @@ -11,11 +11,17 @@ #include "db/snapshot/Snapshots.h" #include "db/snapshot/CompoundOperations.h" +#include "db/snapshot/EventExecutor.h" +#include "db/snapshot/InActiveResourcesGCEvent.h" namespace milvus { namespace engine { namespace snapshot { +/* Status */ +/* Snapshots::DropAll() { */ +/* } */ + Status Snapshots::DropCollection(ID_TYPE collection_id, const LSN_TYPE& lsn) { ScopedSnapshotT ss; @@ -76,7 +82,7 @@ Snapshots::DropPartition(const ID_TYPE& collection_id, const ID_TYPE& partition_ } Status -Snapshots::LoadSnapshot(Store& store, ScopedSnapshotT& ss, ID_TYPE collection_id, ID_TYPE id, bool scoped) { +Snapshots::LoadSnapshot(StorePtr store, ScopedSnapshotT& ss, ID_TYPE collection_id, ID_TYPE id, bool scoped) { SnapshotHolderPtr holder; auto status = LoadHolder(store, collection_id, holder); if (!status.ok()) @@ -124,7 +130,7 @@ Snapshots::GetCollectionNames(std::vector& names) const { } Status -Snapshots::LoadNoLock(Store& store, ID_TYPE collection_id, SnapshotHolderPtr& holder) { +Snapshots::LoadNoLock(StorePtr store, ID_TYPE collection_id, SnapshotHolderPtr& holder) { auto op = std::make_shared(collection_id, false); /* op->Push(); */ (*op)(store); @@ -137,23 +143,24 @@ Snapshots::LoadNoLock(Store& store, ID_TYPE collection_id, SnapshotHolderPtr& ho holder = std::make_shared(collection_id, std::bind(&Snapshots::SnapshotGCCallback, this, std::placeholders::_1)); for (auto c_c_id : collection_commit_ids) { - holder->Add(c_c_id); + holder->Add(store, c_c_id); } return Status::OK(); } -void -Snapshots::Init() { +Status +Snapshots::Init(StorePtr store) { + auto event = std::make_shared(); + EventExecutor::GetInstance().Submit(event); + STATUS_CHECK(event->WaitToFinish()); auto op = std::make_shared(); - op->Push(); + STATUS_CHECK((*op)(store)); auto& collection_ids = op->GetIDs(); SnapshotHolderPtr holder; - // TODO - for (auto collection_id : collection_ids) { - /* GetHolder(collection_id, holder); */ - auto& store = Store::GetInstance(); - LoadHolder(store, collection_id, holder); + for (auto& collection_id : collection_ids) { + STATUS_CHECK(LoadHolder(store, collection_id, holder)); } + return Status::OK(); } Status @@ -181,7 +188,7 @@ Snapshots::GetHolder(const ID_TYPE& collection_id, SnapshotHolderPtr& holder) co } Status -Snapshots::LoadHolder(Store& store, const ID_TYPE& collection_id, SnapshotHolderPtr& holder) { +Snapshots::LoadHolder(StorePtr store, const ID_TYPE& collection_id, SnapshotHolderPtr& holder) { Status status; { std::shared_lock lock(mutex_); diff --git a/core/src/db/snapshot/Snapshots.h b/core/src/db/snapshot/Snapshots.h index 3c41b6615574..6fcda518d486 100644 --- a/core/src/db/snapshot/Snapshots.h +++ b/core/src/db/snapshot/Snapshots.h @@ -39,14 +39,14 @@ class Snapshots { Status GetHolder(const std::string& name, SnapshotHolderPtr& holder) const; Status - LoadHolder(Store& store, const ID_TYPE& collection_id, SnapshotHolderPtr& holder); + LoadHolder(StorePtr store, const ID_TYPE& collection_id, SnapshotHolderPtr& holder); Status GetSnapshot(ScopedSnapshotT& ss, ID_TYPE collection_id, ID_TYPE id = 0, bool scoped = true) const; Status GetSnapshot(ScopedSnapshotT& ss, const std::string& name, ID_TYPE id = 0, bool scoped = true) const; Status - LoadSnapshot(Store& store, ScopedSnapshotT& ss, ID_TYPE collection_id, ID_TYPE id, bool scoped = true); + LoadSnapshot(StorePtr store, ScopedSnapshotT& ss, ID_TYPE collection_id, ID_TYPE id, bool scoped = true); Status GetCollectionIds(IDS_TYPE& ids) const; @@ -64,20 +64,17 @@ class Snapshots { Status Reset(); - void - Init(); + Status Init(StorePtr); private: void SnapshotGCCallback(Snapshot::Ptr ss_ptr); - Snapshots() { - Init(); - } + Snapshots() = default; Status DoDropCollection(ScopedSnapshotT& ss, const LSN_TYPE& lsn); Status - LoadNoLock(Store& store, ID_TYPE collection_id, SnapshotHolderPtr& holder); + LoadNoLock(StorePtr store, ID_TYPE collection_id, SnapshotHolderPtr& holder); Status GetHolderNoLock(ID_TYPE collection_id, SnapshotHolderPtr& holder) const; diff --git a/core/src/db/snapshot/Store.h b/core/src/db/snapshot/Store.h index d4c33d7aad4f..468ac268999a 100644 --- a/core/src/db/snapshot/Store.h +++ b/core/src/db/snapshot/Store.h @@ -11,9 +11,14 @@ #pragma once +#include "db/Utils.h" +#include "db/meta/MetaAdapter.h" +#include "db/snapshot/ResourceContext.h" #include "db/snapshot/ResourceTypes.h" #include "db/snapshot/Resources.h" #include "db/snapshot/Utils.h" +#include "utils/Exception.h" +#include "utils/Log.h" #include "utils/Status.h" #include @@ -24,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -39,227 +45,198 @@ namespace milvus { namespace engine { namespace snapshot { -class Store { +class Store : public std::enable_shared_from_this { public: - using MockIDST = - std::tuple; - using MockResourcesT = std::tuple; - - static Store& - GetInstance() { - static Store store; - return store; + using Ptr = typename std::shared_ptr; + + explicit Store(meta::MetaAdapterPtr adapter, const std::string& root_path) + : adapter_(adapter), root_path_(root_path) { } - template - bool - DoCommit(ResourceT&&... resources) { - auto t = std::make_tuple(std::forward(resources)...); - auto& t_size = std::tuple_size::value; - if (t_size == 0) { - return false; + static Store::Ptr + Build(const std::string& uri, const std::string& root_path) { + utils::MetaUriInfo uri_info; + LOG_ENGINE_DEBUG_ << "MetaUri: " << uri << std::endl; + auto status = utils::ParseMetaUri(uri, uri_info); + if (!status.ok()) { + LOG_ENGINE_ERROR_ << "Wrong URI format: URI = " << uri; + throw InvalidArgumentException("Wrong URI format "); } - StartTransaction(); - std::apply([this](auto&&... resource) { ((std::cout << CommitResource(resource) << "\n"), ...); }, t); - FinishTransaction(); - return true; - } - template - Status - DoCommitOperation(OpT& op) { - for (auto& step_v : op.GetSteps()) { - auto id = ProcessOperationStep(step_v); - op.SetStepResult(id); + if (strcasecmp(uri_info.dialect_.c_str(), "mysql") == 0) { + LOG_ENGINE_INFO_ << "Using MySQL"; + DBMetaOptions options; + /* options.backend_uri_ = "mysql://root:12345678@127.0.0.1:3307/milvus"; */ + options.backend_uri_ = uri; + auto engine = std::make_shared(options); + auto adapter = std::make_shared(engine); + return std::make_shared(adapter, root_path); + } else if (strcasecmp(uri_info.dialect_.c_str(), "mock") == 0) { + LOG_ENGINE_INFO_ << "Using Mock. Should only be used in test environment"; + auto engine = std::make_shared(); + auto adapter = std::make_shared(engine); + return std::make_shared(adapter, root_path); + } else { + LOG_ENGINE_ERROR_ << "Invalid dialect in URI: dialect = " << uri_info.dialect_; + throw InvalidArgumentException("URI dialect is not mysql / sqlite / mock"); } - return Status::OK(); } - template - void - Apply(OpT& op) { - op.ApplyToStore(*this); + std::string + GetRootPath() const { + return root_path_ + "/tables"; } - void - StartTransaction() { + template + Status + ApplyOperation(OpT& op) { + auto session = adapter_->CreateSession(); + std::apply( + [&](auto&... step_context_set) { + std::size_t n{0}; + ((ApplyOpStep(op, n++, step_context_set, session)), ...); + }, + op.GetStepHolders()); + + ID_TYPE result_id; + auto status = session->Commit(result_id); + if (status.ok()) { + op.SetStepResult(result_id); + } + + return status; } + template void - FinishTransaction() { + ApplyOpStep(OpT& op, size_t pos, std::set>>& step_context_set, + const meta::SessionPtr& session) { + for (auto& step_context : step_context_set) { + session->Apply(step_context); + } + if (pos == op.GetPos()) { + session->ResultPos(); + } } - template - bool - CommitResource(ResourceT&& resource) { - std::cout << "Commit " << resource.Name << " " << resource.GetID() << std::endl; - auto res = CreateResource::type>(std::move(resource)); - if (!res) - return false; - return true; + template + Status + Apply(OpT& op) { + return op.ApplyToStore(this->shared_from_this()); } template Status GetResource(ID_TYPE id, typename ResourceT::Ptr& return_v) { - std::shared_lock lock(mutex_); - auto& resources = std::get::value>(resources_); - auto it = resources.find(id); - if (it == resources.end()) { - /* std::cout << "Can't find " << ResourceT::Name << " " << id << " in ("; */ - /* for (auto& i : resources) { */ - /* std::cout << i.first << ","; */ - /* } */ - /* std::cout << ")"; */ - return Status(SS_NOT_FOUND_ERROR, "DB resource not found"); + auto status = adapter_->Select(id, return_v); + + if (!status.ok()) { + return status; + } + + if (return_v == nullptr) { + std::string err = "Cannot select resource " + std::string(ResourceT::Name) + + " from DB: No resource which id = " + std::to_string(id); + return Status(SS_NOT_FOUND_ERROR, err); } - auto& c = it->second; - return_v = std::make_shared(*c); - /* std::cout << "<<< [Load] " << ResourceT::Name << " " << id - * << " IsActive=" << return_v->IsActive() << std::endl; */ + return Status::OK(); } Status GetCollection(const std::string& name, CollectionPtr& return_v) { - std::shared_lock lock(mutex_); - auto it = name_ids_.find(name); - if (it == name_ids_.end()) { - return Status(SS_NOT_FOUND_ERROR, "DB resource not found"); + // TODO: Get active collection + std::vector resources; + auto status = adapter_->SelectBy(NameField::Name, {name}, resources); + if (!status.ok()) { + return status; } - auto& id = it->second; - lock.unlock(); - return GetResource(id, return_v); - } - Status - RemoveCollection(ID_TYPE id) { - std::unique_lock lock(mutex_); - auto& resources = std::get(resources_); - auto it = resources.find(id); - if (it == resources.end()) { - return Status(SS_NOT_FOUND_ERROR, "DB resource not found"); + for (auto& res : resources) { + if (res->IsActive()) { + return_v = res; + return Status::OK(); + } } - auto name = it->second->GetName(); - resources.erase(it); - name_ids_.erase(name); - /* std::cout << ">>> [Remove] Collection " << id << std::endl; */ - return Status::OK(); + return Status(SS_NOT_FOUND_ERROR, "DB resource not found"); + } + + template + Status + GetInActiveResources(std::vector& return_vs) { + std::vector filter_states = {State::PENDING, State::DEACTIVE}; + return adapter_->SelectBy(StateField::Name, filter_states, return_vs); } template Status RemoveResource(ID_TYPE id) { - std::unique_lock lock(mutex_); - auto& resources = std::get::value>(resources_); - auto it = resources.find(id); - if (it == resources.end()) { - return Status(SS_NOT_FOUND_ERROR, "DB resource not found"); - } + auto rc_ctx_p = + ResourceContextBuilder().SetTable(ResourceT::Name).SetOp(meta::oDelete).SetID(id).CreatePtr(); - resources.erase(it); - /* std::cout << ">>> [Remove] " << ResourceT::Name << " " << id << std::endl; */ - return Status::OK(); + int64_t result_id; + return adapter_->Apply(rc_ctx_p, result_id); } IDS_TYPE AllActiveCollectionIds(bool reversed = true) const { - std::shared_lock lock(mutex_); IDS_TYPE ids; - auto& resources = std::get(resources_); + IDS_TYPE selected_ids; + adapter_->SelectResourceIDs(selected_ids, "", {""}); + if (!reversed) { - for (auto& kv : resources) { - ids.push_back(kv.first); - } + ids = selected_ids; } else { - for (auto kv = resources.rbegin(); kv != resources.rend(); ++kv) { - ids.push_back(kv->first); + for (auto it = selected_ids.rbegin(); it != selected_ids.rend(); ++it) { + ids.push_back(*it); } } + return ids; } IDS_TYPE AllActiveCollectionCommitIds(ID_TYPE collection_id, bool reversed = true) const { - std::shared_lock lock(mutex_); - IDS_TYPE ids; - auto& resources = std::get(resources_); + IDS_TYPE ids, selected_ids; + adapter_->SelectResourceIDs(selected_ids, meta::F_COLLECTON_ID, {collection_id}); + if (!reversed) { - for (auto& kv : resources) { - if (kv.second->GetCollectionId() == collection_id) { - ids.push_back(kv.first); - } - } + ids = selected_ids; } else { - for (auto kv = resources.rbegin(); kv != resources.rend(); ++kv) { - if (kv->second->GetCollectionId() == collection_id) { - ids.push_back(kv->first); - } + for (auto it = selected_ids.rbegin(); it != selected_ids.rend(); ++it) { + ids.push_back(*it); } } - return ids; - } - - Status - CreateCollection(Collection&& collection, CollectionPtr& return_v) { - std::unique_lock lock(mutex_); - auto& resources = std::get(resources_); - if (!collection.HasAssigned() && (name_ids_.find(collection.GetName()) != name_ids_.end()) && - (resources[name_ids_[collection.GetName()]]->IsActive()) && !collection.IsDeactive()) { - return Status(SS_DUPLICATED_ERROR, "Duplicated"); - } - auto c = std::make_shared(collection); - auto& id = std::get::value>(ids_); - c->SetID(++id); - c->ResetCnt(); - resources[c->GetID()] = c; - name_ids_[c->GetName()] = c->GetID(); - lock.unlock(); - GetResource(c->GetID(), return_v); - return Status::OK(); - } - template - Status - UpdateResource(ResourceT&& resource, typename ResourceT::Ptr& return_v) { - std::unique_lock lock(mutex_); - auto& resources = std::get(resources_); - auto res = std::make_shared(resource); - auto& id = std::get::value>(ids_); - res->ResetCnt(); - resources[res->GetID()] = res; - lock.unlock(); - GetResource(res->GetID(), return_v); - return Status::OK(); + return ids; } template Status CreateResource(ResourceT&& resource, typename ResourceT::Ptr& return_v) { - if (resource.HasAssigned()) { - return UpdateResource(std::move(resource), return_v); + auto res_p = std::make_shared(resource); + auto res_ctx_p = ResourceContextBuilder().SetOp(meta::oAdd).SetResource(res_p).CreatePtr(); + + int64_t result_id; + auto status = adapter_->Apply(res_ctx_p, result_id); + if (!status.ok()) { + return status; } - std::unique_lock lock(mutex_); - auto& resources = std::get(resources_); - auto res = std::make_shared(resource); - auto& id = std::get::value>(ids_); - res->SetID(++id); - res->ResetCnt(); - resources[res->GetID()] = res; - lock.unlock(); - auto status = GetResource(res->GetID(), return_v); - /* std::cout << ">>> [Create] " << ResourceT::Name << " " << id << std::endl; */ + + return_v = std::make_shared(resource); + return_v->SetID(result_id); + return_v->ResetCnt(); + return Status::OK(); } void DoReset() { - ids_ = MockIDST(); - resources_ = MockResourcesT(); - name_ids_.clear(); + auto status = adapter_->TruncateAll(); + if (!status.ok()) { + std::cout << "TruncateAll failed: " << status.ToString() << std::endl; + } } void @@ -269,220 +246,142 @@ class Store { } private: - ID_TYPE - ProcessOperationStep(const std::any& step_v) { - if (const auto it = any_flush_vistors_.find(std::type_index(step_v.type())); it != any_flush_vistors_.cend()) { - return it->second(step_v); - } else { - std::cerr << "Unregisted step type " << std::quoted(step_v.type().name()); - return 0; - } - } - - template - inline std::pair> - to_any_visitor(F const& f) { - return {std::type_index(typeid(T)), [g = f](std::any const& a) -> ID_TYPE { - if constexpr (std::is_void_v) - return g(); - else - return g(std::any_cast(a)); - }}; - } - - template - inline void - register_any_visitor(F const& f) { - /* std::cout << "Register visitor for type " << std::quoted(typeid(T).name()) << '\n'; */ - any_flush_vistors_.insert(to_any_visitor(f)); - } - - Store() { - register_any_visitor([this](auto c) { - CollectionPtr n; - CreateResource(Collection(*c), n); - return n->GetID(); - }); - register_any_visitor([this](auto c) { - using T = CollectionCommit; - using PtrT = typename T::Ptr; - PtrT n; - CreateResource(T(*c), n); - return n->GetID(); - }); - register_any_visitor([this](auto c) { - using T = SchemaCommit; - using PtrT = typename T::Ptr; - PtrT n; - CreateResource(T(*c), n); - return n->GetID(); - }); - register_any_visitor([this](auto c) { - using T = FieldCommit; - using PtrT = typename T::Ptr; - PtrT n; - CreateResource(T(*c), n); - return n->GetID(); - }); - register_any_visitor([this](auto c) { - using T = Field; - using PtrT = typename T::Ptr; - PtrT n; - CreateResource(T(*c), n); - return n->GetID(); - }); - register_any_visitor([this](auto c) { - using T = FieldElement; - using PtrT = typename T::Ptr; - PtrT n; - CreateResource(T(*c), n); - return n->GetID(); - }); - register_any_visitor([this](auto c) { - using T = PartitionCommit; - using PtrT = typename T::Ptr; - PtrT n; - CreateResource(T(*c), n); - return n->GetID(); - }); - register_any_visitor([this](auto c) { - using T = Partition; - using PtrT = typename T::Ptr; - PtrT n; - CreateResource(T(*c), n); - return n->GetID(); - }); - register_any_visitor([this](auto c) { - using T = Segment; - using PtrT = typename T::Ptr; - PtrT n; - CreateResource(T(*c), n); - return n->GetID(); - }); - register_any_visitor([this](auto c) { - using T = SegmentCommit; - using PtrT = typename T::Ptr; - PtrT n; - CreateResource(T(*c), n); - return n->GetID(); - }); - register_any_visitor([this](auto c) { - using T = SegmentFile; - using PtrT = typename T::Ptr; - PtrT n; - CreateResource(T(*c), n); - return n->GetID(); - }); - } - void DoMock() { Status status; + ID_TYPE result_id; unsigned int seed = 123; auto random = rand_r(&seed) % 2 + 4; std::vector all_records; + std::unordered_map field_commit_records; + std::unordered_map id_map = { + {Collection::Name, 0}, {Field::Name, 0}, {FieldElement::Name, 0}, {Partition::Name, 0}}; + for (auto i = 1; i <= random; i++) { std::stringstream name; - name << "c_" << std::get::value>(ids_) + 1; + name << "c_" << ++id_map[Collection::Name]; auto tc = Collection(name.str()); tc.Activate(); CollectionPtr c; - CreateCollection(std::move(tc), c); + CreateResource(std::move(tc), c); all_records.push_back(c); MappingT schema_c_m; auto random_fields = rand_r(&seed) % 2 + 1; for (auto fi = 1; fi <= random_fields; ++fi) { std::stringstream fname; - fname << "f_" << fi << "_" << std::get::value>(ids_) + 1; + fname << "f_" << fi << "_" << ++id_map[FieldElement::Name]; + + Field temp_f(fname.str(), fi, FieldType::VECTOR); FieldPtr field; - CreateResource(Field(fname.str(), fi, FieldType::VECTOR), field); + + temp_f.Activate(); + CreateResource(std::move(temp_f), field); all_records.push_back(field); MappingT f_c_m = {}; auto random_elements = rand_r(&seed) % 2 + 2; for (auto fei = 1; fei <= random_elements; ++fei) { std::stringstream fename; - fename << "fe_" << fei << "_"; - fename << std::get::value>(ids_) + 1; + fename << "fe_" << field->GetID() << "_" << ++id_map[FieldElement::Name]; FieldElementPtr element; - CreateResource(FieldElement(c->GetID(), field->GetID(), fename.str(), fei), element); + FieldElement temp_fe(c->GetID(), field->GetID(), fename.str(), fei); + temp_fe.Activate(); + CreateResource(std::move(temp_fe), element); all_records.push_back(element); f_c_m.insert(element->GetID()); } FieldCommitPtr f_c; - CreateResource(FieldCommit(c->GetID(), field->GetID(), f_c_m), f_c); + CreateResource(FieldCommit(c->GetID(), field->GetID(), f_c_m, 0, 0, ACTIVE), f_c); all_records.push_back(f_c); + field_commit_records.insert(std::pair(f_c->GetID(), f_c)); schema_c_m.insert(f_c->GetID()); } SchemaCommitPtr schema; - CreateResource(SchemaCommit(c->GetID(), schema_c_m), schema); + CreateResource(SchemaCommit(c->GetID(), schema_c_m, 0, 0, ACTIVE), schema); all_records.push_back(schema); auto random_partitions = rand_r(&seed) % 2 + 1; MappingT c_c_m; for (auto pi = 1; pi <= random_partitions; ++pi) { std::stringstream pname; - pname << "p_" << i << "_" << std::get::value>(ids_) + 1; + pname << "p_" << i << "_" << ++id_map[Partition::Name]; PartitionPtr p; - CreateResource(Partition(pname.str(), c->GetID()), p); + CreateResource(Partition(pname.str(), c->GetID(), 0, 0, ACTIVE), p); all_records.push_back(p); auto random_segments = rand_r(&seed) % 2 + 1; MappingT p_c_m; for (auto si = 1; si <= random_segments; ++si) { SegmentPtr s; - CreateResource(Segment(c->GetID(), p->GetID(), si), s); + CreateResource(Segment(c->GetID(), p->GetID(), si, 0, 0, ACTIVE), s); all_records.push_back(s); auto& schema_m = schema->GetMappings(); MappingT s_c_m; for (auto field_commit_id : schema_m) { - auto& field_commit = std::get(resources_)[field_commit_id]; + auto& field_commit = field_commit_records.at(field_commit_id); auto& f_c_m = field_commit->GetMappings(); - for (auto field_element_id : f_c_m) { + for (auto& field_element_id : f_c_m) { SegmentFilePtr sf; CreateResource( - SegmentFile(c->GetID(), p->GetID(), s->GetID(), field_commit_id), sf); + SegmentFile(c->GetID(), p->GetID(), s->GetID(), field_element_id, 0, 0, 0, 0, ACTIVE), + sf); all_records.push_back(sf); s_c_m.insert(sf->GetID()); } } SegmentCommitPtr s_c; - CreateResource(SegmentCommit(schema->GetID(), p->GetID(), s->GetID(), s_c_m), s_c); + CreateResource( + SegmentCommit(schema->GetID(), p->GetID(), s->GetID(), s_c_m, 0, 0, 0, 0, ACTIVE), s_c); all_records.push_back(s_c); p_c_m.insert(s_c->GetID()); } PartitionCommitPtr p_c; - CreateResource(PartitionCommit(c->GetID(), p->GetID(), p_c_m), p_c); + CreateResource(PartitionCommit(c->GetID(), p->GetID(), p_c_m, 0, 0, 0, 0, ACTIVE), + p_c); all_records.push_back(p_c); c_c_m.insert(p_c->GetID()); } CollectionCommitPtr c_c; - CreateResource(CollectionCommit(c->GetID(), schema->GetID(), c_c_m), c_c); + CollectionCommit temp_cc(c->GetID(), schema->GetID(), c_c_m); + temp_cc.Activate(); + CreateResource(std::move(temp_cc), c_c); all_records.push_back(c_c); } for (auto& record : all_records) { if (record.type() == typeid(std::shared_ptr)) { const auto& r = std::any_cast>(record); r->Activate(); + auto t_c_p = ResourceContextBuilder() + .SetOp(meta::oUpdate) + .SetResource(r) + .AddAttr(meta::F_STATE) + .CreatePtr(); + + adapter_->Apply(t_c_p, result_id); } else if (record.type() == typeid(std::shared_ptr)) { const auto& r = std::any_cast>(record); r->Activate(); + auto t_cc_p = ResourceContextBuilder() + .SetOp(meta::oUpdate) + .SetResource(r) + .AddAttr(meta::F_STATE) + .CreatePtr(); + adapter_->Apply(t_cc_p, result_id); } } } - MockResourcesT resources_; - MockIDST ids_; - std::map name_ids_; - std::unordered_map> any_flush_vistors_; - mutable std::shared_timed_mutex mutex_; + meta::MetaAdapterPtr adapter_; + std::string root_path_; }; +using StorePtr = Store::Ptr; + } // namespace snapshot } // namespace engine } // namespace milvus diff --git a/core/src/db/wal/WalDefinations.h b/core/src/db/wal/WalDefinations.h index 385c83023b85..8e38282f6e66 100644 --- a/core/src/db/wal/WalDefinations.h +++ b/core/src/db/wal/WalDefinations.h @@ -18,6 +18,7 @@ #include "db/Types.h" #include "db/meta/MetaTypes.h" +#include "segment/Segment.h" namespace milvus { namespace engine { @@ -41,10 +42,14 @@ struct MXLogRecord { const IDNumber* ids; uint32_t data_size; const void* data; - std::vector field_names; - std::unordered_map attr_nbytes; - std::unordered_map attr_data_size; - std::unordered_map> attr_data; + std::vector field_names; // will be removed + // std::vector attrs_size; + // std::vector attrs_data; + std::unordered_map attr_nbytes; // will be removed + std::unordered_map attr_data_size; // will be removed + std::unordered_map> attr_data; // will be removed + + engine::DataChunkPtr data_chunk; // for hybird data transfer }; struct MXLogConfiguration { diff --git a/core/src/db/wal/WalManager.cpp b/core/src/db/wal/WalManager.cpp index f6ce17e54dc0..42a8f335ff2b 100644 --- a/core/src/db/wal/WalManager.cpp +++ b/core/src/db/wal/WalManager.cpp @@ -18,6 +18,7 @@ #include #include "config/Config.h" +#include "db/snapshot/Snapshots.h" #include "utils/CommonUtil.h" #include "utils/Exception.h" #include "utils/Log.h" @@ -140,6 +141,101 @@ WalManager::Init(const meta::MetaPtr& meta) { return error_code; } +ErrorCode +WalManager::Init() { + uint64_t applied_lsn = 0; + p_meta_handler_ = std::make_shared(mxlog_config_.mxlog_path); + if (p_meta_handler_ != nullptr) { + p_meta_handler_->GetMXLogInternalMeta(applied_lsn); + } + + uint64_t recovery_start = 0; + std::vector collection_names; + auto status = snapshot::Snapshots::GetInstance().GetCollectionNames(collection_names); + if (!status.ok()) { + return WAL_META_ERROR; + } + + if (!collection_names.empty()) { + u_int64_t min_flushed_lsn = ~(u_int64_t)0; + u_int64_t max_flushed_lsn = 0; + auto update_limit_lsn = [&](u_int64_t lsn) { + if (min_flushed_lsn > lsn) { + min_flushed_lsn = lsn; + } + if (max_flushed_lsn < lsn) { + max_flushed_lsn = lsn; + } + }; + + for (auto& col_name : collection_names) { + auto& collection = collections_[col_name]; + auto& default_part = collection[""]; + + snapshot::ScopedSnapshotT ss; + status = snapshot::Snapshots::GetInstance().GetSnapshot(ss, col_name); + if (!status.ok()) { + return WAL_META_ERROR; + } + + default_part.flush_lsn = ss->GetMaxLsn(); + update_limit_lsn(default_part.flush_lsn); + + std::vector partition_names = ss->GetPartitionNames(); + for (auto& part_name : partition_names) { + auto& partition = collection[part_name]; + + snapshot::PartitionPtr ss_part = ss->GetPartition(part_name); + if (ss_part == nullptr) { + return WAL_META_ERROR; + } + + partition.flush_lsn = ss_part->GetLsn(); + update_limit_lsn(partition.flush_lsn); + } + } + + if (applied_lsn < max_flushed_lsn) { + // a new WAL folder? + applied_lsn = max_flushed_lsn; + } + if (recovery_start < min_flushed_lsn) { + // not flush all yet + recovery_start = min_flushed_lsn; + } + + for (auto& col : collections_) { + for (auto& part : col.second) { + part.second.wal_lsn = applied_lsn; + } + } + } + + // all tables are droped and a new wal path? + if (applied_lsn < recovery_start) { + applied_lsn = recovery_start; + } + + ErrorCode error_code = WAL_ERROR; + p_buffer_ = std::make_shared(mxlog_config_.mxlog_path, mxlog_config_.buffer_size); + if (p_buffer_ != nullptr) { + if (p_buffer_->Init(recovery_start, applied_lsn)) { + error_code = WAL_SUCCESS; + } else if (mxlog_config_.recovery_error_ignore) { + p_buffer_->Reset(applied_lsn); + error_code = WAL_SUCCESS; + } else { + error_code = WAL_FILE_ERROR; + } + } + + // buffer size may changed + mxlog_config_.buffer_size = p_buffer_->GetBufferSize(); + + last_applied_lsn_ = applied_lsn; + return error_code; +} + ErrorCode WalManager::GetNextRecovery(MXLogRecord& record) { ErrorCode error_code = WAL_SUCCESS; diff --git a/core/src/db/wal/WalManager.h b/core/src/db/wal/WalManager.h index 04fd84957597..ab4aabebdcb9 100644 --- a/core/src/db/wal/WalManager.h +++ b/core/src/db/wal/WalManager.h @@ -41,6 +41,9 @@ class WalManager { ErrorCode Init(const meta::MetaPtr& meta); + ErrorCode + Init(); + /* * Get next recovery * @param record[out]: record diff --git a/core/src/grpc/gen-milvus/milvus.pb.cc b/core/src/grpc/gen-milvus/milvus.pb.cc index 0e0b1bd36d43..866e9fb02f94 100644 --- a/core/src/grpc/gen-milvus/milvus.pb.cc +++ b/core/src/grpc/gen-milvus/milvus.pb.cc @@ -1341,62 +1341,62 @@ const char descriptor_table_protodef_milvus_2eproto[] PROTOBUF_SECTION_VARIABLE( "ollection_name\030\001 \001(\t\022\033\n\023partition_tag_ar" "ray\030\002 \003(\t\0220\n\rgeneral_query\030\003 \001(\0132\031.milvu" "s.grpc.GeneralQuery\022/\n\014extra_params\030\004 \003(" - "\0132\031.milvus.grpc.KeyValuePair*\237\001\n\010DataTyp" - "e\022\010\n\004NULL\020\000\022\010\n\004INT8\020\001\022\t\n\005INT16\020\002\022\t\n\005INT3" - "2\020\003\022\t\n\005INT64\020\004\022\n\n\006STRING\020\024\022\010\n\004BOOL\020\036\022\t\n\005" - "FLOAT\020(\022\n\n\006DOUBLE\020)\022\020\n\014FLOAT_VECTOR\020d\022\021\n" - "\rBINARY_VECTOR\020e\022\014\n\007UNKNOWN\020\217N*C\n\017Compar" - "eOperator\022\006\n\002LT\020\000\022\007\n\003LTE\020\001\022\006\n\002EQ\020\002\022\006\n\002GT" - "\020\003\022\007\n\003GTE\020\004\022\006\n\002NE\020\005*8\n\005Occur\022\013\n\007INVALID\020" - "\000\022\010\n\004MUST\020\001\022\n\n\006SHOULD\020\002\022\014\n\010MUST_NOT\020\0032\360\016" - "\n\rMilvusService\022\?\n\020CreateCollection\022\024.mi" - "lvus.grpc.Mapping\032\023.milvus.grpc.Status\"\000" - "\022F\n\rHasCollection\022\033.milvus.grpc.Collecti" - "onName\032\026.milvus.grpc.BoolReply\"\000\022I\n\022Desc" - "ribeCollection\022\033.milvus.grpc.CollectionN" - "ame\032\024.milvus.grpc.Mapping\"\000\022Q\n\017CountColl" - "ection\022\033.milvus.grpc.CollectionName\032\037.mi" - "lvus.grpc.CollectionRowCount\"\000\022J\n\017ShowCo" - "llections\022\024.milvus.grpc.Command\032\037.milvus" - ".grpc.CollectionNameList\"\000\022P\n\022ShowCollec" - "tionInfo\022\033.milvus.grpc.CollectionName\032\033." - "milvus.grpc.CollectionInfo\"\000\022D\n\016DropColl" - "ection\022\033.milvus.grpc.CollectionName\032\023.mi" - "lvus.grpc.Status\"\000\022=\n\013CreateIndex\022\027.milv" - "us.grpc.IndexParam\032\023.milvus.grpc.Status\"" - "\000\022G\n\rDescribeIndex\022\033.milvus.grpc.Collect" - "ionName\032\027.milvus.grpc.IndexParam\"\000\022;\n\tDr" - "opIndex\022\027.milvus.grpc.IndexParam\032\023.milvu" - "s.grpc.Status\"\000\022E\n\017CreatePartition\022\033.mil" - "vus.grpc.PartitionParam\032\023.milvus.grpc.St" - "atus\"\000\022E\n\014HasPartition\022\033.milvus.grpc.Par" - "titionParam\032\026.milvus.grpc.BoolReply\"\000\022K\n" - "\016ShowPartitions\022\033.milvus.grpc.Collection" - "Name\032\032.milvus.grpc.PartitionList\"\000\022C\n\rDr" - "opPartition\022\033.milvus.grpc.PartitionParam" - "\032\023.milvus.grpc.Status\"\000\022<\n\006Insert\022\030.milv" - "us.grpc.InsertParam\032\026.milvus.grpc.Entity" - "Ids\"\000\022E\n\rGetEntityByID\022\033.milvus.grpc.Ent" - "ityIdentity\032\025.milvus.grpc.Entities\"\000\022H\n\014" - "GetEntityIDs\022\036.milvus.grpc.GetEntityIDsP" - "aram\032\026.milvus.grpc.EntityIds\"\000\022>\n\006Search" - "\022\030.milvus.grpc.SearchParam\032\030.milvus.grpc" - ".QueryResult\"\000\022F\n\nSearchByID\022\034.milvus.gr" - "pc.SearchByIDParam\032\030.milvus.grpc.QueryRe" - "sult\"\000\022L\n\rSearchInFiles\022\037.milvus.grpc.Se" - "archInFilesParam\032\030.milvus.grpc.QueryResu" - "lt\"\000\0227\n\003Cmd\022\024.milvus.grpc.Command\032\030.milv" - "us.grpc.StringReply\"\000\022A\n\nDeleteByID\022\034.mi" - "lvus.grpc.DeleteByIDParam\032\023.milvus.grpc." - "Status\"\000\022G\n\021PreloadCollection\022\033.milvus.g" - "rpc.CollectionName\032\023.milvus.grpc.Status\"" - "\000\022I\n\016ReloadSegments\022 .milvus.grpc.ReLoad" - "SegmentsParam\032\023.milvus.grpc.Status\"\000\0227\n\005" - "Flush\022\027.milvus.grpc.FlushParam\032\023.milvus." - "grpc.Status\"\000\022=\n\007Compact\022\033.milvus.grpc.C" - "ollectionName\032\023.milvus.grpc.Status\"\000\022B\n\010" - "SearchPB\022\032.milvus.grpc.SearchParamPB\032\030.m" - "ilvus.grpc.QueryResult\"\000b\006proto3" + "\0132\031.milvus.grpc.KeyValuePair*\236\001\n\010DataTyp" + "e\022\010\n\004NONE\020\000\022\010\n\004BOOL\020\001\022\010\n\004INT8\020\002\022\t\n\005INT16" + "\020\003\022\t\n\005INT32\020\004\022\t\n\005INT64\020\005\022\t\n\005FLOAT\020\n\022\n\n\006D" + "OUBLE\020\013\022\n\n\006STRING\020\024\022\021\n\rVECTOR_BINARY\020d\022\020" + "\n\014VECTOR_FLOAT\020e\022\013\n\006VECTOR\020\310\001*C\n\017Compare" + "Operator\022\006\n\002LT\020\000\022\007\n\003LTE\020\001\022\006\n\002EQ\020\002\022\006\n\002GT\020" + "\003\022\007\n\003GTE\020\004\022\006\n\002NE\020\005*8\n\005Occur\022\013\n\007INVALID\020\000" + "\022\010\n\004MUST\020\001\022\n\n\006SHOULD\020\002\022\014\n\010MUST_NOT\020\0032\360\016\n" + "\rMilvusService\022\?\n\020CreateCollection\022\024.mil" + "vus.grpc.Mapping\032\023.milvus.grpc.Status\"\000\022" + "F\n\rHasCollection\022\033.milvus.grpc.Collectio" + "nName\032\026.milvus.grpc.BoolReply\"\000\022I\n\022Descr" + "ibeCollection\022\033.milvus.grpc.CollectionNa" + "me\032\024.milvus.grpc.Mapping\"\000\022Q\n\017CountColle" + "ction\022\033.milvus.grpc.CollectionName\032\037.mil" + "vus.grpc.CollectionRowCount\"\000\022J\n\017ShowCol" + "lections\022\024.milvus.grpc.Command\032\037.milvus." + "grpc.CollectionNameList\"\000\022P\n\022ShowCollect" + "ionInfo\022\033.milvus.grpc.CollectionName\032\033.m" + "ilvus.grpc.CollectionInfo\"\000\022D\n\016DropColle" + "ction\022\033.milvus.grpc.CollectionName\032\023.mil" + "vus.grpc.Status\"\000\022=\n\013CreateIndex\022\027.milvu" + "s.grpc.IndexParam\032\023.milvus.grpc.Status\"\000" + "\022G\n\rDescribeIndex\022\033.milvus.grpc.Collecti" + "onName\032\027.milvus.grpc.IndexParam\"\000\022;\n\tDro" + "pIndex\022\027.milvus.grpc.IndexParam\032\023.milvus" + ".grpc.Status\"\000\022E\n\017CreatePartition\022\033.milv" + "us.grpc.PartitionParam\032\023.milvus.grpc.Sta" + "tus\"\000\022E\n\014HasPartition\022\033.milvus.grpc.Part" + "itionParam\032\026.milvus.grpc.BoolReply\"\000\022K\n\016" + "ShowPartitions\022\033.milvus.grpc.CollectionN" + "ame\032\032.milvus.grpc.PartitionList\"\000\022C\n\rDro" + "pPartition\022\033.milvus.grpc.PartitionParam\032" + "\023.milvus.grpc.Status\"\000\022<\n\006Insert\022\030.milvu" + "s.grpc.InsertParam\032\026.milvus.grpc.EntityI" + "ds\"\000\022E\n\rGetEntityByID\022\033.milvus.grpc.Enti" + "tyIdentity\032\025.milvus.grpc.Entities\"\000\022H\n\014G" + "etEntityIDs\022\036.milvus.grpc.GetEntityIDsPa" + "ram\032\026.milvus.grpc.EntityIds\"\000\022>\n\006Search\022" + "\030.milvus.grpc.SearchParam\032\030.milvus.grpc." + "QueryResult\"\000\022F\n\nSearchByID\022\034.milvus.grp" + "c.SearchByIDParam\032\030.milvus.grpc.QueryRes" + "ult\"\000\022L\n\rSearchInFiles\022\037.milvus.grpc.Sea" + "rchInFilesParam\032\030.milvus.grpc.QueryResul" + "t\"\000\0227\n\003Cmd\022\024.milvus.grpc.Command\032\030.milvu" + "s.grpc.StringReply\"\000\022A\n\nDeleteByID\022\034.mil" + "vus.grpc.DeleteByIDParam\032\023.milvus.grpc.S" + "tatus\"\000\022G\n\021PreloadCollection\022\033.milvus.gr" + "pc.CollectionName\032\023.milvus.grpc.Status\"\000" + "\022I\n\016ReloadSegments\022 .milvus.grpc.ReLoadS" + "egmentsParam\032\023.milvus.grpc.Status\"\000\0227\n\005F" + "lush\022\027.milvus.grpc.FlushParam\032\023.milvus.g" + "rpc.Status\"\000\022=\n\007Compact\022\033.milvus.grpc.Co" + "llectionName\032\023.milvus.grpc.Status\"\000\022B\n\010S" + "earchPB\022\032.milvus.grpc.SearchParamPB\032\030.mi" + "lvus.grpc.QueryResult\"\000b\006proto3" ; static const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable*const descriptor_table_milvus_2eproto_deps[1] = { &::descriptor_table_status_2eproto, @@ -1446,7 +1446,7 @@ static ::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase*const descriptor_table_mil static ::PROTOBUF_NAMESPACE_ID::internal::once_flag descriptor_table_milvus_2eproto_once; static bool descriptor_table_milvus_2eproto_initialized = false; const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_milvus_2eproto = { - &descriptor_table_milvus_2eproto_initialized, descriptor_table_protodef_milvus_2eproto, "milvus.proto", 6552, + &descriptor_table_milvus_2eproto_initialized, descriptor_table_protodef_milvus_2eproto, "milvus.proto", 6551, &descriptor_table_milvus_2eproto_once, descriptor_table_milvus_2eproto_sccs, descriptor_table_milvus_2eproto_deps, 40, 1, schemas, file_default_instances, TableStruct_milvus_2eproto::offsets, file_level_metadata_milvus_2eproto, 41, file_level_enum_descriptors_milvus_2eproto, file_level_service_descriptors_milvus_2eproto, @@ -1467,13 +1467,13 @@ bool DataType_IsValid(int value) { case 2: case 3: case 4: + case 5: + case 10: + case 11: case 20: - case 30: - case 40: - case 41: case 100: case 101: - case 9999: + case 200: return true; default: return false; diff --git a/core/src/grpc/gen-milvus/milvus.pb.h b/core/src/grpc/gen-milvus/milvus.pb.h index 8a884cea9869..d5cb33dab4ba 100644 --- a/core/src/grpc/gen-milvus/milvus.pb.h +++ b/core/src/grpc/gen-milvus/milvus.pb.h @@ -230,24 +230,24 @@ namespace milvus { namespace grpc { enum DataType : int { - NULL_ = 0, - INT8 = 1, - INT16 = 2, - INT32 = 3, - INT64 = 4, + NONE = 0, + BOOL = 1, + INT8 = 2, + INT16 = 3, + INT32 = 4, + INT64 = 5, + FLOAT = 10, + DOUBLE = 11, STRING = 20, - BOOL = 30, - FLOAT = 40, - DOUBLE = 41, - FLOAT_VECTOR = 100, - BINARY_VECTOR = 101, - UNKNOWN = 9999, + VECTOR_BINARY = 100, + VECTOR_FLOAT = 101, + VECTOR = 200, DataType_INT_MIN_SENTINEL_DO_NOT_USE_ = std::numeric_limits<::PROTOBUF_NAMESPACE_ID::int32>::min(), DataType_INT_MAX_SENTINEL_DO_NOT_USE_ = std::numeric_limits<::PROTOBUF_NAMESPACE_ID::int32>::max() }; bool DataType_IsValid(int value); -constexpr DataType DataType_MIN = NULL_; -constexpr DataType DataType_MAX = UNKNOWN; +constexpr DataType DataType_MIN = NONE; +constexpr DataType DataType_MAX = VECTOR; constexpr int DataType_ARRAYSIZE = DataType_MAX + 1; const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* DataType_descriptor(); diff --git a/core/src/grpc/milvus.proto b/core/src/grpc/milvus.proto index ba8830bc09fe..e567b3ea7b05 100644 --- a/core/src/grpc/milvus.proto +++ b/core/src/grpc/milvus.proto @@ -8,22 +8,21 @@ package milvus.grpc; * @brief field data type */ enum DataType { - NULL = 0; - INT8 = 1; - INT16 = 2; - INT32 = 3; - INT64 = 4; + NONE = 0; + BOOL = 1; + INT8 = 2; + INT16 = 3; + INT32 = 4; + INT64 = 5; - STRING = 20; - - BOOL = 30; + FLOAT = 10; + DOUBLE = 11; - FLOAT = 40; - DOUBLE = 41; + STRING = 20; - FLOAT_VECTOR = 100; - BINARY_VECTOR = 101; - UNKNOWN = 9999; + VECTOR_BINARY = 100; + VECTOR_FLOAT = 101; + VECTOR = 200; } /** diff --git a/core/src/index/cmake/DefineOptionsCore.cmake b/core/src/index/cmake/DefineOptionsCore.cmake index ab4559fdd009..e5bebf9022c4 100644 --- a/core/src/index/cmake/DefineOptionsCore.cmake +++ b/core/src/index/cmake/DefineOptionsCore.cmake @@ -84,6 +84,8 @@ define_option(KNOWHERE_WITH_FAISS_GPU_VERSION "Build with FAISS GPU version" ON) define_option(FAISS_WITH_MKL "Build FAISS with MKL" OFF) +define_option(MILVUS_CUDA_ARCH "Build with CUDA arch" "DEFAULT") + #---------------------------------------------------------------------- set_option_category("Test and benchmark") diff --git a/core/src/index/cmake/ThirdPartyPackagesCore.cmake b/core/src/index/cmake/ThirdPartyPackagesCore.cmake index 35b9f49b019d..73dd466e8c9d 100644 --- a/core/src/index/cmake/ThirdPartyPackagesCore.cmake +++ b/core/src/index/cmake/ThirdPartyPackagesCore.cmake @@ -535,16 +535,26 @@ macro(build_faiss) endif () if (MILVUS_GPU_VERSION) - set(FAISS_CONFIGURE_ARGS ${FAISS_CONFIGURE_ARGS} + if (MILVUS_CUDA_ARCH STREQUAL "DEFAULT") + set(FAISS_CONFIGURE_ARGS ${FAISS_CONFIGURE_ARGS} "--with-cuda=${CUDA_TOOLKIT_ROOT_DIR}" "--with-cuda-arch=-gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_61,code=sm_61 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75" ) + else() + STRING(REPLACE ";" " " MILVUS_CUDA_ARCH "${MILVUS_CUDA_ARCH}") + set(FAISS_CONFIGURE_ARGS ${FAISS_CONFIGURE_ARGS} + "--with-cuda=${CUDA_TOOLKIT_ROOT_DIR}" + "--with-cuda-arch=${MILVUS_CUDA_ARCH}" + ) + endif () else () set(FAISS_CONFIGURE_ARGS ${FAISS_CONFIGURE_ARGS} "CPPFLAGS=-DUSE_CPU" --without-cuda) endif () + message(STATUS "Building FAISS with configure args -${FAISS_CONFIGURE_ARGS}") + if (DEFINED ENV{FAISS_SOURCE_URL}) set(FAISS_SOURCE_URL "$ENV{FAISS_SOURCE_URL}") externalproject_add(faiss_ep diff --git a/core/src/index/knowhere/CMakeLists.txt b/core/src/index/knowhere/CMakeLists.txt index 282dd076a06c..91bf9dad26b2 100644 --- a/core/src/index/knowhere/CMakeLists.txt +++ b/core/src/index/knowhere/CMakeLists.txt @@ -42,7 +42,7 @@ set(external_srcs knowhere/common/Timer.cpp ) -set(index_srcs +set(vector_index_srcs knowhere/index/vector_index/adapter/VectorAdapter.cpp knowhere/index/vector_index/helpers/FaissIO.cpp knowhere/index/vector_index/helpers/IndexParameter.cpp @@ -56,23 +56,30 @@ set(index_srcs knowhere/index/vector_index/FaissBaseIndex.cpp knowhere/index/vector_index/IndexBinaryIDMAP.cpp knowhere/index/vector_index/IndexBinaryIVF.cpp - knowhere/index/vector_index/IndexHNSW.cpp knowhere/index/vector_index/IndexIDMAP.cpp knowhere/index/vector_index/IndexIVF.cpp knowhere/index/vector_index/IndexIVFPQ.cpp knowhere/index/vector_index/IndexIVFSQ.cpp - knowhere/index/vector_index/IndexNSG.cpp - knowhere/index/vector_index/IndexType.cpp + knowhere/index/IndexType.cpp knowhere/index/vector_index/VecIndexFactory.cpp knowhere/index/vector_index/IndexAnnoy.cpp ) +set(vector_offset_index_srcs + knowhere/index/vector_offset_index/OffsetBaseIndex.cpp + knowhere/index/vector_offset_index/IndexIVF_NM.cpp + knowhere/index/vector_offset_index/IndexIVFSQNR_NM.cpp + knowhere/index/vector_offset_index/IndexHNSW_NM.cpp + knowhere/index/vector_offset_index/IndexNSG_NM.cpp + knowhere/index/vector_offset_index/IndexHNSW_SQ8NM.cpp + ) + if (MILVUS_SUPPORT_SPTAG) - set(index_srcs + set(vector_index_srcs knowhere/index/vector_index/adapter/SptagAdapter.cpp knowhere/index/vector_index/helpers/SPTAGParameterMgr.cpp knowhere/index/vector_index/IndexSPTAG.cpp - ${index_srcs} + ${vector_index_srcs} ) endif () @@ -117,7 +124,7 @@ if (MILVUS_GPU_VERSION) ${cuda_lib} ) - set(index_srcs ${index_srcs} + set(vector_index_srcs ${vector_index_srcs} knowhere/index/vector_index/gpu/IndexGPUIDMAP.cpp knowhere/index/vector_index/gpu/IndexGPUIVF.cpp knowhere/index/vector_index/gpu/IndexGPUIVFPQ.cpp @@ -126,13 +133,19 @@ if (MILVUS_GPU_VERSION) knowhere/index/vector_index/helpers/Cloner.cpp knowhere/index/vector_index/helpers/FaissGpuResourceMgr.cpp ) + + set(vector_offset_index_srcs ${vector_offset_index_srcs} + knowhere/index/vector_offset_index/gpu/IndexGPUIVF_NM.cpp + knowhere/index/vector_offset_index/gpu/IndexGPUIVFSQNR_NM.cpp + ) endif () if (NOT TARGET knowhere) add_library( knowhere STATIC ${external_srcs} - ${index_srcs} + ${vector_index_srcs} + ${vector_offset_index_srcs} ) endif () diff --git a/core/src/index/knowhere/knowhere/common/BinarySet.h b/core/src/index/knowhere/knowhere/common/BinarySet.h index ca9df76ab622..90930b1d4955 100644 --- a/core/src/index/knowhere/knowhere/common/BinarySet.h +++ b/core/src/index/knowhere/knowhere/common/BinarySet.h @@ -63,6 +63,17 @@ class BinarySet { // binary_map_[name] = binary; //} + BinaryPtr + Erase(const std::string& name) { + BinaryPtr result = nullptr; + auto it = binary_map_.find(name); + if (it != binary_map_.end()) { + result = it->second; + binary_map_.erase(it); + } + return result; + } + void clear() { binary_map_.clear(); diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexType.cpp b/core/src/index/knowhere/knowhere/index/IndexType.cpp similarity index 90% rename from core/src/index/knowhere/knowhere/index/vector_index/IndexType.cpp rename to core/src/index/knowhere/knowhere/index/IndexType.cpp index 25d3184d9879..92506a6c5457 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/IndexType.cpp +++ b/core/src/index/knowhere/knowhere/index/IndexType.cpp @@ -9,9 +9,10 @@ // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express // or implied. See the License for the specific language governing permissions and limitations under the License -#include "knowhere/index/vector_index/IndexType.h" #include + #include "knowhere/common/Exception.h" +#include "knowhere/index/IndexType.h" namespace milvus { namespace knowhere { @@ -37,6 +38,8 @@ static std::unordered_map old_index_type_str_map = { #endif {(int32_t)OldIndexType::HNSW, IndexEnum::INDEX_HNSW}, {(int32_t)OldIndexType::ANNOY, IndexEnum::INDEX_ANNOY}, + {(int32_t)OldIndexType::HNSW_SQ8NM, IndexEnum::INDEX_HNSW_SQ8NM}, + {(int32_t)OldIndexType::FAISS_IVFSQ8NR, IndexEnum::INDEX_FAISS_IVFSQ8NR}, {(int32_t)OldIndexType::FAISS_BIN_IDMAP, IndexEnum::INDEX_FAISS_BIN_IDMAP}, {(int32_t)OldIndexType::FAISS_BIN_IVFLAT_CPU, IndexEnum::INDEX_FAISS_BIN_IVFFLAT}, }; @@ -55,6 +58,8 @@ static std::unordered_map str_old_index_type_map = { #endif {IndexEnum::INDEX_HNSW, (int32_t)OldIndexType::HNSW}, {IndexEnum::INDEX_ANNOY, (int32_t)OldIndexType::ANNOY}, + {IndexEnum::INDEX_FAISS_IVFSQ8NR, (int32_t)OldIndexType::FAISS_IVFSQ8NR}, + {IndexEnum::INDEX_HNSW_SQ8NM, (int32_t)OldIndexType::HNSW_SQ8NM}, {IndexEnum::INDEX_FAISS_BIN_IDMAP, (int32_t)OldIndexType::FAISS_BIN_IDMAP}, {IndexEnum::INDEX_FAISS_BIN_IVFFLAT, (int32_t)OldIndexType::FAISS_BIN_IVFLAT_CPU}, }; @@ -62,10 +67,11 @@ static std::unordered_map str_old_index_type_map = { /* used in 0.8.0 */ namespace IndexEnum { const char* INVALID = ""; -const char* INDEX_FAISS_IDMAP = "IDMAP"; +const char* INDEX_FAISS_IDMAP = "FLAT"; const char* INDEX_FAISS_IVFFLAT = "IVF_FLAT"; const char* INDEX_FAISS_IVFPQ = "IVF_PQ"; const char* INDEX_FAISS_IVFSQ8 = "IVF_SQ8"; +const char* INDEX_FAISS_IVFSQ8NR = "IVF_SQ8NR"; const char* INDEX_FAISS_IVFSQ8H = "IVF_SQ8_HYBRID"; const char* INDEX_FAISS_BIN_IDMAP = "BIN_IDMAP"; const char* INDEX_FAISS_BIN_IVFFLAT = "BIN_IVF_FLAT"; @@ -76,6 +82,7 @@ const char* INDEX_SPTAG_BKT_RNT = "SPTAG_BKT_RNT"; #endif const char* INDEX_HNSW = "HNSW"; const char* INDEX_ANNOY = "ANNOY"; +const char* INDEX_HNSW_SQ8NM = "HNSW_SQ8NM"; } // namespace IndexEnum std::string diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexType.h b/core/src/index/knowhere/knowhere/index/IndexType.h similarity index 94% rename from core/src/index/knowhere/knowhere/index/vector_index/IndexType.h rename to core/src/index/knowhere/knowhere/index/IndexType.h index 0e8c8b6eeebb..5e1184a1ed33 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/IndexType.h +++ b/core/src/index/knowhere/knowhere/index/IndexType.h @@ -35,6 +35,8 @@ enum class OldIndexType { SPTAG_BKT_RNT_CPU, HNSW, ANNOY, + FAISS_IVFSQ8NR, + HNSW_SQ8NM, FAISS_BIN_IDMAP = 100, FAISS_BIN_IVFLAT_CPU = 101, }; @@ -48,6 +50,7 @@ extern const char* INDEX_FAISS_IDMAP; extern const char* INDEX_FAISS_IVFFLAT; extern const char* INDEX_FAISS_IVFPQ; extern const char* INDEX_FAISS_IVFSQ8; +extern const char* INDEX_FAISS_IVFSQ8NR; extern const char* INDEX_FAISS_IVFSQ8H; extern const char* INDEX_FAISS_BIN_IDMAP; extern const char* INDEX_FAISS_BIN_IVFFLAT; @@ -58,6 +61,7 @@ extern const char* INDEX_SPTAG_BKT_RNT; #endif extern const char* INDEX_HNSW; extern const char* INDEX_ANNOY; +extern const char* INDEX_HNSW_SQ8NM; } // namespace IndexEnum enum class IndexMode { MODE_CPU = 0, MODE_GPU = 1 }; diff --git a/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapter.cpp b/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapter.cpp index 77480989482e..8fd99bda43c7 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapter.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapter.cpp @@ -10,23 +10,20 @@ // or implied. See the License for the specific language governing permissions and limitations under the License. #include "knowhere/index/vector_index/ConfAdapter.h" - #include +#include #include #include #include - #include "knowhere/index/vector_index/helpers/IndexParameter.h" +#ifdef MILVUS_GPU_VERSION +#include "faiss/gpu/utils/DeviceUtils.h" +#endif + namespace milvus { namespace knowhere { -#if CUDA_VERSION > 9000 -#define GPU_MAX_NRPOBE 2048 -#else -#define GPU_MAX_NRPOBE 1024 -#endif - #define DEFAULT_MAX_DIM 32768 #define DEFAULT_MIN_DIM 1 #define DEFAULT_MAX_K 16384 @@ -116,11 +113,12 @@ IVFConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMod static int64_t MAX_NPROBE = 999999; // todo(linxj): [1, nlist] if (mode == IndexMode::MODE_GPU) { - CheckIntByRange(knowhere::IndexParams::nprobe, MIN_NPROBE, GPU_MAX_NRPOBE); +#ifdef MILVUS_GPU_VERSION + CheckIntByRange(knowhere::IndexParams::nprobe, MIN_NPROBE, faiss::gpu::getMaxKSelection()); +#endif } else { CheckIntByRange(knowhere::IndexParams::nprobe, MIN_NPROBE, MAX_NPROBE); } - CheckIntByRange(knowhere::IndexParams::nprobe, MIN_NPROBE, MAX_NPROBE); return ConfAdapter::CheckSearch(oricfg, type, mode); } @@ -133,6 +131,14 @@ IVFSQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) { return IVFConfAdapter::CheckTrain(oricfg, mode); } +bool +IVFSQ8NRConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) { + static int64_t DEFAULT_NBITS = 8; + oricfg[knowhere::IndexParams::nbits] = DEFAULT_NBITS; + + return IVFConfAdapter::CheckTrain(oricfg, mode); +} + bool IVFPQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) { static int64_t DEFAULT_NBITS = 8; @@ -251,6 +257,29 @@ HNSWConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMo return ConfAdapter::CheckSearch(oricfg, type, mode); } +bool +HNSWSQ8NRConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) { + static int64_t MIN_EFCONSTRUCTION = 8; + static int64_t MAX_EFCONSTRUCTION = 512; + static int64_t MIN_M = 4; + static int64_t MAX_M = 64; + + CheckIntByRange(knowhere::meta::ROWS, DEFAULT_MIN_ROWS, DEFAULT_MAX_ROWS); + CheckIntByRange(knowhere::IndexParams::efConstruction, MIN_EFCONSTRUCTION, MAX_EFCONSTRUCTION); + CheckIntByRange(knowhere::IndexParams::M, MIN_M, MAX_M); + + return ConfAdapter::CheckTrain(oricfg, mode); +} + +bool +HNSWSQ8NRConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) { + static int64_t MAX_EF = 4096; + + CheckIntByRange(knowhere::IndexParams::ef, oricfg[knowhere::meta::TOPK], MAX_EF); + + return ConfAdapter::CheckSearch(oricfg, type, mode); +} + bool BinIDMAPConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) { static std::vector METRICS{knowhere::Metric::HAMMING, knowhere::Metric::JACCARD, @@ -299,9 +328,8 @@ ANNOYConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) { bool ANNOYConfAdapter::CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) { - static int64_t MIN_SEARCH_K = 1; - static int64_t MAX_SEARCH_K = 999999; - CheckIntByRange(knowhere::IndexParams::search_k, MIN_SEARCH_K, MAX_SEARCH_K); + CheckIntByRange(knowhere::IndexParams::search_k, std::numeric_limits::min(), + std::numeric_limits::max()); return ConfAdapter::CheckSearch(oricfg, type, mode); } diff --git a/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapter.h b/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapter.h index d4ecd6053782..7ac1454292f8 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapter.h +++ b/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapter.h @@ -16,7 +16,7 @@ #include #include "knowhere/common/Config.h" -#include "knowhere/index/vector_index/IndexType.h" +#include "knowhere/index/IndexType.h" namespace milvus { namespace knowhere { @@ -94,5 +94,20 @@ class ANNOYConfAdapter : public ConfAdapter { CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override; }; +class HNSWSQ8NRConfAdapter : public ConfAdapter { + public: + bool + CheckTrain(Config& oricfg, const IndexMode mode) override; + + bool + CheckSearch(Config& oricfg, const IndexType type, const IndexMode mode) override; +}; + +class IVFSQ8NRConfAdapter : public IVFConfAdapter { + public: + bool + CheckTrain(Config& oricfg, const IndexMode mode) override; +}; + } // namespace knowhere } // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapterMgr.cpp b/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapterMgr.cpp index 3dc231b81bea..21abe6dc1851 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapterMgr.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapterMgr.cpp @@ -49,6 +49,8 @@ AdapterMgr::RegisterAdapter() { #endif REGISTER_CONF_ADAPTER(HNSWConfAdapter, IndexEnum::INDEX_HNSW, hnsw_adapter); REGISTER_CONF_ADAPTER(ANNOYConfAdapter, IndexEnum::INDEX_ANNOY, annoy_adapter); + REGISTER_CONF_ADAPTER(HNSWSQ8NRConfAdapter, IndexEnum::INDEX_HNSW_SQ8NM, hnswsq8nr_adapter); + REGISTER_CONF_ADAPTER(IVFSQ8NRConfAdapter, IndexEnum::INDEX_FAISS_IVFSQ8NR, ivfsq8nr_adapter); } } // namespace knowhere diff --git a/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapterMgr.h b/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapterMgr.h index 5d8c24f3226a..83b9d0f5845e 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapterMgr.h +++ b/core/src/index/knowhere/knowhere/index/vector_index/ConfAdapterMgr.h @@ -15,8 +15,8 @@ #include #include +#include "knowhere/index/IndexType.h" #include "knowhere/index/vector_index/ConfAdapter.h" -#include "knowhere/index/vector_index/IndexType.h" namespace milvus { namespace knowhere { diff --git a/core/src/index/knowhere/knowhere/index/vector_index/FaissBaseBinaryIndex.h b/core/src/index/knowhere/knowhere/index/vector_index/FaissBaseBinaryIndex.h index f4d8784749fa..5db3fb4982bb 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/FaissBaseBinaryIndex.h +++ b/core/src/index/knowhere/knowhere/index/vector_index/FaissBaseBinaryIndex.h @@ -18,7 +18,7 @@ #include "knowhere/common/BinarySet.h" #include "knowhere/common/Dataset.h" -#include "knowhere/index/vector_index/IndexType.h" +#include "knowhere/index/IndexType.h" namespace milvus { namespace knowhere { diff --git a/core/src/index/knowhere/knowhere/index/vector_index/FaissBaseIndex.cpp b/core/src/index/knowhere/knowhere/index/vector_index/FaissBaseIndex.cpp index f0a6facefcf4..996933664819 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/FaissBaseIndex.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/FaissBaseIndex.cpp @@ -13,8 +13,8 @@ #include #include "knowhere/common/Exception.h" +#include "knowhere/index/IndexType.h" #include "knowhere/index/vector_index/FaissBaseIndex.h" -#include "knowhere/index/vector_index/IndexType.h" #include "knowhere/index/vector_index/helpers/FaissIO.h" namespace milvus { diff --git a/core/src/index/knowhere/knowhere/index/vector_index/FaissBaseIndex.h b/core/src/index/knowhere/knowhere/index/vector_index/FaissBaseIndex.h index 0da012387772..53a9c3a307cf 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/FaissBaseIndex.h +++ b/core/src/index/knowhere/knowhere/index/vector_index/FaissBaseIndex.h @@ -17,7 +17,7 @@ #include #include "knowhere/common/BinarySet.h" -#include "knowhere/index/vector_index/IndexType.h" +#include "knowhere/index/IndexType.h" namespace milvus { namespace knowhere { diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexAnnoy.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexAnnoy.cpp index 990f8ebfa518..aa9a04eef69c 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/IndexAnnoy.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexAnnoy.cpp @@ -18,9 +18,6 @@ #include #include -#include "hnswlib/hnswalg.h" -#include "hnswlib/space_ip.h" -#include "hnswlib/space_l2.h" #include "knowhere/common/Exception.h" #include "knowhere/common/Log.h" #include "knowhere/index/vector_index/adapter/VectorAdapter.h" @@ -89,7 +86,7 @@ IndexAnnoy::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) { return; } - GETTENSORWITHIDS(dataset_ptr) + GET_TENSOR(dataset_ptr) metric_type_ = config[Metric::TYPE]; if (metric_type_ == Metric::L2) { @@ -113,7 +110,7 @@ IndexAnnoy::Query(const DatasetPtr& dataset_ptr, const Config& config) { KNOWHERE_THROW_MSG("index not initialize or trained"); } - GETTENSOR(dataset_ptr) + GET_TENSOR_DATA_DIM(dataset_ptr) auto k = config[meta::TOPK].get(); auto search_k = config[IndexParams::search_k].get(); auto all_num = rows * k; @@ -162,5 +159,13 @@ IndexAnnoy::Dim() { return index_->get_dim(); } +void +IndexAnnoy::UpdateIndexSize() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + index_size_ = index_->cal_size(); +} + } // namespace knowhere } // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexAnnoy.h b/core/src/index/knowhere/knowhere/index/vector_index/IndexAnnoy.h index fa78743e23ed..7b86dc531d40 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/IndexAnnoy.h +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexAnnoy.h @@ -62,6 +62,9 @@ class IndexAnnoy : public VecIndex { int64_t Dim() override; + void + UpdateIndexSize() override; + private: MetricType metric_type_; std::shared_ptr> index_ = nullptr; diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.cpp index a38fba01b461..5e3531a51bdf 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.cpp @@ -44,7 +44,7 @@ BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) { if (!index_) { KNOWHERE_THROW_MSG("index not initialize"); } - GETTENSOR(dataset_ptr) + GET_TENSOR_DATA(dataset_ptr) int64_t k = config[meta::TOPK].get(); auto elems = rows * k; @@ -112,6 +112,22 @@ BinaryIDMAP::QueryById(const DatasetPtr& dataset_ptr, const Config& config) { } #endif +int64_t +BinaryIDMAP::Count() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_->ntotal; +} + +int64_t +BinaryIDMAP::Dim() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_->d; +} + void BinaryIDMAP::Add(const DatasetPtr& dataset_ptr, const Config& config) { if (!index_) { @@ -119,7 +135,7 @@ BinaryIDMAP::Add(const DatasetPtr& dataset_ptr, const Config& config) { } std::lock_guard lk(mutex_); - GETTENSORWITHIDS(dataset_ptr) + GET_TENSOR_DATA_ID(dataset_ptr) index_->add_with_ids(rows, (uint8_t*)p_data, p_ids); } @@ -161,7 +177,7 @@ BinaryIDMAP::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) } std::lock_guard lk(mutex_); - GETTENSOR(dataset_ptr) + GET_TENSOR_DATA(dataset_ptr) std::vector new_ids(rows); for (int i = 0; i < rows; ++i) { diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.h b/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.h index bf3b57808363..ce7da9bf0404 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.h +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIDMAP.h @@ -56,14 +56,10 @@ class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex { #endif int64_t - Count() override { - return index_->ntotal; - } + Count() override; int64_t - Dim() override { - return index_->d; - } + Dim() override; int64_t IndexSize() override { diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIVF.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIVF.cpp index 616f4c8d23fb..33fc722d6ad9 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIVF.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIVF.cpp @@ -48,7 +48,7 @@ BinaryIVF::Query(const DatasetPtr& dataset_ptr, const Config& config) { KNOWHERE_THROW_MSG("index not initialize or trained"); } - GETTENSOR(dataset_ptr) + GET_TENSOR_DATA(dataset_ptr) try { int64_t k = config[meta::TOPK].get(); @@ -129,9 +129,39 @@ BinaryIVF::QueryById(const DatasetPtr& dataset_ptr, const Config& config) { } #endif +int64_t +BinaryIVF::Count() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_->ntotal; +} + +int64_t +BinaryIVF::Dim() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_->d; +} + +void +BinaryIVF::UpdateIndexSize() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + auto bin_ivf_index = dynamic_cast(index_.get()); + auto nb = bin_ivf_index->invlists->compute_ntotal(); + auto nlist = bin_ivf_index->nlist; + auto code_size = bin_ivf_index->code_size; + + // binary ivf codes, ids and quantizer + index_size_ = nb * code_size + nb * sizeof(int64_t) + nlist * code_size; +} + void BinaryIVF::Train(const DatasetPtr& dataset_ptr, const Config& config) { - GETTENSORWITHIDS(dataset_ptr) + GET_TENSOR(dataset_ptr) int64_t nlist = config[IndexParams::nlist]; faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get()); diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIVF.h b/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIVF.h index e46a64e8d6c1..8069704436fa 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIVF.h +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexBinaryIVF.h @@ -68,14 +68,13 @@ class BinaryIVF : public VecIndex, public FaissBaseBinaryIndex { #endif int64_t - Count() override { - return index_->ntotal; - } + Count() override; int64_t - Dim() override { - return index_->d; - } + Dim() override; + + void + UpdateIndexSize() override; #if 0 DatasetPtr diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.cpp index 758ef74a7766..dd33384602f8 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.cpp @@ -78,13 +78,14 @@ IndexHNSW::Load(const BinarySet& index_binary) { void IndexHNSW::Train(const DatasetPtr& dataset_ptr, const Config& config) { try { - GETTENSOR(dataset_ptr) + int64_t dim = dataset_ptr->Get(meta::DIM); + int64_t rows = dataset_ptr->Get(meta::ROWS); hnswlib::SpaceInterface* space; if (config[Metric::TYPE] == Metric::L2) { - space = new hnswlib::L2Space(dim); + space = new hnswlib_nm::L2Space(dim); } else if (config[Metric::TYPE] == Metric::IP) { - space = new hnswlib::InnerProductSpace(dim); + space = new hnswlib_nm::InnerProductSpace(dim); normalize = true; } index_ = std::make_shared>(space, rows, config[IndexParams::M].get(), @@ -102,7 +103,7 @@ IndexHNSW::Add(const DatasetPtr& dataset_ptr, const Config& config) { std::lock_guard lk(mutex_); - GETTENSORWITHIDS(dataset_ptr) + GET_TENSOR_DATA_ID(dataset_ptr) // if (normalize) { // std::vector ep_norm_vector(Dim()); @@ -135,7 +136,7 @@ IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config) { if (!index_) { KNOWHERE_THROW_MSG("index not initialize or trained"); } - GETTENSOR(dataset_ptr) + GET_TENSOR_DATA(dataset_ptr) size_t k = config[meta::TOPK].get(); size_t id_size = sizeof(int64_t) * k; @@ -205,5 +206,13 @@ IndexHNSW::Dim() { return (*(size_t*)index_->dist_func_param_); } +void +IndexHNSW::UpdateIndexSize() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + index_size_ = index_->cal_size(); +} + } // namespace knowhere } // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.h b/core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.h index 9a21797f0f67..576f2c484ae7 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.h +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexHNSW.h @@ -54,6 +54,9 @@ class IndexHNSW : public VecIndex { int64_t Dim() override; + void + UpdateIndexSize() override; + private: bool normalize = false; std::mutex mutex_; diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexIDMAP.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexIDMAP.cpp index 8f1babc840e4..01af86211503 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/IndexIDMAP.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexIDMAP.cpp @@ -68,7 +68,7 @@ IDMAP::Add(const DatasetPtr& dataset_ptr, const Config& config) { } std::lock_guard lk(mutex_); - GETTENSORWITHIDS(dataset_ptr) + GET_TENSOR_DATA_ID(dataset_ptr) index_->add_with_ids(rows, (float*)p_data, p_ids); } @@ -96,7 +96,7 @@ IDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) { if (!index_) { KNOWHERE_THROW_MSG("index not initialize"); } - GETTENSOR(dataset_ptr) + GET_TENSOR_DATA(dataset_ptr) int64_t k = config[meta::TOPK].get(); auto elems = rows * k; @@ -142,6 +142,22 @@ IDMAP::QueryById(const DatasetPtr& dataset_ptr, const Config& config) { } #endif +int64_t +IDMAP::Count() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_->ntotal; +} + +int64_t +IDMAP::Dim() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_->d; +} + VecIndexPtr IDMAP::CopyCpuToGpu(const int64_t device_id, const Config& config) { #ifdef MILVUS_GPU_VERSION diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexIDMAP.h b/core/src/index/knowhere/knowhere/index/vector_index/IndexIDMAP.h index f15c665d40b7..128d9f99b468 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/IndexIDMAP.h +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexIDMAP.h @@ -54,14 +54,10 @@ class IDMAP : public VecIndex, public FaissBaseIndex { #endif int64_t - Count() override { - return index_->ntotal; - } + Count() override; int64_t - Dim() override { - return index_->d; - } + Dim() override; int64_t IndexSize() override { diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexIVF.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexIVF.cpp index 94133acc2fb6..5f756db22809 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/IndexIVF.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexIVF.cpp @@ -64,15 +64,13 @@ IVF::Load(const BinarySet& binary_set) { void IVF::Train(const DatasetPtr& dataset_ptr, const Config& config) { - GETTENSOR(dataset_ptr) + GET_TENSOR_DATA_DIM(dataset_ptr) - faiss::Index* coarse_quantizer = new faiss::IndexFlatL2(dim); - int64_t nlist = config[IndexParams::nlist].get(); faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get()); - auto index = std::make_shared(coarse_quantizer, dim, nlist, metric_type); - index->train(rows, (float*)p_data); - - index_.reset(faiss::clone_index(index.get())); + faiss::Index* coarse_quantizer = new faiss::IndexFlat(dim, metric_type); + int64_t nlist = config[IndexParams::nlist].get(); + index_ = std::shared_ptr(new faiss::IndexIVFFlat(coarse_quantizer, dim, nlist, metric_type)); + index_->train(rows, (float*)p_data); } void @@ -82,7 +80,7 @@ IVF::Add(const DatasetPtr& dataset_ptr, const Config& config) { } std::lock_guard lk(mutex_); - GETTENSORWITHIDS(dataset_ptr) + GET_TENSOR_DATA_ID(dataset_ptr) index_->add_with_ids(rows, (float*)p_data, p_ids); } @@ -93,7 +91,7 @@ IVF::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) { } std::lock_guard lk(mutex_); - GETTENSOR(dataset_ptr) + GET_TENSOR_DATA(dataset_ptr) index_->add(rows, (float*)p_data); } @@ -103,7 +101,7 @@ IVF::Query(const DatasetPtr& dataset_ptr, const Config& config) { KNOWHERE_THROW_MSG("index not initialize or trained"); } - GETTENSOR(dataset_ptr) + GET_TENSOR_DATA(dataset_ptr) try { fiu_do_on("IVF.Search.throw_std_exception", throw std::exception()); @@ -217,6 +215,22 @@ IVF::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) { } #endif +int64_t +IVF::Count() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_->ntotal; +} + +int64_t +IVF::Dim() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_->d; +} + void IVF::Seal() { if (!index_ || !index_->is_trained) { @@ -225,6 +239,19 @@ IVF::Seal() { SealImpl(); } +void +IVF::UpdateIndexSize() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + auto ivf_index = dynamic_cast(index_.get()); + auto nb = ivf_index->invlists->compute_ntotal(); + auto nlist = ivf_index->nlist; + auto code_size = ivf_index->code_size; + // ivf codes, ivf ids and quantizer + index_size_ = nb * code_size + nb * sizeof(int64_t) + nlist * code_size; +} + VecIndexPtr IVF::CopyCpuToGpu(const int64_t device_id, const Config& config) { #ifdef MILVUS_GPU_VERSION diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexIVF.h b/core/src/index/knowhere/knowhere/index/vector_index/IndexIVF.h index 612abc6bd11b..d33481ed4b13 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/IndexIVF.h +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexIVF.h @@ -59,14 +59,13 @@ class IVF : public VecIndex, public FaissBaseIndex { #endif int64_t - Count() override { - return index_->ntotal; - } + Count() override; int64_t - Dim() override { - return index_->d; - } + Dim() override; + + void + UpdateIndexSize() override; #if 0 DatasetPtr diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexIVFPQ.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexIVFPQ.cpp index c6a7e4de5ac8..2d7de729f8aa 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/IndexIVFPQ.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexIVFPQ.cpp @@ -33,15 +33,15 @@ namespace knowhere { void IVFPQ::Train(const DatasetPtr& dataset_ptr, const Config& config) { - GETTENSOR(dataset_ptr) + GET_TENSOR_DATA_DIM(dataset_ptr) - faiss::Index* coarse_quantizer = new faiss::IndexFlat(dim, GetMetricType(config[Metric::TYPE].get())); - auto index = std::make_shared(coarse_quantizer, dim, config[IndexParams::nlist].get(), - config[IndexParams::m].get(), - config[IndexParams::nbits].get()); - index->train(rows, (float*)p_data); + faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get()); + faiss::Index* coarse_quantizer = new faiss::IndexFlat(dim, metric_type); + index_ = std::shared_ptr(new faiss::IndexIVFPQ( + coarse_quantizer, dim, config[IndexParams::nlist].get(), config[IndexParams::m].get(), + config[IndexParams::nbits].get(), metric_type)); - index_.reset(faiss::clone_index(index.get())); + index_->train(rows, (float*)p_data); } VecIndexPtr @@ -73,5 +73,28 @@ IVFPQ::GenParams(const Config& config) { return params; } +void +IVFPQ::UpdateIndexSize() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + auto ivfpq_index = dynamic_cast(index_.get()); + auto nb = ivfpq_index->invlists->compute_ntotal(); + auto code_size = ivfpq_index->code_size; + auto pq = ivfpq_index->pq; + auto nlist = ivfpq_index->nlist; + auto d = ivfpq_index->d; + + // ivf codes, ivf ids and quantizer + auto capacity = nb * code_size + nb * sizeof(int64_t) + nlist * d * sizeof(float); + auto centroid_table = pq.M * pq.ksub * pq.dsub * sizeof(float); + auto precomputed_table = nlist * pq.M * pq.ksub * sizeof(float); + if (precomputed_table > ivfpq_index->precomputed_table_max_bytes) { + // will not precompute table + precomputed_table = 0; + } + index_size_ = capacity + centroid_table + precomputed_table; +} + } // namespace knowhere } // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexIVFPQ.h b/core/src/index/knowhere/knowhere/index/vector_index/IndexIVFPQ.h index 8582a6ea936e..aed407209999 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/IndexIVFPQ.h +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexIVFPQ.h @@ -35,6 +35,9 @@ class IVFPQ : public IVF { VecIndexPtr CopyCpuToGpu(const int64_t, const Config&) override; + void + UpdateIndexSize() override; + protected: std::shared_ptr GenParams(const Config& config) override; diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexIVFSQ.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexIVFSQ.cpp index 39d26ad8da75..fefd5ee7738d 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/IndexIVFSQ.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexIVFSQ.cpp @@ -16,6 +16,8 @@ #include #include #endif +#include +#include #include #include @@ -33,16 +35,20 @@ namespace knowhere { void IVFSQ::Train(const DatasetPtr& dataset_ptr, const Config& config) { - GETTENSOR(dataset_ptr) + GET_TENSOR_DATA_DIM(dataset_ptr) - std::stringstream index_type; - index_type << "IVF" << config[IndexParams::nlist] << "," - << "SQ" << config[IndexParams::nbits]; - auto build_index = - faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(config[Metric::TYPE].get())); - build_index->train(rows, (float*)p_data); + // std::stringstream index_type; + // index_type << "IVF" << config[IndexParams::nlist] << "," + // << "SQ" << config[IndexParams::nbits]; + // index_ = std::shared_ptr( + // faiss::index_factory(dim, index_type.str().c_str(), GetMetricType(config[Metric::TYPE].get()))); - index_.reset(faiss::clone_index(build_index)); + faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get()); + faiss::Index* coarse_quantizer = new faiss::IndexFlat(dim, metric_type); + index_ = std::shared_ptr(new faiss::IndexIVFScalarQuantizer( + coarse_quantizer, dim, config[IndexParams::nlist].get(), faiss::QuantizerType::QT_8bit, metric_type)); + + index_->train(rows, (float*)p_data); } VecIndexPtr @@ -64,5 +70,19 @@ IVFSQ::CopyCpuToGpu(const int64_t device_id, const Config& config) { #endif } +void +IVFSQ::UpdateIndexSize() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + auto ivfsq_index = dynamic_cast(index_.get()); + auto nb = ivfsq_index->invlists->compute_ntotal(); + auto code_size = ivfsq_index->code_size; + auto nlist = ivfsq_index->nlist; + auto d = ivfsq_index->d; + // ivf codes, ivf ids, sq trained vectors and quantizer + index_size_ = nb * code_size + nb * sizeof(int64_t) + 2 * d * sizeof(float) + nlist * d * sizeof(float); +} + } // namespace knowhere } // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexIVFSQ.h b/core/src/index/knowhere/knowhere/index/vector_index/IndexIVFSQ.h index 927ceb90f107..0c33eda5692e 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/IndexIVFSQ.h +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexIVFSQ.h @@ -34,6 +34,9 @@ class IVFSQ : public IVF { VecIndexPtr CopyCpuToGpu(const int64_t, const Config&) override; + + void + UpdateIndexSize() override; }; using IVFSQPtr = std::shared_ptr; diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexNSG.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexNSG.cpp index 301306b6459b..476dee4f57be 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/IndexNSG.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexNSG.cpp @@ -14,10 +14,10 @@ #include "knowhere/common/Exception.h" #include "knowhere/common/Timer.h" +#include "knowhere/index/IndexType.h" #include "knowhere/index/vector_index/IndexIDMAP.h" #include "knowhere/index/vector_index/IndexIVF.h" #include "knowhere/index/vector_index/IndexNSG.h" -#include "knowhere/index/vector_index/IndexType.h" #include "knowhere/index/vector_index/adapter/VectorAdapter.h" #include "knowhere/index/vector_index/impl/nsg/NSG.h" #include "knowhere/index/vector_index/impl/nsg/NSGIO.h" @@ -78,7 +78,7 @@ NSG::Query(const DatasetPtr& dataset_ptr, const Config& config) { KNOWHERE_THROW_MSG("index not initialize or trained"); } - GETTENSOR(dataset_ptr) + GET_TENSOR_DATA_DIM(dataset_ptr) try { auto elems = rows * config[meta::TOPK].get(); @@ -94,8 +94,8 @@ NSG::Query(const DatasetPtr& dataset_ptr, const Config& config) { s_params.k = config[meta::TOPK]; { std::lock_guard lk(mutex_); - index_->Search((float*)p_data, rows, dim, config[meta::TOPK].get(), p_dist, p_id, s_params, - blacklist); + index_->Search((float*)p_data, nullptr, rows, dim, config[meta::TOPK].get(), p_dist, p_id, + s_params, blacklist); } auto ret_ds = std::make_shared(); @@ -139,23 +139,46 @@ NSG::Train(const DatasetPtr& dataset_ptr, const Config& config) { b_params.out_degree = config[IndexParams::out_degree]; b_params.search_length = config[IndexParams::search_length]; - auto p_ids = dataset_ptr->Get(meta::IDS); + GET_TENSOR(dataset_ptr) - GETTENSOR(dataset_ptr) - index_ = std::make_shared(dim, rows, config[Metric::TYPE].get()); + impl::NsgIndex::Metric_Type metric; + auto metric_str = config[Metric::TYPE].get(); + if (metric_str == knowhere::Metric::IP) { + metric = impl::NsgIndex::Metric_Type::Metric_Type_IP; + } else if (metric_str == knowhere::Metric::L2) { + metric = impl::NsgIndex::Metric_Type::Metric_Type_L2; + } else { + KNOWHERE_THROW_MSG("Metric is not supported"); + } + + index_ = std::make_shared(dim, rows, metric); index_->SetKnnGraph(knng); index_->Build_with_ids(rows, (float*)p_data, (int64_t*)p_ids, b_params); } int64_t NSG::Count() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } return index_->ntotal; } int64_t NSG::Dim() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } return index_->dimension; } +void +NSG::UpdateIndexSize() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + index_size_ = index_->GetSize(); +} + } // namespace knowhere } // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexSPTAG.cpp b/core/src/index/knowhere/knowhere/index/vector_index/IndexSPTAG.cpp index ddbc37ec80b5..2dc86678f5aa 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/IndexSPTAG.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexSPTAG.cpp @@ -200,14 +200,28 @@ CPUSPTAGRNG::Query(const DatasetPtr& dataset_ptr, const Config& config) { int64_t CPUSPTAGRNG::Count() { + if (!index_ptr_) { + KNOWHERE_THROW_MSG("index not initialize"); + } return index_ptr_->GetNumSamples(); } int64_t CPUSPTAGRNG::Dim() { + if (!index_ptr_) { + KNOWHERE_THROW_MSG("index not initialize"); + } return index_ptr_->GetFeatureDim(); } +void +CPUSPTAGRNG::UpdateIndexSize() { + if (!index_ptr_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + index_size_ = index_ptr_->GetIndexSize(); +} + // void // CPUSPTAGRNG::Add(const DatasetPtr& origin, const Config& add_config) { // SetParameters(add_config); diff --git a/core/src/index/knowhere/knowhere/index/vector_index/IndexSPTAG.h b/core/src/index/knowhere/knowhere/index/vector_index/IndexSPTAG.h index 25361ae85702..945f3369b823 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/IndexSPTAG.h +++ b/core/src/index/knowhere/knowhere/index/vector_index/IndexSPTAG.h @@ -60,6 +60,9 @@ class CPUSPTAGRNG : public VecIndex { int64_t Dim() override; + void + UpdateIndexSize() override; + private: void SetParameters(const Config& config); diff --git a/core/src/index/knowhere/knowhere/index/vector_index/VecIndex.h b/core/src/index/knowhere/knowhere/index/vector_index/VecIndex.h index 9f49166f5ef2..d5e74f90ce69 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/VecIndex.h +++ b/core/src/index/knowhere/knowhere/index/vector_index/VecIndex.h @@ -20,11 +20,15 @@ #include "knowhere/common/Exception.h" #include "knowhere/common/Typedef.h" #include "knowhere/index/Index.h" -#include "knowhere/index/vector_index/IndexType.h" +#include "knowhere/index/IndexType.h" namespace milvus { namespace knowhere { +#define INDEX_DATA "INDEX_DATA" +#define RAW_DATA "RAW_DATA" +#define SQ8_DATA "SQ8_DATA" + class VecIndex : public Index { public: virtual void @@ -129,6 +133,10 @@ class VecIndex : public Index { index_size_ = size; } + virtual void + UpdateIndexSize() { + } + int64_t Size() override { return BlacklistSize() + UidsSize() + IndexSize(); diff --git a/core/src/index/knowhere/knowhere/index/vector_index/VecIndexFactory.cpp b/core/src/index/knowhere/knowhere/index/vector_index/VecIndexFactory.cpp index ae4400660eec..aead15349a12 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/VecIndexFactory.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/VecIndexFactory.cpp @@ -16,12 +16,15 @@ #include "knowhere/index/vector_index/IndexAnnoy.h" #include "knowhere/index/vector_index/IndexBinaryIDMAP.h" #include "knowhere/index/vector_index/IndexBinaryIVF.h" -#include "knowhere/index/vector_index/IndexHNSW.h" #include "knowhere/index/vector_index/IndexIDMAP.h" #include "knowhere/index/vector_index/IndexIVF.h" #include "knowhere/index/vector_index/IndexIVFPQ.h" #include "knowhere/index/vector_index/IndexIVFSQ.h" -#include "knowhere/index/vector_index/IndexNSG.h" +#include "knowhere/index/vector_offset_index/IndexHNSW_NM.h" +#include "knowhere/index/vector_offset_index/IndexHNSW_SQ8NM.h" +#include "knowhere/index/vector_offset_index/IndexIVFSQNR_NM.h" +#include "knowhere/index/vector_offset_index/IndexIVF_NM.h" +#include "knowhere/index/vector_offset_index/IndexNSG_NM.h" #ifdef MILVUS_SUPPORT_SPTAG #include "knowhere/index/vector_index/IndexSPTAG.h" #endif @@ -34,6 +37,8 @@ #include "knowhere/index/vector_index/gpu/IndexGPUIVFSQ.h" #include "knowhere/index/vector_index/gpu/IndexIVFSQHybrid.h" #include "knowhere/index/vector_index/helpers/Cloner.h" +#include "knowhere/index/vector_offset_index/gpu/IndexGPUIVFSQNR_NM.h" +#include "knowhere/index/vector_offset_index/gpu/IndexGPUIVF_NM.h" #endif namespace milvus { @@ -47,10 +52,10 @@ VecIndexFactory::CreateVecIndex(const IndexType& type, const IndexMode mode) { } else if (type == IndexEnum::INDEX_FAISS_IVFFLAT) { #ifdef MILVUS_GPU_VERSION if (mode == IndexMode::MODE_GPU) { - return std::make_shared(gpu_device); + return std::make_shared(gpu_device); } #endif - return std::make_shared(); + return std::make_shared(); } else if (type == IndexEnum::INDEX_FAISS_IVFPQ) { #ifdef MILVUS_GPU_VERSION if (mode == IndexMode::MODE_GPU) { @@ -74,7 +79,7 @@ VecIndexFactory::CreateVecIndex(const IndexType& type, const IndexMode mode) { } else if (type == IndexEnum::INDEX_FAISS_BIN_IVFFLAT) { return std::make_shared(); } else if (type == IndexEnum::INDEX_NSG) { - return std::make_shared(-1); + return std::make_shared(-1); #ifdef MILVUS_SUPPORT_SPTAG } else if (type == IndexEnum::INDEX_SPTAG_KDT_RNT) { return std::make_shared("KDT"); @@ -82,9 +87,13 @@ VecIndexFactory::CreateVecIndex(const IndexType& type, const IndexMode mode) { return std::make_shared("BKT"); #endif } else if (type == IndexEnum::INDEX_HNSW) { - return std::make_shared(); + return std::make_shared(); } else if (type == IndexEnum::INDEX_ANNOY) { return std::make_shared(); + } else if (type == IndexEnum::INDEX_FAISS_IVFSQ8NR) { + return std::make_shared(); + } else if (type == IndexEnum::INDEX_HNSW_SQ8NM) { + return std::make_shared(); } else { return nullptr; } diff --git a/core/src/index/knowhere/knowhere/index/vector_index/VecIndexFactory.h b/core/src/index/knowhere/knowhere/index/vector_index/VecIndexFactory.h index 7a09b54bff70..c96bd1dc7cf3 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/VecIndexFactory.h +++ b/core/src/index/knowhere/knowhere/index/vector_index/VecIndexFactory.h @@ -13,7 +13,7 @@ #include -#include "knowhere/index/vector_index/IndexType.h" +#include "knowhere/index/IndexType.h" #include "knowhere/index/vector_index/VecIndex.h" namespace milvus { diff --git a/core/src/index/knowhere/knowhere/index/vector_index/adapter/SptagAdapter.cpp b/core/src/index/knowhere/knowhere/index/vector_index/adapter/SptagAdapter.cpp index 030aa8665f33..cf07bd237deb 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/adapter/SptagAdapter.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/adapter/SptagAdapter.cpp @@ -32,7 +32,7 @@ ConvertToMetadataSet(const DatasetPtr& dataset_ptr) { std::shared_ptr ConvertToVectorSet(const DatasetPtr& dataset_ptr) { - GETTENSOR(dataset_ptr); + GET_TENSOR_DATA_DIM(dataset_ptr) size_t num_bytes = rows * dim * sizeof(float); SPTAG::ByteArray byte_array((uint8_t*)p_data, num_bytes, false); @@ -42,7 +42,7 @@ ConvertToVectorSet(const DatasetPtr& dataset_ptr) { std::vector ConvertToQueryResult(const DatasetPtr& dataset_ptr, const Config& config) { - GETTENSOR(dataset_ptr); + GET_TENSOR_DATA_DIM(dataset_ptr); int64_t k = config[meta::TOPK].get(); std::vector query_results(rows, SPTAG::QueryResult(nullptr, k, true)); diff --git a/core/src/index/knowhere/knowhere/index/vector_index/adapter/VectorAdapter.h b/core/src/index/knowhere/knowhere/index/vector_index/adapter/VectorAdapter.h index 44086796b593..9fe4e5b93b15 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/adapter/VectorAdapter.h +++ b/core/src/index/knowhere/knowhere/index/vector_index/adapter/VectorAdapter.h @@ -18,15 +18,21 @@ namespace milvus { namespace knowhere { -#define GETTENSOR(dataset_ptr) \ - int64_t dim = dataset_ptr->Get(meta::DIM); \ +#define GET_TENSOR_DATA(dataset_ptr) \ int64_t rows = dataset_ptr->Get(meta::ROWS); \ const void* p_data = dataset_ptr->Get(meta::TENSOR); -#define GETTENSORWITHIDS(dataset_ptr) \ - int64_t dim = dataset_ptr->Get(meta::DIM); \ - int64_t rows = dataset_ptr->Get(meta::ROWS); \ - const void* p_data = dataset_ptr->Get(meta::TENSOR); \ +#define GET_TENSOR_DATA_DIM(dataset_ptr) \ + GET_TENSOR_DATA(dataset_ptr) \ + int64_t dim = dataset_ptr->Get(meta::DIM); + +#define GET_TENSOR_DATA_ID(dataset_ptr) \ + GET_TENSOR_DATA(dataset_ptr) \ + const int64_t* p_ids = dataset_ptr->Get(meta::IDS); + +#define GET_TENSOR(dataset_ptr) \ + GET_TENSOR_DATA(dataset_ptr) \ + int64_t dim = dataset_ptr->Get(meta::DIM); \ const int64_t* p_ids = dataset_ptr->Get(meta::IDS); extern DatasetPtr diff --git a/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIDMAP.cpp b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIDMAP.cpp index 5bcc288c32c0..2585f88ba177 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIDMAP.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIDMAP.cpp @@ -19,8 +19,8 @@ #include #include "knowhere/common/Exception.h" +#include "knowhere/index/IndexType.h" #include "knowhere/index/vector_index/IndexIDMAP.h" -#include "knowhere/index/vector_index/IndexType.h" #include "knowhere/index/vector_index/adapter/VectorAdapter.h" #include "knowhere/index/vector_index/gpu/IndexGPUIDMAP.h" #include "knowhere/index/vector_index/helpers/FaissIO.h" diff --git a/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVF.cpp b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVF.cpp index 3552a1904535..22e606e0a847 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVF.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVF.cpp @@ -30,7 +30,7 @@ namespace knowhere { void GPUIVF::Train(const DatasetPtr& dataset_ptr, const Config& config) { - GETTENSOR(dataset_ptr) + GET_TENSOR_DATA_DIM(dataset_ptr) gpu_id_ = config[knowhere::meta::DEVICEID]; auto gpu_res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_); diff --git a/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFPQ.cpp b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFPQ.cpp index d9465761d92c..3c742d791148 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFPQ.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFPQ.cpp @@ -27,7 +27,7 @@ namespace knowhere { void GPUIVFPQ::Train(const DatasetPtr& dataset_ptr, const Config& config) { - GETTENSOR(dataset_ptr) + GET_TENSOR_DATA_DIM(dataset_ptr) gpu_id_ = config[knowhere::meta::DEVICEID]; auto gpu_res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_); diff --git a/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFSQ.cpp b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFSQ.cpp index 36e5dc13bded..66df444a9814 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFSQ.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFSQ.cpp @@ -9,6 +9,8 @@ // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express // or implied. See the License for the specific language governing permissions and limitations under the License +#include +#include #include #include @@ -26,14 +28,19 @@ namespace knowhere { void GPUIVFSQ::Train(const DatasetPtr& dataset_ptr, const Config& config) { - GETTENSOR(dataset_ptr) + GET_TENSOR_DATA_DIM(dataset_ptr) gpu_id_ = config[knowhere::meta::DEVICEID]; - std::stringstream index_type; - index_type << "IVF" << config[IndexParams::nlist] << "," - << "SQ" << config[IndexParams::nbits]; + // std::stringstream index_type; + // index_type << "IVF" << config[IndexParams::nlist] << "," + // << "SQ" << config[IndexParams::nbits]; + // faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get()); + // auto build_index = faiss::index_factory(dim, index_type.str().c_str(), metric_type); + faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get()); - auto build_index = faiss::index_factory(dim, index_type.str().c_str(), metric_type); + faiss::Index* coarse_quantizer = new faiss::IndexFlat(dim, metric_type); + auto build_index = new faiss::IndexIVFScalarQuantizer( + coarse_quantizer, dim, config[IndexParams::nlist].get(), faiss::QuantizerType::QT_8bit, metric_type); auto gpu_res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_); if (gpu_res != nullptr) { diff --git a/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFSQNR.cpp b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFSQNR.cpp new file mode 100644 index 000000000000..4e23bda32585 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFSQNR.cpp @@ -0,0 +1,72 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include +#include +#include +#include + +#include +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/gpu/IndexGPUIVFSQNR.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" +#include "knowhere/index/vector_offset_index/IndexIVFSQNR_NM.h" + +namespace milvus { +namespace knowhere { + +void +GPUIVFSQNR::Train(const DatasetPtr& dataset_ptr, const Config& config) { + GET_TENSOR_DATA_DIM(dataset_ptr) + gpu_id_ = config[knowhere::meta::DEVICEID]; + + // std::stringstream index_type; + // index_type << "IVF" << config[IndexParams::nlist] << "," + // << "SQ" << config[IndexParams::nbits]; + // faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get()); + // auto build_index = faiss::index_factory(dim, index_type.str().c_str(), metric_type); + + faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get()); + faiss::Index* coarse_quantizer = new faiss::IndexFlat(dim, metric_type); + auto build_index = + new faiss::IndexIVFScalarQuantizer(coarse_quantizer, dim, config[IndexParams::nlist].get(), + faiss::QuantizerType::QT_8bit, metric_type, false); + + auto gpu_res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_); + if (gpu_res != nullptr) { + ResScope rs(gpu_res, gpu_id_, true); + auto device_index = faiss::gpu::index_cpu_to_gpu(gpu_res->faiss_res.get(), gpu_id_, build_index); + device_index->train(rows, (float*)p_data); + + index_.reset(device_index); + res_ = gpu_res; + } else { + KNOWHERE_THROW_MSG("Build IVFSQ can't get gpu resource"); + } +} + +VecIndexPtr +GPUIVFSQNR::CopyGpuToCpu(const Config& config) { + std::lock_guard lk(mutex_); + + faiss::Index* device_index = index_.get(); + faiss::Index* host_index = faiss::gpu::index_gpu_to_cpu(device_index); + + std::shared_ptr new_index; + new_index.reset(host_index); + return std::make_shared(new_index); +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFSQNR.h b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFSQNR.h new file mode 100644 index 000000000000..f1603e2252e2 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFSQNR.h @@ -0,0 +1,43 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include + +#include "knowhere/index/vector_index/gpu/IndexGPUIVF.h" + +namespace milvus { +namespace knowhere { + +class GPUIVFSQNR : public GPUIVF { + public: + explicit GPUIVFSQNR(const int& device_id) : GPUIVF(device_id) { + index_type_ = IndexEnum::INDEX_FAISS_IVFSQ8; + } + + explicit GPUIVFSQNR(std::shared_ptr index, const int64_t device_id, ResPtr& res) + : GPUIVF(std::move(index), device_id, res) { + index_type_ = IndexEnum::INDEX_FAISS_IVFSQ8; + } + + void + Train(const DatasetPtr&, const Config&) override; + + VecIndexPtr + CopyGpuToCpu(const Config&) override; +}; + +using GPUIVFSQNRPtr = std::shared_ptr; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexIVFSQHybrid.cpp b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexIVFSQHybrid.cpp index dde0fcd2e896..5a485ecf97c3 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexIVFSQHybrid.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexIVFSQHybrid.cpp @@ -10,6 +10,7 @@ // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express // or implied. See the License for the specific language governing permissions and limitations under the License +#include #include #include #include @@ -30,7 +31,7 @@ namespace knowhere { void IVFSQHybrid::Train(const DatasetPtr& dataset_ptr, const Config& config) { - GETTENSOR(dataset_ptr) + GET_TENSOR_DATA_DIM(dataset_ptr) gpu_id_ = config[knowhere::meta::DEVICEID]; std::stringstream index_type; @@ -261,6 +262,20 @@ IVFSQHybrid::QueryImpl(int64_t n, const float* data, int64_t k, float* distances } } +void +IVFSQHybrid::UpdateIndexSize() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + auto ivfsqh_index = dynamic_cast(index_.get()); + auto nb = ivfsqh_index->invlists->compute_ntotal(); + auto code_size = ivfsqh_index->code_size; + auto nlist = ivfsqh_index->nlist; + auto d = ivfsqh_index->d; + // ivf codes, ivf ids, sq trained vectors and quantizer + index_size_ = nb * code_size + nb * sizeof(int64_t) + 2 * d * sizeof(float) + nlist * d * sizeof(float); +} + FaissIVFQuantizer::~FaissIVFQuantizer() { if (quantizer != nullptr) { delete quantizer; diff --git a/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexIVFSQHybrid.h b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexIVFSQHybrid.h index 9d092618066c..4aeb7f68675c 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexIVFSQHybrid.h +++ b/core/src/index/knowhere/knowhere/index/vector_index/gpu/IndexIVFSQHybrid.h @@ -77,6 +77,9 @@ class IVFSQHybrid : public GPUIVFSQ { void UnsetQuantizer(); + void + UpdateIndexSize() override; + protected: BinarySet SerializeImpl(const IndexType&) override; diff --git a/core/src/index/knowhere/knowhere/index/vector_index/helpers/Cloner.cpp b/core/src/index/knowhere/knowhere/index/vector_index/helpers/Cloner.cpp index 2ba189c5133c..2d44a9d3561e 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/helpers/Cloner.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/helpers/Cloner.cpp @@ -19,6 +19,8 @@ #include "knowhere/index/vector_index/gpu/GPUIndex.h" #include "knowhere/index/vector_index/gpu/IndexGPUIVF.h" #include "knowhere/index/vector_index/gpu/IndexIVFSQHybrid.h" +#include "knowhere/index/vector_offset_index/IndexIVFSQNR_NM.h" +#include "knowhere/index/vector_offset_index/IndexIVF_NM.h" namespace milvus { namespace knowhere { @@ -29,7 +31,6 @@ CopyIndexData(const VecIndexPtr& dst_index, const VecIndexPtr& src_index) { /* do real copy */ auto uids = src_index->GetUids(); dst_index->SetUids(uids); - dst_index->SetBlacklist(src_index->GetBlacklist()); dst_index->SetIndexSize(src_index->IndexSize()); } @@ -50,10 +51,14 @@ CopyCpuToGpu(const VecIndexPtr& index, const int64_t device_id, const Config& co VecIndexPtr result; if (auto device_index = std::dynamic_pointer_cast(index)) { result = device_index->CopyCpuToGpu(device_id, config); + } else if (auto cpu_index = std::dynamic_pointer_cast(index)) { + result = cpu_index->CopyCpuToGpu(device_id, config); } else if (auto device_index = std::dynamic_pointer_cast(index)) { result = device_index->CopyGpuToGpu(device_id, config); } else if (auto cpu_index = std::dynamic_pointer_cast(index)) { result = cpu_index->CopyCpuToGpu(device_id, config); + } else if (auto cpu_index = std::dynamic_pointer_cast(index)) { + result = cpu_index->CopyCpuToGpu(device_id, config); } else if (auto cpu_index = std::dynamic_pointer_cast(index)) { result = cpu_index->CopyCpuToGpu(device_id, config); } else if (auto cpu_index = std::dynamic_pointer_cast(index)) { diff --git a/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/Distance.cpp b/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/Distance.cpp index 95e5f6992258..77897c7bef8c 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/Distance.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/Distance.cpp @@ -237,7 +237,7 @@ DistanceL2::Compare(const float* a, const float* b, unsigned size) const { float DistanceIP::Compare(const float* a, const float* b, unsigned size) const { - return faiss::fvec_inner_product(a, b, (size_t)size); + return -(faiss::fvec_inner_product(a, b, (size_t)size)); } #endif diff --git a/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/NSG.cpp b/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/NSG.cpp index dbc4cd5bc5f9..3e7d52b7ac41 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/NSG.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/NSG.cpp @@ -31,27 +31,27 @@ namespace impl { unsigned int seed = 100; -NsgIndex::NsgIndex(const size_t& dimension, const size_t& n, std::string metric) +NsgIndex::NsgIndex(const size_t& dimension, const size_t& n, Metric_Type metric) : dimension(dimension), ntotal(n), metric_type(metric) { - if (metric == knowhere::Metric::L2) { + if (metric == Metric_Type::Metric_Type_L2) { distance_ = new DistanceL2; - } else if (metric == knowhere::Metric::IP) { + } else if (metric == Metric_Type::Metric_Type_IP) { distance_ = new DistanceIP; } } NsgIndex::~NsgIndex() { - delete[] ori_data_; + // delete[] ori_data_; delete[] ids_; delete distance_; } void -NsgIndex::Build_with_ids(size_t nb, const float* data, const int64_t* ids, const BuildParams& parameters) { +NsgIndex::Build_with_ids(size_t nb, float* data, const int64_t* ids, const BuildParams& parameters) { ntotal = nb; - ori_data_ = new float[ntotal * dimension]; + // ori_data_ = new float[ntotal * dimension]; ids_ = new int64_t[ntotal]; - memcpy((void*)ori_data_, (void*)data, sizeof(float) * ntotal * dimension); + // memcpy((void*)ori_data_, (void*)data, sizeof(float) * ntotal * dimension); memcpy((void*)ids_, (void*)ids, sizeof(int64_t) * ntotal); search_length = parameters.search_length; @@ -59,13 +59,13 @@ NsgIndex::Build_with_ids(size_t nb, const float* data, const int64_t* ids, const candidate_pool_size = parameters.candidate_pool_size; TimeRecorder rc("NSG", 1); - InitNavigationPoint(); + InitNavigationPoint(data); rc.RecordSection("init"); - Link(); + Link(data); rc.RecordSection("Link"); - CheckConnectivity(); + CheckConnectivity(data); rc.RecordSection("Connect"); rc.ElapseFromBegin("finish"); @@ -89,14 +89,14 @@ NsgIndex::Build_with_ids(size_t nb, const float* data, const int64_t* ids, const } void -NsgIndex::InitNavigationPoint() { +NsgIndex::InitNavigationPoint(float* data) { // calculate the center of vectors auto center = new float[dimension]; memset(center, 0, sizeof(float) * dimension); for (size_t i = 0; i < ntotal; i++) { for (size_t j = 0; j < dimension; j++) { - center[j] += ori_data_[i * dimension + j]; + center[j] += data[i * dimension + j]; } } for (size_t j = 0; j < dimension; j++) { @@ -106,7 +106,7 @@ NsgIndex::InitNavigationPoint() { // select navigation point std::vector resset; navigation_point = rand_r(&seed) % ntotal; // random initialize navigating point - GetNeighbors(center, resset, knng); + GetNeighbors(center, data, resset, knng); navigation_point = resset[0].id; // Debug code @@ -124,7 +124,7 @@ NsgIndex::InitNavigationPoint() { // Specify Link void -NsgIndex::GetNeighbors(const float* query, std::vector& resset, std::vector& fullset, +NsgIndex::GetNeighbors(const float* query, float* data, std::vector& resset, std::vector& fullset, boost::dynamic_bitset<>& has_calculated_dist) { auto& graph = knng; size_t buffer_size = search_length; @@ -174,7 +174,7 @@ NsgIndex::GetNeighbors(const float* query, std::vector& resset, std::v continue; } - float dist = distance_->Compare(ori_data_ + dimension * id, query, dimension); + float dist = distance_->Compare(data + dimension * id, query, dimension); resset[i] = Neighbor(id, dist, false); //// difference from other GetNeighbors @@ -199,7 +199,7 @@ NsgIndex::GetNeighbors(const float* query, std::vector& resset, std::v continue; has_calculated_dist[id] = true; - float dist = distance_->Compare(query, ori_data_ + dimension * id, dimension); + float dist = distance_->Compare(query, data + dimension * id, dimension); Neighbor nn(id, dist, false); fullset.push_back(nn); @@ -226,7 +226,7 @@ NsgIndex::GetNeighbors(const float* query, std::vector& resset, std::v // FindUnconnectedNode void -NsgIndex::GetNeighbors(const float* query, std::vector& resset, std::vector& fullset) { +NsgIndex::GetNeighbors(const float* query, float* data, std::vector& resset, std::vector& fullset) { auto& graph = nsg; size_t buffer_size = search_length; @@ -276,7 +276,7 @@ NsgIndex::GetNeighbors(const float* query, std::vector& resset, std::v continue; } - float dist = distance_->Compare(ori_data_ + id * dimension, query, dimension); + float dist = distance_->Compare(data + id * dimension, query, dimension); resset[i] = Neighbor(id, dist, false); } std::sort(resset.begin(), resset.end()); // sort by distance @@ -297,7 +297,7 @@ NsgIndex::GetNeighbors(const float* query, std::vector& resset, std::v continue; has_calculated_dist[id] = true; - float dist = distance_->Compare(ori_data_ + dimension * id, query, dimension); + float dist = distance_->Compare(data + dimension * id, query, dimension); Neighbor nn(id, dist, false); fullset.push_back(nn); @@ -323,7 +323,8 @@ NsgIndex::GetNeighbors(const float* query, std::vector& resset, std::v } void -NsgIndex::GetNeighbors(const float* query, std::vector& resset, Graph& graph, SearchParams* params) { +NsgIndex::GetNeighbors(const float* query, float* data, std::vector& resset, Graph& graph, + SearchParams* params) { size_t buffer_size = params ? params->search_length : search_length; if (buffer_size > ntotal) { @@ -367,7 +368,7 @@ NsgIndex::GetNeighbors(const float* query, std::vector& resset, Graph& KNOWHERE_THROW_MSG("Build Index Error, id > ntotal"); } - float dist = distance_->Compare(ori_data_ + id * dimension, query, dimension); + float dist = distance_->Compare(data + id * dimension, query, dimension); resset[i] = Neighbor(id, dist, false); } std::sort(resset.begin(), resset.end()); // sort by distance @@ -388,7 +389,7 @@ NsgIndex::GetNeighbors(const float* query, std::vector& resset, Graph& continue; has_calculated_dist[id] = true; - float dist = distance_->Compare(query, ori_data_ + dimension * id, dimension); + float dist = distance_->Compare(query, data + dimension * id, dimension); if (dist >= resset[buffer_size - 1].distance) continue; @@ -406,7 +407,6 @@ NsgIndex::GetNeighbors(const float* query, std::vector& resset, Graph& // std::cout << "pos: " << pos << ", nn: " << nn.id << ":" << nn.distance << ", nup: " << // nearest_updated_pos << std::endl; ///// - // trick: avoid search query search_length < init_ids.size() ... if (buffer_size + 1 < resset.size()) ++buffer_size; @@ -422,7 +422,7 @@ NsgIndex::GetNeighbors(const float* query, std::vector& resset, Graph& } void -NsgIndex::Link() { +NsgIndex::Link(float* data) { float* cut_graph_dist = new float[ntotal * out_degree]; nsg.resize(ntotal); @@ -437,8 +437,8 @@ NsgIndex::Link() { fullset.clear(); temp.clear(); flags.reset(); - GetNeighbors(ori_data_ + dimension * n, temp, fullset, flags); - SyncPrune(n, fullset, flags, cut_graph_dist); + GetNeighbors(data + dimension * n, data, temp, fullset, flags); + SyncPrune(data, n, fullset, flags, cut_graph_dist); } // Debug code @@ -464,20 +464,20 @@ NsgIndex::Link() { #pragma omp for schedule(dynamic, 100) for (unsigned n = 0; n < ntotal; ++n) { faiss::BuilderSuspend::check_wait(); - InterInsert(n, mutex_vec, cut_graph_dist); + InterInsert(data, n, mutex_vec, cut_graph_dist); } delete[] cut_graph_dist; } void -NsgIndex::SyncPrune(size_t n, std::vector& pool, boost::dynamic_bitset<>& has_calculated, +NsgIndex::SyncPrune(float* data, size_t n, std::vector& pool, boost::dynamic_bitset<>& has_calculated, float* cut_graph_dist) { // avoid lose nearest neighbor in knng for (size_t i = 0; i < knng[n].size(); ++i) { auto id = knng[n][i]; if (has_calculated[id]) continue; - float dist = distance_->Compare(ori_data_ + dimension * n, ori_data_ + dimension * id, dimension); + float dist = distance_->Compare(data + dimension * n, data + dimension * id, dimension); pool.emplace_back(Neighbor(id, dist, true)); } @@ -490,7 +490,7 @@ NsgIndex::SyncPrune(size_t n, std::vector& pool, boost::dynamic_bitset } result.push_back(pool[cursor]); // init result with nearest neighbor - SelectEdge(cursor, pool, result, true); + SelectEdge(data, cursor, pool, result, true); // filling the cut_graph auto& des_id_pool = nsg[n]; @@ -507,7 +507,7 @@ NsgIndex::SyncPrune(size_t n, std::vector& pool, boost::dynamic_bitset //>> Optimize: remove read-lock void -NsgIndex::InterInsert(unsigned n, std::vector& mutex_vec, float* cut_graph_dist) { +NsgIndex::InterInsert(float* data, unsigned n, std::vector& mutex_vec, float* cut_graph_dist) { auto& current = n; auto& neighbor_id_pool = nsg[current]; @@ -555,7 +555,7 @@ NsgIndex::InterInsert(unsigned n, std::vector& mutex_vec, float* cut std::sort(wait_for_link_pool.begin(), wait_for_link_pool.end()); result.push_back(wait_for_link_pool[start]); - SelectEdge(start, wait_for_link_pool, result); + SelectEdge(data, start, wait_for_link_pool, result); { LockGuard lk(mutex_vec[current_neighbor]); @@ -580,7 +580,8 @@ NsgIndex::InterInsert(unsigned n, std::vector& mutex_vec, float* cut } void -NsgIndex::SelectEdge(unsigned& cursor, std::vector& sort_pool, std::vector& result, bool limit) { +NsgIndex::SelectEdge(float* data, unsigned& cursor, std::vector& sort_pool, std::vector& result, + bool limit) { auto& pool = sort_pool; /* @@ -594,8 +595,7 @@ NsgIndex::SelectEdge(unsigned& cursor, std::vector& sort_pool, std::ve auto& p = pool[cursor]; bool should_link = true; for (size_t t = 0; t < result.size(); ++t) { - float dist = - distance_->Compare(ori_data_ + dimension * result[t].id, ori_data_ + dimension * p.id, dimension); + float dist = distance_->Compare(data + dimension * result[t].id, data + dimension * p.id, dimension); if (dist < p.distance) { should_link = false; @@ -608,7 +608,7 @@ NsgIndex::SelectEdge(unsigned& cursor, std::vector& sort_pool, std::ve } void -NsgIndex::CheckConnectivity() { +NsgIndex::CheckConnectivity(float* data) { auto root = navigation_point; boost::dynamic_bitset<> has_linked{ntotal, 0}; int64_t linked_count = 0; @@ -619,7 +619,7 @@ NsgIndex::CheckConnectivity() { if (linked_count >= static_cast(ntotal)) { break; } - FindUnconnectedNode(has_linked, root); + FindUnconnectedNode(data, has_linked, root); } } @@ -657,7 +657,7 @@ NsgIndex::DFS(size_t root, boost::dynamic_bitset<>& has_linked, int64_t& linked_ } void -NsgIndex::FindUnconnectedNode(boost::dynamic_bitset<>& has_linked, int64_t& root) { +NsgIndex::FindUnconnectedNode(float* data, boost::dynamic_bitset<>& has_linked, int64_t& root) { // find any of unlinked-node size_t id = ntotal; for (size_t i = 0; i < ntotal; i++) { // find not link @@ -672,7 +672,7 @@ NsgIndex::FindUnconnectedNode(boost::dynamic_bitset<>& has_linked, int64_t& root // search unlinked-node's neighbor std::vector tmp, pool; - GetNeighbors(ori_data_ + dimension * id, tmp, pool); + GetNeighbors(data + dimension * id, data, tmp, pool); std::sort(pool.begin(), pool.end()); size_t found = 0; @@ -831,21 +831,23 @@ NsgIndex::FindUnconnectedNode(boost::dynamic_bitset<>& has_linked, int64_t& root // } void -NsgIndex::Search(const float* query, const unsigned& nq, const unsigned& dim, const unsigned& k, float* dist, - int64_t* ids, SearchParams& params, faiss::ConcurrentBitsetPtr bitset) { +NsgIndex::Search(const float* query, float* data, const unsigned& nq, const unsigned& dim, const unsigned& k, + float* dist, int64_t* ids, SearchParams& params, faiss::ConcurrentBitsetPtr bitset) { std::vector> resset(nq); TimeRecorder rc("NsgIndex::search", 1); if (nq == 1) { - GetNeighbors(query, resset[0], nsg, ¶ms); + GetNeighbors(query, data, resset[0], nsg, ¶ms); } else { #pragma omp parallel for for (unsigned int i = 0; i < nq; ++i) { const float* single_query = query + i * dim; - GetNeighbors(single_query, resset[i], nsg, ¶ms); + GetNeighbors(single_query, data, resset[i], nsg, ¶ms); } } rc.RecordSection("search"); + + bool is_ip = (metric_type == Metric_Type::Metric_Type_IP); for (unsigned int i = 0; i < nq; ++i) { unsigned int pos = 0; for (unsigned int j = 0; j < resset[i].size(); ++j) { @@ -853,7 +855,7 @@ NsgIndex::Search(const float* query, const unsigned& nq, const unsigned& dim, co break; // already top k if (!bitset || !bitset->test((faiss::ConcurrentBitset::id_type_t)resset[i][j].id)) { ids[i * k + pos] = ids_[resset[i][j].id]; - dist[i * k + pos] = resset[i][j].distance; + dist[i * k + pos] = is_ip ? -resset[i][j].distance : resset[i][j].distance; ++pos; } } @@ -871,6 +873,22 @@ NsgIndex::SetKnnGraph(Graph& g) { knng = std::move(g); } +int64_t +NsgIndex::GetSize() { + int64_t ret = 0; + ret += sizeof(*this); + ret += ntotal * dimension * sizeof(float); + ret += ntotal * sizeof(int64_t); + ret += sizeof(*distance_); + for (auto i = 0; i < nsg.size(); ++i) { + ret += nsg[i].size() * sizeof(node_t); + } + for (auto i = 0; i < knng.size(); ++i) { + ret += knng[i].size() * sizeof(node_t); + } + return ret; +} + } // namespace impl } // namespace knowhere } // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/NSG.h b/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/NSG.h index 603af1417d73..c7524c08ba86 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/NSG.h +++ b/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/NSG.h @@ -43,12 +43,17 @@ using Graph = std::vector>; class NsgIndex { public: + enum Metric_Type { + Metric_Type_L2, + Metric_Type_IP, + }; + size_t dimension; - size_t ntotal; // totabl nb of indexed vectors - std::string metric_type; // L2 | IP + size_t ntotal; // totabl nb of indexed vectors + int32_t metric_type; // enum Metric_Type Distance* distance_; - float* ori_data_; + // float* ori_data_; int64_t* ids_; Graph nsg; // final graph Graph knng; // reset after build @@ -65,7 +70,7 @@ class NsgIndex { size_t out_degree; public: - explicit NsgIndex(const size_t& dimension, const size_t& n, std::string metric = knowhere::Metric::L2); + explicit NsgIndex(const size_t& dimension, const size_t& n, Metric_Type metric); NsgIndex() = default; @@ -74,12 +79,15 @@ class NsgIndex { void SetKnnGraph(Graph& knng); - virtual void - Build_with_ids(size_t nb, const float* data, const int64_t* ids, const BuildParams& parameters); + void + Build_with_ids(size_t nb, float* data, const int64_t* ids, const BuildParams& parameters); void - Search(const float* query, const unsigned& nq, const unsigned& dim, const unsigned& k, float* dist, int64_t* ids, - SearchParams& params, faiss::ConcurrentBitsetPtr bitset = nullptr); + Search(const float* query, float* data, const unsigned& nq, const unsigned& dim, const unsigned& k, float* dist, + int64_t* ids, SearchParams& params, faiss::ConcurrentBitsetPtr bitset = nullptr); + + int64_t + GetSize(); // Not support yet. // virtual void Add() = 0; @@ -95,46 +103,49 @@ class NsgIndex { // const BuildParam ¶meters); protected: - virtual void - InitNavigationPoint(); + void + InitNavigationPoint(float* data); // link specify void - GetNeighbors(const float* query, std::vector& resset, std::vector& fullset, + GetNeighbors(const float* query, float* data, std::vector& resset, std::vector& fullset, boost::dynamic_bitset<>& has_calculated_dist); // FindUnconnectedNode void - GetNeighbors(const float* query, std::vector& resset, std::vector& fullset); + GetNeighbors(const float* query, float* data, std::vector& resset, std::vector& fullset); // navigation-point void - GetNeighbors(const float* query, std::vector& resset, Graph& graph, SearchParams* param = nullptr); + GetNeighbors(const float* query, float* data, std::vector& resset, Graph& graph, + SearchParams* param = nullptr); // only for search // void // GetNeighbors(const float* query, node_t* I, float* D, SearchParams* params); void - Link(); + Link(float* data); void - SyncPrune(size_t q, std::vector& pool, boost::dynamic_bitset<>& has_calculated, float* cut_graph_dist); + SyncPrune(float* data, size_t q, std::vector& pool, boost::dynamic_bitset<>& has_calculated, + float* cut_graph_dist); void - SelectEdge(unsigned& cursor, std::vector& sort_pool, std::vector& result, bool limit = false); + SelectEdge(float* data, unsigned& cursor, std::vector& sort_pool, std::vector& result, + bool limit = false); void - InterInsert(unsigned n, std::vector& mutex_vec, float* dist); + InterInsert(float* data, unsigned n, std::vector& mutex_vec, float* dist); void - CheckConnectivity(); + CheckConnectivity(float* data); void DFS(size_t root, boost::dynamic_bitset<>& flags, int64_t& count); void - FindUnconnectedNode(boost::dynamic_bitset<>& flags, int64_t& root); + FindUnconnectedNode(float* data, boost::dynamic_bitset<>& flags, int64_t& root); }; } // namespace impl diff --git a/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/NSGIO.cpp b/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/NSGIO.cpp index 880cca71ea4e..ff9e021134df 100644 --- a/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/NSGIO.cpp +++ b/core/src/index/knowhere/knowhere/index/vector_index/impl/nsg/NSGIO.cpp @@ -19,10 +19,11 @@ namespace impl { void write_index(NsgIndex* index, MemoryIOWriter& writer) { + writer(&index->metric_type, sizeof(int32_t), 1); writer(&index->ntotal, sizeof(index->ntotal), 1); writer(&index->dimension, sizeof(index->dimension), 1); writer(&index->navigation_point, sizeof(index->navigation_point), 1); - writer(index->ori_data_, sizeof(float) * index->ntotal * index->dimension, 1); + // writer(index->ori_data_, sizeof(float) * index->ntotal * index->dimension, 1); writer(index->ids_, sizeof(int64_t) * index->ntotal, 1); for (unsigned i = 0; i < index->ntotal; ++i) { @@ -36,14 +37,16 @@ NsgIndex* read_index(MemoryIOReader& reader) { size_t ntotal; size_t dimension; + int32_t metric; + reader(&metric, sizeof(int32_t), 1); reader(&ntotal, sizeof(size_t), 1); reader(&dimension, sizeof(size_t), 1); - auto index = new NsgIndex(dimension, ntotal); + auto index = new NsgIndex(dimension, ntotal, (impl::NsgIndex::Metric_Type)metric); reader(&index->navigation_point, sizeof(index->navigation_point), 1); - index->ori_data_ = new float[index->ntotal * index->dimension]; + // index->ori_data_ = new float[index->ntotal * index->dimension]; index->ids_ = new int64_t[index->ntotal]; - reader(index->ori_data_, sizeof(float) * index->ntotal * index->dimension, 1); + // reader(index->ori_data_, sizeof(float) * index->ntotal * index->dimension, 1); reader(index->ids_, sizeof(int64_t) * index->ntotal, 1); index->nsg.reserve(index->ntotal); diff --git a/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexHNSW_NM.cpp b/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexHNSW_NM.cpp new file mode 100644 index 000000000000..1f4055909db2 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexHNSW_NM.cpp @@ -0,0 +1,199 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "knowhere/index/vector_offset_index/IndexHNSW_NM.h" + +#include +#include +#include +#include +#include + +#include "faiss/BuilderSuspend.h" +#include "hnswlib/space_ip.h" +#include "hnswlib/space_l2.h" +#include "knowhere/common/Exception.h" +#include "knowhere/common/Log.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/helpers/FaissIO.h" + +namespace milvus { +namespace knowhere { + +// void +// normalize_vector(float* data, float* norm_array, size_t dim) { +// float norm = 0.0f; +// for (int i = 0; i < dim; i++) norm += data[i] * data[i]; +// norm = 1.0f / (sqrtf(norm) + 1e-30f); +// for (int i = 0; i < dim; i++) norm_array[i] = data[i] * norm; +// } + +BinarySet +IndexHNSW_NM::Serialize(const Config& config) { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + try { + MemoryIOWriter writer; + index_->saveIndex(writer); + std::shared_ptr data(writer.data_); + + BinarySet res_set; + res_set.Append("HNSW", data, writer.rp); + return res_set; + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +IndexHNSW_NM::Load(const BinarySet& index_binary) { + try { + auto binary = index_binary.GetByName("HNSW"); + + MemoryIOReader reader; + reader.total = binary->size; + reader.data_ = binary->data.get(); + + hnswlib_nm::SpaceInterface* space; + index_ = std::make_shared>(space); + index_->loadIndex(reader); + + normalize = (index_->metric_type_ == 1); // 1 == InnerProduct + + data_ = index_binary.GetByName(RAW_DATA)->data; + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +IndexHNSW_NM::Train(const DatasetPtr& dataset_ptr, const Config& config) { + try { + int64_t dim = dataset_ptr->Get(meta::DIM); + int64_t rows = dataset_ptr->Get(meta::ROWS); + + hnswlib_nm::SpaceInterface* space; + if (config[Metric::TYPE] == Metric::L2) { + space = new hnswlib_nm::L2Space(dim); + } else if (config[Metric::TYPE] == Metric::IP) { + space = new hnswlib_nm::InnerProductSpace(dim); + normalize = true; + } + index_ = std::make_shared>( + space, rows, config[IndexParams::M].get(), config[IndexParams::efConstruction].get()); + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +IndexHNSW_NM::Add(const DatasetPtr& dataset_ptr, const Config& config) { + // It will not call Query() just after Add() + // So, not to set 'data_' is allowed. + + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + + std::lock_guard lk(mutex_); + + GET_TENSOR_DATA_ID(dataset_ptr) + + auto base = index_->getCurrentElementCount(); + auto pp_data = const_cast(p_data); + index_->addPoint(pp_data, p_ids[0], base, 0); +#pragma omp parallel for + for (int i = 1; i < rows; ++i) { + faiss::BuilderSuspend::check_wait(); + index_->addPoint(pp_data, p_ids[i], base, i); + } +} + +DatasetPtr +IndexHNSW_NM::Query(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + GET_TENSOR_DATA_DIM(dataset_ptr) + + size_t k = config[meta::TOPK].get(); + size_t id_size = sizeof(int64_t) * k; + size_t dist_size = sizeof(float) * k; + auto p_id = (int64_t*)malloc(id_size * rows); + auto p_dist = (float*)malloc(dist_size * rows); + + index_->setEf(config[IndexParams::ef]); + + using P = std::pair; + auto compare = [](const P& v1, const P& v2) { return v1.first < v2.first; }; + + faiss::ConcurrentBitsetPtr blacklist = GetBlacklist(); +#pragma omp parallel for + for (unsigned int i = 0; i < rows; ++i) { + std::vector

ret; + const float* single_query = (float*)p_data + i * dim; + + ret = index_->searchKnn_NM((void*)single_query, k, compare, blacklist, (float*)(data_.get())); + + while (ret.size() < k) { + ret.emplace_back(std::make_pair(-1, -1)); + } + std::vector dist; + std::vector ids; + + if (normalize) { + std::transform(ret.begin(), ret.end(), std::back_inserter(dist), + [](const std::pair& e) { return float(1 - e.first); }); + } else { + std::transform(ret.begin(), ret.end(), std::back_inserter(dist), + [](const std::pair& e) { return e.first; }); + } + std::transform(ret.begin(), ret.end(), std::back_inserter(ids), + [](const std::pair& e) { return e.second; }); + + memcpy(p_dist + i * k, dist.data(), dist_size); + memcpy(p_id + i * k, ids.data(), id_size); + } + + auto ret_ds = std::make_shared(); + ret_ds->Set(meta::IDS, p_id); + ret_ds->Set(meta::DISTANCE, p_dist); + return ret_ds; +} + +int64_t +IndexHNSW_NM::Count() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_->cur_element_count; +} + +int64_t +IndexHNSW_NM::Dim() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return (*(size_t*)index_->dist_func_param_); +} + +void +IndexHNSW_NM::UpdateIndexSize() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + index_size_ = index_->cal_size() + Dim() * Count() * sizeof(float); +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexHNSW_NM.h b/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexHNSW_NM.h new file mode 100644 index 000000000000..a9c13dcb4d0e --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexHNSW_NM.h @@ -0,0 +1,69 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include + +#include "hnswlib/hnswalg_nm.h" +#include "hnswlib/hnswlib_nm.h" + +#include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/VecIndex.h" + +namespace milvus { +namespace knowhere { + +class IndexHNSW_NM : public VecIndex { + public: + IndexHNSW_NM() { + index_type_ = IndexEnum::INDEX_HNSW; + } + + BinarySet + Serialize(const Config& config = Config()) override; + + void + Load(const BinarySet& index_binary) override; + + void + Train(const DatasetPtr& dataset_ptr, const Config& config) override; + + void + Add(const DatasetPtr& dataset_ptr, const Config& config) override; + + void + AddWithoutIds(const DatasetPtr&, const Config&) override { + KNOWHERE_THROW_MSG("Incremental index is not supported"); + } + + DatasetPtr + Query(const DatasetPtr& dataset_ptr, const Config& config) override; + + int64_t + Count() override; + + int64_t + Dim() override; + + void + UpdateIndexSize() override; + + private: + bool normalize = false; + std::mutex mutex_; + std::shared_ptr> index_ = nullptr; + std::shared_ptr data_ = nullptr; +}; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexHNSW_SQ8NM.cpp b/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexHNSW_SQ8NM.cpp new file mode 100644 index 000000000000..119663c546d1 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexHNSW_SQ8NM.cpp @@ -0,0 +1,195 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "knowhere/index/vector_offset_index/IndexHNSW_SQ8NM.h" + +#include +#include +#include +#include +#include + +#include "faiss/BuilderSuspend.h" +#include "hnswlib/space_ip.h" +#include "hnswlib/space_l2.h" +#include "knowhere/common/Exception.h" +#include "knowhere/common/Log.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/helpers/FaissIO.h" + +namespace milvus { +namespace knowhere { + +BinarySet +IndexHNSW_SQ8NM::Serialize(const Config& config) { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + try { + MemoryIOWriter writer; + index_->saveIndex(writer); + std::shared_ptr data(writer.data_); + + BinarySet res_set; + res_set.Append("HNSW_SQ8", data, writer.rp); + res_set.Append(SQ8_DATA, data_, Dim() * (2 * sizeof(float) + Count())); + return res_set; + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +IndexHNSW_SQ8NM::Load(const BinarySet& index_binary) { + try { + auto binary = index_binary.GetByName("HNSW_SQ8"); + + MemoryIOReader reader; + reader.total = binary->size; + reader.data_ = binary->data.get(); + + hnswlib_nm::SpaceInterface* space; + index_ = std::make_shared>(space); + index_->loadIndex(reader); + + normalize = (index_->metric_type_ == 1); // 1 == InnerProduct + + data_ = index_binary.GetByName(SQ8_DATA)->data; + index_->SetSq8((float*)(data_.get() + Dim() * Count())); + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +IndexHNSW_SQ8NM::Train(const DatasetPtr& dataset_ptr, const Config& config) { + try { + GET_TENSOR_DATA_DIM(dataset_ptr) + + hnswlib_nm::SpaceInterface* space; + if (config[Metric::TYPE] == Metric::L2) { + space = new hnswlib_nm::L2Space(dim); + } else if (config[Metric::TYPE] == Metric::IP) { + space = new hnswlib_nm::InnerProductSpace(dim); + normalize = true; + } + index_ = std::make_shared>( + space, rows, config[IndexParams::M].get(), config[IndexParams::efConstruction].get()); + auto data_space = new uint8_t[dim * (rows + 2 * sizeof(float))]; + index_->sq_train(rows, (const float*)p_data, data_space); + data_ = std::shared_ptr(data_space); + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +IndexHNSW_SQ8NM::Add(const DatasetPtr& dataset_ptr, const Config& config) { + // It will not call Query() just after Add() + // So, not to set 'data_' is allowed. + + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + + std::lock_guard lk(mutex_); + + GET_TENSOR_DATA_ID(dataset_ptr) + + auto base = index_->getCurrentElementCount(); + auto pp_data = const_cast(p_data); + index_->addPoint(pp_data, p_ids[0], base, 0); +#pragma omp parallel for + for (int i = 1; i < rows; ++i) { + faiss::BuilderSuspend::check_wait(); + index_->addPoint(pp_data, p_ids[i], base, i); + } +} + +DatasetPtr +IndexHNSW_SQ8NM::Query(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + GET_TENSOR_DATA_DIM(dataset_ptr) + + size_t k = config[meta::TOPK].get(); + size_t id_size = sizeof(int64_t) * k; + size_t dist_size = sizeof(float) * k; + auto p_id = (int64_t*)malloc(id_size * rows); + auto p_dist = (float*)malloc(dist_size * rows); + + index_->setEf(config[IndexParams::ef]); + + using P = std::pair; + auto compare = [](const P& v1, const P& v2) { return v1.first < v2.first; }; + + faiss::ConcurrentBitsetPtr blacklist = GetBlacklist(); +#pragma omp parallel for + for (unsigned int i = 0; i < rows; ++i) { + std::vector

ret; + const float* single_query = (float*)p_data + i * dim; + + ret = index_->searchKnn_NM((void*)single_query, k, compare, blacklist, (float*)(data_.get())); + + while (ret.size() < k) { + ret.emplace_back(std::make_pair(-1, -1)); + } + std::vector dist; + std::vector ids; + + if (normalize) { + std::transform(ret.begin(), ret.end(), std::back_inserter(dist), + [](const std::pair& e) { return float(1 - e.first); }); + } else { + std::transform(ret.begin(), ret.end(), std::back_inserter(dist), + [](const std::pair& e) { return e.first; }); + } + std::transform(ret.begin(), ret.end(), std::back_inserter(ids), + [](const std::pair& e) { return e.second; }); + + memcpy(p_dist + i * k, dist.data(), dist_size); + memcpy(p_id + i * k, ids.data(), id_size); + } + + auto ret_ds = std::make_shared(); + ret_ds->Set(meta::IDS, p_id); + ret_ds->Set(meta::DISTANCE, p_dist); + return ret_ds; +} + +int64_t +IndexHNSW_SQ8NM::Count() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_->cur_element_count; +} + +int64_t +IndexHNSW_SQ8NM::Dim() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return (*(size_t*)index_->dist_func_param_); +} + +void +IndexHNSW_SQ8NM::UpdateIndexSize() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + index_size_ = index_->cal_size() + Dim() * (2 * sizeof(float) + Count()); +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexHNSW_SQ8NM.h b/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexHNSW_SQ8NM.h new file mode 100644 index 000000000000..dfe1eecc8fe5 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexHNSW_SQ8NM.h @@ -0,0 +1,68 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include + +#include "hnswlib/hnswalg_nm.h" + +#include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/VecIndex.h" + +namespace milvus { +namespace knowhere { + +class IndexHNSW_SQ8NM : public VecIndex { + public: + IndexHNSW_SQ8NM() { + index_type_ = IndexEnum::INDEX_HNSW_SQ8NM; + } + + BinarySet + Serialize(const Config& config = Config()) override; + + void + Load(const BinarySet& index_binary) override; + + void + Train(const DatasetPtr& dataset_ptr, const Config& config) override; + + void + Add(const DatasetPtr& dataset_ptr, const Config& config) override; + + void + AddWithoutIds(const DatasetPtr&, const Config&) override { + KNOWHERE_THROW_MSG("Incremental index is not supported"); + } + + DatasetPtr + Query(const DatasetPtr& dataset_ptr, const Config& config) override; + + int64_t + Count() override; + + int64_t + Dim() override; + + void + UpdateIndexSize() override; + + private: + bool normalize = false; + std::mutex mutex_; + std::shared_ptr> index_ = nullptr; + std::shared_ptr data_ = nullptr; +}; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexIVFSQNR_NM.cpp b/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexIVFSQNR_NM.cpp new file mode 100644 index 000000000000..78dc1dc3025a --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexIVFSQNR_NM.cpp @@ -0,0 +1,223 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include +#include +#include + +#ifdef MILVUS_GPU_VERSION +#include +#include +#endif +#include +#include +#include +#include +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" +#include "knowhere/index/vector_offset_index/IndexIVFSQNR_NM.h" +#ifdef MILVUS_GPU_VERSION +#include "knowhere/index/vector_index/gpu/IndexGPUIVFSQ.h" +#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h" +#endif + +namespace milvus { +namespace knowhere { + +BinarySet +IVFSQNR_NM::Serialize(const Config& config) { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + std::lock_guard lk(mutex_); + BinarySet res_set = SerializeImpl(index_type_); + + size_t d = index_->d; + auto ivfsq_index = dynamic_cast(index_.get()); + auto invlists = ivfsq_index->invlists; + auto rows = invlists->compute_ntotal(); + auto sq = ivfsq_index->sq; + auto code_size = ivfsq_index->code_size; + auto arranged_data = new uint8_t[code_size * rows + 2 * d * sizeof(float)]; + size_t curr_index = 0; + + // convert arranged sq8 data to sq8 data + auto ails = dynamic_cast(invlists); + for (size_t i = 0; i < invlists->nlist; i++) { + auto list_size = ails->ids[i].size(); + for (size_t j = 0; j < list_size; j++) { + memcpy(arranged_data + code_size * ails->ids[i][j], data_.get() + code_size * (curr_index + j), code_size); + } + curr_index += list_size; + } + + memcpy(arranged_data + code_size * curr_index, sq.trained.data(), 2 * d * sizeof(float)); + + res_set.Append(SQ8_DATA, std::shared_ptr(arranged_data), + code_size * rows * sizeof(uint8_t) + 2 * d * sizeof(float)); + return res_set; +} + +void +IVFSQNR_NM::Load(const BinarySet& binary_set) { + std::lock_guard lk(mutex_); + data_ = binary_set.GetByName(SQ8_DATA)->data; + LoadImpl(binary_set, index_type_); + // arrange sq8 data + auto ivfsq_index = dynamic_cast(index_.get()); + auto invlists = ivfsq_index->invlists; + auto rows = invlists->compute_ntotal(); + auto sq = ivfsq_index->sq; + auto code_size = sq.code_size; + auto d = sq.d; + auto arranged_data = new uint8_t[code_size * rows + 2 * d * sizeof(float)]; + prefix_sum.resize(invlists->nlist); + size_t curr_index = 0; + +#ifndef MILVUS_GPU_VERSION + auto ails = dynamic_cast(invlists); + for (size_t i = 0; i < invlists->nlist; i++) { + auto list_size = ails->ids[i].size(); + for (size_t j = 0; j < list_size; j++) { + memcpy(arranged_data + code_size * (curr_index + j), data_.get() + code_size * ails->ids[i][j], code_size); + } + prefix_sum[i] = curr_index; + curr_index += list_size; + } +#else + auto rol = dynamic_cast(invlists); + auto lengths = rol->readonly_length; + auto rol_ids = (const int64_t*)rol->pin_readonly_ids->data; + for (size_t i = 0; i < invlists->nlist; i++) { + auto list_size = lengths[i]; + for (size_t j = 0; j < list_size; j++) { + memcpy(arranged_data + code_size * (curr_index + j), data_.get() + code_size * rol_ids[curr_index + j], + code_size); + } + prefix_sum[i] = curr_index; + curr_index += list_size; + } +#endif + memcpy(arranged_data + code_size * curr_index, sq.trained.data(), 2 * d * sizeof(float)); + data_ = std::shared_ptr(arranged_data); +} + +void +IVFSQNR_NM::Train(const DatasetPtr& dataset_ptr, const Config& config) { + GET_TENSOR_DATA_DIM(dataset_ptr) + + faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get()); + faiss::Index* coarse_quantizer = new faiss::IndexFlat(dim, metric_type); + index_ = std::shared_ptr( + new faiss::IndexIVFScalarQuantizer(coarse_quantizer, dim, config[IndexParams::nlist].get(), + faiss::QuantizerType::QT_8bit, metric_type, false)); + + index_->train(rows, (float*)p_data); +} + +void +IVFSQNR_NM::Add(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + std::lock_guard lk(mutex_); + GET_TENSOR_DATA_ID(dataset_ptr) + index_->add_with_ids_without_codes(rows, (float*)p_data, p_ids); + + ArrangeCodes(dataset_ptr, config); +} + +void +IVFSQNR_NM::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + std::lock_guard lk(mutex_); + GET_TENSOR_DATA(dataset_ptr) + index_->add_without_codes(rows, (float*)p_data); + + ArrangeCodes(dataset_ptr, config); +} + +VecIndexPtr +IVFSQNR_NM::CopyCpuToGpu(const int64_t device_id, const Config& config) { +#ifdef MILVUS_GPU_VERSION + if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) { + ResScope rs(res, device_id, false); + + auto gpu_index = + faiss::gpu::index_cpu_to_gpu_without_codes(res->faiss_res.get(), device_id, index_.get(), data_.get()); + + std::shared_ptr device_index; + device_index.reset(gpu_index); + return std::make_shared(device_index, device_id, res); + } else { + KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu_resource"); + } +#else + KNOWHERE_THROW_MSG("Calling IVFSQNR_NM::CopyCpuToGpu when we are using CPU version"); +#endif +} + +void +IVFSQNR_NM::ArrangeCodes(const DatasetPtr& dataset_ptr, const Config& config) { + GET_TENSOR_DATA_DIM(dataset_ptr) + + // Construct arranged sq8 data from original data + const float* original_data = (const float*)p_data; + auto ivfsq_index = dynamic_cast(index_.get()); + auto sq = ivfsq_index->sq; + auto invlists = ivfsq_index->invlists; + auto code_size = sq.code_size; + auto arranged_data = new uint8_t[code_size * rows + 2 * dim * sizeof(float)]; + std::unique_ptr squant(sq.select_quantizer()); + std::vector one_code(code_size); + size_t curr_index = 0; + + auto ails = dynamic_cast(invlists); + for (size_t i = 0; i < invlists->nlist; i++) { + auto list_size = ails->ids[i].size(); + for (size_t j = 0; j < list_size; j++) { + const float* x_j = original_data + dim * ails->ids[i][j]; + + memset(one_code.data(), 0, code_size); + squant->encode_vector(x_j, one_code.data()); + memcpy(arranged_data + code_size * (curr_index + j), one_code.data(), code_size); + } + curr_index += list_size; + } + + memcpy(arranged_data + code_size * curr_index, sq.trained.data(), 2 * dim * sizeof(float)); + data_ = std::shared_ptr(arranged_data); +} + +void +IVFSQNR_NM::UpdateIndexSize() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + auto ivfsq_index = dynamic_cast(index_.get()); + auto nb = ivfsq_index->invlists->compute_ntotal(); + auto code_size = ivfsq_index->code_size; + auto nlist = ivfsq_index->nlist; + auto d = ivfsq_index->d; + // ivf codes, ivf ids, sq trained vectors and quantizer + index_size_ = nb * code_size + nb * sizeof(int64_t) + 2 * d * sizeof(float) + nlist * d * sizeof(float); +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexIVFSQNR_NM.h b/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexIVFSQNR_NM.h new file mode 100644 index 000000000000..c622337672e3 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexIVFSQNR_NM.h @@ -0,0 +1,65 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include + +#include "knowhere/index/vector_offset_index/IndexIVF_NM.h" + +namespace milvus { +namespace knowhere { + +class IVFSQNR_NM : public IVF_NM { + public: + IVFSQNR_NM() : IVF_NM() { + index_type_ = IndexEnum::INDEX_FAISS_IVFSQ8NR; + } + + explicit IVFSQNR_NM(std::shared_ptr index) : IVF_NM(std::move(index)) { + index_type_ = IndexEnum::INDEX_FAISS_IVFSQ8NR; + } + + explicit IVFSQNR_NM(std::shared_ptr index, uint8_t* data) : IVF_NM(std::move(index)) { + index_type_ = IndexEnum::INDEX_FAISS_IVFSQ8NR; + data_ = std::shared_ptr(data, [&](uint8_t*) {}); + } + + BinarySet + Serialize(const Config& config = Config()) override; + + void + Load(const BinarySet&) override; + + void + Train(const DatasetPtr&, const Config&) override; + + void + Add(const DatasetPtr&, const Config&) override; + + void + AddWithoutIds(const DatasetPtr&, const Config&) override; + + VecIndexPtr + CopyCpuToGpu(const int64_t, const Config&) override; + + void + ArrangeCodes(const DatasetPtr&, const Config&); + + void + UpdateIndexSize() override; +}; + +using IVFSQNRNMPtr = std::shared_ptr; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexIVF_NM.cpp b/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexIVF_NM.cpp new file mode 100644 index 000000000000..aa7d9ea0bb38 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexIVF_NM.cpp @@ -0,0 +1,368 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef MILVUS_GPU_VERSION +#include +#include +#endif + +#include +#include +#include +#include +#include +#include + +#include "faiss/BuilderSuspend.h" +#include "knowhere/common/Exception.h" +#include "knowhere/common/Log.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" +#include "knowhere/index/vector_offset_index/IndexIVF_NM.h" +#ifdef MILVUS_GPU_VERSION +#include "knowhere/index/vector_index/gpu/IndexGPUIVF.h" +#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h" +#endif + +namespace milvus { +namespace knowhere { + +using stdclock = std::chrono::high_resolution_clock; + +BinarySet +IVF_NM::Serialize(const Config& config) { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + std::lock_guard lk(mutex_); + return SerializeImpl(index_type_); +} + +void +IVF_NM::Load(const BinarySet& binary_set) { + std::lock_guard lk(mutex_); + LoadImpl(binary_set, index_type_); + + // Construct arranged data from original data + auto binary = binary_set.GetByName(RAW_DATA); + const float* original_data = (const float*)binary->data.get(); + auto ivf_index = dynamic_cast(index_.get()); + auto invlists = ivf_index->invlists; + auto d = ivf_index->d; + auto nb = (size_t)(binary->size / invlists->code_size); + auto arranged_data = new uint8_t[d * sizeof(float) * nb]; + prefix_sum.resize(invlists->nlist); + size_t curr_index = 0; + +#ifndef MILVUS_GPU_VERSION + auto ails = dynamic_cast(invlists); + for (size_t i = 0; i < invlists->nlist; i++) { + auto list_size = ails->ids[i].size(); + for (size_t j = 0; j < list_size; j++) { + memcpy(arranged_data + d * sizeof(float) * (curr_index + j), original_data + d * ails->ids[i][j], + d * sizeof(float)); + } + prefix_sum[i] = curr_index; + curr_index += list_size; + } +#else + auto rol = dynamic_cast(invlists); + auto lengths = rol->readonly_length; + auto rol_ids = (const int64_t*)rol->pin_readonly_ids->data; + for (size_t i = 0; i < invlists->nlist; i++) { + auto list_size = lengths[i]; + for (size_t j = 0; j < list_size; j++) { + memcpy(arranged_data + d * sizeof(float) * (curr_index + j), original_data + d * rol_ids[curr_index + j], + d * sizeof(float)); + } + prefix_sum[i] = curr_index; + curr_index += list_size; + } +#endif + data_ = std::shared_ptr(arranged_data); +} + +void +IVF_NM::Train(const DatasetPtr& dataset_ptr, const Config& config) { + GET_TENSOR_DATA_DIM(dataset_ptr) + + faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get()); + faiss::Index* coarse_quantizer = new faiss::IndexFlat(dim, metric_type); + int64_t nlist = config[IndexParams::nlist].get(); + index_ = std::shared_ptr(new faiss::IndexIVFFlat(coarse_quantizer, dim, nlist, metric_type)); + index_->train(rows, (float*)p_data); +} + +void +IVF_NM::Add(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + std::lock_guard lk(mutex_); + GET_TENSOR_DATA_ID(dataset_ptr) + index_->add_with_ids_without_codes(rows, (float*)p_data, p_ids); +} + +void +IVF_NM::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + std::lock_guard lk(mutex_); + GET_TENSOR_DATA(dataset_ptr) + index_->add_without_codes(rows, (float*)p_data); +} + +DatasetPtr +IVF_NM::Query(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + GET_TENSOR_DATA(dataset_ptr) + + try { + fiu_do_on("IVF_NM.Search.throw_std_exception", throw std::exception()); + fiu_do_on("IVF_NM.Search.throw_faiss_exception", throw faiss::FaissException("")); + int64_t k = config[meta::TOPK].get(); + auto elems = rows * k; + + size_t p_id_size = sizeof(int64_t) * elems; + size_t p_dist_size = sizeof(float) * elems; + auto p_id = (int64_t*)malloc(p_id_size); + auto p_dist = (float*)malloc(p_dist_size); + + QueryImpl(rows, (float*)p_data, k, p_dist, p_id, config); + + auto ret_ds = std::make_shared(); + ret_ds->Set(meta::IDS, p_id); + ret_ds->Set(meta::DISTANCE, p_dist); + return ret_ds; + } catch (faiss::FaissException& e) { + KNOWHERE_THROW_MSG(e.what()); + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +#if 0 +DatasetPtr +IVF_NM::QueryById(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + auto rows = dataset_ptr->Get(meta::ROWS); + auto p_data = dataset_ptr->Get(meta::IDS); + + try { + int64_t k = config[meta::TOPK].get(); + auto elems = rows * k; + + size_t p_id_size = sizeof(int64_t) * elems; + size_t p_dist_size = sizeof(float) * elems; + auto p_id = (int64_t*)malloc(p_id_size); + auto p_dist = (float*)malloc(p_dist_size); + + // todo: enable search by id (zhiru) + // auto blacklist = dataset_ptr->Get("bitset"); + auto index_ivf = std::static_pointer_cast(index_); + index_ivf->search_by_id(rows, p_data, k, p_dist, p_id, bitset_); + + auto ret_ds = std::make_shared(); + ret_ds->Set(meta::IDS, p_id); + ret_ds->Set(meta::DISTANCE, p_dist); + return ret_ds; + } catch (faiss::FaissException& e) { + KNOWHERE_THROW_MSG(e.what()); + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +DatasetPtr +IVF_NM::GetVectorById(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + auto p_data = dataset_ptr->Get(meta::IDS); + auto elems = dataset_ptr->Get(meta::DIM); + + try { + size_t p_x_size = sizeof(float) * elems; + auto p_x = (float*)malloc(p_x_size); + + auto index_ivf = std::static_pointer_cast(index_); + index_ivf->get_vector_by_id(1, p_data, p_x, bitset_); + + auto ret_ds = std::make_shared(); + ret_ds->Set(meta::TENSOR, p_x); + return ret_ds; + } catch (faiss::FaissException& e) { + KNOWHERE_THROW_MSG(e.what()); + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} +#endif + +void +IVF_NM::Seal() { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + SealImpl(); +} + +VecIndexPtr +IVF_NM::CopyCpuToGpu(const int64_t device_id, const Config& config) { +#ifdef MILVUS_GPU_VERSION + if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) { + ResScope rs(res, device_id, false); + auto gpu_index = + faiss::gpu::index_cpu_to_gpu_without_codes(res->faiss_res.get(), device_id, index_.get(), data_.get()); + + std::shared_ptr device_index; + device_index.reset(gpu_index); + return std::make_shared(device_index, device_id, res); + } else { + KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu_resource"); + } + +#else + KNOWHERE_THROW_MSG("Calling IVF_NM::CopyCpuToGpu when we are using CPU version"); +#endif +} + +void +IVF_NM::GenGraph(const float* data, const int64_t k, GraphType& graph, const Config& config) { + int64_t K = k + 1; + auto ntotal = Count(); + + size_t dim = config[meta::DIM]; + auto batch_size = 1000; + auto tail_batch_size = ntotal % batch_size; + auto batch_search_count = ntotal / batch_size; + auto total_search_count = tail_batch_size == 0 ? batch_search_count : batch_search_count + 1; + + std::vector res_dis(K * batch_size); + graph.resize(ntotal); + GraphType res_vec(total_search_count); + for (int i = 0; i < total_search_count; ++i) { + // it is usually used in NSG::train, to check BuilderSuspend + faiss::BuilderSuspend::check_wait(); + + auto b_size = (i == (total_search_count - 1)) && tail_batch_size != 0 ? tail_batch_size : batch_size; + + auto& res = res_vec[i]; + res.resize(K * b_size); + + auto xq = data + batch_size * dim * i; + QueryImpl(b_size, (float*)xq, K, res_dis.data(), res.data(), config); + + for (int j = 0; j < b_size; ++j) { + auto& node = graph[batch_size * i + j]; + node.resize(k); + auto start_pos = j * K + 1; + for (int m = 0, cursor = start_pos; m < k && cursor < start_pos + k; ++m, ++cursor) { + node[m] = res[cursor]; + } + } + } +} + +std::shared_ptr +IVF_NM::GenParams(const Config& config) { + auto params = std::make_shared(); + params->nprobe = config[IndexParams::nprobe]; + // params->max_codes = config["max_codes"]; + return params; +} + +void +IVF_NM::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) { + auto params = GenParams(config); + auto ivf_index = dynamic_cast(index_.get()); + ivf_index->nprobe = params->nprobe; + stdclock::time_point before = stdclock::now(); + if (params->nprobe > 1 && n <= 4) { + ivf_index->parallel_mode = 1; + } else { + ivf_index->parallel_mode = 0; + } + bool is_sq8 = + (index_type_ == IndexEnum::INDEX_FAISS_IVFSQ8 || index_type_ == IndexEnum::INDEX_FAISS_IVFSQ8NR) ? true : false; + ivf_index->search_without_codes(n, (float*)data, (const uint8_t*)data_.get(), prefix_sum, is_sq8, k, distances, + labels, bitset_); + stdclock::time_point after = stdclock::now(); + double search_cost = (std::chrono::duration(after - before)).count(); + LOG_KNOWHERE_DEBUG_ << "IVF_NM search cost: " << search_cost + << ", quantization cost: " << faiss::indexIVF_stats.quantization_time + << ", data search cost: " << faiss::indexIVF_stats.search_time; + faiss::indexIVF_stats.quantization_time = 0; + faiss::indexIVF_stats.search_time = 0; +} + +void +IVF_NM::SealImpl() { +#ifdef MILVUS_GPU_VERSION + faiss::Index* index = index_.get(); + auto idx = dynamic_cast(index); + if (idx != nullptr) { + idx->to_readonly_without_codes(); + } +#endif +} + +int64_t +IVF_NM::Count() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_->ntotal; +} + +int64_t +IVF_NM::Dim() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_->d; +} + +void +IVF_NM::UpdateIndexSize() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + auto ivf_index = dynamic_cast(index_.get()); + auto nb = ivf_index->invlists->compute_ntotal(); + auto nlist = ivf_index->nlist; + auto code_size = ivf_index->code_size; + // ivf codes, ivf ids and quantizer + index_size_ = nb * code_size + nb * sizeof(int64_t) + nlist * code_size; +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexIVF_NM.h b/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexIVF_NM.h new file mode 100644 index 000000000000..f9702f0e6783 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexIVF_NM.h @@ -0,0 +1,103 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include +#include +#include + +#include + +#include "knowhere/common/Typedef.h" +#include "knowhere/index/vector_index/VecIndex.h" +#include "knowhere/index/vector_offset_index/OffsetBaseIndex.h" + +namespace milvus { +namespace knowhere { + +class IVF_NM : public VecIndex, public OffsetBaseIndex { + public: + IVF_NM() : OffsetBaseIndex(nullptr) { + index_type_ = IndexEnum::INDEX_FAISS_IVFFLAT; + } + + explicit IVF_NM(std::shared_ptr index) : OffsetBaseIndex(std::move(index)) { + index_type_ = IndexEnum::INDEX_FAISS_IVFFLAT; + } + + BinarySet + Serialize(const Config& config = Config()) override; + + void + Load(const BinarySet&) override; + + void + Train(const DatasetPtr&, const Config&) override; + + void + Add(const DatasetPtr&, const Config&) override; + + void + AddWithoutIds(const DatasetPtr&, const Config&) override; + + DatasetPtr + Query(const DatasetPtr&, const Config&) override; + +#if 0 + DatasetPtr + QueryById(const DatasetPtr& dataset, const Config& config) override; +#endif + + int64_t + Count() override; + + int64_t + Dim() override; + + void + UpdateIndexSize() override; + +#if 0 + DatasetPtr + GetVectorById(const DatasetPtr& dataset, const Config& config) override; +#endif + + virtual void + Seal(); + + virtual VecIndexPtr + CopyCpuToGpu(const int64_t, const Config&); + + virtual void + GenGraph(const float* data, const int64_t k, GraphType& graph, const Config& config); + + protected: + virtual std::shared_ptr + GenParams(const Config&); + + virtual void + QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&); + + void + SealImpl() override; + + protected: + std::mutex mutex_; + std::shared_ptr data_ = nullptr; + std::vector prefix_sum; +}; + +using IVFNMPtr = std::shared_ptr; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexNSG_NM.cpp b/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexNSG_NM.cpp new file mode 100644 index 000000000000..0adeac662c9f --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexNSG_NM.cpp @@ -0,0 +1,185 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/common/Timer.h" +#include "knowhere/index/IndexType.h" +#include "knowhere/index/vector_index/IndexIDMAP.h" +#include "knowhere/index/vector_index/IndexIVF.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/impl/nsg/NSGIO.h" +#include "knowhere/index/vector_offset_index/IndexNSG_NM.h" + +#ifdef MILVUS_GPU_VERSION +#include "knowhere/index/vector_index/gpu/IndexGPUIDMAP.h" +#include "knowhere/index/vector_index/gpu/IndexGPUIVF.h" +#include "knowhere/index/vector_index/helpers/Cloner.h" +#endif + +namespace milvus { +namespace knowhere { + +BinarySet +NSG_NM::Serialize(const Config& config) { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + try { + fiu_do_on("NSG_NM.Serialize.throw_exception", throw std::exception()); + std::lock_guard lk(mutex_); + impl::NsgIndex* index = index_.get(); + + MemoryIOWriter writer; + impl::write_index(index, writer); + std::shared_ptr data(writer.data_); + + BinarySet res_set; + res_set.Append("NSG_NM", data, writer.rp); + return res_set; + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +NSG_NM::Load(const BinarySet& index_binary) { + try { + fiu_do_on("NSG_NM.Load.throw_exception", throw std::exception()); + std::lock_guard lk(mutex_); + auto binary = index_binary.GetByName("NSG_NM"); + + MemoryIOReader reader; + reader.total = binary->size; + reader.data_ = binary->data.get(); + + auto index = impl::read_index(reader); + index_.reset(index); + + data_ = index_binary.GetByName(RAW_DATA)->data; + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +DatasetPtr +NSG_NM::Query(const DatasetPtr& dataset_ptr, const Config& config) { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + GET_TENSOR_DATA_DIM(dataset_ptr) + + try { + auto topK = config[meta::TOPK].get(); + auto elems = rows * topK; + size_t p_id_size = sizeof(int64_t) * elems; + size_t p_dist_size = sizeof(float) * elems; + auto p_id = (int64_t*)malloc(p_id_size); + auto p_dist = (float*)malloc(p_dist_size); + + faiss::ConcurrentBitsetPtr blacklist = GetBlacklist(); + + impl::SearchParams s_params; + s_params.search_length = config[IndexParams::search_length]; + s_params.k = config[meta::TOPK]; + { + std::lock_guard lk(mutex_); + // index_->ori_data_ = (float*) data_.get(); + index_->Search((float*)p_data, (float*)data_.get(), rows, dim, topK, p_dist, p_id, s_params, blacklist); + } + + auto ret_ds = std::make_shared(); + ret_ds->Set(meta::IDS, p_id); + ret_ds->Set(meta::DISTANCE, p_dist); + return ret_ds; + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +NSG_NM::Train(const DatasetPtr& dataset_ptr, const Config& config) { + auto idmap = std::make_shared(); + idmap->Train(dataset_ptr, config); + idmap->AddWithoutIds(dataset_ptr, config); + impl::Graph knng; + const float* raw_data = idmap->GetRawVectors(); + const int64_t k = config[IndexParams::knng].get(); +#ifdef MILVUS_GPU_VERSION + const int64_t device_id = config[knowhere::meta::DEVICEID].get(); + if (device_id == -1) { + auto preprocess_index = std::make_shared(); + preprocess_index->Train(dataset_ptr, config); + preprocess_index->AddWithoutIds(dataset_ptr, config); + preprocess_index->GenGraph(raw_data, k, knng, config); + } else { + auto gpu_idx = cloner::CopyCpuToGpu(idmap, device_id, config); + auto gpu_idmap = std::dynamic_pointer_cast(gpu_idx); + gpu_idmap->GenGraph(raw_data, k, knng, config); + } +#else + auto preprocess_index = std::make_shared(); + preprocess_index->Train(dataset_ptr, config); + preprocess_index->AddWithoutIds(dataset_ptr, config); + preprocess_index->GenGraph(raw_data, k, knng, config); +#endif + + impl::BuildParams b_params; + b_params.candidate_pool_size = config[IndexParams::candidate]; + b_params.out_degree = config[IndexParams::out_degree]; + b_params.search_length = config[IndexParams::search_length]; + + auto p_ids = dataset_ptr->Get(meta::IDS); + + GET_TENSOR_DATA_DIM(dataset_ptr) + impl::NsgIndex::Metric_Type metric_type_nsg; + if (config[Metric::TYPE].get() == "IP") { + metric_type_nsg = impl::NsgIndex::Metric_Type::Metric_Type_IP; + } else if (config[Metric::TYPE].get() == "L2") { + metric_type_nsg = impl::NsgIndex::Metric_Type::Metric_Type_L2; + } else { + KNOWHERE_THROW_MSG("either IP or L2"); + } + index_ = std::make_shared(dim, rows, metric_type_nsg); + index_->SetKnnGraph(knng); + index_->Build_with_ids(rows, (float*)p_data, (int64_t*)p_ids, b_params); +} + +int64_t +NSG_NM::Count() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_->ntotal; +} + +int64_t +NSG_NM::Dim() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + return index_->dimension; +} + +void +NSG_NM::UpdateIndexSize() { + if (!index_) { + KNOWHERE_THROW_MSG("index not initialize"); + } + index_size_ = index_->GetSize() + Dim() * Count() * sizeof(float); +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexNSG_NM.h b/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexNSG_NM.h new file mode 100644 index 000000000000..5ae9ddade965 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_offset_index/IndexNSG_NM.h @@ -0,0 +1,83 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/common/Log.h" +#include "knowhere/index/vector_index/VecIndex.h" + +namespace milvus { +namespace knowhere { + +namespace impl { +class NsgIndex; +} + +class NSG_NM : public VecIndex { + public: + explicit NSG_NM(const int64_t gpu_num = -1) : gpu_(gpu_num) { + if (gpu_ >= 0) { + index_mode_ = IndexMode::MODE_GPU; + } + index_type_ = IndexEnum::INDEX_NSG; + } + + BinarySet + Serialize(const Config& config = Config()) override; + + void + Load(const BinarySet&) override; + + void + BuildAll(const DatasetPtr& dataset_ptr, const Config& config) override { + Train(dataset_ptr, config); + } + + void + Train(const DatasetPtr&, const Config&) override; + + void + Add(const DatasetPtr&, const Config&) override { + KNOWHERE_THROW_MSG("Incremental index is not supported"); + } + + void + AddWithoutIds(const DatasetPtr&, const Config&) override { + KNOWHERE_THROW_MSG("Addwithoutids is not supported"); + } + + DatasetPtr + Query(const DatasetPtr&, const Config&) override; + + int64_t + Count() override; + + int64_t + Dim() override; + + void + UpdateIndexSize() override; + + private: + std::mutex mutex_; + int64_t gpu_; + std::shared_ptr index_ = nullptr; + std::shared_ptr data_ = nullptr; +}; + +using NSG_NMIndexPtr = std::shared_ptr(); + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_offset_index/OffsetBaseIndex.cpp b/core/src/index/knowhere/knowhere/index/vector_offset_index/OffsetBaseIndex.cpp new file mode 100644 index 000000000000..186eac4e4345 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_offset_index/OffsetBaseIndex.cpp @@ -0,0 +1,56 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/index/IndexType.h" +#include "knowhere/index/vector_index/helpers/FaissIO.h" +#include "knowhere/index/vector_offset_index/OffsetBaseIndex.h" + +namespace milvus { +namespace knowhere { + +BinarySet +OffsetBaseIndex::SerializeImpl(const IndexType& type) { + try { + fiu_do_on("OffsetBaseIndex.SerializeImpl.throw_exception", throw std::exception()); + faiss::Index* index = index_.get(); + + MemoryIOWriter writer; + faiss::write_index_nm(index, &writer); + std::shared_ptr data(writer.data_); + + BinarySet res_set; + res_set.Append("IVF", data, writer.rp); + return res_set; + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +OffsetBaseIndex::LoadImpl(const BinarySet& binary_set, const IndexType& type) { + auto binary = binary_set.GetByName("IVF"); + + MemoryIOReader reader; + reader.total = binary->size; + reader.data_ = binary->data.get(); + + faiss::Index* index = faiss::read_index_nm(&reader); + index_.reset(index); + + SealImpl(); +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_offset_index/OffsetBaseIndex.h b/core/src/index/knowhere/knowhere/index/vector_offset_index/OffsetBaseIndex.h new file mode 100644 index 000000000000..029d8f9f6947 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_offset_index/OffsetBaseIndex.h @@ -0,0 +1,45 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include + +#include + +#include "knowhere/common/BinarySet.h" +#include "knowhere/index/IndexType.h" + +namespace milvus { +namespace knowhere { + +class OffsetBaseIndex { + protected: + explicit OffsetBaseIndex(std::shared_ptr index) : index_(std::move(index)) { + } + + virtual BinarySet + SerializeImpl(const IndexType& type); + + virtual void + LoadImpl(const BinarySet&, const IndexType& type); + + virtual void + SealImpl() { /* do nothing */ + } + + public: + std::shared_ptr index_ = nullptr; +}; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_offset_index/gpu/IndexGPUIVFSQNR_NM.cpp b/core/src/index/knowhere/knowhere/index/vector_offset_index/gpu/IndexGPUIVFSQNR_NM.cpp new file mode 100644 index 000000000000..23a3b304e84c --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_offset_index/gpu/IndexGPUIVFSQNR_NM.cpp @@ -0,0 +1,75 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/helpers/Cloner.h" +#include "knowhere/index/vector_index/helpers/FaissIO.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" +#include "knowhere/index/vector_offset_index/IndexIVFSQNR_NM.h" +#include "knowhere/index/vector_offset_index/gpu/IndexGPUIVFSQNR_NM.h" + +namespace milvus { +namespace knowhere { + +void +GPUIVFSQNR_NM::Train(const DatasetPtr& dataset_ptr, const Config& config) { + GET_TENSOR_DATA_DIM(dataset_ptr) + gpu_id_ = config[knowhere::meta::DEVICEID]; + + faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get()); + faiss::Index* coarse_quantizer = new faiss::IndexFlat(dim, metric_type); + auto build_index = std::shared_ptr( + new faiss::IndexIVFScalarQuantizer(coarse_quantizer, dim, config[IndexParams::nlist].get(), + faiss::QuantizerType::QT_8bit, metric_type, false)); + + auto gpu_res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_); + if (gpu_res != nullptr) { + ResScope rs(gpu_res, gpu_id_, true); + auto device_index = faiss::gpu::index_cpu_to_gpu(gpu_res->faiss_res.get(), gpu_id_, build_index.get()); + device_index->train(rows, (float*)p_data); + + index_.reset(device_index); + res_ = gpu_res; + } else { + KNOWHERE_THROW_MSG("Build IVFSQNR can't get gpu resource"); + } +} + +VecIndexPtr +GPUIVFSQNR_NM::CopyGpuToCpu(const Config& config) { + std::lock_guard lk(mutex_); + + faiss::Index* device_index = index_.get(); + faiss::Index* host_index_codes = faiss::gpu::index_gpu_to_cpu(device_index); + auto ivfsq_index = dynamic_cast(host_index_codes); + auto ail = dynamic_cast(ivfsq_index->invlists); + auto rol = dynamic_cast(ail->to_readonly()); + faiss::Index* host_index = faiss::gpu::index_gpu_to_cpu_without_codes(device_index); + + std::shared_ptr new_index; + new_index.reset(host_index); + return std::make_shared(new_index, (uint8_t*)rol->get_all_codes()); +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_offset_index/gpu/IndexGPUIVFSQNR_NM.h b/core/src/index/knowhere/knowhere/index/vector_offset_index/gpu/IndexGPUIVFSQNR_NM.h new file mode 100644 index 000000000000..1cf724404313 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_offset_index/gpu/IndexGPUIVFSQNR_NM.h @@ -0,0 +1,43 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include + +#include "knowhere/index/vector_offset_index/gpu/IndexGPUIVF_NM.h" + +namespace milvus { +namespace knowhere { + +class GPUIVFSQNR_NM : public GPUIVF_NM { + public: + explicit GPUIVFSQNR_NM(const int& device_id) : GPUIVF_NM(device_id) { + index_type_ = IndexEnum::INDEX_FAISS_IVFSQ8; + } + + explicit GPUIVFSQNR_NM(std::shared_ptr index, const int64_t device_id, ResPtr& res) + : GPUIVF_NM(std::move(index), device_id, res) { + index_type_ = IndexEnum::INDEX_FAISS_IVFSQ8; + } + + void + Train(const DatasetPtr&, const Config&) override; + + VecIndexPtr + CopyGpuToCpu(const Config&) override; +}; + +using GPUIVFSQNRNMPtr = std::shared_ptr; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_offset_index/gpu/IndexGPUIVF_NM.cpp b/core/src/index/knowhere/knowhere/index/vector_offset_index/gpu/IndexGPUIVF_NM.cpp new file mode 100644 index 000000000000..0f697bd7e635 --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_offset_index/gpu/IndexGPUIVF_NM.cpp @@ -0,0 +1,141 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include + +#include +#include +#include +#include +#include +#include + +#include "knowhere/common/Exception.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_index/helpers/Cloner.h" +#include "knowhere/index/vector_index/helpers/FaissIO.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" +#include "knowhere/index/vector_offset_index/IndexIVF_NM.h" +#include "knowhere/index/vector_offset_index/gpu/IndexGPUIVF_NM.h" + +namespace milvus { +namespace knowhere { + +void +GPUIVF_NM::Train(const DatasetPtr& dataset_ptr, const Config& config) { + GET_TENSOR_DATA_DIM(dataset_ptr) + gpu_id_ = config[knowhere::meta::DEVICEID]; + + auto gpu_res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id_); + if (gpu_res != nullptr) { + ResScope rs(gpu_res, gpu_id_, true); + faiss::gpu::GpuIndexIVFFlatConfig idx_config; + idx_config.device = gpu_id_; + int32_t nlist = config[IndexParams::nlist]; + faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get()); + auto device_index = + new faiss::gpu::GpuIndexIVFFlat(gpu_res->faiss_res.get(), dim, nlist, metric_type, idx_config); + device_index->train(rows, (float*)p_data); + + index_.reset(device_index); + res_ = gpu_res; + } else { + KNOWHERE_THROW_MSG("Build IVF can't get gpu resource"); + } +} + +void +GPUIVF_NM::Add(const DatasetPtr& dataset_ptr, const Config& config) { + if (auto spt = res_.lock()) { + ResScope rs(res_, gpu_id_); + IVF::Add(dataset_ptr, config); + } else { + KNOWHERE_THROW_MSG("Add IVF can't get gpu resource"); + } +} + +void +GPUIVF_NM::Load(const BinarySet& binary_set) { + // not supported +} + +VecIndexPtr +GPUIVF_NM::CopyGpuToCpu(const Config& config) { + std::lock_guard lk(mutex_); + + if (auto device_idx = std::dynamic_pointer_cast(index_)) { + faiss::Index* device_index = index_.get(); + faiss::Index* host_index = faiss::gpu::index_gpu_to_cpu_without_codes(device_index); + + std::shared_ptr new_index; + new_index.reset(host_index); + return std::make_shared(new_index); + } else { + return std::make_shared(index_); + } +} + +VecIndexPtr +GPUIVF_NM::CopyGpuToGpu(const int64_t device_id, const Config& config) { + auto host_index = CopyGpuToCpu(config); + return std::static_pointer_cast(host_index)->CopyCpuToGpu(device_id, config); +} + +BinarySet +GPUIVF_NM::SerializeImpl(const IndexType& type) { + if (!index_ || !index_->is_trained) { + KNOWHERE_THROW_MSG("index not initialize or trained"); + } + + try { + fiu_do_on("GPUIVF_NM.SerializeImpl.throw_exception", throw std::exception()); + MemoryIOWriter writer; + { + faiss::Index* index = index_.get(); + faiss::Index* host_index = faiss::gpu::index_gpu_to_cpu_without_codes(index); + faiss::write_index_nm(host_index, &writer); + delete host_index; + } + std::shared_ptr data(writer.data_); + + BinarySet res_set; + res_set.Append("IVF", data, writer.rp); + + return res_set; + } catch (std::exception& e) { + KNOWHERE_THROW_MSG(e.what()); + } +} + +void +GPUIVF_NM::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) { + std::lock_guard lk(mutex_); + + auto device_index = std::dynamic_pointer_cast(index_); + fiu_do_on("GPUIVF_NM.search_impl.invald_index", device_index = nullptr); + if (device_index) { + device_index->nprobe = config[IndexParams::nprobe]; + ResScope rs(res_, gpu_id_); + + // if query size > 2048 we search by blocks to avoid malloc issue + const int64_t block_size = 2048; + int64_t dim = device_index->d; + for (int64_t i = 0; i < n; i += block_size) { + int64_t search_size = (n - i > block_size) ? block_size : (n - i); + device_index->search(search_size, (float*)data + i * dim, k, distances + i * k, labels + i * k, bitset_); + } + } else { + KNOWHERE_THROW_MSG("Not a GpuIndexIVF type."); + } +} + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/knowhere/knowhere/index/vector_offset_index/gpu/IndexGPUIVF_NM.h b/core/src/index/knowhere/knowhere/index/vector_offset_index/gpu/IndexGPUIVF_NM.h new file mode 100644 index 000000000000..7b4254f200ab --- /dev/null +++ b/core/src/index/knowhere/knowhere/index/vector_offset_index/gpu/IndexGPUIVF_NM.h @@ -0,0 +1,63 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once + +#include +#include + +#include "knowhere/index/vector_index/IndexIVF.h" +#include "knowhere/index/vector_index/gpu/GPUIndex.h" + +namespace milvus { +namespace knowhere { + +class GPUIVF_NM : public IVF, public GPUIndex { + public: + explicit GPUIVF_NM(const int& device_id) : IVF(), GPUIndex(device_id) { + index_mode_ = IndexMode::MODE_GPU; + } + + explicit GPUIVF_NM(std::shared_ptr index, const int64_t device_id, ResPtr& res) + : IVF(std::move(index)), GPUIndex(device_id, res) { + index_mode_ = IndexMode::MODE_GPU; + } + + void + Train(const DatasetPtr&, const Config&) override; + + void + Add(const DatasetPtr&, const Config&) override; + + void + Load(const BinarySet&) override; + + VecIndexPtr + CopyGpuToCpu(const Config&) override; + + VecIndexPtr + CopyGpuToGpu(const int64_t, const Config&) override; + + protected: + BinarySet + SerializeImpl(const IndexType&) override; + + void + QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&) override; + + protected: + uint8_t* arranged_data; +}; + +using GPUIVFNMPtr = std::shared_ptr; + +} // namespace knowhere +} // namespace milvus diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/BKT/Index.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/BKT/Index.h index 0722afc1a8c1..e4c52586c7ab 100644 --- a/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/BKT/Index.h +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/BKT/Index.h @@ -79,6 +79,7 @@ namespace SPTAG ~Index() {} inline SizeType GetNumSamples() const { return m_pSamples.R(); } + inline SizeType GetIndexSize() const { return sizeof(*this); } inline DimensionType GetFeatureDim() const { return m_pSamples.C(); } inline int GetCurrMaxCheck() const { return m_iMaxCheck; } diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/KDT/Index.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/KDT/Index.h index 668d423b5240..f3240ebdb211 100644 --- a/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/KDT/Index.h +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/KDT/Index.h @@ -79,6 +79,7 @@ namespace SPTAG ~Index() {} inline SizeType GetNumSamples() const { return m_pSamples.R(); } + inline SizeType GetIndexSize() const { return sizeof(*this); } inline DimensionType GetFeatureDim() const { return m_pSamples.C(); } inline int GetCurrMaxCheck() const { return m_iMaxCheck; } diff --git a/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/VectorIndex.h b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/VectorIndex.h index 49475794d5f2..b93caf0a9e71 100644 --- a/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/VectorIndex.h +++ b/core/src/index/thirdparty/SPTAG/AnnService/inc/Core/VectorIndex.h @@ -37,6 +37,7 @@ class VectorIndex virtual DimensionType GetFeatureDim() const = 0; virtual SizeType GetNumSamples() const = 0; + virtual SizeType GetIndexSize() const = 0; virtual DistCalcMethod GetDistCalcMethod() const = 0; virtual IndexAlgoType GetIndexAlgoType() const = 0; diff --git a/core/src/index/thirdparty/annoy/src/annoylib.h b/core/src/index/thirdparty/annoy/src/annoylib.h index 00058099c9a0..605137a8ed85 100644 --- a/core/src/index/thirdparty/annoy/src/annoylib.h +++ b/core/src/index/thirdparty/annoy/src/annoylib.h @@ -850,6 +850,7 @@ class AnnoyIndexInterface { virtual void get_item(S item, T* v) const = 0; virtual void set_seed(int q) = 0; virtual bool on_disk_build(const char* filename, char** error=nullptr) = 0; + virtual int64_t cal_size() = 0; }; template @@ -1396,6 +1397,14 @@ template result->push_back(nns_dist[i].second); } } + + int64_t cal_size() { + int64_t ret = 0; + ret += sizeof(*this); + ret += _roots.size() * sizeof(S); + ret += std::max(_n_nodes, _nodes_size) * _s; + return ret; + } }; #endif diff --git a/core/src/index/thirdparty/faiss/AutoTune.h b/core/src/index/thirdparty/faiss/AutoTune.h index d7eff14e64b9..d755844d6d92 100644 --- a/core/src/index/thirdparty/faiss/AutoTune.h +++ b/core/src/index/thirdparty/faiss/AutoTune.h @@ -28,7 +28,7 @@ struct AutoTuneCriterion { typedef Index::idx_t idx_t; idx_t nq; ///< nb of queries this criterion is evaluated on idx_t nnn; ///< nb of NNs that the query should request - idx_t gt_nnn; ///< nb of GT NNs required to evaluate crterion + idx_t gt_nnn; ///< nb of GT NNs required to evaluate criterion std::vector gt_D; ///< Ground-truth distances (size nq * gt_nnn) std::vector gt_I; ///< Ground-truth indexes (size nq * gt_nnn) diff --git a/core/src/index/thirdparty/faiss/Index.cpp b/core/src/index/thirdparty/faiss/Index.cpp index d5748f719f62..b11cfb2683d8 100644 --- a/core/src/index/thirdparty/faiss/Index.cpp +++ b/core/src/index/thirdparty/faiss/Index.cpp @@ -52,6 +52,15 @@ void Index::add_with_ids( FAISS_THROW_MSG ("add_with_ids not implemented for this type of index"); } + +void Index::add_without_codes(idx_t n, const float* x) { + FAISS_THROW_MSG ("add_without_codes not implemented for this type of index"); +} + +void Index::add_with_ids_without_codes(idx_t n, const float* x, const idx_t* xids) { + FAISS_THROW_MSG ("add_with_ids_without_codes not implemented for this type of index"); +} + #if 0 void Index::get_vector_by_id (idx_t n, const idx_t *xid, float *x, ConcurrentBitsetPtr bitset) { FAISS_THROW_MSG ("get_vector_by_id not implemented for this type of index"); diff --git a/core/src/index/thirdparty/faiss/Index.h b/core/src/index/thirdparty/faiss/Index.h index 9a0967962e08..9e8d22dba471 100644 --- a/core/src/index/thirdparty/faiss/Index.h +++ b/core/src/index/thirdparty/faiss/Index.h @@ -94,6 +94,13 @@ struct Index { */ virtual void add (idx_t n, const float *x) = 0; + /** Same as add, but only add ids, not codes + * + * @param n nb of training vectors + * @param x training vecors, size n * d + */ + virtual void add_without_codes(idx_t n, const float* x); + /** Same as add, but stores xids instead of sequential ids. * * The default implementation fails with an assertion, as it is @@ -103,6 +110,12 @@ struct Index { */ virtual void add_with_ids (idx_t n, const float * x, const idx_t *xids); + /** Same as add_with_ids, but only add ids, not codes + * + * @param xids if non-null, ids to store for the vectors (size n) + */ + virtual void add_with_ids_without_codes(idx_t n, const float* x, const idx_t* xids); + /** query n vectors of dimension d to the index. * * return at most k vectors. If there are not enough results for a diff --git a/core/src/index/thirdparty/faiss/IndexIVF.cpp b/core/src/index/thirdparty/faiss/IndexIVF.cpp index fa6f050ca6ee..5b4fc9c8b71b 100644 --- a/core/src/index/thirdparty/faiss/IndexIVF.cpp +++ b/core/src/index/thirdparty/faiss/IndexIVF.cpp @@ -196,6 +196,16 @@ void IndexIVF::add (idx_t n, const float * x) add_with_ids (n, x, nullptr); } +void IndexIVF::add_without_codes (idx_t n, const float * x) +{ + add_with_ids_without_codes (n, x, nullptr); +} + +void IndexIVF::add_with_ids_without_codes (idx_t n, const float * x, const idx_t *xids) +{ + // will be overriden + FAISS_THROW_MSG ("add_with_ids_without_codes not implemented for this type of index"); +} void IndexIVF::add_with_ids (idx_t n, const float * x, const idx_t *xids) { @@ -268,6 +278,13 @@ void IndexIVF::to_readonly() { this->replace_invlists(readonly_lists, true); } +void IndexIVF::to_readonly_without_codes() { + if (is_readonly()) return; + auto readonly_lists = this->invlists->to_readonly_without_codes(); + if (!readonly_lists) return; + this->replace_invlists(readonly_lists, true); +} + bool IndexIVF::is_readonly() const { return this->invlists->is_readonly(); } @@ -316,6 +333,26 @@ void IndexIVF::search (idx_t n, const float *x, idx_t k, indexIVF_stats.search_time += getmillisecs() - t0; } +void IndexIVF::search_without_codes (idx_t n, const float *x, + const uint8_t *arranged_codes, std::vector prefix_sum, + bool is_sq8, idx_t k, float *distances, idx_t *labels, + ConcurrentBitsetPtr bitset) +{ + std::unique_ptr idx(new idx_t[n * nprobe]); + std::unique_ptr coarse_dis(new float[n * nprobe]); + + double t0 = getmillisecs(); + quantizer->search (n, x, nprobe, coarse_dis.get(), idx.get()); + indexIVF_stats.quantization_time += getmillisecs() - t0; + + t0 = getmillisecs(); + invlists->prefetch_lists (idx.get(), n * nprobe); + + search_preassigned_without_codes (n, x, arranged_codes, prefix_sum, is_sq8, k, idx.get(), coarse_dis.get(), + distances, labels, false, nullptr, bitset); + indexIVF_stats.search_time += getmillisecs() - t0; +} + #if 0 void IndexIVF::get_vector_by_id (idx_t n, const idx_t *xid, float *x, ConcurrentBitsetPtr bitset) { make_direct_map(true); @@ -545,7 +582,212 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k, } +void IndexIVF::search_preassigned_without_codes (idx_t n, const float *x, + const uint8_t *arranged_codes, + std::vector prefix_sum, + bool is_sq8, idx_t k, + const idx_t *keys, + const float *coarse_dis , + float *distances, idx_t *labels, + bool store_pairs, + const IVFSearchParameters *params, + ConcurrentBitsetPtr bitset) +{ + long nprobe = params ? params->nprobe : this->nprobe; + long max_codes = params ? params->max_codes : this->max_codes; + + size_t nlistv = 0, ndis = 0, nheap = 0; + + using HeapForIP = CMin; + using HeapForL2 = CMax; + + bool interrupt = false; + + int pmode = this->parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT; + bool do_heap_init = !(this->parallel_mode & PARALLEL_MODE_NO_HEAP_INIT); + + // don't start parallel section if single query + bool do_parallel = + pmode == 0 ? n > 1 : + pmode == 1 ? nprobe > 1 : + nprobe * n > 1; + +#pragma omp parallel if(do_parallel) reduction(+: nlistv, ndis, nheap) + { + InvertedListScanner *scanner = get_InvertedListScanner(store_pairs); + ScopeDeleter1 del(scanner); + + /***************************************************** + * Depending on parallel_mode, there are two possible ways + * to organize the search. Here we define local functions + * that are in common between the two + ******************************************************/ + + // intialize + reorder a result heap + + auto init_result = [&](float *simi, idx_t *idxi) { + if (!do_heap_init) return; + if (metric_type == METRIC_INNER_PRODUCT) { + heap_heapify (k, simi, idxi); + } else { + heap_heapify (k, simi, idxi); + } + }; + auto reorder_result = [&] (float *simi, idx_t *idxi) { + if (!do_heap_init) return; + if (metric_type == METRIC_INNER_PRODUCT) { + heap_reorder (k, simi, idxi); + } else { + heap_reorder (k, simi, idxi); + } + }; + + // single list scan using the current scanner (with query + // set porperly) and storing results in simi and idxi + auto scan_one_list = [&] (idx_t key, float coarse_dis_i, const uint8_t *arranged_codes, + float *simi, idx_t *idxi, ConcurrentBitsetPtr bitset) { + + if (key < 0) { + // not enough centroids for multiprobe + return (size_t)0; + } + FAISS_THROW_IF_NOT_FMT (key < (idx_t) nlist, + "Invalid key=%ld nlist=%ld\n", + key, nlist); + + size_t list_size = invlists->list_size(key); + size_t offset = prefix_sum[key]; + + // don't waste time on empty lists + if (list_size == 0) { + return (size_t)0; + } + + scanner->set_list (key, coarse_dis_i); + + nlistv++; + + InvertedLists::ScopedCodes scodes (invlists, key, arranged_codes); + + std::unique_ptr sids; + const Index::idx_t * ids = nullptr; + + if (!store_pairs) { + sids.reset (new InvertedLists::ScopedIds (invlists, key)); + ids = sids->get(); + } + + size_t size = is_sq8 ? sizeof(uint8_t) : sizeof(float); + nheap += scanner->scan_codes (list_size, (const uint8_t *) (scodes.get() + d * offset * size), + ids, simi, idxi, k, bitset); + + return list_size; + }; + + /**************************************************** + * Actual loops, depending on parallel_mode + ****************************************************/ + + if (pmode == 0) { + +#pragma omp for + for (size_t i = 0; i < n; i++) { + + if (interrupt) { + continue; + } + + // loop over queries + scanner->set_query (x + i * d); + float * simi = distances + i * k; + idx_t * idxi = labels + i * k; + + init_result (simi, idxi); + + long nscan = 0; + + // loop over probes + for (size_t ik = 0; ik < nprobe; ik++) { + + nscan += scan_one_list ( + keys [i * nprobe + ik], + coarse_dis[i * nprobe + ik], + arranged_codes, + simi, idxi, bitset + ); + + if (max_codes && nscan >= max_codes) { + break; + } + } + + ndis += nscan; + reorder_result (simi, idxi); + + if (InterruptCallback::is_interrupted ()) { + interrupt = true; + } + + } // parallel for + } else if (pmode == 1) { + std::vector local_idx (k); + std::vector local_dis (k); + + for (size_t i = 0; i < n; i++) { + scanner->set_query (x + i * d); + init_result (local_dis.data(), local_idx.data()); + +#pragma omp for schedule(dynamic) + for (size_t ik = 0; ik < nprobe; ik++) { + ndis += scan_one_list + (keys [i * nprobe + ik], + coarse_dis[i * nprobe + ik], + arranged_codes, + local_dis.data(), local_idx.data(), bitset); + + // can't do the test on max_codes + } + // merge thread-local results + + float * simi = distances + i * k; + idx_t * idxi = labels + i * k; +#pragma omp single + init_result (simi, idxi); + +#pragma omp barrier +#pragma omp critical + { + if (metric_type == METRIC_INNER_PRODUCT) { + heap_addn + (k, simi, idxi, + local_dis.data(), local_idx.data(), k); + } else { + heap_addn + (k, simi, idxi, + local_dis.data(), local_idx.data(), k); + } + } +#pragma omp barrier +#pragma omp single + reorder_result (simi, idxi); + } + } else { + FAISS_THROW_FMT ("parallel_mode %d not supported\n", + pmode); + } + } // parallel section + + if (interrupt) { + FAISS_THROW_MSG ("computation interrupted"); + } + + indexIVF_stats.nq += n; + indexIVF_stats.nlist += nlistv; + indexIVF_stats.ndis += ndis; + indexIVF_stats.nheap_updates += nheap; + +} void IndexIVF::range_search (idx_t nx, const float *x, float radius, RangeSearchResult *result, diff --git a/core/src/index/thirdparty/faiss/IndexIVF.h b/core/src/index/thirdparty/faiss/IndexIVF.h index 744f27f333f0..a7d2af1f8aba 100644 --- a/core/src/index/thirdparty/faiss/IndexIVF.h +++ b/core/src/index/thirdparty/faiss/IndexIVF.h @@ -139,9 +139,15 @@ struct IndexIVF: Index, Level1Quantizer { /// Calls add_with_ids with NULL ids void add(idx_t n, const float* x) override; + /// Calls add_with_ids_without_codes + void add_without_codes(idx_t n, const float* x) override; + /// default implementation that calls encode_vectors void add_with_ids(idx_t n, const float* x, const idx_t* xids) override; + /// Implementation for adding without original vector data + void add_with_ids_without_codes(idx_t n, const float* x, const idx_t* xids) override; + /** Encodes a set of vectors as they would appear in the inverted lists * * @param list_nos inverted list ids as returned by the @@ -187,11 +193,29 @@ struct IndexIVF: Index, Level1Quantizer { ConcurrentBitsetPtr bitset = nullptr ) const; + /** Similar to search_preassigned, but does not store codes **/ + virtual void search_preassigned_without_codes (idx_t n, const float *x, + const uint8_t *arranged_codes, + std::vector prefix_sum, + bool is_sq8, idx_t k, + const idx_t *assign, + const float *centroid_dis, + float *distances, idx_t *labels, + bool store_pairs, + const IVFSearchParameters *params = nullptr, + ConcurrentBitsetPtr bitset = nullptr); + /** assign the vectors, then call search_preassign */ void search (idx_t n, const float *x, idx_t k, float *distances, idx_t *labels, ConcurrentBitsetPtr bitset = nullptr) const override; + /** Similar to search, but does not store codes **/ + void search_without_codes (idx_t n, const float *x, + const uint8_t *arranged_codes, std::vector prefix_sum, + bool is_sq8, idx_t k, float *distances, idx_t *labels, + ConcurrentBitsetPtr bitset = nullptr); + #if 0 /** get raw vectors by ids */ void get_vector_by_id (idx_t n, const idx_t *xid, float *x, ConcurrentBitsetPtr bitset = nullptr) override; @@ -286,6 +310,7 @@ struct IndexIVF: Index, Level1Quantizer { idx_t a1, idx_t a2) const; virtual void to_readonly(); + virtual void to_readonly_without_codes(); virtual bool is_readonly() const; virtual void backup_quantizer(); diff --git a/core/src/index/thirdparty/faiss/IndexIVFFlat.cpp b/core/src/index/thirdparty/faiss/IndexIVFFlat.cpp index 2846990f9f70..147263750f53 100644 --- a/core/src/index/thirdparty/faiss/IndexIVFFlat.cpp +++ b/core/src/index/thirdparty/faiss/IndexIVFFlat.cpp @@ -39,6 +39,40 @@ void IndexIVFFlat::add_with_ids (idx_t n, const float * x, const idx_t *xids) add_core (n, x, xids, nullptr); } +// Add ids only, vectors not added to Index. +void IndexIVFFlat::add_with_ids_without_codes(idx_t n, const float* x, const idx_t* xids) +{ + FAISS_THROW_IF_NOT (is_trained); + assert (invlists); + direct_map.check_can_add (xids); + const int64_t * idx; + ScopeDeleter del; + + int64_t * idx0 = new int64_t [n]; + del.set (idx0); + quantizer->assign (n, x, idx0); + idx = idx0; + + int64_t n_add = 0; + for (size_t i = 0; i < n; i++) { + idx_t id = xids ? xids[i] : ntotal + i; + idx_t list_no = idx [i]; + size_t offset; + + if (list_no >= 0) { + const float *xi = x + i * d; + offset = invlists->add_entry_without_codes ( + list_no, id); + n_add++; + } else { + offset = 0; + } + direct_map.add_single_id (id, list_no, offset); + } + + ntotal += n; +} + void IndexIVFFlat::add_core (idx_t n, const float * x, const int64_t *xids, const int64_t *precomputed_idx) diff --git a/core/src/index/thirdparty/faiss/IndexIVFFlat.h b/core/src/index/thirdparty/faiss/IndexIVFFlat.h index 3c5777a1c29c..74b0b4c0ec8a 100644 --- a/core/src/index/thirdparty/faiss/IndexIVFFlat.h +++ b/core/src/index/thirdparty/faiss/IndexIVFFlat.h @@ -35,6 +35,9 @@ struct IndexIVFFlat: IndexIVF { /// implemented for all IndexIVF* classes void add_with_ids(idx_t n, const float* x, const idx_t* xids) override; + /// implemented for all IndexIVF* classes + void add_with_ids_without_codes(idx_t n, const float* x, const idx_t* xids) override; + void encode_vectors(idx_t n, const float* x, const idx_t *list_nos, uint8_t * codes, diff --git a/core/src/index/thirdparty/faiss/IndexScalarQuantizer.cpp b/core/src/index/thirdparty/faiss/IndexScalarQuantizer.cpp index d96612daefe8..38cd28a6cd3a 100644 --- a/core/src/index/thirdparty/faiss/IndexScalarQuantizer.cpp +++ b/core/src/index/thirdparty/faiss/IndexScalarQuantizer.cpp @@ -294,7 +294,41 @@ void IndexIVFScalarQuantizer::add_with_ids } +void IndexIVFScalarQuantizer::add_with_ids_without_codes + (idx_t n, const float * x, const idx_t *xids) +{ + FAISS_THROW_IF_NOT (is_trained); + std::unique_ptr idx (new int64_t [n]); + quantizer->assign (n, x, idx.get()); + size_t nadd = 0; + std::unique_ptr squant(sq.select_quantizer ()); + + DirectMapAdd dm_add (direct_map, n, xids); + +#pragma omp parallel reduction(+: nadd) + { + int nt = omp_get_num_threads(); + int rank = omp_get_thread_num(); + + // each thread takes care of a subset of lists + for (size_t i = 0; i < n; i++) { + int64_t list_no = idx [i]; + if (list_no >= 0 && list_no % nt == rank) { + int64_t id = xids ? xids[i] : ntotal + i; + size_t ofs = invlists->add_entry_without_codes (list_no, id); + + dm_add.add (i, list_no, ofs); + nadd++; + + } else if (rank == 0 && list_no == -1) { + dm_add.add (i, -1, 0); + } + } + } + + ntotal += n; +} InvertedListScanner* IndexIVFScalarQuantizer::get_InvertedListScanner diff --git a/core/src/index/thirdparty/faiss/IndexScalarQuantizer.h b/core/src/index/thirdparty/faiss/IndexScalarQuantizer.h index feb0e8314f52..03b9abf37dec 100644 --- a/core/src/index/thirdparty/faiss/IndexScalarQuantizer.h +++ b/core/src/index/thirdparty/faiss/IndexScalarQuantizer.h @@ -107,6 +107,8 @@ struct IndexIVFScalarQuantizer: IndexIVF { bool include_listnos=false) const override; void add_with_ids(idx_t n, const float* x, const idx_t* xids) override; + + void add_with_ids_without_codes(idx_t n, const float* x, const idx_t* xids) override; InvertedListScanner *get_InvertedListScanner (bool store_pairs) const override; diff --git a/core/src/index/thirdparty/faiss/InvertedLists.cpp b/core/src/index/thirdparty/faiss/InvertedLists.cpp index 59f5d1e7cba2..cc79ab47dc61 100644 --- a/core/src/index/thirdparty/faiss/InvertedLists.cpp +++ b/core/src/index/thirdparty/faiss/InvertedLists.cpp @@ -108,6 +108,15 @@ size_t InvertedLists::add_entry (size_t list_no, idx_t theid, return add_entries (list_no, 1, &theid, code); } +size_t InvertedLists::add_entry_without_codes (size_t list_no, idx_t theid) +{ + return add_entries_without_codes (list_no, 1, &theid); +} + +size_t InvertedLists::add_entries_without_codes (size_t list_no, size_t n_entry, + const idx_t* ids) +{} + void InvertedLists::update_entry (size_t list_no, size_t offset, idx_t id, const uint8_t *code) { @@ -118,6 +127,10 @@ InvertedLists* InvertedLists::to_readonly() { return nullptr; } +InvertedLists* InvertedLists::to_readonly_without_codes() { + return nullptr; +} + bool InvertedLists::is_readonly() const { return false; } @@ -210,6 +223,18 @@ size_t ArrayInvertedLists::add_entries ( return o; } +size_t ArrayInvertedLists::add_entries_without_codes ( + size_t list_no, size_t n_entry, + const idx_t* ids_in) +{ + if (n_entry == 0) return 0; + assert (list_no < nlist); + size_t o = ids [list_no].size(); + ids [list_no].resize (o + n_entry); + memcpy (&ids[list_no][o], ids_in, sizeof (ids_in[0]) * n_entry); + return o; +} + size_t ArrayInvertedLists::list_size(size_t list_no) const { assert (list_no < nlist); @@ -250,6 +275,11 @@ InvertedLists* ArrayInvertedLists::to_readonly() { return readonly; } +InvertedLists* ArrayInvertedLists::to_readonly_without_codes() { + ReadOnlyArrayInvertedLists* readonly = new ReadOnlyArrayInvertedLists(*this, true); + return readonly; +} + ArrayInvertedLists::~ArrayInvertedLists () {} @@ -325,26 +355,43 @@ ReadOnlyArrayInvertedLists::ReadOnlyArrayInvertedLists(const ArrayInvertedLists& valid = true; } -//ReadOnlyArrayInvertedLists::ReadOnlyArrayInvertedLists(const ReadOnlyArrayInvertedLists &other) -// : InvertedLists (other.nlist, other.code_size) { -// readonly_length = other.readonly_length; -// readonly_offset = other.readonly_offset; -// pin_readonly_codes = std::make_shared(*other.pin_readonly_codes); -// pin_readonly_ids = std::make_shared(*other.pin_readonly_ids); -// valid = true; -//} - -//ReadOnlyArrayInvertedLists::ReadOnlyArrayInvertedLists(ReadOnlyArrayInvertedLists &&other) -// : InvertedLists (other.nlist, other.code_size) { -// readonly_length = std::move(other.readonly_length); -// readonly_offset = std::move(other.readonly_offset); -// pin_readonly_codes = other.pin_readonly_codes; -// pin_readonly_ids = other.pin_readonly_ids; -// -// other.pin_readonly_codes = nullptr; -// other.pin_readonly_ids = nullptr; -// valid = true; -//} +ReadOnlyArrayInvertedLists::ReadOnlyArrayInvertedLists(const ArrayInvertedLists& other, bool offset_only) + : InvertedLists (other.nlist, other.code_size) { + readonly_length.resize(nlist); + readonly_offset.resize(nlist); + size_t offset = 0; + for (auto i = 0; i < other.ids.size(); i++) { + auto& list_ids = other.ids[i]; + readonly_length[i] = list_ids.size(); + readonly_offset[i] = offset; + offset += list_ids.size(); + } + +#ifdef USE_CPU + for (auto i = 0; i < other.ids.size(); i++) { + auto& list_ids = other.ids[i]; + readonly_ids.insert(readonly_ids.end(), list_ids.begin(), list_ids.end()); + } +#else + size_t ids_size = offset * sizeof(idx_t); + size_t codes_size = offset * (this->code_size) * sizeof(uint8_t); + pin_readonly_codes = std::make_shared(codes_size); + pin_readonly_ids = std::make_shared(ids_size); + + offset = 0; + for (auto i = 0; i < other.ids.size(); i++) { + auto& list_ids = other.ids[i]; + + uint8_t* ids_ptr = (uint8_t*)(pin_readonly_ids->data) + offset * sizeof(idx_t); + memcpy(ids_ptr, list_ids.data(), list_ids.size() * sizeof(idx_t)); + + offset += list_ids.size(); + } +#endif + + valid = true; +} + ReadOnlyArrayInvertedLists::~ReadOnlyArrayInvertedLists() { } @@ -361,6 +408,13 @@ size_t ReadOnlyArrayInvertedLists::add_entries ( FAISS_THROW_MSG ("not implemented"); } +size_t ReadOnlyArrayInvertedLists::add_entries_without_codes ( + size_t , size_t , + const idx_t*) +{ + FAISS_THROW_MSG ("not implemented"); +} + void ReadOnlyArrayInvertedLists::update_entries (size_t, size_t , size_t , const idx_t *, const uint8_t *) { @@ -440,6 +494,13 @@ size_t ReadOnlyInvertedLists::add_entries ( FAISS_THROW_MSG ("not implemented"); } +size_t ReadOnlyInvertedLists::add_entries_without_codes ( + size_t , size_t , + const idx_t*) +{ + FAISS_THROW_MSG ("not implemented"); +} + void ReadOnlyInvertedLists::update_entries (size_t, size_t , size_t , const idx_t *, const uint8_t *) { diff --git a/core/src/index/thirdparty/faiss/InvertedLists.h b/core/src/index/thirdparty/faiss/InvertedLists.h index ec77d2cb186b..c57b7b6961ed 100644 --- a/core/src/index/thirdparty/faiss/InvertedLists.h +++ b/core/src/index/thirdparty/faiss/InvertedLists.h @@ -111,6 +111,12 @@ struct InvertedLists { size_t list_no, size_t n_entry, const idx_t* ids, const uint8_t *code) = 0; + /// add one entry to an inverted list without codes + virtual size_t add_entry_without_codes (size_t list_no, idx_t theid); + + virtual size_t add_entries_without_codes ( size_t list_no, size_t n_entry, + const idx_t* ids); + virtual void update_entry (size_t list_no, size_t offset, idx_t id, const uint8_t *code); @@ -123,6 +129,8 @@ struct InvertedLists { virtual InvertedLists* to_readonly(); + virtual InvertedLists* to_readonly_without_codes(); + virtual bool is_readonly() const; /// move all entries from oivf (empty on output) @@ -197,6 +205,11 @@ struct InvertedLists { list_no (list_no) {} + // For codes outside + ScopedCodes (const InvertedLists *il, size_t list_no, const uint8_t *original_codes): + il (il), codes (original_codes), list_no (list_no) + {} + const uint8_t *get() {return codes; } ~ScopedCodes () { @@ -223,6 +236,10 @@ struct ArrayInvertedLists: InvertedLists { size_t list_no, size_t n_entry, const idx_t* ids, const uint8_t *code) override; + size_t add_entries_without_codes ( + size_t list_no, size_t n_entry, + const idx_t* ids) override; + void update_entries (size_t list_no, size_t offset, size_t n_entry, const idx_t *ids, const uint8_t *code) override; @@ -230,6 +247,8 @@ struct ArrayInvertedLists: InvertedLists { InvertedLists* to_readonly() override; + InvertedLists* to_readonly_without_codes() override; + virtual ~ArrayInvertedLists (); }; @@ -248,6 +267,7 @@ struct ReadOnlyArrayInvertedLists: InvertedLists { ReadOnlyArrayInvertedLists(size_t nlist, size_t code_size, const std::vector& list_length); explicit ReadOnlyArrayInvertedLists(const ArrayInvertedLists& other); + explicit ReadOnlyArrayInvertedLists(const ArrayInvertedLists& other, bool offset); // Use default copy construct, just copy pointer, DON'T COPY pin_readonly_codes AND pin_readonly_ids // explicit ReadOnlyArrayInvertedLists(const ReadOnlyArrayInvertedLists &); @@ -266,6 +286,10 @@ struct ReadOnlyArrayInvertedLists: InvertedLists { size_t list_no, size_t n_entry, const idx_t* ids, const uint8_t *code) override; + size_t add_entries_without_codes ( + size_t list_no, size_t n_entry, + const idx_t* ids) override; + void update_entries (size_t list_no, size_t offset, size_t n_entry, const idx_t *ids, const uint8_t *code) override; @@ -292,6 +316,10 @@ struct ReadOnlyInvertedLists: InvertedLists { size_t list_no, size_t n_entry, const idx_t* ids, const uint8_t *code) override; + size_t add_entries_without_codes ( + size_t list_no, size_t n_entry, + const idx_t* ids) override; + void update_entries (size_t list_no, size_t offset, size_t n_entry, const idx_t *ids, const uint8_t *code) override; diff --git a/core/src/index/thirdparty/faiss/gpu/GpuCloner.cpp b/core/src/index/thirdparty/faiss/gpu/GpuCloner.cpp index 192c02db42c9..c273d7ad8f3b 100644 --- a/core/src/index/thirdparty/faiss/gpu/GpuCloner.cpp +++ b/core/src/index/thirdparty/faiss/gpu/GpuCloner.cpp @@ -108,12 +108,33 @@ Index *ToCPUCloner::clone_Index(const Index *index) } } +Index *ToCPUCloner::clone_Index_Without_Codes(const Index *index) +{ + if(auto ifl = dynamic_cast(index)) { + IndexIVFFlat *res = new IndexIVFFlat(); + ifl->copyToWithoutCodes(res); + return res; + } else if(auto ifl = + dynamic_cast(index)) { + IndexIVFScalarQuantizer *res = new IndexIVFScalarQuantizer(); + ifl->copyToWithoutCodes(res); + return res; + } else { + return Cloner::clone_Index(index); + } +} + faiss::Index * index_gpu_to_cpu(const faiss::Index *gpu_index) { ToCPUCloner cl; return cl.clone_Index(gpu_index); } +faiss::Index * index_gpu_to_cpu_without_codes(const faiss::Index *gpu_index) +{ + ToCPUCloner cl; + return cl.clone_Index_Without_Codes(gpu_index); +} @@ -256,6 +277,60 @@ Index *ToGpuCloner::clone_Index(const Index *index) return res; } else { return Cloner::clone_Index(index); + + } + +} + + +Index *ToGpuCloner::clone_Index_Without_Codes(const Index *index, const uint8_t *arranged_data) +{ + auto ivf_sqh = dynamic_cast(index); + if(ivf_sqh) { + // should not happen + } else if(auto ifl = dynamic_cast(index)) { + GpuIndexIVFFlatConfig config; + config.device = device; + config.indicesOptions = indicesOptions; + config.flatConfig.useFloat16 = useFloat16CoarseQuantizer; + config.flatConfig.storeTransposed = storeTransposed; + + GpuIndexIVFFlat *res = + new GpuIndexIVFFlat(resources, + ifl->d, + ifl->nlist, + ifl->metric_type, + config); + if(reserveVecs > 0 && ifl->ntotal == 0) { + res->reserveMemory(reserveVecs); + } + + res->copyFromWithoutCodes(ifl, arranged_data); + return res; + } else if(auto ifl = + dynamic_cast(index)) { + GpuIndexIVFScalarQuantizerConfig config; + config.device = device; + config.indicesOptions = indicesOptions; + config.flatConfig.useFloat16 = useFloat16CoarseQuantizer; + config.flatConfig.storeTransposed = storeTransposed; + + GpuIndexIVFScalarQuantizer *res = + new GpuIndexIVFScalarQuantizer(resources, + ifl->d, + ifl->nlist, + ifl->sq.qtype, + ifl->metric_type, + ifl->by_residual, + config); + if(reserveVecs > 0 && ifl->ntotal == 0) { + res->reserveMemory(reserveVecs); + } + + res->copyFromWithoutCodes(ifl, arranged_data); + return res; + } else { + return Cloner::clone_Index(index); } } @@ -270,6 +345,17 @@ faiss::Index * index_cpu_to_gpu( return cl.clone_Index(index); } +faiss::Index * index_cpu_to_gpu_without_codes( + GpuResources* resources, int device, + const faiss::Index *index, + const uint8_t *arranged_data, + const GpuClonerOptions *options) +{ + GpuClonerOptions defaults; + ToGpuCloner cl(resources, device, options ? *options : defaults); + return cl.clone_Index_Without_Codes(index, arranged_data); +} + faiss::Index * index_cpu_to_gpu( GpuResources* resources, int device, IndexComposition* index_composition, diff --git a/core/src/index/thirdparty/faiss/gpu/GpuCloner.h b/core/src/index/thirdparty/faiss/gpu/GpuCloner.h index f2c5388d937d..c01029279e44 100644 --- a/core/src/index/thirdparty/faiss/gpu/GpuCloner.h +++ b/core/src/index/thirdparty/faiss/gpu/GpuCloner.h @@ -23,7 +23,10 @@ class GpuResources; /// Cloner specialized for GPU -> CPU struct ToCPUCloner: faiss::Cloner { void merge_index(Index *dst, Index *src, bool successive_ids); + Index *clone_Index(const Index *index) override; + + Index *clone_Index_Without_Codes(const Index *index); }; @@ -38,6 +41,8 @@ struct ToGpuCloner: faiss::Cloner, GpuClonerOptions { Index *clone_Index(const Index *index) override; Index *clone_Index (IndexComposition* index_composition) override; + + Index *clone_Index_Without_Codes(const Index *index, const uint8_t *arranged_data); }; /// Cloner specialized for CPU -> multiple GPUs @@ -66,12 +71,20 @@ struct ToGpuClonerMultiple: faiss::Cloner, GpuMultipleClonerOptions { /// converts any GPU index inside gpu_index to a CPU index faiss::Index * index_gpu_to_cpu(const faiss::Index *gpu_index); +faiss::Index * index_gpu_to_cpu_without_codes(const faiss::Index *gpu_index); + /// converts any CPU index that can be converted to GPU faiss::Index * index_cpu_to_gpu( GpuResources* resources, int device, const faiss::Index *index, const GpuClonerOptions *options = nullptr); +faiss::Index * index_cpu_to_gpu_without_codes( + GpuResources* resources, int device, + const faiss::Index *index, + const uint8_t *arranged_data, + const GpuClonerOptions *options = nullptr); + faiss::Index * index_cpu_to_gpu( GpuResources* resources, int device, IndexComposition* index_composition, diff --git a/core/src/index/thirdparty/faiss/gpu/GpuIndexIVF.cu b/core/src/index/thirdparty/faiss/gpu/GpuIndexIVF.cu index 8e873c191473..130e95f86605 100644 --- a/core/src/index/thirdparty/faiss/gpu/GpuIndexIVF.cu +++ b/core/src/index/thirdparty/faiss/gpu/GpuIndexIVF.cu @@ -27,6 +27,7 @@ GpuIndexIVF::GpuIndexIVF(GpuResources* resources, nlist(nlistIn), nprobe(1), quantizer(nullptr) { + init_(); // Only IP and L2 are supported for now diff --git a/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFFlat.cu b/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFFlat.cu index 6ca7c70ffb98..938ac989aec1 100644 --- a/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFFlat.cu +++ b/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFFlat.cu @@ -6,258 +6,325 @@ */ -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace faiss { namespace gpu { - -GpuIndexIVFFlat::GpuIndexIVFFlat(GpuResources* resources, - const faiss::IndexIVFFlat* index, - GpuIndexIVFFlatConfig config) : - GpuIndexIVF(resources, - index->d, - index->metric_type, - index->metric_arg, - index->nlist, - config), - ivfFlatConfig_(config), - reserveMemoryVecs_(0), - index_(nullptr) { - copyFrom(index); -} - -GpuIndexIVFFlat::GpuIndexIVFFlat(GpuResources* resources, - int dims, - int nlist, - faiss::MetricType metric, - GpuIndexIVFFlatConfig config) : - GpuIndexIVF(resources, dims, metric, 0, nlist, config), - ivfFlatConfig_(config), - reserveMemoryVecs_(0), - index_(nullptr) { - - // faiss::Index params - this->is_trained = false; - - // We haven't trained ourselves, so don't construct the IVFFlat - // index yet -} - -GpuIndexIVFFlat::~GpuIndexIVFFlat() { - delete index_; -} - -void -GpuIndexIVFFlat::reserveMemory(size_t numVecs) { - reserveMemoryVecs_ = numVecs; - if (index_) { - DeviceScope scope(device_); - index_->reserveMemory(numVecs); - } -} - -void -GpuIndexIVFFlat::copyFrom(const faiss::IndexIVFFlat* index) { - DeviceScope scope(device_); - - GpuIndexIVF::copyFrom(index); - - // Clear out our old data - delete index_; - index_ = nullptr; - - // The other index might not be trained - if (!index->is_trained) { - FAISS_ASSERT(!is_trained); - return; - } - - // Otherwise, we can populate ourselves from the other index - FAISS_ASSERT(is_trained); - - // Copy our lists as well - index_ = new IVFFlat(resources_, - quantizer->getGpuData(), - index->metric_type, - index->metric_arg, - false, // no residual - nullptr, // no scalar quantizer - ivfFlatConfig_.indicesOptions, - memorySpace_); - InvertedLists *ivf = index->invlists; - - if (ReadOnlyArrayInvertedLists* rol = dynamic_cast(ivf)) { - index_->copyCodeVectorsFromCpu((const float* )(rol->pin_readonly_codes->data), - (const long *)(rol->pin_readonly_ids->data), rol->readonly_length); - /* double t0 = getmillisecs(); */ - /* std::cout << "Readonly Takes " << getmillisecs() - t0 << " ms" << std::endl; */ - } else { - for (size_t i = 0; i < ivf->nlist; ++i) { - auto numVecs = ivf->list_size(i); - - // GPU index can only support max int entries per list - FAISS_THROW_IF_NOT_FMT(numVecs <= - (size_t) std::numeric_limits::max(), - "GPU inverted list can only support " - "%zu entries; %zu found", - (size_t) std::numeric_limits::max(), - numVecs); - - index_->addCodeVectorsFromCpu(i, - (const unsigned char*)(ivf->get_codes(i)), - ivf->get_ids(i), - numVecs); - } - } -} - -void -GpuIndexIVFFlat::copyTo(faiss::IndexIVFFlat* index) const { - DeviceScope scope(device_); - - // We must have the indices in order to copy to ourselves - FAISS_THROW_IF_NOT_MSG(ivfFlatConfig_.indicesOptions != INDICES_IVF, - "Cannot copy to CPU as GPU index doesn't retain " - "indices (INDICES_IVF)"); - - GpuIndexIVF::copyTo(index); - index->code_size = this->d * sizeof(float); - - InvertedLists *ivf = new ArrayInvertedLists(nlist, index->code_size); - index->replace_invlists(ivf, true); - - // Copy the inverted lists - if (index_) { - for (int i = 0; i < nlist; ++i) { - auto listIndices = index_->getListIndices(i); - auto listData = index_->getListVectors(i); - - ivf->add_entries(i, - listIndices.size(), - listIndices.data(), - (const uint8_t*) listData.data()); - } - } -} - -size_t -GpuIndexIVFFlat::reclaimMemory() { - if (index_) { - DeviceScope scope(device_); - - return index_->reclaimMemory(); - } - - return 0; -} - -void -GpuIndexIVFFlat::reset() { - if (index_) { - DeviceScope scope(device_); - - index_->reset(); - this->ntotal = 0; - } else { - FAISS_ASSERT(this->ntotal == 0); - } -} - -void -GpuIndexIVFFlat::train(Index::idx_t n, const float* x) { - DeviceScope scope(device_); - - if (this->is_trained) { - FAISS_ASSERT(quantizer->is_trained); - FAISS_ASSERT(quantizer->ntotal == nlist); - FAISS_ASSERT(index_); - return; - } - - FAISS_ASSERT(!index_); - - trainQuantizer_(n, x); - - // The quantizer is now trained; construct the IVF index - index_ = new IVFFlat(resources_, - quantizer->getGpuData(), - this->metric_type, - this->metric_arg, - false, // no residual - nullptr, // no scalar quantizer - ivfFlatConfig_.indicesOptions, - memorySpace_); - - if (reserveMemoryVecs_) { - index_->reserveMemory(reserveMemoryVecs_); - } - - this->is_trained = true; -} - -void -GpuIndexIVFFlat::addImpl_(int n, - const float* x, - const Index::idx_t* xids) { - // Device is already set in GpuIndex::add - FAISS_ASSERT(index_); - FAISS_ASSERT(n > 0); - - auto stream = resources_->getDefaultStream(device_); - - // Data is already resident on the GPU - Tensor data(const_cast(x), {n, (int) this->d}); - - static_assert(sizeof(long) == sizeof(Index::idx_t), "size mismatch"); - Tensor labels(const_cast(xids), {n}); - - // Not all vectors may be able to be added (some may contain NaNs etc) - index_->classifyAndAddVectors(data, labels); - - // but keep the ntotal based on the total number of vectors that we attempted - // to add - ntotal += n; -} - -void -GpuIndexIVFFlat::searchImpl_(int n, - const float* x, - int k, - float* distances, - Index::idx_t* labels, - ConcurrentBitsetPtr bitset) const { - // Device is already set in GpuIndex::search - FAISS_ASSERT(index_); - FAISS_ASSERT(n > 0); - - auto stream = resources_->getDefaultStream(device_); - - // Data is already resident on the GPU - Tensor queries(const_cast(x), {n, (int) this->d}); - Tensor outDistances(distances, {n, k}); - - static_assert(sizeof(long) == sizeof(Index::idx_t), "size mismatch"); - Tensor outLabels(const_cast(labels), {n, k}); - - if (!bitset) { - auto bitsetDevice = toDevice(resources_, device_, nullptr, stream, {0}); - index_->query(queries, bitsetDevice, nprobe, k, outDistances, outLabels); - } else { - auto bitsetDevice = toDevice(resources_, device_, - const_cast(bitset->data()), stream, - {(int) bitset->size()}); - index_->query(queries, bitsetDevice, nprobe, k, outDistances, outLabels); - } -} - - -} } // namespace + #include + #include + #include + #include + #include + #include + #include + #include + #include + + #include + + namespace faiss { namespace gpu { + + GpuIndexIVFFlat::GpuIndexIVFFlat(GpuResources* resources, + const faiss::IndexIVFFlat* index, + GpuIndexIVFFlatConfig config) : + GpuIndexIVF(resources, + index->d, + index->metric_type, + index->metric_arg, + index->nlist, + config), + ivfFlatConfig_(config), + reserveMemoryVecs_(0), + index_(nullptr) { + + copyFrom(index); + } + + GpuIndexIVFFlat::GpuIndexIVFFlat(GpuResources* resources, + int dims, + int nlist, + faiss::MetricType metric, + GpuIndexIVFFlatConfig config) : + GpuIndexIVF(resources, dims, metric, 0, nlist, config), + ivfFlatConfig_(config), + reserveMemoryVecs_(0), + index_(nullptr) { + + // faiss::Index params + this->is_trained = false; + + // We haven't trained ourselves, so don't construct the IVFFlat + // index yet + } + + GpuIndexIVFFlat::~GpuIndexIVFFlat() { + delete index_; + } + + void + GpuIndexIVFFlat::reserveMemory(size_t numVecs) { + reserveMemoryVecs_ = numVecs; + if (index_) { + DeviceScope scope(device_); + index_->reserveMemory(numVecs); + } + } + + void + GpuIndexIVFFlat::copyFrom(const faiss::IndexIVFFlat* index) { + DeviceScope scope(device_); + + GpuIndexIVF::copyFrom(index); + + // Clear out our old data + delete index_; + index_ = nullptr; + + // The other index might not be trained + if (!index->is_trained) { + FAISS_ASSERT(!is_trained); + return; + } + + // Otherwise, we can populate ourselves from the other index + FAISS_ASSERT(is_trained); + + // Copy our lists as well + index_ = new IVFFlat(resources_, + quantizer->getGpuData(), + index->metric_type, + index->metric_arg, + false, // no residual + nullptr, // no scalar quantizer + ivfFlatConfig_.indicesOptions, + memorySpace_); + InvertedLists *ivf = index->invlists; + + if (ReadOnlyArrayInvertedLists* rol = dynamic_cast(ivf)) { + index_->copyCodeVectorsFromCpu((const float* )(rol->pin_readonly_codes->data), + (const long *)(rol->pin_readonly_ids->data), rol->readonly_length); + /* double t0 = getmillisecs(); */ + /* std::cout << "Readonly Takes " << getmillisecs() - t0 << " ms" << std::endl; */ + } else { + for (size_t i = 0; i < ivf->nlist; ++i) { + auto numVecs = ivf->list_size(i); + + // GPU index can only support max int entries per list + FAISS_THROW_IF_NOT_FMT(numVecs <= + (size_t) std::numeric_limits::max(), + "GPU inverted list can only support " + "%zu entries; %zu found", + (size_t) std::numeric_limits::max(), + numVecs); + + index_->addCodeVectorsFromCpu(i, + (const unsigned char*)(ivf->get_codes(i)), + ivf->get_ids(i), + numVecs); + } + } + } + + void + GpuIndexIVFFlat::copyFromWithoutCodes(const faiss::IndexIVFFlat* index, const uint8_t* arranged_data) { + DeviceScope scope(device_); + + GpuIndexIVF::copyFrom(index); + + // Clear out our old data + delete index_; + index_ = nullptr; + + // The other index might not be trained + if (!index->is_trained) { + FAISS_ASSERT(!is_trained); + return; + } + + // Otherwise, we can populate ourselves from the other index + FAISS_ASSERT(is_trained); + + // Copy our lists as well + index_ = new IVFFlat(resources_, + quantizer->getGpuData(), + index->metric_type, + index->metric_arg, + false, // no residual + nullptr, // no scalar quantizer + ivfFlatConfig_.indicesOptions, + memorySpace_); + InvertedLists *ivf = index->invlists; + + if (ReadOnlyArrayInvertedLists* rol = dynamic_cast(ivf)) { + index_->copyCodeVectorsFromCpu((const float *) arranged_data, + (const long *)(rol->pin_readonly_ids->data), rol->readonly_length); + } else { + // should not happen + } + } + + void + GpuIndexIVFFlat::copyTo(faiss::IndexIVFFlat* index) const { + DeviceScope scope(device_); + + // We must have the indices in order to copy to ourselves + FAISS_THROW_IF_NOT_MSG(ivfFlatConfig_.indicesOptions != INDICES_IVF, + "Cannot copy to CPU as GPU index doesn't retain " + "indices (INDICES_IVF)"); + + GpuIndexIVF::copyTo(index); + index->code_size = this->d * sizeof(float); + + InvertedLists *ivf = new ArrayInvertedLists(nlist, index->code_size); + index->replace_invlists(ivf, true); + + // Copy the inverted lists + if (index_) { + for (int i = 0; i < nlist; ++i) { + auto listIndices = index_->getListIndices(i); + auto listData = index_->getListVectors(i); + + ivf->add_entries(i, + listIndices.size(), + listIndices.data(), + (const uint8_t*) listData.data()); + } + } + } + + void + GpuIndexIVFFlat::copyToWithoutCodes(faiss::IndexIVFFlat* index) const { + DeviceScope scope(device_); + + // We must have the indices in order to copy to ourselves + FAISS_THROW_IF_NOT_MSG(ivfFlatConfig_.indicesOptions != INDICES_IVF, + "Cannot copy to CPU as GPU index doesn't retain " + "indices (INDICES_IVF)"); + + GpuIndexIVF::copyTo(index); + index->code_size = this->d * sizeof(float); + + InvertedLists *ivf = new ArrayInvertedLists(nlist, index->code_size); + index->replace_invlists(ivf, true); + + // Copy the inverted lists + if (index_) { + for (int i = 0; i < nlist; ++i) { + auto listIndices = index_->getListIndices(i); + + ivf->add_entries_without_codes(i, + listIndices.size(), + listIndices.data()); + } + } + } + + size_t + GpuIndexIVFFlat::reclaimMemory() { + if (index_) { + DeviceScope scope(device_); + + return index_->reclaimMemory(); + } + + return 0; + } + + void + GpuIndexIVFFlat::reset() { + if (index_) { + DeviceScope scope(device_); + + index_->reset(); + this->ntotal = 0; + } else { + FAISS_ASSERT(this->ntotal == 0); + } + } + + void + GpuIndexIVFFlat::train(Index::idx_t n, const float* x) { + DeviceScope scope(device_); + + if (this->is_trained) { + FAISS_ASSERT(quantizer->is_trained); + FAISS_ASSERT(quantizer->ntotal == nlist); + FAISS_ASSERT(index_); + return; + } + + FAISS_ASSERT(!index_); + + trainQuantizer_(n, x); + + // The quantizer is now trained; construct the IVF index + index_ = new IVFFlat(resources_, + quantizer->getGpuData(), + this->metric_type, + this->metric_arg, + false, // no residual + nullptr, // no scalar quantizer + ivfFlatConfig_.indicesOptions, + memorySpace_); + + if (reserveMemoryVecs_) { + index_->reserveMemory(reserveMemoryVecs_); + } + + this->is_trained = true; + } + + void + GpuIndexIVFFlat::addImpl_(int n, + const float* x, + const Index::idx_t* xids) { + // Device is already set in GpuIndex::add + FAISS_ASSERT(index_); + FAISS_ASSERT(n > 0); + + auto stream = resources_->getDefaultStream(device_); + + // Data is already resident on the GPU + Tensor data(const_cast(x), {n, (int) this->d}); + + static_assert(sizeof(long) == sizeof(Index::idx_t), "size mismatch"); + Tensor labels(const_cast(xids), {n}); + + // Not all vectors may be able to be added (some may contain NaNs etc) + index_->classifyAndAddVectors(data, labels); + + // but keep the ntotal based on the total number of vectors that we attempted + // to add + ntotal += n; + } + + void + GpuIndexIVFFlat::searchImpl_(int n, + const float* x, + int k, + float* distances, + Index::idx_t* labels, + ConcurrentBitsetPtr bitset) const { + // Device is already set in GpuIndex::search + FAISS_ASSERT(index_); + FAISS_ASSERT(n > 0); + + auto stream = resources_->getDefaultStream(device_); + + // Data is already resident on the GPU + Tensor queries(const_cast(x), {n, (int) this->d}); + Tensor outDistances(distances, {n, k}); + + static_assert(sizeof(long) == sizeof(Index::idx_t), "size mismatch"); + Tensor outLabels(const_cast(labels), {n, k}); + + if (!bitset) { + auto bitsetDevice = toDevice(resources_, device_, nullptr, stream, {0}); + index_->query(queries, bitsetDevice, nprobe, k, outDistances, outLabels); + } else { + auto bitsetDevice = toDevice(resources_, device_, + const_cast(bitset->data()), stream, + {(int) bitset->size()}); + index_->query(queries, bitsetDevice, nprobe, k, outDistances, outLabels); + } + } + + + } } // namespace + \ No newline at end of file diff --git a/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFFlat.h b/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFFlat.h index a7328c31e312..e0b79aaee125 100644 --- a/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFFlat.h +++ b/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFFlat.h @@ -48,10 +48,14 @@ class GpuIndexIVFFlat : public GpuIndexIVF { /// all data in ourselves void copyFrom(const faiss::IndexIVFFlat* index); + void copyFromWithoutCodes(const faiss::IndexIVFFlat* index, const uint8_t* arranged_data); + /// Copy ourselves to the given CPU index; will overwrite all data /// in the index instance void copyTo(faiss::IndexIVFFlat* index) const; + void copyToWithoutCodes(faiss::IndexIVFFlat* index) const; + /// After adding vectors, one can call this to reclaim device memory /// to exactly the amount needed. Returns space reclaimed in bytes size_t reclaimMemory(); diff --git a/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFPQ.cu b/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFPQ.cu index 254c0c410463..d6095a58e8ef 100644 --- a/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFPQ.cu +++ b/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFPQ.cu @@ -34,6 +34,10 @@ GpuIndexIVFPQ::GpuIndexIVFPQ(GpuResources* resources, bitsPerCode_(0), reserveMemoryVecs_(0), index_(nullptr) { +#ifndef FAISS_USE_FLOAT16 + FAISS_ASSERT(!ivfpqConfig_.useFloat16LookupTables); +#endif + copyFrom(index); } @@ -55,6 +59,10 @@ GpuIndexIVFPQ::GpuIndexIVFPQ(GpuResources* resources, bitsPerCode_(bitsPerCode), reserveMemoryVecs_(0), index_(nullptr) { +#ifndef FAISS_USE_FLOAT16 + FAISS_ASSERT(!config.useFloat16LookupTables); +#endif + verifySettings_(); // We haven't trained ourselves, so don't construct the PQ index yet @@ -424,9 +432,11 @@ GpuIndexIVFPQ::verifySettings_() const { // We must have enough shared memory on the current device to store // our lookup distances int lookupTableSize = sizeof(float); +#ifdef FAISS_USE_FLOAT16 if (ivfpqConfig_.useFloat16LookupTables) { lookupTableSize = sizeof(half); } +#endif // 64 bytes per code is only supported with usage of float16, at 2^8 // codes per subquantizer diff --git a/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFScalarQuantizer.cu b/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFScalarQuantizer.cu index 6be3bb1f7902..9a3d908b9ce3 100644 --- a/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFScalarQuantizer.cu +++ b/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFScalarQuantizer.cu @@ -134,6 +134,49 @@ GpuIndexIVFScalarQuantizer::copyFrom( } } +void +GpuIndexIVFScalarQuantizer::copyFromWithoutCodes( + const faiss::IndexIVFScalarQuantizer* index, const uint8_t* arranged_data) { + DeviceScope scope(device_); + + // Clear out our old data + delete index_; + index_ = nullptr; + + // Copy what we need from the CPU index + GpuIndexIVF::copyFrom(index); + sq = index->sq; + by_residual = index->by_residual; + + // The other index might not be trained, in which case we don't need to copy + // over the lists + if (!index->is_trained) { + return; + } + + // Otherwise, we can populate ourselves from the other index + this->is_trained = true; + + // Copy our lists as well + index_ = new IVFFlat(resources_, + quantizer->getGpuData(), + index->metric_type, + index->metric_arg, + by_residual, + &sq, + ivfSQConfig_.indicesOptions, + memorySpace_); + + InvertedLists* ivf = index->invlists; + + if(ReadOnlyArrayInvertedLists* rol = dynamic_cast(ivf)) { + index_->copyCodeVectorsFromCpu((const float *)arranged_data, + (const long *)(rol->pin_readonly_ids->data), rol->readonly_length); + } else { + // should not happen + } +} + void GpuIndexIVFScalarQuantizer::copyTo( faiss::IndexIVFScalarQuantizer* index) const { @@ -168,6 +211,38 @@ GpuIndexIVFScalarQuantizer::copyTo( } } +void +GpuIndexIVFScalarQuantizer::copyToWithoutCodes( + faiss::IndexIVFScalarQuantizer* index) const { + DeviceScope scope(device_); + + // We must have the indices in order to copy to ourselves + FAISS_THROW_IF_NOT_MSG( + ivfSQConfig_.indicesOptions != INDICES_IVF, + "Cannot copy to CPU as GPU index doesn't retain " + "indices (INDICES_IVF)"); + + GpuIndexIVF::copyTo(index); + index->sq = sq; + index->code_size = sq.code_size; + index->by_residual = by_residual; + index->code_size = sq.code_size; + + InvertedLists* ivf = new ArrayInvertedLists(nlist, index->code_size); + index->replace_invlists(ivf, true); + + // Copy the inverted lists + if (index_) { + for (int i = 0; i < nlist; ++i) { + auto listIndices = index_->getListIndices(i); + + ivf->add_entries_without_codes(i, + listIndices.size(), + listIndices.data()); + } + } +} + size_t GpuIndexIVFScalarQuantizer::reclaimMemory() { if (index_) { diff --git a/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFScalarQuantizer.h b/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFScalarQuantizer.h index 47b8de249fdf..427d9e670206 100644 --- a/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFScalarQuantizer.h +++ b/core/src/index/thirdparty/faiss/gpu/GpuIndexIVFScalarQuantizer.h @@ -52,10 +52,14 @@ class GpuIndexIVFScalarQuantizer : public GpuIndexIVF { /// all data in ourselves void copyFrom(const faiss::IndexIVFScalarQuantizer* index); + void copyFromWithoutCodes(const faiss::IndexIVFScalarQuantizer* index, const uint8_t* arranged_data); + /// Copy ourselves to the given CPU index; will overwrite all data /// in the index instance void copyTo(faiss::IndexIVFScalarQuantizer* index) const; + void copyToWithoutCodes(faiss::IndexIVFScalarQuantizer* index) const; + /// After adding vectors, one can call this to reclaim device memory /// to exactly the amount needed. Returns space reclaimed in bytes size_t reclaimMemory(); diff --git a/core/src/index/thirdparty/faiss/gpu/impl/BroadcastSum.cu b/core/src/index/thirdparty/faiss/gpu/impl/BroadcastSum.cu index 364200c3e4ed..e9f7548e25c5 100644 --- a/core/src/index/thirdparty/faiss/gpu/impl/BroadcastSum.cu +++ b/core/src/index/thirdparty/faiss/gpu/impl/BroadcastSum.cu @@ -262,11 +262,13 @@ void runSumAlongColumns(Tensor& input, runSumAlongColumns(input, output, stream); } +#ifdef FAISS_USE_FLOAT16 void runSumAlongColumns(Tensor& input, Tensor& output, cudaStream_t stream) { runSumAlongColumns(input, output, stream); } +#endif template void runAssignAlongColumns(Tensor& input, @@ -310,11 +312,13 @@ void runAssignAlongColumns(Tensor& input, runAssignAlongColumns(input, output, stream); } +#ifdef FAISS_USE_FLOAT16 void runAssignAlongColumns(Tensor& input, Tensor& output, cudaStream_t stream) { runAssignAlongColumns(input, output, stream); } +#endif template void runSumAlongRows(Tensor& input, @@ -344,11 +348,13 @@ void runSumAlongRows(Tensor& input, runSumAlongRows(input, output, zeroClamp, stream); } +#ifdef FAISS_USE_FLOAT16 void runSumAlongRows(Tensor& input, Tensor& output, bool zeroClamp, cudaStream_t stream) { runSumAlongRows(input, output, zeroClamp, stream); } +#endif } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/BroadcastSum.cuh b/core/src/index/thirdparty/faiss/gpu/impl/BroadcastSum.cuh index 8c4b27452c8f..6641aadd40c1 100644 --- a/core/src/index/thirdparty/faiss/gpu/impl/BroadcastSum.cuh +++ b/core/src/index/thirdparty/faiss/gpu/impl/BroadcastSum.cuh @@ -17,18 +17,22 @@ void runSumAlongColumns(Tensor& input, Tensor& output, cudaStream_t stream); +#ifdef FAISS_USE_FLOAT16 void runSumAlongColumns(Tensor& input, Tensor& output, cudaStream_t stream); +#endif // output[x][i] = input[i] for all x void runAssignAlongColumns(Tensor& input, Tensor& output, cudaStream_t stream); +#ifdef FAISS_USE_FLOAT16 void runAssignAlongColumns(Tensor& input, Tensor& output, cudaStream_t stream); +#endif // output[i][x] += input[i] for all x // If zeroClamp, output[i][x] = max(output[i][x] + input[i], 0) for all x @@ -37,9 +41,11 @@ void runSumAlongRows(Tensor& input, bool zeroClamp, cudaStream_t stream); +#ifdef FAISS_USE_FLOAT16 void runSumAlongRows(Tensor& input, Tensor& output, bool zeroClamp, cudaStream_t stream); +#endif } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/Distance.cu b/core/src/index/thirdparty/faiss/gpu/impl/Distance.cu index e4aa3af1fc7b..0856396cc19e 100644 --- a/core/src/index/thirdparty/faiss/gpu/impl/Distance.cu +++ b/core/src/index/thirdparty/faiss/gpu/impl/Distance.cu @@ -370,6 +370,7 @@ runIPDistance(GpuResources* resources, outIndices); } +#ifdef FAISS_USE_FLOAT16 void runIPDistance(GpuResources* resources, Tensor& vectors, @@ -390,6 +391,7 @@ runIPDistance(GpuResources* resources, outDistances, outIndices); } +#endif void runL2Distance(GpuResources* resources, @@ -416,6 +418,7 @@ runL2Distance(GpuResources* resources, ignoreOutDistances); } +#ifdef FAISS_USE_FLOAT16 void runL2Distance(GpuResources* resources, Tensor& vectors, @@ -440,5 +443,6 @@ runL2Distance(GpuResources* resources, outIndices, ignoreOutDistances); } +#endif } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/Distance.cuh b/core/src/index/thirdparty/faiss/gpu/impl/Distance.cuh index 844d420aea42..3430ddf87f95 100644 --- a/core/src/index/thirdparty/faiss/gpu/impl/Distance.cuh +++ b/core/src/index/thirdparty/faiss/gpu/impl/Distance.cuh @@ -57,6 +57,7 @@ void runIPDistance(GpuResources* resources, Tensor& outDistances, Tensor& outIndices); + void runL2Distance(GpuResources* resources, Tensor& vectors, bool vectorsRowMajor, diff --git a/core/src/index/thirdparty/faiss/gpu/impl/FlatIndex.cu b/core/src/index/thirdparty/faiss/gpu/impl/FlatIndex.cu index e7545df767c7..29480fa84fd3 100644 --- a/core/src/index/thirdparty/faiss/gpu/impl/FlatIndex.cu +++ b/core/src/index/thirdparty/faiss/gpu/impl/FlatIndex.cu @@ -29,6 +29,9 @@ FlatIndex::FlatIndex(GpuResources* res, space_(space), num_(0), rawData_(space) { +#ifndef FAISS_USE_FLOAT16 + FAISS_ASSERT(!useFloat16_); +#endif } bool @@ -38,28 +41,40 @@ FlatIndex::getUseFloat16() const { /// Returns the number of vectors we contain int FlatIndex::getSize() const { +#ifdef FAISS_USE_FLOAT16 if (useFloat16_) { return vectorsHalf_.getSize(0); } else { return vectors_.getSize(0); } +#else + return vectors_.getSize(0); +#endif } int FlatIndex::getDim() const { +#ifdef FAISS_USE_FLOAT16 if (useFloat16_) { return vectorsHalf_.getSize(1); } else { return vectors_.getSize(1); } +#else + return vectors_.getSize(1); +#endif } void FlatIndex::reserve(size_t numVecs, cudaStream_t stream) { +#ifdef FAISS_USE_FLOAT16 if (useFloat16_) { rawData_.reserve(numVecs * dim_ * sizeof(half), stream); } else { rawData_.reserve(numVecs * dim_ * sizeof(float), stream); } +#else + rawData_.reserve(numVecs * dim_ * sizeof(float), stream); +#endif } template <> @@ -70,6 +85,7 @@ FlatIndex::getVectorsRef() { return getVectorsFloat32Ref(); } +#ifdef FAISS_USE_FLOAT16 template <> Tensor& FlatIndex::getVectorsRef() { @@ -77,6 +93,7 @@ FlatIndex::getVectorsRef() { FAISS_ASSERT(useFloat16_); return getVectorsFloat16Ref(); } +#endif Tensor& FlatIndex::getVectorsFloat32Ref() { @@ -86,6 +103,7 @@ FlatIndex::getVectorsFloat32Ref() { return vectors_; } +#ifdef FAISS_USE_FLOAT16 Tensor& FlatIndex::getVectorsFloat16Ref() { // Should not call this unless we are in float16 mode @@ -93,6 +111,7 @@ FlatIndex::getVectorsFloat16Ref() { return vectorsHalf_; } +#endif DeviceTensor FlatIndex::getVectorsFloat32Copy(cudaStream_t stream) { @@ -103,12 +122,16 @@ DeviceTensor FlatIndex::getVectorsFloat32Copy(int from, int num, cudaStream_t stream) { DeviceTensor vecFloat32({num, dim_}, space_); +#ifdef FAISS_USE_FLOAT16 if (useFloat16_) { auto halfNarrow = vectorsHalf_.narrowOutermost(from, num); convertTensor(stream, halfNarrow, vecFloat32); } else { vectors_.copyTo(vecFloat32, stream); } +#else + vectors_.copyTo(vecFloat32, stream); +#endif return vecFloat32; } @@ -125,13 +148,16 @@ FlatIndex::query(Tensor& input, auto stream = resources_->getDefaultStreamCurrentDevice(); auto& mem = resources_->getMemoryManagerCurrentDevice(); +#ifdef FAISS_USE_FLOAT16 if (useFloat16_) { // We need to convert the input to float16 for comparison to ourselves + auto inputHalf = convertTensor(resources_, stream, input); query(inputHalf, bitset, k, metric, metricArg, outDistances, outIndices, exactDistance); + } else { bfKnnOnDevice(resources_, getCurrentDevice(), @@ -149,8 +175,26 @@ FlatIndex::query(Tensor& input, outIndices, !exactDistance); } +#else + bfKnnOnDevice(resources_, + getCurrentDevice(), + stream, + storeTransposed_ ? vectorsTransposed_ : vectors_, + !storeTransposed_, // is vectors row major? + &norms_, + input, + true, // input is row major + bitset, + k, + metric, + metricArg, + outDistances, + outIndices, + !exactDistance); +#endif } +#ifdef FAISS_USE_FLOAT16 void FlatIndex::query(Tensor& input, Tensor& bitset, @@ -178,11 +222,13 @@ FlatIndex::query(Tensor& input, outIndices, !exactDistance); } +#endif void FlatIndex::computeResidual(Tensor& vecs, Tensor& listIds, Tensor& residuals) { +#ifdef FAISS_USE_FLOAT16 if (useFloat16_) { runCalcResidual(vecs, getVectorsFloat16Ref(), @@ -196,11 +242,19 @@ FlatIndex::computeResidual(Tensor& vecs, residuals, resources_->getDefaultStreamCurrentDevice()); } +#else + runCalcResidual(vecs, + getVectorsFloat32Ref(), + listIds, + residuals, + resources_->getDefaultStreamCurrentDevice()); +#endif } void FlatIndex::reconstruct(Tensor& listIds, Tensor& vecs) { +#ifdef FAISS_USE_FLOAT16 if (useFloat16_) { runReconstruct(listIds, getVectorsFloat16Ref(), @@ -212,8 +266,13 @@ FlatIndex::reconstruct(Tensor& listIds, vecs, resources_->getDefaultStreamCurrentDevice()); } +#else + runReconstruct(listIds, + getVectorsFloat32Ref(), + vecs, + resources_->getDefaultStreamCurrentDevice()); +#endif } - void FlatIndex::reconstruct(Tensor& listIds, Tensor& vecs) { @@ -229,6 +288,7 @@ FlatIndex::add(const float* data, int numVecs, cudaStream_t stream) { return; } +#ifdef FAISS_USE_FLOAT16 if (useFloat16_) { // Make sure that `data` is on our device; we'll run the // conversion on our device @@ -252,8 +312,15 @@ FlatIndex::add(const float* data, int numVecs, cudaStream_t stream) { true /* reserve exactly */); } +#else + rawData_.append((char*) data, + (size_t) dim_ * numVecs * sizeof(float), + stream, + true /* reserve exactly */); +#endif num_ += numVecs; +#ifdef FAISS_USE_FLOAT16 if (useFloat16_) { DeviceTensor vectorsHalf( (half*) rawData_.data(), {(int) num_, dim_}, space_); @@ -263,8 +330,14 @@ FlatIndex::add(const float* data, int numVecs, cudaStream_t stream) { (float*) rawData_.data(), {(int) num_, dim_}, space_); vectors_ = std::move(vectors); } +#else + DeviceTensor vectors( + (float*) rawData_.data(), {(int) num_, dim_}, space_); + vectors_ = std::move(vectors); +#endif if (storeTransposed_) { +#ifdef FAISS_USE_FLOAT16 if (useFloat16_) { vectorsHalfTransposed_ = std::move(DeviceTensor({dim_, (int) num_}, space_)); @@ -274,9 +347,15 @@ FlatIndex::add(const float* data, int numVecs, cudaStream_t stream) { std::move(DeviceTensor({dim_, (int) num_}, space_)); runTransposeAny(vectors_, 0, 1, vectorsTransposed_, stream); } +#else + vectorsTransposed_ = + std::move(DeviceTensor({dim_, (int) num_}, space_)); + runTransposeAny(vectors_, 0, 1, vectorsTransposed_, stream); +#endif } // Precompute L2 norms of our database +#ifdef FAISS_USE_FLOAT16 if (useFloat16_) { DeviceTensor norms({(int) num_}, space_); runL2Norm(vectorsHalf_, true, norms, true, stream); @@ -286,6 +365,11 @@ FlatIndex::add(const float* data, int numVecs, cudaStream_t stream) { runL2Norm(vectors_, true, norms, true, stream); norms_ = std::move(norms); } +#else + DeviceTensor norms({(int) num_}, space_); + runL2Norm(vectors_, true, norms, true, stream); + norms_ = std::move(norms); +#endif } void @@ -293,8 +377,10 @@ FlatIndex::reset() { rawData_.clear(); vectors_ = std::move(DeviceTensor()); vectorsTransposed_ = std::move(DeviceTensor()); +#ifdef FAISS_USE_FLOAT16 vectorsHalf_ = std::move(DeviceTensor()); vectorsHalfTransposed_ = std::move(DeviceTensor()); +#endif norms_ = std::move(DeviceTensor()); num_ = 0; } diff --git a/core/src/index/thirdparty/faiss/gpu/impl/FlatIndex.cuh b/core/src/index/thirdparty/faiss/gpu/impl/FlatIndex.cuh index 5bc97441c436..eef07df24c00 100644 --- a/core/src/index/thirdparty/faiss/gpu/impl/FlatIndex.cuh +++ b/core/src/index/thirdparty/faiss/gpu/impl/FlatIndex.cuh @@ -47,7 +47,9 @@ class FlatIndex { Tensor& getVectorsFloat32Ref(); /// Returns a reference to our vectors currently in use (useFloat16 mode) +#ifdef FAISS_USE_FLOAT16 Tensor& getVectorsFloat16Ref(); +#endif /// Performs a copy of the vectors on the given device, converting /// as needed from float16 @@ -67,6 +69,7 @@ class FlatIndex { Tensor& outIndices, bool exactDistance); +#ifdef FAISS_USE_FLOAT16 void query(Tensor& vecs, Tensor& bitset, int k, @@ -75,6 +78,7 @@ class FlatIndex { Tensor& outDistances, Tensor& outIndices, bool exactDistance); +#endif /// Compute residual for set of vectors void computeResidual(Tensor& vecs, @@ -123,8 +127,10 @@ class FlatIndex { DeviceTensor vectorsTransposed_; /// Vectors currently in rawData_, float16 form +#ifdef FAISS_USE_FLOAT16 DeviceTensor vectorsHalf_; DeviceTensor vectorsHalfTransposed_; +#endif /// Precomputed L2 norms DeviceTensor norms_; diff --git a/core/src/index/thirdparty/faiss/gpu/impl/IVFPQ.cu b/core/src/index/thirdparty/faiss/gpu/impl/IVFPQ.cu index 01c91f9c3fb1..48254c1f5b26 100644 --- a/core/src/index/thirdparty/faiss/gpu/impl/IVFPQ.cu +++ b/core/src/index/thirdparty/faiss/gpu/impl/IVFPQ.cu @@ -60,6 +60,10 @@ IVFPQ::IVFPQ(GpuResources* resources, FAISS_ASSERT(dim_ % numSubQuantizers_ == 0); FAISS_ASSERT(isSupportedPQCodeLength(bytesPerVector_)); +#ifndef FAISS_USE_FLOAT16 + FAISS_ASSERT(!useFloat16LookupTables_); +#endif + setPQCentroids_(pqCentroidData); } @@ -112,7 +116,9 @@ IVFPQ::setPrecomputedCodes(bool enable) { } else { // Clear out old precomputed code data precomputedCode_ = std::move(DeviceTensor()); +#ifdef FAISS_USE_FLOAT16 precomputedCodeHalf_ = std::move(DeviceTensor()); +#endif } } } @@ -156,6 +162,7 @@ IVFPQ::classifyAndAddVectors(Tensor& vecs, DeviceTensor residuals( mem, {vecs.getSize(0), vecs.getSize(1)}, stream); +#ifdef FAISS_USE_FLOAT16 if (quantizer_->getUseFloat16()) { auto& coarseCentroids = quantizer_->getVectorsFloat16Ref(); runCalcResidual(vecs, coarseCentroids, listIds, residuals, stream); @@ -163,6 +170,10 @@ IVFPQ::classifyAndAddVectors(Tensor& vecs, auto& coarseCentroids = quantizer_->getVectorsFloat32Ref(); runCalcResidual(vecs, coarseCentroids, listIds, residuals, stream); } +#else + auto& coarseCentroids = quantizer_->getVectorsFloat32Ref(); + runCalcResidual(vecs, coarseCentroids, listIds, residuals, stream); +#endif // Residuals are in the form // (vec x numSubQuantizer x dimPerSubQuantizer) @@ -519,6 +530,7 @@ IVFPQ::precomputeCodesT_() { // We added into the view, so `coarsePQProductTransposed` is now our // precomputed term 2. +#ifdef FAISS_USE_FLOAT16 if (useFloat16LookupTables_) { precomputedCodeHalf_ = convertTensor(resources_, @@ -527,15 +539,23 @@ IVFPQ::precomputeCodesT_() { } else { precomputedCode_ = std::move(coarsePQProductTransposed); } +#else + precomputedCode_ = std::move(coarsePQProductTransposed); +#endif + } void IVFPQ::precomputeCodes_() { +#ifdef FAISS_USE_FLOAT16 if (quantizer_->getUseFloat16()) { precomputeCodesT_(); } else { precomputeCodesT_(); } +#else + precomputeCodesT_(); +#endif } void @@ -678,6 +698,7 @@ IVFPQ::runPQPrecomputedCodes_( NoTypeTensor<3, true> term2; NoTypeTensor<3, true> term3; +#ifdef FAISS_USE_FLOAT16 DeviceTensor term3Half; if (useFloat16LookupTables_) { @@ -686,7 +707,10 @@ IVFPQ::runPQPrecomputedCodes_( term2 = NoTypeTensor<3, true>(precomputedCodeHalf_); term3 = NoTypeTensor<3, true>(term3Half); - } else { + } +#endif + + if (!useFloat16LookupTables_) { term2 = NoTypeTensor<3, true>(precomputedCode_); term3 = NoTypeTensor<3, true>(term3Transposed); } @@ -754,6 +778,7 @@ IVFPQ::runPQNoPrecomputedCodes_( int k, Tensor& outDistances, Tensor& outIndices) { +#ifdef FAISS_USE_FLOAT16 if (quantizer_->getUseFloat16()) { runPQNoPrecomputedCodesT_(queries, bitset, @@ -770,7 +795,17 @@ IVFPQ::runPQNoPrecomputedCodes_( k, outDistances, outIndices); - } + } +#else + runPQNoPrecomputedCodesT_(queries, + bitset, + coarseDistances, + coarseIndices, + k, + outDistances, + outIndices); +#endif + } } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/IVFPQ.cuh b/core/src/index/thirdparty/faiss/gpu/impl/IVFPQ.cuh index db8cbb68aa61..ad03fb4f8901 100644 --- a/core/src/index/thirdparty/faiss/gpu/impl/IVFPQ.cuh +++ b/core/src/index/thirdparty/faiss/gpu/impl/IVFPQ.cuh @@ -153,7 +153,9 @@ class IVFPQ : public IVFBase { DeviceTensor precomputedCode_; /// Precomputed term 2 in half form +#ifdef FAISS_USE_FLOAT16 DeviceTensor precomputedCodeHalf_; +#endif }; } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/L2Norm.cu b/core/src/index/thirdparty/faiss/gpu/impl/L2Norm.cu index 96bcd8e95b9a..bdf812524e20 100644 --- a/core/src/index/thirdparty/faiss/gpu/impl/L2Norm.cu +++ b/core/src/index/thirdparty/faiss/gpu/impl/L2Norm.cu @@ -309,6 +309,7 @@ void runL2Norm(Tensor& input, } } +#ifdef FAISS_USE_FLOAT16 void runL2Norm(Tensor& input, bool inputRowMajor, Tensor& output, @@ -325,5 +326,6 @@ void runL2Norm(Tensor& input, inputCast, inputRowMajor, outputCast, normSquared, stream); } } +#endif } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/L2Norm.cuh b/core/src/index/thirdparty/faiss/gpu/impl/L2Norm.cuh index c4d585080290..6df3dcea583e 100644 --- a/core/src/index/thirdparty/faiss/gpu/impl/L2Norm.cuh +++ b/core/src/index/thirdparty/faiss/gpu/impl/L2Norm.cuh @@ -18,10 +18,12 @@ void runL2Norm(Tensor& input, bool normSquared, cudaStream_t stream); +#ifdef FAISS_USE_FLOAT16 void runL2Norm(Tensor& input, bool inputRowMajor, Tensor& output, bool normSquared, cudaStream_t stream); +#endif } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/PQCodeDistances-inl.cuh b/core/src/index/thirdparty/faiss/gpu/impl/PQCodeDistances-inl.cuh index c3ef87f2e7b7..520a8bcafb67 100644 --- a/core/src/index/thirdparty/faiss/gpu/impl/PQCodeDistances-inl.cuh +++ b/core/src/index/thirdparty/faiss/gpu/impl/PQCodeDistances-inl.cuh @@ -438,6 +438,7 @@ runPQCodeDistancesMM(Tensor& pqCentroids, runSumAlongColumns(pqCentroidsNorm, outDistancesCodeViewCols, stream); +#ifdef FAISS_USE_FLOAT16 if (useFloat16Lookup) { // Need to convert back auto outCodeDistancesH = outCodeDistances.toTensor(); @@ -445,6 +446,7 @@ runPQCodeDistancesMM(Tensor& pqCentroids, outCodeDistancesF, outCodeDistancesH); } +#endif } template @@ -477,6 +479,7 @@ runPQCodeDistances(Tensor& pqCentroids, auto smem = (3 * dimsPerSubQuantizer) * sizeof(float) + topQueryToCentroid.getSize(1) * sizeof(int); +#ifdef FAISS_USE_FLOAT16 #define RUN_CODE(DIMS, L2) \ do { \ if (useFloat16Lookup) { \ @@ -495,6 +498,16 @@ runPQCodeDistances(Tensor& pqCentroids, topQueryToCentroid, outCodeDistancesT); \ } \ } while (0) +#else +#define RUN_CODE(DIMS, L2) \ + do { \ + auto outCodeDistancesT = outCodeDistances.toTensor(); \ + pqCodeDistances<<>>( \ + queries, kQueriesPerBlock, \ + coarseCentroids, pqCentroids, \ + topQueryToCentroid, outCodeDistancesT); \ + } while (0) +#endif #define CODE_L2(DIMS) \ do { \ diff --git a/core/src/index/thirdparty/faiss/gpu/impl/PQCodeDistances.cu b/core/src/index/thirdparty/faiss/gpu/impl/PQCodeDistances.cu index 817990b4a6ed..eec88523104c 100644 --- a/core/src/index/thirdparty/faiss/gpu/impl/PQCodeDistances.cu +++ b/core/src/index/thirdparty/faiss/gpu/impl/PQCodeDistances.cu @@ -26,10 +26,12 @@ template struct Converter { }; +#ifdef FAISS_USE_FLOAT16 template <> struct Converter { inline static __device__ half to(float v) { return __float2half(v); } }; +#endif template <> struct Converter { @@ -394,6 +396,7 @@ runPQCodeDistancesMM(Tensor& pqCentroids, Tensor outCodeDistancesF; DeviceTensor outCodeDistancesFloatMem; +#ifdef FAISS_USE_FLOAT16 if (useFloat16Lookup) { outCodeDistancesFloatMem = DeviceTensor( mem, {outCodeDistances.getSize(0), @@ -406,6 +409,9 @@ runPQCodeDistancesMM(Tensor& pqCentroids, } else { outCodeDistancesF = outCodeDistances.toTensor(); } +#else + outCodeDistancesF = outCodeDistances.toTensor(); +#endif // Transpose -2(sub q)(q * c)(code) to -2(q * c)(sub q)(code) (which // is where we build our output distances) @@ -445,6 +451,7 @@ runPQCodeDistancesMM(Tensor& pqCentroids, runSumAlongColumns(pqCentroidsNorm, outDistancesCodeViewCols, stream); +#ifdef FAISS_USE_FLOAT16 if (useFloat16Lookup) { // Need to convert back auto outCodeDistancesH = outCodeDistances.toTensor(); @@ -452,6 +459,7 @@ runPQCodeDistancesMM(Tensor& pqCentroids, outCodeDistancesF, outCodeDistancesH); } +#endif } void @@ -483,6 +491,7 @@ runPQCodeDistances(Tensor& pqCentroids, auto smem = (3 * dimsPerSubQuantizer) * sizeof(float) + topQueryToCentroid.getSize(1) * sizeof(int); +#ifdef FAISS_USE_FLOAT16 #define RUN_CODE(DIMS, L2) \ do { \ if (useFloat16Lookup) { \ @@ -492,7 +501,19 @@ runPQCodeDistances(Tensor& pqCentroids, queries, kQueriesPerBlock, \ coarseCentroids, pqCentroids, \ topQueryToCentroid, outCodeDistancesT); \ - } else { \ + } else { \ + auto outCodeDistancesT = outCodeDistances.toTensor(); \ + \ + pqCodeDistances<<>>( \ + queries, kQueriesPerBlock, \ + coarseCentroids, pqCentroids, \ + topQueryToCentroid, outCodeDistancesT); \ + } \ + } while (0) +#else +#define RUN_CODE(DIMS, L2) \ + do { \ + if(!useFloat16Lookup){ \ auto outCodeDistancesT = outCodeDistances.toTensor(); \ \ pqCodeDistances<<>>( \ @@ -501,6 +522,7 @@ runPQCodeDistances(Tensor& pqCentroids, topQueryToCentroid, outCodeDistancesT); \ } \ } while (0) +#endif #define CODE_L2(DIMS) \ do { \ diff --git a/core/src/index/thirdparty/faiss/gpu/impl/PQScanMultiPassNoPrecomputed-inl.cuh b/core/src/index/thirdparty/faiss/gpu/impl/PQScanMultiPassNoPrecomputed-inl.cuh index ffc81b1f8c09..a77e783d0967 100644 --- a/core/src/index/thirdparty/faiss/gpu/impl/PQScanMultiPassNoPrecomputed-inl.cuh +++ b/core/src/index/thirdparty/faiss/gpu/impl/PQScanMultiPassNoPrecomputed-inl.cuh @@ -275,7 +275,12 @@ runMultiPassTile(Tensor& queries, auto block = dim3(kThreadsPerBlock); // pq centroid distances - auto smem = useFloat16Lookup ? sizeof(half) : sizeof(float); + +#ifdef FAISS_USE_FLOAT16 + auto smem = (sizeof(float)== useFloat16Lookup) ? sizeof(half) : sizeof(float); +#else + auto smem = sizeof(float); +#endif smem *= numSubQuantizers * numSubQuantizerCodes; FAISS_ASSERT(smem <= getMaxSharedMemPerBlockCurrentDevice()); @@ -296,6 +301,7 @@ runMultiPassTile(Tensor& queries, allDistances); \ } while (0) +#ifdef FAISS_USE_FLOAT16 #define RUN_PQ(NUM_SUB_Q) \ do { \ if (useFloat16Lookup) { \ @@ -304,6 +310,12 @@ runMultiPassTile(Tensor& queries, RUN_PQ_OPT(NUM_SUB_Q, float, float4); \ } \ } while (0) +#else +#define RUN_PQ(NUM_SUB_Q) \ + do { \ + RUN_PQ_OPT(NUM_SUB_Q, float, float4); \ + } while (0) +#endif switch (bytesPerCode) { case 1: @@ -499,7 +511,12 @@ runPQScanMultiPassNoPrecomputed(Tensor& queries, sizeof(int), stream)); - int codeDistanceTypeSize = useFloat16Lookup ? sizeof(half) : sizeof(float); + int codeDistanceTypeSize = sizeof(float); +#ifdef FAISS_USE_FLOAT16 + if (useFloat16Lookup) { + codeDistanceTypeSize = sizeof(half); + } +#endif int totalCodeDistancesSize = queryTileSize * nprobe * numSubQuantizers * numSubQuantizerCodes * diff --git a/core/src/index/thirdparty/faiss/gpu/impl/PQScanMultiPassNoPrecomputed.cu b/core/src/index/thirdparty/faiss/gpu/impl/PQScanMultiPassNoPrecomputed.cu index ecf35fffdb80..b4934382cba3 100644 --- a/core/src/index/thirdparty/faiss/gpu/impl/PQScanMultiPassNoPrecomputed.cu +++ b/core/src/index/thirdparty/faiss/gpu/impl/PQScanMultiPassNoPrecomputed.cu @@ -248,8 +248,11 @@ runMultiPassTile(Tensor& queries, metric == MetricType::METRIC_L2); bool l2Distance = metric == MetricType::METRIC_L2; - // Calculate offset lengths, so we know where to write out +#ifndef FAISS_USE_FLOAT16 + FAISS_ASSERT(!useFloat16Lookup); +#endif + // Calculate offset lengths, so we know where to write out // intermediate results runCalcListOffsets(topQueryToCentroid, listLengths, prefixSumOffsets, thrustMem, stream); @@ -275,7 +278,13 @@ runMultiPassTile(Tensor& queries, auto block = dim3(kThreadsPerBlock); // pq centroid distances - auto smem = useFloat16Lookup ? sizeof(half) : sizeof(float); + //auto smem = useFloat16Lookup ? sizeof(half) : sizeof(float); + auto smem = sizeof(float); +#ifdef FAISS_USE_FLOAT16 + if (useFloat16Lookup) { + smem = sizeof(half); + } +#endif smem *= numSubQuantizers * numSubQuantizerCodes; FAISS_ASSERT(smem <= getMaxSharedMemPerBlockCurrentDevice()); @@ -296,6 +305,7 @@ runMultiPassTile(Tensor& queries, allDistances); \ } while (0) +#ifdef FAISS_USE_FLOAT16 #define RUN_PQ(NUM_SUB_Q) \ do { \ if (useFloat16Lookup) { \ @@ -304,6 +314,12 @@ runMultiPassTile(Tensor& queries, RUN_PQ_OPT(NUM_SUB_Q, float, float4); \ } \ } while (0) +#else +#define RUN_PQ(NUM_SUB_Q) \ + do { \ + RUN_PQ_OPT(NUM_SUB_Q, float, float4); \ + } while (0) +#endif // FAISS_USE_FLOAT16 switch (bytesPerCode) { case 1: @@ -497,7 +513,14 @@ void runPQScanMultiPassNoPrecomputed(Tensor& queries, sizeof(int), stream)); - int codeDistanceTypeSize = useFloat16Lookup ? sizeof(half) : sizeof(float); + int codeDistanceTypeSize = sizeof(float); +#ifdef FAISS_USE_FLOAT16 + if (useFloat16Lookup) { + codeDistanceTypeSize = sizeof(half); + } +#else + FAISS_ASSERT(!useFloat16Lookup); +#endif int totalCodeDistancesSize = queryTileSize * nprobe * numSubQuantizers * numSubQuantizerCodes * diff --git a/core/src/index/thirdparty/faiss/gpu/impl/PQScanMultiPassPrecomputed.cu b/core/src/index/thirdparty/faiss/gpu/impl/PQScanMultiPassPrecomputed.cu index 583ee477dc49..02e65ff32a34 100644 --- a/core/src/index/thirdparty/faiss/gpu/impl/PQScanMultiPassPrecomputed.cu +++ b/core/src/index/thirdparty/faiss/gpu/impl/PQScanMultiPassPrecomputed.cu @@ -252,7 +252,12 @@ runMultiPassTile(Tensor& queries, auto block = dim3(kThreadsPerBlock); // pq precomputed terms (2 + 3) - auto smem = useFloat16Lookup ? sizeof(half) : sizeof(float); + auto smem = sizeof(float); +#ifdef FAISS_USE_FLOAT16 + if (useFloat16Lookup) { + smem = sizeof(half); + } +#endif smem *= numSubQuantizers * numSubQuantizerCodes; FAISS_ASSERT(smem <= getMaxSharedMemPerBlockCurrentDevice()); @@ -275,6 +280,7 @@ runMultiPassTile(Tensor& queries, allDistances); \ } while (0) +#ifdef FAISS_USE_FLOAT16 #define RUN_PQ(NUM_SUB_Q) \ do { \ if (useFloat16Lookup) { \ @@ -283,6 +289,12 @@ runMultiPassTile(Tensor& queries, RUN_PQ_OPT(NUM_SUB_Q, float, float4); \ } \ } while (0) +#else +#define RUN_PQ(NUM_SUB_Q) \ + do { \ + RUN_PQ_OPT(NUM_SUB_Q, float, float4); \ + } while (0) +#endif switch (bytesPerCode) { case 1: diff --git a/core/src/index/thirdparty/faiss/gpu/impl/VectorResidual.cu b/core/src/index/thirdparty/faiss/gpu/impl/VectorResidual.cu index 078e6604178e..980b3c39790f 100644 --- a/core/src/index/thirdparty/faiss/gpu/impl/VectorResidual.cu +++ b/core/src/index/thirdparty/faiss/gpu/impl/VectorResidual.cu @@ -119,6 +119,7 @@ void runCalcResidual(Tensor& vecs, calcResidual(vecs, centroids, vecToCentroid, residuals, stream); } +#ifdef FAISS_USE_FLOAT16 void runCalcResidual(Tensor& vecs, Tensor& centroids, Tensor& vecToCentroid, @@ -126,6 +127,7 @@ void runCalcResidual(Tensor& vecs, cudaStream_t stream) { calcResidual(vecs, centroids, vecToCentroid, residuals, stream); } +#endif void runReconstruct(Tensor& listIds, Tensor& vecs, @@ -134,11 +136,13 @@ void runReconstruct(Tensor& listIds, gatherReconstruct(listIds, vecs, out, stream); } +#ifdef FAISS_USE_FLOAT16 void runReconstruct(Tensor& listIds, Tensor& vecs, Tensor& out, cudaStream_t stream) { gatherReconstruct(listIds, vecs, out, stream); } +#endif } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/impl/VectorResidual.cuh b/core/src/index/thirdparty/faiss/gpu/impl/VectorResidual.cuh index ca7bcaa0b612..8e8cd2e75625 100644 --- a/core/src/index/thirdparty/faiss/gpu/impl/VectorResidual.cuh +++ b/core/src/index/thirdparty/faiss/gpu/impl/VectorResidual.cuh @@ -31,9 +31,11 @@ void runReconstruct(Tensor& listIds, Tensor& out, cudaStream_t stream); +#ifdef FAISS_USE_FLOAT16 void runReconstruct(Tensor& listIds, Tensor& vecs, Tensor& out, cudaStream_t stream); +# endif } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/BlockSelectHalf.cu b/core/src/index/thirdparty/faiss/gpu/utils/BlockSelectHalf.cu index 4f642a0ca833..f6989fc084c7 100644 --- a/core/src/index/thirdparty/faiss/gpu/utils/BlockSelectHalf.cu +++ b/core/src/index/thirdparty/faiss/gpu/utils/BlockSelectHalf.cu @@ -10,6 +10,8 @@ namespace faiss { namespace gpu { +#ifdef FAISS_USE_FLOAT16 + // warp Q to thread Q: // 1, 1 // 32, 2 @@ -143,4 +145,6 @@ void runBlockSelectPair(Tensor& inK, } } +#endif // FAISS_USE_FLOAT16 + } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/BlockSelectKernel.cuh b/core/src/index/thirdparty/faiss/gpu/utils/BlockSelectKernel.cuh index 238909d4b07e..f787335cdfe6 100644 --- a/core/src/index/thirdparty/faiss/gpu/utils/BlockSelectKernel.cuh +++ b/core/src/index/thirdparty/faiss/gpu/utils/BlockSelectKernel.cuh @@ -241,6 +241,7 @@ void runBlockSelectPair(Tensor& inKeys, Tensor& outIndices, bool dir, int k, cudaStream_t stream); +#ifdef FAISS_USE_FLOAT16 void runBlockSelect(Tensor& in, Tensor& bitset, Tensor& outKeys, @@ -253,5 +254,6 @@ void runBlockSelectPair(Tensor& inKeys, Tensor& outKeys, Tensor& outIndices, bool dir, int k, cudaStream_t stream); +#endif } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/ConversionOperators.cuh b/core/src/index/thirdparty/faiss/gpu/utils/ConversionOperators.cuh index ddc30af17381..cf9b74c971db 100644 --- a/core/src/index/thirdparty/faiss/gpu/utils/ConversionOperators.cuh +++ b/core/src/index/thirdparty/faiss/gpu/utils/ConversionOperators.cuh @@ -29,6 +29,7 @@ struct Convert { } }; +#ifdef FAISS_USE_FLOAT16 template <> struct Convert { inline __device__ half operator()(float v) const { @@ -42,6 +43,7 @@ struct Convert { return __half2float(v); } }; +#endif template struct ConvertTo { @@ -50,38 +52,50 @@ struct ConvertTo { template <> struct ConvertTo { static inline __device__ float to(float v) { return v; } +#ifdef FAISS_USE_FLOAT16 static inline __device__ float to(half v) { return __half2float(v); } +#endif }; template <> struct ConvertTo { static inline __device__ float2 to(float2 v) { return v; } +#ifdef FAISS_USE_FLOAT16 static inline __device__ float2 to(half2 v) { return __half22float2(v); } +#endif }; template <> struct ConvertTo { static inline __device__ float4 to(float4 v) { return v; } +#ifdef FAISS_USE_FLOAT16 static inline __device__ float4 to(Half4 v) { return half4ToFloat4(v); } +#endif }; +#ifdef FAISS_USE_FLOAT16 template <> struct ConvertTo { static inline __device__ half to(float v) { return __float2half(v); } static inline __device__ half to(half v) { return v; } }; +#endif +#ifdef FAISS_USE_FLOAT16 template <> struct ConvertTo { static inline __device__ half2 to(float2 v) { return __float22half2_rn(v); } static inline __device__ half2 to(half2 v) { return v; } }; +#endif +#ifdef FAISS_USE_FLOAT16 template <> struct ConvertTo { static inline __device__ Half4 to(float4 v) { return float4ToHalf4(v); } static inline __device__ Half4 to(Half4 v) { return v; } }; +#endif // Tensor conversion template diff --git a/core/src/index/thirdparty/faiss/gpu/utils/Float16.cu b/core/src/index/thirdparty/faiss/gpu/utils/Float16.cu index 52d54df309f2..e1f5c09b9fc2 100644 --- a/core/src/index/thirdparty/faiss/gpu/utils/Float16.cu +++ b/core/src/index/thirdparty/faiss/gpu/utils/Float16.cu @@ -12,6 +12,8 @@ #include #include +#ifdef FAISS_USE_FLOAT16 + namespace faiss { namespace gpu { bool getDeviceSupportsFloat16Math(int device) { @@ -36,3 +38,5 @@ __half hostFloat2Half(float a) { } } } // namespace + +#endif // FAISS_USE_FLOAT16 diff --git a/core/src/index/thirdparty/faiss/gpu/utils/Float16.cuh b/core/src/index/thirdparty/faiss/gpu/utils/Float16.cuh index 09566eaa94b9..0af798ba80fb 100644 --- a/core/src/index/thirdparty/faiss/gpu/utils/Float16.cuh +++ b/core/src/index/thirdparty/faiss/gpu/utils/Float16.cuh @@ -22,10 +22,14 @@ #define FAISS_USE_FULL_FLOAT16 1 #endif // __CUDA_ARCH__ types +#ifdef FAISS_USE_FLOAT16 #include +#endif namespace faiss { namespace gpu { +#ifdef FAISS_USE_FLOAT16 + // 64 bytes containing 4 half (float16) values struct Half4 { half2 a; @@ -72,4 +76,6 @@ bool getDeviceSupportsFloat16Math(int device); __half hostFloat2Half(float v); +#endif // FAISS_USE_FLOAT16 + } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/LoadStoreOperators.cuh b/core/src/index/thirdparty/faiss/gpu/utils/LoadStoreOperators.cuh index b0bb8b533074..b49d634461b2 100644 --- a/core/src/index/thirdparty/faiss/gpu/utils/LoadStoreOperators.cuh +++ b/core/src/index/thirdparty/faiss/gpu/utils/LoadStoreOperators.cuh @@ -35,6 +35,8 @@ struct LoadStore { } }; +#ifdef FAISS_USE_FLOAT16 + template <> struct LoadStore { static inline __device__ Half4 load(void* p) { @@ -87,4 +89,6 @@ struct LoadStore { } }; +#endif + } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/MathOperators.cuh b/core/src/index/thirdparty/faiss/gpu/utils/MathOperators.cuh index 68ccbd5686ad..7e9f25a2a01f 100644 --- a/core/src/index/thirdparty/faiss/gpu/utils/MathOperators.cuh +++ b/core/src/index/thirdparty/faiss/gpu/utils/MathOperators.cuh @@ -217,6 +217,7 @@ struct Math { } }; +#ifdef FAISS_USE_FLOAT16 template <> struct Math { typedef half ScalarType; @@ -555,5 +556,6 @@ struct Math { return h; } }; +#endif // FAISS_USE_FLOAT16 } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/WarpSelectHalf.cu b/core/src/index/thirdparty/faiss/gpu/utils/WarpSelectHalf.cu index 54e10be1e544..d700ecaee7f1 100644 --- a/core/src/index/thirdparty/faiss/gpu/utils/WarpSelectHalf.cu +++ b/core/src/index/thirdparty/faiss/gpu/utils/WarpSelectHalf.cu @@ -10,6 +10,8 @@ namespace faiss { namespace gpu { +#ifdef FAISS_USE_FLOAT16 + // warp Q to thread Q: // 1, 1 // 32, 2 @@ -91,4 +93,6 @@ void runWarpSelect(Tensor& in, } } +#endif // FAISS_USE_FLOAT16 + } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/WarpSelectKernel.cuh b/core/src/index/thirdparty/faiss/gpu/utils/WarpSelectKernel.cuh index 3c122e88617c..1b690b03061f 100644 --- a/core/src/index/thirdparty/faiss/gpu/utils/WarpSelectKernel.cuh +++ b/core/src/index/thirdparty/faiss/gpu/utils/WarpSelectKernel.cuh @@ -62,9 +62,11 @@ void runWarpSelect(Tensor& in, Tensor& outIndices, bool dir, int k, cudaStream_t stream); +#ifdef FAISS_USE_FLOAT16 void runWarpSelect(Tensor& in, Tensor& outKeys, Tensor& outIndices, bool dir, int k, cudaStream_t stream); +#endif } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/WarpShuffles.cuh b/core/src/index/thirdparty/faiss/gpu/utils/WarpShuffles.cuh index 504c73f79a5c..ec2e5b618c54 100644 --- a/core/src/index/thirdparty/faiss/gpu/utils/WarpShuffles.cuh +++ b/core/src/index/thirdparty/faiss/gpu/utils/WarpShuffles.cuh @@ -91,6 +91,7 @@ inline __device__ T* shfl_xor(T* const val, return (T*) shfl_xor(v, laneMask, width); } +#ifdef FAISS_USE_FLOAT16 // CUDA 9.0+ has half shuffle #if CUDA_VERSION < 9000 inline __device__ half shfl(half v, @@ -113,5 +114,6 @@ inline __device__ half shfl_xor(half v, return h; } #endif // CUDA_VERSION +#endif } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalf1.cu b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalf1.cu index 88f1d21b57e2..d2525935c22d 100644 --- a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalf1.cu +++ b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalf1.cu @@ -9,7 +9,9 @@ namespace faiss { namespace gpu { +#ifdef FAISS_USE_FLOAT16 BLOCK_SELECT_IMPL(half, true, 1, 1); BLOCK_SELECT_IMPL(half, false, 1, 1); +#endif } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalf128.cu b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalf128.cu index b38c00b83ee1..3759af93428f 100644 --- a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalf128.cu +++ b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalf128.cu @@ -9,7 +9,9 @@ namespace faiss { namespace gpu { +#ifdef FAISS_USE_FLOAT16 BLOCK_SELECT_IMPL(half, true, 128, 3); BLOCK_SELECT_IMPL(half, false, 128, 3); +#endif } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalf256.cu b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalf256.cu index 2cea11ace24f..a8a5cf13e923 100644 --- a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalf256.cu +++ b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalf256.cu @@ -9,7 +9,9 @@ namespace faiss { namespace gpu { +#ifdef FAISS_USE_FLOAT16 BLOCK_SELECT_IMPL(half, true, 256, 4); BLOCK_SELECT_IMPL(half, false, 256, 4); +#endif } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalf32.cu b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalf32.cu index 6045a52feaaf..18907c51196e 100644 --- a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalf32.cu +++ b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalf32.cu @@ -9,7 +9,9 @@ namespace faiss { namespace gpu { +#ifdef FAISS_USE_FLOAT16 BLOCK_SELECT_IMPL(half, true, 32, 2); BLOCK_SELECT_IMPL(half, false, 32, 2); +#endif } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalf64.cu b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalf64.cu index ea4b0bf64be4..81a9a84a9faa 100644 --- a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalf64.cu +++ b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalf64.cu @@ -9,7 +9,9 @@ namespace faiss { namespace gpu { +#ifdef FAISS_USE_FLOAT16 BLOCK_SELECT_IMPL(half, true, 64, 3); BLOCK_SELECT_IMPL(half, false, 64, 3); +#endif } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfF1024.cu b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfF1024.cu index 710e8c846068..e83b615193d5 100644 --- a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfF1024.cu +++ b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfF1024.cu @@ -9,6 +9,8 @@ namespace faiss { namespace gpu { +#ifdef FAISS_USE_FLOAT16 BLOCK_SELECT_IMPL(half, false, 1024, 8); +#endif } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfF2048.cu b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfF2048.cu index 5f7f4d4f6b2e..e06c334481cc 100644 --- a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfF2048.cu +++ b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfF2048.cu @@ -11,7 +11,9 @@ namespace faiss { namespace gpu { #if GPU_MAX_SELECTION_K >= 2048 +#ifdef FAISS_USE_FLOAT16 BLOCK_SELECT_IMPL(half, false, 2048, 8); #endif +#endif } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfF512.cu b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfF512.cu index 07ea1f9f6bdc..c1b67bd3de26 100644 --- a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfF512.cu +++ b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfF512.cu @@ -9,6 +9,8 @@ namespace faiss { namespace gpu { +#ifdef FAISS_USE_FLOAT16 BLOCK_SELECT_IMPL(half, false, 512, 8); +#endif } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfT1024.cu b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfT1024.cu index 6dc37accf707..2fd0dffa37e0 100644 --- a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfT1024.cu +++ b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfT1024.cu @@ -9,6 +9,8 @@ namespace faiss { namespace gpu { +#ifdef FAISS_USE_FLOAT16 BLOCK_SELECT_IMPL(half, true, 1024, 8); +#endif } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfT2048.cu b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfT2048.cu index dd38b8d6a5e6..f91b6787e273 100644 --- a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfT2048.cu +++ b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfT2048.cu @@ -11,7 +11,9 @@ namespace faiss { namespace gpu { #if GPU_MAX_SELECTION_K >= 2048 +#ifdef FAISS_USE_FLOAT16 BLOCK_SELECT_IMPL(half, true, 2048, 8); #endif +#endif } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfT512.cu b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfT512.cu index ff2a9903faee..a2877db6ed62 100644 --- a/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfT512.cu +++ b/core/src/index/thirdparty/faiss/gpu/utils/blockselect/BlockSelectHalfT512.cu @@ -9,6 +9,8 @@ namespace faiss { namespace gpu { +#ifdef FAISS_USE_FLOAT16 BLOCK_SELECT_IMPL(half, true, 512, 8); +#endif } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalf1.cu b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalf1.cu index 79876207f7f0..da3206d45499 100644 --- a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalf1.cu +++ b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalf1.cu @@ -9,7 +9,9 @@ namespace faiss { namespace gpu { +#ifdef FAISS_USE_FLOAT16 WARP_SELECT_IMPL(half, true, 1, 1); WARP_SELECT_IMPL(half, false, 1, 1); +#endif } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalf128.cu b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalf128.cu index 150c9507dae5..8705e593c542 100644 --- a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalf128.cu +++ b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalf128.cu @@ -9,7 +9,9 @@ namespace faiss { namespace gpu { +#ifdef FAISS_USE_FLOAT16 WARP_SELECT_IMPL(half, true, 128, 3); WARP_SELECT_IMPL(half, false, 128, 3); +#endif } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalf256.cu b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalf256.cu index cd8b49b18fe6..a7af219582fd 100644 --- a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalf256.cu +++ b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalf256.cu @@ -9,7 +9,9 @@ namespace faiss { namespace gpu { +#ifdef FAISS_USE_FLOAT16 WARP_SELECT_IMPL(half, true, 256, 4); WARP_SELECT_IMPL(half, false, 256, 4); +#endif } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalf32.cu b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalf32.cu index ce1b7e4c74b0..d7ed389aec5f 100644 --- a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalf32.cu +++ b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalf32.cu @@ -9,7 +9,9 @@ namespace faiss { namespace gpu { +#ifdef FAISS_USE_FLOAT16 WARP_SELECT_IMPL(half, true, 32, 2); WARP_SELECT_IMPL(half, false, 32, 2); +#endif } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalf64.cu b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalf64.cu index 9d4311ec016e..fea6c40b9c86 100644 --- a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalf64.cu +++ b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalf64.cu @@ -9,7 +9,9 @@ namespace faiss { namespace gpu { +#ifdef FAISS_USE_FLOAT16 WARP_SELECT_IMPL(half, true, 64, 3); WARP_SELECT_IMPL(half, false, 64, 3); +#endif } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfF1024.cu b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfF1024.cu index 02413001419a..d99eea9c7c56 100644 --- a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfF1024.cu +++ b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfF1024.cu @@ -9,6 +9,8 @@ namespace faiss { namespace gpu { +#ifdef FAISS_USE_FLOAT16 WARP_SELECT_IMPL(half, false, 1024, 8); +#endif } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfF2048.cu b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfF2048.cu index 1a16ee45c9de..030d28e17ff3 100644 --- a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfF2048.cu +++ b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfF2048.cu @@ -11,7 +11,9 @@ namespace faiss { namespace gpu { #if GPU_MAX_SELECTION_K >= 2048 +#ifdef FAISS_USE_FLOAT16 WARP_SELECT_IMPL(half, false, 2048, 8); #endif +#endif } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfF512.cu b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfF512.cu index 4cb138837b65..651d7275801a 100644 --- a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfF512.cu +++ b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfF512.cu @@ -9,6 +9,8 @@ namespace faiss { namespace gpu { +#ifdef FAISS_USE_FLOAT16 WARP_SELECT_IMPL(half, false, 512, 8); +#endif } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfT1024.cu b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfT1024.cu index 6a95007ff857..5a576d7c486c 100644 --- a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfT1024.cu +++ b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfT1024.cu @@ -9,6 +9,8 @@ namespace faiss { namespace gpu { +#ifdef FAISS_USE_FLOAT16 WARP_SELECT_IMPL(half, true, 1024, 8); +#endif } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfT2048.cu b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfT2048.cu index 94586d0100a7..b5bd1f9e5368 100644 --- a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfT2048.cu +++ b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfT2048.cu @@ -11,7 +11,9 @@ namespace faiss { namespace gpu { #if GPU_MAX_SELECTION_K >= 2048 +#ifdef FAISS_USE_FLOAT16 WARP_SELECT_IMPL(half, true, 2048, 8); #endif +#endif } } // namespace diff --git a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfT512.cu b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfT512.cu index 6ca08a16abb7..21b86602736f 100644 --- a/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfT512.cu +++ b/core/src/index/thirdparty/faiss/gpu/utils/warpselect/WarpSelectHalfT512.cu @@ -9,6 +9,8 @@ namespace faiss { namespace gpu { +#ifdef FAISS_USE_FLOAT16 WARP_SELECT_IMPL(half, true, 512, 8); +#endif } } // namespace diff --git a/core/src/index/thirdparty/faiss/impl/index_read.cpp b/core/src/index/thirdparty/faiss/impl/index_read.cpp index 85cec7d39f0d..24556606ee95 100644 --- a/core/src/index/thirdparty/faiss/impl/index_read.cpp +++ b/core/src/index/thirdparty/faiss/impl/index_read.cpp @@ -343,6 +343,89 @@ static void read_InvertedLists ( ivf->own_invlists = true; } +InvertedLists *read_InvertedLists_nm (IOReader *f, int io_flags) { + uint32_t h; + READ1 (h); + if (h == fourcc ("il00")) { + fprintf(stderr, "read_InvertedLists:" + " WARN! inverted lists not stored with IVF object\n"); + return nullptr; + } else if (h == fourcc ("iloa") && !(io_flags & IO_FLAG_MMAP)) { + // not going to happen + return nullptr; + } else if (h == fourcc ("ilar") && !(io_flags & IO_FLAG_MMAP)) { + auto ails = new ArrayInvertedLists (0, 0); + READ1 (ails->nlist); + READ1 (ails->code_size); + ails->ids.resize (ails->nlist); + std::vector sizes (ails->nlist); + read_ArrayInvertedLists_sizes (f, sizes); + for (size_t i = 0; i < ails->nlist; i++) { + ails->ids[i].resize (sizes[i]); + } + for (size_t i = 0; i < ails->nlist; i++) { + size_t n = ails->ids[i].size(); + if (n > 0) { + READANDCHECK (ails->ids[i].data(), n); + } + } + return ails; + } else if (h == fourcc ("ilar") && (io_flags & IO_FLAG_MMAP)) { + // then we load it as an OnDiskInvertedLists + FileIOReader *reader = dynamic_cast(f); + FAISS_THROW_IF_NOT_MSG(reader, "mmap only supported for File objects"); + FILE *fdesc = reader->f; + + auto ails = new OnDiskInvertedLists (); + READ1 (ails->nlist); + READ1 (ails->code_size); + ails->read_only = true; + ails->lists.resize (ails->nlist); + std::vector sizes (ails->nlist); + read_ArrayInvertedLists_sizes (f, sizes); + size_t o0 = ftell(fdesc), o = o0; + { // do the mmap + struct stat buf; + int ret = fstat (fileno(fdesc), &buf); + FAISS_THROW_IF_NOT_FMT (ret == 0, + "fstat failed: %s", strerror(errno)); + ails->totsize = buf.st_size; + ails->ptr = (uint8_t*)mmap (nullptr, ails->totsize, + PROT_READ, MAP_SHARED, + fileno(fdesc), 0); + FAISS_THROW_IF_NOT_FMT (ails->ptr != MAP_FAILED, + "could not mmap: %s", + strerror(errno)); + } + + for (size_t i = 0; i < ails->nlist; i++) { + OnDiskInvertedLists::List & l = ails->lists[i]; + l.size = l.capacity = sizes[i]; + l.offset = o; + o += l.size * (sizeof(OnDiskInvertedLists::idx_t) + + ails->code_size); + } + FAISS_THROW_IF_NOT(o <= ails->totsize); + // resume normal reading of file + fseek (fdesc, o, SEEK_SET); + return ails; + } else if (h == fourcc ("ilod")) { + // not going to happen + return nullptr; + } else { + FAISS_THROW_MSG ("read_InvertedLists: unsupported invlist type"); + } +} + +static void read_InvertedLists_nm ( + IndexIVF *ivf, IOReader *f, int io_flags) { + InvertedLists *ils = read_InvertedLists_nm (f, io_flags); + FAISS_THROW_IF_NOT (!ils || (ils->nlist == ivf->nlist && + ils->code_size == ivf->code_size)); + ivf->invlists = ils; + ivf->own_invlists = true; +} + static void read_ProductQuantizer (ProductQuantizer *pq, IOReader *f) { READ1 (pq->d); READ1 (pq->M); @@ -736,6 +819,52 @@ Index *read_index (const char *fname, int io_flags) { return idx; } +// read offset-only index +Index *read_index_nm (IOReader *f, int io_flags) { + Index * idx = nullptr; + uint32_t h; + READ1 (h); + if (h == fourcc ("IwFl")) { + IndexIVFFlat * ivfl = new IndexIVFFlat (); + read_ivf_header (ivfl, f); + ivfl->code_size = ivfl->d * sizeof(float); + read_InvertedLists_nm (ivfl, f, io_flags); + idx = ivfl; + } else if(h == fourcc ("IwSq")) { + IndexIVFScalarQuantizer * ivsc = new IndexIVFScalarQuantizer(); + read_ivf_header (ivsc, f); + read_ScalarQuantizer (&ivsc->sq, f); + READ1 (ivsc->code_size); + READ1 (ivsc->by_residual); + read_InvertedLists_nm (ivsc, f, io_flags); + idx = ivsc; + } else if (h == fourcc("ISqH")) { + IndexIVFSQHybrid *ivfsqhbyrid = new IndexIVFSQHybrid(); + read_ivf_header(ivfsqhbyrid, f); + read_ScalarQuantizer(&ivfsqhbyrid->sq, f); + READ1 (ivfsqhbyrid->code_size); + READ1 (ivfsqhbyrid->by_residual); + read_InvertedLists_nm(ivfsqhbyrid, f, io_flags); + idx = ivfsqhbyrid; + } else { + FAISS_THROW_FMT("Index type 0x%08x not supported\n", h); + idx = nullptr; + } + return idx; +} + + +Index *read_index_nm (FILE * f, int io_flags) { + FileIOReader reader(f); + return read_index_nm(&reader, io_flags); +} + +Index *read_index_nm (const char *fname, int io_flags) { + FileIOReader reader(fname); + Index *idx = read_index_nm (&reader, io_flags); + return idx; +} + VectorTransform *read_VectorTransform (const char *fname) { FileIOReader reader(fname); VectorTransform *vt = read_VectorTransform (&reader); @@ -917,4 +1046,4 @@ IndexBinary *read_index_binary (const char *fname, int io_flags) { } -} // namespace faiss +} // namespace faiss \ No newline at end of file diff --git a/core/src/index/thirdparty/faiss/impl/index_write.cpp b/core/src/index/thirdparty/faiss/impl/index_write.cpp index 54fce2fc460b..ef7720a2736a 100644 --- a/core/src/index/thirdparty/faiss/impl/index_write.cpp +++ b/core/src/index/thirdparty/faiss/impl/index_write.cpp @@ -286,6 +286,63 @@ void write_InvertedLists (const InvertedLists *ils, IOWriter *f) { } } +// write inverted lists for offset-only index +void write_InvertedLists_nm (const InvertedLists *ils, IOWriter *f) { + if (ils == nullptr) { + uint32_t h = fourcc ("il00"); + WRITE1 (h); + } else if (const auto & ails = + dynamic_cast(ils)) { + uint32_t h = fourcc ("ilar"); + WRITE1 (h); + WRITE1 (ails->nlist); + WRITE1 (ails->code_size); + // here we store either as a full or a sparse data buffer + size_t n_non0 = 0; + for (size_t i = 0; i < ails->nlist; i++) { + if (ails->ids[i].size() > 0) + n_non0++; + } + if (n_non0 > ails->nlist / 2) { + uint32_t list_type = fourcc("full"); + WRITE1 (list_type); + std::vector sizes; + for (size_t i = 0; i < ails->nlist; i++) { + sizes.push_back (ails->ids[i].size()); + } + WRITEVECTOR (sizes); + } else { + int list_type = fourcc("sprs"); // sparse + WRITE1 (list_type); + std::vector sizes; + for (size_t i = 0; i < ails->nlist; i++) { + size_t n = ails->ids[i].size(); + if (n > 0) { + sizes.push_back (i); + sizes.push_back (n); + } + } + WRITEVECTOR (sizes); + } + // make a single contiguous data buffer (useful for mmapping) + for (size_t i = 0; i < ails->nlist; i++) { + size_t n = ails->ids[i].size(); + if (n > 0) { + // WRITEANDCHECK (ails->codes[i].data(), n * ails->code_size); + WRITEANDCHECK (ails->ids[i].data(), n); + } + } + } else if (const auto & oa = + dynamic_cast(ils)) { + // not going to happen + } else { + fprintf(stderr, "WARN! write_InvertedLists: unsupported invlist type, " + "saving null invlist\n"); + uint32_t h = fourcc ("il00"); + WRITE1 (h); + } +} + void write_ProductQuantizer (const ProductQuantizer*pq, const char *fname) { FileIOWriter writer(fname); @@ -518,6 +575,47 @@ void write_index (const Index *idx, const char *fname) { write_index (idx, &writer); } +// write index for offset-only index +void write_index_nm (const Index *idx, IOWriter *f) { + if(const IndexIVFFlat * ivfl = + dynamic_cast (idx)) { + uint32_t h = fourcc ("IwFl"); + WRITE1 (h); + write_ivf_header (ivfl, f); + write_InvertedLists_nm (ivfl->invlists, f); + } else if(const IndexIVFScalarQuantizer * ivsc = + dynamic_cast (idx)) { + uint32_t h = fourcc ("IwSq"); + WRITE1 (h); + write_ivf_header (ivsc, f); + write_ScalarQuantizer (&ivsc->sq, f); + WRITE1 (ivsc->code_size); + WRITE1 (ivsc->by_residual); + write_InvertedLists_nm (ivsc->invlists, f); + } else if(const IndexIVFSQHybrid *ivfsqhbyrid = + dynamic_cast(idx)) { + uint32_t h = fourcc ("ISqH"); + WRITE1 (h); + write_ivf_header (ivfsqhbyrid, f); + write_ScalarQuantizer (&ivfsqhbyrid->sq, f); + WRITE1 (ivfsqhbyrid->code_size); + WRITE1 (ivfsqhbyrid->by_residual); + write_InvertedLists_nm (ivfsqhbyrid->invlists, f); + } else { + FAISS_THROW_MSG ("don't know how to serialize this type of index"); + } +} + +void write_index_nm (const Index *idx, FILE *f) { + FileIOWriter writer(f); + write_index_nm (idx, &writer); +} + +void write_index_nm (const Index *idx, const char *fname) { + FileIOWriter writer(fname); + write_index_nm (idx, &writer); +} + void write_VectorTransform (const VectorTransform *vt, const char *fname) { FileIOWriter writer(fname); write_VectorTransform (vt, &writer); diff --git a/core/src/index/thirdparty/faiss/index_io.h b/core/src/index/thirdparty/faiss/index_io.h index 5aef62c87b4f..ac686da71cb1 100644 --- a/core/src/index/thirdparty/faiss/index_io.h +++ b/core/src/index/thirdparty/faiss/index_io.h @@ -37,6 +37,10 @@ void write_index (const Index *idx, const char *fname); void write_index (const Index *idx, FILE *f); void write_index (const Index *idx, IOWriter *writer); +void write_index_nm (const Index *idx, const char *fname); +void write_index_nm (const Index *idx, FILE *f); +void write_index_nm (const Index *idx, IOWriter *writer); + void write_index_binary (const IndexBinary *idx, const char *fname); void write_index_binary (const IndexBinary *idx, FILE *f); void write_index_binary (const IndexBinary *idx, IOWriter *writer); @@ -52,6 +56,10 @@ Index *read_index (const char *fname, int io_flags = 0); Index *read_index (FILE * f, int io_flags = 0); Index *read_index (IOReader *reader, int io_flags = 0); +Index *read_index_nm (const char *fname, int io_flags = 0); +Index *read_index_nm (FILE * f, int io_flags = 0); +Index *read_index_nm (IOReader *reader, int io_flags = 0); + IndexBinary *read_index_binary (const char *fname, int io_flags = 0); IndexBinary *read_index_binary (FILE * f, int io_flags = 0); IndexBinary *read_index_binary (IOReader *reader, int io_flags = 0); @@ -68,6 +76,9 @@ void write_ProductQuantizer (const ProductQuantizer*pq, IOWriter *f); void write_InvertedLists (const InvertedLists *ils, IOWriter *f); InvertedLists *read_InvertedLists (IOReader *reader, int io_flags = 0); +void write_InvertedLists_nm (const InvertedLists *ils, IOWriter *f); +InvertedLists *read_InvertedLists_nm (IOReader *reader, int io_flags = 0); + } // namespace faiss diff --git a/core/src/index/thirdparty/hnswlib/bruteforce.h b/core/src/index/thirdparty/hnswlib/bruteforce.h index ae2fa6a8f6ef..ebdda36d4345 100644 --- a/core/src/index/thirdparty/hnswlib/bruteforce.h +++ b/core/src/index/thirdparty/hnswlib/bruteforce.h @@ -4,7 +4,7 @@ #include #include -namespace hnswlib { +namespace hnswlib_nm { template class BruteforceSearch : public AlgorithmInterface { diff --git a/core/src/index/thirdparty/hnswlib/hnswalg.h b/core/src/index/thirdparty/hnswlib/hnswalg.h index 8c54c8c39f34..4bd49030f05e 100644 --- a/core/src/index/thirdparty/hnswlib/hnswalg.h +++ b/core/src/index/thirdparty/hnswlib/hnswalg.h @@ -28,9 +28,9 @@ class HierarchicalNSW : public AlgorithmInterface { link_list_locks_(max_elements), element_levels_(max_elements) { // linxj space = s; - if (auto x = dynamic_cast(s)) { + if (auto x = dynamic_cast(s)) { metric_type_ = 0; - } else if (auto x = dynamic_cast(s)) { + } else if (auto x = dynamic_cast(s)) { metric_type_ = 1; } else { metric_type_ = 100; @@ -62,7 +62,7 @@ class HierarchicalNSW : public AlgorithmInterface { cur_element_count = 0; - visited_list_pool_ = new VisitedListPool(1, max_elements); + visited_list_pool_ = new hnswlib_nm::VisitedListPool(1, max_elements); @@ -75,7 +75,6 @@ class HierarchicalNSW : public AlgorithmInterface { throw std::runtime_error("Not enough memory: HierarchicalNSW failed to allocate linklists"); size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); mult_ = 1 / log(1.0 * M_); - revSize_ = 1.0 / mult_; } struct CompareByFirst { @@ -113,11 +112,11 @@ class HierarchicalNSW : public AlgorithmInterface { size_t maxM0_; size_t ef_construction_; - double mult_, revSize_; + double mult_; int maxlevel_; - VisitedListPool *visited_list_pool_; + hnswlib_nm::VisitedListPool *visited_list_pool_; std::mutex cur_element_count_guard_; std::vector link_list_locks_; @@ -170,9 +169,9 @@ class HierarchicalNSW : public AlgorithmInterface { std::priority_queue, std::vector>, CompareByFirst> searchBaseLayer(tableint ep_id, const void *data_point, int layer) { - VisitedList *vl = visited_list_pool_->getFreeVisitedList(); - vl_type *visited_array = vl->mass; - vl_type visited_array_tag = vl->curV; + hnswlib_nm::VisitedList *vl = visited_list_pool_->getFreeVisitedList(); + hnswlib_nm::vl_type *visited_array = vl->mass; + hnswlib_nm::vl_type visited_array_tag = vl->curV; std::priority_queue, std::vector>, CompareByFirst> top_candidates; std::priority_queue, std::vector>, CompareByFirst> candidateSet; @@ -253,9 +252,9 @@ class HierarchicalNSW : public AlgorithmInterface { template std::priority_queue, std::vector>, CompareByFirst> searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, faiss::ConcurrentBitsetPtr bitset) const { - VisitedList *vl = visited_list_pool_->getFreeVisitedList(); - vl_type *visited_array = vl->mass; - vl_type visited_array_tag = vl->curV; + hnswlib_nm::VisitedList *vl = visited_list_pool_->getFreeVisitedList(); + hnswlib_nm::vl_type *visited_array = vl->mass; + hnswlib_nm::vl_type visited_array_tag = vl->curV; std::priority_queue, std::vector>, CompareByFirst> top_candidates; std::priority_queue, std::vector>, CompareByFirst> candidate_set; @@ -556,7 +555,7 @@ class HierarchicalNSW : public AlgorithmInterface { delete visited_list_pool_; - visited_list_pool_ = new VisitedListPool(1, new_max_elements); + visited_list_pool_ = new hnswlib_nm::VisitedListPool(1, new_max_elements); @@ -624,9 +623,9 @@ class HierarchicalNSW : public AlgorithmInterface { readBinaryPOD(input, data_size_); readBinaryPOD(input, dim); if (metric_type_ == 0) { - space = new hnswlib::L2Space(dim); + space = new hnswlib_nm::L2Space(dim); } else if (metric_type_ == 1) { - space = new hnswlib::InnerProductSpace(dim); + space = new hnswlib_nm::InnerProductSpace(dim); } else { // throw exception } @@ -702,14 +701,13 @@ class HierarchicalNSW : public AlgorithmInterface { std::vector(max_elements).swap(link_list_locks_); - visited_list_pool_ = new VisitedListPool(1, max_elements); + visited_list_pool_ = new hnswlib_nm::VisitedListPool(1, max_elements); linkLists_ = (char **) malloc(sizeof(void *) * max_elements); if (linkLists_ == nullptr) throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists"); element_levels_ = std::vector(max_elements); - revSize_ = 1.0 / mult_; ef_ = 10; for (size_t i = 0; i < cur_element_count; i++) { label_lookup_[getExternalLabel(i)]=i; @@ -840,13 +838,12 @@ class HierarchicalNSW : public AlgorithmInterface { size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); std::vector(max_elements).swap(link_list_locks_); - visited_list_pool_ = new VisitedListPool(1, max_elements); + visited_list_pool_ = new hnswlib_nm::VisitedListPool(1, max_elements); linkLists_ = (char **) malloc(sizeof(void *) * max_elements); if (linkLists_ == nullptr) throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists"); element_levels_ = std::vector(max_elements); - revSize_ = 1.0 / mult_; ef_ = 10; for (size_t i = 0; i < cur_element_count; i++) { label_lookup_[getExternalLabel(i)]=i; @@ -1130,6 +1127,30 @@ class HierarchicalNSW : public AlgorithmInterface { return result; } + + void addPoint(void *datapoint, labeltype label, size_t base, size_t offset) { + return; + } + + std::priority_queue> searchKnn_NM(const void* query_data, size_t k, faiss::ConcurrentBitsetPtr bitset, dist_t *pdata) const { + std::priority_queue> ret; + return ret; + } + + int64_t cal_size() { + int64_t ret = 0; + ret += sizeof(*this); + ret += sizeof(*space); + ret += visited_list_pool_->GetSize(); + ret += link_list_locks_.size() * sizeof(std::mutex); + ret += element_levels_.size() * sizeof(int); + ret += max_elements_ * size_data_per_element_; + ret += max_elements_ * sizeof(void*); + for (auto i = 0; i < max_elements_; ++ i) { + ret += linkLists_[i] ? size_links_per_element_ * element_levels_[i] : 0; + } + return ret; + } }; } diff --git a/core/src/index/thirdparty/hnswlib/hnswalg_nm.h b/core/src/index/thirdparty/hnswlib/hnswalg_nm.h new file mode 100644 index 000000000000..ffdd1985a2cd --- /dev/null +++ b/core/src/index/thirdparty/hnswlib/hnswalg_nm.h @@ -0,0 +1,1227 @@ +#pragma once + +#include "visited_list_pool.h" +#include "hnswlib_nm.h" +#include +#include +#include +#include + +#include "knowhere/index/vector_index/helpers/FaissIO.h" +#include "faiss/impl/ScalarQuantizer.h" +#include "faiss/impl/ScalarQuantizerCodec.h" + +namespace hnswlib_nm { + + typedef unsigned int tableint; + typedef unsigned int linklistsizeint; + + using QuantizerClass = faiss::QuantizerTemplate; + using DCClassIP = faiss::DCTemplate, 1>; + using DCClassL2 = faiss::DCTemplate, 1>; + + template + class HierarchicalNSW_NM : public AlgorithmInterface { + public: + HierarchicalNSW_NM(SpaceInterface *s) { + } + + HierarchicalNSW_NM(SpaceInterface *s, const std::string &location, bool nmslib = false, size_t max_elements=0) { + loadIndex(location, s, max_elements); + } + + HierarchicalNSW_NM(SpaceInterface *s, size_t max_elements, size_t M = 16, size_t ef_construction = 200, size_t random_seed = 100) : + link_list_locks_(max_elements), element_levels_(max_elements) { + // linxj + space = s; + if (auto x = dynamic_cast(s)) { + metric_type_ = 0; + } else if (auto x = dynamic_cast(s)) { + metric_type_ = 1; + } else { + metric_type_ = 100; + } + + max_elements_ = max_elements; + + has_deletions_=false; + data_size_ = s->get_data_size(); + fstdistfunc_ = s->get_dist_func(); + dist_func_param_ = s->get_dist_func_param(); + M_ = M; + maxM_ = M_; + maxM0_ = M_ * 2; + ef_construction_ = std::max(ef_construction,M_); + ef_ = 10; + + is_sq8_ = false; + sq_ = nullptr; + + level_generator_.seed(random_seed); + + size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); + size_data_per_element_ = size_links_level0_; // + sizeof(labeltype); + data_size_;; +// label_offset_ = size_links_level0_; + + data_level0_memory_ = (char *) malloc(max_elements_ * size_data_per_element_); + if (data_level0_memory_ == nullptr) + throw std::runtime_error("Not enough memory"); + + cur_element_count = 0; + + visited_list_pool_ = new hnswlib_nm::VisitedListPool(1, max_elements); + + + + //initializations for special treatment of the first node + enterpoint_node_ = -1; + maxlevel_ = -1; + + linkLists_ = (char **) malloc(sizeof(void *) * max_elements_); + if (linkLists_ == nullptr) + throw std::runtime_error("Not enough memory: HierarchicalNSW_NM failed to allocate linklists"); + size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); + mult_ = 1 / log(1.0 * M_); + revSize_ = 1.0 / mult_; + } + + struct CompareByFirst { + constexpr bool operator()(std::pair const &a, + std::pair const &b) const noexcept { + return a.first < b.first; + } + }; + + ~HierarchicalNSW_NM() { + + free(data_level0_memory_); + for (tableint i = 0; i < cur_element_count; i++) { + if (element_levels_[i] > 0) + free(linkLists_[i]); + } + free(linkLists_); + delete visited_list_pool_; + + if (sq_) delete sq_; + + // linxj: delete + delete space; + } + + // linxj: use for free resource + SpaceInterface *space; + size_t metric_type_; // 0:l2, 1:ip + + size_t max_elements_; + size_t cur_element_count; + size_t size_data_per_element_; + size_t size_links_per_element_; + + size_t M_; + size_t maxM_; + size_t maxM0_; + size_t ef_construction_; + + bool is_sq8_ = false; + faiss::ScalarQuantizer *sq_ = nullptr; + + double mult_, revSize_; + int maxlevel_; + + + VisitedListPool *visited_list_pool_; + std::mutex cur_element_count_guard_; + + std::vector link_list_locks_; + tableint enterpoint_node_; + + + size_t size_links_level0_; + + + char *data_level0_memory_; + char **linkLists_; + std::vector element_levels_; + + size_t data_size_; + + bool has_deletions_; + + + DISTFUNC fstdistfunc_; + void *dist_func_param_; + + std::default_random_engine level_generator_; + + inline char *getDataByInternalId(void *pdata, tableint offset) const { + return ((char*)pdata + offset * data_size_); + } + + void SetSq8(const float *trained) { + if (!trained) + throw std::runtime_error("trained sq8 data cannot be null in SetSq8!"); + if (sq_) delete sq_; + is_sq8_ = true; + sq_ = new faiss::ScalarQuantizer(*(size_t*)dist_func_param_, faiss::QuantizerType::QT_8bit); // hard code + sq_->trained.resize((sq_->d) << 1); + memcpy(sq_->trained.data(), trained, sq_->trained.size() * sizeof(float)); + } + + void sq_train(size_t nb, const float *xb, uint8_t *p_codes) { + if (!p_codes) + throw std::runtime_error("p_codes cannot be null in sq_train!"); + if (!xb) + throw std::runtime_error("base vector cannot be null in sq_train!"); + if (sq_) delete sq_; + is_sq8_ = true; + sq_ = new faiss::ScalarQuantizer(*(size_t*)dist_func_param_, faiss::QuantizerType::QT_8bit); // hard code + sq_->train(nb, xb); + sq_->compute_codes(xb, p_codes, nb); + memcpy(p_codes + *(size_t*)dist_func_param_ * nb, sq_->trained.data(), *(size_t*)dist_func_param_ * sizeof(float) * 2); + } + + int getRandomLevel(double reverse_size) { + std::uniform_real_distribution distribution(0.0, 1.0); + double r = -log(distribution(level_generator_)) * reverse_size; + return (int) r; + } + + std::priority_queue, std::vector>, CompareByFirst> + searchBaseLayer(tableint ep_id, const void *data_point, int layer, void *pdata) { + VisitedList *vl = visited_list_pool_->getFreeVisitedList(); + vl_type *visited_array = vl->mass; + vl_type visited_array_tag = vl->curV; + + std::priority_queue, std::vector>, CompareByFirst> top_candidates; + std::priority_queue, std::vector>, CompareByFirst> candidateSet; + + dist_t lowerBound; + if (!isMarkedDeleted(ep_id)) { + dist_t dist = fstdistfunc_(data_point, getDataByInternalId(pdata, ep_id), dist_func_param_); + top_candidates.emplace(dist, ep_id); + lowerBound = dist; + candidateSet.emplace(-dist, ep_id); + } else { + lowerBound = std::numeric_limits::max(); + candidateSet.emplace(-lowerBound, ep_id); + } + visited_array[ep_id] = visited_array_tag; + + while (!candidateSet.empty()) { + std::pair curr_el_pair = candidateSet.top(); + if ((-curr_el_pair.first) > lowerBound) { + break; + } + candidateSet.pop(); + + tableint curNodeNum = curr_el_pair.second; + + std::unique_lock lock(link_list_locks_[curNodeNum]); + + int *data;// = (int *)(linkList0_ + curNodeNum * size_links_per_element0_); + if (layer == 0) { + data = (int*)get_linklist0(curNodeNum); + } else { + data = (int*)get_linklist(curNodeNum, layer); + // data = (int *) (linkLists_[curNodeNum] + (layer - 1) * size_links_per_element_); + } + size_t size = getListCount((linklistsizeint*)data); + tableint *datal = (tableint *) (data + 1); +#ifdef USE_SSE + _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); + _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(pdata, *datal), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(pdata, *(datal + 1)), _MM_HINT_T0); +#endif + + for (size_t j = 0; j < size; j++) { + tableint candidate_id = *(datal + j); + // if (candidate_id == 0) continue; +#ifdef USE_SSE + _mm_prefetch((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(pdata, *(datal + j + 1)), _MM_HINT_T0); +#endif + if (visited_array[candidate_id] == visited_array_tag) continue; + visited_array[candidate_id] = visited_array_tag; + char *currObj1 = (getDataByInternalId(pdata, candidate_id)); + + dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_); + if (top_candidates.size() < ef_construction_ || lowerBound > dist1) { + candidateSet.emplace(-dist1, candidate_id); +#ifdef USE_SSE + _mm_prefetch(getDataByInternalId(pdata, candidateSet.top().second), _MM_HINT_T0); +#endif + + if (!isMarkedDeleted(candidate_id)) + top_candidates.emplace(dist1, candidate_id); + + if (top_candidates.size() > ef_construction_) + top_candidates.pop(); + + if (!top_candidates.empty()) + lowerBound = top_candidates.top().first; + } + } + } + visited_list_pool_->releaseVisitedList(vl); + + return top_candidates; + } + + template + std::priority_queue, std::vector>, CompareByFirst> + searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, faiss::ConcurrentBitsetPtr bitset, void *pdata) const { + VisitedList *vl = visited_list_pool_->getFreeVisitedList(); + vl_type *visited_array = vl->mass; + vl_type visited_array_tag = vl->curV; + + faiss::SQDistanceComputer *sqdc = nullptr; + if (is_sq8_) { + if (metric_type_ == 0) { // L2 + sqdc = new DCClassL2(sq_->d, sq_->trained); + } else if (metric_type_ == 1) { // IP + sqdc = new DCClassIP(sq_->d, sq_->trained); + } else { + throw std::runtime_error("unsupported metric_type, it must be 0(L2) or 1(IP)!"); + } + sqdc->code_size = sq_->code_size; + sqdc->codes = (uint8_t*)pdata; + sqdc->set_query((const float*)data_point); + } + + std::priority_queue, std::vector>, CompareByFirst> top_candidates; + std::priority_queue, std::vector>, CompareByFirst> candidate_set; + + dist_t lowerBound; +// if (!has_deletions || !isMarkedDeleted(ep_id)) { + if (!has_deletions || !bitset->test((faiss::ConcurrentBitset::id_type_t)(ep_id))) { + dist_t dist; + if (is_sq8_) { + dist = (*sqdc)(ep_id); + } else { + dist = fstdistfunc_(data_point, getDataByInternalId(pdata, ep_id), dist_func_param_); + } + lowerBound = dist; + top_candidates.emplace(dist, ep_id); + candidate_set.emplace(-dist, ep_id); + } else { + lowerBound = std::numeric_limits::max(); + candidate_set.emplace(-lowerBound, ep_id); + } + + visited_array[ep_id] = visited_array_tag; + + while (!candidate_set.empty()) { + + std::pair current_node_pair = candidate_set.top(); + + if ((-current_node_pair.first) > lowerBound) { + break; + } + candidate_set.pop(); + + tableint current_node_id = current_node_pair.second; + int *data = (int *) get_linklist0(current_node_id); + size_t size = getListCount((linklistsizeint*)data); + // bool cur_node_deleted = isMarkedDeleted(current_node_id); + +#ifdef USE_SSE + _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); + _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); +// _mm_prefetch(data_level0_memory_ + (*(data + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(pdata, *(data + 1)), _MM_HINT_T0); + _mm_prefetch((char *) (data + 2), _MM_HINT_T0); +#endif + + for (size_t j = 1; j <= size; j++) { + int candidate_id = *(data + j); + // if (candidate_id == 0) continue; +#ifdef USE_SSE + _mm_prefetch((char *) (visited_array + *(data + j + 1)), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(pdata, *(data + j + 1)), + _MM_HINT_T0);//////////// +#endif + if (!(visited_array[candidate_id] == visited_array_tag)) { + + visited_array[candidate_id] = visited_array_tag; + + dist_t dist; + if (is_sq8_) { + dist = (*sqdc)(candidate_id); + } else { + char *currObj1 = (getDataByInternalId(pdata, candidate_id)); + dist = fstdistfunc_(data_point, currObj1, dist_func_param_); + } + + if (top_candidates.size() < ef || lowerBound > dist) { + candidate_set.emplace(-dist, candidate_id); +#ifdef USE_SSE + _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_,/////////// + _MM_HINT_T0);//////////////////////// +#endif + +// if (!has_deletions || !isMarkedDeleted(candidate_id)) + if (!has_deletions || (!bitset->test((faiss::ConcurrentBitset::id_type_t)(candidate_id)))) + top_candidates.emplace(dist, candidate_id); + + if (top_candidates.size() > ef) + top_candidates.pop(); + + if (!top_candidates.empty()) + lowerBound = top_candidates.top().first; + } + } + } + } + + visited_list_pool_->releaseVisitedList(vl); + if (is_sq8_) delete sqdc; + return top_candidates; + } + + void getNeighborsByHeuristic2( + std::priority_queue, std::vector>, CompareByFirst> &top_candidates, + const size_t M, tableint *ret, size_t &ret_len, void *pdata) { + if (top_candidates.size() < M) { + while (top_candidates.size() > 0) { + ret[ret_len ++] = top_candidates.top().second; + top_candidates.pop(); + } + return; + } + std::priority_queue> queue_closest; + std::vector> return_list; + while (top_candidates.size() > 0) { + queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second); + top_candidates.pop(); + } + + while (queue_closest.size()) { + if (return_list.size() >= M) + break; + std::pair curent_pair = queue_closest.top(); + dist_t dist_to_query = -curent_pair.first; + queue_closest.pop(); + bool good = true; + for (std::pair second_pair : return_list) { + dist_t curdist = + fstdistfunc_(getDataByInternalId(pdata, second_pair.second), + getDataByInternalId(pdata, curent_pair.second), + dist_func_param_);; + if (curdist < dist_to_query) { + good = false; + break; + } + } + if (good) { + return_list.push_back(curent_pair); + ret[ret_len ++] = curent_pair.second; + } + + + } + +// for (std::pair curent_pair : return_list) { +// +// top_candidates.emplace(-curent_pair.first, curent_pair.second); +// } + } + + + linklistsizeint *get_linklist0(tableint internal_id) const { + return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_); + }; + + linklistsizeint *get_linklist0(tableint internal_id, char *data_level0_memory_) const { + return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_); + }; + + linklistsizeint *get_linklist(tableint internal_id, int level) const { + return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_); + }; + + void mutuallyConnectNewElement(const void *data_point, tableint cur_c, + std::priority_queue, std::vector>, CompareByFirst> &top_candidates, + int level, void *pdata) { + + size_t Mcurmax = level ? maxM_ : maxM0_; +// std::vector selectedNeighbors; +// selectedNeighbors.reserve(M_); + tableint *selectedNeighbors = (tableint*)malloc(sizeof(tableint) * M_); + size_t selectedNeighbors_size = 0; + getNeighborsByHeuristic2(top_candidates, M_, selectedNeighbors, selectedNeighbors_size, pdata); + if (selectedNeighbors_size > M_) + throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic"); + +// while (top_candidates.size() > 0) { +// selectedNeighbors.push_back(top_candidates.top().second); +// top_candidates.pop(); +// } + + { + linklistsizeint *ll_cur; + if (level == 0) + ll_cur = get_linklist0(cur_c); + else + ll_cur = get_linklist(cur_c, level); + + if (*ll_cur) { + throw std::runtime_error("The newly inserted element should have blank link list"); + } + setListCount(ll_cur,(unsigned short)selectedNeighbors_size); + tableint *data = (tableint *) (ll_cur + 1); + + + for (size_t idx = 0; idx < selectedNeighbors_size; idx++) { + if (data[idx]) + throw std::runtime_error("Possible memory corruption"); + if (level > element_levels_[selectedNeighbors[idx]]) + throw std::runtime_error("Trying to make a link on a non-existent level"); + + data[idx] = selectedNeighbors[idx]; + + } + } + for (size_t idx = 0; idx < selectedNeighbors_size; idx++) { + + std::unique_lock lock(link_list_locks_[selectedNeighbors[idx]]); + + + linklistsizeint *ll_other; + if (level == 0) + ll_other = get_linklist0(selectedNeighbors[idx]); + else + ll_other = get_linklist(selectedNeighbors[idx], level); + + size_t sz_link_list_other = getListCount(ll_other); + + if (sz_link_list_other > Mcurmax) + throw std::runtime_error("Bad value of sz_link_list_other"); + if (selectedNeighbors[idx] == cur_c) + throw std::runtime_error("Trying to connect an element to itself"); + if (level > element_levels_[selectedNeighbors[idx]]) + throw std::runtime_error("Trying to make a link on a non-existent level"); + + tableint *data = (tableint *) (ll_other + 1); + if (sz_link_list_other < Mcurmax) { + data[sz_link_list_other] = cur_c; + setListCount(ll_other, sz_link_list_other + 1); + } else { + // finding the "weakest" element to replace it with the new one + dist_t d_max = fstdistfunc_(getDataByInternalId(pdata, cur_c), getDataByInternalId(pdata, selectedNeighbors[idx]), + dist_func_param_); + // Heuristic: + std::priority_queue, std::vector>, CompareByFirst> candidates; + candidates.emplace(d_max, cur_c); + + for (size_t j = 0; j < sz_link_list_other; j++) { + candidates.emplace( + fstdistfunc_(getDataByInternalId(pdata, data[j]), getDataByInternalId(pdata, selectedNeighbors[idx]), + dist_func_param_), data[j]); + } + + size_t indx = 0; + getNeighborsByHeuristic2(candidates, Mcurmax, data, indx, pdata); + +// while (candidates.size() > 0) { +// data[indx] = candidates.top().second; +// candidates.pop(); +// indx++; +// } + setListCount(ll_other, indx); + // Nearest K: + /*int indx = -1; + for (int j = 0; j < sz_link_list_other; j++) { + dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_); + if (d > d_max) { + indx = j; + d_max = d; + } + } + if (indx >= 0) { + data[indx] = cur_c; + } */ + } + + } + } + + std::mutex global; + size_t ef_; + + void setEf(size_t ef) { + ef_ = ef; + } + + + std::priority_queue> searchKnnInternal(void *query_data, int k, dist_t *pdata) { + std::priority_queue> top_candidates; + if (cur_element_count == 0) return top_candidates; + tableint currObj = enterpoint_node_; + dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(pdata, enterpoint_node_), dist_func_param_); + + for (size_t level = maxlevel_; level > 0; level--) { + bool changed = true; + while (changed) { + changed = false; + int *data; + data = (int *) get_linklist(currObj,level); + int size = getListCount(data); + tableint *datal = (tableint *) (data + 1); + for (int i = 0; i < size; i++) { + tableint cand = datal[i]; + if (cand < 0 || cand > max_elements_) + throw std::runtime_error("cand error"); + dist_t d = fstdistfunc_(query_data, getDataByInternalId(pdata, cand), dist_func_param_); + + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; + } + } + } + } + + if (has_deletions_) { + std::priority_queue> top_candidates1=searchBaseLayerST(currObj, query_data, + ef_, pdata); + top_candidates.swap(top_candidates1); + } + else{ + std::priority_queue> top_candidates1=searchBaseLayerST(currObj, query_data, + ef_, pdata); + top_candidates.swap(top_candidates1); + } + + while (top_candidates.size() > k) { + top_candidates.pop(); + } + return top_candidates; + }; + + void resizeIndex(size_t new_max_elements){ + if (new_max_elements(new_max_elements).swap(link_list_locks_); + + // Reallocate base layer + char * data_level0_memory_new = (char *) malloc(new_max_elements * size_data_per_element_); + if (data_level0_memory_new == nullptr) + throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer"); + memcpy(data_level0_memory_new, data_level0_memory_,cur_element_count * size_data_per_element_); + free(data_level0_memory_); + data_level0_memory_=data_level0_memory_new; + + // Reallocate all other layers + char ** linkLists_new = (char **) malloc(sizeof(void *) * new_max_elements); + if (linkLists_new == nullptr) + throw std::runtime_error("Not enough memory: resizeIndex failed to allocate other layers"); + memcpy(linkLists_new, linkLists_,cur_element_count * sizeof(void *)); + free(linkLists_); + linkLists_=linkLists_new; + + max_elements_=new_max_elements; + + } + + void saveIndex(milvus::knowhere::MemoryIOWriter& output) { + // write l2/ip calculator + writeBinaryPOD(output, metric_type_); + writeBinaryPOD(output, data_size_); + writeBinaryPOD(output, *((size_t *) dist_func_param_)); + +// writeBinaryPOD(output, offsetLevel0_); + writeBinaryPOD(output, max_elements_); + writeBinaryPOD(output, cur_element_count); + writeBinaryPOD(output, size_data_per_element_); +// writeBinaryPOD(output, label_offset_); +// writeBinaryPOD(output, offsetData_); + writeBinaryPOD(output, maxlevel_); + writeBinaryPOD(output, enterpoint_node_); + writeBinaryPOD(output, maxM_); + + writeBinaryPOD(output, maxM0_); + writeBinaryPOD(output, M_); + writeBinaryPOD(output, mult_); + writeBinaryPOD(output, ef_construction_); + + output.write(data_level0_memory_, cur_element_count * size_data_per_element_); + + for (size_t i = 0; i < cur_element_count; i++) { + unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0; + writeBinaryPOD(output, linkListSize); + if (linkListSize) + output.write(linkLists_[i], linkListSize); + } + // output.close(); + } + + void loadIndex(milvus::knowhere::MemoryIOReader& input, size_t max_elements_i = 0) { + // linxj: init with metrictype + size_t dim = 100; + readBinaryPOD(input, metric_type_); + readBinaryPOD(input, data_size_); + readBinaryPOD(input, dim); + if (metric_type_ == 0) { + space = new L2Space(dim); + } else if (metric_type_ == 1) { + space = new InnerProductSpace(dim); + } else { + // throw exception + } + fstdistfunc_ = space->get_dist_func(); + dist_func_param_ = space->get_dist_func_param(); + +// readBinaryPOD(input, offsetLevel0_); + readBinaryPOD(input, max_elements_); + readBinaryPOD(input, cur_element_count); + + size_t max_elements=max_elements_i; + if(max_elements < cur_element_count) + max_elements = max_elements_; + max_elements_ = max_elements; + readBinaryPOD(input, size_data_per_element_); +// readBinaryPOD(input, label_offset_); +// readBinaryPOD(input, offsetData_); + readBinaryPOD(input, maxlevel_); + readBinaryPOD(input, enterpoint_node_); + + readBinaryPOD(input, maxM_); + readBinaryPOD(input, maxM0_); + readBinaryPOD(input, M_); + readBinaryPOD(input, mult_); + readBinaryPOD(input, ef_construction_); + + + // data_size_ = s->get_data_size(); + // fstdistfunc_ = s->get_dist_func(); + // dist_func_param_ = s->get_dist_func_param(); + + // auto pos= input.rp; + + + // /// Optional - check if index is ok: + // + // input.seekg(cur_element_count * size_data_per_element_,input.cur); + // for (size_t i = 0; i < cur_element_count; i++) { + // if(input.tellg() < 0 || input.tellg()>=total_filesize){ + // throw std::runtime_error("Index seems to be corrupted or unsupported"); + // } + // + // unsigned int linkListSize; + // readBinaryPOD(input, linkListSize); + // if (linkListSize != 0) { + // input.seekg(linkListSize,input.cur); + // } + // } + // + // // throw exception if it either corrupted or old index + // if(input.tellg()!=total_filesize) + // throw std::runtime_error("Index seems to be corrupted or unsupported"); + // + // input.clear(); + // + // /// Optional check end + // + // input.seekg(pos,input.beg); + + + data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_); + if (data_level0_memory_ == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0"); + input.read(data_level0_memory_, cur_element_count * size_data_per_element_); + + + + + size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); + + + size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); + std::vector(max_elements).swap(link_list_locks_); + + + visited_list_pool_ = new VisitedListPool(1, max_elements); + + + linkLists_ = (char **) malloc(sizeof(void *) * max_elements); + if (linkLists_ == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists"); + element_levels_ = std::vector(max_elements); + revSize_ = 1.0 / mult_; + ef_ = 10; + for (size_t i = 0; i < cur_element_count; i++) { +// label_lookup_[getExternalLabel(i)]=i; + unsigned int linkListSize; + readBinaryPOD(input, linkListSize); + if (linkListSize == 0) { + element_levels_[i] = 0; + + linkLists_[i] = nullptr; + } else { + element_levels_[i] = linkListSize / size_links_per_element_; + linkLists_[i] = (char *) malloc(linkListSize); + if (linkLists_[i] == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist"); + input.read(linkLists_[i], linkListSize); + } + } + + has_deletions_=false; + + for (size_t i = 0; i < cur_element_count; i++) { + if(isMarkedDeleted(i)) + has_deletions_=true; + } + + return; + } + + void saveIndex(const std::string &location) { + std::ofstream output(location, std::ios::binary); + std::streampos position; + +// writeBinaryPOD(output, offsetLevel0_); + writeBinaryPOD(output, max_elements_); + writeBinaryPOD(output, cur_element_count); + writeBinaryPOD(output, size_data_per_element_); +// writeBinaryPOD(output, label_offset_); +// writeBinaryPOD(output, offsetData_); + writeBinaryPOD(output, maxlevel_); + writeBinaryPOD(output, enterpoint_node_); + writeBinaryPOD(output, maxM_); + + writeBinaryPOD(output, maxM0_); + writeBinaryPOD(output, M_); + writeBinaryPOD(output, mult_); + writeBinaryPOD(output, ef_construction_); + + output.write(data_level0_memory_, cur_element_count * size_data_per_element_); + + for (size_t i = 0; i < cur_element_count; i++) { + unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0; + writeBinaryPOD(output, linkListSize); + if (linkListSize) + output.write(linkLists_[i], linkListSize); + } + output.close(); + } + + void loadIndex(const std::string &location, SpaceInterface *s, size_t max_elements_i=0) { + std::ifstream input(location, std::ios::binary); + + if (!input.is_open()) + throw std::runtime_error("Cannot open file"); + + // get file size: + input.seekg(0,input.end); + std::streampos total_filesize=input.tellg(); + input.seekg(0,input.beg); + +// readBinaryPOD(input, offsetLevel0_); + readBinaryPOD(input, max_elements_); + readBinaryPOD(input, cur_element_count); + + size_t max_elements=max_elements_i; + if(max_elements < cur_element_count) + max_elements = max_elements_; + max_elements_ = max_elements; + readBinaryPOD(input, size_data_per_element_); +// readBinaryPOD(input, label_offset_); +// readBinaryPOD(input, offsetData_); + readBinaryPOD(input, maxlevel_); + readBinaryPOD(input, enterpoint_node_); + + readBinaryPOD(input, maxM_); + readBinaryPOD(input, maxM0_); + readBinaryPOD(input, M_); + readBinaryPOD(input, mult_); + readBinaryPOD(input, ef_construction_); + + data_size_ = s->get_data_size(); + fstdistfunc_ = s->get_dist_func(); + dist_func_param_ = s->get_dist_func_param(); + + auto pos=input.tellg(); + + /// Optional - check if index is ok: + + input.seekg(cur_element_count * size_data_per_element_,input.cur); + for (size_t i = 0; i < cur_element_count; i++) { + if(input.tellg() < 0 || input.tellg()>=total_filesize){ + throw std::runtime_error("Index seems to be corrupted or unsupported"); + } + + unsigned int linkListSize; + readBinaryPOD(input, linkListSize); + if (linkListSize != 0) { + input.seekg(linkListSize,input.cur); + } + } + + // throw exception if it either corrupted or old index + if(input.tellg()!=total_filesize) + throw std::runtime_error("Index seems to be corrupted or unsupported"); + + input.clear(); + + /// Optional check end + + input.seekg(pos,input.beg); + + data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_); + if (data_level0_memory_ == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0"); + input.read(data_level0_memory_, cur_element_count * size_data_per_element_); + + size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); + + size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); + std::vector(max_elements).swap(link_list_locks_); + + visited_list_pool_ = new VisitedListPool(1, max_elements); + + linkLists_ = (char **) malloc(sizeof(void *) * max_elements); + if (linkLists_ == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists"); + element_levels_ = std::vector(max_elements); + revSize_ = 1.0 / mult_; + ef_ = 10; + for (size_t i = 0; i < cur_element_count; i++) { +// label_lookup_[getExternalLabel(i)]=i; + unsigned int linkListSize; + readBinaryPOD(input, linkListSize); + if (linkListSize == 0) { + element_levels_[i] = 0; + linkLists_[i] = nullptr; + } else { + element_levels_[i] = linkListSize / size_links_per_element_; + linkLists_[i] = (char *) malloc(linkListSize); + if (linkLists_[i] == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist"); + input.read(linkLists_[i], linkListSize); + } + } + + has_deletions_=false; + + for (size_t i = 0; i < cur_element_count; i++) { + if(isMarkedDeleted(i)) + has_deletions_=true; + } + + input.close(); + return; + } + + /* + template + std::vector getDataByLabel(tableint internal_id, dist_t *pdata) { + // tableint label_c; + // auto search = label_lookup_.find(label); + // if (search == label_lookup_.end() || isMarkedDeleted(search->second)) { + // throw std::runtime_error("Label not found"); + // } + // label_c = search->second; + + char* data_ptrv = getDataByInternalId(pdata, internal_id); + size_t dim = *((size_t *) dist_func_param_); + std::vector data; + data_t* data_ptr = (data_t*) data_ptrv; + for (int i = 0; i < dim; i++) { + data.push_back(*data_ptr); + data_ptr += 1; + } + return data; + } + */ + + static const unsigned char DELETE_MARK = 0x01; + // static const unsigned char REUSE_MARK = 0x10; + /** + * Marks an element with the given label deleted, does NOT really change the current graph. + * @param label + */ + void markDelete(labeltype label) + { + has_deletions_=true; +// auto search = label_lookup_.find(label); +// if (search == label_lookup_.end()) { +// throw std::runtime_error("Label not found"); +// } +// markDeletedInternal(search->second); + markDeletedInternal(label); + } + + /** + * Uses the first 8 bits of the memory for the linked list to store the mark, + * whereas maxM0_ has to be limited to the lower 24 bits, however, still large enough in almost all cases. + * @param internalId + */ + void markDeletedInternal(tableint internalId) { + unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2; + *ll_cur |= DELETE_MARK; + } + + /** + * Remove the deleted mark of the node. + * @param internalId + */ + void unmarkDeletedInternal(tableint internalId) { + unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2; + *ll_cur &= ~DELETE_MARK; + } + + /** + * Checks the first 8 bits of the memory to see if the element is marked deleted. + * @param internalId + * @return + */ + bool isMarkedDeleted(tableint internalId) const { + unsigned char *ll_cur = ((unsigned char*)get_linklist0(internalId))+2; + return *ll_cur & DELETE_MARK; + } + + unsigned short int getListCount(linklistsizeint * ptr) const { + return *((unsigned short int *)ptr); + } + + void setListCount(linklistsizeint * ptr, unsigned short int size) const { + *((unsigned short int*)(ptr))=*((unsigned short int *)&size); + } + + size_t getCurrentElementCount() { + return cur_element_count; + } + + void addPoint(void *data_point, labeltype label, size_t base, size_t offset) { + addPoint(data_point, label,-1, base, offset); + } + + tableint addPoint(void *data_point, labeltype label, int level, size_t base, size_t offset) { + tableint cur_c = 0; + { + std::unique_lock lock(cur_element_count_guard_); + if (cur_element_count >= max_elements_) { + throw std::runtime_error("The number of elements exceeds the specified limit"); + }; + +// cur_c = cur_element_count; + cur_c = tableint(base + offset); + cur_element_count++; + +// auto search = label_lookup_.find(label); +// if (search != label_lookup_.end()) { +// std::unique_lock lock_el(link_list_locks_[search->second]); +// has_deletions_ = true; +// markDeletedInternal(search->second); +// } +// label_lookup_[label] = cur_c; + } + + std::unique_lock lock_el(link_list_locks_[cur_c]); + int curlevel = getRandomLevel(mult_); + if (level > 0) + curlevel = level; + + element_levels_[cur_c] = curlevel; + + // prepose non-concurrent operation + memset(data_level0_memory_ + cur_c * size_data_per_element_, 0, size_data_per_element_); +// setExternalLabel(cur_c, label); +// memcpy(getDataByInternalId(cur_c), data_point, data_size_); + if (curlevel) { + linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1); + if (linkLists_[cur_c] == nullptr) + throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist"); + memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1); + } + + + std::unique_lock templock(global); + int maxlevelcopy = maxlevel_; + if (curlevel <= maxlevelcopy) + templock.unlock(); + tableint currObj = enterpoint_node_; + tableint enterpoint_copy = enterpoint_node_; + + if ((signed)currObj != -1) { + + if (curlevel < maxlevelcopy) { + + dist_t curdist = fstdistfunc_(getDataByInternalId(data_point, (tableint)offset), getDataByInternalId(data_point, currObj), dist_func_param_); + for (int level = maxlevelcopy; level > curlevel; level--) { + bool changed = true; + while (changed) { + changed = false; + unsigned int *data; + std::unique_lock lock(link_list_locks_[currObj]); + data = get_linklist(currObj,level); + int size = getListCount(data); + + tableint *datal = (tableint *) (data + 1); + for (int i = 0; i < size; i++) { + tableint cand = datal[i]; + if (cand < 0 || cand > max_elements_) + throw std::runtime_error("cand error"); + dist_t d = fstdistfunc_(getDataByInternalId(data_point, tableint(offset)), getDataByInternalId(data_point, cand), dist_func_param_); + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; + } + } + } + } + } + + bool epDeleted = isMarkedDeleted(enterpoint_copy); + for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--) { + if (level > maxlevelcopy || level < 0) // possible? + throw std::runtime_error("Level error"); + + std::priority_queue, std::vector>, CompareByFirst> top_candidates = searchBaseLayer( + currObj, getDataByInternalId(data_point, (tableint)offset), level, data_point); + if (epDeleted) { + top_candidates.emplace(fstdistfunc_(getDataByInternalId(data_point, (tableint)offset), getDataByInternalId(data_point, enterpoint_copy), dist_func_param_), enterpoint_copy); + if (top_candidates.size() > ef_construction_) + top_candidates.pop(); + } + currObj = top_candidates.top().second; + + mutuallyConnectNewElement(getDataByInternalId(data_point, (tableint)offset), cur_c, top_candidates, level, data_point); + } + } else { + // Do nothing for the first element + enterpoint_node_ = 0; + maxlevel_ = curlevel; + } + + //Releasing lock for the maximum level + if (curlevel > maxlevelcopy) { + enterpoint_node_ = cur_c; + maxlevel_ = curlevel; + } + return cur_c; + }; + + std::priority_queue> + searchKnn_NM(const void *query_data, size_t k, faiss::ConcurrentBitsetPtr bitset, dist_t *pdata) const { + std::priority_queue> result; + if (cur_element_count == 0) return result; + + tableint currObj = enterpoint_node_; + dist_t curdist; + faiss::SQDistanceComputer *sqdc = nullptr; + if (is_sq8_) { + if (metric_type_ == 0) { // L2 + sqdc = new DCClassL2(sq_->d, sq_->trained); + } else if (metric_type_ == 1) { // IP + sqdc = new DCClassIP(sq_->d, sq_->trained); + } else { + throw std::runtime_error("unsupported metric_type, it must be 0(L2) or 1(IP)!"); + } + sqdc->code_size = sq_->code_size; + sqdc->set_query((const float*)query_data); + sqdc->codes = (uint8_t*)pdata; + curdist = (*sqdc)(currObj); + } else { + curdist = fstdistfunc_(query_data, getDataByInternalId(pdata, enterpoint_node_), dist_func_param_); + } + + for (int level = maxlevel_; level > 0; level--) { + bool changed = true; + while (changed) { + changed = false; + unsigned int *data; + + data = (unsigned int *) get_linklist(currObj, level); + int size = getListCount(data); + tableint *datal = (tableint *) (data + 1); + for (int i = 0; i < size; i++) { + tableint cand = datal[i]; + if (cand < 0 || cand > max_elements_) + throw std::runtime_error("cand error"); + dist_t d; + if (is_sq8_) { + d = (*sqdc)(cand); + } else { + d = fstdistfunc_(query_data, getDataByInternalId(pdata, cand), dist_func_param_); + } + + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; + } + } + } + } + + std::priority_queue, std::vector>, CompareByFirst> top_candidates; + if (bitset != nullptr) { + std::priority_queue, std::vector>, CompareByFirst> + top_candidates1 = searchBaseLayerST(currObj, query_data, std::max(ef_, k), bitset, pdata); + top_candidates.swap(top_candidates1); + } + else{ + std::priority_queue, std::vector>, CompareByFirst> + top_candidates1 = searchBaseLayerST(currObj, query_data, std::max(ef_, k), bitset, pdata); + top_candidates.swap(top_candidates1); + } + while (top_candidates.size() > k) { + top_candidates.pop(); + } + while (top_candidates.size() > 0) { + std::pair rez = top_candidates.top(); +// result.push(std::pair(rez.first, getExternalLabel(rez.second))); + result.push(std::pair(rez.first, rez.second)); + top_candidates.pop(); + } + if (is_sq8_) delete sqdc; + return result; + }; + + template + std::vector> + searchKnn_NM(const void* query_data, size_t k, Comp comp, faiss::ConcurrentBitsetPtr bitset, dist_t *pdata) { + std::vector> result; + if (cur_element_count == 0) return result; + + auto ret = searchKnn_NM(query_data, k, bitset, pdata); + + while (!ret.empty()) { + result.push_back(ret.top()); + ret.pop(); + } + + std::sort(result.begin(), result.end(), comp); + + return result; + } + + int64_t cal_size() { + int64_t ret = 0; + ret += sizeof(*this); + ret += sizeof(*space); + ret += visited_list_pool_->GetSize(); + ret += link_list_locks_.size() * sizeof(std::mutex); + ret += element_levels_.size() * sizeof(int); + ret += max_elements_ * size_data_per_element_; + ret += max_elements_ * sizeof(void*); + for (auto i = 0; i < max_elements_; ++ i) { + ret += linkLists_[i] ? size_links_per_element_ * element_levels_[i] : 0; + } + return ret; + } + }; + +} \ No newline at end of file diff --git a/core/src/index/thirdparty/hnswlib/hnswlib.h b/core/src/index/thirdparty/hnswlib/hnswlib.h index 89d8d423a7ae..8c94295d3b7f 100644 --- a/core/src/index/thirdparty/hnswlib/hnswlib.h +++ b/core/src/index/thirdparty/hnswlib/hnswlib.h @@ -82,10 +82,15 @@ namespace hnswlib { class AlgorithmInterface { public: virtual void addPoint(const void *datapoint, labeltype label)=0; +// virtual void addPoint(void *datapoint, labeltype label, size_t base, size_t offset)=0; virtual std::priority_queue> searchKnn(const void *, size_t, faiss::ConcurrentBitsetPtr bitset) const = 0; template std::vector> searchKnn(const void*, size_t, Comp, faiss::ConcurrentBitsetPtr bitset) { } +// virtual std::priority_queue> searchKnn_NM(const void *, size_t, faiss::ConcurrentBitsetPtr bitset, dist_t *pdata) const = 0; +// template +// std::vector> searchKnn_NM(const void*, size_t, Comp, faiss::ConcurrentBitsetPtr bitset, dist_t *pdata) { +// } virtual void saveIndex(const std::string &location)=0; virtual ~AlgorithmInterface(){ } diff --git a/core/src/index/thirdparty/hnswlib/hnswlib_nm.h b/core/src/index/thirdparty/hnswlib/hnswlib_nm.h new file mode 100644 index 000000000000..31142568ee25 --- /dev/null +++ b/core/src/index/thirdparty/hnswlib/hnswlib_nm.h @@ -0,0 +1,98 @@ +#pragma once +#ifndef NO_MANUAL_VECTORIZATION +#ifdef __SSE__ +#define USE_SSE +#ifdef __AVX__ +#define USE_AVX +#endif +#endif +#endif + +#if defined(USE_AVX) || defined(USE_SSE) +#ifdef _MSC_VER +#include +#include +#else +#include +#endif + +#if defined(__GNUC__) +#define PORTABLE_ALIGN32 __attribute__((aligned(32))) +#else +#define PORTABLE_ALIGN32 __declspec(align(32)) +#endif +#endif + +#include +#include +#include + +#include +#include + +namespace hnswlib_nm { + typedef int64_t labeltype; + + template + class pairGreater { + public: + bool operator()(const T& p1, const T& p2) { + return p1.first > p2.first; + } + }; + + template + static void writeBinaryPOD(std::ostream &out, const T &podRef) { + out.write((char *) &podRef, sizeof(T)); + } + + template + static void readBinaryPOD(std::istream &in, T &podRef) { + in.read((char *) &podRef, sizeof(T)); + } + + template + static void writeBinaryPOD(W &out, const T &podRef) { + out.write((char *) &podRef, sizeof(T)); + } + + template + static void readBinaryPOD(R &in, T &podRef) { + in.read((char *) &podRef, sizeof(T)); + } + + template + using DISTFUNC = MTYPE(*)(const void *, const void *, const void *); + + + template + class SpaceInterface { + public: + //virtual void search(void *); + virtual size_t get_data_size() = 0; + + virtual DISTFUNC get_dist_func() = 0; + + virtual void *get_dist_func_param() = 0; + + virtual ~SpaceInterface() {} + }; + + template + class AlgorithmInterface { + public: + virtual void addPoint(void *datapoint, labeltype label, size_t base, size_t offset)=0; + virtual std::priority_queue> searchKnn_NM(const void *, size_t, faiss::ConcurrentBitsetPtr bitset, dist_t *pdata) const = 0; + template + std::vector> searchKnn_NM(const void*, size_t, Comp, faiss::ConcurrentBitsetPtr bitset, dist_t *pdata) { + } + virtual void saveIndex(const std::string &location)=0; + virtual ~AlgorithmInterface(){ + } + }; +} + +#include "space_l2.h" +#include "space_ip.h" +#include "bruteforce.h" +#include "hnswalg_nm.h" \ No newline at end of file diff --git a/core/src/index/thirdparty/hnswlib/space_ip.h b/core/src/index/thirdparty/hnswlib/space_ip.h index fc25485e7db9..87ca28ec7b7e 100644 --- a/core/src/index/thirdparty/hnswlib/space_ip.h +++ b/core/src/index/thirdparty/hnswlib/space_ip.h @@ -1,8 +1,8 @@ #pragma once -#include "hnswlib.h" +#include "hnswlib_nm.h" #include -namespace hnswlib { +namespace hnswlib_nm { static float InnerProduct(const void *pVect1, const void *pVect2, const void *qty_ptr) { diff --git a/core/src/index/thirdparty/hnswlib/space_l2.h b/core/src/index/thirdparty/hnswlib/space_l2.h index 3fd0d2da2c4b..e9fdcb9af986 100644 --- a/core/src/index/thirdparty/hnswlib/space_l2.h +++ b/core/src/index/thirdparty/hnswlib/space_l2.h @@ -1,8 +1,8 @@ #pragma once -#include "hnswlib.h" +#include "hnswlib_nm.h" #include -namespace hnswlib { +namespace hnswlib_nm { static float L2Sqr(const void *pVect1, const void *pVect2, const void *qty_ptr) { diff --git a/core/src/index/thirdparty/hnswlib/visited_list_pool.h b/core/src/index/thirdparty/hnswlib/visited_list_pool.h index 457f73433df6..65e411852e7e 100644 --- a/core/src/index/thirdparty/hnswlib/visited_list_pool.h +++ b/core/src/index/thirdparty/hnswlib/visited_list_pool.h @@ -2,8 +2,9 @@ #include #include +#include -namespace hnswlib { +namespace hnswlib_nm { typedef unsigned short int vl_type; class VisitedList { @@ -26,6 +27,7 @@ class VisitedList { } }; + ~VisitedList() { delete[] mass; } }; @@ -74,6 +76,12 @@ class VisitedListPool { delete rez; } }; + + int64_t GetSize() { + auto visit_list_size = sizeof(VisitedList) + numelements * sizeof(vl_type); + auto pool_size = pool.size() * (sizeof(VisitedList *) + visit_list_size); + return pool_size + sizeof(*this); + } }; } diff --git a/core/src/index/unittest/CMakeLists.txt b/core/src/index/unittest/CMakeLists.txt index f29ca1af063b..72e2f3e67a21 100644 --- a/core/src/index/unittest/CMakeLists.txt +++ b/core/src/index/unittest/CMakeLists.txt @@ -31,7 +31,7 @@ set(util_srcs ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/adapter/VectorAdapter.cpp ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/helpers/FaissIO.cpp ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/helpers/IndexParameter.cpp - ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexType.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/IndexType.cpp ${INDEX_SOURCE_DIR}/knowhere/knowhere/common/Exception.cpp ${INDEX_SOURCE_DIR}/knowhere/knowhere/common/Log.cpp ${INDEX_SOURCE_DIR}/knowhere/knowhere/common/Timer.cpp @@ -62,6 +62,9 @@ set(faiss_srcs ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexIVF.cpp ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexIVFSQ.cpp ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexIVFPQ.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_offset_index/OffsetBaseIndex.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_offset_index/IndexIVF_NM.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_offset_index/IndexIVFSQNR_NM.cpp ) if (MILVUS_GPU_VERSION) set(faiss_srcs ${faiss_srcs} @@ -71,6 +74,8 @@ set(faiss_srcs ${faiss_srcs} ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFSQ.cpp ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVFPQ.cpp ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/gpu/IndexIVFSQHybrid.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_offset_index/gpu/IndexGPUIVF_NM.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_offset_index/gpu/IndexGPUIVFSQNR_NM.cpp ) endif () @@ -120,6 +125,38 @@ endif () target_link_libraries(test_ivf ${depend_libs} ${unittest_libs} ${basic_libs}) install(TARGETS test_ivf DESTINATION unittest) +################################################################################ +# +if (NOT TARGET test_ivf_cpu_nm) + add_executable(test_ivf_cpu_nm test_ivf_cpu_nm.cpp ${faiss_srcs} ${util_srcs}) +endif () +target_link_libraries(test_ivf_cpu_nm ${depend_libs} ${unittest_libs} ${basic_libs}) +install(TARGETS test_ivf_cpu_nm DESTINATION unittest) + +################################################################################ +# +if (NOT TARGET test_ivfsq_cpu_nm) + add_executable(test_ivfsq_cpu_nm test_ivfsq_cpu_nm.cpp ${faiss_srcs} ${util_srcs}) +endif () +target_link_libraries(test_ivfsq_cpu_nm ${depend_libs} ${unittest_libs} ${basic_libs}) +install(TARGETS test_ivfsq_cpu_nm DESTINATION unittest) + +################################################################################ +# +if (NOT TARGET test_ivf_gpu_nm) + add_executable(test_ivf_gpu_nm test_ivf_gpu_nm.cpp ${faiss_srcs} ${util_srcs}) +endif () +target_link_libraries(test_ivf_gpu_nm ${depend_libs} ${unittest_libs} ${basic_libs}) +install(TARGETS test_ivf_gpu_nm DESTINATION unittest) + +################################################################################ +# +if (NOT TARGET test_ivfsq_gpu_nm) + add_executable(test_ivfsq_gpu_nm test_ivfsq_gpu_nm.cpp ${faiss_srcs} ${util_srcs}) +endif () +target_link_libraries(test_ivfsq_gpu_nm ${depend_libs} ${unittest_libs} ${basic_libs}) +install(TARGETS test_ivfsq_gpu_nm DESTINATION unittest) + ################################################################################ # if (NOT TARGET test_binaryidmap) @@ -152,7 +189,7 @@ endif () include_directories(${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/impl/nsg) aux_source_directory(${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/impl/nsg nsg_src) set(interface_src - ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexNSG.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_offset_index/IndexNSG_NM.cpp ) if (NOT TARGET test_nsg) add_executable(test_nsg test_nsg.cpp ${interface_src} ${nsg_src} ${util_srcs} ${faiss_srcs}) @@ -163,7 +200,7 @@ install(TARGETS test_nsg DESTINATION unittest) ################################################################################ # set(hnsw_srcs - ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/IndexHNSW.cpp + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_offset_index/IndexHNSW_NM.cpp ) if (NOT TARGET test_hnsw) add_executable(test_hnsw test_hnsw.cpp ${hnsw_srcs} ${util_srcs}) @@ -171,6 +208,17 @@ endif () target_link_libraries(test_hnsw ${depend_libs} ${unittest_libs} ${basic_libs}) install(TARGETS test_hnsw DESTINATION unittest) +################################################################################ +# +set(hnsw_sq8nm_srcs + ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_offset_index/IndexHNSW_SQ8NM.cpp + ) +if (NOT TARGET test_hnsw_sq8nm) + add_executable(test_hnsw_sq8nm test_hnsw_sq8nm.cpp ${hnsw_sq8nm_srcs} ${util_srcs}) +endif () +target_link_libraries(test_hnsw_sq8nm ${depend_libs} ${unittest_libs} ${basic_libs}) +install(TARGETS test_hnsw_sq8nm DESTINATION unittest) + ################################################################################ # if (MILVUS_SUPPORT_SPTAG) diff --git a/core/src/index/unittest/Helper.h b/core/src/index/unittest/Helper.h index 7ab2810875fe..2cf57fe9828d 100644 --- a/core/src/index/unittest/Helper.h +++ b/core/src/index/unittest/Helper.h @@ -12,17 +12,21 @@ #include #include +#include "knowhere/index/IndexType.h" #include "knowhere/index/vector_index/IndexIVF.h" #include "knowhere/index/vector_index/IndexIVFPQ.h" #include "knowhere/index/vector_index/IndexIVFSQ.h" -#include "knowhere/index/vector_index/IndexType.h" #include "knowhere/index/vector_index/helpers/IndexParameter.h" +#include "knowhere/index/vector_offset_index/IndexIVFSQNR_NM.h" +#include "knowhere/index/vector_offset_index/IndexIVF_NM.h" #ifdef MILVUS_GPU_VERSION #include "knowhere/index/vector_index/gpu/IndexGPUIVF.h" #include "knowhere/index/vector_index/gpu/IndexGPUIVFPQ.h" #include "knowhere/index/vector_index/gpu/IndexGPUIVFSQ.h" #include "knowhere/index/vector_index/gpu/IndexIVFSQHybrid.h" +#include "knowhere/index/vector_offset_index/gpu/IndexGPUIVFSQNR_NM.h" +#include "knowhere/index/vector_offset_index/gpu/IndexGPUIVF_NM.h" #endif int DEVICEID = 0; @@ -56,6 +60,8 @@ IndexFactory(const milvus::knowhere::IndexType& type, const milvus::knowhere::In return std::make_shared(DEVICEID); } else if (type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8) { return std::make_shared(DEVICEID); + } else if (type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8NR) { + return std::make_shared(DEVICEID); } else if (type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8H) { return std::make_shared(DEVICEID); } else { @@ -66,6 +72,20 @@ IndexFactory(const milvus::knowhere::IndexType& type, const milvus::knowhere::In return nullptr; } +milvus::knowhere::IVFNMPtr +IndexFactoryNM(const milvus::knowhere::IndexType& type, const milvus::knowhere::IndexMode mode) { + if (mode == milvus::knowhere::IndexMode::MODE_CPU) { + if (type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT) { + return std::make_shared(); + } else if (type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8NR) { + return std::make_shared(); + } else { + std::cout << "Invalid IndexType " << type << std::endl; + } + } + return nullptr; +} + class ParamGenerator { public: static ParamGenerator& @@ -97,6 +117,7 @@ class ParamGenerator { {milvus::knowhere::meta::DEVICEID, DEVICEID}, }; } else if (type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8 || + type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8NR || type == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8H) { return milvus::knowhere::Config{ {milvus::knowhere::meta::DIM, DIM}, diff --git a/core/src/index/unittest/test_customized_index.cpp b/core/src/index/unittest/test_customized_index.cpp index 03c55d4509ce..0c877f064402 100644 --- a/core/src/index/unittest/test_customized_index.cpp +++ b/core/src/index/unittest/test_customized_index.cpp @@ -15,7 +15,7 @@ #include #include "knowhere/common/Timer.h" -#include "knowhere/index/vector_index/IndexType.h" +#include "knowhere/index/IndexType.h" #include "unittest/Helper.h" #include "unittest/utils.h" diff --git a/core/src/index/unittest/test_gpuresource.cpp b/core/src/index/unittest/test_gpuresource.cpp index e70404e17654..2c7a30a5653e 100644 --- a/core/src/index/unittest/test_gpuresource.cpp +++ b/core/src/index/unittest/test_gpuresource.cpp @@ -20,10 +20,10 @@ #include "knowhere/common/Exception.h" #include "knowhere/common/Timer.h" +#include "knowhere/index/IndexType.h" #include "knowhere/index/vector_index/IndexIVF.h" #include "knowhere/index/vector_index/IndexIVFPQ.h" #include "knowhere/index/vector_index/IndexIVFSQ.h" -#include "knowhere/index/vector_index/IndexType.h" #include "knowhere/index/vector_index/gpu/IndexGPUIVF.h" #include "knowhere/index/vector_index/gpu/IndexGPUIVFPQ.h" #include "knowhere/index/vector_index/gpu/IndexGPUIVFSQ.h" diff --git a/core/src/index/unittest/test_hnsw.cpp b/core/src/index/unittest/test_hnsw.cpp index 48f458fca9b4..f780d9f13b73 100644 --- a/core/src/index/unittest/test_hnsw.cpp +++ b/core/src/index/unittest/test_hnsw.cpp @@ -10,7 +10,7 @@ // or implied. See the License for the specific language governing permissions and limitations under the License. #include -#include +#include #include #include #include @@ -28,7 +28,7 @@ class HNSWTest : public DataGen, public TestWithParam { IndexType = GetParam(); std::cout << "IndexType from GetParam() is: " << IndexType << std::endl; Generate(64, 10000, 10); // dim = 64, nb = 10000, nq = 10 - index_ = std::make_shared(); + index_ = std::make_shared(); conf = milvus::knowhere::Config{ {milvus::knowhere::meta::DIM, 64}, {milvus::knowhere::meta::TOPK, 10}, {milvus::knowhere::IndexParams::M, 16}, {milvus::knowhere::IndexParams::efConstruction, 200}, @@ -38,7 +38,7 @@ class HNSWTest : public DataGen, public TestWithParam { protected: milvus::knowhere::Config conf; - std::shared_ptr index_ = nullptr; + std::shared_ptr index_ = nullptr; std::string IndexType; }; @@ -48,6 +48,7 @@ TEST_P(HNSWTest, HNSW_basic) { assert(!xb.empty()); // null faiss index + /* { ASSERT_ANY_THROW(index_->Serialize()); ASSERT_ANY_THROW(index_->Query(query_dataset, conf)); @@ -56,12 +57,26 @@ TEST_P(HNSWTest, HNSW_basic) { ASSERT_ANY_THROW(index_->Count()); ASSERT_ANY_THROW(index_->Dim()); } + */ index_->Train(base_dataset, conf); index_->Add(base_dataset, conf); EXPECT_EQ(index_->Count(), nb); EXPECT_EQ(index_->Dim(), dim); + // Serialize and Load before Query + milvus::knowhere::BinarySet bs = index_->Serialize(); + + int64_t dim = base_dataset->Get(milvus::knowhere::meta::DIM); + int64_t rows = base_dataset->Get(milvus::knowhere::meta::ROWS); + auto raw_data = base_dataset->Get(milvus::knowhere::meta::TENSOR); + milvus::knowhere::BinaryPtr bptr = std::make_shared(); + bptr->data = std::shared_ptr((uint8_t*)raw_data, [&](uint8_t*) {}); + bptr->size = dim * rows * sizeof(float); + bs.Append(RAW_DATA, bptr); + + index_->Load(bs); + auto result = index_->Query(query_dataset, conf); AssertAnns(result, nq, k); } @@ -78,6 +93,20 @@ TEST_P(HNSWTest, HNSW_delete) { for (auto i = 0; i < nq; ++i) { bitset->set(i); } + + // Serialize and Load before Query + milvus::knowhere::BinarySet bs = index_->Serialize(); + + int64_t dim = base_dataset->Get(milvus::knowhere::meta::DIM); + int64_t rows = base_dataset->Get(milvus::knowhere::meta::ROWS); + auto raw_data = base_dataset->Get(milvus::knowhere::meta::TENSOR); + milvus::knowhere::BinaryPtr bptr = std::make_shared(); + bptr->data = std::shared_ptr((uint8_t*)raw_data, [&](uint8_t*) {}); + bptr->size = dim * rows * sizeof(float); + bs.Append(RAW_DATA, bptr); + + index_->Load(bs); + auto result1 = index_->Query(query_dataset, conf); AssertAnns(result1, nq, k); @@ -107,6 +136,7 @@ TEST_P(HNSWTest, HNSW_delete) { */ } +/* TEST_P(HNSWTest, HNSW_serialize) { auto serialize = [](const std::string& filename, milvus::knowhere::BinaryPtr& bin, uint8_t* ret) { { @@ -138,7 +168,7 @@ TEST_P(HNSWTest, HNSW_serialize) { auto result = index_->Query(query_dataset, conf); AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); } -} +}*/ /* * faiss style test @@ -181,7 +211,7 @@ main() { int k = 4; int m = 16; int ef = 200; - milvus::knowhere::IndexHNSW index; + milvus::knowhere::IndexHNSW_NM index; milvus::knowhere::DatasetPtr base_dataset = generate_dataset(nb, d, (const void*)xb, ids); // base_dataset->Set(milvus::knowhere::meta::ROWS, nb); // base_dataset->Set(milvus::knowhere::meta::DIM, d); diff --git a/core/src/index/unittest/test_hnsw_sq8nm.cpp b/core/src/index/unittest/test_hnsw_sq8nm.cpp new file mode 100644 index 000000000000..829a405929e9 --- /dev/null +++ b/core/src/index/unittest/test_hnsw_sq8nm.cpp @@ -0,0 +1,304 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include +#include +#include +#include +#include +#include "knowhere/common/Exception.h" +#include "unittest/utils.h" + +using ::testing::Combine; +using ::testing::TestWithParam; +using ::testing::Values; + +class HNSWSQ8NRTest : public DataGen, public TestWithParam { + protected: + void + SetUp() override { + IndexType = GetParam(); + std::cout << "IndexType from GetParam() is: " << IndexType << std::endl; + Generate(64, 10000, 10); // dim = 64, nb = 10000, nq = 10 + // Generate(2, 10, 2); // dim = 64, nb = 10000, nq = 10 + index_ = std::make_shared(); + conf = milvus::knowhere::Config{ + {milvus::knowhere::meta::DIM, 64}, {milvus::knowhere::meta::TOPK, 10}, + {milvus::knowhere::IndexParams::M, 16}, {milvus::knowhere::IndexParams::efConstruction, 200}, + {milvus::knowhere::IndexParams::ef, 200}, {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, + }; + /* + conf = milvus::knowhere::Config{ + {milvus::knowhere::meta::DIM, 2}, {milvus::knowhere::meta::TOPK, 2}, + {milvus::knowhere::IndexParams::M, 2}, {milvus::knowhere::IndexParams::efConstruction, 4}, + {milvus::knowhere::IndexParams::ef, 7}, {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, + }; + */ + } + + protected: + milvus::knowhere::Config conf; + std::shared_ptr index_ = nullptr; + std::string IndexType; +}; + +INSTANTIATE_TEST_CASE_P(HNSWParameters, HNSWSQ8NRTest, Values("HNSWSQ8NR")); + +TEST_P(HNSWSQ8NRTest, HNSW_basic) { + assert(!xb.empty()); + + // null faiss index + /* + { + ASSERT_ANY_THROW(index_->Serialize()); + ASSERT_ANY_THROW(index_->Query(query_dataset, conf)); + ASSERT_ANY_THROW(index_->Add(nullptr, conf)); + ASSERT_ANY_THROW(index_->AddWithoutIds(nullptr, conf)); + ASSERT_ANY_THROW(index_->Count()); + ASSERT_ANY_THROW(index_->Dim()); + } + */ + + index_->Train(base_dataset, conf); + index_->Add(base_dataset, conf); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + + // Serialize and Load before Query + milvus::knowhere::BinarySet bs = index_->Serialize(); + + // int64_t dim = base_dataset->Get(milvus::knowhere::meta::DIM); + // int64_t rows = base_dataset->Get(milvus::knowhere::meta::ROWS); + // auto raw_data = base_dataset->Get(milvus::knowhere::meta::TENSOR); + // milvus::knowhere::BinaryPtr bptr = std::make_shared(); + // bptr->data = std::shared_ptr((uint8_t*)raw_data, [&](uint8_t*) {}); + // bptr->size = dim * rows * sizeof(float); + // bs.Append(RAW_DATA, bptr); + + index_->Load(bs); + + auto result = index_->Query(query_dataset, conf); + AssertAnns(result, nq, k); +} + +TEST_P(HNSWSQ8NRTest, HNSW_delete) { + assert(!xb.empty()); + + index_->Train(base_dataset, conf); + index_->Add(base_dataset, conf); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + + faiss::ConcurrentBitsetPtr bitset = std::make_shared(nb); + for (auto i = 0; i < nq; ++i) { + bitset->set(i); + } + + // Serialize and Load before Query + milvus::knowhere::BinarySet bs = index_->Serialize(); + + // int64_t dim = base_dataset->Get(milvus::knowhere::meta::DIM); + // int64_t rows = base_dataset->Get(milvus::knowhere::meta::ROWS); + // auto raw_data = base_dataset->Get(milvus::knowhere::meta::TENSOR); + // milvus::knowhere::BinaryPtr bptr = std::make_shared(); + // bptr->data = std::shared_ptr((uint8_t*)raw_data, [&](uint8_t*) {}); + // bptr->size = dim * rows * sizeof(float); + // bs.Append(RAW_DATA, bptr); + + index_->Load(bs); + + auto result1 = index_->Query(query_dataset, conf); + AssertAnns(result1, nq, k); + + index_->SetBlacklist(bitset); + auto result2 = index_->Query(query_dataset, conf); + AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL); + + /* + * delete result checked by eyes + auto ids1 = result1->Get(milvus::knowhere::meta::IDS); + auto ids2 = result2->Get(milvus::knowhere::meta::IDS); + std::cout << std::endl; + for (int i = 0; i < nq; ++ i) { + std::cout << "ids1: "; + for (int j = 0; j < k; ++ j) { + std::cout << *(ids1 + i * k + j) << " "; + } + std::cout << "ids2: "; + for (int j = 0; j < k; ++ j) { + std::cout << *(ids2 + i * k + j) << " "; + } + std::cout << std::endl; + for (int j = 0; j < std::min(5, k>>1); ++ j) { + ASSERT_EQ(*(ids1 + i * k + j + 1), *(ids2 + i * k + j)); + } + } + */ +} + +TEST_P(HNSWSQ8NRTest, HNSW_serialize) { + auto serialize = [](const std::string& filename, milvus::knowhere::BinaryPtr& bin, uint8_t* ret) { + { + FileIOWriter writer(filename); + writer(static_cast(bin->data.get()), bin->size); + } + + FileIOReader reader(filename); + reader(ret, bin->size); + }; + + { + index_->Train(base_dataset, conf); + index_->Add(base_dataset, conf); + auto binaryset = index_->Serialize(); + auto bin_index = binaryset.GetByName("HNSW_SQ8"); + auto bin_sq8 = binaryset.GetByName(SQ8_DATA); + + std::string filename = "/tmp/HNSW_SQ8NM_test_serialize_index.bin"; + std::string filename2 = "/tmp/HNSW_SQ8NM_test_serialize_sq8.bin"; + auto load_index_data = new uint8_t[bin_index->size]; + serialize(filename, bin_index, load_index_data); + auto load_sq8_data = new uint8_t[bin_sq8->size]; + serialize(filename2, bin_sq8, load_sq8_data); + + binaryset.clear(); + std::shared_ptr data_index(load_index_data); + binaryset.Append("HNSW_SQ8", data_index, bin_index->size); + std::shared_ptr sq8_index(load_sq8_data); + binaryset.Append(SQ8_DATA, sq8_index, bin_sq8->size); + + index_->Load(binaryset); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + auto result = index_->Query(query_dataset, conf); + AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); + } +} + +/* + * faiss style test + * keep it +int +main() { + int64_t d = 64; // dimension + int64_t nb = 10000; // database size + int64_t nq = 10; // 10000; // nb of queries + faiss::ConcurrentBitsetPtr bitset = std::make_shared(nb); + + int64_t* ids = new int64_t[nb]; + float* xb = new float[d * nb]; + float* xq = new float[d * nq]; + // int64_t *ids = (int64_t*)malloc(nb * sizeof(int64_t)); + // float* xb = (float*)malloc(d * nb * sizeof(float)); + // float* xq = (float*)malloc(d * nq * sizeof(float)); + + for (int i = 0; i < nb; i++) { + for (int j = 0; j < d; j++) xb[d * i + j] = drand48(); + xb[d * i] += i / 1000.; + ids[i] = i; + } +// printf("gen xb and ids done! \n"); + + // srand((unsigned)time(nullptr)); + auto random_seed = (unsigned)time(nullptr); +// printf("delete ids: \n"); + for (int i = 0; i < nq; i++) { + auto tmp = rand_r(&random_seed) % nb; +// printf("%ld\n", tmp); + // std::cout << "before delete, test result: " << bitset->test(tmp) << std::endl; + bitset->set(tmp); + // std::cout << "after delete, test result: " << bitset->test(tmp) << std::endl; + for (int j = 0; j < d; j++) xq[d * i + j] = xb[d * tmp + j]; + // xq[d * i] += i / 1000.; + } +// printf("\n"); + + int k = 4; + int m = 16; + int ef = 200; + milvus::knowhere::IndexHNSW_NM index; + milvus::knowhere::DatasetPtr base_dataset = generate_dataset(nb, d, (const void*)xb, ids); +// base_dataset->Set(milvus::knowhere::meta::ROWS, nb); +// base_dataset->Set(milvus::knowhere::meta::DIM, d); +// base_dataset->Set(milvus::knowhere::meta::TENSOR, (const void*)xb); +// base_dataset->Set(milvus::knowhere::meta::IDS, (const int64_t*)ids); + + milvus::knowhere::Config base_conf{ + {milvus::knowhere::meta::DIM, d}, + {milvus::knowhere::meta::TOPK, k}, + {milvus::knowhere::IndexParams::M, m}, + {milvus::knowhere::IndexParams::efConstruction, ef}, + {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, + }; + milvus::knowhere::DatasetPtr query_dataset = generate_query_dataset(nq, d, (const void*)xq); + milvus::knowhere::Config query_conf{ + {milvus::knowhere::meta::DIM, d}, + {milvus::knowhere::meta::TOPK, k}, + {milvus::knowhere::IndexParams::M, m}, + {milvus::knowhere::IndexParams::ef, ef}, + {milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2}, + }; + + index.Train(base_dataset, base_conf); + index.Add(base_dataset, base_conf); + +// printf("------------sanity check----------------\n"); + { // sanity check + auto res = index.Query(query_dataset, query_conf); +// printf("Query done!\n"); + const int64_t* I = res->Get(milvus::knowhere::meta::IDS); +// float* D = res->Get(milvus::knowhere::meta::DISTANCE); + +// printf("I=\n"); +// for (int i = 0; i < 5; i++) { +// for (int j = 0; j < k; j++) printf("%5ld ", I[i * k + j]); +// printf("\n"); +// } + +// printf("D=\n"); +// for (int i = 0; i < 5; i++) { +// for (int j = 0; j < k; j++) printf("%7g ", D[i * k + j]); +// printf("\n"); +// } + } + +// printf("---------------search xq-------------\n"); + { // search xq + auto res = index.Query(query_dataset, query_conf); + const int64_t* I = res->Get(milvus::knowhere::meta::IDS); + + printf("I=\n"); + for (int i = 0; i < nq; i++) { + for (int j = 0; j < k; j++) printf("%5ld ", I[i * k + j]); + printf("\n"); + } + } + + printf("----------------search xq with delete------------\n"); + { // search xq with delete + index.SetBlacklist(bitset); + auto res = index.Query(query_dataset, query_conf); + auto I = res->Get(milvus::knowhere::meta::IDS); + + printf("I=\n"); + for (int i = 0; i < nq; i++) { + for (int j = 0; j < k; j++) printf("%5ld ", I[i * k + j]); + printf("\n"); + } + } + + delete[] xb; + delete[] xq; + delete[] ids; + + return 0; +} +*/ diff --git a/core/src/index/unittest/test_idmap.cpp b/core/src/index/unittest/test_idmap.cpp index 2abbd8e27036..39a73553af59 100644 --- a/core/src/index/unittest/test_idmap.cpp +++ b/core/src/index/unittest/test_idmap.cpp @@ -17,8 +17,8 @@ #include #include "knowhere/common/Exception.h" +#include "knowhere/index/IndexType.h" #include "knowhere/index/vector_index/IndexIDMAP.h" -#include "knowhere/index/vector_index/IndexType.h" #ifdef MILVUS_GPU_VERSION #include #include "knowhere/index/vector_index/gpu/IndexGPUIDMAP.h" diff --git a/core/src/index/unittest/test_ivf.cpp b/core/src/index/unittest/test_ivf.cpp index 1827d0b4d1d9..c36ef5e880e6 100644 --- a/core/src/index/unittest/test_ivf.cpp +++ b/core/src/index/unittest/test_ivf.cpp @@ -22,10 +22,10 @@ #include "knowhere/common/Exception.h" #include "knowhere/common/Timer.h" +#include "knowhere/index/IndexType.h" #include "knowhere/index/vector_index/IndexIVF.h" #include "knowhere/index/vector_index/IndexIVFPQ.h" #include "knowhere/index/vector_index/IndexIVFSQ.h" -#include "knowhere/index/vector_index/IndexType.h" #include "knowhere/index/vector_index/adapter/VectorAdapter.h" #ifdef MILVUS_GPU_VERSION @@ -81,12 +81,10 @@ INSTANTIATE_TEST_CASE_P( IVFParameters, IVFTest, Values( #ifdef MILVUS_GPU_VERSION - std::make_tuple(milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, milvus::knowhere::IndexMode::MODE_GPU), std::make_tuple(milvus::knowhere::IndexEnum::INDEX_FAISS_IVFPQ, milvus::knowhere::IndexMode::MODE_GPU), std::make_tuple(milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, milvus::knowhere::IndexMode::MODE_GPU), std::make_tuple(milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8H, milvus::knowhere::IndexMode::MODE_GPU), #endif - std::make_tuple(milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, milvus::knowhere::IndexMode::MODE_CPU), std::make_tuple(milvus::knowhere::IndexEnum::INDEX_FAISS_IVFPQ, milvus::knowhere::IndexMode::MODE_CPU), std::make_tuple(milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, milvus::knowhere::IndexMode::MODE_CPU))); @@ -245,26 +243,6 @@ TEST_P(IVFTest, clone_test) { } }; - // { - // // clone in place - // std::vector support_idx_vec{"IVF", "GPUIVF", "IVFPQ", "IVFSQ", "GPUIVFSQ"}; - // auto finder = std::find(support_idx_vec.cbegin(), support_idx_vec.cend(), index_type); - // if (finder != support_idx_vec.cend()) { - // EXPECT_NO_THROW({ - // auto clone_index = index_->Clone(); - // auto clone_result = clone_index->Search(query_dataset, conf); - // //AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]); - // AssertEqual(result, clone_result); - // std::cout << "inplace clone [" << index_type << "] success" << std::endl; - // }); - // } else { - // EXPECT_THROW({ - // std::cout << "inplace clone [" << index_type << "] failed" << std::endl; - // auto clone_index = index_->Clone(); - // }, KnowhereException); - // } - // } - { // copy from gpu to cpu if (index_mode_ == milvus::knowhere::IndexMode::MODE_GPU) { @@ -352,11 +330,11 @@ TEST_P(IVFTest, invalid_gpu_source) { auto invalid_conf = ParamGenerator::GetInstance().Gen(index_type_); invalid_conf[milvus::knowhere::meta::DEVICEID] = -1; - if (index_type_ == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT) { - // null faiss index - index_->SetIndexSize(0); - milvus::knowhere::cloner::CopyGpuToCpu(index_, milvus::knowhere::Config()); - } + // if (index_type_ == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT) { + // null faiss index + // index_->SetIndexSize(0); + // milvus::knowhere::cloner::CopyGpuToCpu(index_, milvus::knowhere::Config()); + // } index_->Train(base_dataset, conf_); diff --git a/core/src/index/unittest/test_ivf_cpu_nm.cpp b/core/src/index/unittest/test_ivf_cpu_nm.cpp new file mode 100644 index 000000000000..8e44ee22dd47 --- /dev/null +++ b/core/src/index/unittest/test_ivf_cpu_nm.cpp @@ -0,0 +1,131 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include + +#include +#include +#include +#include + +#ifdef MILVUS_GPU_VERSION +#include +#endif + +#include "knowhere/common/Exception.h" +#include "knowhere/common/Timer.h" +#include "knowhere/index/IndexType.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_offset_index/IndexIVF_NM.h" + +#ifdef MILVUS_GPU_VERSION +#include "knowhere/index/vector_index/helpers/Cloner.h" +#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h" +#include "knowhere/index/vector_offset_index/gpu/IndexGPUIVF_NM.h" +#endif + +#include "unittest/Helper.h" +#include "unittest/utils.h" + +using ::testing::Combine; +using ::testing::TestWithParam; +using ::testing::Values; + +class IVFNMCPUTest : public DataGen, + public TestWithParam<::std::tuple> { + protected: + void + SetUp() override { +#ifdef MILVUS_GPU_VERSION + milvus::knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(DEVICEID, PINMEM, TEMPMEM, RESNUM); +#endif + std::tie(index_type_, index_mode_) = GetParam(); + Generate(DIM, NB, NQ); + index_ = IndexFactoryNM(index_type_, index_mode_); + conf_ = ParamGenerator::GetInstance().Gen(index_type_); + } + + void + TearDown() override { +#ifdef MILVUS_GPU_VERSION + milvus::knowhere::FaissGpuResourceMgr::GetInstance().Free(); +#endif + } + + protected: + milvus::knowhere::IndexType index_type_; + milvus::knowhere::IndexMode index_mode_; + milvus::knowhere::Config conf_; + milvus::knowhere::IVFNMPtr index_ = nullptr; +}; + +INSTANTIATE_TEST_CASE_P(IVFParameters, IVFNMCPUTest, + Values(std::make_tuple(milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, + milvus::knowhere::IndexMode::MODE_CPU))); + +TEST_P(IVFNMCPUTest, ivf_basic_cpu) { + assert(!xb.empty()); + + if (index_mode_ != milvus::knowhere::IndexMode::MODE_CPU) { + return; + } + + // null faiss index + ASSERT_ANY_THROW(index_->Add(base_dataset, conf_)); + ASSERT_ANY_THROW(index_->AddWithoutIds(base_dataset, conf_)); + + index_->Train(base_dataset, conf_); + index_->AddWithoutIds(base_dataset, conf_); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + + index_->SetIndexSize(nq * dim * sizeof(float)); + + milvus::knowhere::BinarySet bs = index_->Serialize(conf_); + + int64_t dim = base_dataset->Get(milvus::knowhere::meta::DIM); + int64_t rows = base_dataset->Get(milvus::knowhere::meta::ROWS); + auto raw_data = base_dataset->Get(milvus::knowhere::meta::TENSOR); + milvus::knowhere::BinaryPtr bptr = std::make_shared(); + bptr->data = std::shared_ptr((uint8_t*)raw_data, [&](uint8_t*) {}); + bptr->size = dim * rows * sizeof(float); + bs.Append(RAW_DATA, bptr); + index_->Load(bs); + + auto result = index_->Query(query_dataset, conf_); + AssertAnns(result, nq, k); + +#ifdef MILVUS_GPU_VERSION + // copy from cpu to gpu + { + EXPECT_NO_THROW({ + auto clone_index = milvus::knowhere::cloner::CopyCpuToGpu(index_, DEVICEID, conf_); + auto clone_result = clone_index->Query(query_dataset, conf_); + AssertAnns(clone_result, nq, k); + std::cout << "clone C <=> G [" << index_type_ << "] success" << std::endl; + }); + EXPECT_ANY_THROW(milvus::knowhere::cloner::CopyCpuToGpu(index_, -1, milvus::knowhere::Config())); + } +#endif + + faiss::ConcurrentBitsetPtr concurrent_bitset_ptr = std::make_shared(nb); + for (int64_t i = 0; i < nq; ++i) { + concurrent_bitset_ptr->set(i); + } + index_->SetBlacklist(concurrent_bitset_ptr); + + auto result_bs_1 = index_->Query(query_dataset, conf_); + AssertAnns(result_bs_1, nq, k, CheckMode::CHECK_NOT_EQUAL); + +#ifdef MILVUS_GPU_VERSION + milvus::knowhere::FaissGpuResourceMgr::GetInstance().Dump(); +#endif +} diff --git a/core/src/index/unittest/test_ivf_gpu_nm.cpp b/core/src/index/unittest/test_ivf_gpu_nm.cpp new file mode 100644 index 000000000000..5486fb0ddcfd --- /dev/null +++ b/core/src/index/unittest/test_ivf_gpu_nm.cpp @@ -0,0 +1,138 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include + +#include +#include +#include +#include + +#ifdef MILVUS_GPU_VERSION +#include +#endif + +#include "knowhere/common/Exception.h" +#include "knowhere/common/Timer.h" +#include "knowhere/index/IndexType.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_offset_index/IndexIVF_NM.h" + +#ifdef MILVUS_GPU_VERSION +#include "knowhere/index/vector_index/helpers/Cloner.h" +#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h" +#include "knowhere/index/vector_offset_index/gpu/IndexGPUIVF_NM.h" +#endif + +#include "unittest/Helper.h" +#include "unittest/utils.h" + +#define SERIALIZE_AND_LOAD(index_) \ + milvus::knowhere::BinarySet bs = index_->Serialize(conf_); \ + int64_t dim = base_dataset->Get(milvus::knowhere::meta::DIM); \ + int64_t rows = base_dataset->Get(milvus::knowhere::meta::ROWS); \ + auto raw_data = base_dataset->Get(milvus::knowhere::meta::TENSOR); \ + milvus::knowhere::BinaryPtr bptr = std::make_shared(); \ + bptr->data = std::shared_ptr((uint8_t*)raw_data, [&](uint8_t*) {}); \ + bptr->size = dim * rows * sizeof(float); \ + bs.Append(RAW_DATA, bptr); \ + index_->Load(bs); + +using ::testing::Combine; +using ::testing::TestWithParam; +using ::testing::Values; + +class IVFNMGPUTest : public DataGen, + public TestWithParam<::std::tuple> { + protected: + void + SetUp() override { +#ifdef MILVUS_GPU_VERSION + milvus::knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(DEVICEID, PINMEM, TEMPMEM, RESNUM); +#endif + index_type_ = milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT; + index_mode_ = milvus::knowhere::IndexMode::MODE_GPU; + Generate(DIM, NB, NQ); +#ifdef MILVUS_GPU_VERSION + index_ = std::make_shared(DEVICEID); +#endif + conf_ = ParamGenerator::GetInstance().Gen(index_type_); + } + + void + TearDown() override { +#ifdef MILVUS_GPU_VERSION + milvus::knowhere::FaissGpuResourceMgr::GetInstance().Free(); +#endif + } + + protected: + milvus::knowhere::IndexType index_type_; + milvus::knowhere::IndexMode index_mode_; + milvus::knowhere::Config conf_; + milvus::knowhere::IVFPtr index_ = nullptr; +}; + +#ifdef MILVUS_GPU_VERSION +TEST_F(IVFNMGPUTest, ivf_basic_gpu) { + assert(!xb.empty()); + + if (index_mode_ != milvus::knowhere::IndexMode::MODE_GPU) { + return; + } + + // null faiss index + ASSERT_ANY_THROW(index_->Add(base_dataset, conf_)); + ASSERT_ANY_THROW(index_->AddWithoutIds(base_dataset, conf_)); + + index_->BuildAll(base_dataset, conf_); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + + index_->SetIndexSize(nq * dim * sizeof(float)); + + SERIALIZE_AND_LOAD(index_); + + auto result = index_->Query(query_dataset, conf_); + AssertAnns(result, nq, k); + + auto AssertEqual = [&](milvus::knowhere::DatasetPtr p1, milvus::knowhere::DatasetPtr p2) { + auto ids_p1 = p1->Get(milvus::knowhere::meta::IDS); + auto ids_p2 = p2->Get(milvus::knowhere::meta::IDS); + + for (int i = 0; i < nq * k; ++i) { + EXPECT_EQ(*((int64_t*)(ids_p2) + i), *((int64_t*)(ids_p1) + i)); + } + }; + + // copy from gpu to cpu + { + EXPECT_NO_THROW({ + auto clone_index = milvus::knowhere::cloner::CopyGpuToCpu(index_, conf_); + SERIALIZE_AND_LOAD(clone_index); + auto clone_result = clone_index->Query(query_dataset, conf_); + AssertEqual(result, clone_result); + std::cout << "clone G <=> C [" << index_type_ << "] success" << std::endl; + }); + } + + faiss::ConcurrentBitsetPtr concurrent_bitset_ptr = std::make_shared(nb); + for (int64_t i = 0; i < nq; ++i) { + concurrent_bitset_ptr->set(i); + } + index_->SetBlacklist(concurrent_bitset_ptr); + + auto result_bs_1 = index_->Query(query_dataset, conf_); + AssertAnns(result_bs_1, nq, k, CheckMode::CHECK_NOT_EQUAL); + + milvus::knowhere::FaissGpuResourceMgr::GetInstance().Dump(); +} +#endif diff --git a/core/src/index/unittest/test_ivfsq_cpu_nm.cpp b/core/src/index/unittest/test_ivfsq_cpu_nm.cpp new file mode 100644 index 000000000000..efc91c4b1e05 --- /dev/null +++ b/core/src/index/unittest/test_ivfsq_cpu_nm.cpp @@ -0,0 +1,133 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include + +#include +#include +#include +#include + +#ifdef MILVUS_GPU_VERSION +#include +#endif + +#include +#include "knowhere/common/Exception.h" +#include "knowhere/common/Timer.h" +#include "knowhere/index/IndexType.h" +#include "knowhere/index/vector_index/IndexIVFSQ.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_offset_index/IndexIVFSQNR_NM.h" + +#ifdef MILVUS_GPU_VERSION +#include "knowhere/index/vector_index/helpers/Cloner.h" +#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h" +#endif + +#include "unittest/Helper.h" +#include "unittest/utils.h" + +using ::testing::Combine; +using ::testing::TestWithParam; +using ::testing::Values; + +class IVFSQNMCPUTest : public DataGen, + public TestWithParam<::std::tuple> { + protected: + void + SetUp() override { +#ifdef MILVUS_GPU_VERSION + milvus::knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(DEVICEID, PINMEM, TEMPMEM, RESNUM); +#endif + std::tie(index_type_, index_mode_) = GetParam(); + Generate(DIM, NB, NQ); + index_ = std::make_shared(); + conf_ = ParamGenerator::GetInstance().Gen(index_type_); + } + + void + TearDown() override { +#ifdef MILVUS_GPU_VERSION + milvus::knowhere::FaissGpuResourceMgr::GetInstance().Free(); +#endif + } + + protected: + milvus::knowhere::IndexType index_type_; + milvus::knowhere::IndexMode index_mode_; + milvus::knowhere::Config conf_; + milvus::knowhere::IVFSQNRNMPtr index_ = nullptr; +}; + +INSTANTIATE_TEST_CASE_P(IVFParameters, IVFSQNMCPUTest, + Values(std::make_tuple(milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8NR, + milvus::knowhere::IndexMode::MODE_CPU))); + +TEST_P(IVFSQNMCPUTest, ivf_basic_cpu) { + assert(!xb.empty()); + + if (index_mode_ != milvus::knowhere::IndexMode::MODE_CPU) { + return; + } + + // null faiss index + ASSERT_ANY_THROW(index_->Add(base_dataset, conf_)); + ASSERT_ANY_THROW(index_->AddWithoutIds(base_dataset, conf_)); + + index_->Train(base_dataset, conf_); + index_->AddWithoutIds(base_dataset, conf_); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + + index_->SetIndexSize(nq * dim * sizeof(float)); + + milvus::knowhere::BinarySet bs = index_->Serialize(); + index_->Load(bs); + + auto result = index_->Query(query_dataset, conf_); + AssertAnns(result, nq, k); + + auto AssertEqual = [&](milvus::knowhere::DatasetPtr p1, milvus::knowhere::DatasetPtr p2) { + auto ids_p1 = p1->Get(milvus::knowhere::meta::IDS); + auto ids_p2 = p2->Get(milvus::knowhere::meta::IDS); + + for (int i = 0; i < nq * k; ++i) { + EXPECT_EQ(*((int64_t*)(ids_p2) + i), *((int64_t*)(ids_p1) + i)); + } + }; + +#ifdef MILVUS_GPU_VERSION + // copy from cpu to gpu + { + EXPECT_NO_THROW({ + auto clone_index = milvus::knowhere::cloner::CopyCpuToGpu(index_, DEVICEID, milvus::knowhere::Config()); + auto clone_result = clone_index->Query(query_dataset, conf_); + AssertEqual(result, clone_result); + std::cout << "clone C <=> G [" << index_type_ << "] success" << std::endl; + }); + EXPECT_ANY_THROW(milvus::knowhere::cloner::CopyCpuToGpu(index_, -1, milvus::knowhere::Config())); + } +#endif + + faiss::ConcurrentBitsetPtr concurrent_bitset_ptr = std::make_shared(nb); + for (int64_t i = 0; i < nq; ++i) { + concurrent_bitset_ptr->set(i); + } + index_->SetBlacklist(concurrent_bitset_ptr); + + auto result_bs_1 = index_->Query(query_dataset, conf_); + AssertAnns(result_bs_1, nq, k, CheckMode::CHECK_NOT_EQUAL); + +#ifdef MILVUS_GPU_VERSION + milvus::knowhere::FaissGpuResourceMgr::GetInstance().Dump(); +#endif +} diff --git a/core/src/index/unittest/test_ivfsq_gpu_nm.cpp b/core/src/index/unittest/test_ivfsq_gpu_nm.cpp new file mode 100644 index 000000000000..22037440b519 --- /dev/null +++ b/core/src/index/unittest/test_ivfsq_gpu_nm.cpp @@ -0,0 +1,136 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include + +#include +#include +#include +#include + +#ifdef MILVUS_GPU_VERSION +#include +#endif + +#include "knowhere/common/Exception.h" +#include "knowhere/common/Timer.h" +#include "knowhere/index/IndexType.h" +#include "knowhere/index/vector_index/adapter/VectorAdapter.h" +#include "knowhere/index/vector_offset_index/IndexIVFSQNR_NM.h" +#include "knowhere/index/vector_offset_index/IndexIVF_NM.h" + +#ifdef MILVUS_GPU_VERSION +#include "knowhere/index/vector_index/helpers/Cloner.h" +#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h" +#include "knowhere/index/vector_offset_index/gpu/IndexGPUIVFSQNR_NM.h" +#include "knowhere/index/vector_offset_index/gpu/IndexGPUIVF_NM.h" +#endif + +#include "unittest/Helper.h" +#include "unittest/utils.h" + +using ::testing::Combine; +using ::testing::TestWithParam; +using ::testing::Values; + +class IVFSQNMGPUTest : public DataGen, + public TestWithParam<::std::tuple> { + protected: + void + SetUp() override { +#ifdef MILVUS_GPU_VERSION + milvus::knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(DEVICEID, PINMEM, TEMPMEM, RESNUM); +#endif + index_type_ = milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8; + index_mode_ = milvus::knowhere::IndexMode::MODE_GPU; + Generate(DIM, NB, NQ); +#ifdef MILVUS_GPU_VERSION + index_ = std::make_shared(DEVICEID); +#endif + conf_ = ParamGenerator::GetInstance().Gen(index_type_); + } + + void + TearDown() override { +#ifdef MILVUS_GPU_VERSION + milvus::knowhere::FaissGpuResourceMgr::GetInstance().Free(); +#endif + } + + protected: + milvus::knowhere::IndexType index_type_; + milvus::knowhere::IndexMode index_mode_; + milvus::knowhere::Config conf_; + milvus::knowhere::IVFPtr index_ = nullptr; +}; + +INSTANTIATE_TEST_CASE_P(IVFParameters, IVFSQNMGPUTest, + Values(std::make_tuple(milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, + milvus::knowhere::IndexMode::MODE_GPU))); + +#ifdef MILVUS_GPU_VERSION +TEST_P(IVFSQNMGPUTest, ivf_basic_gpu) { + assert(!xb.empty()); + + if (index_mode_ != milvus::knowhere::IndexMode::MODE_GPU) { + return; + } + + // null faiss index + ASSERT_ANY_THROW(index_->Add(base_dataset, conf_)); + ASSERT_ANY_THROW(index_->AddWithoutIds(base_dataset, conf_)); + + index_->Train(base_dataset, conf_); + index_->AddWithoutIds(base_dataset, conf_); + EXPECT_EQ(index_->Count(), nb); + EXPECT_EQ(index_->Dim(), dim); + + index_->SetIndexSize(nq * dim * sizeof(float)); + + milvus::knowhere::BinarySet bs = index_->Serialize(); + index_->Load(bs); + + auto result = index_->Query(query_dataset, conf_); + AssertAnns(result, nq, k); + + auto AssertEqual = [&](milvus::knowhere::DatasetPtr p1, milvus::knowhere::DatasetPtr p2) { + auto ids_p1 = p1->Get(milvus::knowhere::meta::IDS); + auto ids_p2 = p2->Get(milvus::knowhere::meta::IDS); + + for (int i = 0; i < nq * k; ++i) { + EXPECT_EQ(*((int64_t*)(ids_p2) + i), *((int64_t*)(ids_p1) + i)); + } + }; + + // copy from gpu to cpu + { + EXPECT_NO_THROW({ + auto clone_index = milvus::knowhere::cloner::CopyGpuToCpu(index_, milvus::knowhere::Config()); + milvus::knowhere::BinarySet bs = clone_index->Serialize(); + clone_index->Load(bs); + auto clone_result = clone_index->Query(query_dataset, conf_); + AssertEqual(result, clone_result); + std::cout << "clone G <=> C [" << index_type_ << "] success" << std::endl; + }); + } + + faiss::ConcurrentBitsetPtr concurrent_bitset_ptr = std::make_shared(nb); + for (int64_t i = 0; i < nq; ++i) { + concurrent_bitset_ptr->set(i); + } + index_->SetBlacklist(concurrent_bitset_ptr); + + auto result_bs_1 = index_->Query(query_dataset, conf_); + AssertAnns(result_bs_1, nq, k, CheckMode::CHECK_NOT_EQUAL); + + milvus::knowhere::FaissGpuResourceMgr::GetInstance().Dump(); +} +#endif diff --git a/core/src/index/unittest/test_nsg.cpp b/core/src/index/unittest/test_nsg.cpp index c4dc46a86981..350e4096b5ea 100644 --- a/core/src/index/unittest/test_nsg.cpp +++ b/core/src/index/unittest/test_nsg.cpp @@ -15,9 +15,8 @@ #include #include "knowhere/common/Exception.h" -#include "knowhere/index/vector_index/FaissBaseIndex.h" -#include "knowhere/index/vector_index/IndexNSG.h" #include "knowhere/index/vector_index/helpers/IndexParameter.h" +#include "knowhere/index/vector_offset_index/IndexNSG_NM.h" #ifdef MILVUS_GPU_VERSION #include "knowhere/index/vector_index/gpu/IndexGPUIDMAP.h" #include "knowhere/index/vector_index/helpers/Cloner.h" @@ -45,7 +44,7 @@ class NSGInterfaceTest : public DataGen, public ::testing::Test { #endif int nsg_dim = 256; Generate(nsg_dim, 20000, nq); - index_ = std::make_shared(); + index_ = std::make_shared(); train_conf = milvus::knowhere::Config{{milvus::knowhere::meta::DIM, 256}, {milvus::knowhere::IndexParams::nlist, 163}, @@ -70,7 +69,7 @@ class NSGInterfaceTest : public DataGen, public ::testing::Test { } protected: - std::shared_ptr index_; + std::shared_ptr index_; milvus::knowhere::Config train_conf; milvus::knowhere::Config search_conf; }; @@ -88,34 +87,43 @@ TEST_F(NSGInterfaceTest, basic_test) { train_conf[milvus::knowhere::meta::DEVICEID] = -1; index_->BuildAll(base_dataset, train_conf); + + // Serialize and Load before Query + milvus::knowhere::BinarySet bs = index_->Serialize(); + + int64_t dim = base_dataset->Get(milvus::knowhere::meta::DIM); + int64_t rows = base_dataset->Get(milvus::knowhere::meta::ROWS); + auto raw_data = base_dataset->Get(milvus::knowhere::meta::TENSOR); + milvus::knowhere::BinaryPtr bptr = std::make_shared(); + bptr->data = std::shared_ptr((uint8_t*)raw_data, [&](uint8_t*) {}); + bptr->size = dim * rows * sizeof(float); + bs.Append(RAW_DATA, bptr); + + index_->Load(bs); + auto result = index_->Query(query_dataset, search_conf); AssertAnns(result, nq, k); - auto binaryset = index_->Serialize(); - { - fiu_enable("NSG.Serialize.throw_exception", 1, nullptr, 0); - ASSERT_ANY_THROW(index_->Serialize()); - fiu_disable("NSG.Serialize.throw_exception"); - } - /* test NSG GPU train */ - auto new_index_1 = std::make_shared(DEVICE_GPU0); + auto new_index_1 = std::make_shared(DEVICE_GPU0); train_conf[milvus::knowhere::meta::DEVICEID] = DEVICE_GPU0; new_index_1->BuildAll(base_dataset, train_conf); - auto new_result_1 = new_index_1->Query(query_dataset, search_conf); - AssertAnns(new_result_1, nq, k); - /* test NSG index load */ - auto new_index_2 = std::make_shared(); - new_index_2->Load(binaryset); - { - fiu_enable("NSG.Load.throw_exception", 1, nullptr, 0); - ASSERT_ANY_THROW(new_index_2->Load(binaryset)); - fiu_disable("NSG.Load.throw_exception"); - } + // Serialize and Load before Query + bs = new_index_1->Serialize(); + + dim = base_dataset->Get(milvus::knowhere::meta::DIM); + rows = base_dataset->Get(milvus::knowhere::meta::ROWS); + raw_data = base_dataset->Get(milvus::knowhere::meta::TENSOR); + bptr = std::make_shared(); + bptr->data = std::shared_ptr((uint8_t*)raw_data, [&](uint8_t*) {}); + bptr->size = dim * rows * sizeof(float); + bs.Append(RAW_DATA, bptr); + + new_index_1->Load(bs); - auto new_result_2 = new_index_2->Query(query_dataset, search_conf); - AssertAnns(new_result_2, nq, k); + auto new_result_1 = new_index_1->Query(query_dataset, search_conf); + AssertAnns(new_result_1, nq, k); ASSERT_EQ(index_->Count(), nb); ASSERT_EQ(index_->Dim(), dim); @@ -142,6 +150,19 @@ TEST_F(NSGInterfaceTest, delete_test) { train_conf[milvus::knowhere::meta::DEVICEID] = DEVICE_GPU0; index_->Train(base_dataset, train_conf); + // Serialize and Load before Query + milvus::knowhere::BinarySet bs = index_->Serialize(); + + int64_t dim = base_dataset->Get(milvus::knowhere::meta::DIM); + int64_t rows = base_dataset->Get(milvus::knowhere::meta::ROWS); + auto raw_data = base_dataset->Get(milvus::knowhere::meta::TENSOR); + milvus::knowhere::BinaryPtr bptr = std::make_shared(); + bptr->data = std::shared_ptr((uint8_t*)raw_data, [&](uint8_t*) {}); + bptr->size = dim * rows * sizeof(float); + bs.Append(RAW_DATA, bptr); + + index_->Load(bs); + auto result = index_->Query(query_dataset, search_conf); AssertAnns(result, nq, k); @@ -157,6 +178,19 @@ TEST_F(NSGInterfaceTest, delete_test) { // search xq with delete index_->SetBlacklist(bitset); + + // Serialize and Load before Query + bs = index_->Serialize(); + + dim = base_dataset->Get(milvus::knowhere::meta::DIM); + rows = base_dataset->Get(milvus::knowhere::meta::ROWS); + raw_data = base_dataset->Get(milvus::knowhere::meta::TENSOR); + bptr = std::make_shared(); + bptr->data = std::shared_ptr((uint8_t*)raw_data, [&](uint8_t*) {}); + bptr->size = dim * rows * sizeof(float); + bs.Append(RAW_DATA, bptr); + + index_->Load(bs); auto result_after = index_->Query(query_dataset, search_conf); AssertAnns(result_after, nq, k, CheckMode::CHECK_NOT_EQUAL); auto I_after = result_after->Get(milvus::knowhere::meta::IDS); diff --git a/core/src/index/unittest/test_vecindex.cpp b/core/src/index/unittest/test_vecindex.cpp index dc231f640362..713e9d7988a4 100644 --- a/core/src/index/unittest/test_vecindex.cpp +++ b/core/src/index/unittest/test_vecindex.cpp @@ -11,7 +11,7 @@ #include -#include "knowhere/index/vector_index/IndexType.h" +#include "knowhere/index/IndexType.h" #include "knowhere/index/vector_index/VecIndex.h" #include "knowhere/index/vector_index/VecIndexFactory.h" #include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h" diff --git a/core/src/query/GeneralQuery.h b/core/src/query/GeneralQuery.h index a2e3274cfd35..23fd5c659239 100644 --- a/core/src/query/GeneralQuery.h +++ b/core/src/query/GeneralQuery.h @@ -109,6 +109,10 @@ struct BinaryQuery { struct Query { BinaryQueryPtr root; std::unordered_map vectors; + + std::string collection_id; + std::vector partitions; + std::vector field_names; }; using QueryPtr = std::shared_ptr; diff --git a/core/src/scheduler/Definition.h b/core/src/scheduler/Definition.h index 10526804c918..ed5fcb89e5a6 100644 --- a/core/src/scheduler/Definition.h +++ b/core/src/scheduler/Definition.h @@ -35,6 +35,7 @@ using ExecutionEnginePtr = engine::ExecutionEnginePtr; using EngineFactory = engine::EngineFactory; using EngineType = engine::EngineType; using MetricType = engine::MetricType; +using DataType = engine::meta::hybrid::DataType; constexpr uint64_t TASK_TABLE_MAX_COUNT = 1ULL << 16ULL; diff --git a/core/src/scheduler/JobMgr.cpp b/core/src/scheduler/JobMgr.cpp index 0856c749be33..78931daa66ea 100644 --- a/core/src/scheduler/JobMgr.cpp +++ b/core/src/scheduler/JobMgr.cpp @@ -9,21 +9,19 @@ // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express // or implied. See the License for the specific language governing permissions and limitations under the License. -#include "scheduler/JobMgr.h" - -#include "src/db/Utils.h" -#include "src/segment/SegmentReader.h" - #include +#include #include -#include "SchedInst.h" -#include "TaskCreator.h" +#include "db/Utils.h" #include "scheduler/Algorithm.h" #include "scheduler/CPUBuilder.h" +#include "scheduler/JobMgr.h" +#include "scheduler/SchedInst.h" +#include "scheduler/TaskCreator.h" +#include "scheduler/selector/Optimizer.h" +#include "scheduler/task/Task.h" #include "scheduler/tasklabel/SpecResLabel.h" -#include "selector/Optimizer.h" -#include "task/Task.h" namespace milvus { namespace scheduler { @@ -33,99 +31,50 @@ JobMgr::JobMgr(ResourceMgrPtr res_mgr) : res_mgr_(std::move(res_mgr)) { void JobMgr::Start() { - if (not running_) { - running_ = true; - worker_thread_ = std::thread(&JobMgr::worker_function, this); + if (worker_thread_ == nullptr) { + worker_thread_ = std::make_shared(&JobMgr::worker_function, this); } } void JobMgr::Stop() { - if (running_) { + if (worker_thread_ != nullptr) { this->Put(nullptr); - worker_thread_.join(); - running_ = false; + worker_thread_->join(); + worker_thread_ = nullptr; } } json JobMgr::Dump() const { json ret{ - {"running", running_}, - {"event_queue_length", queue_.size()}, + {"running", (worker_thread_ != nullptr ? true : false)}, + {"event_queue_length", queue_.Size()}, }; return ret; } void JobMgr::Put(const JobPtr& job) { - { - std::lock_guard lock(mutex_); - queue_.push(job); - } - cv_.notify_one(); + queue_.Put(job); } void JobMgr::worker_function() { SetThreadName("jobmgr_thread"); - while (running_) { - std::unique_lock lock(mutex_); - cv_.wait(lock, [this] { return !queue_.empty(); }); - auto job = queue_.front(); - queue_.pop(); - lock.unlock(); + while (true) { + auto job = queue_.Take(); if (job == nullptr) { break; } - auto tasks = build_task(job); - - // TODO(zhiru): if the job is search by ids, pass any task where the ids don't exist auto search_job = std::dynamic_pointer_cast(job); if (search_job != nullptr) { search_job->GetResultIds().resize(search_job->nq(), -1); search_job->GetResultDistances().resize(search_job->nq(), std::numeric_limits::max()); - - if (search_job->vectors().float_data_.empty() && search_job->vectors().binary_data_.empty() && - !search_job->vectors().id_array_.empty()) { - for (auto task = tasks.begin(); task != tasks.end();) { - auto search_task = std::static_pointer_cast(*task); - auto location = search_task->GetLocation(); - - // Load bloom filter - std::string segment_dir; - engine::utils::GetParentPath(location, segment_dir); - segment::SegmentReader segment_reader(segment_dir); - segment::IdBloomFilterPtr id_bloom_filter_ptr; - segment_reader.LoadBloomFilter(id_bloom_filter_ptr); - - // Check if the id is present. - bool pass = true; - for (auto& id : search_job->vectors().id_array_) { - if (id_bloom_filter_ptr->Check(id)) { - pass = false; - break; - } - } - - if (pass) { - // std::cout << search_task->GetIndexId() << std::endl; - search_job->SearchDone(search_task->GetIndexId()); - task = tasks.erase(task); - } else { - task++; - } - } - } } - // for (auto &task : tasks) { - // if ... - // search_job->SearchDone(task->id); - // tasks.erase(task); - // } - + auto tasks = build_task(job); for (auto& task : tasks) { OptimizerInst::GetInstance()->Run(task); } diff --git a/core/src/scheduler/JobMgr.h b/core/src/scheduler/JobMgr.h index 0e9a21f8eb67..fadb15dfe80f 100644 --- a/core/src/scheduler/JobMgr.h +++ b/core/src/scheduler/JobMgr.h @@ -10,21 +10,15 @@ // or implied. See the License for the specific language governing permissions and limitations under the License. #pragma once -#include -#include -#include #include -#include -#include -#include #include -#include #include -#include "ResourceMgr.h" -#include "interface/interfaces.h" -#include "job/Job.h" -#include "task/Task.h" +#include "scheduler/ResourceMgr.h" +#include "scheduler/interface/interfaces.h" +#include "scheduler/job/Job.h" +#include "scheduler/task/Task.h" +#include "utils/BlockingQueue.h" namespace milvus { namespace scheduler { @@ -58,14 +52,8 @@ class JobMgr : public interface::dumpable { calculate_path(const ResourceMgrPtr& res_mgr, const TaskPtr& task); private: - bool running_ = false; - std::queue queue_; - - std::thread worker_thread_; - - std::mutex mutex_; - std::condition_variable cv_; - + BlockingQueue queue_; + std::shared_ptr worker_thread_ = nullptr; ResourceMgrPtr res_mgr_ = nullptr; }; diff --git a/core/src/scheduler/TaskCreator.cpp b/core/src/scheduler/TaskCreator.cpp index 7196af63cee4..fa9b41ea26e1 100644 --- a/core/src/scheduler/TaskCreator.cpp +++ b/core/src/scheduler/TaskCreator.cpp @@ -10,9 +10,14 @@ // or implied. See the License for the specific language governing permissions and limitations under the License. #include "scheduler/TaskCreator.h" -#include "SchedInst.h" -#include "tasklabel/BroadcastLabel.h" -#include "tasklabel/SpecResLabel.h" +#include "scheduler/SchedInst.h" +#include "scheduler/task/BuildIndexTask.h" +#include "scheduler/task/DeleteTask.h" +#include "scheduler/task/SSBuildIndexTask.h" +#include "scheduler/task/SSSearchTask.h" +#include "scheduler/task/SearchTask.h" +#include "scheduler/tasklabel/BroadcastLabel.h" +#include "scheduler/tasklabel/SpecResLabel.h" namespace milvus { namespace scheduler { @@ -29,6 +34,12 @@ TaskCreator::Create(const JobPtr& job) { case JobType::BUILD: { return Create(std::static_pointer_cast(job)); } + case JobType::SS_SEARCH: { + return Create(std::static_pointer_cast(job)); + } + case JobType::SS_BUILD: { + return Create(std::static_pointer_cast(job)); + } default: { // TODO(wxyu): error return std::vector(); @@ -70,5 +81,28 @@ TaskCreator::Create(const BuildIndexJobPtr& job) { return tasks; } +std::vector +TaskCreator::Create(const SSSearchJobPtr& job) { + std::vector tasks; + for (auto& id : job->segment_ids()) { + auto task = std::make_shared(job->GetContext(), job->options(), job->query_ptr(), id, nullptr); + task->job_ = job; + tasks.emplace_back(task); + } + return tasks; +} + +std::vector +TaskCreator::Create(const SSBuildIndexJobPtr& job) { + std::vector tasks; + const std::string& collection_name = job->collection_name(); + for (auto& id : job->segment_ids()) { + auto task = std::make_shared(job->options(), collection_name, id, nullptr); + task->job_ = job; + tasks.emplace_back(task); + } + return tasks; +} + } // namespace scheduler } // namespace milvus diff --git a/core/src/scheduler/TaskCreator.h b/core/src/scheduler/TaskCreator.h index a1dd8bb8a462..ef2b8997bd65 100644 --- a/core/src/scheduler/TaskCreator.h +++ b/core/src/scheduler/TaskCreator.h @@ -21,12 +21,11 @@ #include #include +#include "job/BuildIndexJob.h" #include "job/DeleteJob.h" -#include "job/Job.h" +#include "job/SSBuildIndexJob.h" +#include "job/SSSearchJob.h" #include "job/SearchJob.h" -#include "task/BuildIndexTask.h" -#include "task/DeleteTask.h" -#include "task/SearchTask.h" #include "task/Task.h" namespace milvus { @@ -46,6 +45,12 @@ class TaskCreator { static std::vector Create(const BuildIndexJobPtr& job); + + static std::vector + Create(const SSSearchJobPtr& job); + + static std::vector + Create(const SSBuildIndexJobPtr& job); }; } // namespace scheduler diff --git a/core/src/scheduler/job/Job.h b/core/src/scheduler/job/Job.h index d350928bf395..ec71fc4e73f7 100644 --- a/core/src/scheduler/job/Job.h +++ b/core/src/scheduler/job/Job.h @@ -21,21 +21,26 @@ #include #include +#include "db/SnapshotVisitor.h" +#include "db/snapshot/ResourceTypes.h" #include "scheduler/interface/interfaces.h" - #include "server/context/Context.h" namespace milvus { namespace scheduler { enum class JobType { - INVALID, - SEARCH, - DELETE, - BUILD, + INVALID = -1, + SEARCH = 0, + DELETE = 1, + BUILD = 2, + + SS_SEARCH = 10, + SS_BUILD = 11, }; using JobId = std::uint64_t; +using SegmentVisitorMap = std::unordered_map; class Job : public interface::dumpable { public: diff --git a/core/src/scheduler/job/SSBuildIndexJob.cpp b/core/src/scheduler/job/SSBuildIndexJob.cpp new file mode 100644 index 000000000000..6153e4e0f7c3 --- /dev/null +++ b/core/src/scheduler/job/SSBuildIndexJob.cpp @@ -0,0 +1,62 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "scheduler/job/SSBuildIndexJob.h" + +#include + +#include "utils/Log.h" + +namespace milvus { +namespace scheduler { + +SSBuildIndexJob::SSBuildIndexJob(engine::DBOptions options, const std::string& collection_name, + const engine::snapshot::IDS_TYPE& segment_ids) + : Job(JobType::SS_BUILD), + options_(std::move(options)), + collection_name_(collection_name), + segment_ids_(segment_ids) { +} + +void +SSBuildIndexJob::WaitFinish() { + std::unique_lock lock(mutex_); + cv_.wait(lock, [this] { return segment_ids_.empty(); }); + LOG_SERVER_DEBUG_ << LogOut("[%s][%ld] BuildIndexJob %ld all done", "build index", 0, id()); +} + +void +SSBuildIndexJob::BuildIndexDone(const engine::snapshot::ID_TYPE seg_id) { + std::unique_lock lock(mutex_); + for (engine::snapshot::IDS_TYPE::iterator iter = segment_ids_.begin(); iter != segment_ids_.end(); ++iter) { + if (*iter == seg_id) { + segment_ids_.erase(iter); + break; + } + } + if (segment_ids_.empty()) { + cv_.notify_all(); + } + LOG_SERVER_DEBUG_ << LogOut("[%s][%ld] BuildIndexJob %ld finish segment: %ld", "build index", 0, id(), seg_id); +} + +json +SSBuildIndexJob::Dump() const { + json ret{ + {"number_of_to_index_segment", segment_ids_.size()}, + }; + auto base = Job::Dump(); + ret.insert(base.begin(), base.end()); + return ret; +} + +} // namespace scheduler +} // namespace milvus diff --git a/core/src/scheduler/job/SSBuildIndexJob.h b/core/src/scheduler/job/SSBuildIndexJob.h new file mode 100644 index 000000000000..01d235d0f471 --- /dev/null +++ b/core/src/scheduler/job/SSBuildIndexJob.h @@ -0,0 +1,83 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "db/snapshot/ResourceTypes.h" +#include "scheduler/Definition.h" +#include "scheduler/job/Job.h" + +namespace milvus { +namespace scheduler { + +class SSBuildIndexJob : public Job { + public: + explicit SSBuildIndexJob(engine::DBOptions options, const std::string& collection_name, + const engine::snapshot::IDS_TYPE& segment_ids); + + ~SSBuildIndexJob() = default; + + public: + void + WaitFinish(); + + void + BuildIndexDone(const engine::snapshot::ID_TYPE seg_id); + + json + Dump() const override; + + public: + engine::DBOptions + options() const { + return options_; + } + + const std::string& + collection_name() { + return collection_name_; + } + + const engine::snapshot::IDS_TYPE& + segment_ids() { + return segment_ids_; + } + + Status& + status() { + return status_; + } + + private: + engine::DBOptions options_; + std::string collection_name_; + engine::snapshot::IDS_TYPE segment_ids_; + + Status status_; + std::mutex mutex_; + std::condition_variable cv_; +}; + +using SSBuildIndexJobPtr = std::shared_ptr; + +} // namespace scheduler +} // namespace milvus diff --git a/core/src/scheduler/job/SSSearchJob.cpp b/core/src/scheduler/job/SSSearchJob.cpp new file mode 100644 index 000000000000..ef4651e2b07a --- /dev/null +++ b/core/src/scheduler/job/SSSearchJob.cpp @@ -0,0 +1,61 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "scheduler/job/SSSearchJob.h" +#include "utils/Log.h" + +namespace milvus { +namespace scheduler { + +SSSearchJob::SSSearchJob(const server::ContextPtr& context, engine::DBOptions options, const query::QueryPtr& query_ptr) + : Job(JobType::SS_SEARCH), context_(context), options_(options), query_ptr_(query_ptr) { + GetSegmentsFromQuery(query_ptr, segment_ids_); +} + +void +SSSearchJob::WaitFinish() { + std::unique_lock lock(mutex_); + cv_.wait(lock, [this] { return segment_ids_.empty(); }); + LOG_SERVER_DEBUG_ << LogOut("[%s][%ld] SearchJob %ld all done", "search", 0, id()); +} + +void +SSSearchJob::SearchDone(const engine::snapshot::ID_TYPE seg_id) { + std::unique_lock lock(mutex_); + for (engine::snapshot::IDS_TYPE::iterator iter = segment_ids_.begin(); iter != segment_ids_.end(); ++iter) { + if (*iter == seg_id) { + segment_ids_.erase(iter); + break; + } + } + if (segment_ids_.empty()) { + cv_.notify_all(); + } + LOG_SERVER_DEBUG_ << LogOut("[%s][%ld] SearchJob %ld finish segment: %ld", "search", 0, id(), seg_id); +} + +json +SSSearchJob::Dump() const { + json ret{ + {"number_of_search_segment", segment_ids_.size()}, + }; + auto base = Job::Dump(); + ret.insert(base.begin(), base.end()); + return ret; +} + +void +SSSearchJob::GetSegmentsFromQuery(const query::QueryPtr& query_ptr, engine::snapshot::IDS_TYPE& segment_ids) { + // TODO +} + +} // namespace scheduler +} // namespace milvus diff --git a/core/src/scheduler/job/SSSearchJob.h b/core/src/scheduler/job/SSSearchJob.h new file mode 100644 index 000000000000..54b6ac0e0362 --- /dev/null +++ b/core/src/scheduler/job/SSSearchJob.h @@ -0,0 +1,115 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "Job.h" +#include "db/SnapshotVisitor.h" +#include "db/Types.h" +//#include "db/meta/MetaTypes.h" + +#include "server/context/Context.h" + +namespace milvus { +namespace scheduler { + +// struct SearchTimeStat { +// double query_time = 0.0; +// double map_uids_time = 0.0; +// double reduce_time = 0.0; +//}; + +class SSSearchJob : public Job { + public: + SSSearchJob(const server::ContextPtr& context, engine::DBOptions options, const query::QueryPtr& query_ptr); + + public: + void + AddSegmentVisitor(const engine::SegmentVisitorPtr& visitor); + + void + WaitFinish(); + + void + SearchDone(const engine::snapshot::ID_TYPE seg_id); + + json + Dump() const override; + + public: + const server::ContextPtr& + GetContext() const { + return context_; + } + + engine::DBOptions + options() const { + return options_; + } + + const query::QueryPtr + query_ptr() const { + return query_ptr_; + } + + engine::QueryResultPtr& + query_result() { + return query_result_; + } + + const engine::snapshot::IDS_TYPE& + segment_ids() { + return segment_ids_; + } + + Status& + status() { + return status_; + } + + std::mutex& + mutex() { + return mutex_; + } + + private: + void + GetSegmentsFromQuery(const query::QueryPtr& query_ptr, engine::snapshot::IDS_TYPE& segment_ids); + + private: + const server::ContextPtr context_; + + engine::DBOptions options_; + + query::QueryPtr query_ptr_; + engine::QueryResultPtr query_result_; + engine::snapshot::IDS_TYPE segment_ids_; + + Status status_; + std::mutex mutex_; + std::condition_variable cv_; +}; + +using SSSearchJobPtr = std::shared_ptr; + +} // namespace scheduler +} // namespace milvus diff --git a/core/src/scheduler/job/SearchJob.h b/core/src/scheduler/job/SearchJob.h index 01d44cf28e3a..5609497d887c 100644 --- a/core/src/scheduler/job/SearchJob.h +++ b/core/src/scheduler/job/SearchJob.h @@ -23,9 +23,9 @@ #include #include "Job.h" +#include "db/SnapshotVisitor.h" #include "db/Types.h" #include "db/meta/MetaTypes.h" - #include "query/GeneralQuery.h" #include "server/context/Context.h" diff --git a/core/src/scheduler/task/BuildIndexTask.cpp b/core/src/scheduler/task/BuildIndexTask.cpp index 68cab7c2b3ee..2a3be7255304 100644 --- a/core/src/scheduler/task/BuildIndexTask.cpp +++ b/core/src/scheduler/task/BuildIndexTask.cpp @@ -220,9 +220,10 @@ XBuildIndexTask::Execute() { LOG_ENGINE_DEBUG_ << "New index file " << table_file.file_id_ << " of size " << table_file.file_size_ << " bytes" << " from file " << origin_file.file_id_; - if (build_index_job->options().insert_cache_immediately_) { - index->Cache(); - } + // XXX_Index_NM doesn't support it now. + // if (build_index_job->options().insert_cache_immediately_) { + // index->Cache(); + // } } else { // failed to update meta, mark the new file as to_delete, don't delete old file origin_file.file_type_ = engine::meta::SegmentSchema::TO_INDEX; diff --git a/core/src/scheduler/task/SSBuildIndexTask.cpp b/core/src/scheduler/task/SSBuildIndexTask.cpp new file mode 100644 index 000000000000..8f803419f402 --- /dev/null +++ b/core/src/scheduler/task/SSBuildIndexTask.cpp @@ -0,0 +1,115 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include +#include +#include + +#include "db/Utils.h" +#include "db/engine/EngineFactory.h" +#include "scheduler/job/SSBuildIndexJob.h" +#include "scheduler/task/SSBuildIndexTask.h" +#include "utils/Log.h" +#include "utils/TimeRecorder.h" + +namespace milvus { +namespace scheduler { + +SSBuildIndexTask::SSBuildIndexTask(const engine::DBOptions& options, const std::string& collection_name, + engine::snapshot::ID_TYPE segment_id, TaskLabelPtr label) + : Task(TaskType::BuildIndexTask, std::move(label)), + options_(options), + collection_name_(collection_name), + segment_id_(segment_id) { + CreateExecEngine(); +} + +void +SSBuildIndexTask::CreateExecEngine() { + if (execution_engine_ == nullptr) { + execution_engine_ = engine::EngineFactory::Build(options_.meta_.path_, collection_name_, segment_id_); + } +} + +void +SSBuildIndexTask::Load(milvus::scheduler::LoadType type, uint8_t device_id) { + TimeRecorder rc("SSBuildIndexTask::Load"); + Status stat = Status::OK(); + std::string error_msg; + std::string type_str; + + if (auto job = job_.lock()) { + try { + if (type == LoadType::DISK2CPU) { + engine::ExecutionEngineContext context; + stat = execution_engine_->Load(context); + type_str = "DISK2CPU"; + } else if (type == LoadType::CPU2GPU) { + stat = execution_engine_->CopyToGpu(device_id); + type_str = "CPU2GPU:" + std::to_string(device_id); + } else { + error_msg = "Wrong load type"; + stat = Status(SERVER_UNEXPECTED_ERROR, error_msg); + } + fiu_do_on("XSSBuildIndexTask.Load.throw_std_exception", throw std::exception()); + } catch (std::exception& ex) { + // typical error: out of disk space or permission denied + error_msg = "Failed to load to_index file: " + std::string(ex.what()); + LOG_ENGINE_ERROR_ << error_msg; + stat = Status(SERVER_UNEXPECTED_ERROR, error_msg); + } + + if (!stat.ok()) { + Status s; + if (stat.ToString().find("out of memory") != std::string::npos) { + error_msg = "out of memory: " + type_str; + s = Status(SERVER_UNEXPECTED_ERROR, error_msg); + } else { + error_msg = "Failed to load to_index file: " + type_str; + s = Status(SERVER_UNEXPECTED_ERROR, error_msg); + } + + LOG_ENGINE_ERROR_ << s.message(); + + auto build_index_job = std::static_pointer_cast(job); + build_index_job->status() = s; + build_index_job->BuildIndexDone(segment_id_); + } + } +} + +void +SSBuildIndexTask::Execute() { + TimeRecorderAuto rc("XSSBuildIndexTask::Execute " + std::to_string(segment_id_)); + + if (auto job = job_.lock()) { + auto build_index_job = std::static_pointer_cast(job); + if (execution_engine_ == nullptr) { + build_index_job->BuildIndexDone(segment_id_); + build_index_job->status() = Status(DB_ERROR, "execution engine is null"); + return; + } + + auto status = execution_engine_->BuildIndex(); + if (!status.ok()) { + LOG_ENGINE_ERROR_ << "Failed to create collection file: " << status.ToString(); + build_index_job->BuildIndexDone(segment_id_); + build_index_job->status() = status; + execution_engine_ = nullptr; + return; + } + + build_index_job->BuildIndexDone(segment_id_); + } +} + +} // namespace scheduler +} // namespace milvus diff --git a/core/src/scheduler/task/SSBuildIndexTask.h b/core/src/scheduler/task/SSBuildIndexTask.h new file mode 100644 index 000000000000..f7e91e28785c --- /dev/null +++ b/core/src/scheduler/task/SSBuildIndexTask.h @@ -0,0 +1,49 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include + +#include "db/engine/SSExecutionEngine.h" +#include "db/snapshot/ResourceTypes.h" +#include "scheduler/Definition.h" +#include "scheduler/job/SSBuildIndexJob.h" +#include "scheduler/task/Task.h" + +namespace milvus { +namespace scheduler { + +class SSBuildIndexTask : public Task { + public: + explicit SSBuildIndexTask(const engine::DBOptions& options, const std::string& collection_name, + engine::snapshot::ID_TYPE segment_id, TaskLabelPtr label); + + void + Load(LoadType type, uint8_t device_id) override; + + void + Execute() override; + + private: + void + CreateExecEngine(); + + public: + const engine::DBOptions& options_; + std::string collection_name_; + engine::snapshot::ID_TYPE segment_id_; + + engine::SSExecutionEnginePtr execution_engine_; +}; + +} // namespace scheduler +} // namespace milvus diff --git a/core/src/scheduler/task/SSSearchTask.cpp b/core/src/scheduler/task/SSSearchTask.cpp new file mode 100644 index 000000000000..765f32f62d91 --- /dev/null +++ b/core/src/scheduler/task/SSSearchTask.cpp @@ -0,0 +1,165 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "scheduler/task/SSSearchTask.h" + +#include + +#include +#include +#include +#include +#include +#include + +#include "db/Utils.h" +#include "db/engine/SSExecutionEngineImpl.h" +#include "scheduler/SchedInst.h" +#include "scheduler/job/SSSearchJob.h" +#include "segment/SegmentReader.h" +#include "utils/Log.h" +#include "utils/TimeRecorder.h" + +namespace milvus { +namespace scheduler { + +SSSearchTask::SSSearchTask(const server::ContextPtr& context, const engine::DBOptions& options, + const query::QueryPtr& query_ptr, engine::snapshot::ID_TYPE segment_id, TaskLabelPtr label) + : Task(TaskType::SearchTask, std::move(label)), + context_(context), + options_(options), + query_ptr_(query_ptr), + segment_id_(segment_id) { + CreateExecEngine(); +} + +void +SSSearchTask::CreateExecEngine() { + if (execution_engine_ == nullptr && query_ptr_ != nullptr) { + execution_engine_ = engine::EngineFactory::Build(options_.meta_.path_, query_ptr_->collection_id, segment_id_); + } +} + +void +SSSearchTask::Load(LoadType type, uint8_t device_id) { + TimeRecorder rc(LogOut("[%s][%ld]", "search", segment_id_)); + Status stat = Status::OK(); + std::string error_msg; + std::string type_str; + + if (auto job = job_.lock()) { + try { + fiu_do_on("XSearchTask.Load.throw_std_exception", throw std::exception()); + if (type == LoadType::DISK2CPU) { + engine::ExecutionEngineContext context; + context.query_ptr_ = query_ptr_; + stat = execution_engine_->Load(context); + type_str = "DISK2CPU"; + } else if (type == LoadType::CPU2GPU) { + stat = execution_engine_->CopyToGpu(device_id); + type_str = "CPU2GPU" + std::to_string(device_id); + } else if (type == LoadType::GPU2CPU) { + // stat = engine_->CopyToCpu(); + type_str = "GPU2CPU"; + } else { + error_msg = "Wrong load type"; + stat = Status(SERVER_UNEXPECTED_ERROR, error_msg); + } + } catch (std::exception& ex) { + // typical error: out of disk space or permition denied + error_msg = "Failed to load index file: " + std::string(ex.what()); + LOG_ENGINE_ERROR_ << LogOut("[%s][%ld] Encounter exception: %s", "search", 0, error_msg.c_str()); + stat = Status(SERVER_UNEXPECTED_ERROR, error_msg); + } + fiu_do_on("XSearchTask.Load.out_of_memory", stat = Status(SERVER_UNEXPECTED_ERROR, "out of memory")); + } + + if (!stat.ok()) { + Status s; + if (stat.ToString().find("out of memory") != std::string::npos) { + error_msg = "out of memory: " + type_str + " : " + stat.message(); + s = Status(SERVER_OUT_OF_MEMORY, error_msg); + } else { + error_msg = "Failed to load index file: " + type_str + " : " + stat.message(); + s = Status(SERVER_UNEXPECTED_ERROR, error_msg); + } + + if (auto job = job_.lock()) { + auto search_job = std::static_pointer_cast(job); + search_job->SearchDone(segment_id_); + search_job->status() = s; + } + + return; + } + + std::string info = "Search task load segment id: " + std::to_string(segment_id_) + " " + type_str + " totally cost"; + rc.ElapseFromBegin(info); +} + +void +SSSearchTask::Execute() { + milvus::server::ContextFollower tracer(context_, "XSearchTask::Execute " + std::to_string(segment_id_)); + TimeRecorder rc(LogOut("[%s][%ld] DoSearch file id:%ld", "search", 0, segment_id_)); + + if (auto job = job_.lock()) { + auto search_job = std::static_pointer_cast(job); + + if (execution_engine_ == nullptr) { + search_job->SearchDone(segment_id_); + search_job->status() = Status(DB_ERROR, "execution engine is null"); + return; + } + + fiu_do_on("XSearchTask.Execute.throw_std_exception", throw std::exception()); + + try { + /* step 2: search */ + engine::ExecutionEngineContext context; + context.query_ptr_ = query_ptr_; + context.query_result_ = std::make_shared(); + auto status = execution_engine_->Search(context); + + fiu_do_on("XSearchTask.Execute.search_fail", status = Status(SERVER_UNEXPECTED_ERROR, "")); + if (!status.ok()) { + search_job->SearchDone(segment_id_); + search_job->status() = status; + return; + } + + rc.RecordSection("search done"); + + /* step 3: pick up topk result */ + // auto spec_k = file_->row_count_ < topk ? file_->row_count_ : topk; + // if (spec_k == 0) { + // LOG_ENGINE_WARNING_ << LogOut("[%s][%ld] Searching in an empty file. file location = %s", + // "search", 0, + // file_->location_.c_str()); + // } else { + // std::unique_lock lock(search_job->mutex()); + // XSearchTask::MergeTopkToResultSet(result, spec_k, nq, topk, ascending_, search_job->GetQueryResult()); + // } + + rc.RecordSection("reduce topk done"); + } catch (std::exception& ex) { + LOG_ENGINE_ERROR_ << LogOut("[%s][%ld] SearchTask encounter exception: %s", "search", 0, ex.what()); + search_job->status() = Status(SERVER_UNEXPECTED_ERROR, ex.what()); + } + + /* step 4: notify to send result to client */ + search_job->SearchDone(segment_id_); + } + + rc.ElapseFromBegin("totally cost"); +} + +} // namespace scheduler +} // namespace milvus diff --git a/core/src/scheduler/task/SSSearchTask.h b/core/src/scheduler/task/SSSearchTask.h new file mode 100644 index 000000000000..65947985e88d --- /dev/null +++ b/core/src/scheduler/task/SSSearchTask.h @@ -0,0 +1,53 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include +#include +#include + +#include "db/SnapshotVisitor.h" +#include "db/engine/SSExecutionEngine.h" +#include "scheduler/Definition.h" +#include "scheduler/job/SSSearchJob.h" +#include "scheduler/task/Task.h" + +namespace milvus { +namespace scheduler { + +class SSSearchTask : public Task { + public: + explicit SSSearchTask(const server::ContextPtr& context, const engine::DBOptions& options, + const query::QueryPtr& query_ptr, engine::snapshot::ID_TYPE segment_id, TaskLabelPtr label); + + void + Load(LoadType type, uint8_t device_id) override; + + void + Execute() override; + + private: + void + CreateExecEngine(); + + public: + const std::shared_ptr context_; + + const engine::DBOptions& options_; + query::QueryPtr query_ptr_; + engine::snapshot::ID_TYPE segment_id_; + + engine::SSExecutionEnginePtr execution_engine_; +}; + +} // namespace scheduler +} // namespace milvus diff --git a/core/src/scheduler/task/SSTestTask.cpp b/core/src/scheduler/task/SSTestTask.cpp new file mode 100644 index 000000000000..0d4c197b8040 --- /dev/null +++ b/core/src/scheduler/task/SSTestTask.cpp @@ -0,0 +1,47 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include + +#include "cache/GpuCacheMgr.h" +#include "scheduler/task/SSTestTask.h" + +namespace milvus { +namespace scheduler { + +SSTestTask::SSTestTask(const server::ContextPtr& context, const engine::DBOptions& options, + const query::QueryPtr& query_ptr, engine::snapshot::ID_TYPE segment_id, TaskLabelPtr label) + : SSSearchTask(context, options, query_ptr, segment_id, std::move(label)) { +} + +void +SSTestTask::Load(LoadType type, uint8_t device_id) { + load_count_++; +} + +void +SSTestTask::Execute() { + { + std::lock_guard lock(mutex_); + exec_count_++; + done_ = true; + } + cv_.notify_one(); +} + +void +SSTestTask::Wait() { + std::unique_lock lock(mutex_); + cv_.wait(lock, [&] { return done_; }); +} + +} // namespace scheduler +} // namespace milvus diff --git a/core/src/scheduler/task/SSTestTask.h b/core/src/scheduler/task/SSTestTask.h new file mode 100644 index 000000000000..6cddd3b762cb --- /dev/null +++ b/core/src/scheduler/task/SSTestTask.h @@ -0,0 +1,46 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include + +#include "SSSearchTask.h" + +namespace milvus { +namespace scheduler { + +class SSTestTask : public SSSearchTask { + public: + explicit SSTestTask(const server::ContextPtr& context, const engine::DBOptions& options, + const query::QueryPtr& query_ptr, engine::snapshot::ID_TYPE segment_id, TaskLabelPtr label); + + public: + void + Load(LoadType type, uint8_t device_id) override; + + void + Execute() override; + + void + Wait(); + + public: + int64_t load_count_ = 0; + int64_t exec_count_ = 0; + + bool done_ = false; + std::mutex mutex_; + std::condition_variable cv_; +}; + +} // namespace scheduler +} // namespace milvus diff --git a/core/src/scheduler/task/SearchTask.cpp b/core/src/scheduler/task/SearchTask.cpp index d224f07f9230..e6d70f622d0c 100644 --- a/core/src/scheduler/task/SearchTask.cpp +++ b/core/src/scheduler/task/SearchTask.cpp @@ -103,9 +103,8 @@ XSearchTask::XSearchTask(const std::shared_ptr& context, Segmen : Task(TaskType::SearchTask, std::move(label)), context_(context), file_(file) { if (file_) { // distance -- value 0 means two vectors equal, ascending reduce, L2/HAMMING/JACCARD/TONIMOTO ... - // similarity -- infinity value means two vectors equal, descending reduce, IP - if (file_->metric_type_ == static_cast(MetricType::IP) && - file_->engine_type_ != static_cast(EngineType::FAISS_PQ)) { + // similarity -- value 1 means two vectors equal, descending reduce, IP + if (file_->metric_type_ == static_cast(MetricType::IP)) { ascending_reduce = false; } @@ -235,11 +234,11 @@ XSearchTask::Execute() { } Status s; if (general_query != nullptr) { - std::unordered_map types; + std::unordered_map types; auto attr_type = search_job->attr_type(); auto type_it = attr_type.begin(); for (; type_it != attr_type.end(); type_it++) { - types.insert(std::make_pair(type_it->first, (engine::DataType)(type_it->second))); + types.insert(std::make_pair(type_it->first, (DataType)(type_it->second))); } auto query_ptr = search_job->query_ptr(); @@ -281,7 +280,7 @@ XSearchTask::Execute() { search_job->time_stat().reduce_time += span / 1000; } catch (std::exception& ex) { LOG_ENGINE_ERROR_ << LogOut("[%s][%ld] SearchTask encounter exception: %s", "search", 0, ex.what()); - // search_job->IndexSearchDone(index_id_); //mark as done avoid dead lock, even search failed + search_job->GetStatus() = Status(SERVER_UNEXPECTED_ERROR, ex.what()); } /* step 4: notify to send result to client */ @@ -308,7 +307,6 @@ XSearchTask::MergeTopkToResultSet(const scheduler::ResultIds& src_ids, const sch scheduler::ResultIds buf_ids(nq * buf_k, -1); scheduler::ResultDistances buf_distances(nq * buf_k, 0.0); - for (uint64_t i = 0; i < nq; i++) { size_t buf_k_j = 0, src_k_j = 0, tar_k_j = 0; size_t buf_idx, src_idx, tar_idx; diff --git a/core/src/scheduler/task/Task.h b/core/src/scheduler/task/Task.h index b8cc1ef45ee2..28261cd9211c 100644 --- a/core/src/scheduler/task/Task.h +++ b/core/src/scheduler/task/Task.h @@ -24,17 +24,17 @@ namespace milvus { namespace scheduler { enum class LoadType { - DISK2CPU, - CPU2GPU, - GPU2CPU, - TEST, + DISK2CPU = 0, + CPU2GPU = 1, + GPU2CPU = 2, + TEST = 99, }; enum class TaskType { - SearchTask, - DeleteTask, - BuildIndexTask, - TestTask, + SearchTask = 0, + DeleteTask = 1, + BuildIndexTask = 2, + TestTask = 99, }; class Task; diff --git a/core/src/segment/SSSegmentReader.cpp b/core/src/segment/SSSegmentReader.cpp new file mode 100644 index 000000000000..fc3949ed2eec --- /dev/null +++ b/core/src/segment/SSSegmentReader.cpp @@ -0,0 +1,445 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "segment/SSSegmentReader.h" + +#include +#include +#include + +#include "Vectors.h" +#include "codecs/snapshot/SSCodec.h" +#include "db/Types.h" +#include "db/snapshot/ResourceHelper.h" +#include "knowhere/index/vector_index/VecIndex.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" +#include "storage/disk/DiskIOReader.h" +#include "storage/disk/DiskIOWriter.h" +#include "storage/disk/DiskOperation.h" +#include "utils/Log.h" + +namespace milvus { +namespace segment { + +SSSegmentReader::SSSegmentReader(const std::string& dir_root, const engine::SegmentVisitorPtr& segment_visitor) + : dir_root_(dir_root), segment_visitor_(segment_visitor) { + Initialize(); +} + +Status +SSSegmentReader::Initialize() { + std::string directory = + engine::snapshot::GetResPath(dir_root_, segment_visitor_->GetSegment()); + + storage::IOReaderPtr reader_ptr = std::make_shared(); + storage::IOWriterPtr writer_ptr = std::make_shared(); + storage::OperationPtr operation_ptr = std::make_shared(directory); + fs_ptr_ = std::make_shared(reader_ptr, writer_ptr, operation_ptr); + + segment_ptr_ = std::make_shared(); + + const engine::SegmentVisitor::IdMapT& field_map = segment_visitor_->GetFieldVisitors(); + for (auto& iter : field_map) { + const engine::snapshot::FieldPtr& field = iter.second->GetField(); + std::string name = field->GetName(); + engine::FIELD_TYPE ftype = static_cast(field->GetFtype()); + if (ftype == engine::FIELD_TYPE::VECTOR || ftype == engine::FIELD_TYPE::VECTOR_FLOAT || + ftype == engine::FIELD_TYPE::VECTOR_BINARY) { + json params = field->GetParams(); + if (params.find(knowhere::meta::DIM) == params.end()) { + std::string msg = "Vector field params must contain: dimension"; + LOG_SERVER_ERROR_ << msg; + return Status(DB_ERROR, msg); + } + + int64_t field_width = 0; + int64_t dimension = params[knowhere::meta::DIM]; + if (ftype == engine::FIELD_TYPE::VECTOR_BINARY) { + field_width += (dimension / 8); + } else { + field_width += (dimension * sizeof(float)); + } + segment_ptr_->AddField(name, ftype, field_width); + } else { + segment_ptr_->AddField(name, ftype); + } + } + + return Status::OK(); +} + +Status +SSSegmentReader::Load() { + STATUS_CHECK(LoadFields()); + + segment::IdBloomFilterPtr id_bloom_filter_ptr; + STATUS_CHECK(LoadBloomFilter(id_bloom_filter_ptr)); + + segment::DeletedDocsPtr deleted_docs_ptr; + STATUS_CHECK(LoadDeletedDocs(deleted_docs_ptr)); + + STATUS_CHECK(LoadVectorIndice()); + + return Status::OK(); +} + +Status +SSSegmentReader::LoadField(const std::string& field_name, std::vector& raw) { + try { + engine::FIXEDX_FIELD_MAP& field_map = segment_ptr_->GetFixedFields(); + auto pair = field_map.find(field_name); + if (pair != field_map.end()) { + raw = pair->second; + return Status::OK(); // alread exist + } + + auto field_visitor = segment_visitor_->GetFieldVisitor(field_name); + auto raw_visitor = field_visitor->GetElementVisitor(engine::FieldElementType::FET_RAW); + std::string file_path = + engine::snapshot::GetResPath(dir_root_, raw_visitor->GetFile()); + + auto& ss_codec = codec::SSCodec::instance(); + ss_codec.GetBlockFormat()->Read(fs_ptr_, file_path, raw); + + field_map.insert(std::make_pair(field_name, raw)); + } catch (std::exception& e) { + std::string err_msg = "Failed to load raw vectors: " + std::string(e.what()); + LOG_ENGINE_ERROR_ << err_msg; + return Status(DB_ERROR, err_msg); + } + return Status::OK(); +} + +Status +SSSegmentReader::LoadFields() { + auto& field_visitors_map = segment_visitor_->GetFieldVisitors(); + for (auto& iter : field_visitors_map) { + const engine::snapshot::FieldPtr& field = iter.second->GetField(); + std::string name = field->GetName(); + engine::FIXED_FIELD_DATA raw_data; + auto status = segment_ptr_->GetFixedFieldData(name, raw_data); + + if (!status.ok() || raw_data.empty()) { + auto element_visitor = iter.second->GetElementVisitor(engine::FieldElementType::FET_RAW); + std::string file_path = + engine::snapshot::GetResPath(dir_root_, element_visitor->GetFile()); + STATUS_CHECK(LoadField(file_path, raw_data)); + } + } + + return Status::OK(); +} + +Status +SSSegmentReader::LoadEntities(const std::string& field_name, const std::vector& offsets, + std::vector& raw) { + try { + auto field_visitor = segment_visitor_->GetFieldVisitor(field_name); + auto raw_visitor = field_visitor->GetElementVisitor(engine::FieldElementType::FET_RAW); + std::string file_path = + engine::snapshot::GetResPath(dir_root_, raw_visitor->GetFile()); + + int64_t field_width = 0; + segment_ptr_->GetFixedFieldWidth(field_name, field_width); + if (field_width <= 0) { + return Status(DB_ERROR, "Invalid field width"); + } + + codec::ReadRanges ranges; + for (auto offset : offsets) { + ranges.push_back(codec::ReadRange(offset, field_width)); + } + auto& ss_codec = codec::SSCodec::instance(); + ss_codec.GetBlockFormat()->Read(fs_ptr_, file_path, ranges, raw); + } catch (std::exception& e) { + std::string err_msg = "Failed to load raw vectors: " + std::string(e.what()); + LOG_ENGINE_ERROR_ << err_msg; + return Status(DB_ERROR, err_msg); + } + + return Status::OK(); +} + +Status +SSSegmentReader::LoadFieldsEntities(const std::vector& fields_name, const std::vector& offsets, + engine::DataChunkPtr& data_chunk) { + data_chunk = std::make_shared(); + data_chunk->count_ = offsets.size(); + for (auto& name : fields_name) { + engine::FIXED_FIELD_DATA raw_data; + auto status = LoadEntities(name, offsets, raw_data); + if (!status.ok()) { + return status; + } + + data_chunk->fixed_fields_[name] = raw_data; + } + + return Status::OK(); +} + +Status +SSSegmentReader::LoadUids(std::vector& uids) { + std::vector raw; + auto status = LoadField(engine::DEFAULT_UID_NAME, raw); + if (!status.ok()) { + LOG_ENGINE_ERROR_ << status.message(); + return status; + } + + if (raw.size() % sizeof(int64_t) != 0) { + std::string err_msg = "Failed to load uids: illegal file size"; + LOG_ENGINE_ERROR_ << err_msg; + return Status(DB_ERROR, err_msg); + } + + uids.clear(); + uids.resize(raw.size() / sizeof(int64_t)); + memcpy(uids.data(), raw.data(), raw.size()); + + return Status::OK(); +} + +Status +SSSegmentReader::LoadVectorIndex(const std::string& field_name, knowhere::VecIndexPtr& index_ptr) { + try { + segment_ptr_->GetVectorIndex(field_name, index_ptr); + if (index_ptr != nullptr) { + return Status::OK(); // already exist + } + + auto& ss_codec = codec::SSCodec::instance(); + auto field_visitor = segment_visitor_->GetFieldVisitor(field_name); + knowhere::BinarySet index_data; + knowhere::BinaryPtr raw_data, compress_data; + + // if index file doesn't exist, return null + auto index_visitor = field_visitor->GetElementVisitor(engine::FieldElementType::FET_INDEX); + if (index_visitor == nullptr || index_visitor->GetFile() == nullptr) { + return Status(DB_ERROR, "index not available"); + } + + // read index file + std::string file_path = + engine::snapshot::GetResPath(dir_root_, index_visitor->GetFile()); + ss_codec.GetVectorIndexFormat()->ReadIndex(fs_ptr_, file_path, index_data); + + auto index_type = knowhere::StrToOldIndexType(index_visitor->GetElement()->GetName()); + + // for some kinds index(IVF), read raw file + if (index_type == (int32_t)engine::EngineType::FAISS_IVFFLAT || + index_type == (int32_t)engine::EngineType::NSG_MIX || index_type == (int32_t)engine::EngineType::HNSW) { + engine::FIXED_FIELD_DATA fixed_data; + auto status = segment_ptr_->GetFixedFieldData(field_name, fixed_data); + if (status.ok()) { + ss_codec.GetVectorIndexFormat()->ConvertRaw(fixed_data, raw_data); + } else if (auto visitor = field_visitor->GetElementVisitor(engine::FieldElementType::FET_RAW)) { + file_path = engine::snapshot::GetResPath(dir_root_, visitor->GetFile()); + ss_codec.GetVectorIndexFormat()->ReadRaw(fs_ptr_, file_path, raw_data); + } + } + + // for some kinds index(SQ8), read compress file + if (index_type == (int32_t)engine::EngineType::FAISS_IVFSQ8NR || + index_type == (int32_t)engine::EngineType::HNSW_SQ8NM) { + if (auto visitor = field_visitor->GetElementVisitor(engine::FieldElementType::FET_COMPRESS_SQ8)) { + file_path = engine::snapshot::GetResPath(dir_root_, visitor->GetFile()); + ss_codec.GetVectorIndexFormat()->ReadCompress(fs_ptr_, file_path, compress_data); + } + } + + std::string index_name = index_visitor->GetElement()->GetName(); + ss_codec.GetVectorIndexFormat()->ConstructIndex(index_name, index_data, raw_data, compress_data, index_ptr); + + segment_ptr_->SetVectorIndex(field_name, index_ptr); + } catch (std::exception& e) { + std::string err_msg = "Failed to load vector index: " + std::string(e.what()); + LOG_ENGINE_ERROR_ << err_msg; + return Status(DB_ERROR, err_msg); + } + + return Status::OK(); +} + +Status +SSSegmentReader::LoadStructuredIndex(const std::string& field_name, knowhere::IndexPtr& index_ptr) { + try { + segment_ptr_->GetStructuredIndex(field_name, index_ptr); + if (index_ptr != nullptr) { + return Status::OK(); // already exist + } + + auto& ss_codec = codec::SSCodec::instance(); + auto field_visitor = segment_visitor_->GetFieldVisitor(field_name); + + auto index_visitor = field_visitor->GetElementVisitor(engine::FieldElementType::FET_INDEX); + if (index_visitor) { + std::string file_path = + engine::snapshot::GetResPath(dir_root_, index_visitor->GetFile()); + ss_codec.GetStructuredIndexFormat()->Read(fs_ptr_, file_path, index_ptr); + + segment_ptr_->SetStructuredIndex(field_name, index_ptr); + } + } catch (std::exception& e) { + std::string err_msg = "Failed to load vector index: " + std::string(e.what()); + LOG_ENGINE_ERROR_ << err_msg; + return Status(DB_ERROR, err_msg); + } + + return Status::OK(); +} + +Status +SSSegmentReader::LoadVectorIndice() { + auto& field_visitors_map = segment_visitor_->GetFieldVisitors(); + for (auto& iter : field_visitors_map) { + const engine::snapshot::FieldPtr& field = iter.second->GetField(); + std::string name = field->GetName(); + + auto element_visitor = iter.second->GetElementVisitor(engine::FieldElementType::FET_INDEX); + if (element_visitor == nullptr) { + continue; + } + + std::string file_path = + engine::snapshot::GetResPath(dir_root_, element_visitor->GetFile()); + if (field->GetFtype() == engine::FIELD_TYPE::VECTOR || field->GetFtype() == engine::FIELD_TYPE::VECTOR_FLOAT || + field->GetFtype() == engine::FIELD_TYPE::VECTOR_BINARY) { + knowhere::VecIndexPtr index_ptr; + STATUS_CHECK(LoadVectorIndex(name, index_ptr)); + } else { + knowhere::IndexPtr index_ptr; + STATUS_CHECK(LoadStructuredIndex(name, index_ptr)); + } + } + + return Status::OK(); +} + +Status +SSSegmentReader::LoadBloomFilter(segment::IdBloomFilterPtr& id_bloom_filter_ptr) { + try { + id_bloom_filter_ptr = segment_ptr_->GetBloomFilter(); + if (id_bloom_filter_ptr != nullptr) { + return Status::OK(); // already exist + } + + auto uid_field_visitor = segment_visitor_->GetFieldVisitor(engine::DEFAULT_UID_NAME); + auto visitor = uid_field_visitor->GetElementVisitor(engine::FieldElementType::FET_BLOOM_FILTER); + std::string file_path = + engine::snapshot::GetResPath(dir_root_, visitor->GetFile()); + + auto& ss_codec = codec::SSCodec::instance(); + ss_codec.GetIdBloomFilterFormat()->Read(fs_ptr_, file_path, id_bloom_filter_ptr); + + if (id_bloom_filter_ptr) { + segment_ptr_->SetBloomFilter(id_bloom_filter_ptr); + } + } catch (std::exception& e) { + std::string err_msg = "Failed to load bloom filter: " + std::string(e.what()); + LOG_ENGINE_ERROR_ << err_msg; + return Status(DB_ERROR, err_msg); + } + return Status::OK(); +} + +Status +SSSegmentReader::LoadDeletedDocs(segment::DeletedDocsPtr& deleted_docs_ptr) { + try { + deleted_docs_ptr = segment_ptr_->GetDeletedDocs(); + if (deleted_docs_ptr != nullptr) { + return Status::OK(); // already exist + } + + auto uid_field_visitor = segment_visitor_->GetFieldVisitor(engine::DEFAULT_UID_NAME); + auto visitor = uid_field_visitor->GetElementVisitor(engine::FieldElementType::FET_DELETED_DOCS); + std::string file_path = + engine::snapshot::GetResPath(dir_root_, visitor->GetFile()); + if (!boost::filesystem::exists(file_path)) { + return Status::OK(); // file doesn't exist + } + + auto& ss_codec = codec::SSCodec::instance(); + ss_codec.GetDeletedDocsFormat()->Read(fs_ptr_, file_path, deleted_docs_ptr); + + if (deleted_docs_ptr) { + segment_ptr_->SetDeletedDocs(deleted_docs_ptr); + } + } catch (std::exception& e) { + std::string err_msg = "Failed to load deleted docs: " + std::string(e.what()); + LOG_ENGINE_ERROR_ << err_msg; + return Status(DB_ERROR, err_msg); + } + return Status::OK(); +} + +Status +SSSegmentReader::ReadDeletedDocsSize(size_t& size) { + try { + size = 0; + auto deleted_docs_ptr = segment_ptr_->GetDeletedDocs(); + if (deleted_docs_ptr != nullptr) { + size = deleted_docs_ptr->GetSize(); + return Status::OK(); // already exist + } + + auto uid_field_visitor = segment_visitor_->GetFieldVisitor(engine::DEFAULT_UID_NAME); + auto visitor = uid_field_visitor->GetElementVisitor(engine::FieldElementType::FET_DELETED_DOCS); + std::string file_path = + engine::snapshot::GetResPath(dir_root_, visitor->GetFile()); + if (!boost::filesystem::exists(file_path)) { + return Status::OK(); // file doesn't exist + } + + auto& ss_codec = codec::SSCodec::instance(); + ss_codec.GetDeletedDocsFormat()->ReadSize(fs_ptr_, file_path, size); + } catch (std::exception& e) { + std::string err_msg = "Failed to read deleted docs size: " + std::string(e.what()); + LOG_ENGINE_ERROR_ << err_msg; + return Status(DB_ERROR, err_msg); + } + return Status::OK(); +} + +Status +SSSegmentReader::GetSegment(engine::SegmentPtr& segment_ptr) { + segment_ptr = segment_ptr_; + return Status::OK(); +} + +Status +SSSegmentReader::GetSegmentID(int64_t& id) { + if (segment_visitor_) { + auto segment = segment_visitor_->GetSegment(); + if (segment) { + id = segment->GetID(); + return Status::OK(); + } + } + + return Status(DB_ERROR, "SSSegmentWriter::GetSegmentID: null pointer"); +} + +std::string +SSSegmentReader::GetSegmentPath() { + std::string seg_path = + engine::snapshot::GetResPath(dir_root_, segment_visitor_->GetSegment()); + return seg_path; +} + +} // namespace segment +} // namespace milvus diff --git a/core/src/segment/SSSegmentReader.h b/core/src/segment/SSSegmentReader.h new file mode 100644 index 000000000000..d579c83997a9 --- /dev/null +++ b/core/src/segment/SSSegmentReader.h @@ -0,0 +1,97 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include "codecs/Codec.h" +#include "db/SnapshotVisitor.h" +#include "segment/Segment.h" +#include "storage/FSHandler.h" +#include "utils/Status.h" + +namespace milvus { +namespace segment { + +class SSSegmentReader { + public: + explicit SSSegmentReader(const std::string& dir_root, const engine::SegmentVisitorPtr& segment_visitor); + + Status + Load(); + + Status + LoadField(const std::string& field_name, std::vector& raw); + + Status + LoadFields(); + + Status + LoadEntities(const std::string& field_name, const std::vector& offsets, std::vector& raw); + + Status + LoadFieldsEntities(const std::vector& fields_name, const std::vector& offsets, + engine::DataChunkPtr& data_chunk); + + Status + LoadUids(std::vector& uids); + + Status + LoadVectorIndex(const std::string& field_name, knowhere::VecIndexPtr& index_ptr); + + Status + LoadStructuredIndex(const std::string& field_name, knowhere::IndexPtr& index_ptr); + + Status + LoadVectorIndice(); + + Status + LoadBloomFilter(segment::IdBloomFilterPtr& id_bloom_filter_ptr); + + Status + LoadDeletedDocs(segment::DeletedDocsPtr& deleted_docs_ptr); + + Status + ReadDeletedDocsSize(size_t& size); + + Status + GetSegment(engine::SegmentPtr& segment_ptr); + + Status + GetSegmentID(int64_t& id); + + std::string + GetSegmentPath(); + + private: + Status + Initialize(); + + private: + engine::SegmentVisitorPtr segment_visitor_; + storage::FSHandlerPtr fs_ptr_; + engine::SegmentPtr segment_ptr_; + std::string dir_root_; +}; + +using SSSegmentReaderPtr = std::shared_ptr; + +} // namespace segment +} // namespace milvus diff --git a/core/src/segment/SSSegmentWriter.cpp b/core/src/segment/SSSegmentWriter.cpp new file mode 100644 index 000000000000..495e1c724c4b --- /dev/null +++ b/core/src/segment/SSSegmentWriter.cpp @@ -0,0 +1,437 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "segment/SSSegmentWriter.h" + +#include +#include +#include + +#include "SSSegmentReader.h" +#include "Vectors.h" +#include "codecs/snapshot/SSCodec.h" +#include "db/Utils.h" +#include "db/snapshot/ResourceHelper.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" +#include "storage/disk/DiskIOReader.h" +#include "storage/disk/DiskIOWriter.h" +#include "storage/disk/DiskOperation.h" +#include "utils/Log.h" +#include "utils/TimeRecorder.h" + +namespace milvus { +namespace segment { + +SSSegmentWriter::SSSegmentWriter(const std::string& dir_root, const engine::SegmentVisitorPtr& segment_visitor) + : dir_root_(dir_root), segment_visitor_(segment_visitor) { + Initialize(); +} + +Status +SSSegmentWriter::Initialize() { + std::string directory = + engine::snapshot::GetResPath(dir_root_, segment_visitor_->GetSegment()); + + storage::IOReaderPtr reader_ptr = std::make_shared(); + storage::IOWriterPtr writer_ptr = std::make_shared(); + storage::OperationPtr operation_ptr = std::make_shared(directory); + fs_ptr_ = std::make_shared(reader_ptr, writer_ptr, operation_ptr); + fs_ptr_->operation_ptr_->CreateDirectory(); + + segment_ptr_ = std::make_shared(); + + const engine::SegmentVisitor::IdMapT& field_map = segment_visitor_->GetFieldVisitors(); + for (auto& iter : field_map) { + const engine::snapshot::FieldPtr& field = iter.second->GetField(); + std::string name = field->GetName(); + engine::FIELD_TYPE ftype = static_cast(field->GetFtype()); + if (ftype == engine::FIELD_TYPE::VECTOR || ftype == engine::FIELD_TYPE::VECTOR_FLOAT || + ftype == engine::FIELD_TYPE::VECTOR_BINARY) { + json params = field->GetParams(); + if (params.find(knowhere::meta::DIM) == params.end()) { + std::string msg = "Vector field params must contain: dimension"; + LOG_SERVER_ERROR_ << msg; + return Status(DB_ERROR, msg); + } + + int64_t field_width = 0; + int64_t dimension = params[knowhere::meta::DIM]; + if (ftype == engine::FIELD_TYPE::VECTOR_BINARY) { + field_width += (dimension / 8); + } else { + field_width += (dimension * sizeof(float)); + } + segment_ptr_->AddField(name, ftype, field_width); + } else { + segment_ptr_->AddField(name, ftype); + } + } + + return Status::OK(); +} + +Status +SSSegmentWriter::AddChunk(const engine::DataChunkPtr& chunk_ptr) { + return segment_ptr_->AddChunk(chunk_ptr); +} + +Status +SSSegmentWriter::AddChunk(const engine::DataChunkPtr& chunk_ptr, int64_t from, int64_t to) { + return segment_ptr_->AddChunk(chunk_ptr, from, to); +} + +Status +SSSegmentWriter::Serialize() { + // write fields raw data + STATUS_CHECK(WriteFields()); + + // write empty UID's deleted docs + STATUS_CHECK(WriteDeletedDocs()); + + // write UID's bloom filter + STATUS_CHECK(WriteBloomFilter()); + + return Status::OK(); +} + +Status +SSSegmentWriter::WriteField(const std::string& file_path, const engine::FIXED_FIELD_DATA& raw) { + try { + auto& ss_codec = codec::SSCodec::instance(); + ss_codec.GetBlockFormat()->Write(fs_ptr_, file_path, raw); + } catch (std::exception& e) { + std::string err_msg = "Failed to write field: " + std::string(e.what()); + LOG_ENGINE_ERROR_ << err_msg; + + engine::utils::SendExitSignal(); + return Status(SERVER_WRITE_ERROR, err_msg); + } + return Status::OK(); +} + +Status +SSSegmentWriter::WriteFields() { + auto& field_visitors_map = segment_visitor_->GetFieldVisitors(); + for (auto& iter : field_visitors_map) { + const engine::snapshot::FieldPtr& field = iter.second->GetField(); + std::string name = field->GetName(); + engine::FIXED_FIELD_DATA raw_data; + segment_ptr_->GetFixedFieldData(name, raw_data); + + auto element_visitor = iter.second->GetElementVisitor(engine::FieldElementType::FET_RAW); + std::string file_path = + engine::snapshot::GetResPath(dir_root_, element_visitor->GetFile()); + STATUS_CHECK(WriteField(file_path, raw_data)); + } + + return Status::OK(); +} + +Status +SSSegmentWriter::WriteBloomFilter() { + try { + TimeRecorder recorder("SSSegmentWriter::WriteBloomFilter"); + + engine::FIXED_FIELD_DATA uid_data; + auto status = segment_ptr_->GetFixedFieldData(engine::DEFAULT_UID_NAME, uid_data); + if (!status.ok()) { + return status; + } + + auto& field_visitors_map = segment_visitor_->GetFieldVisitors(); + auto uid_field_visitor = segment_visitor_->GetFieldVisitor(engine::DEFAULT_UID_NAME); + auto uid_blf_visitor = uid_field_visitor->GetElementVisitor(engine::FieldElementType::FET_BLOOM_FILTER); + std::string uid_blf_path = + engine::snapshot::GetResPath(dir_root_, uid_blf_visitor->GetFile()); + + auto& ss_codec = codec::SSCodec::instance(); + segment::IdBloomFilterPtr bloom_filter_ptr; + ss_codec.GetIdBloomFilterFormat()->Create(fs_ptr_, uid_blf_path, bloom_filter_ptr); + + int64_t* uids = (int64_t*)(uid_data.data()); + int64_t row_count = segment_ptr_->GetRowCount(); + for (int64_t i = 0; i < row_count; i++) { + bloom_filter_ptr->Add(uids[i]); + } + segment_ptr_->SetBloomFilter(bloom_filter_ptr); + + recorder.RecordSection("Initialize bloom filter"); + + return WriteBloomFilter(uid_blf_path, segment_ptr_->GetBloomFilter()); + } catch (std::exception& e) { + std::string err_msg = "Failed to write vectors: " + std::string(e.what()); + LOG_ENGINE_ERROR_ << err_msg; + + engine::utils::SendExitSignal(); + return Status(SERVER_WRITE_ERROR, err_msg); + } +} + +Status +SSSegmentWriter::WriteBloomFilter(const std::string& file_path, const IdBloomFilterPtr& id_bloom_filter_ptr) { + if (id_bloom_filter_ptr == nullptr) { + return Status(DB_ERROR, "WriteBloomFilter: null pointer"); + } + + try { + TimeRecorder recorder("SSSegmentWriter::WriteBloomFilter"); + + auto& ss_codec = codec::SSCodec::instance(); + ss_codec.GetIdBloomFilterFormat()->Write(fs_ptr_, file_path, id_bloom_filter_ptr); + + recorder.RecordSection("Write bloom filter file"); + } catch (std::exception& e) { + std::string err_msg = "Failed to write bloom filter: " + std::string(e.what()); + LOG_ENGINE_ERROR_ << err_msg; + + engine::utils::SendExitSignal(); + return Status(SERVER_WRITE_ERROR, err_msg); + } + return Status::OK(); +} + +Status +SSSegmentWriter::WriteDeletedDocs() { + auto& field_visitors_map = segment_visitor_->GetFieldVisitors(); + auto uid_field_visitor = segment_visitor_->GetFieldVisitor(engine::DEFAULT_UID_NAME); + auto del_doc_visitor = uid_field_visitor->GetElementVisitor(engine::FieldElementType::FET_DELETED_DOCS); + std::string file_path = + engine::snapshot::GetResPath(dir_root_, del_doc_visitor->GetFile()); + + return WriteDeletedDocs(file_path, segment_ptr_->GetDeletedDocs()); +} + +Status +SSSegmentWriter::WriteDeletedDocs(const std::string& file_path, const DeletedDocsPtr& deleted_docs) { + if (deleted_docs == nullptr) { + return Status::OK(); + } + + try { + TimeRecorderAuto recorder("SSSegmentWriter::WriteDeletedDocs"); + + auto& ss_codec = codec::SSCodec::instance(); + ss_codec.GetDeletedDocsFormat()->Write(fs_ptr_, file_path, deleted_docs); + } catch (std::exception& e) { + std::string err_msg = "Failed to write deleted docs: " + std::string(e.what()); + LOG_ENGINE_ERROR_ << err_msg; + + engine::utils::SendExitSignal(); + return Status(SERVER_WRITE_ERROR, err_msg); + } + return Status::OK(); +} + +Status +SSSegmentWriter::Merge(const SSSegmentReaderPtr& segment_reader) { + if (segment_reader == nullptr) { + return Status(DB_ERROR, "Segment reader is null"); + } + + // check conflict + int64_t src_id, target_id; + auto status = GetSegmentID(target_id); + if (!status.ok()) { + return status; + } + status = segment_reader->GetSegmentID(src_id); + if (!status.ok()) { + return status; + } + if (src_id == target_id) { + return Status(DB_ERROR, "Cannot Merge Self"); + } + + LOG_ENGINE_DEBUG_ << "Merging from " << segment_reader->GetSegmentPath() << " to " << GetSegmentPath(); + + TimeRecorder recorder("SSSegmentWriter::Merge"); + + // merge deleted docs (Note: this step must before merge raw data) + segment::DeletedDocsPtr src_deleted_docs; + status = segment_reader->LoadDeletedDocs(src_deleted_docs); + if (!status.ok()) { + return status; + } + + engine::SegmentPtr src_segment; + status = segment_reader->GetSegment(src_segment); + if (!status.ok()) { + return status; + } + + if (src_deleted_docs) { + const std::vector& delete_ids = src_deleted_docs->GetDeletedDocs(); + for (auto offset : delete_ids) { + src_segment->DeleteEntity(offset); + } + } + + // merge filed raw data + engine::DataChunkPtr chunk = std::make_shared(); + auto& field_visitors_map = segment_visitor_->GetFieldVisitors(); + for (auto& iter : field_visitors_map) { + const engine::snapshot::FieldPtr& field = iter.second->GetField(); + std::string name = field->GetName(); + engine::FIXED_FIELD_DATA raw_data; + segment_reader->LoadField(name, raw_data); + chunk->fixed_fields_[name] = raw_data; + } + + auto& uid_data = chunk->fixed_fields_[engine::DEFAULT_UID_NAME]; + chunk->count_ = uid_data.size() / sizeof(int64_t); + status = AddChunk(chunk); + if (!status.ok()) { + return status; + } + + // Note: no need to merge bloom filter, the bloom filter will be created during serialize + + return Status::OK(); +} + +size_t +SSSegmentWriter::Size() { + return 0; +} + +size_t +SSSegmentWriter::RowCount() { + return segment_ptr_->GetRowCount(); +} + +Status +SSSegmentWriter::SetVectorIndex(const std::string& field_name, const milvus::knowhere::VecIndexPtr& index) { + return segment_ptr_->SetVectorIndex(field_name, index); +} + +Status +SSSegmentWriter::WriteVectorIndex(const std::string& field_name) { + try { + knowhere::VecIndexPtr index; + auto status = segment_ptr_->GetVectorIndex(field_name, index); + if (!status.ok() || index == nullptr) { + return Status(DB_ERROR, "Index doesn't exist: " + status.message()); + } + + auto& field_visitors_map = segment_visitor_->GetFieldVisitors(); + auto field = segment_visitor_->GetFieldVisitor(field_name); + if (field == nullptr) { + return Status(DB_ERROR, "Invalid filed name: " + field_name); + } + + auto element_visitor = field->GetElementVisitor(engine::FieldElementType::FET_INDEX); + if (element_visitor == nullptr) { + return Status(DB_ERROR, "Invalid filed name: " + field_name); + } + + auto& ss_codec = codec::SSCodec::instance(); + fs_ptr_->operation_ptr_->CreateDirectory(); + + std::string file_path = + engine::snapshot::GetResPath(dir_root_, element_visitor->GetFile()); + ss_codec.GetVectorIndexFormat()->WriteIndex(fs_ptr_, file_path, index); + + element_visitor = field->GetElementVisitor(engine::FieldElementType::FET_COMPRESS_SQ8); + if (element_visitor != nullptr) { + file_path = + engine::snapshot::GetResPath(dir_root_, element_visitor->GetFile()); + ss_codec.GetVectorIndexFormat()->WriteCompress(fs_ptr_, file_path, index); + } + } catch (std::exception& e) { + std::string err_msg = "Failed to write vector index: " + std::string(e.what()); + LOG_ENGINE_ERROR_ << err_msg; + + engine::utils::SendExitSignal(); + return Status(SERVER_WRITE_ERROR, err_msg); + } + + return Status::OK(); +} + +Status +SSSegmentWriter::SetStructuredIndex(const std::string& field_name, const knowhere::IndexPtr& index) { + return segment_ptr_->SetStructuredIndex(field_name, index); +} + +Status +SSSegmentWriter::WriteStructuredIndex(const std::string& field_name) { + try { + knowhere::IndexPtr index; + auto status = segment_ptr_->GetStructuredIndex(field_name, index); + if (!status.ok() || index == nullptr) { + return Status(DB_ERROR, "Structured index doesn't exist: " + status.message()); + } + + auto& field_visitors_map = segment_visitor_->GetFieldVisitors(); + auto field = segment_visitor_->GetFieldVisitor(field_name); + if (field == nullptr) { + return Status(DB_ERROR, "Invalid filed name: " + field_name); + } + + auto element_visitor = field->GetElementVisitor(engine::FieldElementType::FET_INDEX); + if (element_visitor == nullptr) { + return Status(DB_ERROR, "Invalid filed name: " + field_name); + } + + auto& ss_codec = codec::SSCodec::instance(); + fs_ptr_->operation_ptr_->CreateDirectory(); + + engine::FIELD_TYPE field_type; + segment_ptr_->GetFieldType(field_name, field_type); + + std::string file_path = + engine::snapshot::GetResPath(dir_root_, element_visitor->GetFile()); + ss_codec.GetStructuredIndexFormat()->Write(fs_ptr_, file_path, field_type, index); + } catch (std::exception& e) { + std::string err_msg = "Failed to write vector index: " + std::string(e.what()); + LOG_ENGINE_ERROR_ << err_msg; + + engine::utils::SendExitSignal(); + return Status(SERVER_WRITE_ERROR, err_msg); + } + + return Status::OK(); +} + +Status +SSSegmentWriter::GetSegment(engine::SegmentPtr& segment_ptr) { + segment_ptr = segment_ptr_; + return Status::OK(); +} + +Status +SSSegmentWriter::GetSegmentID(int64_t& id) { + if (segment_visitor_) { + auto segment = segment_visitor_->GetSegment(); + if (segment) { + id = segment->GetID(); + return Status::OK(); + } + } + + return Status(DB_ERROR, "SSSegmentWriter::GetSegmentID: null pointer"); +} + +std::string +SSSegmentWriter::GetSegmentPath() { + std::string seg_path = + engine::snapshot::GetResPath(dir_root_, segment_visitor_->GetSegment()); + return seg_path; +} + +} // namespace segment +} // namespace milvus diff --git a/core/src/segment/SSSegmentWriter.h b/core/src/segment/SSSegmentWriter.h new file mode 100644 index 000000000000..61c5833794c2 --- /dev/null +++ b/core/src/segment/SSSegmentWriter.h @@ -0,0 +1,110 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "db/SnapshotVisitor.h" +#include "segment/SSSegmentReader.h" +#include "segment/Segment.h" +#include "storage/FSHandler.h" +#include "utils/Status.h" + +namespace milvus { +namespace segment { + +class SSSegmentWriter { + public: + explicit SSSegmentWriter(const std::string& dir_root, const engine::SegmentVisitorPtr& segment_visitor); + + Status + AddChunk(const engine::DataChunkPtr& chunk_ptr); + + Status + AddChunk(const engine::DataChunkPtr& chunk_ptr, int64_t from, int64_t to); + + Status + WriteBloomFilter(const std::string& file_path, const IdBloomFilterPtr& bloom_filter_ptr); + + Status + WriteDeletedDocs(const std::string& file_path, const DeletedDocsPtr& deleted_docs); + + Status + Serialize(); + + Status + Merge(const SSSegmentReaderPtr& segment_reader); + + size_t + Size(); + + size_t + RowCount(); + + Status + SetVectorIndex(const std::string& field_name, const knowhere::VecIndexPtr& index); + + Status + WriteVectorIndex(const std::string& field_name); + + Status + SetStructuredIndex(const std::string& field_name, const knowhere::IndexPtr& index); + + Status + WriteStructuredIndex(const std::string& field_name); + + Status + GetSegment(engine::SegmentPtr& segment_ptr); + + Status + GetSegmentID(int64_t& id); + + std::string + GetSegmentPath(); + + private: + Status + Initialize(); + + Status + WriteField(const std::string& file_path, const engine::FIXED_FIELD_DATA& raw); + + Status + WriteFields(); + + Status + WriteBloomFilter(); + + Status + WriteDeletedDocs(); + + private: + engine::SegmentVisitorPtr segment_visitor_; + storage::FSHandlerPtr fs_ptr_; + engine::SegmentPtr segment_ptr_; + std::string dir_root_; +}; + +using SSSegmentWriterPtr = std::shared_ptr; + +} // namespace segment +} // namespace milvus diff --git a/core/src/segment/Segment.cpp b/core/src/segment/Segment.cpp new file mode 100644 index 000000000000..661fc0c1b61c --- /dev/null +++ b/core/src/segment/Segment.cpp @@ -0,0 +1,216 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "segment/Segment.h" +#include "utils/Log.h" + +#include + +namespace milvus { +namespace engine { + +Status +Segment::AddField(const std::string& field_name, FIELD_TYPE field_type, int64_t field_width) { + if (field_types_.find(field_name) != field_types_.end()) { + return Status(DB_ERROR, "duplicate field: " + field_name); + } + + int64_t real_field_width = 0; + switch (field_type) { + case FIELD_TYPE::BOOL: + real_field_width = sizeof(bool); + break; + case FIELD_TYPE::DOUBLE: + real_field_width = sizeof(double); + break; + case FIELD_TYPE::FLOAT: + real_field_width = sizeof(float); + break; + case FIELD_TYPE::INT8: + real_field_width = sizeof(uint8_t); + break; + case FIELD_TYPE::INT16: + real_field_width = sizeof(uint16_t); + break; + case FIELD_TYPE::INT32: + real_field_width = sizeof(uint32_t); + break; + case FIELD_TYPE::UID: + case FIELD_TYPE::INT64: + real_field_width = sizeof(uint64_t); + break; + case FIELD_TYPE::VECTOR: + case FIELD_TYPE::VECTOR_FLOAT: + case FIELD_TYPE::VECTOR_BINARY: { + if (field_width <= 0) { + std::string msg = "vecor field dimension required: " + field_name; + LOG_SERVER_ERROR_ << msg; + return Status(DB_ERROR, msg); + } + + real_field_width = field_width; + break; + } + } + + field_types_.insert(std::make_pair(field_name, field_type)); + fixed_fields_width_.insert(std::make_pair(field_name, real_field_width)); + + return Status::OK(); +} + +Status +Segment::AddChunk(const DataChunkPtr& chunk_ptr) { + if (chunk_ptr == nullptr || chunk_ptr->count_ == 0) { + return Status(DB_ERROR, "invalid input"); + } + + return AddChunk(chunk_ptr, 0, chunk_ptr->count_); +} + +Status +Segment::AddChunk(const DataChunkPtr& chunk_ptr, int64_t from, int64_t to) { + if (chunk_ptr == nullptr || from < 0 || to < 0 || from > chunk_ptr->count_ || to > chunk_ptr->count_ || + from >= to) { + return Status(DB_ERROR, "invalid input"); + } + + // check input + for (auto& iter : chunk_ptr->fixed_fields_) { + auto width_iter = fixed_fields_width_.find(iter.first); + if (width_iter == fixed_fields_width_.end()) { + return Status(DB_ERROR, "field not yet defined: " + iter.first); + } + + if (iter.second.size() != width_iter->second * chunk_ptr->count_) { + return Status(DB_ERROR, "illegal field: " + iter.first); + } + } + + // consume + int64_t add_count = to - from; + for (auto& width_iter : fixed_fields_width_) { + auto input = chunk_ptr->fixed_fields_.find(width_iter.first); + auto& data = fixed_fields_[width_iter.first]; + size_t origin_bytes = data.size(); + int64_t add_bytes = add_count * width_iter.second; + int64_t previous_bytes = row_count_ * width_iter.second; + int64_t target_bytes = previous_bytes + add_bytes; + data.resize(target_bytes); + if (input == chunk_ptr->fixed_fields_.end()) { + // this field is not provided, complicate by 0 + memset(data.data() + origin_bytes, 0, target_bytes - origin_bytes); + } else { + // complicate by 0 + if (origin_bytes < previous_bytes) { + memset(data.data() + origin_bytes, 0, previous_bytes - origin_bytes); + } + // copy input into this field + memcpy(data.data() + previous_bytes, input->second.data() + from * width_iter.second, add_bytes); + } + } + + row_count_ += add_count; + + return Status::OK(); +} + +Status +Segment::DeleteEntity(int64_t offset) { + for (auto& pair : fixed_fields_) { + int64_t width = fixed_fields_width_[pair.first]; + if (width != 0) { + auto step = offset * width; + FIXED_FIELD_DATA& data = pair.second; + data.erase(data.begin() + step, data.begin() + step + width); + } + } + + return Status::OK(); +} + +Status +Segment::GetFieldType(const std::string& field_name, FIELD_TYPE& type) { + auto iter = field_types_.find(field_name); + if (iter == field_types_.end()) { + return Status(DB_ERROR, "invalid field name: " + field_name); + } + + type = iter->second; + return Status::OK(); +} + +Status +Segment::GetFixedFieldWidth(const std::string& field_name, int64_t& width) { + auto iter = fixed_fields_width_.find(field_name); + if (iter == fixed_fields_width_.end()) { + return Status(DB_ERROR, "invalid field name: " + field_name); + } + + width = iter->second; + return Status::OK(); +} + +Status +Segment::GetFixedFieldData(const std::string& field_name, FIXED_FIELD_DATA& data) { + auto iter = fixed_fields_.find(field_name); + if (iter == fixed_fields_.end()) { + return Status(DB_ERROR, "invalid field name: " + field_name); + } + + data = iter->second; + return Status::OK(); +} + +Status +Segment::GetVectorIndex(const std::string& field_name, knowhere::VecIndexPtr& index) { + index = nullptr; + auto iter = vector_indice_.find(field_name); + if (iter == vector_indice_.end()) { + return Status(DB_ERROR, "invalid field name: " + field_name); + } + + index = iter->second; + return Status::OK(); +} + +Status +Segment::SetVectorIndex(const std::string& field_name, const knowhere::VecIndexPtr& index) { + vector_indice_[field_name] = index; + return Status::OK(); +} + +Status +Segment::GetStructuredIndex(const std::string& field_name, knowhere::IndexPtr& index) { + index = nullptr; + auto iter = structured_indice_.find(field_name); + if (iter == structured_indice_.end()) { + return Status(DB_ERROR, "invalid field name: " + field_name); + } + + index = iter->second; + return Status::OK(); +} + +Status +Segment::SetStructuredIndex(const std::string& field_name, const knowhere::IndexPtr& index) { + structured_indice_[field_name] = index; + return Status::OK(); +} + +} // namespace engine +} // namespace milvus diff --git a/core/src/segment/Segment.h b/core/src/segment/Segment.h new file mode 100644 index 000000000000..45ef7679657b --- /dev/null +++ b/core/src/segment/Segment.h @@ -0,0 +1,151 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +#include "db/Types.h" +#include "db/meta/MetaTypes.h" +#include "knowhere/index/vector_index/VecIndex.h" +#include "segment/DeletedDocs.h" +#include "segment/IdBloomFilter.h" + +namespace milvus { +namespace engine { + +using FIELD_TYPE = engine::meta::hybrid::DataType; +using FIELD_TYPE_MAP = std::unordered_map; +using FIELD_WIDTH_MAP = std::unordered_map; +using FIXED_FIELD_DATA = std::vector; +using FIXEDX_FIELD_MAP = std::unordered_map; +using VARIABLE_FIELD_DATA = std::vector; +using VARIABLE_FIELD_MAP = std::unordered_map; +using VECTOR_INDEX_MAP = std::unordered_map; +using STRUCTURED_INDEX_MAP = std::unordered_map; + +struct DataChunk { + int64_t count_ = 0; + FIXEDX_FIELD_MAP fixed_fields_; + VARIABLE_FIELD_MAP variable_fields_; +}; + +using DataChunkPtr = std::shared_ptr; + +class Segment { + public: + Status + AddField(const std::string& field_name, FIELD_TYPE field_type, int64_t field_width = 0); + + Status + AddChunk(const DataChunkPtr& chunk_ptr); + + Status + AddChunk(const DataChunkPtr& chunk_ptr, int64_t from, int64_t to); + + Status + DeleteEntity(int64_t offset); + + Status + GetFieldType(const std::string& field_name, FIELD_TYPE& type); + + Status + GetFixedFieldWidth(const std::string& field_name, int64_t& width); + + Status + GetFixedFieldData(const std::string& field_name, FIXED_FIELD_DATA& data); + + Status + GetVectorIndex(const std::string& field_name, knowhere::VecIndexPtr& index); + + Status + SetVectorIndex(const std::string& field_name, const knowhere::VecIndexPtr& index); + + Status + GetStructuredIndex(const std::string& field_name, knowhere::IndexPtr& index); + + Status + SetStructuredIndex(const std::string& field_name, const knowhere::IndexPtr& index); + + FIELD_TYPE_MAP& + GetFieldTypes() { + return field_types_; + } + FIXEDX_FIELD_MAP& + GetFixedFields() { + return fixed_fields_; + } + VARIABLE_FIELD_MAP& + GetVariableFields() { + return variable_fields_; + } + VECTOR_INDEX_MAP& + GetVectorIndice() { + return vector_indice_; + } + + STRUCTURED_INDEX_MAP& + GetStructuredIndice() { + return structured_indice_; + } + + int64_t + GetRowCount() const { + return row_count_; + } + + segment::DeletedDocsPtr + GetDeletedDocs() const { + return deleted_docs_ptr_; + } + + void + SetDeletedDocs(const segment::DeletedDocsPtr& ptr) { + deleted_docs_ptr_ = ptr; + } + + segment::IdBloomFilterPtr + GetBloomFilter() const { + return id_bloom_filter_ptr_; + } + + void + SetBloomFilter(const segment::IdBloomFilterPtr& ptr) { + id_bloom_filter_ptr_ = ptr; + } + + private: + FIELD_TYPE_MAP field_types_; + FIELD_WIDTH_MAP fixed_fields_width_; + FIXEDX_FIELD_MAP fixed_fields_; + VARIABLE_FIELD_MAP variable_fields_; + VECTOR_INDEX_MAP vector_indice_; + STRUCTURED_INDEX_MAP structured_indice_; + + int64_t row_count_ = 0; + + segment::DeletedDocsPtr deleted_docs_ptr_ = nullptr; + segment::IdBloomFilterPtr id_bloom_filter_ptr_ = nullptr; +}; + +using SegmentPtr = std::shared_ptr; + +} // namespace engine +} // namespace milvus diff --git a/core/src/segment/SegmentReader.cpp b/core/src/segment/SegmentReader.cpp index 62a1358b1791..a2bdbb7865bc 100644 --- a/core/src/segment/SegmentReader.cpp +++ b/core/src/segment/SegmentReader.cpp @@ -21,6 +21,7 @@ #include "Vectors.h" #include "codecs/default/DefaultCodec.h" +#include "knowhere/index/vector_index/VecIndex.h" #include "storage/disk/DiskIOReader.h" #include "storage/disk/DiskIOWriter.h" #include "storage/disk/DiskOperation.h" @@ -45,9 +46,8 @@ SegmentReader::LoadCache(bool& in_cache) { Status SegmentReader::Load() { - // TODO(zhiru) - codec::DefaultCodec default_codec; try { + auto& default_codec = codec::DefaultCodec::instance(); fs_ptr_->operation_ptr_->CreateDirectory(); default_codec.GetVectorsFormat()->read(fs_ptr_, segment_ptr_->vectors_ptr_); default_codec.GetAttrsFormat()->read(fs_ptr_, segment_ptr_->attrs_ptr_); @@ -62,8 +62,8 @@ SegmentReader::Load() { Status SegmentReader::LoadVectors(off_t offset, size_t num_bytes, std::vector& raw_vectors) { - codec::DefaultCodec default_codec; try { + auto& default_codec = codec::DefaultCodec::instance(); fs_ptr_->operation_ptr_->CreateDirectory(); default_codec.GetVectorsFormat()->read_vectors(fs_ptr_, offset, num_bytes, raw_vectors); } catch (std::exception& e) { @@ -77,8 +77,8 @@ SegmentReader::LoadVectors(off_t offset, size_t num_bytes, std::vector& Status SegmentReader::LoadAttrs(const std::string& field_name, off_t offset, size_t num_bytes, std::vector& raw_attrs) { - codec::DefaultCodec default_codec; try { + auto& default_codec = codec::DefaultCodec::instance(); fs_ptr_->operation_ptr_->CreateDirectory(); default_codec.GetAttrsFormat()->read_attrs(fs_ptr_, field_name, offset, num_bytes, raw_attrs); } catch (std::exception& e) { @@ -91,8 +91,8 @@ SegmentReader::LoadAttrs(const std::string& field_name, off_t offset, size_t num Status SegmentReader::LoadUids(std::vector& uids) { - codec::DefaultCodec default_codec; try { + auto& default_codec = codec::DefaultCodec::instance(); fs_ptr_->operation_ptr_->CreateDirectory(); default_codec.GetVectorsFormat()->read_uids(fs_ptr_, uids); } catch (std::exception& e) { @@ -110,11 +110,12 @@ SegmentReader::GetSegment(SegmentPtr& segment_ptr) { } Status -SegmentReader::LoadVectorIndex(const std::string& location, segment::VectorIndexPtr& vector_index_ptr) { - codec::DefaultCodec default_codec; +SegmentReader::LoadVectorIndex(const std::string& location, codec::ExternalData external_data, + segment::VectorIndexPtr& vector_index_ptr) { try { + auto& default_codec = codec::DefaultCodec::instance(); fs_ptr_->operation_ptr_->CreateDirectory(); - default_codec.GetVectorIndexFormat()->read(fs_ptr_, location, vector_index_ptr); + default_codec.GetVectorIndexFormat()->read(fs_ptr_, location, external_data, vector_index_ptr); } catch (std::exception& e) { std::string err_msg = "Failed to load vector index: " + std::string(e.what()); LOG_ENGINE_ERROR_ << err_msg; @@ -125,8 +126,8 @@ SegmentReader::LoadVectorIndex(const std::string& location, segment::VectorIndex Status SegmentReader::LoadBloomFilter(segment::IdBloomFilterPtr& id_bloom_filter_ptr) { - codec::DefaultCodec default_codec; try { + auto& default_codec = codec::DefaultCodec::instance(); fs_ptr_->operation_ptr_->CreateDirectory(); default_codec.GetIdBloomFilterFormat()->read(fs_ptr_, id_bloom_filter_ptr); } catch (std::exception& e) { @@ -139,8 +140,8 @@ SegmentReader::LoadBloomFilter(segment::IdBloomFilterPtr& id_bloom_filter_ptr) { Status SegmentReader::LoadDeletedDocs(segment::DeletedDocsPtr& deleted_docs_ptr) { - codec::DefaultCodec default_codec; try { + auto& default_codec = codec::DefaultCodec::instance(); fs_ptr_->operation_ptr_->CreateDirectory(); default_codec.GetDeletedDocsFormat()->read(fs_ptr_, deleted_docs_ptr); } catch (std::exception& e) { @@ -153,8 +154,8 @@ SegmentReader::LoadDeletedDocs(segment::DeletedDocsPtr& deleted_docs_ptr) { Status SegmentReader::ReadDeletedDocsSize(size_t& size) { - codec::DefaultCodec default_codec; try { + auto& default_codec = codec::DefaultCodec::instance(); fs_ptr_->operation_ptr_->CreateDirectory(); default_codec.GetDeletedDocsFormat()->readSize(fs_ptr_, size); } catch (std::exception& e) { diff --git a/core/src/segment/SegmentReader.h b/core/src/segment/SegmentReader.h index a69b1f3685ca..151b17257d70 100644 --- a/core/src/segment/SegmentReader.h +++ b/core/src/segment/SegmentReader.h @@ -21,6 +21,7 @@ #include #include +#include "codecs/Codec.h" #include "segment/Types.h" #include "storage/FSHandler.h" #include "utils/Status.h" @@ -49,7 +50,8 @@ class SegmentReader { LoadUids(std::vector& uids); Status - LoadVectorIndex(const std::string& location, segment::VectorIndexPtr& vector_index_ptr); + LoadVectorIndex(const std::string& location, codec::ExternalData external_data, + segment::VectorIndexPtr& vector_index_ptr); Status LoadBloomFilter(segment::IdBloomFilterPtr& id_bloom_filter_ptr); diff --git a/core/src/segment/SegmentWriter.cpp b/core/src/segment/SegmentWriter.cpp index 0865e720be88..2fd61de9af98 100644 --- a/core/src/segment/SegmentWriter.cpp +++ b/core/src/segment/SegmentWriter.cpp @@ -150,8 +150,8 @@ SegmentWriter::Serialize() { Status SegmentWriter::WriteVectors() { - codec::DefaultCodec default_codec; try { + auto& default_codec = codec::DefaultCodec::instance(); fs_ptr_->operation_ptr_->CreateDirectory(); default_codec.GetVectorsFormat()->write(fs_ptr_, segment_ptr_->vectors_ptr_); } catch (std::exception& e) { @@ -166,8 +166,8 @@ SegmentWriter::WriteVectors() { Status SegmentWriter::WriteAttrs() { - codec::DefaultCodec default_codec; try { + auto& default_codec = codec::DefaultCodec::instance(); fs_ptr_->operation_ptr_->CreateDirectory(); default_codec.GetAttrsFormat()->write(fs_ptr_, segment_ptr_->attrs_ptr_); } catch (std::exception& e) { @@ -186,8 +186,8 @@ SegmentWriter::WriteVectorIndex(const std::string& location) { return Status(SERVER_WRITE_ERROR, "Invalid parameter of WriteVectorIndex"); } - codec::DefaultCodec default_codec; try { + auto& default_codec = codec::DefaultCodec::instance(); fs_ptr_->operation_ptr_->CreateDirectory(); default_codec.GetVectorIndexFormat()->write(fs_ptr_, location, segment_ptr_->vector_index_ptr_); } catch (std::exception& e) { @@ -202,8 +202,8 @@ SegmentWriter::WriteVectorIndex(const std::string& location) { Status SegmentWriter::WriteAttrsIndex() { - codec::DefaultCodec default_codec; try { + auto& default_codec = codec::DefaultCodec::instance(); fs_ptr_->operation_ptr_->CreateDirectory(); default_codec.GetAttrsIndexFormat()->write(fs_ptr_, segment_ptr_->attrs_index_ptr_); } catch (std::exception& e) { @@ -218,8 +218,9 @@ SegmentWriter::WriteAttrsIndex() { Status SegmentWriter::WriteBloomFilter() { - codec::DefaultCodec default_codec; try { + auto& default_codec = codec::DefaultCodec::instance(); + fs_ptr_->operation_ptr_->CreateDirectory(); TimeRecorder recorder("SegmentWriter::WriteBloomFilter"); @@ -250,8 +251,8 @@ SegmentWriter::WriteBloomFilter() { Status SegmentWriter::WriteDeletedDocs() { - codec::DefaultCodec default_codec; try { + auto& default_codec = codec::DefaultCodec::instance(); fs_ptr_->operation_ptr_->CreateDirectory(); DeletedDocsPtr deleted_docs_ptr = std::make_shared(); default_codec.GetDeletedDocsFormat()->write(fs_ptr_, deleted_docs_ptr); @@ -267,8 +268,8 @@ SegmentWriter::WriteDeletedDocs() { Status SegmentWriter::WriteDeletedDocs(const DeletedDocsPtr& deleted_docs) { - codec::DefaultCodec default_codec; try { + auto& default_codec = codec::DefaultCodec::instance(); fs_ptr_->operation_ptr_->CreateDirectory(); default_codec.GetDeletedDocsFormat()->write(fs_ptr_, deleted_docs); } catch (std::exception& e) { @@ -283,8 +284,8 @@ SegmentWriter::WriteDeletedDocs(const DeletedDocsPtr& deleted_docs) { Status SegmentWriter::WriteBloomFilter(const IdBloomFilterPtr& id_bloom_filter_ptr) { - codec::DefaultCodec default_codec; try { + auto& default_codec = codec::DefaultCodec::instance(); fs_ptr_->operation_ptr_->CreateDirectory(); default_codec.GetIdBloomFilterFormat()->write(fs_ptr_, id_bloom_filter_ptr); } catch (std::exception& e) { @@ -389,5 +390,10 @@ SegmentWriter::VectorCount() { return segment_ptr_->vectors_ptr_->GetCount(); } +void +SegmentWriter::SetSegmentName(const std::string& name) { + segment_ptr_->vectors_ptr_->SetName(name); +} + } // namespace segment } // namespace milvus diff --git a/core/src/segment/SegmentWriter.h b/core/src/segment/SegmentWriter.h index 508656a6c653..0bebfbe0f468 100644 --- a/core/src/segment/SegmentWriter.h +++ b/core/src/segment/SegmentWriter.h @@ -82,6 +82,9 @@ class SegmentWriter { Status WriteAttrsIndex(); + void + SetSegmentName(const std::string& name); + private: Status WriteVectors(); diff --git a/core/src/server/ValidationUtil.cpp b/core/src/server/ValidationUtil.cpp index a3a8bb6ee747..fe6d8f2a1f6f 100644 --- a/core/src/server/ValidationUtil.cpp +++ b/core/src/server/ValidationUtil.cpp @@ -264,6 +264,7 @@ ValidateIndexParams(const milvus::json& index_params, const engine::meta::Collec } case (int32_t)engine::EngineType::FAISS_IVFFLAT: case (int32_t)engine::EngineType::FAISS_IVFSQ8: + case (int32_t)engine::EngineType::FAISS_IVFSQ8NR: case (int32_t)engine::EngineType::FAISS_IVFSQ8H: case (int32_t)engine::EngineType::FAISS_BIN_IVFFLAT: { auto status = CheckParameterRange(index_params, knowhere::IndexParams::nlist, 1, 999999); @@ -329,6 +330,7 @@ ValidateIndexParams(const milvus::json& index_params, const engine::meta::Collec } break; } + case (int32_t)engine::EngineType::HNSW_SQ8NM: case (int32_t)engine::EngineType::HNSW: { auto status = CheckParameterRange(index_params, knowhere::IndexParams::M, 4, 64); if (!status.ok()) { @@ -361,6 +363,7 @@ ValidateSearchParams(const milvus::json& search_params, const engine::meta::Coll } case (int32_t)engine::EngineType::FAISS_IVFFLAT: case (int32_t)engine::EngineType::FAISS_IVFSQ8: + case (int32_t)engine::EngineType::FAISS_IVFSQ8NR: case (int32_t)engine::EngineType::FAISS_IVFSQ8H: case (int32_t)engine::EngineType::FAISS_BIN_IVFFLAT: case (int32_t)engine::EngineType::FAISS_PQ: { @@ -377,6 +380,7 @@ ValidateSearchParams(const milvus::json& search_params, const engine::meta::Coll } break; } + case (int32_t)engine::EngineType::HNSW_SQ8NM: case (int32_t)engine::EngineType::HNSW: { auto status = CheckParameterRange(search_params, knowhere::IndexParams::ef, topk, 4096); if (!status.ok()) { diff --git a/core/src/server/delivery/hybrid_request/CreateHybridCollectionRequest.cpp b/core/src/server/delivery/hybrid_request/CreateHybridCollectionRequest.cpp index 972c89f373c9..14e46aa8c9f8 100644 --- a/core/src/server/delivery/hybrid_request/CreateHybridCollectionRequest.cpp +++ b/core/src/server/delivery/hybrid_request/CreateHybridCollectionRequest.cpp @@ -93,8 +93,8 @@ CreateHybridCollectionRequest::OnExecute() { if (!field_params_.at(field_name).empty()) { auto field_param = field_params_.at(field_name); schema.field_params_ = field_param; - if (field_type.second == engine::meta::hybrid::DataType::FLOAT_VECTOR || - field_type.second == engine::meta::hybrid::DataType::BINARY_VECTOR) { + if (field_type.second == engine::meta::hybrid::DataType::VECTOR_FLOAT || + field_type.second == engine::meta::hybrid::DataType::VECTOR_BINARY) { vector_param = milvus::json::parse(field_param); if (vector_param.contains("dimension")) { dimension = vector_param["dimension"].get(); diff --git a/core/src/server/delivery/hybrid_request/InsertEntityRequest.cpp b/core/src/server/delivery/hybrid_request/InsertEntityRequest.cpp index 6320a30cd923..dee8129c8118 100644 --- a/core/src/server/delivery/hybrid_request/InsertEntityRequest.cpp +++ b/core/src/server/delivery/hybrid_request/InsertEntityRequest.cpp @@ -111,13 +111,13 @@ InsertEntityRequest::OnExecute() { } for (const auto& schema : fields_schema.fields_schema_) { - if (schema.field_type_ == (int32_t)engine::meta::hybrid::DataType::FLOAT_VECTOR && + if (schema.field_type_ == (int32_t)engine::meta::hybrid::DataType::VECTOR_FLOAT && vector_datas_it->second.float_data_.empty()) { return Status{ SERVER_INVALID_ROWRECORD_ARRAY, "The vector field is defined as float vector. Make sure you have entered float vector records"}; } - if (schema.field_type_ == (int32_t)engine::meta::hybrid::DataType::BINARY_VECTOR && + if (schema.field_type_ == (int32_t)engine::meta::hybrid::DataType::VECTOR_BINARY && vector_datas_it->second.binary_data_.empty()) { return Status{ SERVER_INVALID_ROWRECORD_ARRAY, diff --git a/core/src/server/delivery/request/BaseRequest.cpp b/core/src/server/delivery/request/BaseRequest.cpp index d915a0400610..5edcbfdb32eb 100644 --- a/core/src/server/delivery/request/BaseRequest.cpp +++ b/core/src/server/delivery/request/BaseRequest.cpp @@ -129,6 +129,7 @@ BaseRequest::OnPostExecute() { void BaseRequest::Done() { + std::unique_lock lock(finish_mtx_); done_ = true; finish_cond_.notify_all(); } diff --git a/core/src/server/delivery/request/BaseRequest.h b/core/src/server/delivery/request/BaseRequest.h index f74fa9756ad4..526d6fe1f458 100644 --- a/core/src/server/delivery/request/BaseRequest.h +++ b/core/src/server/delivery/request/BaseRequest.h @@ -225,15 +225,16 @@ class BaseRequest { protected: const std::shared_ptr context_; - mutable std::mutex finish_mtx_; - std::condition_variable finish_cond_; - RequestType type_; std::string request_group_; bool async_; - bool done_; Status status_; + private: + mutable std::mutex finish_mtx_; + std::condition_variable finish_cond_; + bool done_; + public: const std::shared_ptr& Context() const { diff --git a/core/src/server/delivery/request/DeleteByIDRequest.cpp b/core/src/server/delivery/request/DeleteByIDRequest.cpp index 56b1c55eb7ff..647fbe931fba 100644 --- a/core/src/server/delivery/request/DeleteByIDRequest.cpp +++ b/core/src/server/delivery/request/DeleteByIDRequest.cpp @@ -68,6 +68,7 @@ DeleteByIDRequest::OnExecute() { } // Check collection's index type supports delete +#ifdef MILVUS_SUPPORT_SPTAG if (collection_schema.engine_type_ == (int32_t)engine::EngineType::SPTAG_BKT || collection_schema.engine_type_ == (int32_t)engine::EngineType::SPTAG_KDT) { std::string err_msg = @@ -75,6 +76,7 @@ DeleteByIDRequest::OnExecute() { LOG_SERVER_ERROR_ << err_msg; return Status(SERVER_UNSUPPORTED_ERROR, err_msg); } +#endif rc.RecordSection("check validation"); diff --git a/core/src/server/delivery/request/SearchCombineRequest.cpp b/core/src/server/delivery/request/SearchCombineRequest.cpp index 6dca3dd96bf8..d53e3e639464 100644 --- a/core/src/server/delivery/request/SearchCombineRequest.cpp +++ b/core/src/server/delivery/request/SearchCombineRequest.cpp @@ -27,7 +27,6 @@ namespace server { namespace { constexpr int64_t MAX_TOPK_GAP = 200; -constexpr uint64_t MAX_NQ = 200; void GetUniqueList(const std::vector& list, std::set& unique_list) { @@ -93,7 +92,8 @@ class TracingContextList { } // namespace -SearchCombineRequest::SearchCombineRequest() : BaseRequest(nullptr, BaseRequest::kSearchCombine) { +SearchCombineRequest::SearchCombineRequest(int64_t max_nq) + : BaseRequest(nullptr, BaseRequest::kSearchCombine), combine_max_nq_(max_nq) { } Status @@ -133,6 +133,8 @@ SearchCombineRequest::Combine(const SearchRequestPtr& request) { } request_list_.push_back(request); + vectors_data_.vector_count_ += request->VectorsData().vector_count_; + return Status::OK(); } @@ -152,11 +154,11 @@ SearchCombineRequest::CanCombine(const SearchRequestPtr& request) { } // sum of nq must less-equal than MAX_NQ - if (vectors_data_.vector_count_ > MAX_NQ || request->VectorsData().vector_count_ > MAX_NQ) { + if (vectors_data_.vector_count_ > combine_max_nq_ || request->VectorsData().vector_count_ > combine_max_nq_) { return false; } uint64_t total_nq = vectors_data_.vector_count_ + request->VectorsData().vector_count_; - if (total_nq > MAX_NQ) { + if (total_nq > combine_max_nq_) { return false; } @@ -178,7 +180,7 @@ SearchCombineRequest::CanCombine(const SearchRequestPtr& request) { } bool -SearchCombineRequest::CanCombine(const SearchRequestPtr& left, const SearchRequestPtr& right) { +SearchCombineRequest::CanCombine(const SearchRequestPtr& left, const SearchRequestPtr& right, int64_t max_nq) { if (left->CollectionName() != right->CollectionName()) { return false; } @@ -193,11 +195,11 @@ SearchCombineRequest::CanCombine(const SearchRequestPtr& left, const SearchReque } // sum of nq must less-equal than MAX_NQ - if (left->VectorsData().vector_count_ > MAX_NQ || right->VectorsData().vector_count_ > MAX_NQ) { + if (left->VectorsData().vector_count_ > max_nq || right->VectorsData().vector_count_ > max_nq) { return false; } uint64_t total_nq = left->VectorsData().vector_count_ + right->VectorsData().vector_count_; - if (total_nq > MAX_NQ) { + if (total_nq > max_nq) { return false; } diff --git a/core/src/server/delivery/request/SearchCombineRequest.h b/core/src/server/delivery/request/SearchCombineRequest.h index 3aa24bb92880..d71c4c70c7f8 100644 --- a/core/src/server/delivery/request/SearchCombineRequest.h +++ b/core/src/server/delivery/request/SearchCombineRequest.h @@ -22,9 +22,11 @@ namespace milvus { namespace server { +constexpr int64_t COMBINE_MAX_NQ = 64; + class SearchCombineRequest : public BaseRequest { public: - SearchCombineRequest(); + explicit SearchCombineRequest(int64_t max_nq = COMBINE_MAX_NQ); Status Combine(const SearchRequestPtr& request); @@ -33,7 +35,7 @@ class SearchCombineRequest : public BaseRequest { CanCombine(const SearchRequestPtr& request); static bool - CanCombine(const SearchRequestPtr& left, const SearchRequestPtr& right); + CanCombine(const SearchRequestPtr& left, const SearchRequestPtr& right, int64_t max_nq = COMBINE_MAX_NQ); protected: Status @@ -54,6 +56,8 @@ class SearchCombineRequest : public BaseRequest { std::set file_id_list_; std::vector request_list_; + + int64_t combine_max_nq_ = COMBINE_MAX_NQ; }; using SearchCombineRequestPtr = std::shared_ptr; diff --git a/core/src/server/delivery/strategy/SearchReqStrategy.cpp b/core/src/server/delivery/strategy/SearchReqStrategy.cpp index 3b49ed6964ac..0b66ca7b5725 100644 --- a/core/src/server/delivery/strategy/SearchReqStrategy.cpp +++ b/core/src/server/delivery/strategy/SearchReqStrategy.cpp @@ -10,6 +10,7 @@ // or implied. See the License for the specific language governing permissions and limitations under the License. #include "server/delivery/strategy/SearchReqStrategy.h" +#include "config/Config.h" #include "server/delivery/request/SearchCombineRequest.h" #include "server/delivery/request/SearchRequest.h" #include "utils/CommonUtil.h" @@ -24,6 +25,8 @@ namespace milvus { namespace server { SearchReqStrategy::SearchReqStrategy() { + SetIdentity("SearchReqStrategy"); + AddSearchCombineMaxNqListener(); } Status @@ -34,15 +37,21 @@ SearchReqStrategy::ReScheduleQueue(const BaseRequestPtr& request, std::queue(request); BaseRequestPtr last_req = queue.back(); if (last_req->GetRequestType() == BaseRequest::kSearch) { SearchRequestPtr last_search_req = std::static_pointer_cast(last_req); - if (SearchCombineRequest::CanCombine(last_search_req, new_search_req)) { + if (SearchCombineRequest::CanCombine(last_search_req, new_search_req, search_combine_nq_)) { // combine request - SearchCombineRequestPtr combine_request = std::make_shared(); + SearchCombineRequestPtr combine_request = std::make_shared(search_combine_nq_); combine_request->Combine(last_search_req); combine_request->Combine(new_search_req); queue.back() = combine_request; // replace the last request to combine request diff --git a/core/src/server/delivery/strategy/SearchReqStrategy.h b/core/src/server/delivery/strategy/SearchReqStrategy.h index 20093c66c24b..3d2c3de03bcb 100644 --- a/core/src/server/delivery/strategy/SearchReqStrategy.h +++ b/core/src/server/delivery/strategy/SearchReqStrategy.h @@ -11,6 +11,7 @@ #pragma once +#include "config/handler/EngineConfigHandler.h" #include "server/delivery/strategy/RequestStrategy.h" #include "utils/Status.h" @@ -20,7 +21,7 @@ namespace milvus { namespace server { -class SearchReqStrategy : public RequestStrategy { +class SearchReqStrategy : public RequestStrategy, public EngineConfigHandler { public: SearchReqStrategy(); diff --git a/core/src/server/grpc_impl/GrpcRequestHandler.cpp b/core/src/server/grpc_impl/GrpcRequestHandler.cpp index ba38277cc106..1d7f4d13def1 100644 --- a/core/src/server/grpc_impl/GrpcRequestHandler.cpp +++ b/core/src/server/grpc_impl/GrpcRequestHandler.cpp @@ -387,7 +387,7 @@ ConstructEntityResults(const std::vector& attrs, const std::v if (not set_valid_row) { response->add_valid_row(true); } - grpc_field->set_type(::milvus::grpc::DataType::FLOAT_VECTOR); + grpc_field->set_type(::milvus::grpc::DataType::VECTOR_FLOAT); grpc_data->mutable_float_data()->Resize(vector.float_data_.size(), 0); memcpy(grpc_data->mutable_float_data()->mutable_data(), vector.float_data_.data(), vector.float_data_.size() * sizeof(float)); @@ -395,7 +395,7 @@ ConstructEntityResults(const std::vector& attrs, const std::v if (not set_valid_row) { response->add_valid_row(true); } - grpc_field->set_type(::milvus::grpc::DataType::BINARY_VECTOR); + grpc_field->set_type(::milvus::grpc::DataType::VECTOR_BINARY); grpc_data->mutable_binary_data()->resize(vector.binary_data_.size()); memcpy(grpc_data->mutable_binary_data()->data(), vector.binary_data_.data(), vector.binary_data_.size() * sizeof(uint8_t)); diff --git a/core/src/server/grpc_impl/GrpcServer.cpp b/core/src/server/grpc_impl/GrpcServer.cpp index 81f1be47b8f9..8717739e5cee 100644 --- a/core/src/server/grpc_impl/GrpcServer.cpp +++ b/core/src/server/grpc_impl/GrpcServer.cpp @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -49,7 +50,10 @@ class NoReusePortOption : public ::grpc::ServerBuilderOption { void UpdateArguments(::grpc::ChannelArguments* args) override { args->SetInt(GRPC_ARG_ALLOW_REUSEPORT, 0); - args->SetInt(GRPC_ARG_MAX_CONCURRENT_STREAMS, 20); + int grpc_concurrency = 4 * std::thread::hardware_concurrency(); + grpc_concurrency = std::max(32, grpc_concurrency); + grpc_concurrency = std::min(256, grpc_concurrency); + args->SetInt(GRPC_ARG_MAX_CONCURRENT_STREAMS, grpc_concurrency); } void diff --git a/core/src/server/web_impl/Constants.cpp b/core/src/server/web_impl/Constants.cpp index b975b44b6895..5dc200e52c61 100644 --- a/core/src/server/web_impl/Constants.cpp +++ b/core/src/server/web_impl/Constants.cpp @@ -23,6 +23,8 @@ const char* NAME_ENGINE_TYPE_RNSG = "RNSG"; const char* NAME_ENGINE_TYPE_IVFPQ = "IVFPQ"; const char* NAME_ENGINE_TYPE_HNSW = "HNSW"; const char* NAME_ENGINE_TYPE_ANNOY = "ANNOY"; +const char* NAME_ENGINE_TYPE_IVFSQ8NR = "IVFSQ8NR"; +const char* NAME_ENGINE_TYPE_HNSWSQ8NM = "HNSWSQ8NM"; const char* NAME_METRIC_TYPE_L2 = "L2"; const char* NAME_METRIC_TYPE_IP = "IP"; @@ -54,7 +56,8 @@ const std::unordered_map IndexMap = { {engine::EngineType::FAISS_PQ, NAME_ENGINE_TYPE_IVFPQ}, {engine::EngineType::HNSW, NAME_ENGINE_TYPE_HNSW}, {engine::EngineType::ANNOY, NAME_ENGINE_TYPE_ANNOY}, -}; + {engine::EngineType::FAISS_IVFSQ8NR, NAME_ENGINE_TYPE_IVFSQ8NR}, + {engine::EngineType::HNSW_SQ8NM, NAME_ENGINE_TYPE_HNSWSQ8NM}}; const std::unordered_map IndexNameMap = { {NAME_ENGINE_TYPE_FLAT, engine::EngineType::FAISS_IDMAP}, @@ -65,7 +68,8 @@ const std::unordered_map IndexNameMap = { {NAME_ENGINE_TYPE_IVFPQ, engine::EngineType::FAISS_PQ}, {NAME_ENGINE_TYPE_HNSW, engine::EngineType::HNSW}, {NAME_ENGINE_TYPE_ANNOY, engine::EngineType::ANNOY}, -}; + {NAME_ENGINE_TYPE_IVFSQ8NR, engine::EngineType::FAISS_IVFSQ8NR}, + {NAME_ENGINE_TYPE_HNSWSQ8NM, engine::EngineType::HNSW_SQ8NM}}; const std::unordered_map MetricMap = { {engine::MetricType::L2, NAME_METRIC_TYPE_L2}, diff --git a/core/src/server/web_impl/Constants.h b/core/src/server/web_impl/Constants.h index 474b2e4e6727..bf1c026a20d1 100644 --- a/core/src/server/web_impl/Constants.h +++ b/core/src/server/web_impl/Constants.h @@ -23,10 +23,12 @@ namespace web { extern const char* NAME_ENGINE_TYPE_FLAT; extern const char* NAME_ENGINE_TYPE_IVFFLAT; extern const char* NAME_ENGINE_TYPE_IVFSQ8; +extern const char* NAME_ENGINE_TYPE_IVFSQ8NR; extern const char* NAME_ENGINE_TYPE_IVFSQ8H; extern const char* NAME_ENGINE_TYPE_RNSG; extern const char* NAME_ENGINE_TYPE_IVFPQ; extern const char* NAME_ENGINE_TYPE_HNSW; +extern const char* NAME_ENGINE_TYPE_HNSW_SQ8NM; extern const char* NAME_ENGINE_TYPE_ANNOY; extern const char* NAME_METRIC_TYPE_L2; diff --git a/core/src/server/web_impl/controller/WebController.hpp b/core/src/server/web_impl/controller/WebController.hpp index 3d982ba9444f..efa6f408ccf6 100644 --- a/core/src/server/web_impl/controller/WebController.hpp +++ b/core/src/server/web_impl/controller/WebController.hpp @@ -510,7 +510,7 @@ class WebController : public oatpp::web::server::api::ApiController { ADD_CORS(ShowPartitions) ENDPOINT("GET", "/collections/{collection_name}/partitions", ShowPartitions, PATH(String, collection_name), - QUERIES(const QueryParams&, query_params), BODY_STRING(String, body)) { + QUERIES(const QueryParams&, query_params)) { TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections/" + collection_name->std_str() + "/partitions\'"); tr.RecordSection("Received request."); @@ -522,7 +522,7 @@ class WebController : public oatpp::web::server::api::ApiController { auto handler = WebRequestHandler(); std::shared_ptr response; - auto status_dto = handler.ShowPartitions(collection_name, query_params, body, partition_list_dto); + auto status_dto = handler.ShowPartitions(collection_name, query_params, partition_list_dto); switch (status_dto->code->getValue()) { case StatusCode::SUCCESS: response = createDtoResponse(Status::CODE_200, partition_list_dto); diff --git a/core/src/server/web_impl/handler/WebRequestHandler.cpp b/core/src/server/web_impl/handler/WebRequestHandler.cpp index 75934ad88b0b..7ad40f65fe9c 100644 --- a/core/src/server/web_impl/handler/WebRequestHandler.cpp +++ b/core/src/server/web_impl/handler/WebRequestHandler.cpp @@ -78,6 +78,8 @@ WebErrorMap(ErrorCode code) { } } +using FloatJson = nlohmann::basic_json; + /////////////////////////////////// Private methods /////////////////////////////////////// void WebRequestHandler::AddStatusToJson(nlohmann::json& json, int64_t code, const std::string& msg) { @@ -1473,7 +1475,7 @@ WebRequestHandler::CreatePartition(const OString& collection_name, const Partiti } StatusDto::ObjectWrapper -WebRequestHandler::ShowPartitions(const OString& collection_name, const OQueryParams& query_params, const OString& body, +WebRequestHandler::ShowPartitions(const OString& collection_name, const OQueryParams& query_params, PartitionListDto::ObjectWrapper& partition_list_dto) { int64_t offset = 0; auto status = ParseQueryInteger(query_params, "offset", offset); @@ -1492,35 +1494,6 @@ WebRequestHandler::ShowPartitions(const OString& collection_name, const OQueryPa Status(SERVER_UNEXPECTED_ERROR, "Query param 'offset' or 'page_size' should equal or bigger than 0")); } - if (nullptr != body.get() && body->getSize() > 0) { - auto body_json = nlohmann::json::parse(body->c_str()); - if (!body_json.contains("filter")) { - RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'filter\' is required.") - } - auto filter_json = body_json["filter"]; - if (filter_json.contains("partition_tag")) { - std::string tag = filter_json["partition_tag"]; - bool exists = false; - status = request_handler_.HasPartition(context_ptr_, collection_name->std_str(), tag, exists); - if (!status.ok()) { - ASSIGN_RETURN_STATUS_DTO(status) - } - auto partition_dto = PartitionFieldsDto::createShared(); - if (exists) { - partition_list_dto->count = 1; - partition_dto->partition_tag = tag.c_str(); - } else { - partition_list_dto->count = 0; - } - partition_list_dto->partitions = partition_list_dto->partitions->createShared(); - partition_list_dto->partitions->pushBack(partition_dto); - - ASSIGN_RETURN_STATUS_DTO(status) - } else { - RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Unknown field.") - } - } - bool all_required = false; auto required = query_params.get("all_required"); if (nullptr != required.get()) { @@ -1819,7 +1792,7 @@ WebRequestHandler::InsertEntity(const OString& collection_name, const milvus::se attr_values.emplace_back(attr_value); break; } - case engine::meta::hybrid::DataType::FLOAT_VECTOR: { + case engine::meta::hybrid::DataType::VECTOR_FLOAT: { bool bin_flag; status = IsBinaryCollection(collection_name->c_str(), bin_flag); if (!status.ok()) { @@ -1923,8 +1896,9 @@ WebRequestHandler::GetVector(const OString& collection_name, const OQueryParams& ASSIGN_RETURN_STATUS_DTO(status) } - nlohmann::json json; - AddStatusToJson(json, status.code(), status.message()); + FloatJson json; + json["code"] = (int64_t)status.code(); + json["message"] = status.message(); if (vectors_json.empty()) { json["vectors"] = std::vector(); } else { diff --git a/core/src/server/web_impl/handler/WebRequestHandler.h b/core/src/server/web_impl/handler/WebRequestHandler.h index 406cc4ad8379..cc314f169b20 100644 --- a/core/src/server/web_impl/handler/WebRequestHandler.h +++ b/core/src/server/web_impl/handler/WebRequestHandler.h @@ -196,7 +196,7 @@ class WebRequestHandler { CreatePartition(const OString& collection_name, const PartitionRequestDto::ObjectWrapper& param); StatusDto::ObjectWrapper - ShowPartitions(const OString& collection_name, const OQueryParams& query_params, const OString& body, + ShowPartitions(const OString& collection_name, const OQueryParams& query_params, PartitionListDto::ObjectWrapper& partition_list_dto); StatusDto::ObjectWrapper diff --git a/core/src/server/web_impl/utils/Util.cpp b/core/src/server/web_impl/utils/Util.cpp index f97851323f8c..878be0174f59 100644 --- a/core/src/server/web_impl/utils/Util.cpp +++ b/core/src/server/web_impl/utils/Util.cpp @@ -10,6 +10,7 @@ // or implied. See the License for the specific language governing permissions and limitations under the License. #include "server/web_impl/utils/Util.h" +#include #include "config/Utils.h" @@ -20,6 +21,7 @@ namespace web { Status ParseQueryInteger(const OQueryParams& query_params, const std::string& key, int64_t& value, bool nullable) { auto query = query_params.get(key.c_str()); + fiu_do_on("WebUtils.ParseQueryInteger.null_query_get", query = ""); if (nullptr != query.get() && query->getSize() > 0) { std::string value_str = query->std_str(); if (!ValidateStringIsNumber(value_str).ok()) { @@ -38,6 +40,7 @@ ParseQueryInteger(const OQueryParams& query_params, const std::string& key, int6 Status ParseQueryStr(const OQueryParams& query_params, const std::string& key, std::string& value, bool nullable) { auto query = query_params.get(key.c_str()); + fiu_do_on("WebUtils.ParseQueryStr.null_query_get", query = ""); if (nullptr != query.get() && query->getSize() > 0) { value = query->std_str(); } else if (!nullable) { @@ -50,6 +53,7 @@ ParseQueryStr(const OQueryParams& query_params, const std::string& key, std::str Status ParseQueryBool(const OQueryParams& query_params, const std::string& key, bool& value, bool nullable) { auto query = query_params.get(key.c_str()); + fiu_do_on("WebUtils.ParseQueryBool.null_query_get", query = ""); if (nullptr != query.get() && query->getSize() > 0) { std::string value_str = query->std_str(); if (!ValidateStringIsBool(value_str).ok()) { diff --git a/core/src/storage/disk/DiskOperation.cpp b/core/src/storage/disk/DiskOperation.cpp index 8ce5af10ce9c..eb3ea0b38355 100644 --- a/core/src/storage/disk/DiskOperation.cpp +++ b/core/src/storage/disk/DiskOperation.cpp @@ -34,7 +34,8 @@ DiskOperation::CreateDirectory() { bool is_dir = boost::filesystem::is_directory(dir_path_); fiu_do_on("DiskOperation.CreateDirectory.is_directory", is_dir = false); if (!is_dir) { - auto ret = boost::filesystem::create_directory(dir_path_); + /* create directories recursively */ + auto ret = boost::filesystem::create_directories(dir_path_); fiu_do_on("DiskOperation.CreateDirectory.create_directory", ret = false); if (!ret) { std::string err_msg = "Failed to create directory: " + dir_path_; diff --git a/core/src/utils/BlockingQueue.h b/core/src/utils/BlockingQueue.h index 55919f8aaff1..1b489f53e6c1 100644 --- a/core/src/utils/BlockingQueue.h +++ b/core/src/utils/BlockingQueue.h @@ -68,13 +68,13 @@ class BlockingQueue { } size_t - Size() { + Size() const { std::lock_guard lock(mtx); return queue_.size(); } bool - Empty() { + Empty() const { std::unique_lock lock(mtx); return queue_.empty(); } diff --git a/core/src/version.h.in b/core/src/version.h.in index 32c30fbe455c..0f0b3c1edf81 100644 --- a/core/src/version.h.in +++ b/core/src/version.h.in @@ -11,5 +11,5 @@ #cmakedefine MILVUS_VERSION "@MILVUS_VERSION@" #cmakedefine BUILD_TYPE "@BUILD_TYPE@" -#cmakedefine BUILD_TIME @BUILD_TIME@ +#cmakedefine BUILD_TIME "@BUILD_TIME@" #cmakedefine LAST_COMMIT_ID "@LAST_COMMIT_ID@" diff --git a/core/thirdparty/versions.txt b/core/thirdparty/versions.txt index ba5adc8cf98f..f9c7ba21a261 100644 --- a/core/thirdparty/versions.txt +++ b/core/thirdparty/versions.txt @@ -1,7 +1,7 @@ EASYLOGGINGPP_VERSION=v9.96.7 GTEST_VERSION=1.8.1 MYSQLPP_VERSION=3.2.4 -PROMETHEUS_VERSION=v0.7.0 +PROMETHEUS_VERSION=0.7.0 SQLITE_VERSION=3280000 SQLITE_ORM_VERSION=master YAMLCPP_VERSION=0.6.2 diff --git a/core/unittest/CMakeLists.txt b/core/unittest/CMakeLists.txt index 3d03ed4c063f..21e37aca15f2 100644 --- a/core/unittest/CMakeLists.txt +++ b/core/unittest/CMakeLists.txt @@ -31,13 +31,19 @@ aux_source_directory(${MILVUS_ENGINE_SRC}/db db_main_files) aux_source_directory(${MILVUS_ENGINE_SRC}/db/attr db_attr_files) aux_source_directory(${MILVUS_ENGINE_SRC}/db/engine db_engine_files) aux_source_directory(${MILVUS_ENGINE_SRC}/db/insert db_insert_files) -aux_source_directory(${MILVUS_ENGINE_SRC}/db/meta db_meta_files) aux_source_directory(${MILVUS_ENGINE_SRC}/db/merge db_merge_files) aux_source_directory(${MILVUS_ENGINE_SRC}/db/wal db_wal_files) aux_source_directory(${MILVUS_ENGINE_SRC}/db/snapshot db_snapshot_files) aux_source_directory(${MILVUS_ENGINE_SRC}/search search_files) aux_source_directory(${MILVUS_ENGINE_SRC}/query query_files) +aux_source_directory(${MILVUS_ENGINE_SRC}/db/meta db_meta_main_files) +aux_source_directory(${MILVUS_ENGINE_SRC}/db/meta/backend db_meta_backend_files) +set(db_meta_files + ${db_meta_main_files} + ${db_meta_backend_files} + ) + set(grpc_service_files ${MILVUS_ENGINE_SRC}/grpc/gen-milvus/milvus.grpc.pb.cc ${MILVUS_ENGINE_SRC}/grpc/gen-milvus/milvus.pb.cc @@ -121,6 +127,7 @@ set(storage_files aux_source_directory(${MILVUS_ENGINE_SRC}/codecs codecs_files) aux_source_directory(${MILVUS_ENGINE_SRC}/codecs/default codecs_default_files) +aux_source_directory(${MILVUS_ENGINE_SRC}/codecs/snapshot codecs_snapshot_files) aux_source_directory(${MILVUS_ENGINE_SRC}/segment segment_files) diff --git a/core/unittest/db/CMakeLists.txt b/core/unittest/db/CMakeLists.txt index 082fa2549c52..ae3faa94ca50 100644 --- a/core/unittest/db/CMakeLists.txt +++ b/core/unittest/db/CMakeLists.txt @@ -36,6 +36,7 @@ add_executable(test_db ${cache_files} ${codecs_files} ${codecs_default_files} + ${codecs_snapshot_files} ${config_files} ${config_handler_files} ${db_main_files} diff --git a/core/unittest/db/test_engine.cpp b/core/unittest/db/test_engine.cpp index 250dc8dddb19..47092d3019b3 100644 --- a/core/unittest/db/test_engine.cpp +++ b/core/unittest/db/test_engine.cpp @@ -116,6 +116,7 @@ TEST_F(EngineTest, FACTORY_TEST) { ASSERT_TRUE(engine_ptr != nullptr); } +#ifdef MILVUS_SUPPORT_SPTAG { auto engine_ptr = milvus::engine::EngineFactory::Build( 512, "/tmp/milvus_index_1", milvus::engine::EngineType::SPTAG_KDT, @@ -150,6 +151,7 @@ TEST_F(EngineTest, FACTORY_TEST) { milvus::engine::MetricType::L2, index_params)); fiu_disable("ExecutionEngineImpl.throw_exception"); } +#endif } TEST_F(EngineTest, ENGINE_IMPL_TEST) { diff --git a/core/unittest/db/test_hybrid_db.cpp b/core/unittest/db/test_hybrid_db.cpp index cd851c79efb7..a95eee7e2b4a 100644 --- a/core/unittest/db/test_hybrid_db.cpp +++ b/core/unittest/db/test_hybrid_db.cpp @@ -219,142 +219,142 @@ TEST_F(DBTest, HYBRID_DB_TEST) { ASSERT_TRUE(stat.ok()); } -TEST_F(DBTest, HYBRID_SEARCH_TEST) { - milvus::engine::meta::CollectionSchema collection_info; - milvus::engine::meta::hybrid::FieldsSchema fields_info; - std::unordered_map attr_type; - BuildCollectionSchema(collection_info, fields_info, attr_type); - - auto stat = db_->CreateHybridCollection(collection_info, fields_info); - ASSERT_TRUE(stat.ok()); - milvus::engine::meta::CollectionSchema collection_info_get; - milvus::engine::meta::hybrid::FieldsSchema fields_info_get; - collection_info_get.collection_id_ = COLLECTION_NAME; - stat = db_->DescribeHybridCollection(collection_info_get, fields_info_get); - ASSERT_TRUE(stat.ok()); - ASSERT_EQ(collection_info_get.dimension_, COLLECTION_DIM); - - uint64_t qb = 1000; - milvus::engine::Entity entity; - BuildEntity(qb, 0, entity); - - std::vector field_names = {"field_0", "field_1", "field_2"}; - - stat = db_->InsertEntities(COLLECTION_NAME, "", field_names, entity, attr_type); - ASSERT_TRUE(stat.ok()); - - stat = db_->Flush(COLLECTION_NAME); - ASSERT_TRUE(stat.ok()); - - // Construct general query - auto general_query = std::make_shared(); - auto query_ptr = std::make_shared(); - ConstructGeneralQuery(general_query, query_ptr); - - std::vector tags; - milvus::engine::QueryResult result; - stat = db_->HybridQuery(dummy_context_, COLLECTION_NAME, tags, general_query, query_ptr, field_names, attr_type, - result); - ASSERT_TRUE(stat.ok()); - ASSERT_EQ(result.row_num_, NQ); - ASSERT_EQ(result.result_ids_.size(), NQ * TOPK); -} - -TEST_F(DBTest, COMPACT_TEST) { - milvus::engine::meta::CollectionSchema collection_info; - milvus::engine::meta::hybrid::FieldsSchema fields_info; - std::unordered_map attr_type; - BuildCollectionSchema(collection_info, fields_info, attr_type); - - auto stat = db_->CreateHybridCollection(collection_info, fields_info); - ASSERT_TRUE(stat.ok()); - milvus::engine::meta::CollectionSchema collection_info_get; - milvus::engine::meta::hybrid::FieldsSchema fields_info_get; - collection_info_get.collection_id_ = COLLECTION_NAME; - stat = db_->DescribeHybridCollection(collection_info_get, fields_info_get); - ASSERT_TRUE(stat.ok()); - ASSERT_EQ(collection_info_get.dimension_, COLLECTION_DIM); - - uint64_t vector_count = 1000; - milvus::engine::Entity entity; - BuildEntity(vector_count, 0, entity); - - std::vector field_names = {"field_0", "field_1", "field_2"}; - - stat = db_->InsertEntities(COLLECTION_NAME, "", field_names, entity, attr_type); - ASSERT_TRUE(stat.ok()); - - stat = db_->Flush(); - ASSERT_TRUE(stat.ok()); - - std::vector ids_to_delete; - ids_to_delete.emplace_back(entity.id_array_.front()); - ids_to_delete.emplace_back(entity.id_array_.back()); - stat = db_->DeleteVectors(collection_info.collection_id_, ids_to_delete); - ASSERT_TRUE(stat.ok()); - - stat = db_->Flush(); - ASSERT_TRUE(stat.ok()); - - stat = db_->Compact(dummy_context_, collection_info.collection_id_); - ASSERT_TRUE(stat.ok()); - - const int topk = 1, nprobe = 1; - milvus::json json_params = {{"nprobe", nprobe}}; - - std::vector tags; - milvus::engine::ResultIds result_ids; - milvus::engine::ResultDistances result_distances; - - stat = db_->QueryByIDs(dummy_context_, collection_info.collection_id_, tags, topk, json_params, ids_to_delete, - result_ids, result_distances); - ASSERT_TRUE(stat.ok()); - ASSERT_EQ(result_ids[0], -1); - ASSERT_EQ(result_distances[0], std::numeric_limits::max()); -} - -TEST_F(DBTest2, GET_ENTITY_BY_ID_TEST) { - milvus::engine::meta::CollectionSchema collection_schema; - milvus::engine::meta::hybrid::FieldsSchema fields_schema; - std::unordered_map attr_type; - BuildCollectionSchema(collection_schema, fields_schema, attr_type); - - auto stat = db_->CreateHybridCollection(collection_schema, fields_schema); - ASSERT_TRUE(stat.ok()); - - uint64_t vector_count = 1000; - milvus::engine::Entity entity; - BuildEntity(vector_count, 0, entity); - - std::vector field_names = {"field_0", "field_1", "field_2"}; - - stat = db_->InsertEntities(COLLECTION_NAME, "", field_names, entity, attr_type); - ASSERT_TRUE(stat.ok()); - - stat = db_->Flush(); - ASSERT_TRUE(stat.ok()); - - std::vector attrs; - std::vector vectors; - stat = db_->GetEntitiesByID(COLLECTION_NAME, entity.id_array_, field_names, vectors, attrs); - ASSERT_TRUE(stat.ok()); - ASSERT_EQ(vectors.size(), entity.id_array_.size()); - ASSERT_EQ(vectors[0].float_data_.size(), COLLECTION_DIM); - ASSERT_EQ(attrs[0].attr_data_.at("field_0").size(), sizeof(int32_t)); - - for (int64_t i = 0; i < COLLECTION_DIM; i++) { - ASSERT_FLOAT_EQ(vectors[0].float_data_[i], entity.vector_data_.at("field_3").float_data_[i]); - } - - std::vector empty_array; - vectors.clear(); - attrs.clear(); - field_names.clear(); - stat = db_->GetEntitiesByID(COLLECTION_NAME, empty_array, field_names, vectors, attrs); - ASSERT_TRUE(stat.ok()); - for (auto& vector : vectors) { - ASSERT_EQ(vector.vector_count_, 0); - ASSERT_TRUE(vector.float_data_.empty()); - ASSERT_TRUE(vector.binary_data_.empty()); - } -} +// TEST_F(DBTest, HYBRID_SEARCH_TEST) { +// milvus::engine::meta::CollectionSchema collection_info; +// milvus::engine::meta::hybrid::FieldsSchema fields_info; +// std::unordered_map attr_type; +// BuildCollectionSchema(collection_info, fields_info, attr_type); +// +// auto stat = db_->CreateHybridCollection(collection_info, fields_info); +// ASSERT_TRUE(stat.ok()); +// milvus::engine::meta::CollectionSchema collection_info_get; +// milvus::engine::meta::hybrid::FieldsSchema fields_info_get; +// collection_info_get.collection_id_ = COLLECTION_NAME; +// stat = db_->DescribeHybridCollection(collection_info_get, fields_info_get); +// ASSERT_TRUE(stat.ok()); +// ASSERT_EQ(collection_info_get.dimension_, COLLECTION_DIM); +// +// uint64_t qb = 1000; +// milvus::engine::Entity entity; +// BuildEntity(qb, 0, entity); +// +// std::vector field_names = {"field_0", "field_1", "field_2"}; +// +// stat = db_->InsertEntities(COLLECTION_NAME, "", field_names, entity, attr_type); +// ASSERT_TRUE(stat.ok()); +// +// stat = db_->Flush(COLLECTION_NAME); +// ASSERT_TRUE(stat.ok()); +// +// // Construct general query +// auto general_query = std::make_shared(); +// auto query_ptr = std::make_shared(); +// ConstructGeneralQuery(general_query, query_ptr); +// +// std::vector tags; +// milvus::engine::QueryResult result; +// stat = db_->HybridQuery(dummy_context_, COLLECTION_NAME, tags, general_query, query_ptr, field_names, attr_type, +// result); +// ASSERT_TRUE(stat.ok()); +// ASSERT_EQ(result.row_num_, NQ); +// ASSERT_EQ(result.result_ids_.size(), NQ * TOPK); +//} +// +// TEST_F(DBTest, COMPACT_TEST) { +// milvus::engine::meta::CollectionSchema collection_info; +// milvus::engine::meta::hybrid::FieldsSchema fields_info; +// std::unordered_map attr_type; +// BuildCollectionSchema(collection_info, fields_info, attr_type); +// +// auto stat = db_->CreateHybridCollection(collection_info, fields_info); +// ASSERT_TRUE(stat.ok()); +// milvus::engine::meta::CollectionSchema collection_info_get; +// milvus::engine::meta::hybrid::FieldsSchema fields_info_get; +// collection_info_get.collection_id_ = COLLECTION_NAME; +// stat = db_->DescribeHybridCollection(collection_info_get, fields_info_get); +// ASSERT_TRUE(stat.ok()); +// ASSERT_EQ(collection_info_get.dimension_, COLLECTION_DIM); +// +// uint64_t vector_count = 1000; +// milvus::engine::Entity entity; +// BuildEntity(vector_count, 0, entity); +// +// std::vector field_names = {"field_0", "field_1", "field_2"}; +// +// stat = db_->InsertEntities(COLLECTION_NAME, "", field_names, entity, attr_type); +// ASSERT_TRUE(stat.ok()); +// +// stat = db_->Flush(); +// ASSERT_TRUE(stat.ok()); +// +// std::vector ids_to_delete; +// ids_to_delete.emplace_back(entity.id_array_.front()); +// ids_to_delete.emplace_back(entity.id_array_.back()); +// stat = db_->DeleteVectors(collection_info.collection_id_, ids_to_delete); +// ASSERT_TRUE(stat.ok()); +// +// stat = db_->Flush(); +// ASSERT_TRUE(stat.ok()); +// +// stat = db_->Compact(dummy_context_, collection_info.collection_id_); +// ASSERT_TRUE(stat.ok()); +// +// const int topk = 1, nprobe = 1; +// milvus::json json_params = {{"nprobe", nprobe}}; +// +// std::vector tags; +// milvus::engine::ResultIds result_ids; +// milvus::engine::ResultDistances result_distances; +// +// stat = db_->QueryByIDs(dummy_context_, collection_info.collection_id_, tags, topk, json_params, ids_to_delete, +// result_ids, result_distances); +// ASSERT_TRUE(stat.ok()); +// ASSERT_EQ(result_ids[0], -1); +// ASSERT_EQ(result_distances[0], std::numeric_limits::max()); +//} +// +// TEST_F(DBTest2, GET_ENTITY_BY_ID_TEST) { +// milvus::engine::meta::CollectionSchema collection_schema; +// milvus::engine::meta::hybrid::FieldsSchema fields_schema; +// std::unordered_map attr_type; +// BuildCollectionSchema(collection_schema, fields_schema, attr_type); +// +// auto stat = db_->CreateHybridCollection(collection_schema, fields_schema); +// ASSERT_TRUE(stat.ok()); +// +// uint64_t vector_count = 1000; +// milvus::engine::Entity entity; +// BuildEntity(vector_count, 0, entity); +// +// std::vector field_names = {"field_0", "field_1", "field_2"}; +// +// stat = db_->InsertEntities(COLLECTION_NAME, "", field_names, entity, attr_type); +// ASSERT_TRUE(stat.ok()); +// +// stat = db_->Flush(); +// ASSERT_TRUE(stat.ok()); +// +// std::vector attrs; +// std::vector vectors; +// stat = db_->GetEntitiesByID(COLLECTION_NAME, entity.id_array_, field_names, vectors, attrs); +// ASSERT_TRUE(stat.ok()); +// ASSERT_EQ(vectors.size(), entity.id_array_.size()); +// ASSERT_EQ(vectors[0].float_data_.size(), COLLECTION_DIM); +// ASSERT_EQ(attrs[0].attr_data_.at("field_0").size(), sizeof(int32_t)); +// +// for (int64_t i = 0; i < COLLECTION_DIM; i++) { +// ASSERT_FLOAT_EQ(vectors[0].float_data_[i], entity.vector_data_.at("field_3").float_data_[i]); +// } +// +// std::vector empty_array; +// vectors.clear(); +// attrs.clear(); +// field_names.clear(); +// stat = db_->GetEntitiesByID(COLLECTION_NAME, empty_array, field_names, vectors, attrs); +// ASSERT_TRUE(stat.ok()); +// for (auto& vector : vectors) { +// ASSERT_EQ(vector.vector_count_, 0); +// ASSERT_TRUE(vector.float_data_.empty()); +// ASSERT_TRUE(vector.binary_data_.empty()); +// } +//} diff --git a/core/unittest/db/test_rpc.cpp b/core/unittest/db/test_rpc.cpp index bdce35178da9..5ad015857226 100644 --- a/core/unittest/db/test_rpc.cpp +++ b/core/unittest/db/test_rpc.cpp @@ -58,7 +58,7 @@ ConstructMapping(::milvus::grpc::Mapping* request, const std::string& collection auto field2 = request->add_fields(); field2->set_name("field_2"); - field2->set_type(::milvus::grpc::DataType::FLOAT_VECTOR); + field2->set_type(::milvus::grpc::DataType::VECTOR_FLOAT); } void @@ -190,7 +190,7 @@ class RpcHandlerTest : public testing::Test { auto field_2 = request.add_fields(); field_2->set_name("field_2"); - field_2->set_type(::milvus::grpc::DataType::FLOAT_VECTOR); + field_2->set_type(::milvus::grpc::DataType::VECTOR_FLOAT); auto grpc_index_param_2 = field_2->add_index_params(); grpc_index_param_2->set_key("name"); grpc_index_param_2->set_value("index_2"); diff --git a/core/unittest/db/test_web.cpp b/core/unittest/db/test_web.cpp index 9306e4e9f529..08442347e7d8 100644 --- a/core/unittest/db/test_web.cpp +++ b/core/unittest/db/test_web.cpp @@ -33,6 +33,7 @@ #include "server/web_impl/dto/StatusDto.hpp" #include "server/web_impl/dto/VectorDto.hpp" #include "server/web_impl/handler/WebRequestHandler.h" +#include "server/web_impl/utils/Util.h" #include "src/version.h" #include "utils/CommonUtil.h" #include "utils/StringHelpFunctions.h" @@ -954,29 +955,6 @@ TEST_F(WebControllerTest, PARTITION) { ASSERT_EQ(OStatus::CODE_404.code, response->getStatusCode()); } -TEST_F(WebControllerTest, PARTITION_FILTER) { - const OString collection_name = "test_controller_partition_" + OString(RandomName().c_str()); - GenCollection(client_ptr, conncetion_ptr, collection_name, 64, 100, "L2"); - - nlohmann::json body_json; - body_json["filter"]["partition_tag"] = "tag_not_exists_"; - auto response = client_ptr->showPartitions(collection_name, "0", "10", body_json.dump().c_str()); - ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); - auto result_dto = response->readBodyToDto(object_mapper.get()); - ASSERT_EQ(result_dto->count->getValue(), 0); - - auto par_param = milvus::server::web::PartitionRequestDto::createShared(); - par_param->partition_tag = "tag01"; - response = client_ptr->createPartition(collection_name, par_param); - ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode()); - - body_json["filter"]["partition_tag"] = "tag01"; - response = client_ptr->showPartitions(collection_name, "0", "10", body_json.dump().c_str()); - ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); - result_dto = response->readBodyToDto(object_mapper.get()); - ASSERT_EQ(result_dto->count->getValue(), 1); -} - TEST_F(WebControllerTest, SHOW_SEGMENTS) { OString collection_name = OString("test_milvus_web_segments_test_") + RandomName().c_str(); @@ -1203,6 +1181,7 @@ TEST_F(WebControllerTest, SEARCH_BIN) { ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); } +#if 0 TEST_F(WebControllerTest, SEARCH_BY_IDS) { #ifdef MILVUS_GPU_VERSION auto& config = milvus::server::Config::GetInstance(); @@ -1249,6 +1228,7 @@ TEST_F(WebControllerTest, SEARCH_BY_IDS) { // ASSERT_EQ(std::to_string(ids.at(j)), id.get()); // } } +#endif TEST_F(WebControllerTest, GET_VECTORS_BY_IDS) { const OString collection_name = "test_milvus_web_get_vector_by_id_test_" + OString(RandomName().c_str()); @@ -1577,3 +1557,112 @@ TEST_F(WebControllerTest, LOAD) { response = client_ptr->op("task", load_json.dump().c_str(), conncetion_ptr); ASSERT_EQ(OStatus::CODE_400.code, response->getStatusCode()); } + +class WebUtilTest : public ::testing::Test { + public: + std::string key; + OString value; + int64_t intValue; + std::string stringValue; + bool boolValue; + OQueryParams params; + + void + SetUp() override { + key = "offset"; + } + + void + TearDown() override { + }; +}; + + +TEST_F(WebUtilTest, ParseQueryInteger) { + value = "5"; + + params.put("offset", value); + milvus::Status status = milvus::server::web::ParseQueryInteger(params, key, intValue); + ASSERT_TRUE(status.ok()); + ASSERT_EQ(5, intValue); +} + +TEST_F(WebUtilTest, ParQueryIntegerIllegalQueryParam) { + value = "-5"; + + params.put("offset", value); + milvus::Status status = milvus::server::web::ParseQueryInteger(params, key, intValue); + ASSERT_EQ(status.code(), milvus::server::web::ILLEGAL_QUERY_PARAM); + ASSERT_STREQ(status.message().c_str(), + "Query param \'offset\' is illegal, only non-negative integer supported"); +} + +TEST_F(WebUtilTest, ParQueryIntegerQueryParamLoss) { + value = "5"; + + params.put("offset", value); + fiu_enable("WebUtils.ParseQueryInteger.null_query_get", 1, nullptr, 0); + milvus::Status status = milvus::server::web::ParseQueryInteger(params, key, intValue, false); + ASSERT_EQ(status.code(), milvus::server::web::QUERY_PARAM_LOSS); + std::string msg = "Query param \"" + key + "\" is required"; + ASSERT_STREQ(status.message().c_str(), msg.c_str()); +} + + +TEST_F(WebUtilTest, ParseQueryBoolTrue) { + value = "True"; + + params.put("offset", value); + milvus::Status status = milvus::server::web::ParseQueryBool(params, key, boolValue); + ASSERT_TRUE(status.ok()); + ASSERT_TRUE(boolValue); +} + +TEST_F(WebUtilTest, ParQueryBoolFalse) { + value = "False"; + + params.put("offset", value); + milvus::Status status = milvus::server::web::ParseQueryBool(params, key, boolValue); + ASSERT_TRUE(status.ok()); + ASSERT_TRUE(!boolValue); +} + +TEST_F(WebUtilTest, ParQueryBoolIllegalQuery) { + value = "Hello"; + + params.put("offset", value); + milvus::Status status = milvus::server::web::ParseQueryBool(params, key, boolValue); + ASSERT_EQ(status.code(), milvus::server::web::ILLEGAL_QUERY_PARAM); + ASSERT_STREQ(status.message().c_str(), "Query param \'all_required\' must be a bool"); +} + +TEST_F(WebUtilTest, ParQueryBoolQueryParamLoss) { + value = "Hello"; + + params.put("offset", value); + fiu_enable("WebUtils.ParseQueryBool.null_query_get", 1, nullptr, 0); + milvus::Status status = milvus::server::web::ParseQueryBool(params, key, boolValue, false); + ASSERT_EQ(status.code(), milvus::server::web::QUERY_PARAM_LOSS); + std::string msg = "Query param \"" + key + "\" is required"; + ASSERT_STREQ(status.message().c_str(), msg.c_str()); +} + +TEST_F(WebUtilTest, ParseQueryStr) { + value = "Are you ok?"; + + params.put("offset", value); + milvus::Status status = milvus::server::web::ParseQueryStr(params, key, stringValue); + ASSERT_TRUE(status.ok()); + ASSERT_STREQ(value->c_str(), stringValue.c_str()); +} + +TEST_F(WebUtilTest, ParQueryStrQueryParamLoss) { + value = "Are you ok?"; + + params.put("offset", value); + fiu_enable("WebUtils.ParseQueryStr.null_query_get", 1, nullptr, 0); + milvus::Status status = milvus::server::web::ParseQueryStr(params, key, stringValue, false); + ASSERT_EQ(status.code(), milvus::server::web::QUERY_PARAM_LOSS); + std::string msg = "Query param \"" + key + "\" is required"; + ASSERT_STREQ(status.message().c_str(), msg.c_str()); +} diff --git a/core/unittest/db/utils.cpp b/core/unittest/db/utils.cpp index 97458572e11c..d4a47c8bbb22 100644 --- a/core/unittest/db/utils.cpp +++ b/core/unittest/db/utils.cpp @@ -189,7 +189,7 @@ DBTest::SetUp() { milvus::scheduler::CPUBuilderInst::GetInstance()->Start(); auto options = GetOptions(); - options.insert_cache_immediately_ = true; + // options.insert_cache_immediately_ = true; BuildDB(options); std::string config_path(options.meta_.path_ + CONFIG_FILE); diff --git a/core/unittest/scheduler/CMakeLists.txt b/core/unittest/scheduler/CMakeLists.txt index d99bf9e85bbd..33e3caa891d1 100644 --- a/core/unittest/scheduler/CMakeLists.txt +++ b/core/unittest/scheduler/CMakeLists.txt @@ -30,6 +30,7 @@ add_executable(test_scheduler ${cache_files} ${codecs_files} ${codecs_default_files} + ${codecs_snapshot_files} ${config_files} ${config_handler_files} ${db_main_files} diff --git a/core/unittest/scheduler/test_algorithm.cpp b/core/unittest/scheduler/test_algorithm.cpp index 1a75c5451438..8b390e314e74 100644 --- a/core/unittest/scheduler/test_algorithm.cpp +++ b/core/unittest/scheduler/test_algorithm.cpp @@ -10,6 +10,8 @@ // or implied. See the License for the specific language governing permissions and limitations under the License. #include +#include +#include #include "scheduler/Algorithm.h" #include "scheduler/ResourceFactory.h" @@ -57,42 +59,129 @@ class AlgorithmTest : public testing::Test { ResourceMgrPtr res_mgr_; }; +TEST_F(AlgorithmTest, SHORTESTPATH_INVALID_PATH_TEST) { + std::vector sp; + uint64_t cost; + // disk to disk is invalid + cost = ShortestPath(disk_.lock(), disk_.lock(), res_mgr_, sp); + ASSERT_TRUE(sp.empty()); + + // cpu_0 to disk is invalid + cost = ShortestPath(cpu_0_.lock(), disk_.lock(), res_mgr_, sp); + ASSERT_TRUE(sp.empty()); + + // cpu2 to gpu0 is invalid + cost = ShortestPath(cpu_2_.lock(), gpu_0_.lock(), res_mgr_, sp); + ASSERT_TRUE(sp.empty()); + + + // gpu0 to gpu1 is invalid + cost = ShortestPath(gpu_0_.lock(), gpu_1_.lock(), res_mgr_, sp); + ASSERT_TRUE(sp.empty()); +} + TEST_F(AlgorithmTest, SHORTESTPATH_TEST) { std::vector sp; uint64_t cost; + + //disk to gpu0 + //disk -> cpu0 -> cpu1 -> gpu0 + std::cout << "************************************\n"; cost = ShortestPath(disk_.lock(), gpu_0_.lock(), res_mgr_, sp); + ASSERT_EQ(sp.size(), 4); + while (!sp.empty()) { + std::cout << sp[sp.size() - 1] << " "; + sp.pop_back(); + } + std::cout << std::endl; + + //disk to gpu1 + //disk -> cpu0 -> cpu2 -> gpu1 + std::cout << "************************************\n"; + cost = ShortestPath(disk_.lock(), gpu_1_.lock(), res_mgr_, sp); + ASSERT_EQ(sp.size(), 4); while (!sp.empty()) { - std::cout << sp[sp.size() - 1] << std::endl; + std::cout << sp[sp.size() - 1] << " "; sp.pop_back(); } + std::cout << std::endl; + // disk to cpu0 + // disk -> cpu0 + std::cout << "************************************\n"; + cost = ShortestPath(disk_.lock(), cpu_0_.lock(), res_mgr_, sp); + ASSERT_EQ(sp.size(), 2); + while (!sp.empty()) { + std::cout << sp[sp.size() - 1] << " "; + sp.pop_back(); + } + std::cout << std::endl; + + // disk to cpu1 + // disk -> cpu0 -> cpu1 + std::cout << "************************************\n"; + cost = ShortestPath(disk_.lock(), cpu_1_.lock(), res_mgr_, sp); + ASSERT_EQ(sp.size(), 3); + while (!sp.empty()) { + std::cout << sp[sp.size() - 1] << " "; + sp.pop_back(); + } + std::cout << std::endl; + + // disk to cpu2 + // disk -> cpu0 -> cpu2 + std::cout << "************************************\n"; + cost = ShortestPath(disk_.lock(), cpu_2_.lock(), res_mgr_, sp); + ASSERT_EQ(sp.size(), 3); + while (!sp.empty()) { + std::cout << sp[sp.size() - 1] << " "; + sp.pop_back(); + } + std::cout << std::endl; + + // cpu0 to gpu0 + // cpu0 -> cpu1 -> gpu0 std::cout << "************************************\n"; cost = ShortestPath(cpu_0_.lock(), gpu_0_.lock(), res_mgr_, sp); + ASSERT_EQ(sp.size(), 3); while (!sp.empty()) { - std::cout << sp[sp.size() - 1] << std::endl; + std::cout << sp[sp.size() - 1] << " "; sp.pop_back(); } + std::cout << std::endl; + // cpu0 to cpu1 + // cpu0 -> cpu1 std::cout << "************************************\n"; - cost = ShortestPath(disk_.lock(), disk_.lock(), res_mgr_, sp); + cost = ShortestPath(cpu_0_.lock(), cpu_1_.lock(), res_mgr_, sp); + ASSERT_EQ(sp.size(), 2); while (!sp.empty()) { - std::cout << sp[sp.size() - 1] << std::endl; + std::cout << sp[sp.size() - 1] << " "; sp.pop_back(); } + std::cout << std::endl; + // cpu0 to cpu2 + // cpu0 -> cpu2 std::cout << "************************************\n"; - cost = ShortestPath(cpu_0_.lock(), disk_.lock(), res_mgr_, sp); + cost = ShortestPath(cpu_0_.lock(), cpu_2_.lock(), res_mgr_, sp); + // ASSERT_EQ(sp.size(), 2); while (!sp.empty()) { - std::cout << sp[sp.size() - 1] << std::endl; + std::cout << sp[sp.size() - 1] << " "; sp.pop_back(); } + std::cout << std::endl; + // cpu0 to gpu1 + // cpu0 -> cpu2 -> gpu1 std::cout << "************************************\n"; - cost = ShortestPath(cpu_2_.lock(), gpu_0_.lock(), res_mgr_, sp); + cost = ShortestPath(cpu_0_.lock(), gpu_1_.lock(), res_mgr_, sp); + // ASSERT_EQ(sp.size(), 3); while (!sp.empty()) { - std::cout << sp[sp.size() - 1] << std::endl; + std::cout << sp[sp.size() - 1] << " "; sp.pop_back(); } + std::cout << std::endl; } } // namespace scheduler diff --git a/core/unittest/server/test_cache.cpp b/core/unittest/server/test_cache.cpp index 7517f537b075..369c91f45504 100644 --- a/core/unittest/server/test_cache.cpp +++ b/core/unittest/server/test_cache.cpp @@ -38,8 +38,7 @@ class LessItemCacheMgr : public milvus::cache::CacheMgr #include "ssdb/utils.h" -#include "db/snapshot/CompoundOperations.h" -#include "db/snapshot/Context.h" -#include "db/snapshot/EventExecutor.h" -#include "db/snapshot/OperationExecutor.h" -#include "db/snapshot/ReferenceProxy.h" -#include "db/snapshot/ResourceHolders.h" -#include "db/snapshot/ScopedResource.h" -#include "db/snapshot/Snapshots.h" -#include "db/snapshot/Store.h" -#include "db/snapshot/WrappedTypes.h" - -using ID_TYPE = milvus::engine::snapshot::ID_TYPE; -using IDS_TYPE = milvus::engine::snapshot::IDS_TYPE; -using LSN_TYPE = milvus::engine::snapshot::LSN_TYPE; -using MappingT = milvus::engine::snapshot::MappingT; -using LoadOperationContext = milvus::engine::snapshot::LoadOperationContext; -using CreateCollectionContext = milvus::engine::snapshot::CreateCollectionContext; -using SegmentFileContext = milvus::engine::snapshot::SegmentFileContext; -using OperationContext = milvus::engine::snapshot::OperationContext; -using PartitionContext = milvus::engine::snapshot::PartitionContext; -using BuildOperation = milvus::engine::snapshot::BuildOperation; -using MergeOperation = milvus::engine::snapshot::MergeOperation; -using CreateCollectionOperation = milvus::engine::snapshot::CreateCollectionOperation; -using NewSegmentOperation = milvus::engine::snapshot::NewSegmentOperation; -using DropPartitionOperation = milvus::engine::snapshot::DropPartitionOperation; -using CreatePartitionOperation = milvus::engine::snapshot::CreatePartitionOperation; -using DropCollectionOperation = milvus::engine::snapshot::DropCollectionOperation; -using CollectionCommitsHolder = milvus::engine::snapshot::CollectionCommitsHolder; -using CollectionsHolder = milvus::engine::snapshot::CollectionsHolder; -using CollectionScopedT = milvus::engine::snapshot::CollectionScopedT; -using Collection = milvus::engine::snapshot::Collection; -using CollectionPtr = milvus::engine::snapshot::CollectionPtr; -using Partition = milvus::engine::snapshot::Partition; -using PartitionPtr = milvus::engine::snapshot::PartitionPtr; -using Segment = milvus::engine::snapshot::Segment; -using SegmentPtr = milvus::engine::snapshot::SegmentPtr; -using SegmentFile = milvus::engine::snapshot::SegmentFile; -using SegmentFilePtr = milvus::engine::snapshot::SegmentFilePtr; -using Field = milvus::engine::snapshot::Field; -using FieldElement = milvus::engine::snapshot::FieldElement; -using Snapshots = milvus::engine::snapshot::Snapshots; -using ScopedSnapshotT = milvus::engine::snapshot::ScopedSnapshotT; -using ReferenceProxy = milvus::engine::snapshot::ReferenceProxy; -using Queue = milvus::BlockingQueue; -using TQueue = milvus::BlockingQueue>; -using SoftDeleteCollectionOperation = milvus::engine::snapshot::SoftDeleteOperation; -using ParamsField = milvus::engine::snapshot::ParamsField; -using IteratePartitionHandler = milvus::engine::snapshot::IterateHandler; -using SSDBImpl = milvus::engine::SSDBImpl; +#include "db/SnapshotVisitor.h" +#include "db/snapshot/IterateHandler.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" +using SegmentVisitor = milvus::engine::SegmentVisitor; + +namespace { milvus::Status CreateCollection(std::shared_ptr db, const std::string& collection_name, const LSN_TYPE& lsn) { CreateCollectionContext context; @@ -75,17 +32,103 @@ CreateCollection(std::shared_ptr db, const std::string& collection_nam auto collection_schema = std::make_shared(collection_name); context.collection = collection_schema; auto vector_field = std::make_shared("vector", 0, - milvus::engine::snapshot::FieldType::VECTOR); + milvus::engine::FieldType::VECTOR); auto vector_field_element = std::make_shared(0, 0, "ivfsq8", - milvus::engine::snapshot::FieldElementType::IVFSQ8); + milvus::engine::FieldElementType::FET_INDEX); auto int_field = std::make_shared("int", 0, - milvus::engine::snapshot::FieldType::INT32); + milvus::engine::FieldType::INT32); context.fields_schema[vector_field] = {vector_field_element}; context.fields_schema[int_field] = {}; return db->CreateCollection(context); } +static constexpr int64_t COLLECTION_DIM = 128; + +milvus::Status +CreateCollection2(std::shared_ptr db, const std::string& collection_name, const LSN_TYPE& lsn) { + CreateCollectionContext context; + context.lsn = lsn; + auto collection_schema = std::make_shared(collection_name); + context.collection = collection_schema; + + nlohmann::json params; + params[milvus::knowhere::meta::DIM] = COLLECTION_DIM; + auto vector_field = std::make_shared("vector", 0, milvus::engine::FieldType::VECTOR, params); + context.fields_schema[vector_field] = {}; + + std::unordered_map attr_type = { + {"field_0", milvus::engine::FieldType::INT32}, + {"field_1", milvus::engine::FieldType::INT64}, + {"field_2", milvus::engine::FieldType::DOUBLE}, + }; + + std::vector field_names; + for (auto& pair : attr_type) { + auto field = std::make_shared(pair.first, 0, pair.second); + context.fields_schema[field] = {}; + field_names.push_back(pair.first); + } + + return db->CreateCollection(context); +} + +void +BuildEntities(uint64_t n, uint64_t batch_index, milvus::engine::DataChunkPtr& data_chunk) { + data_chunk = std::make_shared(); + data_chunk->count_ = n; + + milvus::engine::VectorsData vectors; + vectors.vector_count_ = n; + vectors.float_data_.clear(); + vectors.float_data_.resize(n * COLLECTION_DIM); + float* data = vectors.float_data_.data(); + for (uint64_t i = 0; i < n; i++) { + for (int64_t j = 0; j < COLLECTION_DIM; j++) data[COLLECTION_DIM * i + j] = drand48(); + data[COLLECTION_DIM * i] += i / 2000.; + + vectors.id_array_.push_back(n * batch_index + i); + } + + milvus::engine::FIXED_FIELD_DATA& raw = data_chunk->fixed_fields_["vector"]; + raw.resize(vectors.float_data_.size() * sizeof(float)); + memcpy(raw.data(), vectors.float_data_.data(), vectors.float_data_.size() * sizeof(float)); + + std::vector value_0; + std::vector value_1; + std::vector value_2; + value_0.resize(n); + value_1.resize(n); + value_2.resize(n); + + std::default_random_engine e; + std::uniform_real_distribution u(0, 1); + for (uint64_t i = 0; i < n; ++i) { + value_0[i] = i; + value_1[i] = i + n; + value_2[i] = u(e); + } + + { + milvus::engine::FIXED_FIELD_DATA& raw = data_chunk->fixed_fields_["field_0"]; + raw.resize(value_0.size() * sizeof(int32_t)); + memcpy(raw.data(), value_0.data(), value_0.size() * sizeof(int32_t)); + } + + { + milvus::engine::FIXED_FIELD_DATA& raw = data_chunk->fixed_fields_["field_1"]; + raw.resize(value_1.size() * sizeof(int64_t)); + memcpy(raw.data(), value_1.data(), value_1.size() * sizeof(int64_t)); + } + + { + milvus::engine::FIXED_FIELD_DATA& raw = data_chunk->fixed_fields_["field_2"]; + raw.resize(value_2.size() * sizeof(double)); + memcpy(raw.data(), value_2.data(), value_2.size() * sizeof(double)); + } +} +} // namespace + TEST_F(SSDBTest, CollectionTest) { LSN_TYPE lsn = 0; auto next_lsn = [&]() -> decltype(lsn) { @@ -106,6 +149,12 @@ TEST_F(SSDBTest, CollectionTest) { ASSERT_TRUE(has); ASSERT_TRUE(status.ok()); + ASSERT_EQ(ss->GetCollectionCommit()->GetRowCount(), 0); + milvus::engine::snapshot::SIZE_TYPE row_cnt = 0; + status = db_->GetCollectionRowCount(c1, row_cnt); + ASSERT_TRUE(status.ok()); + ASSERT_EQ(row_cnt, 0); + std::vector names; status = db_->AllCollections(names); ASSERT_TRUE(status.ok()); @@ -135,3 +184,265 @@ TEST_F(SSDBTest, CollectionTest) { status = db_->DropCollection(c1); ASSERT_FALSE(status.ok()); } + +TEST_F(SSDBTest, PartitionTest) { + LSN_TYPE lsn = 0; + auto next_lsn = [&]() -> decltype(lsn) { + return ++lsn; + }; + std::string c1 = "c1"; + auto status = CreateCollection(db_, c1, next_lsn()); + ASSERT_TRUE(status.ok()); + + std::vector partition_names; + status = db_->ShowPartitions(c1, partition_names); + ASSERT_EQ(partition_names.size(), 1); + ASSERT_EQ(partition_names[0], "_default"); + + std::string p1 = "p1"; + std::string c2 = "c2"; + status = db_->CreatePartition(c2, p1); + ASSERT_FALSE(status.ok()); + + status = db_->CreatePartition(c1, p1); + ASSERT_TRUE(status.ok()); + + status = db_->ShowPartitions(c1, partition_names); + ASSERT_TRUE(status.ok()); + ASSERT_EQ(partition_names.size(), 2); + + status = db_->CreatePartition(c1, p1); + ASSERT_FALSE(status.ok()); + + status = db_->DropPartition(c1, "p3"); + ASSERT_FALSE(status.ok()); + + status = db_->DropPartition(c1, p1); + ASSERT_TRUE(status.ok()); + status = db_->ShowPartitions(c1, partition_names); + ASSERT_TRUE(status.ok()); + ASSERT_EQ(partition_names.size(), 1); +} + +TEST_F(SSDBTest, IndexTest) { + LSN_TYPE lsn = 0; + auto next_lsn = [&]() -> decltype(lsn) { + return ++lsn; + }; + + std::string c1 = "c1"; + auto status = CreateCollection(db_, c1, next_lsn()); + ASSERT_TRUE(status.ok()); + + std::stringstream p_name; + auto num = RandomInt(3, 5); + for (auto i = 0; i < num; ++i) { + p_name.str(""); + p_name << "partition_" << i; + status = db_->CreatePartition(c1, p_name.str()); + ASSERT_TRUE(status.ok()); + } + + ScopedSnapshotT ss; + status = Snapshots::GetInstance().GetSnapshot(ss, c1); + ASSERT_TRUE(status.ok()); + + SegmentFileContext sf_context; + SFContextBuilder(sf_context, ss); + + auto new_total = 0; + auto& partitions = ss->GetResources(); + for (auto& kv : partitions) { + num = RandomInt(2, 5); + auto row_cnt = 100; + for (auto i = 0; i < num; ++i) { + ASSERT_TRUE(CreateSegment(ss, kv.first, next_lsn(), sf_context, row_cnt).ok()); + } + new_total += num; + } + + auto field_element_id = ss->GetFieldElementId(sf_context.field_name, sf_context.field_element_name); + ASSERT_NE(field_element_id, 0); + + auto filter1 = [&](SegmentFile::Ptr segment_file) -> bool { + if (segment_file->GetFieldElementId() == field_element_id) { + return true; + } + return false; + }; + + status = Snapshots::GetInstance().GetSnapshot(ss, c1); + ASSERT_TRUE(status.ok()); + auto sf_collector = std::make_shared(ss, filter1); + sf_collector->Iterate(); + ASSERT_EQ(new_total, sf_collector->segment_files_.size()); + + status = db_->DropIndex(c1, sf_context.field_name); +// ASSERT_TRUE(status.ok()); + +// status = Snapshots::GetInstance().GetSnapshot(ss, c1); +// ASSERT_TRUE(status.ok()); +// sf_collector = std::make_shared(ss, filter1); +// sf_collector->Iterate(); +// ASSERT_EQ(0, sf_collector->segment_files_.size()); +// +// { +// auto& field_elements = ss->GetResources(); +// for (auto& kv : field_elements) { +// ASSERT_NE(kv.second->GetID(), field_element_id); +// } +// } +} + +TEST_F(SSDBTest, VisitorTest) { + LSN_TYPE lsn = 0; + auto next_lsn = [&]() -> decltype(lsn) { + return ++lsn; + }; + + std::string c1 = "c1"; + auto status = CreateCollection(db_, c1, next_lsn()); + ASSERT_TRUE(status.ok()); + + std::stringstream p_name; + auto num = RandomInt(1, 3); + for (auto i = 0; i < num; ++i) { + p_name.str(""); + p_name << "partition_" << i; + status = db_->CreatePartition(c1, p_name.str()); + ASSERT_TRUE(status.ok()); + } + + ScopedSnapshotT ss; + status = Snapshots::GetInstance().GetSnapshot(ss, c1); + ASSERT_TRUE(status.ok()); + + SegmentFileContext sf_context; + SFContextBuilder(sf_context, ss); + + auto new_total = 0; + auto& partitions = ss->GetResources(); + ID_TYPE partition_id; + for (auto& kv : partitions) { + num = RandomInt(1, 3); + auto row_cnt = 100; + for (auto i = 0; i < num; ++i) { + ASSERT_TRUE(CreateSegment(ss, kv.first, next_lsn(), sf_context, row_cnt).ok()); + } + new_total += num; + partition_id = kv.first; + } + + status = Snapshots::GetInstance().GetSnapshot(ss, c1); + ASSERT_TRUE(status.ok()); + + auto executor = [&](const Segment::Ptr& segment, SegmentIterator* handler) -> Status { + auto visitor = SegmentVisitor::Build(ss, segment->GetID()); + if (!visitor) { + return Status(milvus::SS_ERROR, "Cannot build segment visitor"); + } + std::cout << visitor->ToString() << std::endl; + return Status::OK(); + }; + + auto segment_handler = std::make_shared(ss, executor); + segment_handler->Iterate(); + std::cout << segment_handler->GetStatus().ToString() << std::endl; + ASSERT_TRUE(segment_handler->GetStatus().ok()); + + auto row_cnt = ss->GetCollectionCommit()->GetRowCount(); + auto new_segment_row_cnt = 1024; + { + OperationContext context; + context.lsn = next_lsn(); + context.prev_partition = ss->GetResource(partition_id); + auto op = std::make_shared(context, ss); + SegmentPtr new_seg; + status = op->CommitNewSegment(new_seg); + ASSERT_TRUE(status.ok()); + SegmentFilePtr seg_file; + auto nsf_context = sf_context; + nsf_context.segment_id = new_seg->GetID(); + nsf_context.partition_id = new_seg->GetPartitionId(); + status = op->CommitNewSegmentFile(nsf_context, seg_file); + ASSERT_TRUE(status.ok()); + auto ctx = op->GetContext(); + ASSERT_TRUE(ctx.new_segment); + auto visitor = SegmentVisitor::Build(ss, ctx.new_segment, ctx.new_segment_files); + ASSERT_TRUE(visitor); + ASSERT_EQ(visitor->GetSegment(), new_seg); + ASSERT_FALSE(visitor->GetSegment()->IsActive()); + + int file_num = 0; + auto field_visitors = visitor->GetFieldVisitors(); + for (auto& kv : field_visitors) { + auto& field_visitor = kv.second; + auto field_element_visitors = field_visitor->GetElementVistors(); + for (auto& kkvv : field_element_visitors) { + auto& field_element_visitor = kkvv.second; + auto file = field_element_visitor->GetFile(); + if (file) { + file_num++; + ASSERT_FALSE(file->IsActive()); + } + } + } + ASSERT_EQ(file_num, 1); + + std::cout << visitor->ToString() << std::endl; + status = op->CommitRowCount(new_segment_row_cnt); + status = op->Push(); + ASSERT_TRUE(status.ok()); + } + status = Snapshots::GetInstance().GetSnapshot(ss, c1); + ASSERT_TRUE(status.ok()); + ASSERT_EQ(ss->GetCollectionCommit()->GetRowCount(), row_cnt + new_segment_row_cnt); + std::cout << ss->ToString() << std::endl; +} + +TEST_F(SSDBTest, InsertTest) { + std::string collection_name = "MERGE_TEST"; + auto status = CreateCollection2(db_, collection_name, 0); + ASSERT_TRUE(status.ok()); + + const uint64_t entity_count = 100; + milvus::engine::DataChunkPtr data_chunk; + BuildEntities(entity_count, 0, data_chunk); + + status = db_->InsertEntities(collection_name, "", data_chunk); + ASSERT_TRUE(status.ok()); + + status = db_->Flush(); + ASSERT_TRUE(status.ok()); + + uint64_t row_count = 0; + status = db_->GetCollectionRowCount(collection_name, row_count); + ASSERT_TRUE(status.ok()); + ASSERT_EQ(row_count, entity_count); +} + +TEST_F(SSDBTest, MergeTest) { + std::string collection_name = "MERGE_TEST"; + auto status = CreateCollection2(db_, collection_name, 0); + ASSERT_TRUE(status.ok()); + + const uint64_t entity_count = 100; + milvus::engine::DataChunkPtr data_chunk; + BuildEntities(entity_count, 0, data_chunk); + + int64_t repeat = 2; + for (int32_t i = 0; i < repeat; i++) { + status = db_->InsertEntities(collection_name, "", data_chunk); + ASSERT_TRUE(status.ok()); + + status = db_->Flush(); + ASSERT_TRUE(status.ok()); + } + + sleep(2); // wait to merge + + uint64_t row_count = 0; + status = db_->GetCollectionRowCount(collection_name, row_count); + ASSERT_TRUE(status.ok()); + ASSERT_EQ(row_count, entity_count * repeat); +} diff --git a/core/unittest/ssdb/test_segment.cpp b/core/unittest/ssdb/test_segment.cpp new file mode 100644 index 000000000000..82ce7d94c8fa --- /dev/null +++ b/core/unittest/ssdb/test_segment.cpp @@ -0,0 +1,159 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include +#include +#include + +#include + +#include "ssdb/utils.h" +#include "db/SnapshotVisitor.h" +#include "db/Types.h" +#include "db/snapshot/IterateHandler.h" +#include "db/snapshot/Resources.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" +#include "segment/SSSegmentReader.h" +#include "segment/SSSegmentWriter.h" +#include "segment/Types.h" +#include "utils/Json.h" + +using SegmentVisitor = milvus::engine::SegmentVisitor; + +namespace { +milvus::Status +CreateCollection(std::shared_ptr db, const std::string& collection_name, const LSN_TYPE& lsn) { + CreateCollectionContext context; + context.lsn = lsn; + auto collection_schema = std::make_shared(collection_name); + context.collection = collection_schema; + + int64_t collection_id = 0; + int64_t field_id = 0; + /* field uid */ + auto uid_field = std::make_shared(milvus::engine::DEFAULT_UID_NAME, 0, + milvus::engine::FieldType::UID, milvus::engine::snapshot::JEmpty, field_id); + auto uid_field_element_blt = std::make_shared(collection_id, field_id, + milvus::engine::DEFAULT_BLOOM_FILTER_NAME, milvus::engine::FieldElementType::FET_BLOOM_FILTER); + auto uid_field_element_del = std::make_shared(collection_id, field_id, + milvus::engine::DEFAULT_DELETED_DOCS_NAME, milvus::engine::FieldElementType::FET_DELETED_DOCS); + + field_id++; + /* field vector */ + milvus::json vector_param = {{milvus::knowhere::meta::DIM, 4}}; + auto vector_field = std::make_shared("vector", 0, milvus::engine::FieldType::VECTOR_FLOAT, vector_param, + field_id); + auto vector_field_element_index = std::make_shared(collection_id, field_id, + milvus::engine::DEFAULT_INDEX_NAME, milvus::engine::FieldElementType::FET_INDEX); + + context.fields_schema[uid_field] = {uid_field_element_blt, uid_field_element_del}; + context.fields_schema[vector_field] = {vector_field_element_index}; + + return db->CreateCollection(context); +} +} // namespace + +TEST_F(SSSegmentTest, SegmentTest) { + LSN_TYPE lsn = 0; + auto next_lsn = [&]() -> decltype(lsn) { + return ++lsn; + }; + + std::string db_root = "/tmp/milvus_test/db/table"; + std::string c1 = "c1"; + auto status = CreateCollection(db_, c1, next_lsn()); + ASSERT_TRUE(status.ok()); + + ScopedSnapshotT ss; + status = Snapshots::GetInstance().GetSnapshot(ss, c1); + ASSERT_TRUE(status.ok()); + ASSERT_TRUE(ss); + ASSERT_EQ(ss->GetName(), c1); + + SegmentFileContext sf_context; + SFContextBuilder(sf_context, ss); + + std::vector contexts; + SFContextsBuilder(contexts, ss); + + + // std::cout << ss->ToString() << std::endl; + + auto& partitions = ss->GetResources(); + ID_TYPE partition_id; + for (auto& kv : partitions) { + /* select the first partition */ + partition_id = kv.first; + break; + } + + std::vector raw_uids = {123}; + std::vector raw_vectors = {1, 2, 3, 4}; + + { + /* commit new segment */ + OperationContext context; + context.lsn = next_lsn(); + context.prev_partition = ss->GetResource(partition_id); + auto op = std::make_shared(context, ss); + SegmentPtr new_seg; + status = op->CommitNewSegment(new_seg); + ASSERT_TRUE(status.ok()); + + /* commit new segment file */ + for (auto& cctx : contexts) { + SegmentFilePtr seg_file; + auto nsf_context = cctx; + nsf_context.segment_id = new_seg->GetID(); + nsf_context.partition_id = new_seg->GetPartitionId(); + status = op->CommitNewSegmentFile(nsf_context, seg_file); + } + + /* build segment visitor */ + auto ctx = op->GetContext(); + ASSERT_TRUE(ctx.new_segment); + auto visitor = SegmentVisitor::Build(ss, ctx.new_segment, ctx.new_segment_files); + ASSERT_TRUE(visitor); + ASSERT_EQ(visitor->GetSegment(), new_seg); + ASSERT_FALSE(visitor->GetSegment()->IsActive()); + // std::cout << visitor->ToString() << std::endl; + // std::cout << ss->ToString() << std::endl; + + /* write data */ + milvus::segment::SSSegmentWriter segment_writer(db_root, visitor); + +// status = segment_writer.AddChunk("test", raw_vectors, raw_uids); +// ASSERT_TRUE(status.ok()) +// +// status = segment_writer.Serialize(); +// ASSERT_TRUE(status.ok()); + + /* read data */ +// milvus::segment::SSSegmentReader segment_reader(db_root, visitor); +// +// status = segment_reader.Load(); +// ASSERT_TRUE(status.ok()); +// +// milvus::segment::SegmentPtr segment_ptr; +// status = segment_reader.GetSegment(segment_ptr); +// ASSERT_TRUE(status.ok()); +// +// auto& out_uids = segment_ptr->vectors_ptr_->GetUids(); +// ASSERT_EQ(raw_uids.size(), out_uids.size()); +// ASSERT_EQ(raw_uids[0], out_uids[0]); +// auto& out_vectors = segment_ptr->vectors_ptr_->GetData(); +// ASSERT_EQ(raw_vectors.size(), out_vectors.size()); +// ASSERT_EQ(raw_vectors[0], out_vectors[0]); + } + + status = db_->DropCollection(c1); + ASSERT_TRUE(status.ok()); +} diff --git a/core/unittest/ssdb/test_snapshot.cpp b/core/unittest/ssdb/test_snapshot.cpp index b41a6a591325..27206afc91f9 100644 --- a/core/unittest/ssdb/test_snapshot.cpp +++ b/core/unittest/ssdb/test_snapshot.cpp @@ -18,163 +18,26 @@ #include #include "ssdb/utils.h" -#include "db/snapshot/CompoundOperations.h" -#include "db/snapshot/Context.h" -#include "db/snapshot/EventExecutor.h" -#include "db/snapshot/OperationExecutor.h" -#include "db/snapshot/ReferenceProxy.h" -#include "db/snapshot/ResourceHolders.h" -#include "db/snapshot/ScopedResource.h" -#include "db/snapshot/Snapshots.h" -#include "db/snapshot/Store.h" -#include "db/snapshot/WrappedTypes.h" - -using ID_TYPE = milvus::engine::snapshot::ID_TYPE; -using IDS_TYPE = milvus::engine::snapshot::IDS_TYPE; -using LSN_TYPE = milvus::engine::snapshot::LSN_TYPE; -using MappingT = milvus::engine::snapshot::MappingT; -using LoadOperationContext = milvus::engine::snapshot::LoadOperationContext; -using CreateCollectionContext = milvus::engine::snapshot::CreateCollectionContext; -using SegmentFileContext = milvus::engine::snapshot::SegmentFileContext; -using OperationContext = milvus::engine::snapshot::OperationContext; -using PartitionContext = milvus::engine::snapshot::PartitionContext; -using BuildOperation = milvus::engine::snapshot::BuildOperation; -using MergeOperation = milvus::engine::snapshot::MergeOperation; -using CreateCollectionOperation = milvus::engine::snapshot::CreateCollectionOperation; -using NewSegmentOperation = milvus::engine::snapshot::NewSegmentOperation; -using DropPartitionOperation = milvus::engine::snapshot::DropPartitionOperation; -using CreatePartitionOperation = milvus::engine::snapshot::CreatePartitionOperation; -using DropCollectionOperation = milvus::engine::snapshot::DropCollectionOperation; -using CollectionCommitsHolder = milvus::engine::snapshot::CollectionCommitsHolder; -using CollectionsHolder = milvus::engine::snapshot::CollectionsHolder; -using CollectionScopedT = milvus::engine::snapshot::CollectionScopedT; -using Collection = milvus::engine::snapshot::Collection; -using CollectionPtr = milvus::engine::snapshot::CollectionPtr; -using Partition = milvus::engine::snapshot::Partition; -using PartitionPtr = milvus::engine::snapshot::PartitionPtr; -using Segment = milvus::engine::snapshot::Segment; -using SegmentPtr = milvus::engine::snapshot::SegmentPtr; -using SegmentFile = milvus::engine::snapshot::SegmentFile; -using SegmentFilePtr = milvus::engine::snapshot::SegmentFilePtr; -using Field = milvus::engine::snapshot::Field; -using FieldElement = milvus::engine::snapshot::FieldElement; -using Snapshots = milvus::engine::snapshot::Snapshots; -using ScopedSnapshotT = milvus::engine::snapshot::ScopedSnapshotT; -using ReferenceProxy = milvus::engine::snapshot::ReferenceProxy; -using Queue = milvus::BlockingQueue; -using TQueue = milvus::BlockingQueue>; -using SoftDeleteCollectionOperation = milvus::engine::snapshot::SoftDeleteOperation; -using ParamsField = milvus::engine::snapshot::ParamsField; -using IteratePartitionHandler = milvus::engine::snapshot::IterateHandler; - -struct PartitionCollector : public IteratePartitionHandler { - using ResourceT = Partition; - using BaseT = IteratePartitionHandler; - explicit PartitionCollector(ScopedSnapshotT ss) : BaseT(ss) {} - - milvus::Status - PreIterate() override { - partition_names_.clear(); - return milvus::Status::OK(); +#include "db/snapshot/HandlerFactory.h" + +Status +GetFirstCollectionID(ID_TYPE& result_id) { + std::vector ids; + auto status = Snapshots::GetInstance().GetCollectionIds(ids); + if (status.ok()) { + result_id = ids.at(0); } - milvus::Status - Handle(const typename ResourceT::Ptr& partition) override { - partition_names_.push_back(partition->GetName()); - return milvus::Status::OK(); - } - - std::vector partition_names_; -}; - -struct WaitableObj { - bool notified_ = false; - std::mutex mutex_; - std::condition_variable cv_; - - void - Wait() { - std::unique_lock lck(mutex_); - if (!notified_) { - cv_.wait(lck); - } - notified_ = false; - } - - void - Notify() { - std::unique_lock lck(mutex_); - notified_ = true; - lck.unlock(); - cv_.notify_one(); - } -}; - -ScopedSnapshotT -CreateCollection(const std::string& collection_name, const LSN_TYPE& lsn) { - CreateCollectionContext context; - context.lsn = lsn; - auto collection_schema = std::make_shared(collection_name); - context.collection = collection_schema; - auto vector_field = std::make_shared("vector", 0, - milvus::engine::snapshot::FieldType::VECTOR); - auto vector_field_element = std::make_shared(0, 0, "ivfsq8", - milvus::engine::snapshot::FieldElementType::IVFSQ8); - auto int_field = std::make_shared("int", 0, - milvus::engine::snapshot::FieldType::INT32); - context.fields_schema[vector_field] = {vector_field_element}; - context.fields_schema[int_field] = {}; - - auto op = std::make_shared(context); - op->Push(); - ScopedSnapshotT ss; - auto status = op->GetSnapshot(ss); - return ss; -} - -ScopedSnapshotT -CreatePartition(const std::string& collection_name, const PartitionContext& p_context, const LSN_TYPE& lsn) { - ScopedSnapshotT curr_ss; - ScopedSnapshotT ss; - auto status = Snapshots::GetInstance().GetSnapshot(ss, collection_name); - if (!status.ok()) { - std::cout << status.ToString() << std::endl; - return curr_ss; - } - - OperationContext context; - context.lsn = lsn; - auto op = std::make_shared(context, ss); - - PartitionPtr partition; - status = op->CommitNewPartition(p_context, partition); - if (!status.ok()) { - std::cout << status.ToString() << std::endl; - return curr_ss; - } - - status = op->Push(); - if (!status.ok()) { - std::cout << status.ToString() << std::endl; - return curr_ss; - } - - status = op->GetSnapshot(curr_ss); - if (!status.ok()) { - std::cout << status.ToString() << std::endl; - return curr_ss; - } - return curr_ss; + return status; } TEST_F(SnapshotTest, ResourcesTest) { int nprobe = 16; milvus::json params = {{"nprobe", nprobe}}; - ParamsField p_field(params.dump()); - ASSERT_EQ(params.dump(), p_field.GetParams()); - ASSERT_EQ(params, p_field.GetParamsJson()); + ParamsField p_field(params); + ASSERT_EQ(params, p_field.GetParams()); - auto nprobe_real = p_field.GetParamsJson().at("nprobe").get(); + auto nprobe_real = p_field.GetParams().at("nprobe").get(); ASSERT_EQ(nprobe, nprobe_real); } @@ -263,7 +126,8 @@ TEST_F(SnapshotTest, ScopedResourceTest) { } TEST_F(SnapshotTest, ResourceHoldersTest) { - ID_TYPE collection_id = 1; + ID_TYPE collection_id; + ASSERT_TRUE(GetFirstCollectionID(collection_id).ok()); auto collection = CollectionsHolder::GetInstance().GetResource(collection_id, false); auto prev_cnt = collection->ref_count(); { @@ -278,6 +142,8 @@ TEST_F(SnapshotTest, ResourceHoldersTest) { ASSERT_EQ(collection_3->ref_count(), 1+prev_cnt); } + std::this_thread::sleep_for(std::chrono::milliseconds(80)); + if (prev_cnt == 0) { auto collection_4 = CollectionsHolder::GetInstance().GetResource(collection_id, false); ASSERT_TRUE(!collection_4); @@ -405,6 +271,7 @@ TEST_F(SnapshotTest, DropCollectionTest) { auto ss_2 = CreateCollection(collection_name, ++lsn); status = Snapshots::GetInstance().GetSnapshot(lss, collection_name); +// EXPECT_DEATH({assert(1 == 2);}, "nullptr") ASSERT_TRUE(status.ok()); ASSERT_EQ(ss_2->GetID(), lss->GetID()); ASSERT_NE(prev_ss_id, ss_2->GetID()); @@ -421,7 +288,7 @@ TEST_F(SnapshotTest, ConCurrentCollectionOperation) { ID_TYPE stale_ss_id; auto worker1 = [&]() { - milvus::Status status; + Status status; auto ss = CreateCollection(collection_name, ++lsn); ASSERT_TRUE(ss); ASSERT_EQ(ss->GetName(), collection_name); @@ -563,48 +430,263 @@ TEST_F(SnapshotTest, PartitionTest) { } } -// TODO: Open this test later -/* TEST_F(SnapshotTest, PartitionTest2) { */ -/* std::string collection_name("c1"); */ -/* LSN_TYPE lsn = 1; */ -/* milvus::Status status; */ - -/* auto ss = CreateCollection(collection_name, ++lsn); */ -/* ASSERT_TRUE(ss); */ -/* ASSERT_EQ(lsn, ss->GetMaxLsn()); */ - -/* OperationContext context; */ -/* context.lsn = lsn; */ -/* auto cp_op = std::make_shared(context, ss); */ -/* std::string partition_name("p1"); */ -/* PartitionContext p_ctx; */ -/* p_ctx.name = partition_name; */ -/* PartitionPtr partition; */ -/* status = cp_op->CommitNewPartition(p_ctx, partition); */ -/* ASSERT_TRUE(status.ok()); */ -/* ASSERT_TRUE(partition); */ -/* ASSERT_EQ(partition->GetName(), partition_name); */ -/* ASSERT_FALSE(partition->IsActive()); */ -/* ASSERT_TRUE(partition->HasAssigned()); */ - -/* status = cp_op->Push(); */ -/* ASSERT_FALSE(status.ok()); */ -/* } */ +TEST_F(SnapshotTest, PartitionTest2) { + std::string collection_name("c1"); + LSN_TYPE lsn = 1; + milvus::Status status; + + auto ss = CreateCollection(collection_name, ++lsn); + ASSERT_TRUE(ss); + ASSERT_EQ(lsn, ss->GetMaxLsn()); + + OperationContext context; + context.lsn = lsn; + auto cp_op = std::make_shared(context, ss); + std::string partition_name("p1"); + PartitionContext p_ctx; + p_ctx.name = partition_name; + PartitionPtr partition; + status = cp_op->CommitNewPartition(p_ctx, partition); + ASSERT_TRUE(status.ok()); + ASSERT_TRUE(partition); + ASSERT_EQ(partition->GetName(), partition_name); + ASSERT_FALSE(partition->IsActive()); + ASSERT_TRUE(partition->HasAssigned()); + + status = cp_op->Push(); + ASSERT_FALSE(status.ok()); +} + +TEST_F(SnapshotTest, IndexTest) { + LSN_TYPE lsn = 0; + auto next_lsn = [&]() -> decltype(lsn) { + return ++lsn; + }; + + std::vector ids; + auto status = Snapshots::GetInstance().GetCollectionIds(ids); + ASSERT_TRUE(status.ok()) << status.message(); + + auto collection_id = ids.at(0); + + ScopedSnapshotT ss; + status = Snapshots::GetInstance().GetSnapshot(ss, collection_id); + ASSERT_TRUE(status.ok()) << status.message(); + + SegmentFileContext sf_context; + SFContextBuilder(sf_context, ss); + + OperationContext context; + context.lsn = next_lsn(); + context.prev_partition = ss->GetResource(sf_context.partition_id); + auto build_op = std::make_shared(context, ss); + SegmentFilePtr seg_file; + status = build_op->CommitNewSegmentFile(sf_context, seg_file); + ASSERT_TRUE(status.ok()); + ASSERT_TRUE(seg_file); + auto op_ctx = build_op->GetContext(); + ASSERT_EQ(seg_file, op_ctx.new_segment_files[0]); + + build_op->Push(); + status = build_op->GetSnapshot(ss); + ASSERT_TRUE(status.ok()) << status.message(); + + auto filter = [&](SegmentFile::Ptr segment_file) -> bool { + return segment_file->GetSegmentId() == seg_file->GetSegmentId(); + }; + + auto filter2 = [&](SegmentFile::Ptr segment_file) -> bool { + return true; + }; + + auto sf_collector = std::make_shared(ss, filter); + sf_collector->Iterate(); + + auto it_found = sf_collector->segment_files_.find(seg_file->GetID()); + ASSERT_NE(it_found, sf_collector->segment_files_.end()); + + status = Snapshots::GetInstance().GetSnapshot(ss, collection_id); + ASSERT_TRUE(status.ok()) << status.ToString(); + + OperationContext drop_ctx; + drop_ctx.lsn = next_lsn(); + drop_ctx.stale_segment_files.push_back(seg_file); + auto drop_op = std::make_shared(drop_ctx, ss); + status = drop_op->Push(); + ASSERT_TRUE(status.ok()); + + status = drop_op->GetSnapshot(ss); + ASSERT_TRUE(status.ok()); + + sf_collector = std::make_shared(ss, filter); + sf_collector->Iterate(); + + it_found = sf_collector->segment_files_.find(seg_file->GetID()); + ASSERT_EQ(it_found, sf_collector->segment_files_.end()); + + PartitionContext pp_ctx; + std::stringstream p_name_stream; + + auto num = RandomInt(3, 5); + for (auto i = 0; i < num; ++i) { + p_name_stream.str(""); + p_name_stream << "partition_" << i; + pp_ctx.name = p_name_stream.str(); + ss = CreatePartition(ss->GetName(), pp_ctx, next_lsn()); + ASSERT_TRUE(ss); + } + ASSERT_EQ(ss->NumberOfPartitions(), num + 1); + + sf_collector = std::make_shared(ss, filter2); + sf_collector->Iterate(); + auto prev_total = sf_collector->segment_files_.size(); + + auto new_total = 0; + auto partitions = ss->GetResources(); + for (auto& kv : partitions) { + num = RandomInt(2, 5); + auto row_cnt = 1024; + for (auto i = 0; i < num; ++i) { + ASSERT_TRUE(CreateSegment(ss, kv.first, next_lsn(), sf_context, row_cnt).ok()); + } + new_total += num; + } + + status = Snapshots::GetInstance().GetSnapshot(ss, ss->GetName()); + ASSERT_TRUE(status.ok()); + + sf_collector = std::make_shared(ss, filter2); + sf_collector->Iterate(); + auto total = sf_collector->segment_files_.size(); + ASSERT_EQ(total, prev_total + new_total); + + auto field_element_id = ss->GetFieldElementId(sf_context.field_name, + sf_context.field_element_name); + ASSERT_NE(field_element_id, 0); + + auto filter3 = [&](SegmentFile::Ptr segment_file) -> bool { + return segment_file->GetFieldElementId() == field_element_id; + }; + sf_collector = std::make_shared(ss, filter3); + sf_collector->Iterate(); + auto specified_segment_files_cnt = sf_collector->segment_files_.size(); + + OperationContext d_a_i_ctx; + d_a_i_ctx.lsn = next_lsn(); + d_a_i_ctx.stale_field_elements.push_back(ss->GetResource(field_element_id)); + + FieldElement::Ptr fe; + status = ss->GetFieldElement(sf_context.field_name, sf_context.field_element_name, + fe); + + ASSERT_TRUE(status.ok()); + ASSERT_EQ(fe, d_a_i_ctx.stale_field_elements[0]); + + std::cout << ss->ToString() << std::endl; + auto drop_all_index_op = std::make_shared(d_a_i_ctx, ss); + status = drop_all_index_op->Push(); + std::cout << status.ToString() << std::endl; + ASSERT_TRUE(status.ok()); + + status = drop_all_index_op->GetSnapshot(ss); + ASSERT_TRUE(status.ok()); + + sf_collector = std::make_shared(ss, filter2); + sf_collector->Iterate(); + ASSERT_EQ(sf_collector->segment_files_.size(), total - specified_segment_files_cnt); + + { + auto& field_elements = ss->GetResources(); + for (auto& kv : field_elements) { + ASSERT_NE(kv.second->GetID(), field_element_id); + } + } + + { + auto& fields = ss->GetResources(); + OperationContext dai_ctx; + for (auto& field : fields) { + auto elements = ss->GetFieldElementsByField(field.second->GetName()); + ASSERT_GE(elements.size(), 1); + dai_ctx.stale_field_elements.push_back(elements[0]); + } + ASSERT_GT(dai_ctx.stale_field_elements.size(), 1); + auto op = std::make_shared(dai_ctx, ss); + status = op->Push(); + ASSERT_FALSE(status.ok()); + } + + { + auto& fields = ss->GetResources(); + ASSERT_GT(fields.size(), 0); + OperationContext dai_ctx; + std::string field_name; + std::set stale_element_ids; + for (auto& field : fields) { + field_name = field.second->GetName(); + auto elements = ss->GetFieldElementsByField(field_name); + ASSERT_GE(elements.size(), 2); + for (auto& element : elements) { + stale_element_ids.insert(element->GetID()); + } + dai_ctx.stale_field_elements = std::move(elements); + break; + } + + std::set stale_segment_ids; + auto& segment_files = ss->GetResources(); + for (auto& kv : segment_files) { + auto& id = kv.first; + auto& segment_file = kv.second; + auto it = stale_element_ids.find(segment_file->GetFieldElementId()); + if (it != stale_element_ids.end()) { + stale_segment_ids.insert(id); + } + } + + auto prev_segment_file_cnt = segment_files.size(); + + ASSERT_GT(dai_ctx.stale_field_elements.size(), 1); + auto op = std::make_shared(dai_ctx, ss); + status = op->Push(); + ASSERT_TRUE(status.ok()); + status = op->GetSnapshot(ss); + ASSERT_TRUE(status.ok()); + ASSERT_EQ(ss->GetResources().size() + stale_segment_ids.size(), prev_segment_file_cnt); + } +} TEST_F(SnapshotTest, OperationTest) { - milvus::Status status; std::string to_string; LSN_TYPE lsn; + Status status; + + /* ID_TYPE collection_id; */ + /* ASSERT_TRUE(GetFirstCollectionID(collection_id).ok()); */ + std::string collection_name("c1"); + auto ss = CreateCollection(collection_name, ++lsn); + ASSERT_TRUE(ss); + SegmentFileContext sf_context; - sf_context.field_name = "f_1_1"; - sf_context.field_element_name = "fe_1_1"; - sf_context.segment_id = 1; - sf_context.partition_id = 1; + SFContextBuilder(sf_context, ss); + + auto& partitions = ss->GetResources(); + auto total_row_cnt = 0; + for (auto& kv : partitions) { + auto num = RandomInt(2, 5); + for (auto i = 0; i < num; ++i) { + auto row_cnt = RandomInt(100, 200); + ASSERT_TRUE(CreateSegment(ss, kv.first, ++lsn, sf_context, row_cnt).ok()); + total_row_cnt += row_cnt; + } + } - ScopedSnapshotT ss; - status = Snapshots::GetInstance().GetSnapshot(ss, 1); - std::cout << status.ToString() << std::endl; + status = Snapshots::GetInstance().GetSnapshot(ss, collection_name); ASSERT_TRUE(status.ok()); + ASSERT_EQ(total_row_cnt, ss->GetCollectionCommit()->GetRowCount()); + + auto total_size = ss->GetCollectionCommit()->GetSize(); + auto ss_id = ss->GetID(); lsn = ss->GetMaxLsn() + 1; @@ -613,11 +695,16 @@ TEST_F(SnapshotTest, OperationTest) { auto collection_commit = CollectionCommitsHolder::GetInstance().GetResource(ss_id, false); /* snapshot::SegmentCommitsHolder::GetInstance().GetResource(prev_segment_commit->GetID()); */ ASSERT_TRUE(collection_commit); - ASSERT_TRUE(collection_commit->ToString().empty()); + std::cout << collection_commit->ToString() << std::endl; } OperationContext merge_ctx; + auto merge_segment_row_cnt = 0; + std::set stale_segment_commit_ids; + SFContextBuilder(sf_context, ss); + + std::cout << ss->ToString() << std::endl; ID_TYPE new_seg_id; ScopedSnapshotT new_ss; @@ -625,18 +712,31 @@ TEST_F(SnapshotTest, OperationTest) { { OperationContext context; context.lsn = ++lsn; - auto build_op = std::make_shared(context, ss); + auto build_op = std::make_shared(context, ss); SegmentFilePtr seg_file; status = build_op->CommitNewSegmentFile(sf_context, seg_file); + std::cout << status.ToString() << std::endl; ASSERT_TRUE(status.ok()); ASSERT_TRUE(seg_file); auto prev_segment_commit = ss->GetSegmentCommitBySegmentId(seg_file->GetSegmentId()); auto prev_segment_commit_mappings = prev_segment_commit->GetMappings(); ASSERT_FALSE(prev_segment_commit->ToString().empty()); + auto new_size = RandomInt(1000, 20000); + seg_file->SetSize(new_size); + total_size += new_size; + + auto delta = prev_segment_commit->GetRowCount() / 2; + build_op->CommitRowCountDelta(delta); + total_row_cnt -= delta; + build_op->Push(); status = build_op->GetSnapshot(ss); + ASSERT_TRUE(status.ok()); ASSERT_GT(ss->GetID(), ss_id); + ASSERT_EQ(ss->GetCollectionCommit()->GetRowCount(), total_row_cnt); + ASSERT_EQ(ss->GetCollectionCommit()->GetSize(), total_size); + std::cout << ss->ToString() << std::endl; auto segment_commit = ss->GetSegmentCommitBySegmentId(seg_file->GetSegmentId()); auto segment_commit_mappings = segment_commit->GetMappings(); @@ -647,22 +747,36 @@ TEST_F(SnapshotTest, OperationTest) { auto seg = ss->GetResource(seg_file->GetSegmentId()); ASSERT_TRUE(seg); merge_ctx.stale_segments.push_back(seg); + merge_segment_row_cnt += ss->GetSegmentCommitBySegmentId(seg->GetID())->GetRowCount(); + stale_segment_commit_ids.insert(segment_commit->GetID()); } - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - // Check stale snapshot has been deleted from store - { - auto collection_commit = CollectionCommitsHolder::GetInstance().GetResource(ss_id, false); - ASSERT_FALSE(collection_commit); - } +// std::this_thread::sleep_for(std::chrono::milliseconds(100)); +// // Check stale snapshot has been deleted from store +// { +// auto collection_commit = CollectionCommitsHolder::GetInstance().GetResource(ss_id, false); +// ASSERT_FALSE(collection_commit); +// } ss_id = ss->GetID(); ID_TYPE partition_id; { + std::vector partitions; + auto executor = [&](const PartitionPtr& partition, + PartitionIterator* itr) -> Status { + if (partition->GetCollectionId() != ss->GetCollectionId()) { + return Status::OK(); + } + partitions.push_back(partition); + return Status::OK(); + }; + auto iterator = std::make_shared(ss, executor); + iterator->Iterate(); + OperationContext context; context.lsn = ++lsn; - context.prev_partition = ss->GetResource(1); + context.prev_partition = ss->GetResource(partitions[0]->GetID()); auto op = std::make_shared(context, ss); SegmentPtr new_seg; status = op->CommitNewSegment(new_seg); @@ -671,19 +785,34 @@ TEST_F(SnapshotTest, OperationTest) { SegmentFilePtr seg_file; status = op->CommitNewSegmentFile(sf_context, seg_file); ASSERT_TRUE(status.ok()); + + auto new_segment_row_cnt = RandomInt(100, 200); + status = op->CommitRowCount(new_segment_row_cnt); + ASSERT_TRUE(status.ok()); + total_row_cnt += new_segment_row_cnt; + + auto new_size = new_segment_row_cnt * 5; + seg_file->SetSize(new_size); + total_size += new_size; + status = op->Push(); ASSERT_TRUE(status.ok()); status = op->GetSnapshot(ss); ASSERT_GT(ss->GetID(), ss_id); ASSERT_TRUE(status.ok()); + ASSERT_EQ(ss->GetCollectionCommit()->GetRowCount(), total_row_cnt); + ASSERT_EQ(ss->GetCollectionCommit()->GetSize(), total_size); auto segment_commit = ss->GetSegmentCommitBySegmentId(seg_file->GetSegmentId()); auto segment_commit_mappings = segment_commit->GetMappings(); MappingT expected_segment_mappings; expected_segment_mappings.insert(seg_file->GetID()); ASSERT_EQ(expected_segment_mappings, segment_commit_mappings); + merge_ctx.stale_segments.push_back(new_seg); + merge_segment_row_cnt += ss->GetSegmentCommitBySegmentId(new_seg->GetID())->GetRowCount(); + partition_id = segment_commit->GetPartitionId(); stale_segment_commit_ids.insert(segment_commit->GetID()); auto partition = ss->GetResource(partition_id); @@ -707,14 +836,29 @@ TEST_F(SnapshotTest, OperationTest) { SegmentFilePtr seg_file; status = op->CommitNewSegmentFile(sf_context, seg_file); ASSERT_TRUE(status.ok()); + + auto new_size = RandomInt(1000, 20000); + seg_file->SetSize(new_size); + auto stale_size = 0; + for (auto& stale_seg : merge_ctx.stale_segments) { + stale_size += ss->GetSegmentCommitBySegmentId(stale_seg->GetID())->GetSize(); + } + total_size += new_size; + total_size -= stale_size; + status = op->Push(); - ASSERT_TRUE(status.ok()); + ASSERT_TRUE(status.ok()) << status.ToString(); std::cout << op->ToString() << std::endl; status = op->GetSnapshot(ss); ASSERT_GT(ss->GetID(), ss_id); ASSERT_TRUE(status.ok()); + ASSERT_EQ(total_size, ss->GetCollectionCommit()->GetSize()); auto segment_commit = ss->GetSegmentCommitBySegmentId(new_seg->GetID()); + ASSERT_EQ(segment_commit->GetRowCount(), merge_segment_row_cnt); + ASSERT_EQ(segment_commit->GetSize(), new_size); + ASSERT_EQ(ss->GetCollectionCommit()->GetRowCount(), total_row_cnt); + auto new_partition_commit = ss->GetPartitionCommitByPartitionId(partition_id); auto new_mappings = new_partition_commit->GetMappings(); auto prev_mappings = prev_partition_commit->GetMappings(); @@ -736,7 +880,7 @@ TEST_F(SnapshotTest, OperationTest) { { OperationContext context; context.lsn = ++lsn; - auto build_op = std::make_shared(context, new_ss); + auto build_op = std::make_shared(context, new_ss); SegmentFilePtr seg_file; auto new_sf_context = sf_context; new_sf_context.segment_id = new_seg_id; @@ -744,6 +888,68 @@ TEST_F(SnapshotTest, OperationTest) { ASSERT_FALSE(status.ok()); } + { + OperationContext context; + context.lsn = ++lsn; + auto op = std::make_shared(context, ss); + SegmentFilePtr seg_file; + auto new_sf_context = sf_context; + new_sf_context.segment_id = merge_seg->GetID(); + status = op->CommitNewSegmentFile(new_sf_context, seg_file); + ASSERT_TRUE(status.ok()); + auto prev_sc = ss->GetSegmentCommitBySegmentId(merge_seg->GetID()); + ASSERT_TRUE(prev_sc); + auto delta = prev_sc->GetRowCount() + 1; + op->CommitRowCountDelta(delta); + status = op->Push(); + std::cout << status.ToString() << std::endl; + ASSERT_FALSE(status.ok()); + } + + std::string new_fe_name = "fe_index"; + { + status = Snapshots::GetInstance().GetSnapshot(ss, collection_name); + ASSERT_TRUE(status.ok()); + + auto field = ss->GetField(sf_context.field_name); + ASSERT_TRUE(field); + auto new_fe = std::make_shared(ss->GetCollectionId(), + field->GetID(), new_fe_name, milvus::engine::FieldElementType::FET_INDEX); + + OperationContext context; + context.lsn = ++lsn; + context.new_field_elements.push_back(new_fe); + auto op = std::make_shared(context, ss); + status = op->Push(); + ASSERT_TRUE(status.ok()); + + status = op->GetSnapshot(ss); + ASSERT_TRUE(status.ok()); + + std::cout << ss->ToString() << std::endl; + } + + { + auto snapshot_id = ss->GetID(); + auto field = ss->GetField(sf_context.field_name); + ASSERT_TRUE(field); + FieldElementPtr new_fe; + status = ss->GetFieldElement(sf_context.field_name, new_fe_name, new_fe); + ASSERT_TRUE(status.ok()); + ASSERT_TRUE(new_fe); + + OperationContext context; + context.lsn = ++lsn; + context.new_field_elements.push_back(new_fe); + auto op = std::make_shared(context, ss); + status = op->Push(); + ASSERT_FALSE(status.ok()); + + status = Snapshots::GetInstance().GetSnapshot(ss, collection_name); + ASSERT_TRUE(status.ok()); + ASSERT_EQ(snapshot_id, ss->GetID()); + } + // 1. Build start // 2. Commit new seg file of build operation // 3. Drop collection @@ -751,7 +957,7 @@ TEST_F(SnapshotTest, OperationTest) { { OperationContext context; context.lsn = ++lsn; - auto build_op = std::make_shared(context, ss); + auto build_op = std::make_shared(context, ss); SegmentFilePtr seg_file; auto new_sf_context = sf_context; new_sf_context.segment_id = merge_seg->GetID(); @@ -767,16 +973,17 @@ TEST_F(SnapshotTest, OperationTest) { ASSERT_FALSE(build_op->GetStatus().ok()); std::cout << build_op->ToString() << std::endl; } + Snapshots::GetInstance().Reset(); } TEST_F(SnapshotTest, CompoundTest1) { - milvus::Status status; + Status status; std::atomic lsn = 0; auto next_lsn = [&]() -> decltype(lsn) { return ++lsn; }; - LSN_TYPE pid = 0; + std::atomic pid = 0; auto next_pid = [&]() -> decltype(pid) { return ++pid; }; @@ -820,7 +1027,7 @@ TEST_F(SnapshotTest, CompoundTest1) { OperationContext context; context.lsn = next_lsn(); - auto build_op = std::make_shared(context, latest_ss); + auto build_op = std::make_shared(context, latest_ss); SegmentFilePtr seg_file; build_sf_context.segment_id = seg_id; status = build_op->CommitNewSegmentFile(build_sf_context, seg_file); @@ -1112,12 +1319,12 @@ TEST_F(SnapshotTest, CompoundTest1) { TEST_F(SnapshotTest, CompoundTest2) { - milvus::Status status; + Status status; LSN_TYPE lsn = 0; auto next_lsn = [&]() -> LSN_TYPE& { return ++lsn; }; - LSN_TYPE pid = 0; + std::atomic pid = 0; auto next_pid = [&]() -> LSN_TYPE { return ++pid; }; @@ -1164,7 +1371,7 @@ TEST_F(SnapshotTest, CompoundTest2) { OperationContext context; context.lsn = next_lsn(); - auto build_op = std::make_shared(context, latest_ss); + auto build_op = std::make_shared(context, latest_ss); SegmentFilePtr seg_file; build_sf_context.segment_id = seg_id; status = build_op->CommitNewSegmentFile(build_sf_context, seg_file); @@ -1529,7 +1736,6 @@ TEST_F(SnapshotTest, CompoundTest2) { if (it == stale_partitions.end()) { continue; } - /* std::cout << "stale Segment " << seg_p.first << std::endl; */ expect_segments.erase(seg_p.first); } @@ -1541,3 +1747,47 @@ TEST_F(SnapshotTest, CompoundTest2) { ASSERT_EQ(final_segments, expect_segments); // TODO: Check Total Segment Files Cnt } + +struct GCSchedule { + static constexpr const char* Name = "GCSchedule"; +}; + +struct FlushSchedule { + static constexpr const char* Name = "FlushSchedule"; +}; + +using IEventHandler = milvus::engine::snapshot::IEventHandler; +/* struct SampleHandler : public IEventHandler { */ +/* static constexpr const char* EventName = "SampleHandler"; */ +/* const char* */ +/* GetEventName() const override { */ +/* return EventName; */ +/* } */ +/* }; */ + +REGISTER_HANDLER(GCSchedule, IEventHandler); +/* REGISTER_HANDLER(GCSchedule, SampleHandler); */ +REGISTER_HANDLER(FlushSchedule, IEventHandler); +/* REGISTER_HANDLER(FlushSchedule, SampleHandler); */ + +using GCScheduleFactory = milvus::engine::snapshot::HandlerFactory; +using FlushScheduleFactory = milvus::engine::snapshot::HandlerFactory; + +TEST_F(SnapshotTest, RegistryTest) { + { + auto& factory = GCScheduleFactory::GetInstance(); + auto ihandler = factory.GetHandler(IEventHandler::EventName); + ASSERT_TRUE(ihandler); + /* auto sihandler = factory.GetHandler(SampleHandler::EventName); */ + /* ASSERT_TRUE(sihandler); */ + /* ASSERT_EQ(SampleHandler::EventName, sihandler->GetEventName()); */ + } + { + /* auto& factory = FlushScheduleFactory::GetInstance(); */ + /* auto ihandler = factory.GetHandler(IEventHandler::EventName); */ + /* ASSERT_TRUE(ihandler); */ + /* auto sihandler = factory.GetHandler(SampleHandler::EventName); */ + /* ASSERT_TRUE(sihandler); */ + /* ASSERT_EQ(SampleHandler::EventName, sihandler->GetEventName()); */ + } +} diff --git a/core/unittest/ssdb/test_ss_event.cpp b/core/unittest/ssdb/test_ss_event.cpp new file mode 100644 index 000000000000..2d66cff6c157 --- /dev/null +++ b/core/unittest/ssdb/test_ss_event.cpp @@ -0,0 +1,154 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include + +#include "db/snapshot/InActiveResourcesGCEvent.h" +#include "ssdb/utils.h" + +using CollectionCommit = milvus::engine::snapshot::CollectionCommit; +using CollectionCommitPtr = milvus::engine::snapshot::CollectionCommitPtr; +using PartitionCommit = milvus::engine::snapshot::PartitionCommit; +using PartitionCommitPtr = milvus::engine::snapshot::PartitionCommitPtr; +using SegmentCommit = milvus::engine::snapshot::SegmentCommit; +using SegmentCommitPtr = milvus::engine::snapshot::SegmentCommitPtr; +using SchemaCommit = milvus::engine::snapshot::SchemaCommit; +using FieldCommit = milvus::engine::snapshot::FieldCommit; + +using FType = milvus::engine::FieldType; +using FEType = milvus::engine::FieldElementType; + +using InActiveResourcesGCEvent = milvus::engine::snapshot::InActiveResourcesGCEvent; + +TEST_F(SSEventTest, TestInActiveResGcEvent) { + CollectionPtr collection; + auto status = store_->CreateResource(Collection("test_gc_c1"), collection); + ASSERT_TRUE(status.ok()) << status.ToString(); + + CollectionPtr inactive_collection; + auto c = Collection("test_gc_c2"); + c.Deactivate(); + status = store_->CreateResource(std::move(c), inactive_collection); + ASSERT_TRUE(status.ok()) << status.ToString(); + + CollectionPtr active_collection; + auto c_2 = Collection("test_gc_c3"); + c_2.Activate(); + status = store_->CreateResource(std::move(c_2), active_collection); + ASSERT_TRUE(status.ok()) << status.ToString(); + + PartitionPtr partition; + status = store_->CreateResource(Partition("test_gc_c1_p1", collection->GetID()), partition); + ASSERT_TRUE(status.ok()) << status.ToString(); + + PartitionPtr inactive_partition; + auto p = Partition("test_gc_c1_p2", collection->GetID()); + p.Deactivate(); + status = store_->CreateResource(std::move(p), inactive_partition); + ASSERT_TRUE(status.ok()) << status.ToString(); + + PartitionCommitPtr partition_commit; + status = store_->CreateResource(PartitionCommit(collection->GetID(), partition->GetID()), + partition_commit); + ASSERT_TRUE(status.ok()) << status.ToString(); + + CollectionCommitPtr collection_commit; + status = store_->CreateResource(CollectionCommit(0, 0), collection_commit); + ASSERT_TRUE(status.ok()) << status.ToString(); + + SegmentPtr s; + status = store_->CreateResource(Segment(collection->GetID(), partition->GetCollectionId()), s); + ASSERT_TRUE(status.ok()) << status.ToString(); + + Field::Ptr field; + status = store_->CreateResource(Field("f_0", 0, FType::INT64), field); + ASSERT_TRUE(status.ok()) << status.ToString(); + + FieldElementPtr field_element; + status = store_->CreateResource( + FieldElement(collection->GetID(), field->GetID(), "fe_0", FEType::FET_INDEX), field_element); + ASSERT_TRUE(status.ok()) << status.ToString(); + + FieldCommit::Ptr field_commit; + status = store_->CreateResource(FieldCommit(collection->GetID(), field->GetID()), field_commit); + ASSERT_TRUE(status.ok()) << status.ToString(); + + SchemaCommit::Ptr schema; + status = store_->CreateResource(SchemaCommit(collection->GetID(), {}), schema); + ASSERT_TRUE(status.ok()) << status.ToString(); + + SegmentFilePtr seg_file; + status = store_->CreateResource( + SegmentFile(collection->GetID(), partition->GetID(), s->GetID(), field_element->GetID()), seg_file); + ASSERT_TRUE(status.ok()) << status.ToString(); + + SegmentCommitPtr sc; + status = store_->CreateResource(SegmentCommit(schema->GetID(), partition->GetID(), s->GetID()), sc); + ASSERT_TRUE(status.ok()) << status.ToString(); + + CollectionCommitPtr inactive_collection_commit; + auto cc = CollectionCommit(collection->GetID(), schema->GetID()); + cc.Deactivate(); + status = store_->CreateResource(std::move(cc), inactive_collection_commit); + ASSERT_TRUE(status.ok()) << status.ToString(); + + // TODO(yhz): Check if disk file has been deleted + + auto event = std::make_shared(); + status = event->Process(store_); +// milvus::engine::snapshot::EventExecutor::GetInstance().Submit(event); +// status = event->WaitToFinish(); + ASSERT_TRUE(status.ok()) << status.ToString(); + + std::vector field_elements; + ASSERT_TRUE(store_->GetInActiveResources(field_elements).ok()); + ASSERT_TRUE(field_elements.empty()); + + std::vector fields; + ASSERT_TRUE(store_->GetInActiveResources(fields).ok()); + ASSERT_TRUE(fields.empty()); + + std::vector field_commits; + ASSERT_TRUE(store_->GetInActiveResources(field_commits).ok()); + ASSERT_TRUE(field_commits.empty()); + + std::vector seg_files; + ASSERT_TRUE(store_->GetInActiveResources(seg_files).ok()); + ASSERT_TRUE(seg_files.empty()); + + std::vector seg_commits; + ASSERT_TRUE(store_->GetInActiveResources(seg_commits).ok()); + ASSERT_TRUE(seg_commits.empty()); + + std::vector segs; + ASSERT_TRUE(store_->GetInActiveResources(segs).ok()); + ASSERT_TRUE(segs.empty()); + + std::vector schemas; + ASSERT_TRUE(store_->GetInActiveResources(schemas).ok()); + ASSERT_TRUE(schemas.empty()); + + std::vector partition_commits; + ASSERT_TRUE(store_->GetInActiveResources(partition_commits).ok()); + ASSERT_TRUE(partition_commits.empty()); + + std::vector partitions; + ASSERT_TRUE(store_->GetInActiveResources(partitions).ok()); + ASSERT_TRUE(partitions.empty()); + + std::vector collections; + ASSERT_TRUE(store_->GetInActiveResources(collections).ok()); + ASSERT_TRUE(collections.empty()); + + std::vector collection_commits; + ASSERT_TRUE(store_->GetInActiveResources(collection_commits).ok()); + ASSERT_TRUE(collection_commits.empty()); +} diff --git a/core/unittest/ssdb/test_ss_job.cpp b/core/unittest/ssdb/test_ss_job.cpp new file mode 100644 index 000000000000..aa4a5d7ae9e0 --- /dev/null +++ b/core/unittest/ssdb/test_ss_job.cpp @@ -0,0 +1,110 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include + +#include "db/SnapshotVisitor.h" +#include "knowhere/index/vector_index/helpers/IndexParameter.h" +#include "scheduler/SchedInst.h" +#include "scheduler/job/SSBuildIndexJob.h" +#include "scheduler/job/SSSearchJob.h" +#include "ssdb/utils.h" + +using SegmentVisitor = milvus::engine::SegmentVisitor; + +namespace { +milvus::Status +CreateCollection(std::shared_ptr db, const std::string& collection_name, const LSN_TYPE& lsn) { + CreateCollectionContext context; + context.lsn = lsn; + auto collection_schema = std::make_shared(collection_name); + context.collection = collection_schema; + auto vector_field = std::make_shared("vector", 0, + milvus::engine::FieldType::VECTOR); + auto vector_field_element = std::make_shared(0, 0, "ivfsq8", + milvus::engine::FieldElementType::FET_INDEX); + auto int_field = std::make_shared("int", 0, + milvus::engine::FieldType::INT32); + context.fields_schema[vector_field] = {vector_field_element}; + context.fields_schema[int_field] = {}; + + return db->CreateCollection(context); +} +} // namespace + +TEST_F(SSSchedulerTest, SSJobTest) { + LSN_TYPE lsn = 0; + auto next_lsn = [&]() -> decltype(lsn) { + return ++lsn; + }; + + std::string c1 = "c1"; + auto status = CreateCollection(db_, c1, next_lsn()); + ASSERT_TRUE(status.ok()); + + status = db_->CreatePartition(c1, "p_0"); + ASSERT_TRUE(status.ok()); + + ScopedSnapshotT ss; + status = Snapshots::GetInstance().GetSnapshot(ss, c1); + ASSERT_TRUE(status.ok()); + + SegmentFileContext sf_context; + SFContextBuilder(sf_context, ss); + + auto& partitions = ss->GetResources(); + ASSERT_EQ(partitions.size(), 2); + for (auto& kv : partitions) { + int64_t row_cnt = 100; + ASSERT_TRUE(CreateSegment(ss, kv.first, next_lsn(), sf_context, row_cnt).ok()); + } + + status = Snapshots::GetInstance().GetSnapshot(ss, c1); + ASSERT_TRUE(status.ok()); + + /* collect all valid segment */ + std::vector segment_visitors; + auto executor = [&](const SegmentPtr& segment, SegmentIterator* handler) -> Status { + auto visitor = SegmentVisitor::Build(ss, segment->GetID()); + if (visitor == nullptr) { + return Status(milvus::SS_ERROR, "Cannot build segment visitor"); + } + segment_visitors.push_back(visitor); + return Status::OK(); + }; + + auto segment_iter = std::make_shared(ss, executor); + segment_iter->Iterate(); + ASSERT_TRUE(segment_iter->GetStatus().ok()); + ASSERT_EQ(segment_visitors.size(), 2); + + /* create BuildIndexJob */ +// milvus::scheduler::SSBuildIndexJobPtr build_index_job = +// std::make_shared(""); +// for (auto& sv : segment_visitors) { +// build_index_job->AddSegmentVisitor(sv); +// } + + /* put search job to scheduler and wait result */ +// milvus::scheduler::JobMgrInst::GetInstance()->Put(build_index_job); +// build_index_job->WaitFinish(); + +// /* create SearchJob */ +// milvus::scheduler::SSSearchJobPtr search_job = +// std::make_shared(nullptr, "", nullptr); +// for (auto& sv : segment_visitors) { +// search_job->AddSegmentVisitor(sv); +// } +// +// /* put search job to scheduler and wait result */ +// milvus::scheduler::JobMgrInst::GetInstance()->Put(search_job); +// search_job->WaitFinish(); +} diff --git a/core/unittest/ssdb/test_ss_meta.cpp b/core/unittest/ssdb/test_ss_meta.cpp new file mode 100644 index 000000000000..2b95bafec614 --- /dev/null +++ b/core/unittest/ssdb/test_ss_meta.cpp @@ -0,0 +1,189 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "db/meta/MetaFields.h" +#include "db/meta/backend/MetaContext.h" +#include "db/snapshot/ResourceContext.h" +#include "ssdb/utils.h" + +template +using ResourceContext = milvus::engine::snapshot::ResourceContext; +template +using ResourceContextBuilder = milvus::engine::snapshot::ResourceContextBuilder; + +using FType = milvus::engine::FieldType; +using FEType = milvus::engine::FieldElementType; +using Op = milvus::engine::meta::MetaContextOp; + +TEST_F(SSMetaTest, ApplyTest) { + ID_TYPE result_id; + + auto collection = std::make_shared("meta_test_c1"); + auto c_ctx = ResourceContextBuilder().SetResource(collection).CreatePtr(); + auto status = meta_->Apply(c_ctx, result_id); + ASSERT_TRUE(status.ok()) << status.ToString(); + ASSERT_GT(result_id, 0); + collection->SetID(result_id); + + collection->Activate(); + auto c2_ctx = ResourceContextBuilder().SetResource(collection) + .SetOp(Op::oUpdate).AddAttr(milvus::engine::meta::F_STATE).CreatePtr(); + status = meta_->Apply(c2_ctx, result_id); + ASSERT_TRUE(status.ok()) << status.ToString(); + ASSERT_GT(result_id, 0); + ASSERT_EQ(result_id, collection->GetID()); + + auto c3_ctx = ResourceContextBuilder().SetID(result_id).SetOp(Op::oDelete).CreatePtr(); + status = meta_->Apply(c3_ctx, result_id); + ASSERT_TRUE(status.ok()) << status.ToString(); + ASSERT_GT(result_id, 0); + ASSERT_EQ(result_id, collection->GetID()); +} + +TEST_F(SSMetaTest, SessionTest) { + ID_TYPE result_id; + + auto collection = std::make_shared("meta_test_c1"); + auto c_ctx = ResourceContextBuilder().SetResource(collection).CreatePtr(); + auto status = meta_->Apply(c_ctx, result_id); + ASSERT_TRUE(status.ok()) << status.ToString(); + ASSERT_GT(result_id, 0); + collection->SetID(result_id); + + auto partition = std::make_shared("meta_test_p1", result_id); + auto p_ctx = ResourceContextBuilder().SetResource(partition).CreatePtr(); + status = meta_->Apply(p_ctx, result_id); + ASSERT_TRUE(status.ok()) << status.ToString(); + ASSERT_GT(result_id, 0); + partition->SetID(result_id); + + auto field = std::make_shared("meta_test_f1", 1, FType::INT64); + auto f_ctx = ResourceContextBuilder().SetResource(field).CreatePtr(); + status = meta_->Apply(f_ctx, result_id); + ASSERT_TRUE(status.ok()) << status.ToString(); + ASSERT_GT(result_id, 0); + field->SetID(result_id); + + auto field_element = std::make_shared(collection->GetID(), field->GetID(), + "meta_test_f1_fe1", FEType::FET_RAW); + auto fe_ctx = ResourceContextBuilder().SetResource(field_element).CreatePtr(); + status = meta_->Apply(fe_ctx, result_id); + ASSERT_TRUE(status.ok()) << status.ToString(); + ASSERT_GT(result_id, 0); + field_element->SetID(result_id); + + auto session = meta_->CreateSession(); + ASSERT_TRUE(collection->Activate()); + auto c2_ctx = ResourceContextBuilder().SetResource(collection) + .SetOp(Op::oUpdate).AddAttr(milvus::engine::meta::F_STATE).CreatePtr(); + status = session->Apply(c2_ctx); + ASSERT_TRUE(status.ok()) << status.ToString(); + + ASSERT_TRUE(partition->Activate()); + auto p2_ctx = ResourceContextBuilder().SetResource(partition) + .SetOp(Op::oUpdate).AddAttr(milvus::engine::meta::F_STATE).CreatePtr(); + status = session->Apply(p2_ctx); + ASSERT_TRUE(status.ok()) << status.ToString(); + + ASSERT_TRUE(field->Activate()); + auto f2_ctx = ResourceContextBuilder().SetResource(field) + .SetOp(Op::oUpdate).AddAttr(milvus::engine::meta::F_STATE).CreatePtr(); + status = session->Apply(f2_ctx); + ASSERT_TRUE(status.ok()) << status.ToString(); + + ASSERT_TRUE(field_element->Activate()); + auto fe2_ctx = ResourceContextBuilder().SetResource(field_element) + .SetOp(Op::oUpdate).AddAttr(milvus::engine::meta::F_STATE).CreatePtr(); + status = session->Apply(fe2_ctx); + ASSERT_TRUE(status.ok()) << status.ToString(); + + std::vector result_ids; + status = session->Commit(result_ids); + ASSERT_TRUE(status.ok()) << status.ToString(); + ASSERT_EQ(result_ids.size(), 4); + ASSERT_EQ(result_ids.at(0), collection->GetID()); + ASSERT_EQ(result_ids.at(1), partition->GetID()); + ASSERT_EQ(result_ids.at(2), field->GetID()); + ASSERT_EQ(result_ids.at(3), field_element->GetID()); +} + +TEST_F(SSMetaTest, SelectTest) { + ID_TYPE result_id; + + auto collection = std::make_shared("meta_test_c1"); + ASSERT_TRUE(collection->Activate()); + auto c_ctx = ResourceContextBuilder().SetResource(collection).CreatePtr(); + auto status = meta_->Apply(c_ctx, result_id); + ASSERT_TRUE(status.ok()) << status.ToString(); + ASSERT_GT(result_id, 0); + collection->SetID(result_id); + + Collection::Ptr return_collection; + status = meta_->Select(collection->GetID(), return_collection); + ASSERT_TRUE(status.ok()) << status.ToString(); + ASSERT_EQ(collection->GetID(), return_collection->GetID()); + ASSERT_EQ(collection->GetName(), return_collection->GetName()); + + auto collection2 = std::make_shared("meta_test_c2"); + ASSERT_TRUE(collection2->Activate()); + auto c2_ctx = ResourceContextBuilder().SetResource(collection2).CreatePtr(); + status = meta_->Apply(c2_ctx, result_id); + ASSERT_TRUE(status.ok()) << status.ToString(); + ASSERT_GT(result_id, 0); + collection2->SetID(result_id); + + ASSERT_GT(collection2->GetID(), collection->GetID()); + + std::vector return_collections; + status = meta_->SelectBy(milvus::engine::meta::F_ID, + {collection2->GetID()}, return_collections); + ASSERT_TRUE(status.ok()) << status.ToString(); + ASSERT_EQ(return_collections.size(), 1); + ASSERT_EQ(return_collections.at(0)->GetID(), collection2->GetID()); + ASSERT_EQ(return_collections.at(0)->GetName(), collection2->GetName()); + return_collections.clear(); + + status = meta_->SelectBy(milvus::engine::meta::F_STATE, {State::ACTIVE}, return_collections); + ASSERT_TRUE(status.ok()) << status.ToString(); + ASSERT_EQ(return_collections.size(), 2); + + std::vector ids; + status = meta_->SelectResourceIDs(ids, "", {""}); + ASSERT_TRUE(status.ok()) << status.ToString(); + ASSERT_EQ(ids.size(), 2); + + ids.clear(); + status = meta_->SelectResourceIDs(ids, milvus::engine::meta::F_NAME, + {collection->GetName()}); + ASSERT_TRUE(status.ok()) << status.ToString(); + ASSERT_EQ(ids.size(), 1); + ASSERT_EQ(ids.at(0), collection->GetID()); +} + +TEST_F(SSMetaTest, TruncateTest) { + ID_TYPE result_id; + + auto collection = std::make_shared("meta_test_c1"); + ASSERT_TRUE(collection->Activate()); + auto c_ctx = ResourceContextBuilder().SetResource(collection).CreatePtr(); + auto status = meta_->Apply(c_ctx, result_id); + ASSERT_TRUE(status.ok()) << status.ToString(); + ASSERT_GT(result_id, 0); + collection->SetID(result_id); + + status = meta_->TruncateAll(); + ASSERT_TRUE(status.ok()) << status.ToString(); + + Collection::Ptr return_collection; + status = meta_->Select(collection->GetID(), return_collection); + ASSERT_TRUE(status.ok()) << status.ToString(); + ASSERT_EQ(return_collection, nullptr); +} diff --git a/core/unittest/ssdb/test_ss_task.cpp b/core/unittest/ssdb/test_ss_task.cpp new file mode 100644 index 000000000000..4503636cb3f4 --- /dev/null +++ b/core/unittest/ssdb/test_ss_task.cpp @@ -0,0 +1,182 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include +#include +#include +#include + +//#include "db/meta/SqliteMetaImpl.h" +#include "db/DBFactory.h" +#include "scheduler/SchedInst.h" +#include "scheduler/job/SSBuildIndexJob.h" +#include "scheduler/job/SSSearchJob.h" +#include "scheduler/resource/CpuResource.h" +#include "scheduler/tasklabel/BroadcastLabel.h" +#include "scheduler/task/SSBuildIndexTask.h" +#include "scheduler/task/SSSearchTask.h" +#include "scheduler/task/SSTestTask.h" + +namespace milvus { +namespace scheduler { + +TEST(SSTaskTest, INVALID_INDEX) { + auto dummy_context = std::make_shared("dummy_request_id"); + opentracing::mocktracer::MockTracerOptions tracer_options; + auto mock_tracer = + std::shared_ptr{new opentracing::mocktracer::MockTracer{std::move(tracer_options)}}; + auto mock_span = mock_tracer->StartSpan("mock_span"); + auto trace_context = std::make_shared(mock_span); + dummy_context->SetTraceContext(trace_context); +} + +TEST(SSTaskTest, TEST_TASK) { + auto dummy_context = std::make_shared("dummy_request_id"); + +// auto file = std::make_shared(); +// file->index_params_ = "{ \"nlist\": 16384 }"; +// file->dimension_ = 64; + auto label = std::make_shared(); + +// SSTestTask task(dummy_context, nullptr, label); +// task.Load(LoadType::CPU2GPU, 0); +// auto th = std::thread([&]() { +// task.Execute(); +// }); +// task.Wait(); +// +// if (th.joinable()) { +// th.join(); +// } + +// static const char* CONFIG_PATH = "/tmp/milvus_test"; +// auto options = milvus::engine::DBFactory::BuildOption(); +// options.meta_.path_ = CONFIG_PATH; +// options.meta_.backend_uri_ = "sqlite://:@:/"; +// options.insert_cache_immediately_ = true; +// +// file->collection_id_ = "111"; +// file->location_ = "/tmp/milvus_test/index_file1.txt"; +// auto build_index_job = std::make_shared(options); +// XSSBuildIndexTask build_index_task(nullptr, label); +// build_index_task.job_ = build_index_job; +// +// build_index_task.Load(LoadType::TEST, 0); +// +// fiu_init(0); +// fiu_enable("XBuildIndexTask.Load.throw_std_exception", 1, NULL, 0); +// build_index_task.Load(LoadType::TEST, 0); +// fiu_disable("XBuildIndexTask.Load.throw_std_exception"); +// +// fiu_enable("XBuildIndexTask.Load.out_of_memory", 1, NULL, 0); +// build_index_task.Load(LoadType::TEST, 0); +// fiu_disable("XBuildIndexTask.Load.out_of_memory"); +// +// build_index_task.Execute(); +// // always enable 'create_table_success' +// fiu_enable("XBuildIndexTask.Execute.create_table_success", 1, NULL, 0); +// +// milvus::json json = {{"nlist", 16384}}; +// build_index_task.to_index_engine_ = +// EngineFactory::Build(file->dimension_, file->location_, (EngineType)file->engine_type_, +// (MetricType)file->metric_type_, json); +// +// build_index_task.Execute(); +// +// fiu_enable("XBuildIndexTask.Execute.build_index_fail", 1, NULL, 0); +// build_index_task.to_index_engine_ = +// EngineFactory::Build(file->dimension_, file->location_, (EngineType)file->engine_type_, +// (MetricType)file->metric_type_, json); +// build_index_task.Execute(); +// fiu_disable("XBuildIndexTask.Execute.build_index_fail"); +// +// // always enable 'has_collection' +// fiu_enable("XBuildIndexTask.Execute.has_collection", 1, NULL, 0); +// build_index_task.to_index_engine_ = +// EngineFactory::Build(file->dimension_, file->location_, (EngineType)file->engine_type_, +// (MetricType)file->metric_type_, json); +// build_index_task.Execute(); +// +// fiu_enable("XBuildIndexTask.Execute.throw_std_exception", 1, NULL, 0); +// build_index_task.to_index_engine_ = +// EngineFactory::Build(file->dimension_, file->location_, (EngineType)file->engine_type_, +// (MetricType)file->metric_type_, json); +// build_index_task.Execute(); +// fiu_disable("XBuildIndexTask.Execute.throw_std_exception"); +// +// fiu_enable("XBuildIndexTask.Execute.update_table_file_fail", 1, NULL, 0); +// build_index_task.to_index_engine_ = +// EngineFactory::Build(file->dimension_, file->location_, (EngineType)file->engine_type_, +// (MetricType)file->metric_type_, json); +// build_index_task.Execute(); +// fiu_disable("XBuildIndexTask.Execute.update_table_file_fail"); +// +// fiu_disable("XBuildIndexTask.Execute.throw_std_exception"); +// fiu_disable("XBuildIndexTask.Execute.has_collection"); +// fiu_disable("XBuildIndexTask.Execute.create_table_success"); +// build_index_task.Execute(); +// +// // search task +// engine::VectorsData vector; +// auto search_job = std::make_shared(dummy_context, 1, 1, vector); +// file->metric_type_ = static_cast(MetricType::IP); +// file->engine_type_ = static_cast(engine::EngineType::FAISS_IVFSQ8H); +// opentracing::mocktracer::MockTracerOptions tracer_options; +// auto mock_tracer = +// std::shared_ptr{new opentracing::mocktracer::MockTracer{std::move(tracer_options)}}; +// auto mock_span = mock_tracer->StartSpan("mock_span"); +// auto trace_context = std::make_shared(mock_span); +// dummy_context->SetTraceContext(trace_context); +// XSearchTask search_task(dummy_context, file, label); +// search_task.job_ = search_job; +// std::string cpu_resouce_name = "cpu_name1"; +// std::vector path = {cpu_resouce_name}; +// search_task.task_path_ = Path(path, 0); +// ResMgrInst::GetInstance()->Add(std::make_shared(cpu_resouce_name, 1, true)); +// +// search_task.Load(LoadType::CPU2GPU, 0); +// search_task.Load(LoadType::GPU2CPU, 0); +// +// fiu_enable("XSearchTask.Load.throw_std_exception", 1, NULL, 0); +// search_task.Load(LoadType::GPU2CPU, 0); +// fiu_disable("XSearchTask.Load.throw_std_exception"); +// +// fiu_enable("XSearchTask.Load.out_of_memory", 1, NULL, 0); +// search_task.Load(LoadType::GPU2CPU, 0); +// fiu_disable("XSearchTask.Load.out_of_memory"); +// +// fiu_enable("XSearchTask.Execute.search_fail", 1, NULL, 0); +// search_task.Execute(); +// fiu_disable("XSearchTask.Execute.search_fail"); +// +// fiu_enable("XSearchTask.Execute.throw_std_exception", 1, NULL, 0); +// search_task.Execute(); +// fiu_disable("XSearchTask.Execute.throw_std_exception"); +// +// search_task.Execute(); +// +// scheduler::ResultIds ids, tar_ids; +// scheduler::ResultDistances distances, tar_distances; +// XSearchTask::MergeTopkToResultSet(ids, distances, 1, 1, 1, true, tar_ids, tar_distances); +} + +TEST(SSTaskTest, TEST_PATH) { + Path path; + auto empty_path = path.Current(); + ASSERT_TRUE(empty_path.empty()); + empty_path = path.Next(); + ASSERT_TRUE(empty_path.empty()); + empty_path = path.Last(); + ASSERT_TRUE(empty_path.empty()); +} + +} // namespace scheduler +} // namespace milvus diff --git a/core/unittest/ssdb/utils.cpp b/core/unittest/ssdb/utils.cpp index 627243c72ac8..d12b9b2bb3ca 100644 --- a/core/unittest/ssdb/utils.cpp +++ b/core/unittest/ssdb/utils.cpp @@ -30,11 +30,11 @@ #include "db/snapshot/OperationExecutor.h" #include "db/snapshot/Snapshots.h" #include "db/snapshot/ResourceHolders.h" - #ifdef MILVUS_GPU_VERSION #include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h" #endif - +#include "scheduler/ResourceFactory.h" +#include "scheduler/SchedInst.h" #include "utils/CommonUtil.h" @@ -51,7 +51,7 @@ static const char* CONFIG_STR = "\n" "general:\n" " timezone: UTC+8\n" - " meta_uri: sqlite://:@:/\n" + " meta_uri: mock://:@:/\n" "\n" "network:\n" " bind.address: 0.0.0.0\n" @@ -138,20 +138,20 @@ BaseTest::InitLog() { } void -BaseTest::SetUp() { - InitLog(); -} - -void -BaseTest::TearDown() { -} - -///////////////////////////////////////////////////////////////////////////////////////////////////////////////// -void -SnapshotTest::SetUp() { - BaseTest::SetUp(); +BaseTest::SnapshotStart(bool mock_store) { + /* auto uri = "mysql://root:12345678@127.0.0.1:3307/milvus"; */ + auto uri = "mock://:@:/"; + auto& config = milvus::server::Config::GetInstance(); + config.SetGeneralConfigMetaURI(uri); + std::string path = "/tmp/milvus_ss/db"; + config.SetStorageConfigPath(path); + auto store = Store::Build(uri, path); + + milvus::engine::snapshot::OperationExecutor::Init(store); milvus::engine::snapshot::OperationExecutor::GetInstance().Start(); + milvus::engine::snapshot::EventExecutor::Init(store); milvus::engine::snapshot::EventExecutor::GetInstance().Start(); + milvus::engine::snapshot::CollectionCommitsHolder::GetInstance().Reset(); milvus::engine::snapshot::CollectionsHolder::GetInstance().Reset(); milvus::engine::snapshot::SchemaCommitsHolder::GetInstance().Reset(); @@ -164,41 +164,78 @@ SnapshotTest::SetUp() { milvus::engine::snapshot::SegmentCommitsHolder::GetInstance().Reset(); milvus::engine::snapshot::SegmentFilesHolder::GetInstance().Reset(); + if (mock_store) { + store->Mock(); + } else { + store->DoReset(); + } + milvus::engine::snapshot::Snapshots::GetInstance().Reset(); - milvus::engine::snapshot::Store::GetInstance().Mock(); - milvus::engine::snapshot::Snapshots::GetInstance().Init(); + milvus::engine::snapshot::Snapshots::GetInstance().Init(store); } void -SnapshotTest::TearDown() { +BaseTest::SnapshotStop() { // TODO: Temp to delay some time. OperationExecutor should wait all resources be destructed before stop std::this_thread::sleep_for(std::chrono::milliseconds(20)); milvus::engine::snapshot::EventExecutor::GetInstance().Stop(); milvus::engine::snapshot::OperationExecutor::GetInstance().Stop(); +} + +void +BaseTest::SetUp() { + InitLog(); +} + +void +BaseTest::TearDown() { +} + +///////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void +SnapshotTest::SetUp() { + BaseTest::SetUp(); + BaseTest::SnapshotStart(true); +} + +void +SnapshotTest::TearDown() { + BaseTest::SnapshotStop(); BaseTest::TearDown(); } ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// +milvus::engine::DBOptions +SSDBTest::GetOptions() { + auto options = milvus::engine::DBOptions(); + options.meta_.path_ = "/tmp/milvus_ss"; + options.meta_.backend_uri_ = "sqlite://:@:/"; + options.wal_enable_ = false; + return options; +} + void SSDBTest::SetUp() { BaseTest::SetUp(); - milvus::engine::snapshot::OperationExecutor::GetInstance().Start(); - milvus::engine::snapshot::EventExecutor::GetInstance().Start(); - milvus::engine::snapshot::CollectionCommitsHolder::GetInstance().Reset(); - milvus::engine::snapshot::CollectionsHolder::GetInstance().Reset(); - milvus::engine::snapshot::SchemaCommitsHolder::GetInstance().Reset(); - milvus::engine::snapshot::FieldCommitsHolder::GetInstance().Reset(); - milvus::engine::snapshot::FieldsHolder::GetInstance().Reset(); - milvus::engine::snapshot::FieldElementsHolder::GetInstance().Reset(); - milvus::engine::snapshot::PartitionsHolder::GetInstance().Reset(); - milvus::engine::snapshot::PartitionCommitsHolder::GetInstance().Reset(); - milvus::engine::snapshot::SegmentsHolder::GetInstance().Reset(); - milvus::engine::snapshot::SegmentCommitsHolder::GetInstance().Reset(); - milvus::engine::snapshot::SegmentFilesHolder::GetInstance().Reset(); + BaseTest::SnapshotStart(false); + db_ = std::make_shared(GetOptions()); +} - milvus::engine::snapshot::Store::GetInstance().DoReset(); - milvus::engine::snapshot::Snapshots::GetInstance().Reset(); - milvus::engine::snapshot::Snapshots::GetInstance().Init(); +void +SSDBTest::TearDown() { + BaseTest::SnapshotStop(); + db_ = nullptr; + auto options = GetOptions(); + boost::filesystem::remove_all(options.meta_.path_); + + BaseTest::TearDown(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void +SSSegmentTest::SetUp() { + BaseTest::SetUp(); + BaseTest::SnapshotStart(false); auto options = milvus::engine::DBOptions(); options.wal_enable_ = false; @@ -206,16 +243,78 @@ SSDBTest::SetUp() { } void -SSDBTest::TearDown() { +SSSegmentTest::TearDown() { + BaseTest::SnapshotStop(); db_ = nullptr; - // TODO: Temp to delay some time. OperationExecutor should wait all resources be destructed before stop - std::this_thread::sleep_for(std::chrono::milliseconds(20)); - milvus::engine::snapshot::EventExecutor::GetInstance().Stop(); - milvus::engine::snapshot::OperationExecutor::GetInstance().Stop(); + BaseTest::TearDown(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void +SSMetaTest::SetUp() { + auto engine = std::make_shared(); +// milvus::engine::DBMetaOptions options; +// options.backend_uri_ = "mysql://root:12345678@127.0.0.1:3307/milvus"; +// auto engine = std::make_shared(options); + meta_ = std::make_shared(engine); + meta_->TruncateAll(); +} + +void +SSMetaTest::TearDown() { +} +///////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void +SSSchedulerTest::SetUp() { + BaseTest::SetUp(); + BaseTest::SnapshotStart(true); + auto options = milvus::engine::DBOptions(); + options.wal_enable_ = false; + db_ = std::make_shared(options); + + auto res_mgr = milvus::scheduler::ResMgrInst::GetInstance(); + res_mgr->Clear(); + res_mgr->Add(milvus::scheduler::ResourceFactory::Create("disk", "DISK", 0, false)); + res_mgr->Add(milvus::scheduler::ResourceFactory::Create("cpu", "CPU", 0)); + + auto default_conn = milvus::scheduler::Connection("IO", 500.0); + auto PCIE = milvus::scheduler::Connection("IO", 11000.0); + res_mgr->Connect("disk", "cpu", default_conn); +#ifdef MILVUS_GPU_VERSION + res_mgr->Add(milvus::scheduler::ResourceFactory::Create("0", "GPU", 0)); + res_mgr->Connect("cpu", "0", PCIE); +#endif + res_mgr->Start(); + milvus::scheduler::SchedInst::GetInstance()->Start(); + milvus::scheduler::JobMgrInst::GetInstance()->Start(); + milvus::scheduler::CPUBuilderInst::GetInstance()->Start(); +} + +void +SSSchedulerTest::TearDown() { + milvus::scheduler::JobMgrInst::GetInstance()->Stop(); + milvus::scheduler::SchedInst::GetInstance()->Stop(); + milvus::scheduler::CPUBuilderInst::GetInstance()->Stop(); + milvus::scheduler::ResMgrInst::GetInstance()->Stop(); + milvus::scheduler::ResMgrInst::GetInstance()->Clear(); + + db_ = nullptr; + BaseTest::SnapshotStop(); BaseTest::TearDown(); } +void +SSEventTest::SetUp() { + auto uri = "mock://:@:/"; + store_ = Store::Build(uri, "/tmp/milvus_ss/db"); + store_->DoReset(); +} + +void +SSEventTest::TearDown() { +} + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// int main(int argc, char **argv) { diff --git a/core/unittest/ssdb/utils.h b/core/unittest/ssdb/utils.h index 6415048ccd44..23bb594354ad 100644 --- a/core/unittest/ssdb/utils.h +++ b/core/unittest/ssdb/utils.h @@ -14,8 +14,76 @@ #include #include #include +#include +#include +#include +#include #include "db/SSDBImpl.h" +#include "db/meta/MetaAdapter.h" +#include "db/snapshot/CompoundOperations.h" +#include "db/snapshot/Context.h" +#include "db/snapshot/EventExecutor.h" +#include "db/snapshot/OperationExecutor.h" +#include "db/snapshot/ReferenceProxy.h" +#include "db/snapshot/ResourceHolders.h" +#include "db/snapshot/ScopedResource.h" +#include "db/snapshot/Snapshots.h" +#include "db/snapshot/Store.h" +#include "db/snapshot/WrappedTypes.h" + +using ID_TYPE = milvus::engine::snapshot::ID_TYPE; +using IDS_TYPE = milvus::engine::snapshot::IDS_TYPE; +using LSN_TYPE = milvus::engine::snapshot::LSN_TYPE; +using SIZE_TYPE = milvus::engine::snapshot::SIZE_TYPE; +using MappingT = milvus::engine::snapshot::MappingT; +using State = milvus::engine::snapshot::State; +using LoadOperationContext = milvus::engine::snapshot::LoadOperationContext; +using CreateCollectionContext = milvus::engine::snapshot::CreateCollectionContext; +using SegmentFileContext = milvus::engine::snapshot::SegmentFileContext; +using OperationContext = milvus::engine::snapshot::OperationContext; +using PartitionContext = milvus::engine::snapshot::PartitionContext; +using DropIndexOperation = milvus::engine::snapshot::DropIndexOperation; +using AddFieldElementOperation = milvus::engine::snapshot::AddFieldElementOperation; +using DropAllIndexOperation = milvus::engine::snapshot::DropAllIndexOperation; +using AddSegmentFileOperation = milvus::engine::snapshot::AddSegmentFileOperation; +using MergeOperation = milvus::engine::snapshot::MergeOperation; +using CreateCollectionOperation = milvus::engine::snapshot::CreateCollectionOperation; +using NewSegmentOperation = milvus::engine::snapshot::NewSegmentOperation; +using DropPartitionOperation = milvus::engine::snapshot::DropPartitionOperation; +using CreatePartitionOperation = milvus::engine::snapshot::CreatePartitionOperation; +using DropCollectionOperation = milvus::engine::snapshot::DropCollectionOperation; +using CollectionCommitsHolder = milvus::engine::snapshot::CollectionCommitsHolder; +using CollectionsHolder = milvus::engine::snapshot::CollectionsHolder; +using CollectionScopedT = milvus::engine::snapshot::CollectionScopedT; +using Collection = milvus::engine::snapshot::Collection; +using CollectionPtr = milvus::engine::snapshot::CollectionPtr; +using Partition = milvus::engine::snapshot::Partition; +using PartitionPtr = milvus::engine::snapshot::PartitionPtr; +using Segment = milvus::engine::snapshot::Segment; +using SegmentPtr = milvus::engine::snapshot::SegmentPtr; +using SegmentFile = milvus::engine::snapshot::SegmentFile; +using SegmentFilePtr = milvus::engine::snapshot::SegmentFilePtr; +using Field = milvus::engine::snapshot::Field; +using FieldElement = milvus::engine::snapshot::FieldElement; +using FieldElementPtr = milvus::engine::snapshot::FieldElementPtr; +using Snapshots = milvus::engine::snapshot::Snapshots; +using ScopedSnapshotT = milvus::engine::snapshot::ScopedSnapshotT; +using ReferenceProxy = milvus::engine::snapshot::ReferenceProxy; +using Queue = milvus::BlockingQueue; +using TQueue = milvus::BlockingQueue>; +using SoftDeleteCollectionOperation = milvus::engine::snapshot::SoftDeleteOperation; +using ParamsField = milvus::engine::snapshot::ParamsField; +using IteratePartitionHandler = milvus::engine::snapshot::IterateHandler; +using IterateSegmentFileHandler = milvus::engine::snapshot::IterateHandler; +using PartitionIterator = milvus::engine::snapshot::PartitionIterator; +using SegmentIterator = milvus::engine::snapshot::SegmentIterator; +using SSDBImpl = milvus::engine::SSDBImpl; +using Status = milvus::Status; +using Store = milvus::engine::snapshot::Store; + +using StorePtr = milvus::engine::snapshot::Store::Ptr; +using MetaAdapterPtr = milvus::engine::meta::MetaAdapterPtr; inline int RandomInt(int start, int end) { @@ -25,10 +93,204 @@ RandomInt(int start, int end) { return dist(rng); } +inline void +SFContextBuilder(SegmentFileContext& ctx, ScopedSnapshotT sss) { + auto field = sss->GetResources().begin()->second; + ctx.field_name = field->GetName(); + for (auto& kv : sss->GetResources()) { + ctx.field_element_name = kv.second->GetName(); + break; + } + auto& segments = sss->GetResources(); + if (segments.size() == 0) { + return; + } + + ctx.segment_id = sss->GetResources().begin()->second->GetID(); + ctx.partition_id = sss->GetResources().begin()->second->GetPartitionId(); +} + +inline void +SFContextsBuilder(std::vector& contexts, ScopedSnapshotT sss) { + auto fields = sss->GetResources(); + for (auto& field_kv : fields) { + for (auto& kv : sss->GetResources()) { + if (kv.second->GetFieldId() != field_kv.first) { + continue; + } + SegmentFileContext ctx; + ctx.field_name = field_kv.second->GetName(); + ctx.field_element_name = kv.second->GetName(); + contexts.push_back(ctx); + } + } + auto& segments = sss->GetResources(); + if (segments.size() == 0) { + return; + } + + for (auto& ctx : contexts) { + ctx.segment_id = sss->GetResources().begin()->second->GetID(); + ctx.partition_id = sss->GetResources().begin()->second->GetPartitionId(); + } +} + +struct PartitionCollector : public IteratePartitionHandler { + using ResourceT = Partition; + using BaseT = IteratePartitionHandler; + explicit PartitionCollector(ScopedSnapshotT ss) : BaseT(ss) {} + + Status + PreIterate() override { + partition_names_.clear(); + return Status::OK(); + } + + Status + Handle(const typename ResourceT::Ptr& partition) override { + partition_names_.push_back(partition->GetName()); + return Status::OK(); + } + + std::vector partition_names_; +}; + +using FilterT = std::function; +struct SegmentFileCollector : public IterateSegmentFileHandler { + using ResourceT = SegmentFile; + using BaseT = IterateSegmentFileHandler; + explicit SegmentFileCollector(ScopedSnapshotT ss, const FilterT& filter) + : filter_(filter), BaseT(ss) {} + + Status + PreIterate() override { + segment_files_.clear(); + return Status::OK(); + } + + Status + Handle(const typename ResourceT::Ptr& segment_file) override { + if (!filter_(segment_file)) { + return Status::OK(); + } + segment_files_.insert(segment_file->GetID()); + return Status::OK(); + } + + FilterT filter_; + std::set segment_files_; +}; + +struct WaitableObj { + bool notified_ = false; + std::mutex mutex_; + std::condition_variable cv_; + + void + Wait() { + std::unique_lock lck(mutex_); + if (!notified_) { + cv_.wait(lck); + } + notified_ = false; + } + + void + Notify() { + std::unique_lock lck(mutex_); + notified_ = true; + lck.unlock(); + cv_.notify_one(); + } +}; + +inline ScopedSnapshotT +CreateCollection(const std::string& collection_name, const LSN_TYPE& lsn) { + CreateCollectionContext context; + context.lsn = lsn; + auto collection_schema = std::make_shared(collection_name); + context.collection = collection_schema; + auto vector_field = std::make_shared("vector", 0, + milvus::engine::FieldType::VECTOR_FLOAT); + auto vector_field_element = std::make_shared(0, 0, "ivfsq8", + milvus::engine::FieldElementType::FET_INDEX); + auto int_field = std::make_shared("int", 0, + milvus::engine::FieldType::INT32); + context.fields_schema[vector_field] = {vector_field_element}; + context.fields_schema[int_field] = {}; + + auto op = std::make_shared(context); + op->Push(); + ScopedSnapshotT ss; + auto status = op->GetSnapshot(ss); + return ss; +} + +inline ScopedSnapshotT +CreatePartition(const std::string& collection_name, const PartitionContext& p_context, const LSN_TYPE& lsn) { + ScopedSnapshotT curr_ss; + ScopedSnapshotT ss; + auto status = Snapshots::GetInstance().GetSnapshot(ss, collection_name); + if (!status.ok()) { + std::cout << status.ToString() << std::endl; + return curr_ss; + } + + OperationContext context; + context.lsn = lsn; + auto op = std::make_shared(context, ss); + + PartitionPtr partition; + status = op->CommitNewPartition(p_context, partition); + if (!status.ok()) { + std::cout << status.ToString() << std::endl; + return curr_ss; + } + + status = op->Push(); + if (!status.ok()) { + std::cout << status.ToString() << std::endl; + return curr_ss; + } + + status = op->GetSnapshot(curr_ss); + if (!status.ok()) { + std::cout << status.ToString() << std::endl; + return curr_ss; + } + return curr_ss; +} + +inline Status +CreateSegment(ScopedSnapshotT ss, ID_TYPE partition_id, LSN_TYPE lsn, const SegmentFileContext& sf_context, + SIZE_TYPE row_cnt) { + OperationContext context; + context.lsn = lsn; + context.prev_partition = ss->GetResource(partition_id); + auto op = std::make_shared(context, ss); + SegmentPtr new_seg; + STATUS_CHECK(op->CommitNewSegment(new_seg)); + SegmentFilePtr seg_file; + auto nsf_context = sf_context; + nsf_context.segment_id = new_seg->GetID(); + nsf_context.partition_id = new_seg->GetPartitionId(); + STATUS_CHECK(op->CommitNewSegmentFile(nsf_context, seg_file)); + op->CommitRowCount(row_cnt); + seg_file->SetSize(row_cnt * 10); + STATUS_CHECK(op->Push()); + + return op->GetSnapshot(ss); +} + +/////////////////////////////////////////////////////////////////////////////// class BaseTest : public ::testing::Test { protected: void InitLog(); + void + SnapshotStart(bool mock_store); + void + SnapshotStop(); void SetUp() override; @@ -36,6 +298,7 @@ class BaseTest : public ::testing::Test { TearDown() override; }; +/////////////////////////////////////////////////////////////////////////////// class SnapshotTest : public BaseTest { protected: void @@ -44,10 +307,59 @@ class SnapshotTest : public BaseTest { TearDown() override; }; +/////////////////////////////////////////////////////////////////////////////// class SSDBTest : public BaseTest { protected: - std::shared_ptr db_; + std::shared_ptr db_; + + milvus::engine::DBOptions + GetOptions(); + + void + SetUp() override; + void + TearDown() override; +}; + +/////////////////////////////////////////////////////////////////////////////// +class SSSegmentTest : public BaseTest { + protected: + std::shared_ptr db_; + + void + SetUp() override; + void + TearDown() override; +}; + +/////////////////////////////////////////////////////////////////////////////// +class SSMetaTest : public BaseTest { + protected: + MetaAdapterPtr meta_; + protected: + void + SetUp() override; + void + TearDown() override; +}; + +/////////////////////////////////////////////////////////////////////////////// +class SSSchedulerTest : public BaseTest { + protected: + std::shared_ptr db_; + + void + SetUp() override; + void + TearDown() override; +}; + +class SSEventTest : public BaseTest { + protected: + StorePtr store_; + + protected: void SetUp() override; void diff --git a/sdk/examples/binary_vector/src/ClientTest.cpp b/sdk/examples/binary_vector/src/ClientTest.cpp index 439fcaf284e3..980e18327c5e 100644 --- a/sdk/examples/binary_vector/src/ClientTest.cpp +++ b/sdk/examples/binary_vector/src/ClientTest.cpp @@ -164,7 +164,7 @@ ClientTest::Test(const std::string& address, const std::string& port) { field_ptr1->index_params = index_param_1.dump(); milvus::FieldPtr field_ptr2 = std::make_shared(); - field_ptr2->field_type = milvus::DataType::BINARY_VECTOR; + field_ptr2->field_type = milvus::DataType::VECTOR_BINARY; field_ptr2->field_name = "field_vec"; JSON index_param_2; index_param_2["name"] = "index_3"; diff --git a/sdk/examples/simple/src/ClientTest.cpp b/sdk/examples/simple/src/ClientTest.cpp index 2373be4c519b..84f981cc5eeb 100644 --- a/sdk/examples/simple/src/ClientTest.cpp +++ b/sdk/examples/simple/src/ClientTest.cpp @@ -113,7 +113,7 @@ ClientTest::CreateCollection(const std::string& collection_name) { field_ptr3->index_params = index_param_3.dump(); field_ptr4->field_name = "field_vec"; - field_ptr4->field_type = milvus::DataType::FLOAT_VECTOR; + field_ptr4->field_type = milvus::DataType::VECTOR_FLOAT; JSON index_param_4; index_param_4["name"] = "index_3"; field_ptr4->index_params = index_param_4.dump(); diff --git a/sdk/examples/simple/src/ClientTest.h b/sdk/examples/simple/src/ClientTest.h index 3e832d188aa9..8d1650f15c70 100644 --- a/sdk/examples/simple/src/ClientTest.h +++ b/sdk/examples/simple/src/ClientTest.h @@ -87,4 +87,4 @@ class ClientTest { std::shared_ptr conn_; std::vector> search_entity_array_; std::vector search_id_array_; -}; +}; \ No newline at end of file diff --git a/sdk/examples/utils/Utils.cpp b/sdk/examples/utils/Utils.cpp index 895d838ab44e..db1775269b94 100644 --- a/sdk/examples/utils/Utils.cpp +++ b/sdk/examples/utils/Utils.cpp @@ -115,8 +115,12 @@ Utils::IndexTypeName(const milvus::IndexType& index_type) { return "SPTAGBKT"; case milvus::IndexType::HNSW: return "HNSW"; + case milvus::IndexType::HNSW_SQ8NM: + return "HNSW_SQ8NM"; case milvus::IndexType::ANNOY: return "ANNOY"; + case milvus::IndexType::IVFSQ8NR: + return "IVFSQ8NR"; default: return "Unknown index type"; } @@ -429,3 +433,4 @@ Utils::PrintTopKQueryResult(milvus::TopKQueryResult& topk_query_result) { } } // namespace milvus_sdk + diff --git a/sdk/examples/utils/Utils.h b/sdk/examples/utils/Utils.h index a0133e01fd39..38e1260fb530 100644 --- a/sdk/examples/utils/Utils.h +++ b/sdk/examples/utils/Utils.h @@ -88,4 +88,4 @@ class Utils { PrintTopKQueryResult(milvus::TopKQueryResult& topk_query_result); }; -} // namespace milvus_sdk +} // namespace milvus_sdk \ No newline at end of file diff --git a/sdk/grpc-gen/gen-milvus/milvus.pb.cc b/sdk/grpc-gen/gen-milvus/milvus.pb.cc index 0e0b1bd36d43..866e9fb02f94 100644 --- a/sdk/grpc-gen/gen-milvus/milvus.pb.cc +++ b/sdk/grpc-gen/gen-milvus/milvus.pb.cc @@ -1341,62 +1341,62 @@ const char descriptor_table_protodef_milvus_2eproto[] PROTOBUF_SECTION_VARIABLE( "ollection_name\030\001 \001(\t\022\033\n\023partition_tag_ar" "ray\030\002 \003(\t\0220\n\rgeneral_query\030\003 \001(\0132\031.milvu" "s.grpc.GeneralQuery\022/\n\014extra_params\030\004 \003(" - "\0132\031.milvus.grpc.KeyValuePair*\237\001\n\010DataTyp" - "e\022\010\n\004NULL\020\000\022\010\n\004INT8\020\001\022\t\n\005INT16\020\002\022\t\n\005INT3" - "2\020\003\022\t\n\005INT64\020\004\022\n\n\006STRING\020\024\022\010\n\004BOOL\020\036\022\t\n\005" - "FLOAT\020(\022\n\n\006DOUBLE\020)\022\020\n\014FLOAT_VECTOR\020d\022\021\n" - "\rBINARY_VECTOR\020e\022\014\n\007UNKNOWN\020\217N*C\n\017Compar" - "eOperator\022\006\n\002LT\020\000\022\007\n\003LTE\020\001\022\006\n\002EQ\020\002\022\006\n\002GT" - "\020\003\022\007\n\003GTE\020\004\022\006\n\002NE\020\005*8\n\005Occur\022\013\n\007INVALID\020" - "\000\022\010\n\004MUST\020\001\022\n\n\006SHOULD\020\002\022\014\n\010MUST_NOT\020\0032\360\016" - "\n\rMilvusService\022\?\n\020CreateCollection\022\024.mi" - "lvus.grpc.Mapping\032\023.milvus.grpc.Status\"\000" - "\022F\n\rHasCollection\022\033.milvus.grpc.Collecti" - "onName\032\026.milvus.grpc.BoolReply\"\000\022I\n\022Desc" - "ribeCollection\022\033.milvus.grpc.CollectionN" - "ame\032\024.milvus.grpc.Mapping\"\000\022Q\n\017CountColl" - "ection\022\033.milvus.grpc.CollectionName\032\037.mi" - "lvus.grpc.CollectionRowCount\"\000\022J\n\017ShowCo" - "llections\022\024.milvus.grpc.Command\032\037.milvus" - ".grpc.CollectionNameList\"\000\022P\n\022ShowCollec" - "tionInfo\022\033.milvus.grpc.CollectionName\032\033." - "milvus.grpc.CollectionInfo\"\000\022D\n\016DropColl" - "ection\022\033.milvus.grpc.CollectionName\032\023.mi" - "lvus.grpc.Status\"\000\022=\n\013CreateIndex\022\027.milv" - "us.grpc.IndexParam\032\023.milvus.grpc.Status\"" - "\000\022G\n\rDescribeIndex\022\033.milvus.grpc.Collect" - "ionName\032\027.milvus.grpc.IndexParam\"\000\022;\n\tDr" - "opIndex\022\027.milvus.grpc.IndexParam\032\023.milvu" - "s.grpc.Status\"\000\022E\n\017CreatePartition\022\033.mil" - "vus.grpc.PartitionParam\032\023.milvus.grpc.St" - "atus\"\000\022E\n\014HasPartition\022\033.milvus.grpc.Par" - "titionParam\032\026.milvus.grpc.BoolReply\"\000\022K\n" - "\016ShowPartitions\022\033.milvus.grpc.Collection" - "Name\032\032.milvus.grpc.PartitionList\"\000\022C\n\rDr" - "opPartition\022\033.milvus.grpc.PartitionParam" - "\032\023.milvus.grpc.Status\"\000\022<\n\006Insert\022\030.milv" - "us.grpc.InsertParam\032\026.milvus.grpc.Entity" - "Ids\"\000\022E\n\rGetEntityByID\022\033.milvus.grpc.Ent" - "ityIdentity\032\025.milvus.grpc.Entities\"\000\022H\n\014" - "GetEntityIDs\022\036.milvus.grpc.GetEntityIDsP" - "aram\032\026.milvus.grpc.EntityIds\"\000\022>\n\006Search" - "\022\030.milvus.grpc.SearchParam\032\030.milvus.grpc" - ".QueryResult\"\000\022F\n\nSearchByID\022\034.milvus.gr" - "pc.SearchByIDParam\032\030.milvus.grpc.QueryRe" - "sult\"\000\022L\n\rSearchInFiles\022\037.milvus.grpc.Se" - "archInFilesParam\032\030.milvus.grpc.QueryResu" - "lt\"\000\0227\n\003Cmd\022\024.milvus.grpc.Command\032\030.milv" - "us.grpc.StringReply\"\000\022A\n\nDeleteByID\022\034.mi" - "lvus.grpc.DeleteByIDParam\032\023.milvus.grpc." - "Status\"\000\022G\n\021PreloadCollection\022\033.milvus.g" - "rpc.CollectionName\032\023.milvus.grpc.Status\"" - "\000\022I\n\016ReloadSegments\022 .milvus.grpc.ReLoad" - "SegmentsParam\032\023.milvus.grpc.Status\"\000\0227\n\005" - "Flush\022\027.milvus.grpc.FlushParam\032\023.milvus." - "grpc.Status\"\000\022=\n\007Compact\022\033.milvus.grpc.C" - "ollectionName\032\023.milvus.grpc.Status\"\000\022B\n\010" - "SearchPB\022\032.milvus.grpc.SearchParamPB\032\030.m" - "ilvus.grpc.QueryResult\"\000b\006proto3" + "\0132\031.milvus.grpc.KeyValuePair*\236\001\n\010DataTyp" + "e\022\010\n\004NONE\020\000\022\010\n\004BOOL\020\001\022\010\n\004INT8\020\002\022\t\n\005INT16" + "\020\003\022\t\n\005INT32\020\004\022\t\n\005INT64\020\005\022\t\n\005FLOAT\020\n\022\n\n\006D" + "OUBLE\020\013\022\n\n\006STRING\020\024\022\021\n\rVECTOR_BINARY\020d\022\020" + "\n\014VECTOR_FLOAT\020e\022\013\n\006VECTOR\020\310\001*C\n\017Compare" + "Operator\022\006\n\002LT\020\000\022\007\n\003LTE\020\001\022\006\n\002EQ\020\002\022\006\n\002GT\020" + "\003\022\007\n\003GTE\020\004\022\006\n\002NE\020\005*8\n\005Occur\022\013\n\007INVALID\020\000" + "\022\010\n\004MUST\020\001\022\n\n\006SHOULD\020\002\022\014\n\010MUST_NOT\020\0032\360\016\n" + "\rMilvusService\022\?\n\020CreateCollection\022\024.mil" + "vus.grpc.Mapping\032\023.milvus.grpc.Status\"\000\022" + "F\n\rHasCollection\022\033.milvus.grpc.Collectio" + "nName\032\026.milvus.grpc.BoolReply\"\000\022I\n\022Descr" + "ibeCollection\022\033.milvus.grpc.CollectionNa" + "me\032\024.milvus.grpc.Mapping\"\000\022Q\n\017CountColle" + "ction\022\033.milvus.grpc.CollectionName\032\037.mil" + "vus.grpc.CollectionRowCount\"\000\022J\n\017ShowCol" + "lections\022\024.milvus.grpc.Command\032\037.milvus." + "grpc.CollectionNameList\"\000\022P\n\022ShowCollect" + "ionInfo\022\033.milvus.grpc.CollectionName\032\033.m" + "ilvus.grpc.CollectionInfo\"\000\022D\n\016DropColle" + "ction\022\033.milvus.grpc.CollectionName\032\023.mil" + "vus.grpc.Status\"\000\022=\n\013CreateIndex\022\027.milvu" + "s.grpc.IndexParam\032\023.milvus.grpc.Status\"\000" + "\022G\n\rDescribeIndex\022\033.milvus.grpc.Collecti" + "onName\032\027.milvus.grpc.IndexParam\"\000\022;\n\tDro" + "pIndex\022\027.milvus.grpc.IndexParam\032\023.milvus" + ".grpc.Status\"\000\022E\n\017CreatePartition\022\033.milv" + "us.grpc.PartitionParam\032\023.milvus.grpc.Sta" + "tus\"\000\022E\n\014HasPartition\022\033.milvus.grpc.Part" + "itionParam\032\026.milvus.grpc.BoolReply\"\000\022K\n\016" + "ShowPartitions\022\033.milvus.grpc.CollectionN" + "ame\032\032.milvus.grpc.PartitionList\"\000\022C\n\rDro" + "pPartition\022\033.milvus.grpc.PartitionParam\032" + "\023.milvus.grpc.Status\"\000\022<\n\006Insert\022\030.milvu" + "s.grpc.InsertParam\032\026.milvus.grpc.EntityI" + "ds\"\000\022E\n\rGetEntityByID\022\033.milvus.grpc.Enti" + "tyIdentity\032\025.milvus.grpc.Entities\"\000\022H\n\014G" + "etEntityIDs\022\036.milvus.grpc.GetEntityIDsPa" + "ram\032\026.milvus.grpc.EntityIds\"\000\022>\n\006Search\022" + "\030.milvus.grpc.SearchParam\032\030.milvus.grpc." + "QueryResult\"\000\022F\n\nSearchByID\022\034.milvus.grp" + "c.SearchByIDParam\032\030.milvus.grpc.QueryRes" + "ult\"\000\022L\n\rSearchInFiles\022\037.milvus.grpc.Sea" + "rchInFilesParam\032\030.milvus.grpc.QueryResul" + "t\"\000\0227\n\003Cmd\022\024.milvus.grpc.Command\032\030.milvu" + "s.grpc.StringReply\"\000\022A\n\nDeleteByID\022\034.mil" + "vus.grpc.DeleteByIDParam\032\023.milvus.grpc.S" + "tatus\"\000\022G\n\021PreloadCollection\022\033.milvus.gr" + "pc.CollectionName\032\023.milvus.grpc.Status\"\000" + "\022I\n\016ReloadSegments\022 .milvus.grpc.ReLoadS" + "egmentsParam\032\023.milvus.grpc.Status\"\000\0227\n\005F" + "lush\022\027.milvus.grpc.FlushParam\032\023.milvus.g" + "rpc.Status\"\000\022=\n\007Compact\022\033.milvus.grpc.Co" + "llectionName\032\023.milvus.grpc.Status\"\000\022B\n\010S" + "earchPB\022\032.milvus.grpc.SearchParamPB\032\030.mi" + "lvus.grpc.QueryResult\"\000b\006proto3" ; static const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable*const descriptor_table_milvus_2eproto_deps[1] = { &::descriptor_table_status_2eproto, @@ -1446,7 +1446,7 @@ static ::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase*const descriptor_table_mil static ::PROTOBUF_NAMESPACE_ID::internal::once_flag descriptor_table_milvus_2eproto_once; static bool descriptor_table_milvus_2eproto_initialized = false; const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_milvus_2eproto = { - &descriptor_table_milvus_2eproto_initialized, descriptor_table_protodef_milvus_2eproto, "milvus.proto", 6552, + &descriptor_table_milvus_2eproto_initialized, descriptor_table_protodef_milvus_2eproto, "milvus.proto", 6551, &descriptor_table_milvus_2eproto_once, descriptor_table_milvus_2eproto_sccs, descriptor_table_milvus_2eproto_deps, 40, 1, schemas, file_default_instances, TableStruct_milvus_2eproto::offsets, file_level_metadata_milvus_2eproto, 41, file_level_enum_descriptors_milvus_2eproto, file_level_service_descriptors_milvus_2eproto, @@ -1467,13 +1467,13 @@ bool DataType_IsValid(int value) { case 2: case 3: case 4: + case 5: + case 10: + case 11: case 20: - case 30: - case 40: - case 41: case 100: case 101: - case 9999: + case 200: return true; default: return false; diff --git a/sdk/grpc-gen/gen-milvus/milvus.pb.h b/sdk/grpc-gen/gen-milvus/milvus.pb.h index 8a884cea9869..d5cb33dab4ba 100644 --- a/sdk/grpc-gen/gen-milvus/milvus.pb.h +++ b/sdk/grpc-gen/gen-milvus/milvus.pb.h @@ -230,24 +230,24 @@ namespace milvus { namespace grpc { enum DataType : int { - NULL_ = 0, - INT8 = 1, - INT16 = 2, - INT32 = 3, - INT64 = 4, + NONE = 0, + BOOL = 1, + INT8 = 2, + INT16 = 3, + INT32 = 4, + INT64 = 5, + FLOAT = 10, + DOUBLE = 11, STRING = 20, - BOOL = 30, - FLOAT = 40, - DOUBLE = 41, - FLOAT_VECTOR = 100, - BINARY_VECTOR = 101, - UNKNOWN = 9999, + VECTOR_BINARY = 100, + VECTOR_FLOAT = 101, + VECTOR = 200, DataType_INT_MIN_SENTINEL_DO_NOT_USE_ = std::numeric_limits<::PROTOBUF_NAMESPACE_ID::int32>::min(), DataType_INT_MAX_SENTINEL_DO_NOT_USE_ = std::numeric_limits<::PROTOBUF_NAMESPACE_ID::int32>::max() }; bool DataType_IsValid(int value); -constexpr DataType DataType_MIN = NULL_; -constexpr DataType DataType_MAX = UNKNOWN; +constexpr DataType DataType_MIN = NONE; +constexpr DataType DataType_MAX = VECTOR; constexpr int DataType_ARRAYSIZE = DataType_MAX + 1; const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* DataType_descriptor(); diff --git a/sdk/grpc/ClientProxy.cpp b/sdk/grpc/ClientProxy.cpp index a84e07a6f0c3..7d924b1f840f 100644 --- a/sdk/grpc/ClientProxy.cpp +++ b/sdk/grpc/ClientProxy.cpp @@ -327,7 +327,7 @@ CopyEntityToJson(::milvus::grpc::Entities& grpc_entities, JSON& json_entity) { double_data.insert(std::make_pair(grpc_field.field_name(), data)); break; } - case ::milvus::grpc::FLOAT_VECTOR: { + case ::milvus::grpc::VECTOR_FLOAT: { std::vector data(row_num); for (int j = 0; j < row_num; j++) { size_t dim = grpc_vector_record.records(j).float_data_size(); @@ -338,7 +338,7 @@ CopyEntityToJson(::milvus::grpc::Entities& grpc_entities, JSON& json_entity) { vector_data.insert(std::make_pair(grpc_field.field_name(), data)); break; } - case ::milvus::grpc::BINARY_VECTOR: { + case ::milvus::grpc::VECTOR_BINARY: { // TODO (yukun) } default: {} diff --git a/sdk/include/Field.h b/sdk/include/Field.h index 11f55ef87871..542520aa05b1 100644 --- a/sdk/include/Field.h +++ b/sdk/include/Field.h @@ -20,20 +20,21 @@ namespace milvus { enum class DataType { - INT8 = 1, - INT16 = 2, - INT32 = 3, - INT64 = 4, + NONE = 0, + BOOL = 1, + INT8 = 2, + INT16 = 3, + INT32 = 4, + INT64 = 5, - STRING = 20, - - BOOL = 30, + FLOAT = 10, + DOUBLE = 11, - FLOAT = 40, - DOUBLE = 41, + STRING = 20, - FLOAT_VECTOR = 100, - BINARY_VECTOR = 101, + VECTOR_BINARY = 100, + VECTOR_FLOAT = 101, + VECTOR = 200, UNKNOWN = 9999, }; diff --git a/sdk/include/MilvusApi.h b/sdk/include/MilvusApi.h index 148b402482da..2a2e6ca27a9c 100644 --- a/sdk/include/MilvusApi.h +++ b/sdk/include/MilvusApi.h @@ -40,6 +40,8 @@ enum class IndexType { SPTAGBKT = 8, HNSW = 11, ANNOY = 12, + IVFSQ8NR = 13, + HNSW_SQ8NM = 14, }; enum class MetricType { diff --git a/tests/milvus-java-test/pom.xml b/tests/milvus-java-test/pom.xml index f1df09e2fce6..c89f17474bfe 100644 --- a/tests/milvus-java-test/pom.xml +++ b/tests/milvus-java-test/pom.xml @@ -97,7 +97,7 @@ com.alibaba fastjson - 1.2.70 + 1.2.72 diff --git a/tests/milvus_python_test/collection/test_create_collection.py b/tests/milvus_python_test/collection/test_create_collection.py index 4054105465c3..5ff8bb77a517 100644 --- a/tests/milvus_python_test/collection/test_create_collection.py +++ b/tests/milvus_python_test/collection/test_create_collection.py @@ -300,7 +300,8 @@ def test_create_collection_no_segment_size(self, connect): logging.getLogger().info(res) assert res["segment_size"] == default_segment_size - def test_create_collection_no_metric_type(self, connect): + # TODO: + def _test_create_collection_no_metric_type(self, connect): ''' target: test create collection with no metric_type params method: create collection with corrent params @@ -308,11 +309,11 @@ def test_create_collection_no_metric_type(self, connect): ''' collection_name = gen_unique_str(collection_id) fields = copy.deepcopy(default_fields) - fields["fields"][-1].pop("metric_type") + fields["fields"][-1]["params"].pop("metric_type") connect.create_collection(collection_name, fields) res = connect.get_collection_info(collection_name) logging.getLogger().info(res) - assert result["metric_type"] == "L2" + assert res["metric_type"] == "L2" # TODO: assert exception def test_create_collection_limit_fields(self, connect): diff --git a/tests/milvus_python_test/entity/test_delete.py b/tests/milvus_python_test/entity/test_delete.py index 0713fb1a4ac1..d709475f9bc7 100644 --- a/tests/milvus_python_test/entity/test_delete.py +++ b/tests/milvus_python_test/entity/test_delete.py @@ -6,7 +6,6 @@ import logging from multiprocessing import Pool, Process import pytest -from milvus import IndexType, MetricType from utils import * diff --git a/tests/milvus_python_test/entity/test_get_entity_by_id.py b/tests/milvus_python_test/entity/test_get_entity_by_id.py index 02fb7adaa35c..66264457a35e 100644 --- a/tests/milvus_python_test/entity/test_get_entity_by_id.py +++ b/tests/milvus_python_test/entity/test_get_entity_by_id.py @@ -584,9 +584,9 @@ def test_get_entities_with_invalid_collection_name(self, connect, get_collection res = connect.get_entity_by_id(collection_name, ids) @pytest.mark.level(2) - def test_get_entities_with_invalid_field_name(self, connect, get_field_name): + def test_get_entities_with_invalid_field_name(self, connect, collection, get_field_name): field_name = get_field_name ids = [1] fields = [field_name] with pytest.raises(Exception): - res = connect.get_entity_by_id(collection_name, ids, fields=fields) \ No newline at end of file + res = connect.get_entity_by_id(collection, ids, fields=fields) \ No newline at end of file diff --git a/tests/milvus_python_test/entity/test_insert.py b/tests/milvus_python_test/entity/test_insert.py index 917699facc85..d4af3af1cd4e 100644 --- a/tests/milvus_python_test/entity/test_insert.py +++ b/tests/milvus_python_test/entity/test_insert.py @@ -5,6 +5,7 @@ import logging from multiprocessing import Pool, Process import pytest +from milvus import DataType from utils import * dim = 128 @@ -62,6 +63,26 @@ def get_filter_field(self, request): def get_vector_field(self, request): yield request.param + def test_add_vector_with_empty_vector(self, connect, collection): + ''' + target: test add vectors with empty vectors list + method: set empty vectors list as add method params + expected: raises a Exception + ''' + vector = [] + with pytest.raises(Exception) as e: + status, ids = connect.insert(collection, vector) + + def test_add_vector_with_None(self, connect, collection): + ''' + target: test add vectors with None + method: set None as add method params + expected: raises a Exception + ''' + vector = None + with pytest.raises(Exception) as e: + status, ids = connect.insert(collection, vector) + @pytest.mark.timeout(ADD_TIMEOUT) def test_insert_collection_not_existed(self, connect): ''' @@ -195,8 +216,8 @@ def test_insert_ids_fields(self, connect, get_filter_field, get_vector_field): entities = gen_entities_by_fields(fields["fields"], nb, dim) res_ids = connect.insert(collection_name, entities, ids) assert res_ids == ids - connect.flush([collection]) - res_count = connect.count_entities(collection) + connect.flush([collection_name]) + res_count = connect.count_entities(collection_name) assert res_count == nb # TODO: assert exception @@ -372,9 +393,10 @@ def test_insert_with_field_name_not_match(self, connect, collection): ''' tmp_entity = update_field_name(copy.deepcopy(entity), "int8", "int8new") with pytest.raises(Exception): - connect.insert(collection_name, tmp_entity) + connect.insert(collection, tmp_entity) - def test_insert_with_field_type_not_match(self, connect, collection): + # TODO: Python sdk needs to do check + def _test_insert_with_field_type_not_match(self, connect, collection): ''' target: test insert entities, with the entity field type updated method: update entity field type @@ -382,9 +404,10 @@ def test_insert_with_field_type_not_match(self, connect, collection): ''' tmp_entity = update_field_type(copy.deepcopy(entity), DataType.INT8, DataType.FLOAT) with pytest.raises(Exception): - connect.insert(collection_name, tmp_entity) + connect.insert(collection, tmp_entity) - def test_insert_with_field_value_not_match(self, connect, collection): + # TODO: Python sdk needs to do check + def _test_insert_with_field_value_not_match(self, connect, collection): ''' target: test insert entities, with the entity field value updated method: update entity field value @@ -392,7 +415,7 @@ def test_insert_with_field_value_not_match(self, connect, collection): ''' tmp_entity = update_field_value(copy.deepcopy(entity), 'int8', 's') with pytest.raises(Exception): - connect.insert(collection_name, tmp_entity) + connect.insert(collection, tmp_entity) def test_insert_with_field_more(self, connect, collection): ''' @@ -556,8 +579,9 @@ def test_insert_async_callback(self, connect, collection, insert_count): future = connect.insert(collection, gen_entities(nb), _async=True, _callback=self.check_status) future.done() + # TODO: @pytest.mark.level(2) - def test_insert_async_long(self, connect, collection): + def _test_insert_async_long(self, connect, collection): ''' target: test insert vectors with different length of vectors method: set different vectors as insert method params diff --git a/tests/milvus_python_test/entity/test_search.py b/tests/milvus_python_test/entity/test_search.py index 5475e391be83..5491b24ad704 100644 --- a/tests/milvus_python_test/entity/test_search.py +++ b/tests/milvus_python_test/entity/test_search.py @@ -5,7 +5,9 @@ import logging from multiprocessing import Pool, Process import pytest -from milvus import IndexType, MetricType +import numpy as np + +from milvus import DataType from utils import * dim = 128 @@ -366,11 +368,6 @@ def _test_search_index_partitions_B(self, connect, collection, get_simple_index, # @pytest.mark.level(2) def test_search_ip_flat(self, connect, ip_collection, get_simple_index, get_top_k, get_nq): - ''' - target: test basic search fuction, all the search params is corrent, test all index params, and build - method: search with the given vectors, check the result - expected: the length of the result is top_k - ''' ''' target: test basic search fuction, all the search params is corrent, change top-k value method: search with the given vectors, check the result diff --git a/tests/milvus_python_test/stability/test_mysql.py b/tests/milvus_python_test/stability/test_mysql.py new file mode 100644 index 000000000000..dd9726d10ff9 --- /dev/null +++ b/tests/milvus_python_test/stability/test_mysql.py @@ -0,0 +1,53 @@ +import time +import random +import pdb +import threading +import logging +from multiprocessing import Pool, Process +import pytest +from milvus import IndexType, MetricType +from utils import * + + +dim = 128 +index_file_size = 10 +collection_id = "mysql_failure" +nprobe = 1 +tag = "1970-01-01" + + +class TestMysql: + + """ + ****************************************************************** + The following cases are used to test mysql failure + ****************************************************************** + """ + @pytest.fixture(scope="function", autouse=True) + def skip_check(self, connect, args): + if args["service_name"].find("shards") != -1: + reason = "Skip restart cases in shards mode" + logging.getLogger().info(reason) + pytest.skip(reason) + + def _test_kill_mysql_during_index(self, connect, collection, args): + big_nb = 20000 + index_param = {"nlist": 1024, "m": 16} + index_type = IndexType.IVF_PQ + vectors = gen_vectors(big_nb, dim) + status, ids = connect.insert(collection, vectors, ids=[i for i in range(big_nb)]) + status = connect.flush([collection]) + assert status.OK() + status, res_count = connect.count_entities(collection) + logging.getLogger().info(res_count) + assert status.OK() + assert res_count == big_nb + logging.getLogger().info("Start create index async") + status = connect.create_index(collection, index_type, index_param, _async=True) + time.sleep(2) + logging.getLogger().info("Start play mysql failure") + # pass + new_connect = get_milvus(args["ip"], args["port"], handler=args["handler"]) + status, res_count = new_connect.count_entities(collection) + assert status.OK() + assert res_count == big_nb diff --git a/tests/milvus_python_test/test_config.py b/tests/milvus_python_test/test_config.py index c7f718269d32..fe056b840f00 100644 --- a/tests/milvus_python_test/test_config.py +++ b/tests/milvus_python_test/test_config.py @@ -5,7 +5,6 @@ import logging from multiprocessing import Pool, Process import pytest -from milvus import IndexType, MetricType from utils import * import ujson @@ -187,9 +186,9 @@ def test_set_cache_size_valid(self, connect, collection): expected: status ok, set successfully ''' self.reset_configs(connect) - relpy = connect.set_config("cache", "cache_size", '8GB') + relpy = connect.set_config("cache", "cache_size", '2GB') config_value = connect.get_config("cache", "cache_size") - assert config_value == '8GB' + assert config_value == '2GB' @pytest.mark.level(2) def test_set_cache_size_valid_multiple_times(self, connect, collection): @@ -204,9 +203,9 @@ def test_set_cache_size_valid_multiple_times(self, connect, collection): config_value = connect.get_config("cache", "cache_size") assert config_value == '4GB' for i in range(20): - relpy = connect.set_config("cache", "cache_size", '8GB') + relpy = connect.set_config("cache", "cache_size", '2GB') config_value = connect.get_config("cache", "cache_size") - assert config_value == '8GB' + assert config_value == '2GB' @pytest.mark.level(2) def test_set_insert_buffer_size_invalid_parent_key(self, connect, collection): diff --git a/tests/milvus_python_test/test_mix.py b/tests/milvus_python_test/test_mix.py index 488f13852c14..e5432b39fcc1 100644 --- a/tests/milvus_python_test/test_mix.py +++ b/tests/milvus_python_test/test_mix.py @@ -13,7 +13,7 @@ dim = 128 index_file_size = 10 collection_id = "test_mix" -add_interval_time = 2 +add_interval_time = 5 vectors = gen_vectors(10000, dim) vectors = sklearn.preprocessing.normalize(vectors, axis=1, norm='l2') vectors = vectors.tolist() diff --git a/tests/milvus_python_test/test_ping.py b/tests/milvus_python_test/test_ping.py index 8bd47decc4d3..e7a06f68c402 100644 --- a/tests/milvus_python_test/test_ping.py +++ b/tests/milvus_python_test/test_ping.py @@ -9,7 +9,7 @@ def test_server_version(self, connect): ''' target: test get the server version method: call the server_version method after connected - expected: version should be the pymilvus version + expected: version should be the milvus version ''' res = connect.server_version() assert res == __version__ @@ -48,6 +48,56 @@ def test_connected(self, connect): assert connect +class TestPingWithTimeout: + def test_server_version_legal_timeout(self, connect): + ''' + target: test get the server version with legal timeout + method: call the server_version method after connected with altering timeout + expected: version should be the milvus version + ''' + res = connect.server_version(20) + assert res == __version__ + + def test_server_version_negative_timeout(self, connect): + ''' + target: test get the server version with negative timeout + method: call the server_version method after connected with altering timeout + expected: when timeout is illegal raises an error; + ''' + with pytest.raises(Exception) as e: + res = connect.server_version(-1) + + def test_server_cmd_with_params_version_with_legal_timeout(self, connect): + ''' + target: test cmd: version and timeout + method: cmd = "version" , timeout=10 + expected: when cmd = 'version', return version of server; + ''' + cmd = "version" + msg = connect._cmd(cmd, 10) + logging.getLogger().info(msg) + assert msg == __version__ + + def test_server_cmd_with_params_version_with_illegal_timeout(self, connect): + ''' + target: test cmd: version and timeout + method: cmd = "version" , timeout=-1 + expected: when timeout is illegal raises an error; + ''' + with pytest.raises(Exception) as e: + res = connect.server_version(-1) + + def test_server_cmd_with_params_others_with_illegal_timeout(self, connect): + ''' + target: test cmd: lalala, timeout = -1 + method: cmd = "lalala", timeout = -1 + expected: when timeout is illegal raises an error; + ''' + cmd = "rm -rf test" + with pytest.raises(Exception) as e: + res = connect.server_version(-1) + + class TestPingDisconnect: def test_server_version(self, dis_connect): ''' @@ -66,3 +116,13 @@ def test_server_status(self, dis_connect): ''' with pytest.raises(Exception) as e: res = dis_connect.server_status() + + def test_server_version_with_timeout(self, connect): + ''' + target: test get the server status with timeout settings after disconnect + method: call the server_status method after connected + expected: status returned should be not ok + ''' + status = None + with pytest.raises(Exception) as e: + res = connect.server_status(100) diff --git a/tests/milvus_python_test/utils.py b/tests/milvus_python_test/utils.py index 3c14ded3bfef..766b8a7a31da 100644 --- a/tests/milvus_python_test/utils.py +++ b/tests/milvus_python_test/utils.py @@ -200,10 +200,10 @@ def gen_single_filter_fields(): def gen_single_vector_fields(): fields = [] for metric_type in ['HAMMING', 'IP', 'JACCARD', 'L2', 'SUBSTRUCTURE', 'SUPERSTRUCTURE', 'TANIMOTO']: - for data_type in [DataType.VECTOR, DataType.BINARY_VECTOR]: + for data_type in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]: if metric_type in ["L2", "IP"] and data_type == DataType.BINARY_VECTOR: continue - if metric_type not in ["L2", "IP"] and data_type == DataType.VECTOR: + if metric_type not in ["L2", "IP"] and data_type == DataType.FLOAT_VECTOR: continue field = {"field": data_type.name, "type": data_type, "params": {"metric_type": metric_type, "dimension": dimension}} fields.append(field) @@ -216,7 +216,7 @@ def gen_default_fields(): {"field": "int8", "type": DataType.INT8}, {"field": "int64", "type": DataType.INT64}, {"field": "float", "type": DataType.FLOAT}, - {"field": "vector", "type": DataType.VECTOR, "params": {"metric_type": "L2", "dimension": dimension}} + {"field": "vector", "type": DataType.FLOAT_VECTOR, "params": {"metric_type": "L2", "dimension": dimension}} ], "segment_size": segment_size } @@ -229,7 +229,7 @@ def gen_entities(nb, is_normal=False): {"field": "int8", "type": DataType.INT8, "values": [1 for i in range(nb)]}, {"field": "int64", "type": DataType.INT64, "values": [2 for i in range(nb)]}, {"field": "float", "type": DataType.FLOAT, "values": [3.0 for i in range(nb)]}, - {"field": "vector", "type": DataType.VECTOR, "values": vectors} + {"field": "vector", "type": DataType.FLOAT_VECTOR, "values": vectors} ] return entities @@ -254,7 +254,7 @@ def gen_entities_by_fields(fields, nb, dimension): field_value = [3.0 for i in range(nb)] elif field["type"] == DataType.BINARY_VECTOR: field_value = gen_binary_vectors(nb, dimension)[1] - elif field["type"] == DataType.VECTOR: + elif field["type"] == DataType.FLOAT_VECTOR: field_value = gen_vectors(nb, dimension) field.update({"values": field_value}) entities.append(field) @@ -307,7 +307,7 @@ def add_vector_field(entities, is_normal=False): vectors = gen_vectors(nb, dimension, is_normal) field = { "field": gen_unique_str(), - "type": DataType.VECTOR, + "type": DataType.FLOAT_VECTOR, "values": vectors } entities.append(field) @@ -317,7 +317,7 @@ def add_vector_field(entities, is_normal=False): def update_fields_metric_type(fields, metric_type): tmp_fields = copy.deepcopy(fields) if metric_type in ["L2", "IP"]: - tmp_fields["fields"][-1]["type"] = DataType.VECTOR + tmp_fields["fields"][-1]["type"] = DataType.FLOAT_VECTOR else: tmp_fields["fields"][-1]["type"] = DataType.BINARY_VECTOR tmp_fields["fields"][-1]["params"]["metric_type"] = metric_type @@ -360,7 +360,7 @@ def add_vector_field(nb, dimension=dimension): field_name = gen_unique_str() field = { "field": field_name, - "type": DataType.VECTOR, + "type": DataType.FLOAT_VECTOR, "values": gen_vectors(nb, dimension) } return field_name @@ -449,9 +449,9 @@ def gen_invalid_strs(): def gen_invalid_field_types(): field_types = [ - 1, + # 1, "=c", - 0, + # 0, None, "", "a".join("a" for i in range(256))