Skip to content

Commit

Permalink
fix: Allow async functions in Runnable.mapInput (#396)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmigloz committed Apr 30, 2024
1 parent c7dc979 commit e4c3509
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 3 deletions.
4 changes: 3 additions & 1 deletion packages/langchain_core/lib/src/runnables/input_map.dart
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import 'dart:async';

import 'runnable.dart';
import 'types.dart';

Expand Down Expand Up @@ -32,7 +34,7 @@ class RunnableMapInput<RunInput extends Object, RunOutput extends Object>
: super(defaultOptions: const RunnableOptions());

/// A function that maps [RunInput] to [RunOutput].
final RunOutput Function(RunInput input) inputMapper;
final FutureOr<RunOutput> Function(RunInput input) inputMapper;

/// Invokes the [RunnableMapInput] on the given [input].
///
Expand Down
2 changes: 1 addition & 1 deletion packages/langchain_core/lib/src/runnables/runnable.dart
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ abstract class Runnable<RunInput extends Object?,
/// - [inputMapper] - a function that maps [RunInput] to [RunOutput].
static Runnable<RunInput, RunnableOptions, RunOutput>
mapInput<RunInput extends Object, RunOutput extends Object>(
final RunOutput Function(RunInput input) inputMapper,
final FutureOr<RunOutput> Function(RunInput input) inputMapper,
) {
return RunnableMapInput<RunInput, RunOutput>(inputMapper);
}
Expand Down
16 changes: 15 additions & 1 deletion packages/langchain_core/test/runnables/input_map_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import 'package:test/test.dart';

void main() {
group('RunnableMapInput tests', () {
test('RunnableMapInput from Runnable.getItemFromMap', () async {
test('Invoke RunnableMapInput', () async {
final chain =
Runnable.mapInput<Map<String, dynamic>, Map<String, dynamic>>(
(final input) => {
Expand All @@ -15,6 +15,20 @@ void main() {
expect(res, {'input': 'foo1bar1'});
});

test('Invoke async RunnableMapInput', () async {
Future<int> asyncFunc(final Map<String, dynamic> input) async {
await Future<void>.delayed(const Duration(milliseconds: 100));
return input.length;
}

final chain = Runnable.mapInput<Map<String, dynamic>, int>(
(final input) async => asyncFunc(input),
);

final res = await chain.invoke({'foo': 'foo1', 'bar': 'bar1'});
expect(res, 2);
});

test('Streaming RunnableMapInput', () async {
final chain =
Runnable.mapInput<Map<String, dynamic>, Map<String, dynamic>>(
Expand Down

0 comments on commit e4c3509

Please sign in to comment.