-
Notifications
You must be signed in to change notification settings - Fork 3k
/
SemanticTextMemory.cs
134 lines (114 loc) · 4.99 KB
/
SemanticTextMemory.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.SemanticKernel.AI.Embeddings;
namespace Microsoft.SemanticKernel.Memory;
/// <summary>
/// Implementation of <see cref="ISemanticTextMemory"/>./>.
/// </summary>
public sealed class SemanticTextMemory : ISemanticTextMemory, IDisposable
{
private readonly ITextEmbeddingGeneration _embeddingGenerator;
private readonly IMemoryStore _storage;
public SemanticTextMemory(
IMemoryStore storage,
ITextEmbeddingGeneration embeddingGenerator)
{
this._embeddingGenerator = embeddingGenerator;
this._storage = storage;
}
/// <inheritdoc/>
public async Task<string> SaveInformationAsync(
string collection,
string text,
string id,
string? description = null,
string? additionalMetadata = null,
CancellationToken cancellationToken = default)
{
var embedding = await this._embeddingGenerator.GenerateEmbeddingAsync(text, cancellationToken).ConfigureAwait(false);
MemoryRecord data = MemoryRecord.LocalRecord(
id: id, text: text, description: description, additionalMetadata: additionalMetadata, embedding: embedding);
if (!(await this._storage.DoesCollectionExistAsync(collection, cancellationToken).ConfigureAwait(false)))
{
await this._storage.CreateCollectionAsync(collection, cancellationToken).ConfigureAwait(false);
}
return await this._storage.UpsertAsync(collection, data, cancellationToken).ConfigureAwait(false);
}
/// <inheritdoc/>
public async Task<string> SaveReferenceAsync(
string collection,
string text,
string externalId,
string externalSourceName,
string? description = null,
string? additionalMetadata = null,
CancellationToken cancellationToken = default)
{
var embedding = await this._embeddingGenerator.GenerateEmbeddingAsync(text, cancellationToken).ConfigureAwait(false);
var data = MemoryRecord.ReferenceRecord(externalId: externalId, sourceName: externalSourceName, description: description,
additionalMetadata: additionalMetadata, embedding: embedding);
if (!(await this._storage.DoesCollectionExistAsync(collection, cancellationToken).ConfigureAwait(false)))
{
await this._storage.CreateCollectionAsync(collection, cancellationToken).ConfigureAwait(false);
}
return await this._storage.UpsertAsync(collection, data, cancellationToken).ConfigureAwait(false);
}
/// <inheritdoc/>
public async Task<MemoryQueryResult?> GetAsync(
string collection,
string key,
bool withEmbedding = false,
CancellationToken cancellationToken = default)
{
MemoryRecord? record = await this._storage.GetAsync(collection, key, withEmbedding, cancellationToken).ConfigureAwait(false);
if (record == null) { return null; }
return MemoryQueryResult.FromMemoryRecord(record, 1);
}
/// <inheritdoc/>
public async Task RemoveAsync(
string collection,
string key,
CancellationToken cancellationToken = default)
{
await this._storage.RemoveAsync(collection, key, cancellationToken).ConfigureAwait(false);
}
/// <inheritdoc/>
public async IAsyncEnumerable<MemoryQueryResult> SearchAsync(
string collection,
string query,
int limit = 1,
double minRelevanceScore = 0.0,
bool withEmbeddings = false,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
Embedding<float> queryEmbedding = await this._embeddingGenerator.GenerateEmbeddingAsync(query, cancellationToken).ConfigureAwait(false);
IAsyncEnumerable<(MemoryRecord, double)> results = this._storage.GetNearestMatchesAsync(
collectionName: collection,
embedding: queryEmbedding,
limit: limit,
minRelevanceScore: minRelevanceScore,
withEmbeddings: withEmbeddings,
cancellationToken: cancellationToken);
await foreach ((MemoryRecord, double) result in results.WithCancellation(cancellationToken))
{
yield return MemoryQueryResult.FromMemoryRecord(result.Item1, result.Item2);
}
}
/// <inheritdoc/>
public async Task<IList<string>> GetCollectionsAsync(CancellationToken cancellationToken = default)
{
return await this._storage.GetCollectionsAsync(cancellationToken).ToListAsync(cancellationToken).ConfigureAwait(false);
}
public void Dispose()
{
// ReSharper disable once SuspiciousTypeConversion.Global
if (this._embeddingGenerator is IDisposable emb) { emb.Dispose(); }
// ReSharper disable once SuspiciousTypeConversion.Global
if (this._storage is IDisposable storage) { storage.Dispose(); }
}
}