-
Notifications
You must be signed in to change notification settings - Fork 3k
/
semantic_text_memory.py
163 lines (136 loc) · 6.46 KB
/
semantic_text_memory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# Copyright (c) Microsoft. All rights reserved.
from typing import Any, Dict, List, Optional
from pydantic import PrivateAttr
from semantic_kernel.connectors.ai.embeddings.embedding_generator_base import EmbeddingGeneratorBase
from semantic_kernel.memory.memory_query_result import MemoryQueryResult
from semantic_kernel.memory.memory_record import MemoryRecord
from semantic_kernel.memory.memory_store_base import MemoryStoreBase
from semantic_kernel.memory.semantic_text_memory_base import SemanticTextMemoryBase
from semantic_kernel.utils.experimental_decorator import experimental_class
@experimental_class
class SemanticTextMemory(SemanticTextMemoryBase):
_storage: MemoryStoreBase = PrivateAttr()
# TODO: replace with kernel and service_selector pattern
_embeddings_generator: EmbeddingGeneratorBase = PrivateAttr()
def __init__(self, storage: MemoryStoreBase, embeddings_generator: EmbeddingGeneratorBase) -> None:
"""Initialize a new instance of SemanticTextMemory.
Arguments:
storage {MemoryStoreBase} -- The MemoryStoreBase to use for storage.
embeddings_generator {EmbeddingGeneratorBase} -- The EmbeddingGeneratorBase
to use for generating embeddings.
Returns:
None -- None.
"""
super().__init__()
self._storage = storage
self._embeddings_generator = embeddings_generator
async def save_information(
self,
collection: str,
text: str,
id: str,
description: Optional[str] = None,
additional_metadata: Optional[str] = None,
embeddings_kwargs: Optional[Dict[str, Any]] = {},
) -> None:
"""Save information to the memory (calls the memory store's upsert method).
Arguments:
collection {str} -- The collection to save the information to.
text {str} -- The text to save.
id {str} -- The id of the information.
description {Optional[str]} -- The description of the information.
Returns:
None -- None.
"""
# TODO: not the best place to create collection, but will address this behavior together with .NET SK
if not await self._storage.does_collection_exist(collection_name=collection):
await self._storage.create_collection(collection_name=collection)
embedding = (await self._embeddings_generator.generate_embeddings([text], **embeddings_kwargs))[0]
data = MemoryRecord.local_record(
id=id,
text=text,
description=description,
additional_metadata=additional_metadata,
embedding=embedding,
)
await self._storage.upsert(collection_name=collection, record=data)
async def save_reference(
self,
collection: str,
text: str,
external_id: str,
external_source_name: str,
description: Optional[str] = None,
additional_metadata: Optional[str] = None,
embeddings_kwargs: Optional[Dict[str, Any]] = {},
) -> None:
"""Save a reference to the memory (calls the memory store's upsert method).
Arguments:
collection {str} -- The collection to save the reference to.
text {str} -- The text to save.
external_id {str} -- The external id of the reference.
external_source_name {str} -- The external source name of the reference.
description {Optional[str]} -- The description of the reference.
Returns:
None -- None.
"""
# TODO: not the best place to create collection, but will address this behavior together with .NET SK
if not await self._storage.does_collection_exist(collection_name=collection):
await self._storage.create_collection(collection_name=collection)
embedding = (await self._embeddings_generator.generate_embeddings([text], **embeddings_kwargs))[0]
data = MemoryRecord.reference_record(
external_id=external_id,
source_name=external_source_name,
description=description,
additional_metadata=additional_metadata,
embedding=embedding,
)
await self._storage.upsert(collection_name=collection, record=data)
async def get(
self,
collection: str,
key: str,
) -> Optional[MemoryQueryResult]:
"""Get information from the memory (calls the memory store's get method).
Arguments:
collection {str} -- The collection to get the information from.
key {str} -- The key of the information.
Returns:
Optional[MemoryQueryResult] -- The MemoryQueryResult if found, None otherwise.
"""
record = await self._storage.get(collection_name=collection, key=key)
return MemoryQueryResult.from_memory_record(record, 1.0) if record else None
async def search(
self,
collection: str,
query: str,
limit: int = 1,
min_relevance_score: float = 0.0,
with_embeddings: bool = False,
embeddings_kwargs: Optional[Dict[str, Any]] = {},
) -> List[MemoryQueryResult]:
"""Search the memory (calls the memory store's get_nearest_matches method).
Arguments:
collection {str} -- The collection to search in.
query {str} -- The query to search for.
limit {int} -- The maximum number of results to return. (default: {1})
min_relevance_score {float} -- The minimum relevance score to return. (default: {0.0})
with_embeddings {bool} -- Whether to return the embeddings of the results. (default: {False})
Returns:
List[MemoryQueryResult] -- The list of MemoryQueryResult found.
"""
query_embedding = (await self._embeddings_generator.generate_embeddings([query], **embeddings_kwargs))[0]
results = await self._storage.get_nearest_matches(
collection_name=collection,
embedding=query_embedding,
limit=limit,
min_relevance_score=min_relevance_score,
with_embeddings=with_embeddings,
)
return [MemoryQueryResult.from_memory_record(r[0], r[1]) for r in results]
async def get_collections(self) -> List[str]:
"""Get the list of collections in the memory (calls the memory store's get_collections method).
Returns:
List[str] -- The list of all the memory collection names.
"""
return await self._storage.get_collections()