Skip to content

Commit

Permalink
feat(agents): Support LLMChain in OpenAIFunctionsAgent and memory
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmigloz committed Aug 5, 2023
1 parent edf4ea5 commit c5b0a0a
Show file tree
Hide file tree
Showing 13 changed files with 256 additions and 93 deletions.
3 changes: 3 additions & 0 deletions packages/langchain/lib/src/agents/agent.dart
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ abstract class BaseActionAgent {
/// {@macro base_action_agent}
const BaseActionAgent();

/// The key for the scratchpad (intermediate steps) of the agent.
static const agentScratchpadInputKey = 'agent_scratchpad';

/// Return key for the agent's output.
static const agentReturnKey = 'output';

Expand Down
6 changes: 6 additions & 0 deletions packages/langchain/lib/src/agents/executors.dart
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ import 'tools/invalid.dart';
/// retrieves the output, and passes it back to the agent to determine the next
/// action. This process continues until the agent determines it can directly
/// respond to the user or completes its task.
///
/// If you add [memory] to the [AgentExecutor], it will save the
/// [AgentExecutor]'s inputs and outputs. It won't save the agent's
/// intermediate inputs and outputs. If you want to save the agent's
/// intermediate inputs and outputs, you should add [memory] to the agent
/// instead.
/// {@endtemplate}
class AgentExecutor extends BaseChain {
AgentExecutor({
Expand Down
10 changes: 10 additions & 0 deletions packages/langchain/lib/src/agents/tools/base.dart
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import 'dart:async';

import 'package:meta/meta.dart';

import '../../model_io/chat_models/models/models.dart';
import 'models/models.dart';

/// {@template base_tool}
Expand Down Expand Up @@ -104,6 +105,15 @@ abstract base class BaseTool {
return run(toolInput);
}

/// Converts the tool to a [ChatFunction].
ChatFunction toChatFunction() {
return ChatFunction(
name: name,
description: description,
parameters: inputJsonSchema,
);
}

@override
bool operator ==(covariant final BaseTool other) =>
identical(this, other) || name == other.name;
Expand Down
17 changes: 14 additions & 3 deletions packages/langchain/lib/src/memory/chat.dart
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import 'dart:async';

import '../model_io/chat_models/models/models.dart';
import '../utils/exception.dart';
import 'base.dart';
import 'models/models.dart';
Expand Down Expand Up @@ -46,11 +47,21 @@ abstract base class BaseChatMemory implements BaseMemory {
}) async {
// this is purposefully done in sequence so they're saved in order
final (input, output) = _getInputOutputValues(inputValues, outputValues);
await chatHistory.addUserChatMessage(input);
await chatHistory.addAIChatMessage(output);

if (input is ChatMessage) {
await chatHistory.addChatMessage(input);
} else {
await chatHistory.addHumanChatMessage(input.toString());
}

if (output is ChatMessage) {
await chatHistory.addChatMessage(output);
} else {
await chatHistory.addAIChatMessage(output.toString());
}
}

(String input, String output) _getInputOutputValues(
(dynamic input, dynamic output) _getInputOutputValues(
final MemoryInputValues inputValues,
final MemoryOutputValues outputValues,
) {
Expand Down
4 changes: 2 additions & 2 deletions packages/langchain/lib/src/memory/stores/message/history.dart
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ abstract base class BaseChatMessageHistory {
/// Add [ChatMessage] to the history.
Future<void> addChatMessage(final ChatMessage message);

/// Add a user message to the history.
Future<void> addUserChatMessage(final String message) {
/// Add a human message to the history.
Future<void> addHumanChatMessage(final String message) {
return addChatMessage(ChatMessage.human(message));
}

Expand Down
15 changes: 10 additions & 5 deletions packages/langchain/lib/src/memory/utils.dart
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import '../agents/agent.dart';
import '../utils/exception.dart';
import 'models/models.dart';

Expand All @@ -10,13 +11,17 @@ String getPromptInputKey(
final MemoryInputValues inputValues,
final Set<String> memoryKeys,
) {
// "stop" is a special key that can be passed as input but is not used to
// format the prompt
final promptInputKeys =
inputValues.keys.toSet().difference({...memoryKeys, 'stop'});
// Reserved keys can be passed as input but is not used to format the prompt
final promptInputKeys = inputValues.keys.toSet().difference({
...memoryKeys,
'stop',
BaseActionAgent.agentScratchpadInputKey,
});
if (promptInputKeys.length != 1) {
throw LangChainException(
message: 'One input key expected got $promptInputKeys',
message: 'One input key expected got $promptInputKeys. '
'If you have multiple input keys in your prompt you need to specify '
'the input key to use for the memory using the `inputKey` parameter.',
);
}
return promptInputKeys.first;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,12 +336,12 @@ CustomChatMessage{
/// Role of a chat message
enum ChatMessageRole { system, human, ai, custom }

/// {@template openai_function_model}
/// {@template chat_function}
/// The description of a function that can be called by the chat model.
/// {@endtemplate
@immutable
class ChatFunction {
/// {@macro openai_function_model}
/// {@macro chat_function}
const ChatFunction({
required this.name,
this.description,
Expand Down
90 changes: 83 additions & 7 deletions packages/langchain/test/memory/buffer_test.dart
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import 'package:langchain/src/agents/agent.dart';
import 'package:langchain/src/memory/memory.dart';
import 'package:langchain/src/model_io/chat_models/chat_models.dart';
import 'package:test/test.dart';
Expand All @@ -7,21 +8,21 @@ void main() {
test('Test buffer memory', () async {
final memory = ConversationBufferMemory();
final result1 = await memory.loadMemoryVariables();
expect(result1, {'history': ''});
expect(result1, {BaseMemory.defaultMemoryKey: ''});

await memory.saveContext(
inputValues: {'foo': 'bar'},
outputValues: {'bar': 'foo'},
);
const expectedString = 'Human: bar\nAI: foo';
final result2 = await memory.loadMemoryVariables();
expect(result2, {'history': expectedString});
expect(result2, {BaseMemory.defaultMemoryKey: expectedString});
});

test('Test buffer memory return messages', () async {
final memory = ConversationBufferMemory(returnMessages: true);
final result1 = await memory.loadMemoryVariables();
expect(result1, {'history': <ChatMessage>[]});
expect(result1, {BaseMemory.defaultMemoryKey: <ChatMessage>[]});

await memory.saveContext(
inputValues: {'foo': 'bar'},
Expand All @@ -32,7 +33,24 @@ void main() {
ChatMessage.ai('foo'),
];
final result2 = await memory.loadMemoryVariables();
expect(result2, {'history': expectedResult});
expect(result2, {BaseMemory.defaultMemoryKey: expectedResult});
});

test('Test chat message as input and output', () async {
final memory = ConversationBufferMemory(returnMessages: true);
final result1 = await memory.loadMemoryVariables();
expect(result1, {BaseMemory.defaultMemoryKey: <ChatMessage>[]});

await memory.saveContext(
inputValues: {'foo': ChatMessage.function(name: 'foo', content: 'bar')},
outputValues: {'bar': ChatMessage.ai('baz')},
);
final expectedResult = [
ChatMessage.function(name: 'foo', content: 'bar'),
ChatMessage.ai('baz'),
];
final result2 = await memory.loadMemoryVariables();
expect(result2, {BaseMemory.defaultMemoryKey: expectedResult});
});

test('Test buffer memory with pre-loaded history', () async {
Expand All @@ -45,7 +63,7 @@ void main() {
chatHistory: ChatMessageHistory(messages: pastMessages),
);
final result = await memory.loadMemoryVariables();
expect(result, {'history': pastMessages});
expect(result, {BaseMemory.defaultMemoryKey: pastMessages});
});

test('Test clear memory', () async {
Expand All @@ -56,11 +74,69 @@ void main() {
);
const expectedString = 'Human: bar\nAI: foo';
final result1 = await memory.loadMemoryVariables();
expect(result1, {'history': expectedString});
expect(result1, {BaseMemory.defaultMemoryKey: expectedString});

memory.clear();
final result2 = await memory.loadMemoryVariables();
expect(result2, {'history': ''});
expect(result2, {BaseMemory.defaultMemoryKey: ''});
});

test('Test reserved keys are ignored when selecting prompt input keys',
() async {
final memory = ConversationBufferMemory(returnMessages: true);
await memory.saveContext(
inputValues: {
'foo': 'bar',
'stop': 'stop',
BaseActionAgent.agentScratchpadInputKey: 'baz',
},
outputValues: {'bar': 'foo'},
);
final expectedResult = [
ChatMessage.human('bar'),
ChatMessage.ai('foo'),
];
final result1 = await memory.loadMemoryVariables();
expect(result1, {BaseMemory.defaultMemoryKey: expectedResult});
});

test('Test multiple input values with inputKey specified', () async {
final memory = ConversationBufferMemory(
returnMessages: true,
inputKey: 'foo2',
);
await memory.saveContext(
inputValues: {
'foo1': 'bar1',
'foo2': 'bar2',
BaseActionAgent.agentScratchpadInputKey: 'baz',
},
outputValues: {'bar': 'foo'},
);
final expectedResult = [
ChatMessage.human('bar2'),
ChatMessage.ai('foo'),
];
final result1 = await memory.loadMemoryVariables();
expect(result1, {BaseMemory.defaultMemoryKey: expectedResult});
});

test(
'Test error is thrown if inputKey not specified when using with '
'multiple input values', () async {
final memory = ConversationBufferMemory(returnMessages: true);

// expect throws exception if no input keys are selected
expect(
() async => memory.saveContext(
inputValues: {
'foo1': 'bar1',
'foo2': 'bar2',
},
outputValues: {'bar': 'foo'},
),
throwsException,
);
});
});
}
14 changes: 7 additions & 7 deletions packages/langchain/test/memory/buffer_window_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ void main() {
);
const expectedString = 'Human: bar\nAI: foo';
final result2 = await memory.loadMemoryVariables();
expect(result2, {'history': expectedString});
expect(result2, {BaseMemory.defaultMemoryKey: expectedString});
});

test('Test buffer memory return messages', () async {
final memory = ConversationBufferWindowMemory(k: 1, returnMessages: true);
final result1 = await memory.loadMemoryVariables();
expect(result1, {'history': <ChatMessage>[]});
expect(result1, {BaseMemory.defaultMemoryKey: <ChatMessage>[]});

await memory.saveContext(
inputValues: {'foo': 'bar'},
Expand All @@ -32,7 +32,7 @@ void main() {
ChatMessage.ai('foo'),
];
final result2 = await memory.loadMemoryVariables();
expect(result2, {'history': expectedResult});
expect(result2, {BaseMemory.defaultMemoryKey: expectedResult});

await memory.saveContext(
inputValues: {'foo': 'bar1'},
Expand All @@ -44,7 +44,7 @@ void main() {
ChatMessage.ai('foo1'),
];
final result3 = await memory.loadMemoryVariables();
expect(result3, {'history': expectedResult2});
expect(result3, {BaseMemory.defaultMemoryKey: expectedResult2});
});

test('Test buffer memory with pre-loaded history', () async {
Expand All @@ -57,7 +57,7 @@ void main() {
chatHistory: ChatMessageHistory(messages: pastMessages),
);
final result = await memory.loadMemoryVariables();
expect(result, {'history': pastMessages});
expect(result, {BaseMemory.defaultMemoryKey: pastMessages});
});

test('Test clear memory', () async {
Expand All @@ -68,11 +68,11 @@ void main() {
);
const expectedString = 'Human: bar\nAI: foo';
final result1 = await memory.loadMemoryVariables();
expect(result1, {'history': expectedString});
expect(result1, {BaseMemory.defaultMemoryKey: expectedString});

memory.clear();
final result2 = await memory.loadMemoryVariables();
expect(result2, {'history': ''});
expect(result2, {BaseMemory.defaultMemoryKey: ''});
});
});
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ void main() {

test('Test addUserMessage', () async {
final history = ChatMessageHistory()
..addUserChatMessage('This is a human msg');
..addHumanChatMessage('This is a human msg');
final messages = await history.getChatMessages();
expect(messages.first, isA<HumanChatMessage>());
expect(messages.first.content, 'This is a human msg');
Expand Down
14 changes: 7 additions & 7 deletions packages/langchain/test/memory/token_buffer_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ void main() {
);
const expectedString = 'Human: bar\nAI: foo';
final result2 = await memory.loadMemoryVariables();
expect(result2, {'history': expectedString});
expect(result2, {BaseMemory.defaultMemoryKey: expectedString});
});

test('Test buffer memory return messages', () async {
Expand All @@ -29,7 +29,7 @@ void main() {
maxTokenLimit: 4,
);
final result1 = await memory.loadMemoryVariables();
expect(result1, {'history': <ChatMessage>[]});
expect(result1, {BaseMemory.defaultMemoryKey: <ChatMessage>[]});

await memory.saveContext(
inputValues: {'foo': 'bar'},
Expand All @@ -40,7 +40,7 @@ void main() {
ChatMessage.ai('foo'),
];
final result2 = await memory.loadMemoryVariables();
expect(result2, {'history': expectedResult});
expect(result2, {BaseMemory.defaultMemoryKey: expectedResult});

await memory.saveContext(
inputValues: {'foo': 'bar1'},
Expand All @@ -53,7 +53,7 @@ void main() {
ChatMessage.ai('foo1'),
];
final result3 = await memory.loadMemoryVariables();
expect(result3, {'history': expectedResult2});
expect(result3, {BaseMemory.defaultMemoryKey: expectedResult2});
});

test('Test buffer memory with pre-loaded history', () async {
Expand All @@ -69,7 +69,7 @@ void main() {
chatHistory: ChatMessageHistory(messages: pastMessages),
);
final result = await memory.loadMemoryVariables();
expect(result, {'history': pastMessages});
expect(result, {BaseMemory.defaultMemoryKey: pastMessages});
});

test('Test clear memory', () async {
Expand All @@ -80,11 +80,11 @@ void main() {
);
const expectedString = 'Human: bar\nAI: foo';
final result1 = await memory.loadMemoryVariables();
expect(result1, {'history': expectedString});
expect(result1, {BaseMemory.defaultMemoryKey: expectedString});

memory.clear();
final result2 = await memory.loadMemoryVariables();
expect(result2, {'history': ''});
expect(result2, {BaseMemory.defaultMemoryKey: ''});
});
});
}
Loading

0 comments on commit c5b0a0a

Please sign in to comment.