Skip to content

Commit

Permalink
feat: Implement .batch support (#370)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmigloz committed Apr 9, 2024
1 parent 8f514dd commit d254f92
Show file tree
Hide file tree
Showing 26 changed files with 509 additions and 56 deletions.
37 changes: 35 additions & 2 deletions docs/expression_language/interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ The type of the input and output varies by component:
| `LLM` | `PromptValue` | `LLMResult` |
| `ChatModel` | `PromptValue` | `ChatResult` |
| `Chain` | `Map<String, dynamic>` | `Map<String, dynamic>` |
| `OutputParser` | `LanguageModelResult` | Parser output type |
| `OutputParser` | Runnable input type | Parser output type |
| `Tool` | `Map<String, dynamic>` | `String` |
| `RunnableSequence` | Fist input type | Last output type |
| `RunnableMap` | Runnable input type | `Map<String, dynamic>` |
Expand Down Expand Up @@ -104,7 +104,40 @@ await for (final res in stream) {

### Batch

Batch is not supported yet.
Batches the invocation of the `Runnable` on the given `inputs`.

```dart
final res = await chain.batch([
{'topic': 'bears'},
{'topic': 'cats'},
]);
print(res);
//['Why did the bear break up with his girlfriend? Because she was too "grizzly" for him!',
// 'Why was the cat sitting on the computer? Because it wanted to keep an eye on the mouse!']
```

If the underlying provider supports batching, this method will try to batch the calls to the provider. Otherwise, it will just call `invoke` on each input concurrently. You can configure the concurrency limit by setting the `concurrencyLimit` field in the `options` parameter.

You can provide call options to the `batch` method using the `options` parameter. It can be:
- `null`: the default options are used.
- List with 1 element: the same options are used for all inputs.
- List with the same length as the inputs: each input gets its own options.

```dart
final res = await chain.batch(
[
{'topic': 'bears'},
{'topic': 'cats'},
],
options: [
const ChatOpenAIOptions(model: 'gpt-3.5-turbo', temperature: 0.5),
const ChatOpenAIOptions(model: 'gpt-4', temperature: 0.7),
],
);
print(res);
//['Why did the bear break up with his girlfriend? Because he couldn't bear the relationship anymore!,',
// 'Why don't cats play poker in the jungle? Because there's too many cheetahs!']
```

## Runnable types

Expand Down
46 changes: 46 additions & 0 deletions examples/docs_examples/bin/expression_language/interface.dart
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ void main(final List<String> arguments) async {
// Runnable interface
await _runnableInterfaceInvoke();
await _runnableInterfaceStream();
await _runnableInterfaceBatch();
await _runnableInterfaceBatchOptions();

// Runnable types
await _runnableTypesRunnableSequence();
Expand Down Expand Up @@ -72,6 +74,50 @@ Future<void> _runnableInterfaceStream() async {
// 15: !
}

Future<void> _runnableInterfaceBatch() async {
final openaiApiKey = Platform.environment['OPENAI_API_KEY'];
final model = ChatOpenAI(apiKey: openaiApiKey);

final promptTemplate = ChatPromptTemplate.fromTemplate(
'Tell me a joke about {topic}',
);

final chain = promptTemplate.pipe(model).pipe(const StringOutputParser());

final res = await chain.batch([
{'topic': 'bears'},
{'topic': 'cats'},
]);
print(res);
//['Why did the bear break up with his girlfriend? Because she was too "grizzly" for him!',
// 'Why was the cat sitting on the computer? Because it wanted to keep an eye on the mouse!']
}

Future<void> _runnableInterfaceBatchOptions() async {
final openaiApiKey = Platform.environment['OPENAI_API_KEY'];
final model = ChatOpenAI(apiKey: openaiApiKey);

final promptTemplate = ChatPromptTemplate.fromTemplate(
'Tell me a joke about {topic}',
);

final chain = promptTemplate.pipe(model).pipe(const StringOutputParser());

final res = await chain.batch(
[
{'topic': 'bears'},
{'topic': 'cats'},
],
options: [
const ChatOpenAIOptions(model: 'gpt-3.5-turbo', temperature: 0.5),
const ChatOpenAIOptions(model: 'gpt-4', temperature: 0.7),
],
);
print(res);
//['Why did the bear break up with his girlfriend? Because he couldn't bear the relationship anymore!,',
// 'Why don't cats play poker in the jungle? Because there's too many cheetahs!']
}

Future<void> _runnableTypesRunnableSequence() async {
final openaiApiKey = Platform.environment['OPENAI_API_KEY'];
final model = ChatOpenAI(apiKey: openaiApiKey);
Expand Down
2 changes: 1 addition & 1 deletion packages/langchain/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ The components can be grouped into a few core modules:

![LangChain.dart](https://raw.githubusercontent.com/davidmigloz/langchain_dart/main/docs/img/langchain.dart.png)

- 📃 **Model I/O:** LangChain offers a unified API for interacting with various LLM providers (e.g. OpenAI, Google, Mistral, Ollama, etc.), allowing developers to switch between them with ease. Additionally, it provides tools for managing model inputs (prompt templates and example selectors) and parses the resulting model outputs (output parsers).
- 📃 **Model I/O:** LangChain offers a unified API for interacting with various LLM providers (e.g. OpenAI, Google, Mistral, Ollama, etc.), allowing developers to switch between them with ease. Additionally, it provides tools for managing model inputs (prompt templates and example selectors) and parsing the resulting model outputs (output parsers).
- 📚 **Retrieval:** assists in loading user data (via document loaders), transforming it (with text splitters), extracting its meaning (using embedding models), storing (in vector stores) and retrieving it (through retrievers) so that it can be used to ground the model's responses (i.e. Retrieval-Augmented Generation or RAG).
- 🤖 **Agents:** "bots" that leverage LLMs to make informed decisions about which available tools (such as web search, calculators, database lookup, etc.) to use to accomplish the designated task.

Expand Down
4 changes: 3 additions & 1 deletion packages/langchain_core/lib/src/chains/types.dart
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,7 @@ typedef ChainValues = Map<String, dynamic>;
@immutable
class ChainOptions extends BaseLangChainOptions {
/// {@macro chain_options}
const ChainOptions();
const ChainOptions({
super.concurrencyLimit,
});
}
17 changes: 16 additions & 1 deletion packages/langchain_core/lib/src/chat_models/types.dart
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ import '../language_models/language_models.dart';
/// {@endtemplate}
class ChatModelOptions extends LanguageModelOptions {
/// {@macro chat_model_options}
const ChatModelOptions();
const ChatModelOptions({
super.concurrencyLimit,
});
}

/// {@template chat_result}
Expand Down Expand Up @@ -44,6 +46,19 @@ class ChatResult extends LanguageModelResult<AIChatMessage> {
streaming: other.streaming,
);
}

@override
String toString() {
return '''
ChatResult{
id: $id,
output: $output,
finishReason: $finishReason,
metadata: $metadata,
usage: $usage,
streaming: $streaming
}''';
}
}

/// {@template chat_message}
Expand Down
4 changes: 3 additions & 1 deletion packages/langchain_core/lib/src/langchain/types.dart
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,7 @@ import '../runnables/types.dart';
@immutable
class BaseLangChainOptions extends RunnableOptions {
/// {@macro base_lang_chain_options}
const BaseLangChainOptions();
const BaseLangChainOptions({
super.concurrencyLimit,
});
}
17 changes: 3 additions & 14 deletions packages/langchain_core/lib/src/language_models/types.dart
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ import '../langchain/types.dart';
@immutable
abstract class LanguageModelOptions extends BaseLangChainOptions {
/// {@macro language_model_options}
const LanguageModelOptions();
const LanguageModelOptions({
super.concurrencyLimit,
});
}

/// {@template language_model}
Expand Down Expand Up @@ -75,19 +77,6 @@ abstract class LanguageModelResult<O extends Object> {

/// Merges this result with another by concatenating the outputs.
LanguageModelResult<O> concat(final LanguageModelResult<O> other);

@override
String toString() {
return '''
LanguageModelResult{
id: $id,
output: $output,
finishReason: $finishReason,
metadata: $metadata,
usage: $usage,
streaming: $streaming
}''';
}
}

/// {@template language_model_usage}
Expand Down
20 changes: 19 additions & 1 deletion packages/langchain_core/lib/src/llms/types.dart
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import 'package:meta/meta.dart';

import '../language_models/types.dart';

/// {@template llm_options}
/// Generation options to pass into the LLM.
/// {@endtemplate}
@immutable
class LLMOptions extends LanguageModelOptions {
/// {@macro llm_options}
const LLMOptions();
const LLMOptions({
super.concurrencyLimit,
});
}

/// {@template llm_result}
Expand Down Expand Up @@ -41,4 +46,17 @@ class LLMResult extends LanguageModelResult<String> {
streaming: other.streaming,
);
}

@override
String toString() {
return '''
LLMResult{
id: $id,
output: $output,
finishReason: $finishReason,
metadata: $metadata,
usage: $usage,
streaming: $streaming
}''';
}
}
4 changes: 3 additions & 1 deletion packages/langchain_core/lib/src/output_parsers/types.dart
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,7 @@ import '../langchain/types.dart';
@immutable
class OutputParserOptions extends BaseLangChainOptions {
/// {@macro output_parser_options}
const OutputParserOptions();
const OutputParserOptions({
super.concurrencyLimit,
});
}
10 changes: 8 additions & 2 deletions packages/langchain_core/lib/src/runnables/map.dart
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ class RunnableMap<RunInput extends Object>
final output = <String, dynamic>{};

await Future.forEach(steps.entries, (final entry) async {
output[entry.key] = await entry.value.invoke(input, options: options);
output[entry.key] = await entry.value.invoke(
input,
options: entry.value.getCompatibleOptions(options),
);
});

return output;
Expand All @@ -87,7 +90,10 @@ class RunnableMap<RunInput extends Object>
return StreamGroup.merge(
steps.entries.map((final entry) {
return entry.value
.streamFromInputStream(inputStream, options: options)
.streamFromInputStream(
inputStream,
options: entry.value.getCompatibleOptions(options),
)
.map((final output) => {entry.key: output});
}),
).asBroadcastStream();
Expand Down
52 changes: 52 additions & 0 deletions packages/langchain_core/lib/src/runnables/runnable.dart
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import 'dart:async';

import '../../utils.dart';
import 'binding.dart';
import 'function.dart';
import 'input_getter.dart';
Expand Down Expand Up @@ -118,6 +119,49 @@ abstract class Runnable<RunInput extends Object?,
final CallOptions? options,
});

/// Batches the invocation of the [Runnable] on the given [inputs].
///
/// If the underlying provider supports batching, this method will try to
/// batch the calls to the provider. Otherwise, it will just call [invoke] on
/// each input concurrently.
///
/// You can configure the concurrency limit by setting the `concurrencyLimit`
/// field in the [options] parameter.
///
/// - [inputs] - the inputs to invoke the [Runnable] on concurrently.
/// - [options] - the options to use when invoking the [Runnable]. It can be:
/// * `null`: the default options are used.
/// * List with 1 element: the same options are used for all inputs.
/// * List with the same length as the inputs: each input gets its own options.
Future<List<RunOutput>> batch(
final List<RunInput> inputs, {
final List<CallOptions>? options,
}) async {
// By default, it just calls `.invoke` on each input con
// Subclasses should override this method if they support batching
assert(
options == null || options.length == 1 || options.length == inputs.length,
);

final finalOptions = options?.first ?? defaultOptions;
final concurrencyLimit = finalOptions.concurrencyLimit;

var index = 0;
final results = <RunOutput>[];
for (final chunk in chunkList(inputs, chunkSize: concurrencyLimit)) {
final chunkResults = await Future.wait(
chunk.map(
(final input) => invoke(
input,
options: options?.length == 1 ? options![0] : options?[index++],
),
),
);
results.addAll(chunkResults);
}
return results;
}

/// Streams the output of invoking the [Runnable] on the given [input].
///
/// - [input] - the input to invoke the [Runnable] on.
Expand Down Expand Up @@ -168,4 +212,12 @@ abstract class Runnable<RunInput extends Object?,
options: options,
);
}

/// Returns the given [options] if they are compatible with the [Runnable],
/// otherwise returns `null`.
CallOptions? getCompatibleOptions(
final RunnableOptions? options,
) {
return options is CallOptions ? options : null;
}
}
25 changes: 20 additions & 5 deletions packages/langchain_core/lib/src/runnables/sequence.dart
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,16 @@ class RunnableSequence<RunInput extends Object?, RunOutput extends Object?>
Object? nextStepInput = input;

for (final step in [first, ...middle]) {
nextStepInput = await step.invoke(nextStepInput, options: options);
nextStepInput = await step.invoke(
nextStepInput,
options: step.getCompatibleOptions(options),
);
}

return last.invoke(nextStepInput, options: options);
return last.invoke(
nextStepInput,
options: last.getCompatibleOptions(options),
);
}

@override
Expand All @@ -126,21 +132,30 @@ class RunnableSequence<RunInput extends Object?, RunOutput extends Object?>
}) {
Stream<Object?> nextStepStream;
try {
nextStepStream = first.streamFromInputStream(inputStream);
nextStepStream = first.streamFromInputStream(
inputStream,
options: first.getCompatibleOptions(options),
);
} on TypeError catch (e) {
_throwInvalidInputTypeStream(e, first);
}

for (final step in middle) {
try {
nextStepStream = step.streamFromInputStream(nextStepStream);
nextStepStream = step.streamFromInputStream(
nextStepStream,
options: step.getCompatibleOptions(options),
);
} on TypeError catch (e) {
_throwInvalidInputTypeStream(e, step);
}
}

try {
return last.streamFromInputStream(nextStepStream);
return last.streamFromInputStream(
nextStepStream,
options: last.getCompatibleOptions(options),
);
} on TypeError catch (e) {
_throwInvalidInputTypeStream(e, last);
}
Expand Down
Loading

0 comments on commit d254f92

Please sign in to comment.