Skip to content

Commit

Permalink
feat: Reduce input stream for PromptTemplate, LLM, ChatModel, Retriev…
Browse files Browse the repository at this point in the history
…er and Tool (#388)
  • Loading branch information
davidmigloz committed Apr 26, 2024
1 parent 827e262 commit b59bcd4
Show file tree
Hide file tree
Showing 28 changed files with 373 additions and 106 deletions.
6 changes: 3 additions & 3 deletions docs/modules/memory/memory.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ You may want to use this class directly if you are managing memory outside of a
```dart
final history = ChatMessageHistory();
history.addUserChatMessage('hi!');
history.addHumanChatMessage('hi!');
history.addAIChatMessage('whats up?');
print(await history.getChatMessages());
Expand All @@ -60,7 +60,7 @@ We can first extract it as a string.
```dart
final memory = ConversationBufferMemory();
memory.chatHistory.addUserChatMessage('hi!');
memory.chatHistory.addHumanChatMessage('hi!');
memory.chatHistory.addAIChatMessage('whats up
print(await memory.loadMemoryVariables());
Expand All @@ -72,7 +72,7 @@ We can also get the history as a list of messages:
```dart
final memory = ConversationBufferMemory(returnMessages: true);
memory.chatHistory.addUserChatMessage('hi!');
memory.chatHistory.addHumanChatMessage('hi!');
memory.chatHistory.addAIChatMessage('whats up?');
print(await memory.loadMemoryVariables());
Expand Down
4 changes: 2 additions & 2 deletions docs/modules/model_io/models/llms/how_to/fake_llm.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ We expose some fake LLM classes that can be used for testing. This allows you
to mock out calls to the LLM and simulate what would happen if the LLM
responded in a certain way.

## FakeListLLM
## FakeLLM

You can configure a list of responses that the LLM will return in order.

Example:
```dart
test('Test LLMChain call', () async {
final model = FakeListLLM(responses: ['Hello world!']);
final model = FakeLLM(responses: ['Hello world!']);
final prompt = PromptTemplate.fromTemplate('Print {foo}');
final chain = LLMChain(prompt: prompt, llm: model);
final res = await chain.call({'foo': 'Hello world!'});
Expand Down
8 changes: 8 additions & 0 deletions examples/browser_summarizer/pubspec.lock
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,14 @@ packages:
url: "https://pub.dev"
source: hosted
version: "0.2.0"
rxdart:
dependency: transitive
description:
name: rxdart
sha256: "0c7c0cedd93788d996e33041ffecda924cc54389199cde4e6a34b440f50044cb"
url: "https://pub.dev"
source: hosted
version: "0.27.7"
shared_preferences:
dependency: "direct main"
description:
Expand Down
8 changes: 8 additions & 0 deletions examples/docs_examples/pubspec.lock
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,14 @@ packages:
url: "https://pub.dev"
source: hosted
version: "0.2.0"
rxdart:
dependency: transitive
description:
name: rxdart
sha256: "0c7c0cedd93788d996e33041ffecda924cc54389199cde4e6a34b440f50044cb"
url: "https://pub.dev"
source: hosted
version: "0.27.7"
source_span:
dependency: transitive
description:
Expand Down
8 changes: 8 additions & 0 deletions examples/hello_world_backend/pubspec.lock
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,14 @@ packages:
url: "https://pub.dev"
source: hosted
version: "1.9.0"
rxdart:
dependency: transitive
description:
name: rxdart
sha256: "0c7c0cedd93788d996e33041ffecda924cc54389199cde4e6a34b440f50044cb"
url: "https://pub.dev"
source: hosted
version: "0.27.7"
shelf:
dependency: "direct main"
description:
Expand Down
8 changes: 8 additions & 0 deletions examples/hello_world_cli/pubspec.lock
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,14 @@ packages:
url: "https://pub.dev"
source: hosted
version: "1.9.0"
rxdart:
dependency: transitive
description:
name: rxdart
sha256: "0c7c0cedd93788d996e33041ffecda924cc54389199cde4e6a34b440f50044cb"
url: "https://pub.dev"
source: hosted
version: "0.27.7"
source_span:
dependency: transitive
description:
Expand Down
8 changes: 8 additions & 0 deletions examples/hello_world_flutter/pubspec.lock
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,14 @@ packages:
url: "https://pub.dev"
source: hosted
version: "6.1.1"
rxdart:
dependency: transitive
description:
name: rxdart
sha256: "0c7c0cedd93788d996e33041ffecda924cc54389199cde4e6a34b440f50044cb"
url: "https://pub.dev"
source: hosted
version: "0.27.7"
sky_engine:
dependency: transitive
description: flutter
Expand Down
2 changes: 1 addition & 1 deletion packages/langchain/example/langchain_example.dart
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ void main() async {
final promptTemplate = PromptTemplate.fromTemplate(
'tell me a joke about {subject}',
);
final llm = FakeListLLM(
final llm = FakeLLM(
responses: ['Why did the AI go on a diet? Because it had too many bytes!'],
);
final chain = promptTemplate.pipe(llm).pipe(const StringOutputParser());
Expand Down
4 changes: 2 additions & 2 deletions packages/langchain/test/chains/base_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,15 @@ void main() {

group('Runnable tests', () {
test('Chain as Runnable', () async {
final model = FakeListLLM(responses: ['Hello world!']);
final model = FakeLLM(responses: ['Hello world!']);
final prompt = PromptTemplate.fromTemplate('Print {foo}');
final run = LLMChain(prompt: prompt, llm: model);
final res = await run.invoke({'foo': 'Hello world!'});
expect(res[LLMChain.defaultOutputKey], 'Hello world!');
});

test('Streaming Chain', () async {
final model = FakeListLLM(responses: ['Hello world!']);
final model = FakeLLM(responses: ['Hello world!']);
final prompt = PromptTemplate.fromTemplate('Print {foo}');
final run = LLMChain(prompt: prompt, llm: model);
final stream = run.stream({'foo': 'Hello world!'});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ void main() {
}

test('Test MapReduceDocumentsChain with LLM', () async {
final model = FakeListLLM(
final model = FakeLLM(
responses: [
// Summarize this content: Hello 1!
'1',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import 'package:test/test.dart';
void main() {
group('ReduceDocumentsChain tests', () {
test('Test reduce', () async {
final llm = FakeListLLM(
final llm = FakeLLM(
responses: [
// Summarize this content: Hello 1!\n\nHello 2!\n\nHello 3!\n\nHello 4!
'Hello 1234!',
Expand Down Expand Up @@ -32,7 +32,7 @@ void main() {
});

test('Test reduce and collapse', () async {
final llm = FakeListLLM(
final llm = FakeLLM(
responses: [
// Collapse this content: Hello 1!\n\nHello 2!\n\nHello 3!
'Hello 123!',
Expand Down
22 changes: 8 additions & 14 deletions packages/langchain_core/lib/src/chat_models/base.dart
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import 'package:meta/meta.dart';

import '../language_models/language_models.dart';
import '../prompts/types.dart';
import '../utils/reduce.dart';
import 'types.dart';

/// {@template base_chat_model}
Expand All @@ -15,22 +16,15 @@ abstract class BaseChatModel<Options extends ChatModelOptions>
required super.defaultOptions,
});

/// Runs the chat model on the given prompt value.
///
/// - [input] The prompt value to pass into the model.
/// - [options] Generation options to pass into the Chat Model.
///
/// Example:
/// ```dart
/// final result = await chat.invoke(
/// PromptValue.chat([ChatMessage.humanText('say hi!')]),
/// );
/// ```
@override
Future<ChatResult> invoke(
final PromptValue input, {
Stream<ChatResult> streamFromInputStream(
final Stream<PromptValue> inputStream, {
final Options? options,
});
}) async* {
final input = await inputStream.toList();
final reduced = reduce<PromptValue>(input);
yield* stream(reduced, options: options);
}

/// Runs the chat model on the given messages and returns a chat message.
///
Expand Down
2 changes: 1 addition & 1 deletion packages/langchain_core/lib/src/chat_models/types.dart
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class ChatResult extends LanguageModelResult<AIChatMessage> {
String get outputAsString => output.content;

@override
LanguageModelResult<AIChatMessage> concat(
ChatResult concat(
final LanguageModelResult<AIChatMessage> other,
) {
return ChatResult(
Expand Down
10 changes: 0 additions & 10 deletions packages/langchain_core/lib/src/language_models/base.dart
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,6 @@ abstract class BaseLanguageModel<
/// Return type of language model.
String get modelType;

/// Runs the Language Model on the given prompt value.
///
/// - [input] The prompt value to pass into the model.
/// - [options] Generation options to pass into the model.
@override
Future<Output> invoke(
final PromptValue input, {
final Options? options,
});

/// Tokenizes the given prompt using the encoding used by the language
/// model.
///
Expand Down
17 changes: 0 additions & 17 deletions packages/langchain_core/lib/src/llms/base.dart
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,6 @@ abstract class BaseLLM<Options extends LLMOptions>
required super.defaultOptions,
});

/// Runs the LLM on the given prompt value.
///
/// - [input] The prompt value to pass into the model.
/// - [options] Generation options to pass into the LLM.
///
/// Example:
/// ```dart
/// final result = await openai.invoke(
/// PromptValue.string('Tell me a joke.'),
/// );
/// ```
@override
Future<LLMResult> invoke(
final PromptValue input, {
final Options? options,
});

/// Runs the LLM on the given String prompt and returns a String with the
/// generated text.
///
Expand Down
22 changes: 20 additions & 2 deletions packages/langchain_core/lib/src/llms/fake.dart
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ import 'types.dart';
/// Fake LLM for testing.
/// You can pass in a list of responses to return in order when called.
/// {@endtemplate}
class FakeListLLM extends SimpleLLM {
class FakeLLM extends SimpleLLM {
/// {@macro fake_list_llm}
FakeListLLM({
FakeLLM({
required this.responses,
}) : super(defaultOptions: const LLMOptions());

Expand All @@ -29,6 +29,24 @@ class FakeListLLM extends SimpleLLM {
return Future<String>.value(responses[_i++ % responses.length]);
}

@override
Stream<LLMResult> stream(
final PromptValue input, {
final LLMOptions? options,
}) {
final res = responses[_i++ % responses.length].split('');
return Stream.fromIterable(res).map(
(final item) => LLMResult(
id: 'fake-echo',
output: item,
finishReason: FinishReason.unspecified,
metadata: const {},
usage: const LanguageModelUsage(),
streaming: true,
),
);
}

@override
Future<List<int>> tokenize(
final PromptValue promptValue, {
Expand Down
2 changes: 1 addition & 1 deletion packages/langchain_core/lib/src/llms/types.dart
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class LLMResult extends LanguageModelResult<String> {
String get outputAsString => output;

@override
LanguageModelResult<String> concat(
LLMResult concat(
final LanguageModelResult<String> other,
) {
return LLMResult(
Expand Down
42 changes: 6 additions & 36 deletions packages/langchain_core/lib/src/prompts/base_prompt.dart
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import 'package:meta/meta.dart';

import '../langchain/types.dart';
import '../runnables/runnable.dart';
import '../utils/reduce.dart';
import 'template.dart';
import 'types.dart';

Expand Down Expand Up @@ -68,46 +69,15 @@ abstract base class BasePromptTemplate
return Future.value(formatPrompt(input));
}

@override
Stream<PromptValue> stream(
final InputValues input, {
final BaseLangChainOptions? options,
}) {
return streamFromInputStream(
Stream.value(input).asBroadcastStream(),
options: options,
);
}

@override
Stream<PromptValue> streamFromInputStream(
final Stream<InputValues> inputStream, {
final BaseLangChainOptions? options,
}) {
final userKeys = inputVariables.difference(
partialVariables?.keys.toSet() ?? {},
);
final userInput = <String, dynamic>{};
return inputStream
.asyncMap((final InputValues inputValues) {
for (final input in inputValues.entries) {
final key = input.key;
final value = input.value;
if (value is String) {
userInput[key] = (userInput[key] as String? ?? '') + value;
} else {
userInput[key] = value;
}
}
final hasAllUserValues = userKeys.every(userInput.containsKey);
if (hasAllUserValues) {
return formatPrompt(userInput);
} else {
return null;
}
})
.where((final res) => res != null)
.cast();
}) async* {
final List<InputValues> input = await inputStream.toList();
final InputValues reduced =
input.isEmpty ? const {} : reduce<InputValues>(input);
yield* stream(reduced, options: options);
}

/// Format the prompt given the input values and return a formatted string.
Expand Down
Loading

0 comments on commit b59bcd4

Please sign in to comment.