From 3e197853bf54e7cfc42fc22c019ea1157db6eae1 Mon Sep 17 00:00:00 2001 From: Kevin Pilch Date: Sun, 19 May 2024 18:24:33 -0700 Subject: [PATCH] .Net: Adds a memory connector for Azure Cosmos DB for NoSQL (#6148) ### Motivation and Context Azure Cosmos DB is adding Vector Similarity APIs to the NoSQL project, and would like Semantic Kernel users to be able to leverage them. ### Description This adds a Memory Connector implementation for Azure Cosmos DB's, including support for the new vector search functionality coming soon in Cosmos DB. It is mostly based off the existing connector for Azure Cosmos DB for Mongo DB vCore. ### Contribution Checklist - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone :smile: --------- Co-authored-by: Stephen Toub --- dotnet/Directory.Packages.props | 2 +- dotnet/SK-dotnet.sln | 10 + .../AssemblyInfo.cs | 6 + .../AzureCosmosDBNoSQLMemoryStore.cs | 430 ++++++++++++++++++ ...onnectors.Memory.AzureCosmosDBNoSQL.csproj | 30 ++ .../CosmosSystemTextJSonSerializer.cs | 130 ++++++ .../AzureCosmosDBNoSQLMemoryStoreTests.cs | 150 ++++++ ...ureCosmosDBNoSQLMemoryStoreTestsFixture.cs | 78 ++++ .../Memory/AzureCosmosDBNoSQL/DataHelper.cs | 36 ++ .../IntegrationTests/IntegrationTests.csproj | 5 +- 10 files changed, 874 insertions(+), 3 deletions(-) create mode 100644 dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AssemblyInfo.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStore.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/Connectors.Memory.AzureCosmosDBNoSQL.csproj create mode 100644 dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/CosmosSystemTextJSonSerializer.cs create mode 100644 dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStoreTests.cs create mode 100644 dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStoreTestsFixture.cs create mode 100644 dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/DataHelper.cs diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props index 0f45264e4068..0a78b2c0332f 100644 --- a/dotnet/Directory.Packages.props +++ b/dotnet/Directory.Packages.props @@ -75,7 +75,6 @@ - @@ -87,6 +86,7 @@ + diff --git a/dotnet/SK-dotnet.sln b/dotnet/SK-dotnet.sln index 8b58bb93f4aa..6320eeb19832 100644 --- a/dotnet/SK-dotnet.sln +++ b/dotnet/SK-dotnet.sln @@ -310,6 +310,7 @@ EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "QualityCheckWithFilters", "samples\Demos\QualityCheck\QualityCheckWithFilters\QualityCheckWithFilters.csproj", "{1D3EEB5B-0E06-4700-80D5-164956E43D0A}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TimePlugin", "samples\Demos\TimePlugin\TimePlugin.csproj", "{F312FCE1-12D7-4DEF-BC29-2FF6618509F3}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Connectors.Memory.AzureCosmosDBNoSQL", "src\Connectors\Connectors.Memory.AzureCosmosDBNoSQL\Connectors.Memory.AzureCosmosDBNoSQL.csproj", "{B0B3901E-AF56-432B-8FAA-858468E5D0DF}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution @@ -762,6 +763,12 @@ Global {F312FCE1-12D7-4DEF-BC29-2FF6618509F3}.Publish|Any CPU.Build.0 = Debug|Any CPU {F312FCE1-12D7-4DEF-BC29-2FF6618509F3}.Release|Any CPU.ActiveCfg = Release|Any CPU {F312FCE1-12D7-4DEF-BC29-2FF6618509F3}.Release|Any CPU.Build.0 = Release|Any CPU + {B0B3901E-AF56-432B-8FAA-858468E5D0DF}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {B0B3901E-AF56-432B-8FAA-858468E5D0DF}.Debug|Any CPU.Build.0 = Debug|Any CPU + {B0B3901E-AF56-432B-8FAA-858468E5D0DF}.Publish|Any CPU.ActiveCfg = Publish|Any CPU + {B0B3901E-AF56-432B-8FAA-858468E5D0DF}.Publish|Any CPU.Build.0 = Publish|Any CPU + {B0B3901E-AF56-432B-8FAA-858468E5D0DF}.Release|Any CPU.ActiveCfg = Release|Any CPU + {B0B3901E-AF56-432B-8FAA-858468E5D0DF}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -867,6 +874,9 @@ Global {3ED53702-0E53-473A-A0F4-645DB33541C2} = {5D4C0700-BBB5-418F-A7B2-F392B9A18263} {1D3EEB5B-0E06-4700-80D5-164956E43D0A} = {5D4C0700-BBB5-418F-A7B2-F392B9A18263} {F312FCE1-12D7-4DEF-BC29-2FF6618509F3} = {5D4C0700-BBB5-418F-A7B2-F392B9A18263} + {6EF9663D-976C-4A27-B8D3-8B1E63BA3BF2} = {5D4C0700-BBB5-418F-A7B2-F392B9A18263} + {925B1185-8B58-4E2D-95C9-4CA0BA9364E5} = {FA3720F1-C99A-49B2-9577-A940257098BF} + {B0B3901E-AF56-432B-8FAA-858468E5D0DF} = {24503383-A8C4-4255-9998-28D70FE8E99A} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {FBDC56A3-86AD-4323-AA0F-201E59123B83} diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AssemblyInfo.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AssemblyInfo.cs new file mode 100644 index 000000000000..d174fc92303c --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AssemblyInfo.cs @@ -0,0 +1,6 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; + +// This assembly is currently experimental. +[assembly: Experimental("SKEXP0020")] diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStore.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStore.cs new file mode 100644 index 000000000000..70d6210fc355 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStore.cs @@ -0,0 +1,430 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Azure.Cosmos; +using Microsoft.SemanticKernel.Http; +using Microsoft.SemanticKernel.Memory; + +namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL; + +/// +/// An implementation of backed by a Azure Cosmos DB database. +/// Get more details about Azure Cosmos DB vector search https://learn.microsoft.com/en-us/azure/cosmos-db/ +/// +public class AzureCosmosDBNoSQLMemoryStore : IMemoryStore, IDisposable +{ + private readonly CosmosClient _cosmosClient; + private readonly VectorEmbeddingPolicy _vectorEmbeddingPolicy; + private readonly IndexingPolicy _indexingPolicy; + private readonly string _databaseName; + + /// + /// Initiates a AzureCosmosDBNoSQLMemoryStore instance using a Azure Cosmos DB connection string + /// and other properties required for vector search. + /// + /// Connection string required to connect to Azure Cosmos DB. + /// The database name to connect to. + /// The to use if a collection is created. NOTE that embeddings will be stored in a property named 'embedding'. + /// The to use if a collection is created. NOTE that embeddings will be stored in a property named 'embedding'. + /// The application name to use in requests. + public AzureCosmosDBNoSQLMemoryStore( + string connectionString, + string databaseName, + VectorEmbeddingPolicy vectorEmbeddingPolicy, + IndexingPolicy indexingPolicy, + string? applicationName = null) + : this( + new CosmosClient( + connectionString, + new CosmosClientOptions + { + ApplicationName = applicationName ?? HttpHeaderConstant.Values.UserAgent, + Serializer = new CosmosSystemTextJsonSerializer(JsonSerializerOptions.Default), + }), + databaseName, + vectorEmbeddingPolicy, + indexingPolicy) + { + } + + /// + /// Initiates a AzureCosmosDBNoSQLMemoryStore instance using a instance + /// and other properties required for vector search. + /// + /// An existing to use. NOTE: This must support serializing with + /// System.Text.Json, not the default Cosmos serializer. + /// The database name to operate against. + /// The to use if a collection is created. NOTE that embeddings will be stored in a property named 'embedding'. + /// The to use if a collection is created. NOTE that embeddings will be stored in a property named 'embedding'. + public AzureCosmosDBNoSQLMemoryStore( + CosmosClient cosmosClient, + string databaseName, + VectorEmbeddingPolicy vectorEmbeddingPolicy, + IndexingPolicy indexingPolicy) + { + if (!vectorEmbeddingPolicy.Embeddings.Any(e => e.Path == "/embedding")) + { + throw new InvalidOperationException($""" + In order for {nameof(GetNearestMatchAsync)} to function, {nameof(vectorEmbeddingPolicy)} should + contain an embedding path at /embedding. It's also recommended to include a that path in the + {nameof(indexingPolicy)} to improve performance and reduce cost for searches. + """); + } + this._cosmosClient = cosmosClient; + this._databaseName = databaseName; + this._vectorEmbeddingPolicy = vectorEmbeddingPolicy; + this._indexingPolicy = indexingPolicy; + } + + /// + public async Task CreateCollectionAsync( + string collectionName, + CancellationToken cancellationToken = default) + { + var databaseResponse = await this._cosmosClient.CreateDatabaseIfNotExistsAsync( + this._databaseName, cancellationToken: cancellationToken).ConfigureAwait(false); + + var containerProperties = new ContainerProperties(collectionName, "/key") + { + VectorEmbeddingPolicy = this._vectorEmbeddingPolicy, + IndexingPolicy = this._indexingPolicy, + }; + var containerResponse = await databaseResponse.Database.CreateContainerIfNotExistsAsync( + containerProperties, + cancellationToken: cancellationToken).ConfigureAwait(false); + } + + /// + public async IAsyncEnumerable GetCollectionsAsync( + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + using var feedIterator = this. + _cosmosClient + .GetDatabase(this._databaseName) + .GetContainerQueryIterator("SELECT VALUE(c.id) FROM c"); + + while (feedIterator.HasMoreResults) + { + var next = await feedIterator.ReadNextAsync(cancellationToken).ConfigureAwait(false); + foreach (var containerName in next.Resource) + { + yield return containerName; + } + } + } + + /// + public async Task DoesCollectionExistAsync( + string collectionName, + CancellationToken cancellationToken = default) + { + var queryDefinition = new QueryDefinition("SELECT VALUE(c.id) FROM c WHERE c.id = @collectionName"); + queryDefinition.WithParameter("@collectionName", collectionName); + using var feedIterator = this. + _cosmosClient + .GetDatabase(this._databaseName) + .GetContainerQueryIterator(queryDefinition); + + while (feedIterator.HasMoreResults) + { + var next = await feedIterator.ReadNextAsync(cancellationToken).ConfigureAwait(false); + foreach (var containerName in next.Resource) + { + return true; + } + } + + return false; + } + + /// + public async Task DeleteCollectionAsync( + string collectionName, + CancellationToken cancellationToken = default) + { + await this._cosmosClient + .GetDatabase(this._databaseName) + .GetContainer(collectionName) + .DeleteContainerAsync(cancellationToken: cancellationToken) + .ConfigureAwait(false); + } + + /// + public async Task UpsertAsync( + string collectionName, + MemoryRecord record, + CancellationToken cancellationToken = default) + { + var result = await this._cosmosClient + .GetDatabase(this._databaseName) + .GetContainer(collectionName) + .UpsertItemAsync(new MemoryRecordWithId(record), new PartitionKey(record.Key), cancellationToken: cancellationToken) + .ConfigureAwait(false); + + return record.Key; + } + + /// + public async IAsyncEnumerable UpsertBatchAsync( + string collectionName, + IEnumerable records, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + foreach (var record in records) + { + yield return await this.UpsertAsync(collectionName, record, cancellationToken) + .ConfigureAwait(false); + } + } + + /// + public async Task GetAsync( + string collectionName, + string key, + bool withEmbedding = false, + CancellationToken cancellationToken = default) + { + var result = await this._cosmosClient + .GetDatabase(this._databaseName) + .GetContainer(collectionName) + .ReadItemAsync(key, new PartitionKey(key), cancellationToken: cancellationToken) + .ConfigureAwait(false); + + return result.Resource; + } + + /// + public async IAsyncEnumerable GetBatchAsync( + string collectionName, + IEnumerable keys, + bool withEmbeddings = false, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + const string OR = " OR "; + var queryStart = $""" + SELECT x.id,x.key,x.metadata,x.timestamp{(withEmbeddings ? ",x.embedding" : "")} + FROM x + WHERE + """; + // NOTE: Cosmos DB queries are limited to 512kB, so we'll break this into chunks + // of around 500kB. We don't go all the way to 512kB so that we don't have to + // remove the last clause we added once we go over. + int keyIndex = 0; + var keyList = keys.ToList(); + while (keyIndex < keyList.Count) + { + var length = queryStart.Length; + var countThisBatch = 0; + var whereClauses = new StringBuilder(); + for (int i = keyIndex; i < keyList.Count && length <= 500 * 1024; i++, countThisBatch++) + { + string keyId = $"@key{i:D}"; + var clause = $"(x.id = {keyId} AND x.key = {keyId})"; + whereClauses.Append(clause).Append(OR); + length += clause.Length + OR.Length + 4 + keyId.Length + Encoding.UTF8.GetByteCount(keyList[keyIndex]); + } + whereClauses.Length -= OR.Length; + + var queryDefinition = new QueryDefinition(queryStart + whereClauses); + for (int i = keyIndex; i < keyIndex + countThisBatch; i++) + { + queryDefinition.WithParameter($"@key{i:D}", keyList[i]); + } + + var feedIterator = this._cosmosClient + .GetDatabase(this._databaseName) + .GetContainer(collectionName) + .GetItemQueryIterator(queryDefinition); + + while (feedIterator.HasMoreResults) + { + foreach (var memoryRecord in await feedIterator.ReadNextAsync(cancellationToken).ConfigureAwait(false)) + { + yield return memoryRecord; + } + } + + keyIndex += countThisBatch; + } + } + + /// + public async Task RemoveAsync( + string collectionName, + string key, + CancellationToken cancellationToken = default) + { + var response = await this._cosmosClient + .GetDatabase(this._databaseName) + .GetContainer(collectionName) + .DeleteItemAsync(key, new PartitionKey(key), cancellationToken: cancellationToken) + .ConfigureAwait(false); + } + + /// + public async Task RemoveBatchAsync( + string collectionName, + IEnumerable keys, + CancellationToken cancellationToken = default) + { + foreach (var key in keys) + { + var response = await this._cosmosClient + .GetDatabase(this._databaseName) + .GetContainer(collectionName) + .DeleteItemAsync(key, new PartitionKey(key), cancellationToken: cancellationToken) + .ConfigureAwait(false); + } + } + + /// + public async Task<(MemoryRecord, double)?> GetNearestMatchAsync( + string collectionName, + ReadOnlyMemory embedding, + double minRelevanceScore = 0, + bool withEmbedding = false, + CancellationToken cancellationToken = default) + { + await foreach (var item in this.GetNearestMatchesAsync(collectionName, embedding, limit: 1, minRelevanceScore, withEmbedding, cancellationToken).ConfigureAwait(false)) + { + return item; + } + + return null; + } + + /// + public async IAsyncEnumerable<(MemoryRecord, double)> GetNearestMatchesAsync( + string collectionName, + ReadOnlyMemory embedding, + int limit, + double minRelevanceScore = 0, + bool withEmbeddings = false, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + // It would be nice to "WHERE" on the similarity score to stay above the `minRelevanceScore`, but alas + // queries don't support that. + var queryDefinition = new QueryDefinition($""" + SELECT TOP @limit x.id,x.key,x.metadata,x.timestamp,{(withEmbeddings ? "x.embedding," : "")}VectorDistance(x.embedding, @embedding) AS SimilarityScore + FROM x + ORDER BY VectorDistance(x.embedding, @embedding) + """); + queryDefinition.WithParameter("@embedding", embedding); + queryDefinition.WithParameter("@limit", limit); + + var feedIterator = this._cosmosClient + .GetDatabase(this._databaseName) + .GetContainer(collectionName) + .GetItemQueryIterator(queryDefinition); + + while (feedIterator.HasMoreResults) + { + foreach (var memoryRecord in await feedIterator.ReadNextAsync(cancellationToken).ConfigureAwait(false)) + { + if (memoryRecord.SimilarityScore >= minRelevanceScore) + { + yield return (memoryRecord, memoryRecord.SimilarityScore); + } + } + } + } + + /// + /// Disposes the instance. + /// + public void Dispose() + { + this.Dispose(true); + GC.SuppressFinalize(this); + } + + /// + /// Disposes the resources used by the instance. + /// + /// True to release both managed and unmanaged resources; false to release only unmanaged resources. + protected virtual void Dispose(bool disposing) + { + if (disposing) + { + this._cosmosClient.Dispose(); + } + } +} + +/// +/// Creates a new record with a similarity score. +/// +/// +/// +/// +/// +[DebuggerDisplay("{GetDebuggerDisplay()}")] +#pragma warning disable CA1812 // 'MemoryRecordWithSimilarityScore' is an internal class that is apparently never instantiated. If so, remove the code from the assembly. If this class is intended to contain only static members, make it 'static' (Module in Visual Basic). (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1812) +internal sealed class MemoryRecordWithSimilarityScore( +#pragma warning restore CA1812 + MemoryRecordMetadata metadata, + ReadOnlyMemory embedding, + string? key, + DateTimeOffset? timestamp = null) : MemoryRecord(metadata, embedding, key, timestamp) +{ + /// + /// The similarity score returned. + /// + public double SimilarityScore { get; set; } + + private string GetDebuggerDisplay() + { + return $"{this.Key} - {this.SimilarityScore}"; + } +} + +/// +/// Creates a new record that also serializes an "id" property. +/// +[DebuggerDisplay("{GetDebuggerDisplay()}")] +internal sealed class MemoryRecordWithId : MemoryRecord +{ + /// + /// Creates a new record that also serializes an "id" property. + /// + public MemoryRecordWithId(MemoryRecord source) + : base(source.Metadata, source.Embedding, source.Key, source.Timestamp) + { + } + + /// + /// Creates a new record that also serializes an "id" property. + /// + [JsonConstructor] + public MemoryRecordWithId( + MemoryRecordMetadata metadata, + ReadOnlyMemory embedding, + string? key, + DateTimeOffset? timestamp = null) + : base(metadata, embedding, key, timestamp) + { + } + + /// + /// Serializes the property as "id". + /// We do this because Azure Cosmos DB requires a property named "id" for + /// each item. + /// + [JsonInclude] + [JsonPropertyName("id")] + public string Id => this.Key; + + private string GetDebuggerDisplay() + { + return this.Key; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/Connectors.Memory.AzureCosmosDBNoSQL.csproj b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/Connectors.Memory.AzureCosmosDBNoSQL.csproj new file mode 100644 index 000000000000..0ffb5b602e05 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/Connectors.Memory.AzureCosmosDBNoSQL.csproj @@ -0,0 +1,30 @@ + + + + + Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL + $(AssemblyName) + net8.0;netstandard2.0 + $(NoWarn);NU5104;SKEXP0001,SKEXP0010 + alpha + + + + + + + + + Semantic Kernel - Azure CosmosDB NoSQL Connector + Azure CosmosDB NoSQL connector for Semantic Kernel plugins and semantic memory + + + + + + + + + + + diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/CosmosSystemTextJSonSerializer.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/CosmosSystemTextJSonSerializer.cs new file mode 100644 index 000000000000..0737ce09c120 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/CosmosSystemTextJSonSerializer.cs @@ -0,0 +1,130 @@ +// Copyright (c) Microsoft. All rights reserved. + +// Taken from https://github.com/Azure/azure-cosmos-dotnet-v3/pull/4332 + +using System; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Reflection; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace Microsoft.Azure.Cosmos; + +/// +/// This class provides a default implementation of System.Text.Json Cosmos Linq Serializer. +/// +internal sealed class CosmosSystemTextJsonSerializer : CosmosLinqSerializer +{ + /// + /// A read-only instance of . + /// + private readonly JsonSerializerOptions _jsonSerializerOptions; + + /// + /// Creates an instance of + /// with the default values for the Cosmos SDK + /// + /// An instance of containing the json serialization options. + public CosmosSystemTextJsonSerializer( + JsonSerializerOptions jsonSerializerOptions) + { + this._jsonSerializerOptions = jsonSerializerOptions; + } + + /// + [return: MaybeNull] + public override T FromStream(Stream stream) + { + if (stream == null) + { + throw new ArgumentNullException(nameof(stream)); + } + + if (stream.CanSeek && stream.Length == 0) + { + return default; + } + + if (typeof(Stream).IsAssignableFrom(typeof(T))) + { + return (T)(object)stream; + } + + using (stream) + { + return JsonSerializer.Deserialize(stream, this._jsonSerializerOptions); + } + } + + /// + public override Stream ToStream(T input) + { + MemoryStream streamPayload = new(); + JsonSerializer.Serialize( + utf8Json: streamPayload, + value: input, + options: this._jsonSerializerOptions); + + streamPayload.Position = 0; + return streamPayload; + } + + /// + /// Convert a MemberInfo to a string for use in LINQ query translation. + /// + /// Any MemberInfo used in the query. + /// A serialized representation of the member. + /// + /// Note that this is just a default implementation which handles the basic scenarios. Any passed in + /// here are not going to be reflected in SerializeMemberName(). For example, if customers passed in a JsonSerializerOption such as below + /// + /// + /// + /// This would not be honored by SerializeMemberName() unless it included special handling for this, for example. + /// + /// (true); + /// if (jsonExtensionDataAttribute != null) + /// { + /// return null; + /// } + /// JsonPropertyNameAttribute jsonPropertyNameAttribute = memberInfo.GetCustomAttribute(true); + /// if (!string.IsNullOrEmpty(jsonPropertyNameAttribute?.Name)) + /// { + /// return jsonPropertyNameAttribute.Name; + /// } + /// return System.Text.Json.JsonNamingPolicy.CamelCase.ConvertName(memberInfo.Name); + /// } + /// ]]> + /// + /// To handle such scenarios, please create a custom serializer which inherits from the and overrides the + /// SerializeMemberName to add any special handling. + /// + public override string? SerializeMemberName(MemberInfo memberInfo) + { + JsonExtensionDataAttribute? jsonExtensionDataAttribute = + memberInfo.GetCustomAttribute(true); + + if (jsonExtensionDataAttribute != null) + { + return null; + } + + JsonPropertyNameAttribute? jsonPropertyNameAttribute = memberInfo.GetCustomAttribute(true); + if (jsonPropertyNameAttribute is { } && !string.IsNullOrEmpty(jsonPropertyNameAttribute.Name)) + { + return jsonPropertyNameAttribute.Name; + } + + return memberInfo.Name; + } +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStoreTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStoreTests.cs new file mode 100644 index 000000000000..0e8aee320856 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStoreTests.cs @@ -0,0 +1,150 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL; +using Microsoft.SemanticKernel.Memory; +using MongoDB.Driver; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Connectors.AzureCosmosDBNoSQL; + +/// +/// Integration tests of . +/// +public class AzureCosmosDBNoSQLMemoryStoreTests : IClassFixture +{ + private const string? SkipReason = "Azure Cosmos DB Account with Vector indexing enabled required"; + + private readonly AzureCosmosDBNoSQLMemoryStoreTestsFixture _fixture; + + public AzureCosmosDBNoSQLMemoryStoreTests(AzureCosmosDBNoSQLMemoryStoreTestsFixture fixture) + { + this._fixture = fixture; + } + + [Fact(Skip = SkipReason)] + public async Task ItCanCreateGetCheckAndDeleteCollectionAsync() + { + var collectionName = this._fixture.CollectionName; + var memoryStore = this._fixture.MemoryStore; + + await memoryStore.CreateCollectionAsync(collectionName); + var collectionNames = memoryStore.GetCollectionsAsync(); + + Assert.True(await collectionNames.ContainsAsync(collectionName)); + Assert.True(await memoryStore.DoesCollectionExistAsync(collectionName)); + + await memoryStore.DeleteCollectionAsync(collectionName); + Assert.False(await memoryStore.DoesCollectionExistAsync(collectionName)); + } + + [Theory(Skip = SkipReason)] + [InlineData(true)] + [InlineData(false)] + public async Task ItCanBatchUpsertGetRemoveAsync(bool withEmbeddings) + { + const int Count = 10; + var collectionName = this._fixture.CollectionName; + var memoryStore = this._fixture.MemoryStore; + var records = DataHelper.CreateBatchRecords(Count); + + await memoryStore.CreateCollectionAsync(collectionName); + var keys = await memoryStore.UpsertBatchAsync(collectionName, records).ToListAsync(); + var actualRecords = await memoryStore + .GetBatchAsync(collectionName, keys, withEmbeddings: withEmbeddings) + .ToListAsync(); + + Assert.NotNull(keys); + Assert.NotNull(actualRecords); + Assert.Equal(keys, actualRecords.Select(obj => obj.Key).ToList()); + Console.WriteLine(actualRecords); + + var actualRecordsOrdered = actualRecords.OrderBy(r => r.Key).ToArray(); + for (int i = 0; i < Count; i++) + { + AssertMemoryRecordEqual( + records[i], + actualRecordsOrdered[i], + assertEmbeddingEqual: withEmbeddings + ); + } + + await memoryStore.RemoveBatchAsync(collectionName, keys); + var ids = await memoryStore.GetBatchAsync(collectionName, keys).ToListAsync(); + Assert.Empty(ids); + + await memoryStore.DeleteCollectionAsync(collectionName); + } + + [Theory(Skip = SkipReason)] + [InlineData(1, false)] + [InlineData(1, true)] + [InlineData(5, false)] + [InlineData(8, false)] + public async Task ItCanGetNearestMatchesAsync(int limit, bool withEmbeddings) + { + var collectionName = this._fixture.CollectionName; + var memoryStore = this._fixture.MemoryStore; + var searchEmbedding = DataHelper.VectorSearchTestEmbedding; + var nearestMatchesExpected = DataHelper.VectorSearchExpectedResults; + + await memoryStore.CreateCollectionAsync(collectionName); + var keys = await memoryStore.UpsertBatchAsync(collectionName, DataHelper.VectorSearchTestRecords).ToListAsync(); + + var nearestMatchesActual = await memoryStore + .GetNearestMatchesAsync( + collectionName, + searchEmbedding, + limit, + withEmbeddings: withEmbeddings + ) + .ToListAsync(); + + Assert.NotNull(nearestMatchesActual); + Assert.Equal(limit, nearestMatchesActual.Count); + + for (int i = 0; i < limit; i++) + { + AssertMemoryRecordEqual( + nearestMatchesExpected[i], + nearestMatchesActual[i].Item1, + withEmbeddings + ); + } + + await memoryStore.DeleteCollectionAsync(collectionName); + } + + private static void AssertMemoryRecordEqual( + MemoryRecord expectedRecord, + MemoryRecord actualRecord, + bool assertEmbeddingEqual = true + ) + { + Assert.Equal(expectedRecord.Key, actualRecord.Key); + Assert.Equal(expectedRecord.Timestamp, actualRecord.Timestamp); + Assert.Equal(expectedRecord.Metadata.Id, actualRecord.Metadata.Id); + Assert.Equal(expectedRecord.Metadata.Text, actualRecord.Metadata.Text); + Assert.Equal(expectedRecord.Metadata.Description, actualRecord.Metadata.Description); + Assert.Equal( + expectedRecord.Metadata.AdditionalMetadata, + actualRecord.Metadata.AdditionalMetadata + ); + Assert.Equal(expectedRecord.Metadata.IsReference, actualRecord.Metadata.IsReference); + Assert.Equal( + expectedRecord.Metadata.ExternalSourceName, + actualRecord.Metadata.ExternalSourceName + ); + + if (assertEmbeddingEqual) + { + Assert.True(expectedRecord.Embedding.Span.SequenceEqual(actualRecord.Embedding.Span)); + } + else + { + Assert.True(actualRecord.Embedding.Span.IsEmpty); + } + } +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStoreTestsFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStoreTestsFixture.cs new file mode 100644 index 000000000000..93cbea170f40 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStoreTestsFixture.cs @@ -0,0 +1,78 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.ObjectModel; +using System.Threading.Tasks; +using Microsoft.Azure.Cosmos; +using Microsoft.Extensions.Configuration; +using Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Connectors.AzureCosmosDBNoSQL; + +public class AzureCosmosDBNoSQLMemoryStoreTestsFixture : IAsyncLifetime +{ + public AzureCosmosDBNoSQLMemoryStore MemoryStore { get; } + public string DatabaseName { get; } + public string CollectionName { get; } + + public AzureCosmosDBNoSQLMemoryStoreTestsFixture() + { + // Load Configuration + var configuration = new ConfigurationBuilder() + .AddJsonFile(path: "testsettings.json", optional: false, reloadOnChange: true) + .AddJsonFile( + path: "testsettings.development.json", + optional: false, + reloadOnChange: true + ) + .AddEnvironmentVariables() + .Build(); + + var connectionString = GetSetting(configuration, "ConnectionString"); + this.DatabaseName = "DotNetSKTestDB"; + this.CollectionName = "DotNetSKTestCollection"; + this.MemoryStore = new AzureCosmosDBNoSQLMemoryStore( + connectionString, + this.DatabaseName, + new VectorEmbeddingPolicy( + new Collection + { + new() + { + DataType = VectorDataType.Float32, + Dimensions = 3, + DistanceFunction = DistanceFunction.Cosine, + Path = "/embedding" + } + }), + new() + { + VectorIndexes = new Collection { + new() + { + Path = "/embedding", + Type = VectorIndexType.Flat, + }, + }, + } + ); + } + + public Task InitializeAsync() + => Task.CompletedTask; + + public Task DisposeAsync() + => Task.CompletedTask; + + private static string GetSetting(IConfigurationRoot configuration, string settingName) + { + var settingValue = configuration[$"AzureCosmosDB:{settingName}"]; + if (string.IsNullOrWhiteSpace(settingValue)) + { + throw new ArgumentNullException($"{settingValue} string is not configured"); + } + + return settingValue; + } +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/DataHelper.cs b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/DataHelper.cs new file mode 100644 index 000000000000..476142430d6a --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/DataHelper.cs @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Linq; +using System.Numerics.Tensors; +using Microsoft.SemanticKernel.Memory; + +namespace SemanticKernel.IntegrationTests.Connectors.AzureCosmosDBNoSQL; + +internal static class DataHelper +{ + public static MemoryRecord[] VectorSearchExpectedResults { get; } + public static MemoryRecord[] VectorSearchTestRecords { get; } + public static float[] VectorSearchTestEmbedding { get; } + + static DataHelper() + { + VectorSearchTestRecords = CreateBatchRecords(8); + VectorSearchTestEmbedding = new[] { 1, 0.699f, 0.701f }; + VectorSearchExpectedResults = VectorSearchTestRecords + .OrderByDescending(r => TensorPrimitives.CosineSimilarity(r.Embedding.Span, VectorSearchTestEmbedding)) + .ToArray(); + } + + public static MemoryRecord[] CreateBatchRecords(int count) => + Enumerable + .Range(0, count) + .Select(i => MemoryRecord.LocalRecord( + id: $"test_{i}", + text: $"text_{i}", + description: $"description_{i}", + embedding: new[] { 1, (float)Math.Cos(Math.PI * i / count), (float)Math.Sin(Math.PI * i / count) }, + key: $"test_{i}", + timestamp: DateTimeOffset.Now)) + .ToArray(); +} diff --git a/dotnet/src/IntegrationTests/IntegrationTests.csproj b/dotnet/src/IntegrationTests/IntegrationTests.csproj index ac04125bc9fa..8f6e3a652d43 100644 --- a/dotnet/src/IntegrationTests/IntegrationTests.csproj +++ b/dotnet/src/IntegrationTests/IntegrationTests.csproj @@ -53,16 +53,17 @@ - + + + -