Skip to content

Commit

Permalink
Merge branch 'preview' of https://github.com/KoStard/continue into pr…
Browse files Browse the repository at this point in the history
…eview
  • Loading branch information
sestinj committed May 20, 2024
2 parents 59ca138 + 8e18355 commit de0dc3d
Show file tree
Hide file tree
Showing 3 changed files with 1,268 additions and 38 deletions.
91 changes: 55 additions & 36 deletions core/llm/llms/Bedrock.ts
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
import {
BedrockRuntimeClient,
InvokeModelWithResponseStreamCommand,
} from "@aws-sdk/client-bedrock-runtime";
import * as fs from "fs";
import os from "os";
import { join as joinPath } from "path";
import { promisify } from "util";
import { BaseLLM } from "../index.js";
import {
ChatMessage,
CompletionOptions,
LLMOptions,
ModelProvider,
} from "../../index.js";
import { stripImages } from "../countTokens.js";
import { BaseLLM } from "../index.js";

const aws4 = require("aws4");
const readFile = promisify(fs.readFile);

namespace BedrockCommon {
export enum Method {
Chat = "invoke",
Chat = "invoke-with-response-stream",
Completion = "invoke-with-response-stream",
}
export const Service: string = "bedrock";
Expand Down Expand Up @@ -126,10 +130,7 @@ class Bedrock extends BaseLLM {
const path = `/model/${model}/${apiMethod}`;
const opts = {
headers: {
accept:
apiMethod === BedrockCommon.Method.Chat
? "application/json"
: "application/vnd.amazon.eventstream",
accept: "application/vnd.amazon.eventstream",
"content-type": "application/json",
"x-amzn-bedrock-accept": "*/*",
},
Expand All @@ -147,10 +148,11 @@ class Bedrock extends BaseLLM {
joinPath(process.env.HOME ?? os.homedir(), ".aws", "credentials"),
"utf8",
);
const credentials = this._parseCredentialsFile(data);
accessKeyId = credentials.bedrock.accessKeyId;
secretAccessKey = credentials.bedrock.secretAccessKey;
sessionToken = credentials.bedrock.sessionToken || "";
const credentialsFile = this._parseCredentialsFile(data);
const credentials = credentialsFile.bedrock ?? credentialsFile.default;
accessKeyId = credentials.accessKeyId;
secretAccessKey = credentials.secretAccessKey;
sessionToken = credentials.sessionToken || "";
} catch (err) {
console.error("Error reading AWS credentials", err);
return new Response("403");
Expand All @@ -172,39 +174,56 @@ class Bedrock extends BaseLLM {
for await (const update of this._streamChat(messages, options)) {
yield stripImages(update.content);
}
// TODO: Couldn't seem to get this stream API working yet. Deferring to _streamChat.
// import { streamSse } from "../stream";
// const response = await this._fetchWithAwsAuthSigV4(BedrockCommon.Method.Completion, JSON.stringify({
// ...this._convertArgs(options),
// max_tokens: undefined, // Delete this key in favor of the correct one for the Completions API.
// max_tokens_to_sample: options.maxTokens,
// prompt: `\n\nHuman: ${prompt}\n\nAssistant:`,
// })
// );
// for await (const value of streamSse(response)) {
// if (value.completion) {
// yield value.completion
// }
// }
}

protected async *_streamChat(
messages: ChatMessage[],
options: CompletionOptions,
): AsyncGenerator<ChatMessage> {
const response = await this._fetchWithAwsAuthSigV4(
BedrockCommon.Method.Chat,
JSON.stringify({
...this._convertArgs(options),
messages: this._convertMessages(messages),
anthropic_version: "bedrock-2023-05-31", // Fixed, required parameter for Chat API.
}),
this._convertModelName(options.model),
const data = await readFile(
joinPath(process.env.HOME ?? os.homedir(), ".aws", "credentials"),
"utf8",
);
yield {
role: "assistant",
content: (await response.json()).content[0].text,
};
const credentialsFile = this._parseCredentialsFile(data);
const credentials = credentialsFile.bedrock ?? credentialsFile.default
const accessKeyId = credentials.accessKeyId;
const secretAccessKey = credentials.secretAccessKey;
const sessionToken = credentials.sessionToken || "";
const client = new BedrockRuntimeClient({
region: this.region,
credentials: {
accessKeyId: accessKeyId,
secretAccessKey: secretAccessKey,
sessionToken: sessionToken,
},
});
const command = new InvokeModelWithResponseStreamCommand({
body: new TextEncoder().encode(
JSON.stringify({
anthropic_version: "bedrock-2023-05-31",
max_tokens: options.maxTokens,
system: this.systemMessage,
messages: this._convertMessages(messages),
temperature: options.temperature,
top_p: options.topP,
top_k: options.topK,
stop_sequences: options.stop
}),
),
contentType: "application/json",
modelId: options.model,
});
const response = await client.send(command);
if (response.body) {
for await (const value of response.body) {
const binaryChunk = value.chunk?.bytes;
const textChunk = new TextDecoder().decode(binaryChunk);
const chunk = JSON.parse(textChunk).delta?.text;
if (chunk) {
yield { role: "assistant", content: chunk };
}
}
}
}
}

Expand Down
Loading

0 comments on commit de0dc3d

Please sign in to comment.