-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Implementation of token buffer memory (#3211)
* [feat] implementation of ConversationTokenBufferMemory * Added integration test for buffer_token_memory class * [feat] implementation of ConversationTokenBufferMemory - implementation of ConversationTokenBufferMemory, a module already implemented in Langchain Python * Added two test cases - Chain test case does not resolve * Add token buffer memory implementation Implemented ConversationTokenBufferMemory Added integration test for buffer_token_memory class Co-authored-by: Henry <tran.c.henry@gmail.com> Co-authored-by: Jerry Dang <jerry.dang668@gmail.com> * Added new test case * Added another test case (buffer token memory return messages) Deleted chain test case for now * prettier formatting * prettier format test file * Update devcontainer.json * add documentation for buffer token memory (#5) * add documentation for buffer token memory * Adds entrypoint, update docs to use an example * Passthrough all fields to superclass --------- Co-authored-by: Jerry Dang <jerry.dang668@gmail.com> Co-authored-by: Henry <tran.c.henry@gmail.com> Co-authored-by: Jerry Dang <59210998+jerry-dang@users.noreply.github.com> Co-authored-by: jacoblee93 <jacoblee93@gmail.com>
- Loading branch information
1 parent
c7763b1
commit dba8a9d
Showing
5 changed files
with
210 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Conversation token buffer memory | ||
|
||
This notebook covers how to use `ConversationTokenBufferMemory`. This memory keeps a buffer of recent interactions in memory, and uses token length rather than number of interactions to determine when to flush interactions. | ||
|
||
Below is a basic demonstration of the usage of token buffer memory. | ||
|
||
import CodeBlock from "@theme/CodeBlock"; | ||
import Example from "@examples/memory/token_buffer.ts"; | ||
|
||
<CodeBlock language="typescript">{Example}</CodeBlock> | ||
|
||
We can also get the history as a list of messages, useful if you are using this with `MessagesPlaceholder` in a chat prompt template. | ||
|
||
```typescript | ||
const memory = new ConversationTokenBufferMemory({ | ||
llm: model, | ||
maxTokenLimit: 10, | ||
returnMessages: true, | ||
}); | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import { OpenAI } from "langchain/llms/openai"; | ||
import { ConversationTokenBufferMemory } from "langchain/memory"; | ||
|
||
const model = new OpenAI({}); | ||
const memory = new ConversationTokenBufferMemory({ | ||
llm: model, | ||
maxTokenLimit: 10, | ||
}); | ||
|
||
await memory.saveContext({ input: "hi" }, { output: "whats up" }); | ||
await memory.saveContext({ input: "not much you" }, { output: "not much" }); | ||
|
||
const result1 = await memory.loadMemoryVariables({}); | ||
console.log(result1); | ||
|
||
/* | ||
{ history: 'Human: not much you\nAI: not much' } | ||
*/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import { | ||
InputValues, | ||
MemoryVariables, | ||
getBufferString, | ||
OutputValues, | ||
} from "./base.js"; | ||
|
||
import { BaseChatMemory, BaseChatMemoryInput } from "./chat_memory.js"; | ||
import { BaseLanguageModel } from "../base_language/index.js"; | ||
|
||
/** | ||
* Interface for the input parameters of the `BufferTokenMemory` class. | ||
*/ | ||
|
||
export interface ConversationTokenBufferMemoryInput | ||
extends BaseChatMemoryInput { | ||
/* Prefix for human messages in the buffer. */ | ||
humanPrefix?: string; | ||
|
||
/* Prefix for AI messages in the buffer. */ | ||
aiPrefix?: string; | ||
|
||
/* The LLM for this instance. */ | ||
llm: BaseLanguageModel; | ||
|
||
/* Memory key for buffer instance. */ | ||
memoryKey?: string; | ||
|
||
/* Maximmum number of tokens allowed in the buffer. */ | ||
maxTokenLimit?: number; | ||
} | ||
|
||
/** | ||
* Class that represents a conversation chat memory with a token buffer. | ||
* It extends the `BaseChatMemory` class and implements the | ||
* `ConversationTokenBufferMemoryInput` interface. | ||
*/ | ||
|
||
export class ConversationTokenBufferMemory | ||
extends BaseChatMemory | ||
implements ConversationTokenBufferMemoryInput | ||
{ | ||
humanPrefix = "Human"; | ||
|
||
aiPrefix = "AI"; | ||
|
||
memoryKey = "history"; | ||
|
||
maxTokenLimit = 2000; // Default max token limit of 2000 which can be overridden | ||
|
||
llm: BaseLanguageModel; | ||
|
||
constructor(fields: ConversationTokenBufferMemoryInput) { | ||
super(fields); | ||
this.llm = fields.llm; | ||
this.humanPrefix = fields?.humanPrefix ?? this.humanPrefix; | ||
this.aiPrefix = fields?.aiPrefix ?? this.aiPrefix; | ||
this.memoryKey = fields?.memoryKey ?? this.memoryKey; | ||
this.maxTokenLimit = fields?.maxTokenLimit ?? this.maxTokenLimit; | ||
} | ||
|
||
get memoryKeys() { | ||
return [this.memoryKey]; | ||
} | ||
|
||
/** | ||
* Loads the memory variables. It takes an `InputValues` object as a | ||
* parameter and returns a `Promise` that resolves with a | ||
* `MemoryVariables` object. | ||
* @param _values `InputValues` object. | ||
* @returns A `Promise` that resolves with a `MemoryVariables` object. | ||
*/ | ||
async loadMemoryVariables(_values: InputValues): Promise<MemoryVariables> { | ||
const messages = await this.chatHistory.getMessages(); | ||
if (this.returnMessages) { | ||
const result = { | ||
[this.memoryKey]: messages, | ||
}; | ||
return result; | ||
} | ||
const result = { | ||
[this.memoryKey]: getBufferString( | ||
messages, | ||
this.humanPrefix, | ||
this.aiPrefix | ||
), | ||
}; | ||
return result; | ||
} | ||
|
||
/** | ||
* Saves the context from this conversation to buffer. If the amount | ||
* of tokens required to save the buffer exceeds MAX_TOKEN_LIMIT, | ||
* prune it. | ||
*/ | ||
async saveContext(inputValues: InputValues, outputValues: OutputValues) { | ||
await super.saveContext(inputValues, outputValues); | ||
|
||
// Prune buffer if it exceeds the max token limit set for this instance. | ||
const buffer = await this.chatHistory.getMessages(); | ||
let currBufferLength = await this.llm.getNumTokens( | ||
getBufferString(buffer, this.humanPrefix, this.aiPrefix) | ||
); | ||
|
||
if (currBufferLength > this.maxTokenLimit) { | ||
const prunedMemory = []; | ||
while (currBufferLength > this.maxTokenLimit) { | ||
prunedMemory.push(buffer.shift()); | ||
currBufferLength = await this.llm.getNumTokens( | ||
getBufferString(buffer, this.humanPrefix, this.aiPrefix) | ||
); | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
53 changes: 53 additions & 0 deletions
53
langchain/src/memory/tests/buffer_token_memory.int.test.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import { test, expect } from "@jest/globals"; | ||
import { OpenAI } from "../../llms/openai.js"; | ||
import { ConversationTokenBufferMemory } from "../buffer_token_memory.js"; | ||
import { ChatMessageHistory } from "../../stores/message/in_memory.js"; | ||
import { HumanMessage, AIMessage } from "../../schema/index.js"; | ||
|
||
test("Test buffer token memory with LLM", async () => { | ||
const memory = new ConversationTokenBufferMemory({ | ||
llm: new OpenAI(), | ||
maxTokenLimit: 10, | ||
}); | ||
const result1 = await memory.loadMemoryVariables({}); | ||
expect(result1).toStrictEqual({ history: "" }); | ||
|
||
await memory.saveContext({ input: "foo" }, { output: "bar" }); | ||
const expectedString = "Human: foo\nAI: bar"; | ||
const result2 = await memory.loadMemoryVariables({}); | ||
expect(result2).toStrictEqual({ history: expectedString }); | ||
|
||
await memory.saveContext({ foo: "foo" }, { bar: "bar" }); | ||
await memory.saveContext({ foo: "bar" }, { bar: "foo" }); | ||
const expectedString3 = "Human: bar\nAI: foo"; | ||
const result3 = await memory.loadMemoryVariables({}); | ||
expect(result3).toStrictEqual({ history: expectedString3 }); | ||
}); | ||
|
||
test("Test buffer token memory return messages", async () => { | ||
const memory = new ConversationTokenBufferMemory({ | ||
llm: new OpenAI(), | ||
returnMessages: true, | ||
}); | ||
const result1 = await memory.loadMemoryVariables({}); | ||
expect(result1).toStrictEqual({ history: [] }); | ||
|
||
await memory.saveContext({ foo: "bar" }, { bar: "foo" }); | ||
const expectedResult = [new HumanMessage("bar"), new AIMessage("foo")]; | ||
const result2 = await memory.loadMemoryVariables({}); | ||
expect(result2).toStrictEqual({ history: expectedResult }); | ||
}); | ||
|
||
test("Test buffer token memory with pre-loaded history", async () => { | ||
const pastMessages = [ | ||
new HumanMessage("My name's Jonas"), | ||
new AIMessage("Nice to meet you, Jonas!"), | ||
]; | ||
const memory = new ConversationTokenBufferMemory({ | ||
llm: new OpenAI(), | ||
returnMessages: true, | ||
chatHistory: new ChatMessageHistory(pastMessages), | ||
}); | ||
const result = await memory.loadMemoryVariables({}); | ||
expect(result).toStrictEqual({ history: pastMessages }); | ||
}); |
dba8a9d
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Successfully deployed to the following URLs:
langchainjs-docs – ./docs/core_docs/
langchainjs-docs-langchain.vercel.app
langchainjs-docs-git-main-langchain.vercel.app
langchainjs-docs-ruddy.vercel.app
js.langchain.com