-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
bedrock.ts
142 lines (126 loc) Β· 4.09 KB
/
bedrock.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import {
BedrockRuntimeClient,
InvokeModelCommand,
} from "@aws-sdk/client-bedrock-runtime";
import { Embeddings, EmbeddingsParams } from "@langchain/core/embeddings";
import type { CredentialType } from "../utils/bedrock.js";
/**
* Interface that extends EmbeddingsParams and defines additional
* parameters specific to the BedrockEmbeddings class.
*/
export interface BedrockEmbeddingsParams extends EmbeddingsParams {
/**
* Model Name to use. Defaults to `amazon.titan-embed-text-v1` if not provided
*
*/
model?: string;
/**
* A client provided by the user that allows them to customze any
* SDK configuration options.
*/
client?: BedrockRuntimeClient;
region?: string;
credentials?: CredentialType;
}
/**
* Class that extends the Embeddings class and provides methods for
* generating embeddings using the Bedrock API.
* @example
* ```typescript
* const embeddings = new BedrockEmbeddings({
* region: "your-aws-region",
* credentials: {
* accessKeyId: "your-access-key-id",
* secretAccessKey: "your-secret-access-key",
* },
* model: "amazon.titan-embed-text-v1",
* });
*
* // Embed a query and log the result
* const res = await embeddings.embedQuery(
* "What would be a good company name for a company that makes colorful socks?"
* );
* console.log({ res });
* ```
*/
export class BedrockEmbeddings
extends Embeddings
implements BedrockEmbeddingsParams
{
model: string;
client: BedrockRuntimeClient;
batchSize = 512;
constructor(fields?: BedrockEmbeddingsParams) {
super(fields ?? {});
this.model = fields?.model ?? "amazon.titan-embed-text-v1";
this.client =
fields?.client ??
new BedrockRuntimeClient({
region: fields?.region,
credentials: fields?.credentials,
});
}
/**
* Protected method to make a request to the Bedrock API to generate
* embeddings. Handles the retry logic and returns the response from the
* API.
* @param request Request to send to the Bedrock API.
* @returns Promise that resolves to the response from the API.
*/
protected async _embedText(text: string): Promise<number[]> {
return this.caller.call(async () => {
try {
// replace newlines, which can negatively affect performance.
const cleanedText = text.replace(/\n/g, " ");
const res = await this.client.send(
new InvokeModelCommand({
modelId: this.model,
body: JSON.stringify({
inputText: cleanedText,
}),
contentType: "application/json",
accept: "application/json",
})
);
const body = new TextDecoder().decode(res.body);
return JSON.parse(body).embedding;
} catch (e) {
console.error({
error: e,
});
// eslint-disable-next-line no-instanceof/no-instanceof
if (e instanceof Error) {
throw new Error(
`An error occurred while embedding documents with Bedrock: ${e.message}`
);
}
throw new Error(
"An error occurred while embedding documents with Bedrock"
);
}
});
}
/**
* Method that takes a document as input and returns a promise that
* resolves to an embedding for the document. It calls the _embedText
* method with the document as the input.
* @param document Document for which to generate an embedding.
* @returns Promise that resolves to an embedding for the input document.
*/
embedQuery(document: string): Promise<number[]> {
return this.caller.callWithOptions(
{},
this._embedText.bind(this),
document
);
}
/**
* Method to generate embeddings for an array of texts. Calls _embedText
* method which batches and handles retry logic when calling the AWS Bedrock API.
* @param documents Array of texts for which to generate embeddings.
* @returns Promise that resolves to a 2D array of embeddings for each input document.
*/
async embedDocuments(documents: string[]): Promise<number[][]> {
return Promise.all(documents.map((document) => this._embedText(document)));
}
}