Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Meta Llama2 support for BedrockChat #3260

Merged
merged 2 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions langchain/src/chat_models/bedrock/web.ts
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ export class BedrockChat extends SimpleChatModel implements BaseBedrockInput {
super(fields ?? {});

this.model = fields?.model ?? this.model;
const allowedModels = ["ai21", "anthropic", "amazon", "cohere"];
const allowedModels = ["ai21", "anthropic", "amazon", "cohere", "meta"];
if (!allowedModels.includes(this.model.split(".")[0])) {
throw new Error(
`Unknown model: '${this.model}', only these are supported: ${allowedModels}`
Expand Down Expand Up @@ -300,7 +300,7 @@ export class BedrockChat extends SimpleChatModel implements BaseBedrockInput {
this.endpointHost ?? `${service}.${this.region}.amazonaws.com`;

const bedrockMethod =
provider === "anthropic" || provider === "cohere"
provider === "anthropic" || provider === "cohere" || provider === "meta"
? "invoke-with-response-stream"
: "invoke";

Expand All @@ -318,7 +318,11 @@ export class BedrockChat extends SimpleChatModel implements BaseBedrockInput {
);
}

if (provider === "anthropic" || provider === "cohere") {
if (
provider === "anthropic" ||
provider === "cohere" ||
provider === "meta"
) {
const reader = response.body?.getReader();
const decoder = new TextDecoder();
for await (const chunk of this._readChunks(reader)) {
Expand Down
214 changes: 146 additions & 68 deletions langchain/src/chat_models/tests/chatbedrock.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,85 +2,162 @@
/* eslint-disable @typescript-eslint/no-non-null-assertion */
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR introduces changes that access environment variables via process.env or getEnvironmentVariable. Please review these changes to ensure they are handled correctly and securely.


import { test, expect } from "@jest/globals";
import { ChatBedrock } from "../bedrock/web.js";
import { BedrockChat } from "../bedrock/web.js";
import { HumanMessage } from "../../schema/index.js";

test("Test Bedrock chat model: Claude-v2", async () => {
const region = process.env.BEDROCK_AWS_REGION ?? "us-east-1";
const model = "anthropic.claude-v2";
void testChatModel(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do something like test.each below?

"Test Bedrock chat model: Llama2 13B v1",
"us-east-1",
"meta.llama2-13b-chat-v1",
"What is your name?"
);
void testChatStreamingModel(
"Test Bedrock streaming chat model: Llama2 13B v1",
"us-east-1",
"meta.llama2-13b-chat-v1",
"What is your name and something about yourself?"
);

const bedrock = new ChatBedrock({
maxTokens: 20,
region,
model,
maxRetries: 0,
credentials: {
secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!,
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
},
});
void testChatModel(
"Test Bedrock chat model: Claude-v2",
"us-east-1",
"anthropic.claude-v2",
"What is your name?"
);
void testChatStreamingModel(
"Test Bedrock chat model streaming: Claude-v2",
"us-east-1",
"anthropic.claude-v2",
"What is your name and something about yourself?"
);

const res = await bedrock.call([new HumanMessage("What is your name?")]);
console.log(res);
});
void testChatHandleLLMNewToken(
"Test Bedrock chat model HandleLLMNewToken: Claude-v2",
"us-east-1",
"anthropic.claude-v2",
"What is your name and something about yourself?"
);
void testChatHandleLLMNewToken(
"Test Bedrock chat model HandleLLMNewToken: Llama2 13B v1",
"us-east-1",
"meta.llama2-13b-chat-v1",
"What is your name and something about yourself?"
);

test("Test Bedrock chat model streaming: Claude-v2", async () => {
const region = process.env.BEDROCK_AWS_REGION ?? "us-east-1";
const model = "anthropic.claude-v2";
/**
* Tests a BedrockChat model
* @param title The name of the test to run
* @param defaultRegion The AWS region to default back to if not set via environment
* @param model The model string to test
* @param message The prompt test to send to the LLM
*/
async function testChatModel(
title: string,
defaultRegion: string,
model: string,
message: string
) {
test(title, async () => {
const region = process.env.BEDROCK_AWS_REGION ?? defaultRegion;

const bedrock = new ChatBedrock({
maxTokens: 200,
region,
model,
maxRetries: 0,
credentials: {
secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!,
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
},
const bedrock = new BedrockChat({
maxTokens: 20,
region,
model,
maxRetries: 0,
credentials: {
secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!,
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
sessionToken: process.env.BEDROCK_AWS_SESSION_TOKEN,
},
});

const res = await bedrock.call([new HumanMessage(message)]);
console.log(res);
});
}
/**
* Tests a BedrockChat model with a streaming response
* @param title The name of the test to run
* @param defaultRegion The AWS region to default back to if not set via environment
* @param model The model string to test
* @param message The prompt test to send to the LLM
*/
async function testChatStreamingModel(
title: string,
defaultRegion: string,
model: string,
message: string
) {
test(title, async () => {
const region = process.env.BEDROCK_AWS_REGION ?? defaultRegion;

const stream = await bedrock.stream([
new HumanMessage({
content: "What is your name and something about yourself?",
}),
]);
const chunks = [];
for await (const chunk of stream) {
console.log(chunk);
chunks.push(chunk);
}
expect(chunks.length).toBeGreaterThan(1);
});
const bedrock = new BedrockChat({
maxTokens: 200,
region,
model,
maxRetries: 0,
credentials: {
secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!,
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
sessionToken: process.env.BEDROCK_AWS_SESSION_TOKEN,
},
});

test("Test Bedrock chat model handleLLMNewToken: Claude-v2", async () => {
const region = process.env.BEDROCK_AWS_REGION ?? "us-east-1";
const model = "anthropic.claude-v2";
const tokens: string[] = [];
const stream = await bedrock.stream([
new HumanMessage({
content: message,
}),
]);
const chunks = [];
for await (const chunk of stream) {
console.log(chunk);
chunks.push(chunk);
}
expect(chunks.length).toBeGreaterThan(1);
});
}
/**
* Tests a BedrockChat model with a streaming response using a new token callback
* @param title The name of the test to run
* @param defaultRegion The AWS region to default back to if not set via environment
* @param model The model string to test
* @param message The prompt test to send to the LLM
*/
async function testChatHandleLLMNewToken(
title: string,
defaultRegion: string,
model: string,
message: string
) {
test(title, async () => {
const region = process.env.BEDROCK_AWS_REGION ?? defaultRegion;
const tokens: string[] = [];

const bedrock = new ChatBedrock({
maxTokens: 200,
region,
model,
maxRetries: 0,
credentials: {
secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!,
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
},
streaming: true,
callbacks: [
{
handleLLMNewToken: (token) => {
tokens.push(token);
},
const bedrock = new BedrockChat({
maxTokens: 200,
region,
model,
maxRetries: 0,
credentials: {
secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!,
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
sessionToken: process.env.BEDROCK_AWS_SESSION_TOKEN,
},
],
streaming: true,
callbacks: [
{
handleLLMNewToken: (token) => {
tokens.push(token);
},
},
],
});
const stream = await bedrock.call([new HumanMessage(message)]);
expect(tokens.length).toBeGreaterThan(1);
expect(stream.content).toEqual(tokens.join(""));
});
const stream = await bedrock.call([
new HumanMessage("What is your name and something about yourself?"),
]);
expect(tokens.length).toBeGreaterThan(1);
expect(stream.content).toEqual(tokens.join(""));
});
}

test.skip.each([
"amazon.titan-text-express-v1",
Expand All @@ -90,14 +167,15 @@ test.skip.each([
])("Test Bedrock base chat model: %s", async (model) => {
const region = process.env.BEDROCK_AWS_REGION ?? "us-east-1";

const bedrock = new ChatBedrock({
const bedrock = new BedrockChat({
region,
model,
maxRetries: 0,
modelKwargs: {},
credentials: {
secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!,
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
sessionToken: process.env.BEDROCK_AWS_SESSION_TOKEN,
},
});

Expand Down
10 changes: 7 additions & 3 deletions langchain/src/llms/bedrock/web.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ export class Bedrock extends LLM implements BaseBedrockInput {
super(fields ?? {});

this.model = fields?.model ?? this.model;
const allowedModels = ["ai21", "anthropic", "amazon", "cohere"];
const allowedModels = ["ai21", "anthropic", "amazon", "cohere", "meta"];
if (!allowedModels.includes(this.model.split(".")[0])) {
throw new Error(
`Unknown model: '${this.model}', only these are supported: ${allowedModels}`
Expand Down Expand Up @@ -238,7 +238,7 @@ export class Bedrock extends LLM implements BaseBedrockInput {
): AsyncGenerator<GenerationChunk> {
const provider = this.model.split(".")[0];
const bedrockMethod =
provider === "anthropic" || provider === "cohere"
provider === "anthropic" || provider === "cohere" || provider === "meta"
? "invoke-with-response-stream"
: "invoke";

Expand All @@ -261,7 +261,11 @@ export class Bedrock extends LLM implements BaseBedrockInput {
);
}

if (provider === "anthropic" || provider === "cohere") {
if (
provider === "anthropic" ||
provider === "cohere" ||
provider === "meta"
) {
const reader = response.body?.getReader();
const decoder = new TextDecoder();
for await (const chunk of this._readChunks(reader)) {
Expand Down
Loading