From 2a50aec28385068f9be32392020d727fc9a1561e Mon Sep 17 00:00:00 2001 From: David Miguel Lozano Date: Thu, 2 May 2024 23:25:46 +0200 Subject: [PATCH] refactor!: Improve Tool abstractions (#398) --- .../agent_types/openai_functions_agent.md | 78 +++-- docs/modules/agents/tools/openai_dall_e.md | 4 +- examples/browser_summarizer/pubspec.lock | 10 +- .../agent_types/openai_functions_agent.dart | 63 ++-- .../modules/agents/tools/openai_dalle.dart | 4 +- examples/docs_examples/pubspec.lock | 18 +- examples/hello_world_backend/pubspec.lock | 8 +- examples/hello_world_cli/pubspec.lock | 8 +- examples/hello_world_flutter/pubspec.lock | 8 +- melos.yaml | 1 - .../langchain/lib/src/agents/executor.dart | 27 +- .../langchain/lib/src/tools/exception.dart | 6 +- packages/langchain/lib/src/tools/invalid.dart | 23 -- packages/langchain/lib/src/tools/tools.dart | 1 - .../langchain/test/agents/executor_test.dart | 22 +- .../lib/src/tools/calculator.dart | 6 +- .../test/tools/calculator_test.dart | 11 +- .../langchain_core/lib/src/agents/base.dart | 4 +- .../langchain_core/lib/src/tools/base.dart | 293 +++++++----------- .../langchain_core/lib/src/tools/fake.dart | 9 +- .../langchain_core/lib/src/tools/string.dart | 102 ++++++ .../langchain_core/lib/src/tools/tools.dart | 1 + .../test/runnables/batch_test.dart | 6 +- .../test/runnables/invoke_test.dart | 2 +- .../test/runnables/stream_test.dart | 24 +- .../langchain_core/test/tools/base_test.dart | 100 ++++-- .../test/tools/string_test.dart | 18 ++ .../lib/src/agents/functions.dart | 2 +- .../lib/src/tools/dall_e.dart | 4 +- .../test/agents/functions_test.dart | 41 ++- .../test/chat_models/chat_openai_test.dart | 2 +- .../test/tools/dall_e_test.dart | 8 +- 32 files changed, 527 insertions(+), 387 deletions(-) delete mode 100644 packages/langchain/lib/src/tools/invalid.dart create mode 100644 packages/langchain_core/lib/src/tools/string.dart create mode 100644 packages/langchain_core/test/tools/string_test.dart diff --git a/docs/modules/agents/agent_types/openai_functions_agent.md b/docs/modules/agents/agent_types/openai_functions_agent.md index ecc29423..ae2697bb 100644 --- a/docs/modules/agents/agent_types/openai_functions_agent.md +++ b/docs/modules/agents/agent_types/openai_functions_agent.md @@ -15,8 +15,8 @@ The OpenAI Functions Agent is designed to work with these models. ```dart final llm = ChatOpenAI( apiKey: openaiApiKey, - defaultOptions: const ChatOpenAIOptions( - model: 'gpt-4', + defaultOptions: ChatOpenAIOptions( + model: 'gpt-4-turbo', temperature: 0, ), ); @@ -29,15 +29,32 @@ print(res); // -> '40 raised to the power of 0.43 is approximately 4.8852' You can easily call your own functions by wrapping them in a `Tool`. You can also add memory to the agent by passing it when creating the agent. -Let's see an example of how to do this: +Let's see an example of how to do this. + +First let's create a class that will be the input for our tool. ```dart -final llm = ChatOpenAI( - apiKey: openaiApiKey, - defaultOptions: const ChatOpenAIOptions(temperature: 0), -); +class SearchInput { + const SearchInput({ + required this.query, + required this.n, + }); + + final String query; + final int n; + + SearchInput.fromJson(final Map json) + : this( + query: json['query'] as String, + n: json['n'] as int, + ); +} +``` + +Now let's define the tool: -final tool = BaseTool.fromFunction( +```dart +final tool = Tool.fromFunction( name: 'search', description: 'Tool for searching the web.', inputJsonSchema: const { @@ -54,21 +71,32 @@ final tool = BaseTool.fromFunction( }, 'required': ['query'], }, - func: ( - final Map toolInput, { - final ToolOptions? options, - }) async { - final query = toolInput['query']; - final n = toolInput['n']; - return callYourSearchFunction(query, n); - }, + func: callYourSearchFunction, + getInputFromJson: SearchInput.fromJson, +); +``` + +Notice that we need to provide a function that converts the JSON input that the model will send to our tool into the input class that we defined. + +The tool will call `callYourSearchFunction` function with the parsed input. For simplicity, we will just mock the search function. +```dart +String callYourSearchFunction(final SearchInput input) { + return 'Results:\n${List.generate(input.n, (final i) => 'Result ${i + 1}').join('\n')}'; +} +``` + +Now we can create the agent and run it. + +```dart +final llm = ChatOpenAI( + apiKey: openaiApiKey, + defaultOptions: const ChatOpenAIOptions(temperature: 0), ); -final tools = [tool]; final memory = ConversationBufferMemory(returnMessages: true); final agent = OpenAIFunctionsAgent.fromLLMAndTools( llm: llm, - tools: tools, + tools: [tool], memory: memory, ); @@ -77,6 +105,11 @@ final executor = AgentExecutor(agent: agent); final res1 = await executor.run( 'Search for cats. Return only 3 results.', ); +print(res1); +// Here are 3 search results for "cats": +// 1. Result 1 +// 2. Result 2 +// 3. Result 3 ``` ## Using LangChain Expression Language (LCEL) @@ -86,10 +119,10 @@ You can replicate the functionality of the OpenAI Functions Agent by using the L ```dart final openaiApiKey = Platform.environment['OPENAI_API_KEY']; -final prompt = ChatPromptTemplate.fromPromptMessages([ - SystemChatMessagePromptTemplate.fromTemplate('You are a helpful assistant'), - HumanChatMessagePromptTemplate.fromTemplate('{input}'), - const MessagesPlaceholder(variableName: 'agent_scratchpad'), +final prompt = ChatPromptTemplate.fromTemplates(const [ + (ChatMessageType.system, 'You are a helpful assistant'), + (ChatMessageType.human, '{input}'), + (ChatMessageType.messagesPlaceholder, 'agent_scratchpad'), ]); final tool = CalculatorTool(); @@ -134,6 +167,7 @@ final res = await executor.invoke({ 'input': 'What is 40 raised to the 0.43 power?', }); print(res['output']); +// 40 raised to the power of 0.43 is approximately 4.88524. ``` In this way, you can create your own custom agents with full control over their behavior. diff --git a/docs/modules/agents/tools/openai_dall_e.md b/docs/modules/agents/tools/openai_dall_e.md index 8773982a..7520d18a 100644 --- a/docs/modules/agents/tools/openai_dall_e.md +++ b/docs/modules/agents/tools/openai_dall_e.md @@ -10,11 +10,11 @@ Example: final llm = ChatOpenAI( apiKey: openAiKey, defaultOptions: const ChatOpenAIOptions( - model: 'gpt-4', + model: 'gpt-4-turbo', temperature: 0, ), ); -final tools = [ +final tools = [ CalculatorTool(), OpenAIDallETool(apiKey: openAiKey), ]; diff --git a/examples/browser_summarizer/pubspec.lock b/examples/browser_summarizer/pubspec.lock index e9a63872..c5f12de6 100644 --- a/examples/browser_summarizer/pubspec.lock +++ b/examples/browser_summarizer/pubspec.lock @@ -225,28 +225,28 @@ packages: path: "../../packages/langchain" relative: true source: path - version: "0.5.0+1" + version: "0.6.0+1" langchain_community: dependency: "direct main" description: path: "../../packages/langchain_community" relative: true source: path - version: "0.1.0" + version: "0.1.0+2" langchain_core: dependency: "direct overridden" description: path: "../../packages/langchain_core" relative: true source: path - version: "0.1.0" + version: "0.2.0+1" langchain_openai: dependency: "direct main" description: path: "../../packages/langchain_openai" relative: true source: path - version: "0.5.0+1" + version: "0.5.1+1" langchain_tiktoken: dependency: transitive description: @@ -309,7 +309,7 @@ packages: path: "../../packages/openai_dart" relative: true source: path - version: "0.2.1" + version: "0.2.2" path: dependency: transitive description: diff --git a/examples/docs_examples/bin/modules/agents/agent_types/openai_functions_agent.dart b/examples/docs_examples/bin/modules/agents/agent_types/openai_functions_agent.dart index 00a59dcf..e19283fe 100644 --- a/examples/docs_examples/bin/modules/agents/agent_types/openai_functions_agent.dart +++ b/examples/docs_examples/bin/modules/agents/agent_types/openai_functions_agent.dart @@ -1,4 +1,4 @@ -// ignore_for_file: avoid_print +// ignore_for_file: avoid_print, unreachable_from_main import 'dart:io'; import 'package:langchain/langchain.dart'; @@ -16,7 +16,7 @@ Future _openaiFunctionsAgent() async { final llm = ChatOpenAI( apiKey: openaiApiKey, defaultOptions: const ChatOpenAIOptions( - model: 'gpt-4', + model: 'gpt-4-turbo', temperature: 0, ), ); @@ -29,12 +29,8 @@ Future _openaiFunctionsAgent() async { Future _openaiFunctionsAgentCustomToolsMemory() async { final openaiApiKey = Platform.environment['OPENAI_API_KEY']; - final llm = ChatOpenAI( - apiKey: openaiApiKey, - defaultOptions: const ChatOpenAIOptions(temperature: 0), - ); - final tool = BaseTool.fromFunction( + final tool = Tool.fromFunction( name: 'search', description: 'Tool for searching the web.', inputJsonSchema: const { @@ -51,21 +47,19 @@ Future _openaiFunctionsAgentCustomToolsMemory() async { }, 'required': ['query'], }, - func: ( - final Map toolInput, { - final ToolOptions? options, - }) async { - final query = toolInput['query']; - final n = toolInput['n']; - return callYourSearchFunction(query, n); - }, + func: callYourSearchFunction, + getInputFromJson: SearchInput.fromJson, + ); + + final llm = ChatOpenAI( + apiKey: openaiApiKey, + defaultOptions: const ChatOpenAIOptions(temperature: 0), ); - final tools = [tool]; final memory = ConversationBufferMemory(returnMessages: true); final agent = OpenAIFunctionsAgent.fromLLMAndTools( llm: llm, - tools: tools, + tools: [tool], memory: memory, ); @@ -75,21 +69,39 @@ Future _openaiFunctionsAgentCustomToolsMemory() async { 'Search for cats. Return only 3 results.', ); print(res1); + // Here are 3 search results for "cats": + // 1. Result 1 + // 2. Result 2 + // 3. Result 3 +} + +class SearchInput { + const SearchInput({ + required this.query, + required this.n, + }); + + final String query; + final int n; + + SearchInput.fromJson(final Map json) + : this( + query: json['query'] as String, + n: json['n'] as int, + ); } -String callYourSearchFunction(final String query, final int n) { - return 'Results:\n${List.generate(n, (final i) => 'Result ${i + 1}').join('\n')}'; +String callYourSearchFunction(final SearchInput input) { + return 'Results:\n${List.generate(input.n, (final i) => 'Result ${i + 1}').join('\n')}'; } Future _openaiFunctionsAgentLCEL() async { final openaiApiKey = Platform.environment['OPENAI_API_KEY']; - final prompt = ChatPromptTemplate.fromPromptMessages([ - SystemChatMessagePromptTemplate.fromTemplate( - 'You are a helpful assistant', - ), - HumanChatMessagePromptTemplate.fromTemplate('{input}'), - const MessagesPlaceholder(variableName: 'agent_scratchpad'), + final prompt = ChatPromptTemplate.fromTemplates(const [ + (ChatMessageType.system, 'You are a helpful assistant'), + (ChatMessageType.human, '{input}'), + (ChatMessageType.messagesPlaceholder, 'agent_scratchpad'), ]); final tool = CalculatorTool(); @@ -134,4 +146,5 @@ Future _openaiFunctionsAgentLCEL() async { 'input': 'What is 40 raised to the 0.43 power?', }); print(res['output']); + // 40 raised to the power of 0.43 is approximately 4.88524. } diff --git a/examples/docs_examples/bin/modules/agents/tools/openai_dalle.dart b/examples/docs_examples/bin/modules/agents/tools/openai_dalle.dart index d9c5a120..1b649112 100644 --- a/examples/docs_examples/bin/modules/agents/tools/openai_dalle.dart +++ b/examples/docs_examples/bin/modules/agents/tools/openai_dalle.dart @@ -10,11 +10,11 @@ void main() async { final llm = ChatOpenAI( apiKey: openAiKey, defaultOptions: const ChatOpenAIOptions( - model: 'gpt-4', + model: 'gpt-4-turbo', temperature: 0, ), ); - final tools = [ + final tools = [ CalculatorTool(), OpenAIDallETool(apiKey: openAiKey), ]; diff --git a/examples/docs_examples/pubspec.lock b/examples/docs_examples/pubspec.lock index f0384ca9..665639f6 100644 --- a/examples/docs_examples/pubspec.lock +++ b/examples/docs_examples/pubspec.lock @@ -221,56 +221,56 @@ packages: path: "../../packages/langchain" relative: true source: path - version: "0.5.0+1" + version: "0.6.0+1" langchain_chroma: dependency: "direct main" description: path: "../../packages/langchain_chroma" relative: true source: path - version: "0.2.0" + version: "0.2.0+2" langchain_community: dependency: "direct main" description: path: "../../packages/langchain_community" relative: true source: path - version: "0.1.0" + version: "0.1.0+2" langchain_core: dependency: "direct overridden" description: path: "../../packages/langchain_core" relative: true source: path - version: "0.1.0" + version: "0.2.0+1" langchain_google: dependency: "direct main" description: path: "../../packages/langchain_google" relative: true source: path - version: "0.3.0" + version: "0.3.0+2" langchain_mistralai: dependency: "direct main" description: path: "../../packages/langchain_mistralai" relative: true source: path - version: "0.1.0" + version: "0.1.0+2" langchain_ollama: dependency: "direct main" description: path: "../../packages/langchain_ollama" relative: true source: path - version: "0.1.0" + version: "0.1.0+2" langchain_openai: dependency: "direct main" description: path: "../../packages/langchain_openai" relative: true source: path - version: "0.5.0+1" + version: "0.5.1+1" langchain_tiktoken: dependency: transitive description: @@ -323,7 +323,7 @@ packages: path: "../../packages/openai_dart" relative: true source: path - version: "0.2.1" + version: "0.2.2" path: dependency: transitive description: diff --git a/examples/hello_world_backend/pubspec.lock b/examples/hello_world_backend/pubspec.lock index 286dcebf..8350944e 100644 --- a/examples/hello_world_backend/pubspec.lock +++ b/examples/hello_world_backend/pubspec.lock @@ -119,21 +119,21 @@ packages: path: "../../packages/langchain" relative: true source: path - version: "0.5.0+1" + version: "0.6.0+1" langchain_core: dependency: "direct overridden" description: path: "../../packages/langchain_core" relative: true source: path - version: "0.1.0" + version: "0.2.0+1" langchain_openai: dependency: "direct main" description: path: "../../packages/langchain_openai" relative: true source: path - version: "0.5.0+1" + version: "0.5.1+1" langchain_tiktoken: dependency: transitive description: @@ -156,7 +156,7 @@ packages: path: "../../packages/openai_dart" relative: true source: path - version: "0.2.1" + version: "0.2.2" path: dependency: transitive description: diff --git a/examples/hello_world_cli/pubspec.lock b/examples/hello_world_cli/pubspec.lock index c5718392..a0919a79 100644 --- a/examples/hello_world_cli/pubspec.lock +++ b/examples/hello_world_cli/pubspec.lock @@ -111,21 +111,21 @@ packages: path: "../../packages/langchain" relative: true source: path - version: "0.5.0+1" + version: "0.6.0+1" langchain_core: dependency: "direct overridden" description: path: "../../packages/langchain_core" relative: true source: path - version: "0.1.0" + version: "0.2.0+1" langchain_openai: dependency: "direct main" description: path: "../../packages/langchain_openai" relative: true source: path - version: "0.5.0+1" + version: "0.5.1+1" langchain_tiktoken: dependency: transitive description: @@ -148,7 +148,7 @@ packages: path: "../../packages/openai_dart" relative: true source: path - version: "0.2.1" + version: "0.2.2" path: dependency: transitive description: diff --git a/examples/hello_world_flutter/pubspec.lock b/examples/hello_world_flutter/pubspec.lock index a99021aa..8bcbd6cb 100644 --- a/examples/hello_world_flutter/pubspec.lock +++ b/examples/hello_world_flutter/pubspec.lock @@ -140,21 +140,21 @@ packages: path: "../../packages/langchain" relative: true source: path - version: "0.5.0+1" + version: "0.6.0+1" langchain_core: dependency: "direct overridden" description: path: "../../packages/langchain_core" relative: true source: path - version: "0.1.0" + version: "0.2.0+1" langchain_openai: dependency: "direct main" description: path: "../../packages/langchain_openai" relative: true source: path - version: "0.5.0+1" + version: "0.5.1+1" langchain_tiktoken: dependency: transitive description: @@ -193,7 +193,7 @@ packages: path: "../../packages/openai_dart" relative: true source: path - version: "0.2.1" + version: "0.2.2" path: dependency: transitive description: diff --git a/melos.yaml b/melos.yaml index acb3691e..e47b8950 100644 --- a/melos.yaml +++ b/melos.yaml @@ -41,7 +41,6 @@ command: js: ^0.7.1 json_annotation: ^4.8.1 json_path: ^0.7.1 - langchain_core: ^0.0.1-dev.2 langchain_tiktoken: ^1.0.1 math_expressions: ^2.4.0 meta: ^1.11.0 diff --git a/packages/langchain/lib/src/agents/executor.dart b/packages/langchain/lib/src/agents/executor.dart index 94b8b93b..688d3996 100644 --- a/packages/langchain/lib/src/agents/executor.dart +++ b/packages/langchain/lib/src/agents/executor.dart @@ -5,7 +5,6 @@ import 'package:langchain_core/tools.dart'; import 'package:meta/meta.dart'; import '../tools/exception.dart'; -import '../tools/invalid.dart'; /// {@template agent_executor} /// A chain responsible for executing the actions of an agent using tools. @@ -46,7 +45,7 @@ class AgentExecutor extends BaseChain { /// The valid tools the agent can call plus some internal tools used by the /// executor. - final List _internalTools; + final List _internalTools; /// Whether to return the agent's trajectory of intermediate steps at the /// end in addition to the final output. @@ -64,9 +63,9 @@ class AgentExecutor extends BaseChain { final AgentEarlyStoppingMethod earlyStoppingMethod; /// Handles errors raised by the agent's output parser. - /// The response from this handlers is passed to the agent as the observation - /// resulting from the step. - final String Function(OutputParserException)? handleParsingErrors; + /// The response from this handler will be used as the tool input. + final Map Function(OutputParserException)? + handleParsingErrors; /// Output key for the agent's intermediate steps output. static const intermediateStepsOutputKey = 'intermediate_steps'; @@ -85,7 +84,7 @@ class AgentExecutor extends BaseChain { final agent = this.agent; final tools = _internalTools; if (agent is BaseMultiActionAgent) { - for (final BaseTool tool in tools) { + for (final Tool tool in tools) { if (tool.returnDirect) { return false; } @@ -169,7 +168,7 @@ class AgentExecutor extends BaseChain { /// Override this to take control of how the agent makes and acts on choices. @visibleForOverriding Future<(AgentFinish? result, List? nextSteps)> takeNextStep( - final Map nameToToolMap, + final Map nameToToolMap, final ChainValues inputs, final List intermediateSteps, ) async { @@ -183,7 +182,7 @@ class AgentExecutor extends BaseChain { actions = [ AgentAction( tool: ExceptionTool.toolName, - toolInput: {Tool.inputVar: handleParsingErrors!(e)}, + toolInput: handleParsingErrors!(e), log: e.toString(), ), ]; @@ -198,11 +197,17 @@ class AgentExecutor extends BaseChain { // Otherwise, we run the tool final agentAction = action as AgentAction; final tool = nameToToolMap[agentAction.tool]; + String observation; + if (tool != null) { + final toolInput = tool.getInputFromJson(agentAction.toolInput); + observation = (await tool.invoke(toolInput)).toString(); + } else { + observation = + '${agentAction.tool} is not a valid tool, try another one.'; + } final step = AgentStep( action: action, - observation: await (tool != null - ? tool.run(agentAction.toolInput) - : InvalidTool().run({Tool.inputVar: agentAction.tool})), + observation: observation, ); result.add(step); } diff --git a/packages/langchain/lib/src/tools/exception.dart b/packages/langchain/lib/src/tools/exception.dart index 81677dab..23cb02cc 100644 --- a/packages/langchain/lib/src/tools/exception.dart +++ b/packages/langchain/lib/src/tools/exception.dart @@ -7,7 +7,7 @@ import 'package:langchain_core/tools.dart'; /// /// Returns the output of [AgentExecutor.handleParsingErrors]. /// {@endtemplate} -final class ExceptionTool extends Tool { +final class ExceptionTool extends StringTool { /// {@macro exception_tool} ExceptionTool() : super( @@ -19,10 +19,10 @@ final class ExceptionTool extends Tool { static const toolName = '_exception'; @override - FutureOr runInternalString( + Future invokeInternal( final String toolInput, { final ToolOptions? options, }) { - return toolInput; + return Future.value(toolInput); } } diff --git a/packages/langchain/lib/src/tools/invalid.dart b/packages/langchain/lib/src/tools/invalid.dart deleted file mode 100644 index 67a19cbb..00000000 --- a/packages/langchain/lib/src/tools/invalid.dart +++ /dev/null @@ -1,23 +0,0 @@ -import 'dart:async'; - -import 'package:langchain_core/tools.dart'; - -/// {@template invalid_tool} -/// Tool that is run when invalid tool name is encountered by agent -/// {@endtemplate} -final class InvalidTool extends Tool { - /// {@macro invalid_tool} - InvalidTool() - : super( - name: 'invalid_tool', - description: 'Called when tool name is invalid.', - ); - - @override - FutureOr runInternalString( - final String toolInput, { - final ToolOptions? options, - }) { - return '$toolInput is not a valid tool, try another one.'; - } -} diff --git a/packages/langchain/lib/src/tools/tools.dart b/packages/langchain/lib/src/tools/tools.dart index b22f57d5..303491e0 100644 --- a/packages/langchain/lib/src/tools/tools.dart +++ b/packages/langchain/lib/src/tools/tools.dart @@ -1,4 +1,3 @@ export 'package:langchain_core/tools.dart'; export 'exception.dart'; -export 'invalid.dart'; diff --git a/packages/langchain/test/agents/executor_test.dart b/packages/langchain/test/agents/executor_test.dart index da0b1c86..7c3bfa17 100644 --- a/packages/langchain/test/agents/executor_test.dart +++ b/packages/langchain/test/agents/executor_test.dart @@ -13,7 +13,7 @@ void main() { actions: [ AgentAction( tool: tool.name, - toolInput: {Tool.inputVar: 'mock'}, + toolInput: {'input': 'mock'}, ), ], ); @@ -32,7 +32,7 @@ void main() { actions: [ AgentAction( tool: tool.name, - toolInput: {Tool.inputVar: 'mock'}, + toolInput: {'input': 'mock'}, ), ], ); @@ -51,7 +51,7 @@ void main() { actions: [ AgentAction( tool: tool.name, - toolInput: {Tool.inputVar: 'mock'}, + toolInput: {'input': 'mock'}, ), const AgentFinish( returnValues: {BaseActionAgent.agentReturnKey: 'mock'}, @@ -70,7 +70,7 @@ void main() { actions: [ const AgentAction( tool: 'tool', - toolInput: {Tool.inputVar: 'mock'}, + toolInput: {'input': 'mock'}, ), ], ); @@ -98,13 +98,13 @@ void main() { actions: [ const AgentAction( tool: 'invalid_tool', - toolInput: {Tool.inputVar: 'mock'}, + toolInput: {'input': 'mock'}, ), ], ); final executor = AgentExecutor( agent: agent, - handleParsingErrors: (final _) => 'fallback', + handleParsingErrors: (final _) => {'input': 'fallback'}, maxIterations: 1, returnIntermediateSteps: true, ); @@ -143,19 +143,17 @@ void main() { }); } -final class _MockTool extends Tool { +final class _MockTool extends StringTool { _MockTool({ super.name = 'tool', super.returnDirect = false, - }) : super( - description: '$name-description', - ); + }) : super(description: '$name-description'); @override - FutureOr runInternalString( + Future invokeInternal( final String toolInput, { final ToolOptions? options, - }) { + }) async { return '$name-output'; } } diff --git a/packages/langchain_community/lib/src/tools/calculator.dart b/packages/langchain_community/lib/src/tools/calculator.dart index f3b5bfa5..510b8e36 100644 --- a/packages/langchain_community/lib/src/tools/calculator.dart +++ b/packages/langchain_community/lib/src/tools/calculator.dart @@ -20,7 +20,7 @@ import 'package:math_expressions/math_expressions.dart'; /// print(res); // -> '40 raised to the power of 0.43 is approximately 4.8852' /// ``` /// {@endtemplate} -final class CalculatorTool extends Tool { +final class CalculatorTool extends StringTool { /// {@macro calculator_tool} CalculatorTool() : super( @@ -34,10 +34,10 @@ final class CalculatorTool extends Tool { final _parser = Parser(); @override - FutureOr runInternalString( + Future invokeInternal( final String toolInput, { final ToolOptions? options, - }) { + }) async { try { return _parser .parse(toolInput) diff --git a/packages/langchain_community/test/tools/calculator_test.dart b/packages/langchain_community/test/tools/calculator_test.dart index 1499208c..4dd78553 100644 --- a/packages/langchain_community/test/tools/calculator_test.dart +++ b/packages/langchain_community/test/tools/calculator_test.dart @@ -1,5 +1,4 @@ import 'package:langchain_community/langchain_community.dart'; -import 'package:langchain_core/tools.dart'; import 'package:test/test.dart'; void main() { @@ -7,13 +6,11 @@ void main() { test('Calculate expressions', () async { final echoTool = CalculatorTool(); - expect(echoTool.run({Tool.inputVar: '1 + 1'}), '2.0'); - expect(echoTool.run({Tool.inputVar: '1 - 1'}), '0.0'); - expect(echoTool.run({Tool.inputVar: '10*1 - (-5)'}), '15.0'); + expect(await echoTool.invoke('1 + 1'), '2.0'); + expect(await echoTool.invoke('1 - 1'), '0.0'); + expect(await echoTool.invoke('10*1 - (-5)'), '15.0'); expect( - double.parse( - await echoTool.run({Tool.inputVar: '(2^2 + cos(3.14)) / 3'}), - ), + double.parse(await echoTool.invoke('(2^2 + cos(3.14)) / 3')), closeTo(1.0, 0.000001), ); }); diff --git a/packages/langchain_core/lib/src/agents/base.dart b/packages/langchain_core/lib/src/agents/base.dart index d70e6059..1dd7dd07 100644 --- a/packages/langchain_core/lib/src/agents/base.dart +++ b/packages/langchain_core/lib/src/agents/base.dart @@ -15,7 +15,7 @@ abstract class Agent { static BaseMultiActionAgent fromRunnable( final Runnable> runnable, { - required final List tools, + required final List tools, }) { return RunnableAgent(runnable, tools: tools); } @@ -46,7 +46,7 @@ abstract class BaseActionAgent extends Agent { String get agentType; /// The tools this agent can use. - final List tools; + final List tools; /// Given the input and previous steps, returns the next action to take. /// diff --git a/packages/langchain_core/lib/src/tools/base.dart b/packages/langchain_core/lib/src/tools/base.dart index 43756fb0..266daaeb 100644 --- a/packages/langchain_core/lib/src/tools/base.dart +++ b/packages/langchain_core/lib/src/tools/base.dart @@ -1,4 +1,4 @@ -// ignore_for_file: avoid_equals_and_hash_code_on_mutable_classes +// ignore_for_file: avoid_equals_and_hash_code_on_mutable_classes, avoid_implementing_value_types import 'dart:async'; import 'package:meta/meta.dart'; @@ -6,16 +6,55 @@ import 'package:meta/meta.dart'; import '../chat_models/types.dart'; import '../langchain/base.dart'; import '../utils/reduce.dart'; +import 'string.dart'; import 'types.dart'; -/// {@template base_tool} -/// Base class LangChain tools must extend. -/// The input to the tool needs to be described by [inputJsonSchema]. +/// {@template tool_spec} +/// The specification of a LangChain tool without the actual implementation. /// {@endtemplate} -abstract base class BaseTool - extends BaseLangChain, Options, String> { - /// {@macro base_tool} - BaseTool({ +@immutable +class ToolSpec { + /// {@macro tool_spec} + const ToolSpec({ + required this.name, + required this.description, + required this.inputJsonSchema, + }); + + /// The unique name of the tool that clearly communicates its purpose. + final String name; + + /// Used to tell the model how/when/why to use the tool. + /// You can provide few-shot examples as a part of the description. + final String description; + + /// Schema to parse and validate tool's input arguments. + /// Following the [JSON Schema specification](https://json-schema.org). + final Map inputJsonSchema; + + @override + bool operator ==(covariant final ToolSpec other) => + identical(this, other) || name == other.name; + + @override + int get hashCode => name.hashCode; +} + +/// {@template tool} +/// A LangChain tool. +/// +/// The [Input] to the tool needs to be described by the [inputJsonSchema]. +/// +/// You can easily create a tool from a function using [Tool.fromFunction]. +/// +/// If you want to create a tool that accepts a single string input and returns +/// a string output, you can use [StringTool] or [StringTool.fromFunction]. +/// {@endtemplate} +abstract base class Tool extends BaseLangChain + implements ToolSpec { + /// {@macro tool} + Tool({ required this.name, required this.description, required this.inputJsonSchema, @@ -26,15 +65,13 @@ abstract base class BaseTool assert(description.isNotEmpty, 'Tool description cannot be empty.'), super(defaultOptions: defaultOptions ?? const ToolOptions() as Options); - /// The unique name of the tool that clearly communicates its purpose. + @override final String name; - /// Used to tell the model how/when/why to use the tool. - /// You can provide few-shot examples as a part of the description. + @override final String description; - /// Schema to parse and validate tool's input arguments. - /// Following the [JSON Schema specification](https://json-schema.org). + @override final Map inputJsonSchema; /// Whether to return the tool's output directly. @@ -43,42 +80,43 @@ abstract base class BaseTool final bool returnDirect; /// Handle the content of the [ToolException] thrown by the tool. - final String Function(ToolException)? handleToolError; + final Output Function(ToolException)? handleToolError; - /// Creates a [BaseTool] from a function. + /// Creates a [Tool] from a function. /// /// - [name] is the unique name of the tool that clearly communicates its /// purpose. /// - [description] is used to tell the model how/when/why to use the tool. /// You can provide few-shot examples as a part of the description. - /// - [func] is the function that will be called when the tool is run. /// - [inputJsonSchema] is the schema to parse and validate tool's input + /// - [func] is the function that will be called when the tool is run. /// arguments. + /// - [getInputFromJson] is a function that parses the input JSON to the + /// tool's input type. By default, it assumes the input values is under + /// the key 'input'. Define your own deserialization logic if the input + /// is not a primitive type or is under a different key. /// - [returnDirect] whether to return the tool's output directly. /// Setting this to true means that after the tool is called, /// the AgentExecutor will stop looping. /// - [handleToolError] is a function that handles the content of the /// [ToolException] thrown by the tool. - static BaseTool fromFunction({ + static Tool fromFunction({ required final String name, required final String description, - required final FutureOr Function( - Map toolInput, { - Options? options, - }) func, required final Map inputJsonSchema, - final Options? defaultOptions, + required final FutureOr Function(Input input) func, + Input Function(Map json)? getInputFromJson, final bool returnDirect = false, - final String Function(ToolException)? handleToolError, + final Output Function(ToolException)? handleToolError, }) { - return _BaseToolFunc( + return _ToolFunc( name: name, description: description, - func: func, inputJsonSchema: inputJsonSchema, + function: func, + getInputFromJson: getInputFromJson ?? (json) => json['input'] as Input, returnDirect: returnDirect, handleToolError: handleToolError, - defaultOptions: defaultOptions ?? const ToolOptions() as Options, ); } @@ -87,38 +125,12 @@ abstract base class BaseTool /// - [input] is the input to the tool. /// - [options] is the options to pass to the tool. @override - Future invoke( - final Map input, { + Future invoke( + final Input input, { final Options? options, }) async { - return run(input); - } - - /// Streams the tool's output for the input resulting from - /// reducing the input stream. - /// - /// - [inputStream] - the input stream to reduce and use as the input. - /// - [options] is the options to pass to the tool. - @override - Stream streamFromInputStream( - final Stream> inputStream, { - final Options? options, - }) async* { - final input = await inputStream.toList(); - final reduced = reduce>(input); - yield* stream(reduced, options: options); - } - - /// Runs the tool. - /// - /// - [toolInput] the input to the tool. - /// - [options] the options to pass to the tool. - FutureOr run( - final Map toolInput, { - final Options? options, - }) { try { - return runInternal(toolInput); + return invokeInternal(input, options: options); } on ToolException catch (e) { if (handleToolError != null) { return handleToolError!(e); @@ -130,24 +142,31 @@ abstract base class BaseTool } } - /// Actual implementation of [run] method logic. + /// Actual implementation of [invoke] method logic. @protected - FutureOr runInternal( - final Map toolInput, { + Future invokeInternal( + final Input input, { final Options? options, }); - /// Runs the tool (same as [run] but using callable class syntax). + /// Streams the tool's output for the input resulting from + /// reducing the input stream. /// - /// - [toolInput] the input to the tool. - /// - [options] the options to pass to the tool. - FutureOr call({ - required final Map toolInput, + /// - [inputStream] - the input stream to reduce and use as the input. + /// - [options] is the options to pass to the tool. + @override + Stream streamFromInputStream( + final Stream inputStream, { final Options? options, - }) { - return run(toolInput, options: options); + }) async* { + final input = await inputStream.toList(); + final reduced = reduce(input); + yield* stream(reduced, options: options); } + /// Parses the input JSON to the tool's input type. + Input getInputFromJson(final Map json); + /// Converts the tool to a [ChatFunction]. ChatFunction toChatFunction() { return ChatFunction( @@ -158,150 +177,48 @@ abstract base class BaseTool } @override - bool operator ==(covariant final BaseTool other) => + bool operator ==(covariant final Tool other) => identical(this, other) || name == other.name; @override int get hashCode => name.hashCode; } -/// {@template base_tool_func} +/// {@template tool_func} /// A tool that accepts a function as input. -/// Used in [BaseTool.fromFunction]. +/// Used in [Tool.fromFunction]. /// {@endtemplate} -final class _BaseToolFunc - extends BaseTool { - /// {@macro base_tool_func} - _BaseToolFunc({ +final class _ToolFunc + extends Tool { + /// {@macro tool_func} + _ToolFunc({ required super.name, required super.description, - required this.func, required super.inputJsonSchema, + required FutureOr Function(Input input) function, + required Input Function(Map json) getInputFromJson, super.returnDirect = false, super.handleToolError, super.defaultOptions, - }); + }) : _getInputFromJson = getInputFromJson, + _function = function; /// The function to run when the tool is called. - final FutureOr Function( - Map toolInput, { - Options? options, - }) func; - - @override - FutureOr runInternal( - final Map toolInput, { - final Options? options, - }) { - return func(toolInput, options: options); - } -} - -/// {@template tool} -/// This class wraps functions that accept a single string input and returns a -/// string output. -/// {@endtemplate} -abstract base class Tool - extends BaseTool { - /// {@macro tool} - Tool({ - required super.name, - required super.description, - final String inputDescription = 'The input to the tool', - super.returnDirect = false, - super.handleToolError, - super.defaultOptions, - }) : super( - inputJsonSchema: { - 'type': 'object', - 'properties': { - inputVar: { - 'type': 'string', - 'description': inputDescription, - }, - }, - 'required': ['input'], - }, - ); + final FutureOr Function(Input toolInput) _function; - /// The name of the input variable. - static const inputVar = 'input'; - - /// Creates a [Tool] from a function. - /// - /// - [name] is the unique name of the tool that clearly communicates its - /// purpose. - /// - [description] is used to tell the model how/when/why to use the tool. - /// You can provide few-shot examples as a part of the description. - /// - [func] is the function that will be called when the tool is run. - /// - [returnDirect] whether to return the tool's output directly. - /// Setting this to true means that after the tool is called, - /// the AgentExecutor will stop looping. - /// - [handleToolError] is a function that handles the content of the - /// [ToolException] thrown by the tool. - static Tool fromFunction({ - required final String name, - required final String description, - final String inputDescription = 'The input to the tool', - required final FutureOr Function( - String toolInput, { - Options? options, - }) func, - final bool returnDirect = false, - final String Function(ToolException)? handleToolError, - }) { - return _ToolFunc( - name: name, - description: description, - inputDescription: inputDescription, - func: func, - returnDirect: returnDirect, - handleToolError: handleToolError, - ); - } + /// The function to parse the input JSON to the tool's input type. + final Input Function(Map json) _getInputFromJson; @override - FutureOr runInternal( - final Map toolInput, { - final Options? options, - }) { - return runInternalString(toolInput[Tool.inputVar], options: options); + Future invokeInternal( + final Input toolInput, { + final ToolOptions? options, + }) async { + return _function(toolInput); } - /// Actual implementation of [run] method logic with string input. - @protected - FutureOr runInternalString( - final String toolInput, { - final Options? options, - }); -} - -/// {@template tool_func} -/// Implementation of [Tool] that accepts a function as input. -/// Used in [Tool.fromFunction]. -/// {@endtemplate} -final class _ToolFunc extends Tool { - /// {@macro tool_func} - _ToolFunc({ - required super.name, - required super.description, - super.inputDescription, - required this.func, - super.returnDirect = false, - super.handleToolError, - super.defaultOptions, - }); - - final FutureOr Function( - String toolInput, { - Options? options, - }) func; - @override - FutureOr runInternalString( - final String toolInput, { - final Options? options, - }) { - return func(toolInput, options: options); + Input getInputFromJson(final Map json) { + return _getInputFromJson(json); } } diff --git a/packages/langchain_core/lib/src/tools/fake.dart b/packages/langchain_core/lib/src/tools/fake.dart index bd7b9b09..60540e67 100644 --- a/packages/langchain_core/lib/src/tools/fake.dart +++ b/packages/langchain_core/lib/src/tools/fake.dart @@ -1,12 +1,13 @@ import 'dart:async'; -import '../../tools.dart'; +import 'string.dart'; +import 'types.dart'; /// {@template fake_tool} /// Fake tool for testing. /// It just returns the input string as is. /// {@endtemplate} -final class FakeTool extends Tool { +final class FakeTool extends StringTool { /// {@macro fake_tool} FakeTool() : super( @@ -16,10 +17,10 @@ final class FakeTool extends Tool { ); @override - FutureOr runInternalString( + Future invokeInternal( final String toolInput, { final ToolOptions? options, - }) { + }) async { try { return toolInput; } catch (e) { diff --git a/packages/langchain_core/lib/src/tools/string.dart b/packages/langchain_core/lib/src/tools/string.dart new file mode 100644 index 00000000..3c9973d5 --- /dev/null +++ b/packages/langchain_core/lib/src/tools/string.dart @@ -0,0 +1,102 @@ +import 'dart:async'; + +import 'base.dart'; +import 'types.dart'; + +/// {@template string_tool} +/// Base class for tools that accept a single string input and returns a +/// string output. +/// {@endtemplate} +abstract base class StringTool + extends Tool { + /// {@macro string_tool} + StringTool({ + required super.name, + required super.description, + final String inputDescription = 'The input to the tool', + super.returnDirect = false, + super.handleToolError, + super.defaultOptions, + }) : super( + inputJsonSchema: { + 'type': 'object', + 'properties': { + 'input': { + 'type': 'string', + 'description': inputDescription, + }, + }, + 'required': ['input'], + }, + ); + + /// Creates a [StringTool] from a function. + /// + /// - [name] is the unique name of the tool that clearly communicates its + /// purpose. + /// - [description] is used to tell the model how/when/why to use the tool. + /// You can provide few-shot examples as a part of the description. + /// - [func] is the function that will be called when the tool is run. + /// - [returnDirect] whether to return the tool's output directly. + /// Setting this to true means that after the tool is called, + /// the AgentExecutor will stop looping. + /// - [handleToolError] is a function that handles the content of the + /// [ToolException] thrown by the tool. + static StringTool fromFunction({ + required final String name, + required final String description, + final String inputDescription = 'The input to the tool', + required final FutureOr Function(String input) func, + final bool returnDirect = false, + final String Function(ToolException)? handleToolError, + }) { + return _StringToolFunc( + name: name, + description: description, + inputDescription: inputDescription, + func: func, + returnDirect: returnDirect, + handleToolError: handleToolError, + ); + } + + /// Actual implementation of [invoke] method logic with string input. + @override + Future invokeInternal( + final String toolInput, { + final Options? options, + }); + + @override + String getInputFromJson(final Map json) { + return json['input'] as String; + } +} + +/// {@template string_tool_func} +/// Implementation of [StringTool] that accepts a function as input. +/// Used in [StringTool.fromFunction]. +/// {@endtemplate} +final class _StringToolFunc + extends StringTool { + /// {@macro string_tool_func} + _StringToolFunc({ + required super.name, + required super.description, + super.inputDescription, + required FutureOr Function(String) func, + super.returnDirect = false, + super.handleToolError, + super.defaultOptions, + }) : _func = func; + + final FutureOr Function(String input) _func; + + @override + Future invokeInternal( + final String toolInput, { + final Options? options, + }) async { + return _func(toolInput); + } +} diff --git a/packages/langchain_core/lib/src/tools/tools.dart b/packages/langchain_core/lib/src/tools/tools.dart index 10eeb1d8..6c1c26ed 100644 --- a/packages/langchain_core/lib/src/tools/tools.dart +++ b/packages/langchain_core/lib/src/tools/tools.dart @@ -1,3 +1,4 @@ export 'base.dart'; export 'fake.dart'; +export 'string.dart'; export 'types.dart'; diff --git a/packages/langchain_core/test/runnables/batch_test.dart b/packages/langchain_core/test/runnables/batch_test.dart index e0ca7263..a06b71ff 100644 --- a/packages/langchain_core/test/runnables/batch_test.dart +++ b/packages/langchain_core/test/runnables/batch_test.dart @@ -124,9 +124,9 @@ void main() { test('Tool batch', () async { final run = FakeTool(); final res = await run.batch([ - {'input': 'hello1'}, - {'input': 'hello2'}, - {'input': 'hello3'}, + 'hello1', + 'hello2', + 'hello3', ]); expect( res.map((final e) => e).toList(), diff --git a/packages/langchain_core/test/runnables/invoke_test.dart b/packages/langchain_core/test/runnables/invoke_test.dart index 3e6c8a56..6f873dc3 100644 --- a/packages/langchain_core/test/runnables/invoke_test.dart +++ b/packages/langchain_core/test/runnables/invoke_test.dart @@ -75,7 +75,7 @@ void main() { test('Tool as Runnable', () async { final run = FakeTool(); - final res = await run.invoke({'input': 'hello'}); + final res = await run.invoke('hello'); expect(res, 'hello'); }); }); diff --git a/packages/langchain_core/test/runnables/stream_test.dart b/packages/langchain_core/test/runnables/stream_test.dart index 3e25a7d6..eabd57e3 100644 --- a/packages/langchain_core/test/runnables/stream_test.dart +++ b/packages/langchain_core/test/runnables/stream_test.dart @@ -109,7 +109,7 @@ void main() { test('Streaming Tool', () async { final run = FakeTool(); - final stream = run.stream({'input': 'hello'}); + final stream = run.stream('hello'); final streamList = await stream.toList(); expect(streamList.length, 1); @@ -241,17 +241,17 @@ void main() { test('Test call to Tool from streaming input', () async { final inputStream = Stream.fromIterable([ - {'input': 'H'}, - {'input': 'e'}, - {'input': 'l'}, - {'input': 'l'}, - {'input': 'o'}, - {'input': ' '}, - {'input': 'W'}, - {'input': 'o'}, - {'input': 'r'}, - {'input': 'l'}, - {'input': 'd'}, + 'H', + 'e', + 'l', + 'l', + 'o', + ' ', + 'W', + 'o', + 'r', + 'l', + 'd', ]); final tool = FakeTool(); diff --git a/packages/langchain_core/test/tools/base_test.dart b/packages/langchain_core/test/tools/base_test.dart index 2131f4a1..ab1bbfba 100644 --- a/packages/langchain_core/test/tools/base_test.dart +++ b/packages/langchain_core/test/tools/base_test.dart @@ -1,39 +1,95 @@ import 'package:langchain_core/tools.dart'; +import 'package:meta/meta.dart'; import 'package:test/test.dart'; void main() { - group('BaseTool tests', () { - test('StructuredTool.fromFunction', () async { - final echoTool = BaseTool.fromFunction( + group('Tool tests', () { + test('Tool.fromFunction', () async { + final echoTool = Tool.fromFunction( name: 'echo-int', description: 'echo-int', - func: ( - final Map toolInput, { - final ToolOptions? options, - }) => - toolInput['input'].toString(), - inputJsonSchema: const {'type': 'integer'}, + func: (final int toolInput) => toolInput.toString(), + inputJsonSchema: const { + 'type': 'object', + 'properties': { + 'input': { + 'type': 'integer', + 'description': 'The input to the tool', + }, + }, + 'required': ['input'], + }, ); expect(echoTool.name, 'echo-int'); expect(echoTool.description, 'echo-int'); - expect(echoTool.run({'input': 1}), '1'); + expect(await echoTool.invoke(1), '1'); + expect(echoTool.getInputFromJson({'input': 1}), 1); }); - test('Tool.fromFunction', () async { - final echoTool = Tool.fromFunction( - name: 'echo', - description: 'echo', - func: ( - final String toolInput, { - final ToolOptions? options, - }) => - toolInput, + test('Tool.fromFunction with custom deserialization', () async { + final tool = Tool.fromFunction<_SearchInput, String>( + name: 'search', + description: 'Tool for searching the web.', + inputJsonSchema: const { + 'type': 'object', + 'properties': { + 'query': { + 'type': 'string', + 'description': 'The query to search for', + }, + 'n': { + 'type': 'number', + 'description': 'The number of results to return', + }, + }, + 'required': ['query'], + }, + func: (final _SearchInput toolInput) async { + final n = toolInput.n; + final res = List.generate(n, (final i) => 'Result ${i + 1}'); + return 'Results:\n${res.join('\n')}'; + }, + getInputFromJson: _SearchInput.fromJson, ); - expect(echoTool.name, 'echo'); - expect(echoTool.description, 'echo'); - expect(echoTool.run({'input': 'Hello world!'}), 'Hello world!'); + expect(tool.name, 'search'); + expect(tool.description, 'Tool for searching the web.'); + expect( + await tool.invoke(const _SearchInput(query: 'cats', n: 3)), + 'Results:\nResult 1\nResult 2\nResult 3', + ); + expect( + tool.getInputFromJson({ + 'query': 'cats', + 'n': 3, + }), + const _SearchInput(query: 'cats', n: 3), + ); }); }); } + +@immutable +class _SearchInput { + const _SearchInput({ + required this.query, + required this.n, + }); + + final String query; + final int n; + + _SearchInput.fromJson(final Map json) + : this( + query: json['query'] as String, + n: json['n'] as int, + ); + + @override + bool operator ==(covariant _SearchInput other) => + identical(this, other) || query == other.query && n == other.n; + + @override + int get hashCode => query.hashCode ^ n.hashCode; +} diff --git a/packages/langchain_core/test/tools/string_test.dart b/packages/langchain_core/test/tools/string_test.dart new file mode 100644 index 00000000..fe938d3f --- /dev/null +++ b/packages/langchain_core/test/tools/string_test.dart @@ -0,0 +1,18 @@ +import 'package:langchain_core/tools.dart'; +import 'package:test/test.dart'; + +void main() { + group('StringTool tests', () { + test('StringTool.fromFunction', () async { + final echoTool = StringTool.fromFunction( + name: 'echo', + description: 'echo', + func: (String input) => input, + ); + + expect(echoTool.name, 'echo'); + expect(echoTool.description, 'echo'); + expect(await echoTool.invoke('Hello world!'), 'Hello world!'); + }); + }); +} diff --git a/packages/langchain_openai/lib/src/agents/functions.dart b/packages/langchain_openai/lib/src/agents/functions.dart index 0fe5f148..9bedbf03 100644 --- a/packages/langchain_openai/lib/src/agents/functions.dart +++ b/packages/langchain_openai/lib/src/agents/functions.dart @@ -125,7 +125,7 @@ class OpenAIFunctionsAgent extends BaseSingleActionAgent { /// system message and the input from the agent. factory OpenAIFunctionsAgent.fromLLMAndTools({ required final ChatOpenAI llm, - required final List> tools, + required final List tools, final BaseChatMemory? memory, final SystemChatMessagePromptTemplate systemChatMessage = _systemChatMessagePromptTemplate, diff --git a/packages/langchain_openai/lib/src/tools/dall_e.dart b/packages/langchain_openai/lib/src/tools/dall_e.dart index 425a648f..bdece57a 100644 --- a/packages/langchain_openai/lib/src/tools/dall_e.dart +++ b/packages/langchain_openai/lib/src/tools/dall_e.dart @@ -45,7 +45,7 @@ export 'package:openai_dart/openai_dart.dart' /// ); /// ``` /// {@endtemplate} -final class OpenAIDallETool extends Tool { +final class OpenAIDallETool extends StringTool { /// {@macro dall_e_tool} OpenAIDallETool({ final String? apiKey, @@ -80,7 +80,7 @@ final class OpenAIDallETool extends Tool { String get apiKey => _client.apiKey; @override - FutureOr runInternalString( + Future invokeInternal( final String toolInput, { final OpenAIDallEToolOptions? options, }) async { diff --git a/packages/langchain_openai/test/agents/functions_test.dart b/packages/langchain_openai/test/agents/functions_test.dart index 62aba653..1f9be7c0 100644 --- a/packages/langchain_openai/test/agents/functions_test.dart +++ b/packages/langchain_openai/test/agents/functions_test.dart @@ -12,6 +12,7 @@ import 'package:langchain_core/prompts.dart'; import 'package:langchain_core/runnables.dart'; import 'package:langchain_core/tools.dart'; import 'package:langchain_openai/langchain_openai.dart'; +import 'package:meta/meta.dart'; import 'package:test/test.dart'; void main() { @@ -22,7 +23,7 @@ void main() { final llm = ChatOpenAI( apiKey: openaiApiKey, defaultOptions: const ChatOpenAIOptions( - model: 'gpt-4-turbo-preview', + model: 'gpt-4-turbo', temperature: 0, ), ); @@ -45,12 +46,12 @@ void main() { final llm = ChatOpenAI( apiKey: openaiApiKey, defaultOptions: const ChatOpenAIOptions( - model: 'gpt-4-turbo-preview', + model: 'gpt-4-turbo', temperature: 0, ), ); - final tool = BaseTool.fromFunction( + final tool = Tool.fromFunction<_SearchInput, String>( name: 'search', description: 'Tool for searching the web.', inputJsonSchema: const { @@ -67,14 +68,12 @@ void main() { }, 'required': ['query'], }, - func: ( - final Map toolInput, { - final ToolOptions? options, - }) async { - final n = toolInput['n']; + func: (final _SearchInput toolInput) async { + final n = toolInput.n; final res = List.generate(n, (final i) => 'Result ${i + 1}'); return 'Results:\n${res.join('\n')}'; }, + getInputFromJson: _SearchInput.fromJson, ); final tools = [tool]; @@ -136,7 +135,7 @@ void main() { final model = ChatOpenAI( apiKey: openaiApiKey, defaultOptions: const ChatOpenAIOptions( - model: 'gpt-4-turbo-preview', + model: 'gpt-4-turbo', temperature: 0, ), ).bind(ChatOpenAIOptions(functions: [tool.toChatFunction()])); @@ -172,3 +171,27 @@ void main() { }); }); } + +@immutable +class _SearchInput { + const _SearchInput({ + required this.query, + required this.n, + }); + + final String query; + final int n; + + _SearchInput.fromJson(final Map json) + : this( + query: json['query'] as String, + n: json['n'] as int, + ); + + @override + bool operator ==(covariant _SearchInput other) => + identical(this, other) || query == other.query && n == other.n; + + @override + int get hashCode => query.hashCode ^ n.hashCode; +} diff --git a/packages/langchain_openai/test/chat_models/chat_openai_test.dart b/packages/langchain_openai/test/chat_models/chat_openai_test.dart index bcf1f775..36355465 100644 --- a/packages/langchain_openai/test/chat_models/chat_openai_test.dart +++ b/packages/langchain_openai/test/chat_models/chat_openai_test.dart @@ -323,7 +323,7 @@ void main() { final llm = ChatOpenAI( apiKey: openaiApiKey, defaultOptions: const ChatOpenAIOptions( - model: 'gpt-4-turbo-preview', + model: 'gpt-4-turbo', temperature: 0, seed: 12345, ), diff --git a/packages/langchain_openai/test/tools/dall_e_test.dart b/packages/langchain_openai/test/tools/dall_e_test.dart index 9a9f34a8..aa19def0 100644 --- a/packages/langchain_openai/test/tools/dall_e_test.dart +++ b/packages/langchain_openai/test/tools/dall_e_test.dart @@ -22,7 +22,7 @@ void main() { size: ImageSize.v256x256, ), ); - final res = await tool.invoke({Tool.inputVar: 'A cute baby sea otter'}); + final res = await tool.invoke('A cute baby sea otter'); expect(res, startsWith('http')); tool.close(); }); @@ -36,13 +36,13 @@ void main() { responseFormat: ImageResponseFormat.b64Json, ), ); - final res = await tool.invoke({Tool.inputVar: 'A cute baby sea otter'}); + final res = await tool.invoke('A cute baby sea otter'); expect(res, isNot(startsWith('http'))); tool.close(); }); test('Test OpenAIDallETool in an agent', - timeout: const Timeout(Duration(minutes: 2)), skip: true, () async { + timeout: const Timeout(Duration(minutes: 2)), skip: false, () async { final llm = ChatOpenAI( apiKey: openAiKey, defaultOptions: const ChatOpenAIOptions( @@ -51,7 +51,7 @@ void main() { ), ); - final List> tools = [ + final List tools = [ CalculatorTool(), OpenAIDallETool( apiKey: openAiKey,