Skip to content

Commit

Permalink
refactor!: Improve Tool abstractions (#398)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmigloz committed May 2, 2024
1 parent 7647248 commit 2a50aec
Show file tree
Hide file tree
Showing 32 changed files with 527 additions and 387 deletions.
78 changes: 56 additions & 22 deletions docs/modules/agents/agent_types/openai_functions_agent.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ The OpenAI Functions Agent is designed to work with these models.
```dart
final llm = ChatOpenAI(
apiKey: openaiApiKey,
defaultOptions: const ChatOpenAIOptions(
model: 'gpt-4',
defaultOptions: ChatOpenAIOptions(
model: 'gpt-4-turbo',
temperature: 0,
),
);
Expand All @@ -29,15 +29,32 @@ print(res); // -> '40 raised to the power of 0.43 is approximately 4.8852'

You can easily call your own functions by wrapping them in a `Tool`. You can also add memory to the agent by passing it when creating the agent.

Let's see an example of how to do this:
Let's see an example of how to do this.

First let's create a class that will be the input for our tool.

```dart
final llm = ChatOpenAI(
apiKey: openaiApiKey,
defaultOptions: const ChatOpenAIOptions(temperature: 0),
);
class SearchInput {
const SearchInput({
required this.query,
required this.n,
});
final String query;
final int n;
SearchInput.fromJson(final Map<String, dynamic> json)
: this(
query: json['query'] as String,
n: json['n'] as int,
);
}
```

Now let's define the tool:

final tool = BaseTool.fromFunction(
```dart
final tool = Tool.fromFunction<SearchInput, String>(
name: 'search',
description: 'Tool for searching the web.',
inputJsonSchema: const {
Expand All @@ -54,21 +71,32 @@ final tool = BaseTool.fromFunction(
},
'required': ['query'],
},
func: (
final Map<String, dynamic> toolInput, {
final ToolOptions? options,
}) async {
final query = toolInput['query'];
final n = toolInput['n'];
return callYourSearchFunction(query, n);
},
func: callYourSearchFunction,
getInputFromJson: SearchInput.fromJson,
);
```

Notice that we need to provide a function that converts the JSON input that the model will send to our tool into the input class that we defined.

The tool will call `callYourSearchFunction` function with the parsed input. For simplicity, we will just mock the search function.
```dart
String callYourSearchFunction(final SearchInput input) {
return 'Results:\n${List<String>.generate(input.n, (final i) => 'Result ${i + 1}').join('\n')}';
}
```

Now we can create the agent and run it.

```dart
final llm = ChatOpenAI(
apiKey: openaiApiKey,
defaultOptions: const ChatOpenAIOptions(temperature: 0),
);
final tools = [tool];
final memory = ConversationBufferMemory(returnMessages: true);
final agent = OpenAIFunctionsAgent.fromLLMAndTools(
llm: llm,
tools: tools,
tools: [tool],
memory: memory,
);
Expand All @@ -77,6 +105,11 @@ final executor = AgentExecutor(agent: agent);
final res1 = await executor.run(
'Search for cats. Return only 3 results.',
);
print(res1);
// Here are 3 search results for "cats":
// 1. Result 1
// 2. Result 2
// 3. Result 3
```

## Using LangChain Expression Language (LCEL)
Expand All @@ -86,10 +119,10 @@ You can replicate the functionality of the OpenAI Functions Agent by using the L
```dart
final openaiApiKey = Platform.environment['OPENAI_API_KEY'];
final prompt = ChatPromptTemplate.fromPromptMessages([
SystemChatMessagePromptTemplate.fromTemplate('You are a helpful assistant'),
HumanChatMessagePromptTemplate.fromTemplate('{input}'),
const MessagesPlaceholder(variableName: 'agent_scratchpad'),
final prompt = ChatPromptTemplate.fromTemplates(const [
(ChatMessageType.system, 'You are a helpful assistant'),
(ChatMessageType.human, '{input}'),
(ChatMessageType.messagesPlaceholder, 'agent_scratchpad'),
]);
final tool = CalculatorTool();
Expand Down Expand Up @@ -134,6 +167,7 @@ final res = await executor.invoke({
'input': 'What is 40 raised to the 0.43 power?',
});
print(res['output']);
// 40 raised to the power of 0.43 is approximately 4.88524.
```

In this way, you can create your own custom agents with full control over their behavior.
4 changes: 2 additions & 2 deletions docs/modules/agents/tools/openai_dall_e.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ Example:
final llm = ChatOpenAI(
apiKey: openAiKey,
defaultOptions: const ChatOpenAIOptions(
model: 'gpt-4',
model: 'gpt-4-turbo',
temperature: 0,
),
);
final tools = <BaseTool>[
final tools = <Tool>[
CalculatorTool(),
OpenAIDallETool(apiKey: openAiKey),
];
Expand Down
10 changes: 5 additions & 5 deletions examples/browser_summarizer/pubspec.lock
Original file line number Diff line number Diff line change
Expand Up @@ -225,28 +225,28 @@ packages:
path: "../../packages/langchain"
relative: true
source: path
version: "0.5.0+1"
version: "0.6.0+1"
langchain_community:
dependency: "direct main"
description:
path: "../../packages/langchain_community"
relative: true
source: path
version: "0.1.0"
version: "0.1.0+2"
langchain_core:
dependency: "direct overridden"
description:
path: "../../packages/langchain_core"
relative: true
source: path
version: "0.1.0"
version: "0.2.0+1"
langchain_openai:
dependency: "direct main"
description:
path: "../../packages/langchain_openai"
relative: true
source: path
version: "0.5.0+1"
version: "0.5.1+1"
langchain_tiktoken:
dependency: transitive
description:
Expand Down Expand Up @@ -309,7 +309,7 @@ packages:
path: "../../packages/openai_dart"
relative: true
source: path
version: "0.2.1"
version: "0.2.2"
path:
dependency: transitive
description:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// ignore_for_file: avoid_print
// ignore_for_file: avoid_print, unreachable_from_main
import 'dart:io';

import 'package:langchain/langchain.dart';
Expand All @@ -16,7 +16,7 @@ Future<void> _openaiFunctionsAgent() async {
final llm = ChatOpenAI(
apiKey: openaiApiKey,
defaultOptions: const ChatOpenAIOptions(
model: 'gpt-4',
model: 'gpt-4-turbo',
temperature: 0,
),
);
Expand All @@ -29,12 +29,8 @@ Future<void> _openaiFunctionsAgent() async {

Future<void> _openaiFunctionsAgentCustomToolsMemory() async {
final openaiApiKey = Platform.environment['OPENAI_API_KEY'];
final llm = ChatOpenAI(
apiKey: openaiApiKey,
defaultOptions: const ChatOpenAIOptions(temperature: 0),
);

final tool = BaseTool.fromFunction(
final tool = Tool.fromFunction<SearchInput, String>(
name: 'search',
description: 'Tool for searching the web.',
inputJsonSchema: const {
Expand All @@ -51,21 +47,19 @@ Future<void> _openaiFunctionsAgentCustomToolsMemory() async {
},
'required': ['query'],
},
func: (
final Map<String, dynamic> toolInput, {
final ToolOptions? options,
}) async {
final query = toolInput['query'];
final n = toolInput['n'];
return callYourSearchFunction(query, n);
},
func: callYourSearchFunction,
getInputFromJson: SearchInput.fromJson,
);

final llm = ChatOpenAI(
apiKey: openaiApiKey,
defaultOptions: const ChatOpenAIOptions(temperature: 0),
);
final tools = [tool];

final memory = ConversationBufferMemory(returnMessages: true);
final agent = OpenAIFunctionsAgent.fromLLMAndTools(
llm: llm,
tools: tools,
tools: [tool],
memory: memory,
);

Expand All @@ -75,21 +69,39 @@ Future<void> _openaiFunctionsAgentCustomToolsMemory() async {
'Search for cats. Return only 3 results.',
);
print(res1);
// Here are 3 search results for "cats":
// 1. Result 1
// 2. Result 2
// 3. Result 3
}

class SearchInput {
const SearchInput({
required this.query,
required this.n,
});

final String query;
final int n;

SearchInput.fromJson(final Map<String, dynamic> json)
: this(
query: json['query'] as String,
n: json['n'] as int,
);
}

String callYourSearchFunction(final String query, final int n) {
return 'Results:\n${List<String>.generate(n, (final i) => 'Result ${i + 1}').join('\n')}';
String callYourSearchFunction(final SearchInput input) {
return 'Results:\n${List<String>.generate(input.n, (final i) => 'Result ${i + 1}').join('\n')}';
}

Future<void> _openaiFunctionsAgentLCEL() async {
final openaiApiKey = Platform.environment['OPENAI_API_KEY'];

final prompt = ChatPromptTemplate.fromPromptMessages([
SystemChatMessagePromptTemplate.fromTemplate(
'You are a helpful assistant',
),
HumanChatMessagePromptTemplate.fromTemplate('{input}'),
const MessagesPlaceholder(variableName: 'agent_scratchpad'),
final prompt = ChatPromptTemplate.fromTemplates(const [
(ChatMessageType.system, 'You are a helpful assistant'),
(ChatMessageType.human, '{input}'),
(ChatMessageType.messagesPlaceholder, 'agent_scratchpad'),
]);

final tool = CalculatorTool();
Expand Down Expand Up @@ -134,4 +146,5 @@ Future<void> _openaiFunctionsAgentLCEL() async {
'input': 'What is 40 raised to the 0.43 power?',
});
print(res['output']);
// 40 raised to the power of 0.43 is approximately 4.88524.
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ void main() async {
final llm = ChatOpenAI(
apiKey: openAiKey,
defaultOptions: const ChatOpenAIOptions(
model: 'gpt-4',
model: 'gpt-4-turbo',
temperature: 0,
),
);
final tools = <BaseTool>[
final tools = <Tool>[
CalculatorTool(),
OpenAIDallETool(apiKey: openAiKey),
];
Expand Down
18 changes: 9 additions & 9 deletions examples/docs_examples/pubspec.lock
Original file line number Diff line number Diff line change
Expand Up @@ -221,56 +221,56 @@ packages:
path: "../../packages/langchain"
relative: true
source: path
version: "0.5.0+1"
version: "0.6.0+1"
langchain_chroma:
dependency: "direct main"
description:
path: "../../packages/langchain_chroma"
relative: true
source: path
version: "0.2.0"
version: "0.2.0+2"
langchain_community:
dependency: "direct main"
description:
path: "../../packages/langchain_community"
relative: true
source: path
version: "0.1.0"
version: "0.1.0+2"
langchain_core:
dependency: "direct overridden"
description:
path: "../../packages/langchain_core"
relative: true
source: path
version: "0.1.0"
version: "0.2.0+1"
langchain_google:
dependency: "direct main"
description:
path: "../../packages/langchain_google"
relative: true
source: path
version: "0.3.0"
version: "0.3.0+2"
langchain_mistralai:
dependency: "direct main"
description:
path: "../../packages/langchain_mistralai"
relative: true
source: path
version: "0.1.0"
version: "0.1.0+2"
langchain_ollama:
dependency: "direct main"
description:
path: "../../packages/langchain_ollama"
relative: true
source: path
version: "0.1.0"
version: "0.1.0+2"
langchain_openai:
dependency: "direct main"
description:
path: "../../packages/langchain_openai"
relative: true
source: path
version: "0.5.0+1"
version: "0.5.1+1"
langchain_tiktoken:
dependency: transitive
description:
Expand Down Expand Up @@ -323,7 +323,7 @@ packages:
path: "../../packages/openai_dart"
relative: true
source: path
version: "0.2.1"
version: "0.2.2"
path:
dependency: transitive
description:
Expand Down
Loading

0 comments on commit 2a50aec

Please sign in to comment.