-
Notifications
You must be signed in to change notification settings - Fork 2k
/
entity_memory.ts
216 lines (194 loc) Β· 6.29 KB
/
entity_memory.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
import type { BaseLanguageModelInterface } from "@langchain/core/language_models/base";
import {
BaseChatMemory,
BaseChatMemoryInput,
} from "@langchain/community/memory/chat_memory";
import { PromptTemplate } from "@langchain/core/prompts";
import {
InputValues,
MemoryVariables,
OutputValues,
getPromptInputKey,
} from "@langchain/core/memory";
import { getBufferString } from "@langchain/core/messages";
import { InMemoryEntityStore } from "./stores/entity/in_memory.js";
import { LLMChain } from "../chains/llm_chain.js";
import {
ENTITY_EXTRACTION_PROMPT,
ENTITY_SUMMARIZATION_PROMPT,
} from "./prompt.js";
import { BaseEntityStore } from "./stores/entity/base.js";
/**
* Interface for the input parameters required by the EntityMemory class.
*/
export interface EntityMemoryInput extends BaseChatMemoryInput {
llm: BaseLanguageModelInterface;
humanPrefix?: string;
aiPrefix?: string;
entityExtractionPrompt?: PromptTemplate;
entitySummarizationPrompt?: PromptTemplate;
entityCache?: string[];
k?: number;
chatHistoryKey?: string;
entitiesKey?: string;
entityStore?: BaseEntityStore;
}
// Entity extractor & summarizer to memory.
/**
* Class for managing entity extraction and summarization to memory in
* chatbot applications. Extends the BaseChatMemory class and implements
* the EntityMemoryInput interface.
* @example
* ```typescript
* const memory = new EntityMemory({
* llm: new ChatOpenAI({ temperature: 0 }),
* chatHistoryKey: "history",
* entitiesKey: "entities",
* });
* const model = new ChatOpenAI({ temperature: 0.9 });
* const chain = new LLMChain({
* llm: model,
* prompt: ENTITY_MEMORY_CONVERSATION_TEMPLATE,
* memory,
* });
*
* const res1 = await chain.call({ input: "Hi! I'm Jim." });
* console.log({
* res1,
* memory: await memory.loadMemoryVariables({ input: "Who is Jim?" }),
* });
*
* const res2 = await chain.call({
* input: "I work in construction. What about you?",
* });
* console.log({
* res2,
* memory: await memory.loadMemoryVariables({ input: "Who is Jim?" }),
* });
*
* ```
*/
export class EntityMemory extends BaseChatMemory implements EntityMemoryInput {
private entityExtractionChain: LLMChain;
private entitySummarizationChain: LLMChain;
entityStore: BaseEntityStore;
entityCache: string[] = [];
k = 3;
chatHistoryKey = "history";
llm: BaseLanguageModelInterface;
entitiesKey = "entities";
humanPrefix?: string;
aiPrefix?: string;
constructor(fields: EntityMemoryInput) {
super({
chatHistory: fields.chatHistory,
returnMessages: fields.returnMessages ?? false,
inputKey: fields.inputKey,
outputKey: fields.outputKey,
});
this.llm = fields.llm;
this.humanPrefix = fields.humanPrefix;
this.aiPrefix = fields.aiPrefix;
this.chatHistoryKey = fields.chatHistoryKey ?? this.chatHistoryKey;
this.entitiesKey = fields.entitiesKey ?? this.entitiesKey;
this.entityExtractionChain = new LLMChain({
llm: this.llm,
prompt: fields.entityExtractionPrompt ?? ENTITY_EXTRACTION_PROMPT,
});
this.entitySummarizationChain = new LLMChain({
llm: this.llm,
prompt: fields.entitySummarizationPrompt ?? ENTITY_SUMMARIZATION_PROMPT,
});
this.entityStore = fields.entityStore ?? new InMemoryEntityStore();
this.entityCache = fields.entityCache ?? this.entityCache;
this.k = fields.k ?? this.k;
}
get memoryKeys() {
return [this.chatHistoryKey];
}
// Will always return list of memory variables.
get memoryVariables(): string[] {
return [this.entitiesKey, this.chatHistoryKey];
}
// Return history buffer.
/**
* Method to load memory variables and perform entity extraction.
* @param inputs Input values for the method.
* @returns Promise resolving to an object containing memory variables.
*/
async loadMemoryVariables(inputs: InputValues): Promise<MemoryVariables> {
const promptInputKey =
this.inputKey ?? getPromptInputKey(inputs, this.memoryVariables);
const messages = await this.chatHistory.getMessages();
const serializedMessages = getBufferString(
messages.slice(-this.k * 2),
this.humanPrefix,
this.aiPrefix
);
const output = await this.entityExtractionChain.predict({
history: serializedMessages,
input: inputs[promptInputKey],
});
const entities: string[] =
output.trim() === "NONE" ? [] : output.split(",").map((w) => w.trim());
const entitySummaries: { [key: string]: string | undefined } = {};
for (const entity of entities) {
entitySummaries[entity] = await this.entityStore.get(
entity,
"No current information known."
);
}
this.entityCache = [...entities];
const buffer = this.returnMessages
? messages.slice(-this.k * 2)
: serializedMessages;
return {
[this.chatHistoryKey]: buffer,
[this.entitiesKey]: entitySummaries,
};
}
// Save context from this conversation to buffer.
/**
* Method to save the context from a conversation to a buffer and perform
* entity summarization.
* @param inputs Input values for the method.
* @param outputs Output values from the method.
* @returns Promise resolving to void.
*/
async saveContext(inputs: InputValues, outputs: OutputValues): Promise<void> {
await super.saveContext(inputs, outputs);
const promptInputKey =
this.inputKey ?? getPromptInputKey(inputs, this.memoryVariables);
const messages = await this.chatHistory.getMessages();
const serializedMessages = getBufferString(
messages.slice(-this.k * 2),
this.humanPrefix,
this.aiPrefix
);
const inputData = inputs[promptInputKey];
for (const entity of this.entityCache) {
const existingSummary = await this.entityStore.get(
entity,
"No current information known."
);
const output = await this.entitySummarizationChain.predict({
summary: existingSummary,
entity,
history: serializedMessages,
input: inputData,
});
if (output.trim() !== "UNCHANGED") {
await this.entityStore.set(entity, output.trim());
}
}
}
// Clear memory contents.
/**
* Method to clear the memory contents.
* @returns Promise resolving to void.
*/
async clear() {
await super.clear();
await this.entityStore.clear();
}
}