From 5cf5f8799cd510a40d1e674c4751bd941ea4a1a2 Mon Sep 17 00:00:00 2001 From: Kevin Pilch Date: Sun, 19 May 2024 14:10:16 -0700 Subject: [PATCH] Use QueryDefinition to avoid injection --- .../AzureCosmosDBNoSQLMemoryStore.cs | 94 ++++++------------- 1 file changed, 31 insertions(+), 63 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStore.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStore.cs index d3df09f8350c..70d6210fc355 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLMemoryStore.cs @@ -14,10 +14,6 @@ using Microsoft.SemanticKernel.Http; using Microsoft.SemanticKernel.Memory; -#if NET6_0_OR_GREATER -using System.Globalization; -#endif - namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL; /// @@ -131,10 +127,12 @@ public class AzureCosmosDBNoSQLMemoryStore : IMemoryStore, IDisposable string collectionName, CancellationToken cancellationToken = default) { + var queryDefinition = new QueryDefinition("SELECT VALUE(c.id) FROM c WHERE c.id = @collectionName"); + queryDefinition.WithParameter("@collectionName", collectionName); using var feedIterator = this. _cosmosClient .GetDatabase(this._databaseName) - .GetContainerQueryIterator($"SELECT VALUE(c.id) FROM c WHERE c.id = '{collectionName}'"); + .GetContainerQueryIterator(queryDefinition); while (feedIterator.HasMoreResults) { @@ -212,20 +210,36 @@ await this._cosmosClient [EnumeratorCancellation] CancellationToken cancellationToken = default) { const string OR = " OR "; - - // Optimistically create the entire query string. - var whereClause = string.Join(OR, keys.Select(k => $"(x.id = \"{k}\" AND x.key = \"{k}\")")); - var queryDefinition = new QueryDefinition($""" + var queryStart = $""" SELECT x.id,x.key,x.metadata,x.timestamp{(withEmbeddings ? ",x.embedding" : "")} FROM x - WHERE {whereClause} - """); - - // NOTE: Cosmos DB queries are limited to 512kB, so if this is larger than that, break it into segments. - var byteCount = Encoding.UTF8.GetByteCount(whereClause); - var ratio = byteCount / ((float)(512 * 1024)); - if (ratio < 1) + WHERE + """; + // NOTE: Cosmos DB queries are limited to 512kB, so we'll break this into chunks + // of around 500kB. We don't go all the way to 512kB so that we don't have to + // remove the last clause we added once we go over. + int keyIndex = 0; + var keyList = keys.ToList(); + while (keyIndex < keyList.Count) { + var length = queryStart.Length; + var countThisBatch = 0; + var whereClauses = new StringBuilder(); + for (int i = keyIndex; i < keyList.Count && length <= 500 * 1024; i++, countThisBatch++) + { + string keyId = $"@key{i:D}"; + var clause = $"(x.id = {keyId} AND x.key = {keyId})"; + whereClauses.Append(clause).Append(OR); + length += clause.Length + OR.Length + 4 + keyId.Length + Encoding.UTF8.GetByteCount(keyList[keyIndex]); + } + whereClauses.Length -= OR.Length; + + var queryDefinition = new QueryDefinition(queryStart + whereClauses); + for (int i = keyIndex; i < keyIndex + countThisBatch; i++) + { + queryDefinition.WithParameter($"@key{i:D}", keyList[i]); + } + var feedIterator = this._cosmosClient .GetDatabase(this._databaseName) .GetContainer(collectionName) @@ -238,54 +252,8 @@ FROM x yield return memoryRecord; } } - } - else - { - // We're in the very large case, we'll need to split this into multiple queries. - // We add one to catch any fractional piece left in the last segment - var segments = (int)(ratio + 1); - var keyList = keys.ToList(); - var keysPerQuery = keyList.Count / segments; - // Make a guess as to how long this query will be. We need at least 26 chars for each "OR" block, so - // put a few extra for the values of the keys. - var estimatedWhereLength = 30 * keysPerQuery; - var localWhere = new StringBuilder(estimatedWhereLength); - for (var i = 0; i < segments; i++) - { - localWhere.Clear(); - for (var q = i * keysPerQuery; q < (i + 1) * keysPerQuery && q < keyList.Count; q++) - { - var k = keyList[q]; -#if NET6_0_OR_GREATER - localWhere.Append(CultureInfo.InvariantCulture, $"(x.id = \"{k}\" AND x.key = \"{k}\")").Append(OR); -#else - localWhere.Append($"(x.id = \"{k}\" AND x.key = \"{k}\")").Append(OR); -#endif - } - if (localWhere.Length >= OR.Length) - { - localWhere.Length -= OR.Length; - - var localQueryDefinition = new QueryDefinition($""" - SELECT x.id,x.key,x.metadata,x.timestamp{(withEmbeddings ? ",x.embedding" : "")} - FROM x - WHERE {localWhere} - """); - var feedIterator = this._cosmosClient - .GetDatabase(this._databaseName) - .GetContainer(collectionName) - .GetItemQueryIterator(localQueryDefinition); - - while (feedIterator.HasMoreResults) - { - foreach (var memoryRecord in await feedIterator.ReadNextAsync(cancellationToken).ConfigureAwait(false)) - { - yield return memoryRecord; - } - } - } - } + keyIndex += countThisBatch; } }