Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add stream support for Bedrock Anthropic #1271

Merged
merged 1 commit into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 48 additions & 31 deletions core/llm/llms/Bedrock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,17 @@ import {
ModelProvider,
} from "../..";
import { stripImages } from "../countTokens";
import {
BedrockRuntimeClient,
InvokeModelWithResponseStreamCommand,
} from "@aws-sdk/client-bedrock-runtime";

const aws4 = require("aws4");
const readFile = promisify(fs.readFile);

namespace BedrockCommon {
export enum Method {
Chat = "invoke",
Chat = "invoke-with-response-stream",
Completion = "invoke-with-response-stream",
}
export const Service: string = "bedrock";
Expand Down Expand Up @@ -126,10 +130,7 @@ class Bedrock extends BaseLLM {
const path = `/model/${model}/${apiMethod}`;
const opts = {
headers: {
accept:
apiMethod === BedrockCommon.Method.Chat
? "application/json"
: "application/vnd.amazon.eventstream",
accept: "application/vnd.amazon.eventstream",
"content-type": "application/json",
"x-amzn-bedrock-accept": "*/*",
},
Expand Down Expand Up @@ -172,39 +173,55 @@ class Bedrock extends BaseLLM {
for await (const update of this._streamChat(messages, options)) {
yield stripImages(update.content);
}
// TODO: Couldn't seem to get this stream API working yet. Deferring to _streamChat.
// import { streamSse } from "../stream";
// const response = await this._fetchWithAwsAuthSigV4(BedrockCommon.Method.Completion, JSON.stringify({
// ...this._convertArgs(options),
// max_tokens: undefined, // Delete this key in favor of the correct one for the Completions API.
// max_tokens_to_sample: options.maxTokens,
// prompt: `\n\nHuman: ${prompt}\n\nAssistant:`,
// })
// );
// for await (const value of streamSse(response)) {
// if (value.completion) {
// yield value.completion
// }
// }
}

protected async *_streamChat(
messages: ChatMessage[],
options: CompletionOptions,
): AsyncGenerator<ChatMessage> {
const response = await this._fetchWithAwsAuthSigV4(
BedrockCommon.Method.Chat,
JSON.stringify({
...this._convertArgs(options),
messages: this._convertMessages(messages),
anthropic_version: "bedrock-2023-05-31", // Fixed, required parameter for Chat API.
}),
this._convertModelName(options.model),
const data = await readFile(
joinPath(process.env.HOME ?? os.homedir(), ".aws", "credentials"),
"utf8",
);
yield {
role: "assistant",
content: (await response.json()).content[0].text,
};
const credentials = this._parseCredentialsFile(data);
const accessKeyId = credentials.bedrock.accessKeyId;
const secretAccessKey = credentials.bedrock.secretAccessKey;
const sessionToken = credentials.bedrock.sessionToken || "";
const client = new BedrockRuntimeClient({
region: this.region,
credentials: {
accessKeyId: accessKeyId,
secretAccessKey: secretAccessKey,
sessionToken: sessionToken,
},
});
const command = new InvokeModelWithResponseStreamCommand({
body: new TextEncoder().encode(
JSON.stringify({
anthropic_version: "bedrock-2023-05-31",
max_tokens: options.maxTokens,
system: this.systemMessage,
messages: this._convertMessages(messages),
temperature: options.temperature,
top_p: options.topP,
top_k: options.topK,
stop_sequences: options.stop
}),
),
contentType: "application/json",
modelId: options.model,
});
const response = await client.send(command);
if (response.body) {
for await (const value of response.body) {
const binaryChunk = value.chunk?.bytes;
const textChunk = new TextDecoder().decode(binaryChunk);
const chunk = JSON.parse(textChunk).delta?.text;
if (chunk) {
yield { role: "assistant", content: chunk };
}
}
}
}
}

Expand Down
1 change: 1 addition & 0 deletions core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"ts-jest": "^29.1.1"
},
"dependencies": {
"@aws-sdk/client-bedrock-runtime": "^3.574.0",
"@mozilla/readability": "^0.5.0",
"@octokit/rest": "^20.0.2",
"@types/jsdom": "^21.1.6",
Expand Down
2 changes: 1 addition & 1 deletion extensions/vscode/package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "continue",
"icon": "media/icon.png",
"version": "0.9.93",
"version": "0.9.94",
"repository": {
"type": "git",
"url": "https://github.com/continuedev/continue"
Expand Down