From 34cf10bd485618bff4cddb5b29a1b46ac9f3a9fa Mon Sep 17 00:00:00 2001 From: David Miguel Date: Wed, 12 Jul 2023 00:23:18 +0200 Subject: [PATCH] feat: Add support for ReduceDocumentsChain (#70) --- packages/langchain/lib/src/chains/base.dart | 11 +- .../src/chains/combine_documents/base.dart | 21 +- .../combine_documents/combine_documents.dart | 1 + .../src/chains/combine_documents/reduce.dart | 201 ++++++++++++++++++ .../src/chains/combine_documents/stuff.dart | 18 +- .../test/chains/combine_documents/reduce.dart | 77 +++++++ .../test/chains/combine_documents/stuff.dart | 20 +- 7 files changed, 336 insertions(+), 13 deletions(-) create mode 100644 packages/langchain/lib/src/chains/combine_documents/reduce.dart create mode 100644 packages/langchain/test/chains/combine_documents/reduce.dart diff --git a/packages/langchain/lib/src/chains/base.dart b/packages/langchain/lib/src/chains/base.dart index aabacacc..7ac25017 100644 --- a/packages/langchain/lib/src/chains/base.dart +++ b/packages/langchain/lib/src/chains/base.dart @@ -132,7 +132,8 @@ abstract class BaseChain { return Future.wait(inputs.map(call)); } - /// Convenience method for executing chain when there's a single output. + /// Convenience method for executing chain when there's a single string + /// output. /// /// The main difference between this method and [call] is that this method /// can only be used for chains that return a single output. If a Chain has @@ -146,7 +147,7 @@ abstract class BaseChain { /// Eg: `chain.run('Hello world!')` /// - A map of key->values, if the chain has multiple input keys. /// Eg: `chain.run({'foo': 'Hello', 'bar': 'world!'})` - Future run(final dynamic input) async { + Future run(final dynamic input) async { final outputKey = runOutputKey; final returnValues = await call(input, returnOnlyOutputs: true); return returnValues[outputKey]; @@ -161,12 +162,12 @@ abstract class BaseChain { final inputKeysSet = inputMap.keys.toSet(); final inputKeysSetLength = inputKeysSet.length; - if (inputKeysSetLength != inputKeys.length) { + if (inputKeysSetLength < inputKeys.length) { return false; } - final inputKeysSetDiff = inputKeysSet.difference(inputKeys); - if (inputKeysSetDiff.isNotEmpty) { + final inputKeysDiff = inputKeys.difference(inputKeysSet); + if (inputKeysDiff.isNotEmpty) { return false; } diff --git a/packages/langchain/lib/src/chains/combine_documents/base.dart b/packages/langchain/lib/src/chains/combine_documents/base.dart index 1f311039..2e7b5c8c 100644 --- a/packages/langchain/lib/src/chains/combine_documents/base.dart +++ b/packages/langchain/lib/src/chains/combine_documents/base.dart @@ -61,15 +61,30 @@ abstract class BaseCombineDocumentsChain extends BaseChain { }; } - // TODO add promptLength method to base chain the prompt length given the documents passed in + /// Returns the prompt length (number of tokens) given the documents passed + /// in. + /// + /// This can be used by a caller to determine whether passing in a list of + /// documents would exceed a certain prompt length. This useful when trying + /// to ensure that the size of a prompt remains below a certain context limit. + /// + /// - [docs] is the list of documents to combine. + /// - [inputs] is a map of other inputs to use in the combination. + /// + /// Returns null if the combine method doesn't depend on the prompt length. + /// Otherwise, the length of the prompt in tokens. + Future promptLength( + final List docs, { + final InputValues inputs = const {}, + }); /// Combines the given [docs] into a single string. /// /// - [docs] is the list of documents to combine. /// - [inputs] is a map of other inputs to use in the combination. /// - /// Returns a tuple of the output string and any extra info to return. - Future<(dynamic output, Map extraInfo)> combineDocs( + /// Returns a tuple of the output and any extra info to return. + Future<(String output, Map extraInfo)> combineDocs( final List docs, { final InputValues inputs = const {}, }); diff --git a/packages/langchain/lib/src/chains/combine_documents/combine_documents.dart b/packages/langchain/lib/src/chains/combine_documents/combine_documents.dart index cd6328fc..de943631 100644 --- a/packages/langchain/lib/src/chains/combine_documents/combine_documents.dart +++ b/packages/langchain/lib/src/chains/combine_documents/combine_documents.dart @@ -1,2 +1,3 @@ export 'base.dart'; +export 'reduce.dart'; export 'stuff.dart'; diff --git a/packages/langchain/lib/src/chains/combine_documents/reduce.dart b/packages/langchain/lib/src/chains/combine_documents/reduce.dart new file mode 100644 index 00000000..dd13c443 --- /dev/null +++ b/packages/langchain/lib/src/chains/combine_documents/reduce.dart @@ -0,0 +1,201 @@ +import '../../documents/models/models.dart'; +import '../../model_io/prompts/models/models.dart'; +import 'base.dart'; +import 'stuff.dart'; + +/// {@template reduce_documents_chain} +/// Chain that combines documents by recursively reducing them if needed. +/// +/// This involves two chains: +/// - [combineDocumentsChain] this is the chain that combines the documents. +/// - [collapseDocumentsChain] this is the chain that collapses the documents +/// if they exceed [defaultTokenMax]. +/// +/// The chain works as follows: +/// - If the number of tokens resulting of formatting the prompt from +/// [combineDocumentsChain] is less than [defaultTokenMax], then +/// [combineDocumentsChain] is called with the documents and the result is +/// returned. +/// - Otherwise, the documents are split into groups of max [defaultTokenMax] +/// tokens and [collapseDocumentsChain] is called for each group. Then, the +/// resulting documents are combined by calling [combineDocumentsChain] and +/// the result is returned. +/// +/// Example: +/// ```dart +/// final finalPrompt = PromptTemplate.fromTemplate( +/// 'Summarize this content: {context}', +/// ); +/// final finalLlmChain = LLMChain(prompt: finalPrompt, llm: llm); +/// final combineDocsChain = StuffDocumentsChain(llmChain: finalLlmChain); +/// +/// final collapsePrompt = PromptTemplate.fromTemplate( +/// 'Collapse this content: {context}', +/// ); +/// final collapseLlmChain = LLMChain(prompt: collapsePrompt, llm: llm); +/// final collapseDocsChain = StuffDocumentsChain(llmChain: collapseLlmChain); +/// +/// final reduceChain = ReduceDocumentsChain( +/// combineDocumentsChain: combineDocsChain, +/// collapseDocumentsChain: collapseDocsChain, +/// ); +/// +/// const docs = [ +/// Document(pageContent: 'Hello world 1!'), +/// Document(pageContent: 'Hello world 2!'), +/// Document(pageContent: 'Hello world 3!'), +/// Document(pageContent: 'Hello world 4!'), +/// ]; +/// final res = await reduceChain.run(docs); +/// ``` +/// {@endtemplate} +class ReduceDocumentsChain extends BaseCombineDocumentsChain { + /// {@macro reduce_documents_chain} + const ReduceDocumentsChain({ + required this.combineDocumentsChain, + this.collapseDocumentsChain, + this.defaultTokenMax = 3000, + }); + + /// Final chain to call to combine documents. + /// This is typically a [StuffDocumentsChain]. + final BaseCombineDocumentsChain combineDocumentsChain; + + /// Chain to use to collapse documents if needed until they can all fit. + /// If null, [combineDocumentsChain] will be used. + /// This is typically a [StuffDocumentsChain]. + final BaseCombineDocumentsChain? collapseDocumentsChain; + + /// The maximum number of tokens to group documents into. For example, if + /// set to 3000 then documents will be grouped into chunks of no greater than + /// 3000 tokens before trying to combine them into a smaller chunk. + /// + /// This is useful to avoid exceeding the context size when combining the + /// documents. + /// + /// It is assumed that each document to combine is less than + /// [defaultTokenMax] tokens. + final int defaultTokenMax; + + @override + String get chainType => 'reduce_documents_chain'; + + @override + Future promptLength( + final List docs, { + final InputValues inputs = const {}, + }) async { + // This combine method doesn't depend on the prompt length. + return null; + } + + /// Combine multiple documents. + /// + /// - [docs] the documents to combine. It is assumed that each one is less + /// than [defaultTokenMax] tokens. + /// - [inputs] additional parameters to be passed to LLM calls (like other + /// input variables besides the documents). + /// + /// Returns a tuple of the output and any extra info to return. + @override + Future<(String output, Map extraInfo)> combineDocs( + final List docs, { + final InputValues inputs = const {}, + final int? tokenMax, + }) async { + final resultDocs = await _splitAndCollapseDocs( + docs, + inputs: inputs, + tokenMax: tokenMax, + ); + return combineDocumentsChain.combineDocs(resultDocs, inputs: inputs); + } + + /// Splits the documents into smaller chunks that are each less than + /// [tokenMax] tokens. And then collapses them into a single document. + Future> _splitAndCollapseDocs( + final List docs, { + final int? tokenMax, + final InputValues inputs = const {}, + }) async { + final finalTokenMax = tokenMax ?? defaultTokenMax; + final lengthFunc = combineDocumentsChain.promptLength; + + List resultDocs = docs; + int? numTokens = await lengthFunc(resultDocs, inputs: inputs); + + while (numTokens != null && numTokens > finalTokenMax) { + final newResultDocList = await _splitDocs( + docs, + inputs, + lengthFunc, + finalTokenMax, + ); + resultDocs = []; + for (final docs in newResultDocList) { + final newDoc = await _collapseDocs(docs, inputs); + resultDocs.add(newDoc); + } + numTokens = await lengthFunc(resultDocs, inputs: inputs); + } + + return resultDocs; + } + + /// Split a list of documents into smaller lists of documents that are each + /// less than [tokenMax] tokens. + Future>> _splitDocs( + final List docs, + final InputValues inputs, + final Future Function( + List docs, { + InputValues inputs, + }) lengthFunc, + final int tokenMax, + ) async { + final List> newResultDocList = []; + List subResultDocs = []; + + for (final doc in docs) { + subResultDocs.add(doc); + final numTokens = await lengthFunc(subResultDocs, inputs: inputs); + if (numTokens != null && numTokens > tokenMax) { + assert( + subResultDocs.length > 1, + 'We should never have a single document that is longer than the tokenMax.', + ); + newResultDocList.add( + subResultDocs.sublist(0, subResultDocs.length - 1), + ); + subResultDocs = subResultDocs.sublist(subResultDocs.length - 1); + } + } + newResultDocList.add(subResultDocs); + return newResultDocList; + } + + /// Combines multiple documents into one using [collapseDocumentsChain] (or + /// [combineDocumentsChain] if [collapseDocumentsChain] is null). + /// The metadata of the different documents is also combined. + Future _collapseDocs( + final List docs, + final InputValues inputs, + ) async { + final collapseChain = collapseDocumentsChain ?? combineDocumentsChain; + final result = await collapseChain.run({ + ...inputs, + BaseCombineDocumentsChain.defaultInputKey: docs, + }); + final combinedMetadata = {...docs[0].metadata}; + for (var i = 1; i < docs.length; i++) { + docs[i].metadata.forEach((final key, final value) { + if (combinedMetadata.containsKey(key) && value is String) { + combinedMetadata[key] = '${combinedMetadata[key]}, $value'; + } else { + combinedMetadata[key] = value; + } + }); + } + return Document(pageContent: result, metadata: combinedMetadata); + } +} diff --git a/packages/langchain/lib/src/chains/combine_documents/stuff.dart b/packages/langchain/lib/src/chains/combine_documents/stuff.dart index bf55c9ea..f9dc9ddd 100644 --- a/packages/langchain/lib/src/chains/combine_documents/stuff.dart +++ b/packages/langchain/lib/src/chains/combine_documents/stuff.dart @@ -9,9 +9,9 @@ import 'base.dart'; /// This chain takes a list of documents and first combines them into a single /// string. It does this by formatting each document into a string with the /// [documentPrompt] and then joining them together with [documentSeparator]. -/// It then adds that new string to the inputs with the variable name set by -/// [llmChainStuffedDocumentInputKey]. Those inputs are then passed to the -/// [llmChain]. +/// It then adds that new resulting string to the inputs with the variable +/// name set by [llmChainStuffedDocumentInputKey]. Those inputs are then +/// passed to the [llmChain]. /// /// The content of each document is formatted using [documentPrompt]. /// By default, it just takes the content of the document. @@ -110,6 +110,16 @@ class StuffDocumentsChain extends BaseCombineDocumentsChain { } } + @override + Future promptLength( + final List docs, { + final InputValues inputs = const {}, + }) { + final llmInputs = _getInputs(docs, inputs); + final prompt = llmChain.prompt.formatPrompt(llmInputs); + return llmChain.llm.countTokens(prompt); + } + /// Stuff all documents into one prompt and pass to LLM. /// /// - [docs] the documents to combine. @@ -117,7 +127,7 @@ class StuffDocumentsChain extends BaseCombineDocumentsChain { /// /// Returns a tuple of the output string and any extra info to return. @override - Future<(dynamic output, Map extraInfo)> combineDocs( + Future<(String output, Map extraInfo)> combineDocs( final List docs, { final InputValues inputs = const {}, }) async { diff --git a/packages/langchain/test/chains/combine_documents/reduce.dart b/packages/langchain/test/chains/combine_documents/reduce.dart new file mode 100644 index 00000000..736eda6f --- /dev/null +++ b/packages/langchain/test/chains/combine_documents/reduce.dart @@ -0,0 +1,77 @@ +import 'package:langchain/src/chains/chains.dart'; +import 'package:langchain/src/documents/documents.dart'; +import 'package:langchain/src/model_io/llms/fake.dart'; +import 'package:langchain/src/model_io/prompts/prompts.dart'; +import 'package:test/test.dart'; + +void main() { + group('ReduceDocumentsChain tests', () { + test('Test reduce', () async { + final llm = FakeListLLM( + responses: [ + // Summarize this content: Hello 1!\n\nHello 2!\n\nHello 3!\n\nHello 4! + 'Hello 1234!', + ], + ); + + final finalPrompt = PromptTemplate.fromTemplate( + 'Summarize this content: {context}', + ); + final finalLlmChain = LLMChain(prompt: finalPrompt, llm: llm); + final combineDocsChain = StuffDocumentsChain(llmChain: finalLlmChain); + + final reduceChain = ReduceDocumentsChain( + combineDocumentsChain: combineDocsChain, + ); + + const docs = [ + Document(pageContent: 'Hello 1!'), + Document(pageContent: 'Hello 2!'), + Document(pageContent: 'Hello 3!'), + Document(pageContent: 'Hello 4!'), + ]; + final res = await reduceChain.run(docs); + expect(res, 'Hello 1234!'); + }); + + test('Test reduce and collapse', () async { + final llm = FakeListLLM( + responses: [ + // Collapse this content: Hello 1!\n\nHello 2!\n\nHello 3! + 'Hello 123!', + // Collapse this content: Hello 4! + 'Hello 4!', + // Summarize this content: Hello 123!\n\nHello 4! + 'Hello 1234!', + ], + ); + + final finalPrompt = PromptTemplate.fromTemplate( + 'Summarize this content: {context}', + ); + final finalLlmChain = LLMChain(prompt: finalPrompt, llm: llm); + final combineDocsChain = StuffDocumentsChain(llmChain: finalLlmChain); + + final collapsePrompt = PromptTemplate.fromTemplate( + 'Collapse this content: {context}', + ); + final collapseLlmChain = LLMChain(prompt: collapsePrompt, llm: llm); + final collapseDocsChain = StuffDocumentsChain(llmChain: collapseLlmChain); + + final reduceChain = ReduceDocumentsChain( + combineDocumentsChain: combineDocsChain, + collapseDocumentsChain: collapseDocsChain, + defaultTokenMax: 7, + ); + + const docs = [ + Document(pageContent: 'Hello 1!'), + Document(pageContent: 'Hello 2!'), + Document(pageContent: 'Hello 3!'), + Document(pageContent: 'Hello 4!'), + ]; + final res = await reduceChain.run(docs); + expect(res, 'Hello 1234!'); + }); + }); +} diff --git a/packages/langchain/test/chains/combine_documents/stuff.dart b/packages/langchain/test/chains/combine_documents/stuff.dart index 194bb96a..c3b1cd03 100644 --- a/packages/langchain/test/chains/combine_documents/stuff.dart +++ b/packages/langchain/test/chains/combine_documents/stuff.dart @@ -6,7 +6,7 @@ import 'package:test/test.dart'; void main() { group('StuffDocumentsChain tests', () { - test('Test LLMChain call', () async { + test('Test StuffDocumentsChain call', () async { const model = FakeEchoLLM(); final prompt = PromptTemplate.fromTemplate( 'Print {foo}. Context: {context}', @@ -30,5 +30,23 @@ void main() { 'Print Hello world!. Context: Hello 1!\n\nHello 2!', ); }); + + test('Test promptLength', () async { + const model = FakeEchoLLM(); + final prompt = PromptTemplate.fromTemplate( + 'Print {foo}. Context: {context}', + ); + final llmChain = LLMChain(prompt: prompt, llm: model); + final stuffChain = StuffDocumentsChain(llmChain: llmChain); + + const foo = 'Hello world!'; + const docs = [ + Document(pageContent: 'Hello 1!'), + Document(pageContent: 'Hello 2!'), + ]; + + final tokens = await stuffChain.promptLength(docs, inputs: {'foo': foo}); + expect(tokens, 7); + }); }); }