Skip to content

Commit

Permalink
Add Azure AI Search hybrid search support (#428)
Browse files Browse the repository at this point in the history
## Motivation and Context (Why the change? What's the scenario?)
Already described in issue #159 
The main idea is to support Azure AI search Hybrid search.

## High level description (Approach, Design)

The idea is to have a new Config property in the AzureAISearchConfig class, so Hybrid is only enabled explicitly.
When enabled, the CosineSimilarity is not calculated and the minDistance is set to the minRelevance parameter (passed from the top SearchAsync method).

---------

Co-authored-by: “luismanez” <“luis.manez@outlook.com”>
Co-authored-by: Devis Lucato <devis@microsoft.com>
  • Loading branch information
3 people committed Apr 25, 2024
1 parent 713cbb2 commit c631a64
Show file tree
Hide file tree
Showing 11 changed files with 247 additions and 18 deletions.
7 changes: 6 additions & 1 deletion KernelMemory.sln
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,8 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "modules", "modules", "{C2D3
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Service.AspNetCore", "service\Service.AspNetCore\Service.AspNetCore.csproj", "{A46B0BE1-03F2-4520-A3DA-FD845BA1FD69}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "111-dotnet-azure-ai-hybrid-search", "examples\111-dotnet-azure-ai-hybrid-search\111-dotnet-azure-ai-hybrid-search.csproj", "{28534545-CB39-446A-9EB9-A5ABBFE0CFD3}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -332,6 +334,7 @@ Global
{8FB12876-013D-44CB-9F0D-E926D9F0F4E3} = {0A43C65C-6007-4BB4-B3FE-8D439FC91841}
{C2D3A947-B6F9-4306-BD42-21D8D1F42750} = {B488168B-AD86-4CC5-9D89-324B6EB743D9}
{A46B0BE1-03F2-4520-A3DA-FD845BA1FD69} = {87DEAE8D-138C-4FDD-B4C9-11C3A7817E8F}
{28534545-CB39-446A-9EB9-A5ABBFE0CFD3} = {0A43C65C-6007-4BB4-B3FE-8D439FC91841}
EndGlobalSection
GlobalSection(ProjectConfigurationPlatforms) = postSolution
{8A9FA587-7EBA-4D43-BE47-38D798B1C74C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
Expand Down Expand Up @@ -532,10 +535,12 @@ Global
{8FB12876-013D-44CB-9F0D-E926D9F0F4E3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{8FB12876-013D-44CB-9F0D-E926D9F0F4E3}.Debug|Any CPU.Build.0 = Debug|Any CPU
{8FB12876-013D-44CB-9F0D-E926D9F0F4E3}.Release|Any CPU.ActiveCfg = Release|Any CPU
{8FB12876-013D-44CB-9F0D-E926D9F0F4E3}.Release|Any CPU.Build.0 = Release|Any CPU
{A46B0BE1-03F2-4520-A3DA-FD845BA1FD69}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{A46B0BE1-03F2-4520-A3DA-FD845BA1FD69}.Debug|Any CPU.Build.0 = Debug|Any CPU
{A46B0BE1-03F2-4520-A3DA-FD845BA1FD69}.Release|Any CPU.ActiveCfg = Release|Any CPU
{A46B0BE1-03F2-4520-A3DA-FD845BA1FD69}.Release|Any CPU.Build.0 = Release|Any CPU
{28534545-CB39-446A-9EB9-A5ABBFE0CFD3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{28534545-CB39-446A-9EB9-A5ABBFE0CFD3}.Debug|Any CPU.Build.0 = Debug|Any CPU
{28534545-CB39-446A-9EB9-A5ABBFE0CFD3}.Release|Any CPU.ActiveCfg = Release|Any CPU
EndGlobalSection
EndGlobal
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
<Project Sdk="Microsoft.NET.Sdk.Web">

<PropertyGroup>
<TargetFramework>net8.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<ManagePackageVersionsCentrally>false</ManagePackageVersionsCentrally>
<NoWarn>$(NoWarn);CA1050;CA2000;CA1707;CA1303;CA2007;CA1724;CA1861;CA1859;</NoWarn>
</PropertyGroup>

<ItemGroup>
<ProjectReference Include="..\..\service\Core\Core.csproj" />
</ItemGroup>

</Project>
83 changes: 83 additions & 0 deletions examples/111-dotnet-azure-ai-hybrid-search/Program.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright (c) Microsoft. All rights reserved.

// ReSharper disable InconsistentNaming

using Microsoft.KernelMemory;
using Microsoft.KernelMemory.AI.OpenAI;

public static class Program
{
private const string indexName = "acronyms";

public static async Task Main()
{
var azureOpenAITextConfig = new AzureOpenAIConfig();
var azureOpenAIEmbeddingConfig = new AzureOpenAIConfig();
var azureAISearchConfigWithHybridSearch = new AzureAISearchConfig();
var azureAISearchConfigWithoutHybridSearch = new AzureAISearchConfig();

new ConfigurationBuilder()
.AddJsonFile("appsettings.json")
.AddJsonFile("appsettings.Development.json", optional: true)
.Build()
.BindSection("KernelMemory:Services:AzureOpenAIText", azureOpenAITextConfig)
.BindSection("KernelMemory:Services:AzureOpenAIEmbedding", azureOpenAIEmbeddingConfig)
.BindSection("KernelMemory:Services:AzureAISearch", azureAISearchConfigWithHybridSearch)
.BindSection("KernelMemory:Services:AzureAISearch", azureAISearchConfigWithoutHybridSearch);

azureAISearchConfigWithHybridSearch.UseHybridSearch = true;
azureAISearchConfigWithoutHybridSearch.UseHybridSearch = false;

var memoryNoHybridSearch = new KernelMemoryBuilder()
.WithAzureOpenAITextGeneration(azureOpenAITextConfig, new DefaultGPTTokenizer())
.WithAzureOpenAITextEmbeddingGeneration(azureOpenAIEmbeddingConfig, new DefaultGPTTokenizer())
.WithAzureAISearchMemoryDb(azureAISearchConfigWithoutHybridSearch)
.WithSearchClientConfig(new SearchClientConfig { MaxMatchesCount = 2, Temperature = 0, TopP = 0 })
.Build<MemoryServerless>();

var memoryWithHybridSearch = new KernelMemoryBuilder()
.WithAzureOpenAITextGeneration(azureOpenAITextConfig, new DefaultGPTTokenizer())
.WithAzureOpenAITextEmbeddingGeneration(azureOpenAIEmbeddingConfig, new DefaultGPTTokenizer())
.WithAzureAISearchMemoryDb(azureAISearchConfigWithHybridSearch)
.WithSearchClientConfig(new SearchClientConfig { MaxMatchesCount = 2, Temperature = 0, TopP = 0 })
.Build<MemoryServerless>();

await CreateIndexAndImportData(memoryWithHybridSearch);

const string question = "abc";

Console.WriteLine("Answer without hybrid search:");
await AskQuestion(memoryNoHybridSearch, question);
// Output: INFO NOT FOUND

Console.WriteLine("Answer using hybrid search:");
await AskQuestion(memoryWithHybridSearch, question);
// Output: 'Aliens Brewing Coffee'
}

private static async Task AskQuestion(IKernelMemory memory, string question)
{
var answer = await memory.AskAsync(question, index: indexName);
Console.WriteLine(answer.Result);
}

private static async Task CreateIndexAndImportData(IKernelMemory memory)
{
await memory.DeleteIndexAsync(indexName);

var data = """
aaa bbb ccc 000000000
C B A .......
ai bee cee Something else
XY. abc means 'Aliens Brewing Coffee'
abeec abecedario
A B C D first 4 letters
""";

var rows = data.Split("\n");
foreach (var acronym in rows)
{
await memory.ImportTextAsync(acronym, index: indexName);
}
}
}
77 changes: 77 additions & 0 deletions examples/111-dotnet-azure-ai-hybrid-search/appsettings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
{
"Logging": {
"LogLevel": {
"Default": "Warning",
// Examples: how to handle logs differently by class
// "Microsoft.KernelMemory.Handlers.TextExtractionHandler": "Information",
// "Microsoft.KernelMemory.Handlers.TextPartitioningHandler": "Information",
// "Microsoft.KernelMemory.Handlers.GenerateEmbeddingsHandler": "Information",
// "Microsoft.KernelMemory.Handlers.SaveEmbeddingsHandler": "Information",
// "Microsoft.KernelMemory.ContentStorage.AzureBlobs": "Information",
// "Microsoft.KernelMemory.Pipeline.Queue.AzureQueues": "Information",
"Microsoft.AspNetCore": "Warning"
}
},
"KernelMemory": {
"Services": {
"AzureAISearch": {
// "ApiKey" or "AzureIdentity". For other options see <AzureAISearchConfig>.
// AzureIdentity: use automatic AAD authentication mechanism. You can test locally
// using the env vars AZURE_TENANT_ID, AZURE_CLIENT_ID, AZURE_CLIENT_SECRET.
"Auth": "AzureIdentity",
"Endpoint": "https://<...>",
"APIKey": ""
},
"AzureOpenAIText": {
// "ApiKey" or "AzureIdentity"
// AzureIdentity: use automatic AAD authentication mechanism. You can test locally
// using the env vars AZURE_TENANT_ID, AZURE_CLIENT_ID, AZURE_CLIENT_SECRET.
"Auth": "AzureIdentity",
"Endpoint": "https://<...>.openai.azure.com/",
"APIKey": "",
"Deployment": "",
// The max number of tokens supported by model deployed
// See https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models
"MaxTokenTotal": 16384,
// "ChatCompletion" or "TextCompletion"
"APIType": "ChatCompletion",
"MaxRetries": 10
},
"AzureOpenAIEmbedding": {
// "ApiKey" or "AzureIdentity"
// AzureIdentity: use automatic AAD authentication mechanism. You can test locally
// using the env vars AZURE_TENANT_ID, AZURE_CLIENT_ID, AZURE_CLIENT_SECRET.
"Auth": "AzureIdentity",
"Endpoint": "https://<...>.openai.azure.com/",
"APIKey": "",
"Deployment": "",
// The max number of tokens supported by model deployed
// See https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models
"MaxTokenTotal": 8191
},
"OpenAI": {
// Name of the model used to generate text (text completion or chat completion)
"TextModel": "gpt-3.5-turbo-16k",
// The max number of tokens supported by the text model.
"TextModelMaxTokenTotal": 16384,
// What type of text generation, by default autodetect using the model name.
// Possible values: "Auto", "TextCompletion", "Chat"
"TextGenerationType": "Auto",
// Name of the model used to generate text embeddings
"EmbeddingModel": "text-embedding-ada-002",
// The max number of tokens supported by the embedding model
// See https://platform.openai.com/docs/guides/embeddings/what-are-embeddings
"EmbeddingModelMaxTokenTotal": 8191,
// OpenAI API Key
"APIKey": "",
// OpenAI Organization ID (usually empty, unless you have multiple accounts on different orgs)
"OrgId": "",
// Endpoint to use. By default the system uses 'https://api.openai.com/v1'.
// Change this to use proxies or services compatible with OpenAI HTTP protocol like LM Studio.
"Endpoint": "",
// How many times to retry in case of throttling
"MaxRetries": 10
}
}
}
}
6 changes: 6 additions & 0 deletions extensions/AzureAISearch/AzureAISearch/AzureAISearchConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ public enum AuthTypes
public string Endpoint { get; set; } = string.Empty;
public string APIKey { get; set; } = string.Empty;

/// <summary>
/// Important: when using hybrid search, relevance scores a very
/// different from when using just vector search.
/// </summary>
public bool UseHybridSearch { get; set; } = false;

public void SetCredential(TokenCredential credential)
{
this.Auth = AuthTypes.ManualTokenCredential;
Expand Down
14 changes: 11 additions & 3 deletions extensions/AzureAISearch/AzureAISearch/AzureAISearchMemory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public class AzureAISearchMemory : IMemoryDb
{
private readonly ITextEmbeddingGenerator _embeddingGenerator;
private readonly ILogger<AzureAISearchMemory> _log;
private readonly bool _useHybridSearch;

/// <summary>
/// Create a new instance
Expand All @@ -48,6 +49,7 @@ public class AzureAISearchMemory : IMemoryDb
{
this._embeddingGenerator = embeddingGenerator;
this._log = log ?? DefaultLogger<AzureAISearchMemory>.Instance;
this._useHybridSearch = config.UseHybridSearch;

if (string.IsNullOrEmpty(config.Endpoint))
{
Expand Down Expand Up @@ -184,8 +186,9 @@ public async Task<string> UpsertAsync(string index, MemoryRecord record, Cancell
Response<SearchResults<AzureAISearchMemoryRecord>>? searchResult = null;
try
{
var keyword = this._useHybridSearch ? text : null;
searchResult = await client
.SearchAsync<AzureAISearchMemoryRecord>(null, options, cancellationToken: cancellationToken)
.SearchAsync<AzureAISearchMemoryRecord>(keyword, options, cancellationToken: cancellationToken)
.ConfigureAwait(false);
}
catch (RequestFailedException e) when (e.Status == 404)
Expand All @@ -196,14 +199,19 @@ public async Task<string> UpsertAsync(string index, MemoryRecord record, Cancell

if (searchResult == null) { yield break; }

var minDistance = CosineSimilarityToScore(minRelevance);
var minDistance = this._useHybridSearch ? minRelevance : CosineSimilarityToScore(minRelevance);
var count = 0;
await foreach (SearchResult<AzureAISearchMemoryRecord>? doc in searchResult.Value.GetResultsAsync().ConfigureAwait(false))
{
if (doc == null || doc.Score < minDistance) { continue; }

// In cases where Azure Search is returning too many records
if (++count > limit) { break; }

MemoryRecord memoryRecord = doc.Document.ToMemoryRecord(withEmbeddings);

yield return (memoryRecord, ScoreToCosineSimilarity(doc.Score ?? 0));
var documentScore = this._useHybridSearch ? doc.Score ?? 0 : ScoreToCosineSimilarity(doc.Score ?? 0);
yield return (memoryRecord, documentScore);
}
}

Expand Down
12 changes: 12 additions & 0 deletions service/Core/Search/SearchClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,12 @@ await foreach (MemoryRecord memory in matches.ConfigureAwait(false))
LastUpdate = memory.GetLastUpdate(),
Tags = memory.Tags,
});

// In cases where a buggy storage connector is returning too many records
if (result.Results.Count >= this._config.MaxMatchesCount)
{
break;
}
}

if (result.Results.Count == 0)
Expand Down Expand Up @@ -284,6 +290,12 @@ await foreach ((MemoryRecord memory, double relevance) in matches.ConfigureAwait
LastUpdate = memory.GetLastUpdate(),
Tags = memory.Tags,
});

// In cases where a buggy storage connector is returning too many records
if (factsUsedCount >= this._config.MaxMatchesCount)
{
break;
}
}

if (factsAvailableCount > 0 && factsUsedCount == 0)
Expand Down
5 changes: 4 additions & 1 deletion service/Service/appsettings.json
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,10 @@
// using the env vars AZURE_TENANT_ID, AZURE_CLIENT_ID, AZURE_CLIENT_SECRET.
"Auth": "AzureIdentity",
"Endpoint": "https://<...>",
"APIKey": ""
"APIKey": "",
// Hybrid search is not enabled by default. Note that when using hybrid search
// relevance scores are different, usually lower, than when using just vector search
"UseHybridSearch": false
},
"AzureAIDocIntel": {
// "APIKey" or "AzureIdentity".
Expand Down
2 changes: 2 additions & 0 deletions tools/InteractiveSetup/Services/AzureAISearch.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ public static void Setup(Context ctx, bool force = false)
{ "Auth", "ApiKey" },
{ "Endpoint", "" },
{ "APIKey", "" },
{ "UseHybridSearch", false },
};
}

Expand All @@ -29,6 +30,7 @@ public static void Setup(Context ctx, bool force = false)
{ "Auth", "ApiKey" },
{ "Endpoint", SetupUI.AskOpenQuestion("Azure AI Search <endpoint>", config["Endpoint"].ToString()) },
{ "APIKey", SetupUI.AskPassword("Azure AI Search <API Key>", config["APIKey"].ToString()) },
{ "UseHybridSearch", SetupUI.AskBoolean("Use hybrid search (yes/no)?", (bool)config["UseHybridSearch"]) },
});
}
}
18 changes: 18 additions & 0 deletions tools/InteractiveSetup/UI/DictionaryExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;

namespace Microsoft.KernelMemory.InteractiveSetup.UI;

internal static class DictionaryExtensions
{
public static string TryGet(this Dictionary<string, object> data, string key)
{
return data.TryGetValue(key, out object? value) ? value.ToString() ?? string.Empty : string.Empty;
}

public static string TryGetOr(this Dictionary<string, object> data, string key, string fallbackValue)
{
return data.TryGetValue(key, out object? value) ? value.ToString() ?? string.Empty : fallbackValue;
}
}
26 changes: 13 additions & 13 deletions tools/InteractiveSetup/UI/SetupUI.cs
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Linq;

namespace Microsoft.KernelMemory.InteractiveSetup.UI;

public static class DictionaryExtensions
internal static class SetupUI
{
public static string TryGet(this Dictionary<string, object> data, string key)
public static string AskPassword(string question, string? defaultValue, bool trim = true, bool optional = false)
{
return data.TryGetValue(key, out object? value) ? value.ToString() ?? string.Empty : string.Empty;
return AskOpenQuestion(question: question, defaultValue: defaultValue, trim: trim, optional: optional, isPassword: true);
}

public static string TryGetOr(this Dictionary<string, object> data, string key, string fallbackValue)
public static bool AskBoolean(string question, bool defaultValue)
{
return data.TryGetValue(key, out object? value) ? value.ToString() ?? string.Empty : fallbackValue;
}
}
string[] yes = { "YES", "Y" };
string[] no = { "NO", "N" };
while (true)
{
var answer = AskOpenQuestion(question: question, defaultValue: defaultValue ? "Yes" : "No", optional: false).ToUpperInvariant();
if (yes.Contains(answer)) { return true; }

public static class SetupUI
{
public static string AskPassword(string question, string? defaultValue, bool trim = true, bool optional = false)
{
return AskOpenQuestion(question: question, defaultValue: defaultValue, trim: trim, optional: optional, isPassword: true);
if (no.Contains(answer)) { return false; }
}
}

public static string AskOptionalOpenQuestion(string question, string? defaultValue)
Expand Down

0 comments on commit c631a64

Please sign in to comment.