diff --git a/langchain/src/chat_models/bedrock/web.ts b/langchain/src/chat_models/bedrock/web.ts index 7db648503ad..9c530adfbdd 100644 --- a/langchain/src/chat_models/bedrock/web.ts +++ b/langchain/src/chat_models/bedrock/web.ts @@ -142,7 +142,7 @@ export class BedrockChat extends SimpleChatModel implements BaseBedrockInput { super(fields ?? {}); this.model = fields?.model ?? this.model; - const allowedModels = ["ai21", "anthropic", "amazon", "cohere"]; + const allowedModels = ["ai21", "anthropic", "amazon", "cohere", "meta"]; if (!allowedModels.includes(this.model.split(".")[0])) { throw new Error( `Unknown model: '${this.model}', only these are supported: ${allowedModels}` @@ -300,7 +300,7 @@ export class BedrockChat extends SimpleChatModel implements BaseBedrockInput { this.endpointHost ?? `${service}.${this.region}.amazonaws.com`; const bedrockMethod = - provider === "anthropic" || provider === "cohere" + provider === "anthropic" || provider === "cohere" || provider === "meta" ? "invoke-with-response-stream" : "invoke"; @@ -318,7 +318,11 @@ export class BedrockChat extends SimpleChatModel implements BaseBedrockInput { ); } - if (provider === "anthropic" || provider === "cohere") { + if ( + provider === "anthropic" || + provider === "cohere" || + provider === "meta" + ) { const reader = response.body?.getReader(); const decoder = new TextDecoder(); for await (const chunk of this._readChunks(reader)) { diff --git a/langchain/src/chat_models/tests/chatbedrock.int.test.ts b/langchain/src/chat_models/tests/chatbedrock.int.test.ts index 8be4f077f43..7ca0fe38c59 100644 --- a/langchain/src/chat_models/tests/chatbedrock.int.test.ts +++ b/langchain/src/chat_models/tests/chatbedrock.int.test.ts @@ -2,85 +2,162 @@ /* eslint-disable @typescript-eslint/no-non-null-assertion */ import { test, expect } from "@jest/globals"; -import { ChatBedrock } from "../bedrock/web.js"; +import { BedrockChat } from "../bedrock/web.js"; import { HumanMessage } from "../../schema/index.js"; -test("Test Bedrock chat model: Claude-v2", async () => { - const region = process.env.BEDROCK_AWS_REGION ?? "us-east-1"; - const model = "anthropic.claude-v2"; +void testChatModel( + "Test Bedrock chat model: Llama2 13B v1", + "us-east-1", + "meta.llama2-13b-chat-v1", + "What is your name?" +); +void testChatStreamingModel( + "Test Bedrock streaming chat model: Llama2 13B v1", + "us-east-1", + "meta.llama2-13b-chat-v1", + "What is your name and something about yourself?" +); - const bedrock = new ChatBedrock({ - maxTokens: 20, - region, - model, - maxRetries: 0, - credentials: { - secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!, - accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!, - }, - }); +void testChatModel( + "Test Bedrock chat model: Claude-v2", + "us-east-1", + "anthropic.claude-v2", + "What is your name?" +); +void testChatStreamingModel( + "Test Bedrock chat model streaming: Claude-v2", + "us-east-1", + "anthropic.claude-v2", + "What is your name and something about yourself?" +); - const res = await bedrock.call([new HumanMessage("What is your name?")]); - console.log(res); -}); +void testChatHandleLLMNewToken( + "Test Bedrock chat model HandleLLMNewToken: Claude-v2", + "us-east-1", + "anthropic.claude-v2", + "What is your name and something about yourself?" +); +void testChatHandleLLMNewToken( + "Test Bedrock chat model HandleLLMNewToken: Llama2 13B v1", + "us-east-1", + "meta.llama2-13b-chat-v1", + "What is your name and something about yourself?" +); -test("Test Bedrock chat model streaming: Claude-v2", async () => { - const region = process.env.BEDROCK_AWS_REGION ?? "us-east-1"; - const model = "anthropic.claude-v2"; +/** + * Tests a BedrockChat model + * @param title The name of the test to run + * @param defaultRegion The AWS region to default back to if not set via environment + * @param model The model string to test + * @param message The prompt test to send to the LLM + */ +async function testChatModel( + title: string, + defaultRegion: string, + model: string, + message: string +) { + test(title, async () => { + const region = process.env.BEDROCK_AWS_REGION ?? defaultRegion; - const bedrock = new ChatBedrock({ - maxTokens: 200, - region, - model, - maxRetries: 0, - credentials: { - secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!, - accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!, - }, + const bedrock = new BedrockChat({ + maxTokens: 20, + region, + model, + maxRetries: 0, + credentials: { + secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!, + accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!, + sessionToken: process.env.BEDROCK_AWS_SESSION_TOKEN, + }, + }); + + const res = await bedrock.call([new HumanMessage(message)]); + console.log(res); }); +} +/** + * Tests a BedrockChat model with a streaming response + * @param title The name of the test to run + * @param defaultRegion The AWS region to default back to if not set via environment + * @param model The model string to test + * @param message The prompt test to send to the LLM + */ +async function testChatStreamingModel( + title: string, + defaultRegion: string, + model: string, + message: string +) { + test(title, async () => { + const region = process.env.BEDROCK_AWS_REGION ?? defaultRegion; - const stream = await bedrock.stream([ - new HumanMessage({ - content: "What is your name and something about yourself?", - }), - ]); - const chunks = []; - for await (const chunk of stream) { - console.log(chunk); - chunks.push(chunk); - } - expect(chunks.length).toBeGreaterThan(1); -}); + const bedrock = new BedrockChat({ + maxTokens: 200, + region, + model, + maxRetries: 0, + credentials: { + secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!, + accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!, + sessionToken: process.env.BEDROCK_AWS_SESSION_TOKEN, + }, + }); -test("Test Bedrock chat model handleLLMNewToken: Claude-v2", async () => { - const region = process.env.BEDROCK_AWS_REGION ?? "us-east-1"; - const model = "anthropic.claude-v2"; - const tokens: string[] = []; + const stream = await bedrock.stream([ + new HumanMessage({ + content: message, + }), + ]); + const chunks = []; + for await (const chunk of stream) { + console.log(chunk); + chunks.push(chunk); + } + expect(chunks.length).toBeGreaterThan(1); + }); +} +/** + * Tests a BedrockChat model with a streaming response using a new token callback + * @param title The name of the test to run + * @param defaultRegion The AWS region to default back to if not set via environment + * @param model The model string to test + * @param message The prompt test to send to the LLM + */ +async function testChatHandleLLMNewToken( + title: string, + defaultRegion: string, + model: string, + message: string +) { + test(title, async () => { + const region = process.env.BEDROCK_AWS_REGION ?? defaultRegion; + const tokens: string[] = []; - const bedrock = new ChatBedrock({ - maxTokens: 200, - region, - model, - maxRetries: 0, - credentials: { - secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!, - accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!, - }, - streaming: true, - callbacks: [ - { - handleLLMNewToken: (token) => { - tokens.push(token); - }, + const bedrock = new BedrockChat({ + maxTokens: 200, + region, + model, + maxRetries: 0, + credentials: { + secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!, + accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!, + sessionToken: process.env.BEDROCK_AWS_SESSION_TOKEN, }, - ], + streaming: true, + callbacks: [ + { + handleLLMNewToken: (token) => { + tokens.push(token); + }, + }, + ], + }); + const stream = await bedrock.call([new HumanMessage(message)]); + expect(tokens.length).toBeGreaterThan(1); + expect(stream.content).toEqual(tokens.join("")); }); - const stream = await bedrock.call([ - new HumanMessage("What is your name and something about yourself?"), - ]); - expect(tokens.length).toBeGreaterThan(1); - expect(stream.content).toEqual(tokens.join("")); -}); +} test.skip.each([ "amazon.titan-text-express-v1", @@ -90,7 +167,7 @@ test.skip.each([ ])("Test Bedrock base chat model: %s", async (model) => { const region = process.env.BEDROCK_AWS_REGION ?? "us-east-1"; - const bedrock = new ChatBedrock({ + const bedrock = new BedrockChat({ region, model, maxRetries: 0, @@ -98,6 +175,7 @@ test.skip.each([ credentials: { secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!, accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!, + sessionToken: process.env.BEDROCK_AWS_SESSION_TOKEN, }, }); diff --git a/langchain/src/llms/bedrock/web.ts b/langchain/src/llms/bedrock/web.ts index be8607f17b2..36f01021641 100644 --- a/langchain/src/llms/bedrock/web.ts +++ b/langchain/src/llms/bedrock/web.ts @@ -81,7 +81,7 @@ export class Bedrock extends LLM implements BaseBedrockInput { super(fields ?? {}); this.model = fields?.model ?? this.model; - const allowedModels = ["ai21", "anthropic", "amazon", "cohere"]; + const allowedModels = ["ai21", "anthropic", "amazon", "cohere", "meta"]; if (!allowedModels.includes(this.model.split(".")[0])) { throw new Error( `Unknown model: '${this.model}', only these are supported: ${allowedModels}` @@ -238,7 +238,7 @@ export class Bedrock extends LLM implements BaseBedrockInput { ): AsyncGenerator { const provider = this.model.split(".")[0]; const bedrockMethod = - provider === "anthropic" || provider === "cohere" + provider === "anthropic" || provider === "cohere" || provider === "meta" ? "invoke-with-response-stream" : "invoke"; @@ -261,7 +261,11 @@ export class Bedrock extends LLM implements BaseBedrockInput { ); } - if (provider === "anthropic" || provider === "cohere") { + if ( + provider === "anthropic" || + provider === "cohere" || + provider === "meta" + ) { const reader = response.body?.getReader(); const decoder = new TextDecoder(); for await (const chunk of this._readChunks(reader)) { diff --git a/langchain/src/llms/tests/bedrock.int.test.ts b/langchain/src/llms/tests/bedrock.int.test.ts index 0f59adb6267..67858bd8329 100644 --- a/langchain/src/llms/tests/bedrock.int.test.ts +++ b/langchain/src/llms/tests/bedrock.int.test.ts @@ -17,6 +17,7 @@ test("Test Bedrock LLM: AI21", async () => { credentials: { accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!, secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!, + sessionToken: process.env.BEDROCK_AWS_SESSION_TOKEN, }, }); @@ -26,6 +27,55 @@ test("Test Bedrock LLM: AI21", async () => { console.log(res); }); +test("Test Bedrock LLM: Meta Llama2", async () => { + const region = process.env.BEDROCK_AWS_REGION!; + const model = "meta.llama2-13b-chat-v1"; + const prompt = "Human: What is your name?"; + + const bedrock = new Bedrock({ + maxTokens: 20, + region, + model, + maxRetries: 0, + credentials: { + accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!, + secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!, + sessionToken: process.env.BEDROCK_AWS_SESSION_TOKEN, + }, + }); + + const res = await bedrock.call(prompt); + expect(typeof res).toBe("string"); + + console.log(res); +}); + +test("Test Bedrock LLM streaming: Meta Llama2", async () => { + const region = process.env.BEDROCK_AWS_REGION!; + const model = "meta.llama2-13b-chat-v1"; + const prompt = "What is your name?"; + + const bedrock = new Bedrock({ + maxTokens: 20, + region, + model, + maxRetries: 0, + credentials: { + accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!, + secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!, + sessionToken: process.env.BEDROCK_AWS_SESSION_TOKEN, + }, + }); + + const stream = await bedrock.stream(prompt); + const chunks = []; + for await (const chunk of stream) { + console.log(chunk); + chunks.push(chunk); + } + expect(chunks.length).toBeGreaterThan(1); +}); + test("Test Bedrock LLM: Claude-v2", async () => { const region = process.env.BEDROCK_AWS_REGION!; const model = "anthropic.claude-v2"; @@ -39,6 +89,7 @@ test("Test Bedrock LLM: Claude-v2", async () => { credentials: { accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!, secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!, + sessionToken: process.env.BEDROCK_AWS_SESSION_TOKEN, }, }); @@ -60,6 +111,7 @@ test("Test Bedrock LLM streaming: AI21", async () => { credentials: { accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!, secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!, + sessionToken: process.env.BEDROCK_AWS_SESSION_TOKEN, }, }); @@ -86,6 +138,7 @@ test("Test Bedrock LLM handleLLMNewToken: Claude-v2", async () => { credentials: { accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!, secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!, + sessionToken: process.env.BEDROCK_AWS_SESSION_TOKEN, }, streaming: true, callbacks: [ @@ -115,6 +168,7 @@ test("Test Bedrock LLM streaming: Claude-v2", async () => { credentials: { accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!, secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!, + sessionToken: process.env.BEDROCK_AWS_SESSION_TOKEN, }, }); diff --git a/langchain/src/util/bedrock.ts b/langchain/src/util/bedrock.ts index c6222e6a437..82a5d21ca4c 100644 --- a/langchain/src/util/bedrock.ts +++ b/langchain/src/util/bedrock.ts @@ -88,6 +88,10 @@ export class BedrockLLMInputOutputAdapter { inputBody.maxTokens = maxTokens; inputBody.temperature = temperature; inputBody.stopSequences = stopSequences; + } else if (provider === "meta") { + inputBody.prompt = prompt; + inputBody.max_gen_len = maxTokens; + inputBody.temperature = temperature; } else if (provider === "amazon") { inputBody.inputText = prompt; inputBody.textGenerationConfig = { @@ -120,6 +124,8 @@ export class BedrockLLMInputOutputAdapter { return responseBody?.completions?.[0]?.data?.text ?? ""; } else if (provider === "cohere") { return responseBody?.generations?.[0]?.text ?? responseBody?.text ?? ""; + } else if (provider === "meta") { + return responseBody.generation; } // I haven't been able to get a response with more than one result in it.