Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Azure Cognitive Search connector for Semantic Memory #747

Merged
merged 2 commits into from
May 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions dotnet/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
<ManagePackageVersionsCentrally>true</ManagePackageVersionsCentrally>
</PropertyGroup>
<ItemGroup>
<PackageVersion Include="Azure.Search.Documents" Version="11.5.0-beta.2" />
<PackageVersion Include="Microsoft.Bcl.HashCode" Version="[1.1.0, )" />
<PackageVersion Include="System.Linq.Async" Version="[6.0.1, )" />
<PackageVersion Include="System.Text.Json" Version="[6.0.0, )" />
Expand Down
7 changes: 7 additions & 0 deletions dotnet/SK-dotnet.sln
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Extensions.UnitTests", "src
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Planning.SequentialPlanner", "src\Extensions\Planning.SequentialPlanner\Planning.SequentialPlanner.csproj", "{A350933D-F9D5-4AD3-8C4F-B856B5020297}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Connectors.Memory.AzureCognitiveSearch", "src\Connectors\Connectors.Memory.AzureCognitiveSearch\Connectors.Memory.AzureCognitiveSearch.csproj", "{EC3BB6D1-2FB2-4702-84C6-F791DE533ED4}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -197,6 +199,10 @@ Global
{A350933D-F9D5-4AD3-8C4F-B856B5020297}.Debug|Any CPU.Build.0 = Debug|Any CPU
{A350933D-F9D5-4AD3-8C4F-B856B5020297}.Release|Any CPU.ActiveCfg = Release|Any CPU
{A350933D-F9D5-4AD3-8C4F-B856B5020297}.Release|Any CPU.Build.0 = Release|Any CPU
{EC3BB6D1-2FB2-4702-84C6-F791DE533ED4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{EC3BB6D1-2FB2-4702-84C6-F791DE533ED4}.Debug|Any CPU.Build.0 = Debug|Any CPU
{EC3BB6D1-2FB2-4702-84C6-F791DE533ED4}.Release|Any CPU.ActiveCfg = Release|Any CPU
{EC3BB6D1-2FB2-4702-84C6-F791DE533ED4}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down Expand Up @@ -233,6 +239,7 @@ Global
{994BEF0B-E277-4D10-BB13-FE670D26620D} = {078F96B4-09E1-4E0E-B214-F71A4F4BF633}
{F51017A9-15C8-472D-893C-080046D710A6} = {078F96B4-09E1-4E0E-B214-F71A4F4BF633}
{A350933D-F9D5-4AD3-8C4F-B856B5020297} = {078F96B4-09E1-4E0E-B214-F71A4F4BF633}
{EC3BB6D1-2FB2-4702-84C6-F791DE533ED4} = {0247C2C9-86C3-45BA-8873-28B0948EDC0C}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {FBDC56A3-86AD-4323-AA0F-201E59123B83}
Expand Down
4 changes: 4 additions & 0 deletions dotnet/SK-dotnet.sln.DotSettings
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
<s:Int64 x:Key="/Default/CodeStyle/CodeFormatting/CSharpFormat/WRAP_LIMIT/@EntryValue">160</s:Int64>
<s:Boolean x:Key="/Default/CodeStyle/CodeFormatting/CSharpFormat/WRAP_LINES/@EntryValue">True</s:Boolean>
<s:String x:Key="/Default/CodeStyle/FileHeader/FileHeaderText/@EntryValue">Copyright (c) Microsoft. All rights reserved.</s:String>
<s:String x:Key="/Default/CodeStyle/Naming/CSharpNaming/Abbreviations/=ACS/@EntryIndexedValue">ACS</s:String>
<s:String x:Key="/Default/CodeStyle/Naming/CSharpNaming/Abbreviations/=AI/@EntryIndexedValue">AI</s:String>
<s:String x:Key="/Default/CodeStyle/Naming/CSharpNaming/Abbreviations/=AIGPT/@EntryIndexedValue">AIGPT</s:String>
<s:String x:Key="/Default/CodeStyle/Naming/CSharpNaming/Abbreviations/=AMQP/@EntryIndexedValue">AMQP</s:String>
Expand Down Expand Up @@ -193,8 +194,11 @@ public void It$SOMENAME$()
<s:Boolean x:Key="/Default/UserDictionary/Words/=mergeresults/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=myfile/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=Notegen/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=pgvector/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=Pinecone/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=Qdrant/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=Roundtrips/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=Reranker/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=sandboxing/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=SK/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=SKHTTP/@EntryIndexedValue">True</s:Boolean>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,333 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.RegularExpressions;
using System.Threading;
using System.Threading.Tasks;
using Azure;
using Azure.Core;
using Azure.Search.Documents;
using Azure.Search.Documents.Indexes;
using Azure.Search.Documents.Indexes.Models;
using Azure.Search.Documents.Models;
using Microsoft.SemanticKernel.Memory;

namespace Microsoft.SemanticKernel.Connectors.Memory.AzureCognitiveSearch;

/// <summary>
/// Semantic Memory implementation using Azure Cognitive Search.
/// For more information about Azure Cognitive Search see https://learn.microsoft.com/azure/search/search-what-is-azure-search
/// </summary>
public class AzureCognitiveSearchMemory : ISemanticTextMemory
dluc marked this conversation as resolved.
Show resolved Hide resolved
{
private readonly SearchIndexClient _adminClient;

private readonly ConcurrentDictionary<string, SearchClient> _clientsByIndex = new();

/// <summary>
/// Create a new instance of semantic memory using Azure Cognitive Search.
/// </summary>
/// <param name="endpoint">Azure Cognitive Search URI, e.g. "https://contoso.search.windows.net"</param>
/// <param name="apiKey">API Key</param>
public AzureCognitiveSearchMemory(string endpoint, string apiKey)
{
AzureKeyCredential credentials = new(apiKey);
this._adminClient = new SearchIndexClient(new Uri(endpoint), credentials);
}

/// <summary>
/// Create a new instance of semantic memory using Azure Cognitive Search.
/// </summary>
/// <param name="endpoint">Azure Cognitive Search URI, e.g. "https://contoso.search.windows.net"</param>
/// <param name="credentials">Azure service</param>
public AzureCognitiveSearchMemory(string endpoint, TokenCredential credentials)
{
this._adminClient = new SearchIndexClient(new Uri(endpoint), credentials);
}

/// <inheritdoc />
public Task<string> SaveInformationAsync(
string collection,
string text,
string id,
string? description = null,
string? additionalMetadata = null,
CancellationToken cancellationToken = default)
{
adrianwyatt marked this conversation as resolved.
Show resolved Hide resolved
collection = NormalizeIndexName(collection);

AzureCognitiveSearchRecord record = new()
{
Id = EncodeId(id),
Text = text,
Description = description,
AdditionalMetadata = additionalMetadata,
IsReference = false,
};

return this.UpsertRecordAsync(collection, record, cancellationToken);
}

/// <inheritdoc />
public Task<string> SaveReferenceAsync(
string collection,
string text,
string externalId,
string externalSourceName,
string? description = null,
string? additionalMetadata = null,
CancellationToken cancellationToken = default)
{
collection = NormalizeIndexName(collection);

AzureCognitiveSearchRecord record = new()
{
Id = EncodeId(externalId),
Text = text,
Description = description,
AdditionalMetadata = additionalMetadata,
ExternalSourceName = externalSourceName,
IsReference = true,
};

return this.UpsertRecordAsync(collection, record, cancellationToken);
}

/// <inheritdoc />
public async Task<MemoryQueryResult?> GetAsync(
string collection,
string key,
bool withEmbedding = false,
CancellationToken cancellationToken = default)
{
collection = NormalizeIndexName(collection);

var client = await this.GetSearchClientAsync(collection, cancellationToken).ConfigureAwait(false);

Response<AzureCognitiveSearchRecord>? result = await client.GetDocumentAsync<AzureCognitiveSearchRecord>(
EncodeId(key), cancellationToken: cancellationToken).ConfigureAwait(false);

if (result == null || result.Value == null)
{
throw new AzureCognitiveSearchMemoryException("Memory read returned null");
}

return new MemoryQueryResult(ToMemoryRecordMetadata(result.Value), 1, null);
}

/// <inheritdoc />
public async IAsyncEnumerable<MemoryQueryResult> SearchAsync(
string collection,
string query,
int limit = 1,
double minRelevanceScore = 0.7,
bool withEmbeddings = false,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
collection = NormalizeIndexName(collection);

var client = await this.GetSearchClientAsync(collection, cancellationToken).ConfigureAwait(false);

var options = new SearchOptions
{
QueryType = SearchQueryType.Semantic,
SemanticConfigurationName = "default",
QueryLanguage = "en-us", // TODO: this shouldn't be required
Size = limit,
};

Response<SearchResults<AzureCognitiveSearchRecord>>? searchResult = await client
.SearchAsync<AzureCognitiveSearchRecord>(query, options, cancellationToken: cancellationToken)
.ConfigureAwait(false);

await foreach (SearchResult<AzureCognitiveSearchRecord>? doc in searchResult.Value.GetResultsAsync())
{
if (doc.RerankerScore < minRelevanceScore) { break; }

yield return new MemoryQueryResult(ToMemoryRecordMetadata(doc.Document), doc.RerankerScore ?? 1, null);
}
}

/// <inheritdoc />
public async Task RemoveAsync(string collection, string key, CancellationToken cancellationToken = default)
{
collection = NormalizeIndexName(collection);

var records = new List<AzureCognitiveSearchRecord> { new() { Id = EncodeId(key) } };

var client = await this.GetSearchClientAsync(collection, cancellationToken).ConfigureAwait(false);

await client.DeleteDocumentsAsync(records, cancellationToken: cancellationToken).ConfigureAwait(false);
}

/// <inheritdoc />
public async Task<IList<string>> GetCollectionsAsync(CancellationToken cancellationToken = default)
{
ConfiguredCancelableAsyncEnumerable<SearchIndex> indexes = this._adminClient.GetIndexesAsync(cancellationToken).ConfigureAwait(false);

var result = new List<string>();
await foreach (var index in indexes)
{
result.Add(index.Name);
}

return result;
}

#region private ================================================================================

/// <summary>
/// Index names cannot contain special chars. We use this rule to replace a few common ones
/// with an underscore and reduce the chance of errors. If other special chars are used, we leave it
/// to the service to throw an error.
/// Note:
/// - replacing chars introduces a small chance of conflicts, e.g. "the-user" and "the_user".
/// - we should consider whether making this optional and leave it to the developer to handle.
/// </summary>
private static readonly Regex s_replaceIndexNameSymbolsRegex = new(@"[\s|\\|/|.|_|:]");

/// <summary>
/// Get a search client for the index specified.
/// </summary>
/// <param name="indexName">Index name</param>
/// <param name="cancellationToken">Task cancellation token</param>
/// <returns>Search client ready to read/write</returns>
private async Task<SearchClient> GetSearchClientAsync(
string indexName,
CancellationToken cancellationToken = default)
{
Response<SearchIndex>? existingIndex = null;
try
{
// Search the index
existingIndex = await this._adminClient.GetIndexAsync(indexName, cancellationToken).ConfigureAwait(false);
}
catch (RequestFailedException e) when (e.Status == 404)
{
}

// Create the index if it doesn't exist
if (existingIndex == null || existingIndex.Value == null)
{
await this.CreateIndexAsync(indexName, cancellationToken).ConfigureAwait(false);
}

// Search an available client from the local cache
if (!this._clientsByIndex.TryGetValue(indexName, out SearchClient client))
{
client = this._adminClient.GetSearchClient(indexName);
this._clientsByIndex[indexName] = client;
}

return client;
}

/// <summary>
/// Create a new search index.
/// </summary>
/// <param name="indexName">Index name</param>
/// <param name="cancellationToken">Task cancellation token</param>
private Task CreateIndexAsync(
string indexName,
CancellationToken cancellationToken = default)
{
var fieldBuilder = new FieldBuilder();
var fields = fieldBuilder.Build(typeof(AzureCognitiveSearchRecord));
var newIndex = new SearchIndex(indexName, fields)
{
SemanticSettings = new SemanticSettings
{
Configurations =
{
new SemanticConfiguration("default", new PrioritizedFields
{
TitleField = new SemanticField { FieldName = "Description" },
ContentFields =
{
new SemanticField { FieldName = "Text" },
new SemanticField { FieldName = "AdditionalMetadata" },
}
})
}
}
};

return this._adminClient.CreateIndexAsync(newIndex, cancellationToken);
}

private async Task<string> UpsertRecordAsync(
string indexName,
AzureCognitiveSearchRecord record,
CancellationToken cancellationToken = default)
{
var client = await this.GetSearchClientAsync(indexName, cancellationToken).ConfigureAwait(false);

Response<IndexDocumentsResult>? result = await client.MergeOrUploadDocumentsAsync(
new List<AzureCognitiveSearchRecord> { record },
new IndexDocumentsOptions { ThrowOnAnyError = true },
cancellationToken).ConfigureAwait(false);

if (result == null || result.Value.Results.Count == 0)
{
throw new AzureCognitiveSearchMemoryException("Memory write returned null or an empty set");
}

return result.Value.Results[0].Key;
}

private static MemoryRecordMetadata ToMemoryRecordMetadata(AzureCognitiveSearchRecord data)
{
return new MemoryRecordMetadata(
isReference: data.IsReference,
id: DecodeId(data.Id),
text: data.Text ?? string.Empty,
description: data.Description ?? string.Empty,
externalSourceName: data.ExternalSourceName,
additionalMetadata: data.AdditionalMetadata ?? string.Empty);
}

/// <summary>
/// Normalize index name to match ACS rules.
/// The method doesn't handle all the error scenarios, leaving it to the service
/// to throw an error for edge cases not handled locally.
/// </summary>
/// <param name="indexName">Value to normalize</param>
/// <returns>Normalized name</returns>
private static string NormalizeIndexName(string indexName)
{
if (indexName.Length > 128)
{
throw new AzureCognitiveSearchMemoryException("The collection name is too long, it cannot exceed 128 chars");
}

#pragma warning disable CA1308 // The service expects a lowercase string
indexName = indexName.ToLowerInvariant();
#pragma warning restore CA1308

return s_replaceIndexNameSymbolsRegex.Replace(indexName.Trim(), "-");
}

/// <summary>
/// ACS keys can contain only letters, digits, underscore, dash, equal sign, recommending
/// to encode values with a URL-safe algorithm.
/// </summary>
/// <param name="realId">Original Id</param>
/// <returns>Encoded id</returns>
private static string EncodeId(string realId)
{
var bytes = Encoding.UTF8.GetBytes(realId);
return Convert.ToBase64String(bytes);
}

private static string DecodeId(string encodedId)
{
var bytes = Convert.FromBase64String(encodedId);
return Encoding.UTF8.GetString(bytes);
}

#endregion
}