Skip to content

Commit

Permalink
Meta Llama2 support for BedrockChat (#3260)
Browse files Browse the repository at this point in the history
* Adding in meta llama2 support for BedrockChat, updating BedrockChat tests to be easier to add more in the future

* Adding meta compatibility to the Amazon Bedrock LLM class
  • Loading branch information
shafkevi committed Nov 15, 2023
1 parent a69a3e4 commit bd87672
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 74 deletions.
10 changes: 7 additions & 3 deletions langchain/src/chat_models/bedrock/web.ts
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ export class BedrockChat extends SimpleChatModel implements BaseBedrockInput {
super(fields ?? {});

this.model = fields?.model ?? this.model;
const allowedModels = ["ai21", "anthropic", "amazon", "cohere"];
const allowedModels = ["ai21", "anthropic", "amazon", "cohere", "meta"];
if (!allowedModels.includes(this.model.split(".")[0])) {
throw new Error(
`Unknown model: '${this.model}', only these are supported: ${allowedModels}`
Expand Down Expand Up @@ -300,7 +300,7 @@ export class BedrockChat extends SimpleChatModel implements BaseBedrockInput {
this.endpointHost ?? `${service}.${this.region}.amazonaws.com`;

const bedrockMethod =
provider === "anthropic" || provider === "cohere"
provider === "anthropic" || provider === "cohere" || provider === "meta"
? "invoke-with-response-stream"
: "invoke";

Expand All @@ -318,7 +318,11 @@ export class BedrockChat extends SimpleChatModel implements BaseBedrockInput {
);
}

if (provider === "anthropic" || provider === "cohere") {
if (
provider === "anthropic" ||
provider === "cohere" ||
provider === "meta"
) {
const reader = response.body?.getReader();
const decoder = new TextDecoder();
for await (const chunk of this._readChunks(reader)) {
Expand Down
214 changes: 146 additions & 68 deletions langchain/src/chat_models/tests/chatbedrock.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,85 +2,162 @@
/* eslint-disable @typescript-eslint/no-non-null-assertion */

import { test, expect } from "@jest/globals";
import { ChatBedrock } from "../bedrock/web.js";
import { BedrockChat } from "../bedrock/web.js";
import { HumanMessage } from "../../schema/index.js";

test("Test Bedrock chat model: Claude-v2", async () => {
const region = process.env.BEDROCK_AWS_REGION ?? "us-east-1";
const model = "anthropic.claude-v2";
void testChatModel(
"Test Bedrock chat model: Llama2 13B v1",
"us-east-1",
"meta.llama2-13b-chat-v1",
"What is your name?"
);
void testChatStreamingModel(
"Test Bedrock streaming chat model: Llama2 13B v1",
"us-east-1",
"meta.llama2-13b-chat-v1",
"What is your name and something about yourself?"
);

const bedrock = new ChatBedrock({
maxTokens: 20,
region,
model,
maxRetries: 0,
credentials: {
secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!,
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
},
});
void testChatModel(
"Test Bedrock chat model: Claude-v2",
"us-east-1",
"anthropic.claude-v2",
"What is your name?"
);
void testChatStreamingModel(
"Test Bedrock chat model streaming: Claude-v2",
"us-east-1",
"anthropic.claude-v2",
"What is your name and something about yourself?"
);

const res = await bedrock.call([new HumanMessage("What is your name?")]);
console.log(res);
});
void testChatHandleLLMNewToken(
"Test Bedrock chat model HandleLLMNewToken: Claude-v2",
"us-east-1",
"anthropic.claude-v2",
"What is your name and something about yourself?"
);
void testChatHandleLLMNewToken(
"Test Bedrock chat model HandleLLMNewToken: Llama2 13B v1",
"us-east-1",
"meta.llama2-13b-chat-v1",
"What is your name and something about yourself?"
);

test("Test Bedrock chat model streaming: Claude-v2", async () => {
const region = process.env.BEDROCK_AWS_REGION ?? "us-east-1";
const model = "anthropic.claude-v2";
/**
* Tests a BedrockChat model
* @param title The name of the test to run
* @param defaultRegion The AWS region to default back to if not set via environment
* @param model The model string to test
* @param message The prompt test to send to the LLM
*/
async function testChatModel(
title: string,
defaultRegion: string,
model: string,
message: string
) {
test(title, async () => {
const region = process.env.BEDROCK_AWS_REGION ?? defaultRegion;

const bedrock = new ChatBedrock({
maxTokens: 200,
region,
model,
maxRetries: 0,
credentials: {
secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!,
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
},
const bedrock = new BedrockChat({
maxTokens: 20,
region,
model,
maxRetries: 0,
credentials: {
secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!,
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
sessionToken: process.env.BEDROCK_AWS_SESSION_TOKEN,
},
});

const res = await bedrock.call([new HumanMessage(message)]);
console.log(res);
});
}
/**
* Tests a BedrockChat model with a streaming response
* @param title The name of the test to run
* @param defaultRegion The AWS region to default back to if not set via environment
* @param model The model string to test
* @param message The prompt test to send to the LLM
*/
async function testChatStreamingModel(
title: string,
defaultRegion: string,
model: string,
message: string
) {
test(title, async () => {
const region = process.env.BEDROCK_AWS_REGION ?? defaultRegion;

const stream = await bedrock.stream([
new HumanMessage({
content: "What is your name and something about yourself?",
}),
]);
const chunks = [];
for await (const chunk of stream) {
console.log(chunk);
chunks.push(chunk);
}
expect(chunks.length).toBeGreaterThan(1);
});
const bedrock = new BedrockChat({
maxTokens: 200,
region,
model,
maxRetries: 0,
credentials: {
secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!,
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
sessionToken: process.env.BEDROCK_AWS_SESSION_TOKEN,
},
});

test("Test Bedrock chat model handleLLMNewToken: Claude-v2", async () => {
const region = process.env.BEDROCK_AWS_REGION ?? "us-east-1";
const model = "anthropic.claude-v2";
const tokens: string[] = [];
const stream = await bedrock.stream([
new HumanMessage({
content: message,
}),
]);
const chunks = [];
for await (const chunk of stream) {
console.log(chunk);
chunks.push(chunk);
}
expect(chunks.length).toBeGreaterThan(1);
});
}
/**
* Tests a BedrockChat model with a streaming response using a new token callback
* @param title The name of the test to run
* @param defaultRegion The AWS region to default back to if not set via environment
* @param model The model string to test
* @param message The prompt test to send to the LLM
*/
async function testChatHandleLLMNewToken(
title: string,
defaultRegion: string,
model: string,
message: string
) {
test(title, async () => {
const region = process.env.BEDROCK_AWS_REGION ?? defaultRegion;
const tokens: string[] = [];

const bedrock = new ChatBedrock({
maxTokens: 200,
region,
model,
maxRetries: 0,
credentials: {
secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!,
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
},
streaming: true,
callbacks: [
{
handleLLMNewToken: (token) => {
tokens.push(token);
},
const bedrock = new BedrockChat({
maxTokens: 200,
region,
model,
maxRetries: 0,
credentials: {
secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!,
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
sessionToken: process.env.BEDROCK_AWS_SESSION_TOKEN,
},
],
streaming: true,
callbacks: [
{
handleLLMNewToken: (token) => {
tokens.push(token);
},
},
],
});
const stream = await bedrock.call([new HumanMessage(message)]);
expect(tokens.length).toBeGreaterThan(1);
expect(stream.content).toEqual(tokens.join(""));
});
const stream = await bedrock.call([
new HumanMessage("What is your name and something about yourself?"),
]);
expect(tokens.length).toBeGreaterThan(1);
expect(stream.content).toEqual(tokens.join(""));
});
}

test.skip.each([
"amazon.titan-text-express-v1",
Expand All @@ -90,14 +167,15 @@ test.skip.each([
])("Test Bedrock base chat model: %s", async (model) => {
const region = process.env.BEDROCK_AWS_REGION ?? "us-east-1";

const bedrock = new ChatBedrock({
const bedrock = new BedrockChat({
region,
model,
maxRetries: 0,
modelKwargs: {},
credentials: {
secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!,
accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!,
sessionToken: process.env.BEDROCK_AWS_SESSION_TOKEN,
},
});

Expand Down
10 changes: 7 additions & 3 deletions langchain/src/llms/bedrock/web.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ export class Bedrock extends LLM implements BaseBedrockInput {
super(fields ?? {});

this.model = fields?.model ?? this.model;
const allowedModels = ["ai21", "anthropic", "amazon", "cohere"];
const allowedModels = ["ai21", "anthropic", "amazon", "cohere", "meta"];
if (!allowedModels.includes(this.model.split(".")[0])) {
throw new Error(
`Unknown model: '${this.model}', only these are supported: ${allowedModels}`
Expand Down Expand Up @@ -238,7 +238,7 @@ export class Bedrock extends LLM implements BaseBedrockInput {
): AsyncGenerator<GenerationChunk> {
const provider = this.model.split(".")[0];
const bedrockMethod =
provider === "anthropic" || provider === "cohere"
provider === "anthropic" || provider === "cohere" || provider === "meta"
? "invoke-with-response-stream"
: "invoke";

Expand All @@ -261,7 +261,11 @@ export class Bedrock extends LLM implements BaseBedrockInput {
);
}

if (provider === "anthropic" || provider === "cohere") {
if (
provider === "anthropic" ||
provider === "cohere" ||
provider === "meta"
) {
const reader = response.body?.getReader();
const decoder = new TextDecoder();
for await (const chunk of this._readChunks(reader)) {
Expand Down

0 comments on commit bd87672

Please sign in to comment.