diff --git a/packages/langchain_core/lib/src/runnables/input_map.dart b/packages/langchain_core/lib/src/runnables/input_map.dart index fb6b8cf2..8e5c05b8 100644 --- a/packages/langchain_core/lib/src/runnables/input_map.dart +++ b/packages/langchain_core/lib/src/runnables/input_map.dart @@ -1,3 +1,5 @@ +import 'dart:async'; + import 'runnable.dart'; import 'types.dart'; @@ -32,7 +34,7 @@ class RunnableMapInput : super(defaultOptions: const RunnableOptions()); /// A function that maps [RunInput] to [RunOutput]. - final RunOutput Function(RunInput input) inputMapper; + final FutureOr Function(RunInput input) inputMapper; /// Invokes the [RunnableMapInput] on the given [input]. /// diff --git a/packages/langchain_core/lib/src/runnables/runnable.dart b/packages/langchain_core/lib/src/runnables/runnable.dart index 488ee33f..792bc80a 100644 --- a/packages/langchain_core/lib/src/runnables/runnable.dart +++ b/packages/langchain_core/lib/src/runnables/runnable.dart @@ -133,7 +133,7 @@ abstract class Runnable mapInput( - final RunOutput Function(RunInput input) inputMapper, + final FutureOr Function(RunInput input) inputMapper, ) { return RunnableMapInput(inputMapper); } diff --git a/packages/langchain_core/test/runnables/input_map_test.dart b/packages/langchain_core/test/runnables/input_map_test.dart index 18dafe49..670846c3 100644 --- a/packages/langchain_core/test/runnables/input_map_test.dart +++ b/packages/langchain_core/test/runnables/input_map_test.dart @@ -3,7 +3,7 @@ import 'package:test/test.dart'; void main() { group('RunnableMapInput tests', () { - test('RunnableMapInput from Runnable.getItemFromMap', () async { + test('Invoke RunnableMapInput', () async { final chain = Runnable.mapInput, Map>( (final input) => { @@ -15,6 +15,20 @@ void main() { expect(res, {'input': 'foo1bar1'}); }); + test('Invoke async RunnableMapInput', () async { + Future asyncFunc(final Map input) async { + await Future.delayed(const Duration(milliseconds: 100)); + return input.length; + } + + final chain = Runnable.mapInput, int>( + (final input) async => asyncFunc(input), + ); + + final res = await chain.invoke({'foo': 'foo1', 'bar': 'bar1'}); + expect(res, 2); + }); + test('Streaming RunnableMapInput', () async { final chain = Runnable.mapInput, Map>(