-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
amazon_knowledge_base.ts
113 lines (99 loc) · 2.92 KB
/
amazon_knowledge_base.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import {
RetrieveCommand,
BedrockAgentRuntimeClient,
BedrockAgentRuntimeClientConfig,
} from "@aws-sdk/client-bedrock-agent-runtime";
import { BaseRetriever } from "@langchain/core/retrievers";
import { Document } from "@langchain/core/documents";
/**
* Interface for the arguments required to initialize an
* AmazonKnowledgeBaseRetriever instance.
*/
export interface AmazonKnowledgeBaseRetrieverArgs {
knowledgeBaseId: string;
topK: number;
region: string;
clientOptions?: BedrockAgentRuntimeClientConfig;
}
/**
* Class for interacting with Amazon Bedrock Knowledge Bases, a RAG workflow oriented service
* provided by AWS. Extends the BaseRetriever class.
* @example
* ```typescript
* const retriever = new AmazonKnowledgeBaseRetriever({
* topK: 10,
* knowledgeBaseId: "YOUR_KNOWLEDGE_BASE_ID",
* region: "us-east-2",
* clientOptions: {
* credentials: {
* accessKeyId: "YOUR_ACCESS_KEY_ID",
* secretAccessKey: "YOUR_SECRET_ACCESS_KEY",
* },
* },
* });
*
* const docs = await retriever.getRelevantDocuments("How are clouds formed?");
* ```
*/
export class AmazonKnowledgeBaseRetriever extends BaseRetriever {
static lc_name() {
return "AmazonKnowledgeBaseRetriever";
}
lc_namespace = ["langchain", "retrievers", "amazon_bedrock_knowledge_base"];
knowledgeBaseId: string;
topK: number;
bedrockAgentRuntimeClient: BedrockAgentRuntimeClient;
constructor({
knowledgeBaseId,
topK = 10,
clientOptions,
region,
}: AmazonKnowledgeBaseRetrieverArgs) {
super();
this.topK = topK;
this.bedrockAgentRuntimeClient = new BedrockAgentRuntimeClient({
region,
...clientOptions,
});
this.knowledgeBaseId = knowledgeBaseId;
}
/**
* Cleans the result text by replacing sequences of whitespace with a
* single space and removing ellipses.
* @param resText The result text to clean.
* @returns The cleaned result text.
*/
cleanResult(resText: string) {
const res = resText.replace(/\s+/g, " ").replace(/\.\.\./g, "");
return res;
}
async queryKnowledgeBase(query: string, topK: number) {
const retrieveCommand = new RetrieveCommand({
knowledgeBaseId: this.knowledgeBaseId,
retrievalQuery: {
text: query,
},
retrievalConfiguration: {
vectorSearchConfiguration: {
numberOfResults: topK,
},
},
});
const retrieveResponse = await this.bedrockAgentRuntimeClient.send(
retrieveCommand
);
return (
retrieveResponse.retrievalResults?.map((result) => ({
pageContent: this.cleanResult(result.content?.text || ""),
metadata: {
source: result.location?.s3Location?.uri,
score: result.score,
},
})) ?? ([] as Array<Document>)
);
}
async _getRelevantDocuments(query: string): Promise<Document[]> {
const docs = await this.queryKnowledgeBase(query, this.topK);
return docs;
}
}