-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
time_weighted.ts
299 lines (273 loc) Β· 9.44 KB
/
time_weighted.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
import { BaseRetriever, BaseRetrieverInput } from "@langchain/core/retrievers";
import type { VectorStoreInterface } from "@langchain/core/vectorstores";
import type { DocumentInterface } from "@langchain/core/documents";
import { CallbackManagerForRetrieverRun } from "@langchain/core/callbacks/manager";
/**
* Interface for the fields required to initialize a
* TimeWeightedVectorStoreRetriever instance.
*/
export interface TimeWeightedVectorStoreRetrieverFields
extends BaseRetrieverInput {
vectorStore: VectorStoreInterface;
searchKwargs?: number;
memoryStream?: DocumentInterface[];
decayRate?: number;
k?: number;
otherScoreKeys?: string[];
defaultSalience?: number;
}
export const LAST_ACCESSED_AT_KEY = "last_accessed_at";
export const BUFFER_IDX = "buffer_idx";
/**
* TimeWeightedVectorStoreRetriever retrieves documents based on their time-weighted relevance.
* ref: https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/retrievers/time_weighted_retriever.py
* @example
* ```typescript
* const retriever = new TimeWeightedVectorStoreRetriever({
* vectorStore: new MemoryVectorStore(new OpenAIEmbeddings()),
* memoryStream: [],
* searchKwargs: 2,
* });
* await retriever.addDocuments([
* { pageContent: "My name is John.", metadata: {} },
* { pageContent: "My favourite food is pizza.", metadata: {} },
*
* ]);
* const results = await retriever.getRelevantDocuments(
* "What is my favourite food?",
* );
* ```
*/
export class TimeWeightedVectorStoreRetriever extends BaseRetriever {
static lc_name() {
return "TimeWeightedVectorStoreRetriever";
}
get lc_namespace() {
return ["langchain", "retrievers", "time_weighted"];
}
/**
* The vectorstore to store documents and determine salience.
*/
private vectorStore: VectorStoreInterface;
/**
* The number of top K most relevant documents to consider when searching.
*/
private searchKwargs: number;
/**
* The memory_stream of documents to search through.
*/
private memoryStream: DocumentInterface[];
/**
* The exponential decay factor used as (1.0-decay_rate)**(hrs_passed).
*/
private decayRate: number;
/**
* The maximum number of documents to retrieve in a given call.
*/
private k: number;
/**
* Other keys in the metadata to factor into the score, e.g. 'importance'.
*/
private otherScoreKeys: string[];
/**
* The salience to assign memories not retrieved from the vector store.
*/
private defaultSalience: number | null;
/**
* Constructor to initialize the required fields
* @param fields - The fields required for initializing the TimeWeightedVectorStoreRetriever
*/
constructor(fields: TimeWeightedVectorStoreRetrieverFields) {
super(fields);
this.vectorStore = fields.vectorStore;
this.searchKwargs = fields.searchKwargs ?? 100;
this.memoryStream = fields.memoryStream ?? [];
this.decayRate = fields.decayRate ?? 0.01;
this.k = fields.k ?? 4;
this.otherScoreKeys = fields.otherScoreKeys ?? [];
this.defaultSalience = fields.defaultSalience ?? null;
}
/**
* Get the memory stream of documents.
* @returns The memory stream of documents.
*/
getMemoryStream(): DocumentInterface[] {
return this.memoryStream;
}
/**
* Set the memory stream of documents.
* @param memoryStream The new memory stream of documents.
*/
setMemoryStream(memoryStream: DocumentInterface[]) {
this.memoryStream = memoryStream;
}
/**
* Get relevant documents based on time-weighted relevance
* @param query - The query to search for
* @returns The relevant documents
*/
async _getRelevantDocuments(
query: string,
runManager?: CallbackManagerForRetrieverRun
): Promise<DocumentInterface[]> {
const now = Math.floor(Date.now() / 1000);
const memoryDocsAndScores = this.getMemoryDocsAndScores();
const salientDocsAndScores = await this.getSalientDocuments(
query,
runManager
);
const docsAndScores = { ...memoryDocsAndScores, ...salientDocsAndScores };
return this.computeResults(docsAndScores, now);
}
/**
* NOTE: When adding documents to a vector store, use addDocuments
* via retriever instead of directly to the vector store.
* This is because it is necessary to process the document
* in prepareDocuments.
*
* @param docs - The documents to add to vector store in the retriever
*/
async addDocuments(docs: DocumentInterface[]): Promise<void> {
const now = Math.floor(Date.now() / 1000);
const savedDocs = this.prepareDocuments(docs, now);
this.memoryStream.push(...savedDocs);
await this.vectorStore.addDocuments(savedDocs);
}
/**
* Get memory documents and their scores
* @returns An object containing memory documents and their scores
*/
private getMemoryDocsAndScores(): Record<
number,
{ doc: DocumentInterface; score: number }
> {
const memoryDocsAndScores: Record<
number,
{ doc: DocumentInterface; score: number }
> = {};
for (const doc of this.memoryStream.slice(-this.k)) {
const bufferIdx = doc.metadata[BUFFER_IDX];
if (bufferIdx === undefined) {
throw new Error(
`Found a document in the vector store that is missing required metadata. This retriever only supports vector stores with documents that have been added through the "addDocuments" method on a TimeWeightedVectorStoreRetriever, not directly added or loaded into the backing vector store.`
);
}
memoryDocsAndScores[bufferIdx] = {
doc,
score: this.defaultSalience ?? 0,
};
}
return memoryDocsAndScores;
}
/**
* Get salient documents and their scores based on the query
* @param query - The query to search for
* @returns An object containing salient documents and their scores
*/
private async getSalientDocuments(
query: string,
runManager?: CallbackManagerForRetrieverRun
): Promise<Record<number, { doc: DocumentInterface; score: number }>> {
const docAndScores: [DocumentInterface, number][] =
await this.vectorStore.similaritySearchWithScore(
query,
this.searchKwargs,
undefined,
runManager?.getChild()
);
const results: Record<number, { doc: DocumentInterface; score: number }> =
{};
for (const [fetchedDoc, score] of docAndScores) {
const bufferIdx = fetchedDoc.metadata[BUFFER_IDX];
if (bufferIdx === undefined) {
throw new Error(
`Found a document in the vector store that is missing required metadata. This retriever only supports vector stores with documents that have been added through the "addDocuments" method on a TimeWeightedVectorStoreRetriever, not directly added or loaded into the backing vector store.`
);
}
const doc = this.memoryStream[bufferIdx];
results[bufferIdx] = { doc, score };
}
return results;
}
/**
* Compute the final result set of documents based on the combined scores
* @param docsAndScores - An object containing documents and their scores
* @param now - The current timestamp
* @returns The final set of documents
*/
private computeResults(
docsAndScores: Record<number, { doc: DocumentInterface; score: number }>,
now: number
): DocumentInterface[] {
const recordedDocs = Object.values(docsAndScores)
.map(({ doc, score }) => ({
doc,
score: this.getCombinedScore(doc, score, now),
}))
.sort((a, b) => b.score - a.score);
const results: DocumentInterface[] = [];
for (const { doc } of recordedDocs) {
const bufferedDoc = this.memoryStream[doc.metadata[BUFFER_IDX]];
bufferedDoc.metadata[LAST_ACCESSED_AT_KEY] = now;
results.push(bufferedDoc);
if (results.length > this.k) {
break;
}
}
return results;
}
/**
* Prepare documents with necessary metadata before saving
* @param docs - The documents to prepare
* @param now - The current timestamp
* @returns The prepared documents
*/
private prepareDocuments(
docs: DocumentInterface[],
now: number
): DocumentInterface[] {
return docs.map((doc, i) => ({
...doc,
metadata: {
...doc.metadata,
[LAST_ACCESSED_AT_KEY]: doc.metadata[LAST_ACCESSED_AT_KEY] ?? now,
created_at: doc.metadata.created_at ?? now,
[BUFFER_IDX]: this.memoryStream.length + i,
},
}));
}
/**
* Calculate the combined score based on vector relevance and other factors
* @param doc - The document to calculate the score for
* @param vectorRelevance - The relevance score from the vector store
* @param nowMsec - The current timestamp in milliseconds
* @returns The combined score for the document
*/
private getCombinedScore(
doc: DocumentInterface,
vectorRelevance: number | null,
nowMsec: number
): number {
const hoursPassed = this.getHoursPassed(
nowMsec,
doc.metadata[LAST_ACCESSED_AT_KEY]
);
let score = (1.0 - this.decayRate) ** hoursPassed;
for (const key of this.otherScoreKeys) {
score += doc.metadata[key];
}
if (vectorRelevance !== null) {
score += vectorRelevance;
}
return score;
}
/**
* Calculate the hours passed between two time points
* @param time - The current time in seconds
* @param refTime - The reference time in seconds
* @returns The number of hours passed between the two time points
*/
private getHoursPassed(time: number, refTime: number): number {
return (time - refTime) / 3600;
}
}