diff --git a/python/poetry.lock b/python/poetry.lock index 12794851718b..1f76a4e56b4e 100644 --- a/python/poetry.lock +++ b/python/poetry.lock @@ -3321,6 +3321,44 @@ files = [ [package.extras] tests = ["pytest"] +[[package]] +name = "pyarrow" +version = "12.0.1" +description = "Python library for Apache Arrow" +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pyarrow-12.0.1-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:6d288029a94a9bb5407ceebdd7110ba398a00412c5b0155ee9813a40d246c5df"}, + {file = "pyarrow-12.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:345e1828efdbd9aa4d4de7d5676778aba384a2c3add896d995b23d368e60e5af"}, + {file = "pyarrow-12.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8d6009fdf8986332b2169314da482baed47ac053311c8934ac6651e614deacd6"}, + {file = "pyarrow-12.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2d3c4cbbf81e6dd23fe921bc91dc4619ea3b79bc58ef10bce0f49bdafb103daf"}, + {file = "pyarrow-12.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:cdacf515ec276709ac8042c7d9bd5be83b4f5f39c6c037a17a60d7ebfd92c890"}, + {file = "pyarrow-12.0.1-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:749be7fd2ff260683f9cc739cb862fb11be376de965a2a8ccbf2693b098db6c7"}, + {file = "pyarrow-12.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6895b5fb74289d055c43db3af0de6e16b07586c45763cb5e558d38b86a91e3a7"}, + {file = "pyarrow-12.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1887bdae17ec3b4c046fcf19951e71b6a619f39fa674f9881216173566c8f718"}, + {file = "pyarrow-12.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2c9cb8eeabbadf5fcfc3d1ddea616c7ce893db2ce4dcef0ac13b099ad7ca082"}, + {file = "pyarrow-12.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:ce4aebdf412bd0eeb800d8e47db854f9f9f7e2f5a0220440acf219ddfddd4f63"}, + {file = "pyarrow-12.0.1-cp37-cp37m-macosx_10_14_x86_64.whl", hash = "sha256:e0d8730c7f6e893f6db5d5b86eda42c0a130842d101992b581e2138e4d5663d3"}, + {file = "pyarrow-12.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:43364daec02f69fec89d2315f7fbfbeec956e0d991cbbef471681bd77875c40f"}, + {file = "pyarrow-12.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:051f9f5ccf585f12d7de836e50965b3c235542cc896959320d9776ab93f3b33d"}, + {file = "pyarrow-12.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:be2757e9275875d2a9c6e6052ac7957fbbfc7bc7370e4a036a9b893e96fedaba"}, + {file = "pyarrow-12.0.1-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:cf812306d66f40f69e684300f7af5111c11f6e0d89d6b733e05a3de44961529d"}, + {file = "pyarrow-12.0.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:459a1c0ed2d68671188b2118c63bac91eaef6fc150c77ddd8a583e3c795737bf"}, + {file = "pyarrow-12.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:85e705e33eaf666bbe508a16fd5ba27ca061e177916b7a317ba5a51bee43384c"}, + {file = "pyarrow-12.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9120c3eb2b1f6f516a3b7a9714ed860882d9ef98c4b17edcdc91d95b7528db60"}, + {file = "pyarrow-12.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:c780f4dc40460015d80fcd6a6140de80b615349ed68ef9adb653fe351778c9b3"}, + {file = "pyarrow-12.0.1-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:a3c63124fc26bf5f95f508f5d04e1ece8cc23a8b0af2a1e6ab2b1ec3fdc91b24"}, + {file = "pyarrow-12.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b13329f79fa4472324f8d32dc1b1216616d09bd1e77cfb13104dec5463632c36"}, + {file = "pyarrow-12.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb656150d3d12ec1396f6dde542db1675a95c0cc8366d507347b0beed96e87ca"}, + {file = "pyarrow-12.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6251e38470da97a5b2e00de5c6a049149f7b2bd62f12fa5dbb9ac674119ba71a"}, + {file = "pyarrow-12.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:3de26da901216149ce086920547dfff5cd22818c9eab67ebc41e863a5883bac7"}, + {file = "pyarrow-12.0.1.tar.gz", hash = "sha256:cce317fc96e5b71107bf1f9f184d5e54e2bd14bbf3f9a3d62819961f0af86fec"}, +] + +[package.dependencies] +numpy = ">=1.16.6" + [[package]] name = "pyasn1" version = "0.5.0" @@ -4844,6 +4882,38 @@ files = [ {file = "tzdata-2023.3.tar.gz", hash = "sha256:11ef1e08e54acb0d4f95bdb1be05da659673de4acbd21bf9c69e94cc5e907a3a"}, ] +[[package]] +name = "ucall" +version = "0.5.1" +description = "Up to 100x Faster FastAPI. JSON-RPC with io_uring, SIMD-acceleration, and pure CPython bindings" +category = "dev" +optional = false +python-versions = ">=3.9" +files = [ + {file = "ucall-0.5.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:15cd5ca5b15d198775d0b80532f41579f4ae7bf3693b86b4ac5f5ff1ed0be1d8"}, + {file = "ucall-0.5.1-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:577018de6f01651ba53ac7c8867ddd9b92cc79f98fbb4c0513fcc22a8d58e007"}, + {file = "ucall-0.5.1-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:a18ac9f297ef08e928b59c55ad75cba34511e7d4816af2fcb986043a8ecf719d"}, + {file = "ucall-0.5.1-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ef678d300edcb8d1d6d3af65b63034f2b09873b75b9fcb323eaec0d824cadff7"}, + {file = "ucall-0.5.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5f1d3709b4a977c9bdefba3098edd8bc8ae37855f40ccf29f6580f195d7e2b09"}, + {file = "ucall-0.5.1-cp310-cp310-win_amd64.whl", hash = "sha256:fcc1df86e1129bacdbb17662892e02f20189e74f827c0162da56fb2490df87ed"}, + {file = "ucall-0.5.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:406131c6d0c74035ee7b13131e2e674ca571607bc4c7b3f47c4758f9a5b8724d"}, + {file = "ucall-0.5.1-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:9a61bf176ce73df006507bdf2a30098b3519e534969886b2caf3d2dc479fda0c"}, + {file = "ucall-0.5.1-cp311-cp311-macosx_11_0_x86_64.whl", hash = "sha256:29bfcd675c458d23c492e86cf703443f1f4ba266a92880bd943de1ead8e27ddd"}, + {file = "ucall-0.5.1-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:26c79e67dbc7ecf6d925c8ce2f291db281961a8db13d3470f4dafb1c32d72a4b"}, + {file = "ucall-0.5.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fba80f2094597dfa182da47cd5a71a5420da2d171aa5067a7a6efbc196eeb86e"}, + {file = "ucall-0.5.1-cp311-cp311-win_amd64.whl", hash = "sha256:562632065fd36968ec92cc8b51033276449610465ef54916ec05bf9505be6b8b"}, + {file = "ucall-0.5.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2d9b1cb358b6967023dcdbf1e9744501a687be75b3dc9b5fd5b3f177f18714a0"}, + {file = "ucall-0.5.1-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:e99eb1c64837b596281a41a641404bcb618c5037d75d652fc1f2a9b8c38aaed1"}, + {file = "ucall-0.5.1-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:9b3e346a93c1ff7d2eef9cf03d7f515be4fc6fed195939cfa37cda6ac36a2514"}, + {file = "ucall-0.5.1-cp39-cp39-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:358ae439f17fd05e70baa8809b38c5ff1146cc3fe77e91d5d50288d9154484af"}, + {file = "ucall-0.5.1-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6012d8d374535b87931dc48cd8ad34529e0e8ae5f30f8a301a8784c39fe7d013"}, + {file = "ucall-0.5.1-cp39-cp39-win_amd64.whl", hash = "sha256:1a96b1df8d1e74afaca42a9b58d6b27a9b4a8444f83e17c0dfd76f9b4c3f3b20"}, +] + +[package.dependencies] +numpy = "*" +pillow = "*" + [[package]] name = "ujson" version = "5.8.0" @@ -4932,6 +5002,50 @@ brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"] secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"] socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] +[[package]] +name = "usearch" +version = "1.1.1" +description = "Smaller & Faster Single-File Vector Search Engine from Unum" +category = "dev" +optional = false +python-versions = "*" +files = [ + {file = "usearch-1.1.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:1a68a223be42573a036c76e516f30c076b16dd38d8cfe9ca79a1cc0e4d60e8a8"}, + {file = "usearch-1.1.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bf2d8246a23e5a232a9389f4efd25e0bd10624a96f0f95d0cd415945a6be84ee"}, + {file = "usearch-1.1.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8692dbd0b66874e6b01e2dd7c50216347f52d175bc7e072712a5e0646ec9295b"}, + {file = "usearch-1.1.1-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:5be4ede1592b056714e3f529cd17e69907364e3c0ee6eee5cf1498f946f0c2ec"}, + {file = "usearch-1.1.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:6230b0583779f43dba2da3865dd398f8cf88daa6427d60afff3348bbdea6652f"}, + {file = "usearch-1.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:eb731f74a7a8208e0fa5b04d9488d1dfc177e253b9c761687cb51d38138d5b93"}, + {file = "usearch-1.1.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:7fbb82767109d03c807678664ab02383e31db778adb1d3602da347613fdbf15e"}, + {file = "usearch-1.1.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c7e1f81a92f3fcc091400f2997b7b12b6d53f7abf4edf87e8a17b5eede350431"}, + {file = "usearch-1.1.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7027cdc4c733d6926fc2a58e77cb9b14528a3f585b5d738ad6c5f14dc6e027ca"}, + {file = "usearch-1.1.1-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:e1bb1b238390cc990d051a07fe2a0f173e60bc9e82b7f0f34eb9ddf5bef2b1f8"}, + {file = "usearch-1.1.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:42711bad96f487f5d31ac7be1685797fb4b26904328bc855182e8d6c83b9e538"}, + {file = "usearch-1.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:ebc34eb7cf0b9f7e52b0f176c48d374f19668ad9653533bdd2e5be1463435d66"}, + {file = "usearch-1.1.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:36059015e49f9ea303a1d823b80113ce96833330563a54ceac447e4218d63a2c"}, + {file = "usearch-1.1.1-cp37-cp37m-manylinux_2_28_aarch64.whl", hash = "sha256:918830e1876064227f0a6a17bd4f92c985d8df4856b0370c7729b6e43721b3cc"}, + {file = "usearch-1.1.1-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:a59dfd5d848c484448470e613514f636cf42acac3eab1a9fb9b98d9511de2a97"}, + {file = "usearch-1.1.1-cp37-cp37m-win_amd64.whl", hash = "sha256:70d1f5148a1032da5b0e6651371d29643bf302c0d24a2896d6969d504fccac15"}, + {file = "usearch-1.1.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5112ebd545ad63b7a63d68838da8a56cfcd313c9ade86bfbe30061c946cbc5dc"}, + {file = "usearch-1.1.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0f1e58d11d9dfe1d499e86c108a21f7deb85fe4f40e54b042e057b5df5ead036"}, + {file = "usearch-1.1.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:fbd08ecbf2225f16b9f4b8190cff53de372baddc173e5ba7854425392552013b"}, + {file = "usearch-1.1.1-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:b8040aa9f448ddfaac5528ec1a1c216351cf7a17af35ddf169b239282f7fa4c4"}, + {file = "usearch-1.1.1-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:479fcf8b884d1a822b83c7cfb853c090f0db4386e86ef790f2c820f96de70140"}, + {file = "usearch-1.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:b0338b054fde34ab0a42a786bae43ae1793412995f6a87122850fc0318cb5955"}, + {file = "usearch-1.1.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:12121e7ac82868877ae9e6b513a61c1113afc8a28d793f9350719ef94ac33091"}, + {file = "usearch-1.1.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:33255b29bd7fc1feb6584887f6489bf9f896bd9d79b9ce423ff7185b2c2059e5"}, + {file = "usearch-1.1.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9bb5464473e8ceeef6237285fc0e86a0b77a75304397db3365cb011761fd6abe"}, + {file = "usearch-1.1.1-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:6d84c5771aa37584a335f4b3392185782da785733aab4c3a4ae9949434cbe679"}, + {file = "usearch-1.1.1-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:02167b0c03062a6d28926535ee862401669b6d6f303e99d2cd1232dc610d2a25"}, + {file = "usearch-1.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:8300ba31fcc3ace452429781f517273e1297a5881cff629e2f1c6a3a411a48fc"}, +] + +[package.dependencies] +numpy = "*" +pandas = "*" +tqdm = "*" +ucall = {version = "*", markers = "python_version >= \"3.9\""} + [[package]] name = "uvicorn" version = "0.23.2" @@ -5340,4 +5454,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "2b5b32caa34ffa4783b7381d12ab03bb23737883f5af22149ec502349c5b726c" +content-hash = "36b5293ce1687c9cad6e548c73b691aff0015c2de8e3abeeaa6f6a1edfd92a1b" diff --git a/python/pyproject.toml b/python/pyproject.toml index 48f13379bd3f..64c8f4a76fe7 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -61,6 +61,10 @@ azure-search-documents = {version = "11.4.0b8", allow-prereleases = true} azure-core = "^1.28.0" azure-identity = "^1.13.0" +[tool.poetry.group.usearch.dependencies] +usearch = "^1.1.1" +pyarrow = "^12.0.1" + [tool.isort] profile = "black" diff --git a/python/semantic_kernel/connectors/memory/usearch/__init__.py b/python/semantic_kernel/connectors/memory/usearch/__init__.py new file mode 100644 index 000000000000..f74403f6441f --- /dev/null +++ b/python/semantic_kernel/connectors/memory/usearch/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft. All rights reserved. + +from semantic_kernel.connectors.memory.usearch.usearch_memory_store import ( + USearchMemoryStore, +) + +__all__ = ["USearchMemoryStore"] diff --git a/python/semantic_kernel/connectors/memory/usearch/usearch_memory_store.py b/python/semantic_kernel/connectors/memory/usearch/usearch_memory_store.py new file mode 100644 index 000000000000..cb59d1c953f8 --- /dev/null +++ b/python/semantic_kernel/connectors/memory/usearch/usearch_memory_store.py @@ -0,0 +1,638 @@ +# Copyright (c) Microsoft. All rights reserved. + +import itertools +import os +from dataclasses import dataclass +from enum import Enum +from logging import Logger +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import pandas as pd +import pyarrow as pa +import pyarrow.parquet as pq +from numpy import ndarray + +from semantic_kernel.memory.memory_record import MemoryRecord +from semantic_kernel.memory.memory_store_base import MemoryStoreBase +from semantic_kernel.utils.null_logger import NullLogger +from usearch.index import ( + BatchMatches, + CompiledMetric, + Index, + Matches, + MetricKind, + ScalarKind, +) + + +@dataclass +class _USearchCollection: + """Represents a collection for USearch with embeddings and related data. + + Attributes: + embeddings_index (Index): The index of embeddings. + embeddings_data_table (pa.Table): The PyArrow table holding embeddings data. + embeddings_id_to_label (Dict[str, int]): Mapping of embeddings ID to label. + """ + + embeddings_index: Index + embeddings_data_table: pa.Table + embeddings_id_to_label: Dict[str, int] + + @staticmethod + def create_default(embeddings_index: Index) -> "_USearchCollection": + """Create a default `_USearchCollection` using a given embeddings index. + + Args: + embeddings_index (Index): The index of embeddings to be used for the default collection. + + Returns: + _USearchCollection: A default `_USearchCollection` initialized with the given embeddings index. + """ + return _USearchCollection( + embeddings_index, + pa.Table.from_pandas( + pd.DataFrame(columns=_embeddings_data_schema.names), + schema=_embeddings_data_schema, + ), + {}, + ) + + +# PyArrow Schema definition for the embeddings data from `MemoryRecord`. +_embeddings_data_schema = pa.schema( + [ + pa.field("key", pa.string()), + pa.field("timestamp", pa.timestamp("us")), + pa.field("is_reference", pa.bool_()), + pa.field("external_source_name", pa.string()), + pa.field("id", pa.string()), + pa.field("description", pa.string()), + pa.field("text", pa.string()), + pa.field("additional_metadata", pa.string()), + ] +) + + +class _CollectionFileType(Enum): + """Enumeration of file types used for storing collections.""" + + USEARCH = 0 + PARQUET = 1 + + +# Mapping of collection file types to their file extensions. +_collection_file_extensions: Dict[_CollectionFileType, str] = { + _CollectionFileType.USEARCH: ".usearch", + _CollectionFileType.PARQUET: ".parquet", +} + + +def memoryrecords_to_pyarrow_table(records: List[MemoryRecord]) -> pa.Table: + """Convert a list of `MemoryRecord` to a PyArrow Table""" + records_pylist = [ + {attr: getattr(record, "_" + attr) for attr in _embeddings_data_schema.names} + for record in records + ] + return pa.Table.from_pylist(records_pylist, schema=_embeddings_data_schema) + + +def pyarrow_table_to_memoryrecords( + table: pa.Table, vectors: Optional[ndarray] = None +) -> List[MemoryRecord]: + """Convert a PyArrow Table to a list of MemoryRecords. + + Args: + table (pa.Table): The PyArrow Table to convert. + vectors (Optional[ndarray], optional): An array of vectors to include as embeddings in the MemoryRecords. + The length and order of the vectors should match the rows in the table. Defaults to None. + + Returns: + List[MemoryRecord]: List of MemoryRecords constructed from the table. + """ + result_memory_records = [ + MemoryRecord( + **row.to_dict(), embedding=vectors[index] if vectors is not None else None + ) + for index, row in table.to_pandas().iterrows() + ] + + return result_memory_records + + +class USearchMemoryStore(MemoryStoreBase): + def __init__( + self, + persist_directory: Optional[os.PathLike] = None, + logger: Optional[Logger] = None, + ) -> None: + """ + Create a USearchMemoryStore instance. + + This store helps searching embeddings with USearch, keeping collections in memory. + To save collections to disk, provide the `persist_directory` param. + Collections are saved when `close_async` is called. + + To both save collections and free up memory, call `close_async`. + When `USearchMemoryStore` is used with a context manager, this will happen automatically. + Otherwise, it should be called explicitly. + + Args: + persist_directory (Optional[os.PathLike], default=None): Directory for loading and saving collections. + If None, collections are not loaded nor saved. + logger (Optional[Logger], default=None): Logger for diagnostics. If None, a NullLogger is used. + """ + self._logger = logger or NullLogger() + self._persist_directory = ( + Path(persist_directory) if persist_directory is not None else None + ) + + self._collections: Dict[str, _USearchCollection] = {} + if self._persist_directory: + self._collections = self._read_collections_from_dir() + + def _get_collection_path( + self, collection_name: str, *, file_type: _CollectionFileType + ) -> Path: + """ + Get the path for the given collection name and file type. + + Args: + collection_name (str): Name of the collection. + file_type (_CollectionFileType): The file type. + + Returns: + Path: Path to the collection file. + + Raises: + ValueError: If persist directory path is not set. + """ + collection_name = collection_name.lower() + if self._persist_directory is None: + raise ValueError("Path of persist directory is not set") + + return self._persist_directory / ( + collection_name + _collection_file_extensions[file_type] + ) + + async def create_collection_async( + self, + collection_name: str, + ndim: int = 0, + metric: Union[str, MetricKind, CompiledMetric] = MetricKind.IP, + dtype: Optional[Union[str, ScalarKind]] = None, + connectivity: Optional[int] = None, + expansion_add: Optional[int] = None, + expansion_search: Optional[int] = None, + view: bool = False, + ) -> None: + """Create a new collection. + + Args: + collection_name (str): Name of the collection. Case-insensitive. + Must have name that is valid file name for the current OS environment. + ndim (int, optional): Number of dimensions. Defaults to 0. + metric (Union[str, MetricKind, CompiledMetric], optional): Metric kind. Defaults to MetricKind.IP. + dtype (Optional[Union[str, ScalarKind]], optional): Data type. Defaults to None. + connectivity (int, optional): Connectivity parameter. Defaults to None. + expansion_add (int, optional): Expansion add parameter. Defaults to None. + expansion_search (int, optional): Expansion search parameter. Defaults to None. + view (bool, optional): Viewing flag. Defaults to False. + + Raises: + ValueError: If collection with the given name already exists. + ValueError: If collection name is empty string. + """ + collection_name = collection_name.lower() + if not collection_name: + raise ValueError("Collection name can not be empty.") + if collection_name in self._collections: + raise ValueError(f"Collection with name {collection_name} already exists.") + + embeddings_index_path = ( + self._get_collection_path( + collection_name, file_type=_CollectionFileType.USEARCH + ) + if self._persist_directory + else None + ) + + embeddings_index = Index( + path=embeddings_index_path, + ndim=ndim, + metric=metric, + dtype=dtype, + connectivity=connectivity, + expansion_add=expansion_add, + expansion_search=expansion_search, + view=view, + ) + + self._collections[collection_name] = _USearchCollection.create_default( + embeddings_index + ) + + return None + + def _read_embeddings_table( + self, path: os.PathLike + ) -> Tuple[pa.Table, Dict[str, int]]: + """Read embeddings from the provided path and generate an ID to label mapping. + + Args: + path (os.PathLike): Path to the embeddings. + + Returns: + Tuple of embeddings table and a dictionary mapping from record ID to its label. + """ + embeddings_table = pq.read_table(path, schema=_embeddings_data_schema) + embeddings_id_to_label: Dict[str, int] = { + record_id: idx + for idx, record_id in enumerate(embeddings_table.column("id").to_pylist()) + } + return embeddings_table, embeddings_id_to_label + + def _read_embeddings_index(self, path: Path) -> Index: + """Read embeddings index.""" + # str cast is temporarily fix for https://github.com/unum-cloud/usearch/issues/196 + return Index.restore(str(path), view=False) + + def _read_collections_from_dir(self) -> Dict[str, _USearchCollection]: + """Read all collections from directory to memory. + + Raises: + ValueError: If files for a collection do not match expected amount. + + Returns: + Dict[str, _USearchCollection]: Dictionary with collection names as keys and + their _USearchCollection as values. + """ + collections: Dict[str, _USearchCollection] = {} + + for collection_name, collection_files in self._get_all_storage_files().items(): + expected_storage_files = len(_CollectionFileType) + if len(collection_files) != expected_storage_files: + raise ValueError( + f"Expected {expected_storage_files} files for collection {collection_name}" + ) + parquet_file, usearch_file = collection_files + if ( + parquet_file.suffix + == _collection_file_extensions[_CollectionFileType.USEARCH] + ): + parquet_file, usearch_file = usearch_file, parquet_file + + embeddings_table, embeddings_id_to_label = self._read_embeddings_table( + parquet_file + ) + embeddings_index = self._read_embeddings_index(usearch_file) + + collections[collection_name] = _USearchCollection( + embeddings_index, + embeddings_table, + embeddings_id_to_label, + ) + + return collections + + async def get_collections_async(self) -> List[str]: + """Get list of existing collections. + + Returns: + List[str]: List of collection names. + """ + return list(self._collections.keys()) + + async def delete_collection_async(self, collection_name: str) -> None: + collection_name = collection_name.lower() + collection = self._collections.pop(collection_name, None) + if collection: + collection.embeddings_index.reset() + return None + + async def does_collection_exist_async(self, collection_name: str) -> bool: + collection_name = collection_name.lower() + return collection_name in self._collections + + async def upsert_async(self, collection_name: str, record: MemoryRecord) -> str: + """Upsert single MemoryRecord and return its ID.""" + collection_name = collection_name.lower() + res = await self.upsert_batch_async( + collection_name=collection_name, records=[record] + ) + return res[0] + + async def upsert_batch_async( + self, + collection_name: str, + records: List[MemoryRecord], + *, + compact: bool = False, + copy: bool = True, + threads: int = 0, + log: Union[str, bool] = False, + batch_size: int = 0, + ) -> List[str]: + """Upsert a batch of MemoryRecords and return their IDs. + + Args: + collection_name (str): Name of the collection to search within. + records (List[MemoryRecord]): Records to upsert. + compact (bool, optional): Removes links to removed nodes (expensive). Defaults to False. + copy (bool, optional): Should the index store a copy of vectors. Defaults to True. + threads (int, optional): Optimal number of cores to use. Defaults to 0. + log (Union[str, bool], optional): Whether to print the progress bar. Defaults to False. + batch_size (int, optional): Number of vectors to process at once. Defaults to 0. + + Raises: + KeyError: If collection not exist + + Returns: + List[str]: List of IDs. + """ + collection_name = collection_name.lower() + if collection_name not in self._collections: + raise KeyError( + f"Collection {collection_name} does not exist, cannot insert." + ) + + ucollection = self._collections[collection_name] + all_records_id = [record._id for record in records] + + # Remove vectors from index + remove_labels = [ + ucollection.embeddings_id_to_label[id] + for id in all_records_id + if id in ucollection.embeddings_id_to_label + ] + ucollection.embeddings_index.remove( + remove_labels, compact=compact, threads=threads + ) + + # Determine label insertion points + table_num_rows = ucollection.embeddings_data_table.num_rows + insert_labels = np.arange(table_num_rows, table_num_rows + len(records)) + + # Add embeddings to index + ucollection.embeddings_index.add( + keys=insert_labels, + vectors=np.stack([record.embedding for record in records]), + copy=copy, + threads=threads, + log=log, + batch_size=batch_size, + ) + + # Update embeddings_table + ucollection.embeddings_data_table = pa.concat_tables( + [ucollection.embeddings_data_table, memoryrecords_to_pyarrow_table(records)] + ) + + # Update embeddings_id_to_label + for index, record_id in enumerate(all_records_id): + ucollection.embeddings_id_to_label[record_id] = insert_labels[index] + + return all_records_id + + async def get_async( + self, + collection_name: str, + key: str, + with_embedding: bool, + dtype: ScalarKind = ScalarKind.F32, + ) -> MemoryRecord: + """Retrieve a single MemoryRecord using its key.""" + collection_name = collection_name.lower() + result = await self.get_batch_async( + collection_name=collection_name, + keys=[key], + with_embeddings=with_embedding, + dtype=dtype, + ) + if not result: + raise KeyError(f"Key '{key}' not found in collection '{collection_name}'") + return result[0] + + async def get_batch_async( + self, + collection_name: str, + keys: List[str], + with_embeddings: bool, + dtype: ScalarKind = ScalarKind.F32, + ) -> List[MemoryRecord]: + """Retrieve a batch of MemoryRecords using their keys.""" + collection_name = collection_name.lower() + if collection_name not in self._collections: + raise KeyError(f"Collection {collection_name} does not exist") + + ucollection = self._collections[collection_name] + labels = [ + ucollection.embeddings_id_to_label[key] + for key in keys + if key in ucollection.embeddings_id_to_label + ] + if not labels: + return [] + vectors = ( + ucollection.embeddings_index.get_vectors(labels, dtype) + if with_embeddings + else None + ) + + return pyarrow_table_to_memoryrecords( + ucollection.embeddings_data_table.take(pa.array(labels)), vectors + ) + + async def remove_async(self, collection_name: str, key: str) -> None: + """Remove a single MemoryRecord using its key.""" + collection_name = collection_name.lower() + await self.remove_batch_async(collection_name=collection_name, keys=[key]) + return None + + async def remove_batch_async(self, collection_name: str, keys: List[str]) -> None: + """Remove a batch of MemoryRecords using their keys.""" + collection_name = collection_name.lower() + if collection_name not in self._collections: + raise KeyError( + f"Collection {collection_name} does not exist, cannot insert." + ) + + ucollection = self._collections[collection_name] + + labels = [ucollection.embeddings_id_to_label[key] for key in keys] + ucollection.embeddings_index.remove(labels) + for key in keys: + del ucollection.embeddings_id_to_label[key] + + return None + + async def get_nearest_match_async( + self, + collection_name: str, + embedding: ndarray, + min_relevance_score: float = 0.0, + with_embedding: bool = True, + exact: bool = False, + ) -> Tuple[MemoryRecord, float]: + """Retrieve the nearest matching MemoryRecord for the provided embedding. + + By default it is approximately search, see `exact` param description. + + Measure of similarity between vectors is relevance score. It is from 0 to 1. + USearch returns distances for vectors. Distance is converted to relevance score by inverse function. + + Args: + collection_name (str): Name of the collection to search within. + embedding (ndarray): The embedding vector to search for. + min_relevance_score (float, optional): The minimum relevance score for vectors. Supposed to be from 0 to 1. + Only vectors with greater or equal relevance score are returned. Defaults to 0.0. + with_embedding (bool, optional): If True, include the embedding in the result. Defaults to True. + exact (bool, optional): Perform exhaustive linear-time exact search. Defaults to False. + + Returns: + Tuple[MemoryRecord, float]: The nearest matching record and its relevance score. + """ + collection_name = collection_name.lower() + results = await self.get_nearest_matches_async( + collection_name=collection_name, + embedding=embedding, + limit=1, + min_relevance_score=min_relevance_score, + with_embeddings=with_embedding, + exact=exact, + ) + return results[0] + + async def get_nearest_matches_async( + self, + collection_name: str, + embedding: ndarray, + limit: int, + min_relevance_score: float = 0.0, + with_embeddings: bool = True, + *, + threads: int = 0, + exact: bool = False, + log: Union[str, bool] = False, + batch_size: int = 0, + ) -> List[Tuple[MemoryRecord, float]]: + """Get the nearest matches to a given embedding. + + By default it is approximately search, see `exact` param description. + + Measure of similarity between vectors is relevance score. It is from 0 to 1. + USearch returns distances for vectors. Distance is converted to relevance score by inverse function. + + Args: + collection_name (str): Name of the collection to search within. + embedding (ndarray): The embedding vector to search for. + limit (int): maximum amount of embeddings to search for. + min_relevance_score (float, optional): The minimum relevance score for vectors. Supposed to be from 0 to 1. + Only vectors with greater or equal relevance score are returned. Defaults to 0.0. + with_embedding (bool, optional): If True, include the embedding in the result. Defaults to True. + threads (int, optional): Optimal number of cores to use. Defaults to 0. + exact (bool, optional): Perform exhaustive linear-time exact search. Defaults to False. + log (Union[str, bool], optional): Whether to print the progress bar. Defaults to False. + batch_size (int, optional): Number of vectors to process at once. Defaults to 0. + + Raises: + KeyError: if a collection with specified name does not exist + + Returns: + List[Tuple[MemoryRecord, float]]: The nearest matching records and their relevance score. + """ + collection_name = collection_name.lower() + ucollection = self._collections[collection_name] + + result: Union[Matches, BatchMatches] = ucollection.embeddings_index.search( + vectors=embedding, + k=limit, + threads=threads, + exact=exact, + log=log, + batch_size=batch_size, + ) + + assert isinstance(result, Matches) + + relevance_score = 1 / (result.distances + 1) + filtered_labels = result.keys[ + np.where(relevance_score >= min_relevance_score)[0] + ] + + filtered_vectors: Optional[np.ndarray] = None + if with_embeddings: + filtered_vectors = ucollection.embeddings_index.get_vectors(filtered_labels) + + return [ + (mem_rec, relevance_score[index].item()) + for index, mem_rec in enumerate( + pyarrow_table_to_memoryrecords( + ucollection.embeddings_data_table.take(pa.array(filtered_labels)), + filtered_vectors, + ) + ) + ] + + def _get_all_storage_files(self) -> Dict[str, List[Path]]: + """Return storage files for each collection in `self._persist_directory`. + + Collection name is derived from file name and converted to lowercase. Files with extensions that + do not match storage extensions are discarded. + + Raises: + ValueError: If persist directory is not set. + + Returns: + Dict[str, List[Path]]: Dictionary of collection names mapped to their respective files. + """ + if self._persist_directory is None: + raise ValueError("Persist directory is not set") + + storage_exts = _collection_file_extensions.values() + collection_storage_files: Dict[str, List[Path]] = {} + for path in self._persist_directory.iterdir(): + if path.is_file() and (path.suffix in storage_exts): + collection_name = path.stem.lower() + if collection_name in collection_storage_files: + collection_storage_files[collection_name].append(path) + else: + collection_storage_files[collection_name] = [path] + return collection_storage_files + + def _dump_collections(self) -> None: + collection_storage_files = self._get_all_storage_files() + for file_path in itertools.chain.from_iterable( + collection_storage_files.values() + ): + file_path.unlink() + + for collection_name, ucollection in self._collections.items(): + ucollection.embeddings_index.save( + self._get_collection_path( + collection_name, file_type=_CollectionFileType.USEARCH + ) + ) + pq.write_table( + ucollection.embeddings_data_table, + self._get_collection_path( + collection_name, file_type=_CollectionFileType.PARQUET + ), + ) + + return None + + async def close_async(self) -> None: + """Persist collection, clear. + + Returns: + None + """ + if self._persist_directory: + self._dump_collections() + + for collection_name in await self.get_collections_async(): + await self.delete_collection_async(collection_name) + self._collections = {} diff --git a/python/tests/integration/connectors/memory/test_usearch.py b/python/tests/integration/connectors/memory/test_usearch.py new file mode 100644 index 000000000000..30e24295b42e --- /dev/null +++ b/python/tests/integration/connectors/memory/test_usearch.py @@ -0,0 +1,389 @@ +# Copyright (c) Microsoft. All rights reserved. + +from datetime import datetime +from typing import List + +import numpy as np +import pytest + +from semantic_kernel.connectors.memory.usearch import USearchMemoryStore +from semantic_kernel.memory.memory_record import MemoryRecord + +try: + import pyarrow # noqa: F401 + + pyarrow_installed = True +except ImportError: + pyarrow_installed = False + +try: + import usearch # noqa: F401 + + usearch_installed = True +except ImportError: + usearch_installed = False + + +pytestmark = [ + pytest.mark.skipif(not usearch_installed, reason="`USearch` is not installed"), + pytest.mark.skipif( + not pyarrow_installed, + reason="`USearch` dependency `pyarrow` is not installed", + ), +] + + +@pytest.fixture +def memory_record1(): + return MemoryRecord( + id="test_id1", + text="sample text1", + is_reference=False, + embedding=np.array([0.5, 0.5], dtype=np.float32), + description="description", + additional_metadata="additional metadata", + external_source_name="external source", + timestamp=datetime.now(), + ) + + +@pytest.fixture +def memory_record1_with_collision(): + return MemoryRecord( + id="test_id1", + text="sample text2", + is_reference=False, + embedding=np.array([1, 0.6], dtype=np.float32), + description="description_2", + additional_metadata="additional metadata_2", + external_source_name="external source", + timestamp=datetime.now(), + ) + + +@pytest.fixture +def memory_record2(): + return MemoryRecord( + id="test_id2", + text="sample text2", + is_reference=False, + embedding=np.array([0.25, 0.75], dtype=np.float32), + description="description", + additional_metadata="additional metadata", + external_source_name="external source", + timestamp=datetime.now(), + ) + + +@pytest.fixture +def memory_record3(): + return MemoryRecord( + id="test_id3", + text="sample text3", + is_reference=False, + embedding=np.array([0.25, 0.80], dtype=np.float32), + description="description", + additional_metadata="additional metadata", + external_source_name="external source", + timestamp=datetime.now(), + ) + + +def gen_memory_records( + count: int, ndim: int, start_index: int = 0 +) -> List[MemoryRecord]: + return [ + MemoryRecord( + is_reference=False, + text="random text", + additional_metadata="additional", + external_source_name="external_name", + description="something descriptive", + timestamp=datetime.datetime.now(), + id=f":{start_index + index}", + embedding=np.random.uniform(0, 0.3, (ndim)).astype(np.float32), + ) + for index in range(count) + ] + + +def compare_memory_records( + record1: MemoryRecord, record2: MemoryRecord, with_embedding: bool +): + """Compare two MemoryRecord instances and assert they are the same.""" + + assert ( + record1._key == record2._key + ), f"_key mismatch: {record1._key} != {record2._key}" + assert ( + record1._timestamp == record2._timestamp + ), f"_timestamp mismatch: {record1._timestamp} != {record2._timestamp}" + assert ( + record1._is_reference == record2._is_reference + ), f"_is_reference mismatch: {record1._is_reference} != {record2._is_reference}" + assert ( + record1._external_source_name == record2._external_source_name + ), f"_external_source_name mismatch: {record1._external_source_name} != {record2._external_source_name}" + assert record1._id == record2._id, f"_id mismatch: {record1._id} != {record2._id}" + assert ( + record1._description == record2._description + ), f"_description mismatch: {record1._description} != {record2._description}" + assert ( + record1._text == record2._text + ), f"_text mismatch: {record1._text} != {record2._text}" + assert ( + record1._additional_metadata == record2._additional_metadata + ), f"_additional_metadata mismatch: {record1._additional_metadata} != {record2._additional_metadata}" + if with_embedding is True: + assert np.array_equal( + record1._embedding, record2._embedding + ), "_embedding arrays are not equal" + + +@pytest.mark.asyncio +async def test_create_and_get_collection_async(): + memory = USearchMemoryStore() + + await memory.create_collection_async("test_collection1") + await memory.create_collection_async("test_collection2") + await memory.create_collection_async("test_collection3") + result = await memory.get_collections_async() + + assert len(result) == 3 + assert result == ["test_collection1", "test_collection2", "test_collection3"] + + +@pytest.mark.asyncio +async def test_delete_collection_async(): + memory = USearchMemoryStore() + + await memory.create_collection_async("test_collection") + await memory.delete_collection_async("test_collection") + result = await memory.get_collections_async() + assert len(result) == 0 + + await memory.create_collection_async("test_collection") + await memory.delete_collection_async("TEST_COLLECTION") + result = await memory.get_collections_async() + assert len(result) == 0 + + +@pytest.mark.asyncio +async def test_does_collection_exist_async(): + memory = USearchMemoryStore() + await memory.create_collection_async("test_collection") + result = await memory.does_collection_exist_async("test_collection") + assert result is True + + result = await memory.does_collection_exist_async("TEST_COLLECTION") + assert result is True + + +@pytest.mark.asyncio +async def test_upsert_and_get_async_with_no_embedding(memory_record1: MemoryRecord): + memory = USearchMemoryStore() + await memory.create_collection_async("test_collection", ndim=2) + await memory.upsert_async("test_collection", memory_record1) + + result = await memory.get_async("test_collection", "test_id1", False) + compare_memory_records(result, memory_record1, False) + + +@pytest.mark.asyncio +async def test_upsert_and_get_async_with_embedding(memory_record1: MemoryRecord): + memory = USearchMemoryStore() + await memory.create_collection_async("test_collection", ndim=2) + await memory.upsert_async("test_collection", memory_record1) + + result = await memory.get_async("test_collection", "test_id1", True) + compare_memory_records(result, memory_record1, True) + + +@pytest.mark.asyncio +async def test_upsert_and_get_batch_async( + memory_record1: MemoryRecord, memory_record2: MemoryRecord +): + memory = USearchMemoryStore() + await memory.create_collection_async( + "test_collection", ndim=memory_record1.embedding.shape[0] + ) + + await memory.upsert_batch_async("test_collection", [memory_record1, memory_record2]) + + result = await memory.get_batch_async( + "test_collection", ["test_id1", "test_id2"], True + ) + assert len(result) == 2 + + compare_memory_records(result[0], memory_record1, True) + compare_memory_records(result[1], memory_record2, True) + + +@pytest.mark.asyncio +async def test_remove_async(memory_record1): + memory = USearchMemoryStore() + await memory.create_collection_async( + "test_collection", ndim=memory_record1.embedding.shape[0] + ) + + await memory.upsert_async("test_collection", memory_record1) + await memory.remove_async("test_collection", "test_id1") + + # memory.get_async should raise Exception if record is not found + with pytest.raises(KeyError): + await memory.get_async("test_collection", "test_id1", True) + + +@pytest.mark.asyncio +async def test_remove_batch_async( + memory_record1: MemoryRecord, memory_record2: MemoryRecord +): + memory = USearchMemoryStore() + await memory.create_collection_async( + "test_collection", ndim=memory_record1.embedding.shape[0] + ) + + await memory.upsert_batch_async("test_collection", [memory_record1, memory_record2]) + await memory.remove_batch_async("test_collection", ["test_id1", "test_id2"]) + + result = await memory.get_batch_async( + "test_collection", ["test_id1", "test_id2"], True + ) + assert len(result) == 0 + + +@pytest.mark.asyncio +async def test_get_nearest_match_async( + memory_record1: MemoryRecord, memory_record2: MemoryRecord +): + memory = USearchMemoryStore() + + collection_name = "test_collection" + await memory.create_collection_async( + collection_name, ndim=memory_record1.embedding.shape[0], metric="cos" + ) + + await memory.upsert_batch_async(collection_name, [memory_record1, memory_record2]) + + result = await memory.get_nearest_match_async( + collection_name, np.array([0.5, 0.5]), exact=True + ) + + assert len(result) == 2 + assert isinstance(result[0], MemoryRecord) + assert result[1] == pytest.approx(1, abs=1e-5) + + +@pytest.mark.asyncio +async def test_get_nearest_matches_async( + memory_record1: MemoryRecord, memory_record2: MemoryRecord +): + memory = USearchMemoryStore() + + collection_name = "test_collection" + await memory.create_collection_async( + collection_name, ndim=memory_record1.embedding.shape[0], metric="cos" + ) + + await memory.upsert_batch_async(collection_name, [memory_record1, memory_record2]) + + results = await memory.get_nearest_matches_async( + collection_name, np.array([0.5, 0.5]), limit=2, exact=True + ) + + assert len(results) == 2 + assert isinstance(results[0][0], MemoryRecord) + assert results[0][1] == pytest.approx(1, abs=1e-5) + assert results[1][1] == pytest.approx(0.90450, abs=1e-5) + + +@pytest.mark.asyncio +async def test_create_and_save_collection_async( + tmpdir, memory_record1, memory_record2, memory_record3 +): + memory = USearchMemoryStore(tmpdir) + + await memory.create_collection_async("test_collection1", ndim=2) + await memory.create_collection_async("test_collection2", ndim=2) + await memory.create_collection_async("test_collection3", ndim=2) + await memory.upsert_batch_async( + "test_collection1", [memory_record1, memory_record2] + ) + await memory.upsert_batch_async( + "test_collection2", [memory_record2, memory_record3] + ) + await memory.upsert_batch_async( + "test_collection3", [memory_record1, memory_record3] + ) + await memory.close_async() + + assert (tmpdir / "test_collection1.parquet").exists() + assert (tmpdir / "test_collection1.usearch").exists() + assert (tmpdir / "test_collection2.parquet").exists() + assert (tmpdir / "test_collection2.usearch").exists() + assert (tmpdir / "test_collection3.parquet").exists() + assert (tmpdir / "test_collection3.usearch").exists() + + memory = USearchMemoryStore(tmpdir) + result = await memory.get_collections_async() + assert len(result) == 3 + assert set(result) == {"test_collection1", "test_collection2", "test_collection3"} + await memory.delete_collection_async("test_collection1") + await memory.delete_collection_async("test_collection3") + await memory.close_async() + + memory = USearchMemoryStore(tmpdir) + result = await memory.get_collections_async() + assert len(result) == 1 + assert set(result) == {"test_collection2"} + await memory.delete_collection_async("test_collection2") + await memory.close_async() + + memory = USearchMemoryStore(tmpdir) + result = await memory.get_collections_async() + assert len(result) == 0 + + +@pytest.mark.asyncio +async def test_upsert_and_get_async_with_embedding_with_persist( + tmpdir, memory_record1: MemoryRecord, memory_record1_with_collision: MemoryRecord +): + memory = USearchMemoryStore(tmpdir) + assert len(await memory.get_collections_async()) == 0 + await memory.create_collection_async("test_collection", ndim=2) + await memory.upsert_async("test_collection", memory_record1) + await memory.close_async() + + memory = USearchMemoryStore(tmpdir) + assert len(await memory.get_collections_async()) == 1 + result = await memory.get_async("test_collection", "test_id1", True) + compare_memory_records(result, memory_record1, True) + + await memory.upsert_async("test_collection", memory_record1_with_collision) + result = await memory.get_async("test_collection", "test_id1", True) + compare_memory_records(result, memory_record1_with_collision, True) + await memory.close_async() + + memory = USearchMemoryStore(tmpdir) + assert len(await memory.get_collections_async()) == 1 + result = await memory.get_async("test_collection", "test_id1", True) + compare_memory_records(result, memory_record1_with_collision, True) + + +@pytest.mark.asyncio +async def test_remove_get_async( + memory_record1: MemoryRecord, memory_record2: MemoryRecord +): + memory = USearchMemoryStore() + await memory.create_collection_async( + "test_collection", ndim=memory_record1.embedding.shape[0] + ) + + await memory.upsert_batch_async("test_collection", [memory_record1, memory_record2]) + await memory.remove_async("test_collection", "test_id1") + + result = await memory.get_batch_async( + "test_collection", ["test_id1", "test_id2"], True + ) + assert len(result) == 1 + compare_memory_records(result[0], memory_record2, True)