Skip to content

Commit

Permalink
Azure Cognitive Search connector for Semantic Memory
Browse files Browse the repository at this point in the history
  • Loading branch information
dluc committed May 1, 2023
1 parent 826adf1 commit f7fac42
Show file tree
Hide file tree
Showing 12 changed files with 606 additions and 138 deletions.
1 change: 1 addition & 0 deletions dotnet/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
<ManagePackageVersionsCentrally>true</ManagePackageVersionsCentrally>
</PropertyGroup>
<ItemGroup>
<PackageVersion Include="Azure.Search.Documents" Version="11.5.0-beta.2" />
<PackageVersion Include="Microsoft.Bcl.HashCode" Version="[1.1.0, )" />
<PackageVersion Include="System.Linq.Async" Version="[6.0.1, )" />
<PackageVersion Include="System.Text.Json" Version="[6.0.0, )" />
Expand Down
7 changes: 7 additions & 0 deletions dotnet/SK-dotnet.sln
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Extensions.UnitTests", "src
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Planning.SequentialPlanner", "src\Extensions\Planning.SequentialPlanner\Planning.SequentialPlanner.csproj", "{A350933D-F9D5-4AD3-8C4F-B856B5020297}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Connectors.Memory.AzureCognitiveSearch", "src\Connectors\Connectors.Memory.AzureCognitiveSearch\Connectors.Memory.AzureCognitiveSearch.csproj", "{EC3BB6D1-2FB2-4702-84C6-F791DE533ED4}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -197,6 +199,10 @@ Global
{A350933D-F9D5-4AD3-8C4F-B856B5020297}.Debug|Any CPU.Build.0 = Debug|Any CPU
{A350933D-F9D5-4AD3-8C4F-B856B5020297}.Release|Any CPU.ActiveCfg = Release|Any CPU
{A350933D-F9D5-4AD3-8C4F-B856B5020297}.Release|Any CPU.Build.0 = Release|Any CPU
{EC3BB6D1-2FB2-4702-84C6-F791DE533ED4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{EC3BB6D1-2FB2-4702-84C6-F791DE533ED4}.Debug|Any CPU.Build.0 = Debug|Any CPU
{EC3BB6D1-2FB2-4702-84C6-F791DE533ED4}.Release|Any CPU.ActiveCfg = Release|Any CPU
{EC3BB6D1-2FB2-4702-84C6-F791DE533ED4}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down Expand Up @@ -233,6 +239,7 @@ Global
{994BEF0B-E277-4D10-BB13-FE670D26620D} = {078F96B4-09E1-4E0E-B214-F71A4F4BF633}
{F51017A9-15C8-472D-893C-080046D710A6} = {078F96B4-09E1-4E0E-B214-F71A4F4BF633}
{A350933D-F9D5-4AD3-8C4F-B856B5020297} = {078F96B4-09E1-4E0E-B214-F71A4F4BF633}
{EC3BB6D1-2FB2-4702-84C6-F791DE533ED4} = {0247C2C9-86C3-45BA-8873-28B0948EDC0C}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {FBDC56A3-86AD-4323-AA0F-201E59123B83}
Expand Down
4 changes: 4 additions & 0 deletions dotnet/SK-dotnet.sln.DotSettings
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
<s:Int64 x:Key="/Default/CodeStyle/CodeFormatting/CSharpFormat/WRAP_LIMIT/@EntryValue">160</s:Int64>
<s:Boolean x:Key="/Default/CodeStyle/CodeFormatting/CSharpFormat/WRAP_LINES/@EntryValue">True</s:Boolean>
<s:String x:Key="/Default/CodeStyle/FileHeader/FileHeaderText/@EntryValue">Copyright (c) Microsoft. All rights reserved.</s:String>
<s:String x:Key="/Default/CodeStyle/Naming/CSharpNaming/Abbreviations/=ACS/@EntryIndexedValue">ACS</s:String>
<s:String x:Key="/Default/CodeStyle/Naming/CSharpNaming/Abbreviations/=AI/@EntryIndexedValue">AI</s:String>
<s:String x:Key="/Default/CodeStyle/Naming/CSharpNaming/Abbreviations/=AIGPT/@EntryIndexedValue">AIGPT</s:String>
<s:String x:Key="/Default/CodeStyle/Naming/CSharpNaming/Abbreviations/=AMQP/@EntryIndexedValue">AMQP</s:String>
Expand Down Expand Up @@ -193,8 +194,11 @@ public void It$SOMENAME$()
<s:Boolean x:Key="/Default/UserDictionary/Words/=mergeresults/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=myfile/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=Notegen/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=pgvector/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=Pinecone/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=Qdrant/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=Roundtrips/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=Reranker/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=sandboxing/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=SK/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=SKHTTP/@EntryIndexedValue">True</s:Boolean>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,333 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.RegularExpressions;
using System.Threading;
using System.Threading.Tasks;
using Azure;
using Azure.Core;
using Azure.Search.Documents;
using Azure.Search.Documents.Indexes;
using Azure.Search.Documents.Indexes.Models;
using Azure.Search.Documents.Models;
using Microsoft.SemanticKernel.Memory;

namespace Microsoft.SemanticKernel.Connectors.Memory.AzureCognitiveSearch;

/// <summary>
/// Semantic Memory implementation using Azure Cognitive Search.
/// For more information about Azure Cognitive Search see https://learn.microsoft.com/azure/search/search-what-is-azure-search
/// </summary>
public class AzureCognitiveSearchMemory : ISemanticTextMemory
{
private readonly SearchIndexClient _adminClient;

private readonly ConcurrentDictionary<string, SearchClient> _clientsByIndex = new();

/// <summary>
/// Create a new instance of semantic memory using Azure Cognitive Search.
/// </summary>
/// <param name="endpoint">Azure Cognitive Search URI, e.g. "https://contoso.search.windows.net"</param>
/// <param name="apiKey">API Key</param>
public AzureCognitiveSearchMemory(string endpoint, string apiKey)
{
AzureKeyCredential credentials = new(apiKey);
this._adminClient = new SearchIndexClient(new Uri(endpoint), credentials);
}

/// <summary>
/// Create a new instance of semantic memory using Azure Cognitive Search.
/// </summary>
/// <param name="endpoint">Azure Cognitive Search URI, e.g. "https://contoso.search.windows.net"</param>
/// <param name="credentials">Azure service</param>
public AzureCognitiveSearchMemory(string endpoint, TokenCredential credentials)
{
this._adminClient = new SearchIndexClient(new Uri(endpoint), credentials);
}

/// <inheritdoc />
public Task<string> SaveInformationAsync(
string collection,
string text,
string id,
string? description = null,
string? additionalMetadata = null,
CancellationToken cancellationToken = default)
{
collection = NormalizeIndexName(collection);

AzureCognitiveSearchRecord record = new()
{
Id = EncodeId(id),
Text = text,
Description = description,
AdditionalMetadata = additionalMetadata,
IsReference = false,
};

return this.UpsertRecordAsync(collection, record, cancellationToken);
}

/// <inheritdoc />
public Task<string> SaveReferenceAsync(
string collection,
string text,
string externalId,
string externalSourceName,
string? description = null,
string? additionalMetadata = null,
CancellationToken cancellationToken = default)
{
collection = NormalizeIndexName(collection);

AzureCognitiveSearchRecord record = new()
{
Id = EncodeId(externalId),
Text = text,
Description = description,
AdditionalMetadata = additionalMetadata,
ExternalSourceName = externalSourceName,
IsReference = true,
};

return this.UpsertRecordAsync(collection, record, cancellationToken);
}

/// <inheritdoc />
public async Task<MemoryQueryResult?> GetAsync(
string collection,
string key,
bool withEmbedding = false,
CancellationToken cancellationToken = default)
{
collection = NormalizeIndexName(collection);

var client = await this.GetSearchClientAsync(collection, cancellationToken).ConfigureAwait(false);

Response<AzureCognitiveSearchRecord>? result = await client.GetDocumentAsync<AzureCognitiveSearchRecord>(
EncodeId(key), cancellationToken: cancellationToken).ConfigureAwait(false);

if (result == null || result.Value == null)
{
throw new AzureCognitiveSearchMemoryException("Memory read returned null");
}

return new MemoryQueryResult(ToMemoryRecordMetadata(result.Value), 1, null);
}

/// <inheritdoc />
public async IAsyncEnumerable<MemoryQueryResult> SearchAsync(
string collection,
string query,
int limit = 1,
double minRelevanceScore = 0.7,
bool withEmbeddings = false,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
collection = NormalizeIndexName(collection);

var client = await this.GetSearchClientAsync(collection, cancellationToken).ConfigureAwait(false);

var options = new SearchOptions
{
QueryType = SearchQueryType.Semantic,
SemanticConfigurationName = "default",
QueryLanguage = "en-us", // TODO: this shouldn't be required
Size = limit,
};

Response<SearchResults<AzureCognitiveSearchRecord>>? searchResult = await client
.SearchAsync<AzureCognitiveSearchRecord>(query, options, cancellationToken: cancellationToken)
.ConfigureAwait(false);

await foreach (SearchResult<AzureCognitiveSearchRecord>? doc in searchResult.Value.GetResultsAsync())
{
if (doc.RerankerScore < minRelevanceScore) { break; }

yield return new MemoryQueryResult(ToMemoryRecordMetadata(doc.Document), doc.RerankerScore ?? 1, null);
}
}

/// <inheritdoc />
public async Task RemoveAsync(string collection, string key, CancellationToken cancellationToken = default)
{
collection = NormalizeIndexName(collection);

var records = new List<AzureCognitiveSearchRecord> { new() { Id = EncodeId(key) } };

var client = await this.GetSearchClientAsync(collection, cancellationToken).ConfigureAwait(false);

await client.DeleteDocumentsAsync(records, cancellationToken: cancellationToken).ConfigureAwait(false);
}

/// <inheritdoc />
public async Task<IList<string>> GetCollectionsAsync(CancellationToken cancellationToken = default)
{
ConfiguredCancelableAsyncEnumerable<SearchIndex> indexes = this._adminClient.GetIndexesAsync(cancellationToken).ConfigureAwait(false);

var result = new List<string>();
await foreach (var index in indexes)
{
result.Add(index.Name);
}

return result;
}

#region private ================================================================================

/// <summary>
/// Index names cannot contain special chars. We use this rule to replace a few common ones
/// with an underscore and reduce the chance of errors. If other special chars are used, we leave it
/// to the service to throw an error.
/// Note:
/// - replacing chars introduces a small chance of conflicts, e.g. "the-user" and "the_user".
/// - we should consider whether making this optional and leave it to the developer to handle.
/// </summary>
private static readonly Regex s_replaceIndexNameSymbolsRegex = new(@"[\s|\\|/|.|_|:]");

/// <summary>
/// Get a search client for the index specified.
/// </summary>
/// <param name="indexName">Index name</param>
/// <param name="cancellationToken">Task cancellation token</param>
/// <returns>Search client ready to read/write</returns>
private async Task<SearchClient> GetSearchClientAsync(
string indexName,
CancellationToken cancellationToken = default)
{
Response<SearchIndex>? existingIndex = null;
try
{
// Search the index
existingIndex = await this._adminClient.GetIndexAsync(indexName, cancellationToken).ConfigureAwait(false);
}
catch (RequestFailedException e) when (e.Status == 404)
{
}

// Create the index if it doesn't exist
if (existingIndex == null || existingIndex.Value == null)
{
await this.CreateIndexAsync(indexName, cancellationToken).ConfigureAwait(false);
}

// Search an available client from the local cache
if (!this._clientsByIndex.TryGetValue(indexName, out SearchClient? client) || client == null)
{
client = this._adminClient.GetSearchClient(indexName);
this._clientsByIndex[indexName] = client;
}

return client;
}

/// <summary>
/// Create a new search index.
/// </summary>
/// <param name="indexName">Index name</param>
/// <param name="cancellationToken">Task cancellation token</param>
private Task CreateIndexAsync(
string indexName,
CancellationToken cancellationToken = default)
{
var fieldBuilder = new FieldBuilder();
var fields = fieldBuilder.Build(typeof(AzureCognitiveSearchRecord));
var newIndex = new SearchIndex(indexName, fields)
{
SemanticSettings = new SemanticSettings
{
Configurations =
{
new SemanticConfiguration("default", new PrioritizedFields
{
TitleField = new SemanticField { FieldName = "Description" },
ContentFields =
{
new SemanticField { FieldName = "Text" },
new SemanticField { FieldName = "AdditionalMetadata" },
}
})
}
}
};

return this._adminClient.CreateIndexAsync(newIndex, cancellationToken);
}

private async Task<string> UpsertRecordAsync(
string indexName,
AzureCognitiveSearchRecord record,
CancellationToken cancellationToken = default)
{
var client = await this.GetSearchClientAsync(indexName, cancellationToken).ConfigureAwait(false);

Response<IndexDocumentsResult>? result = await client.MergeOrUploadDocumentsAsync(
new List<AzureCognitiveSearchRecord> { record },
new IndexDocumentsOptions { ThrowOnAnyError = true },
cancellationToken).ConfigureAwait(false);

if (result == null || result.Value.Results.Count == 0)
{
throw new AzureCognitiveSearchMemoryException("Memory write returned null or an empty set");
}

return result.Value.Results[0].Key;
}

private static MemoryRecordMetadata ToMemoryRecordMetadata(AzureCognitiveSearchRecord data)
{
return new MemoryRecordMetadata(
isReference: data.IsReference,
id: DecodeId(data.Id),
text: data.Text ?? string.Empty,
description: data.Description ?? string.Empty,
externalSourceName: data.ExternalSourceName,
additionalMetadata: data.AdditionalMetadata ?? string.Empty);
}

/// <summary>
/// Normalize index name to match ACS rules.
/// The method doesn't handle all the error scenarios, leaving it to the service
/// to throw an error for edge cases not handled locally.
/// </summary>
/// <param name="indexName">Value to normalize</param>
/// <returns>Normalized name</returns>
private static string NormalizeIndexName(string indexName)
{
if (indexName.Length > 128)
{
throw new AzureCognitiveSearchMemoryException("The collection name is too long, it cannot exceed 128 chars");
}

#pragma warning disable CA1308 // The service expects a lowercase string
indexName = indexName.ToLowerInvariant();
#pragma warning restore CA1308

return s_replaceIndexNameSymbolsRegex.Replace(indexName.Trim(), "-");
}

/// <summary>
/// ACS keys can contain only letters, digits, underscore, dash, equal sign, recommending
/// to encode values with a URL-safe algorithm.
/// </summary>
/// <param name="realId">Original Id</param>
/// <returns>Encoded id</returns>
private static string EncodeId(string realId)
{
var bytes = Encoding.UTF8.GetBytes(realId);
return Convert.ToBase64String(bytes);
}

private static string DecodeId(string encodedId)
{
var bytes = Convert.FromBase64String(encodedId);
return Encoding.UTF8.GetString(bytes);
}

#endregion
}

0 comments on commit f7fac42

Please sign in to comment.