-
-
Notifications
You must be signed in to change notification settings - Fork 63
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add support for ReduceDocumentsChain (#70)
- Loading branch information
1 parent
e22f22c
commit 34cf10b
Showing
7 changed files
with
336 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
1 change: 1 addition & 0 deletions
1
packages/langchain/lib/src/chains/combine_documents/combine_documents.dart
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
export 'base.dart'; | ||
export 'reduce.dart'; | ||
export 'stuff.dart'; |
201 changes: 201 additions & 0 deletions
201
packages/langchain/lib/src/chains/combine_documents/reduce.dart
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
import '../../documents/models/models.dart'; | ||
import '../../model_io/prompts/models/models.dart'; | ||
import 'base.dart'; | ||
import 'stuff.dart'; | ||
|
||
/// {@template reduce_documents_chain} | ||
/// Chain that combines documents by recursively reducing them if needed. | ||
/// | ||
/// This involves two chains: | ||
/// - [combineDocumentsChain] this is the chain that combines the documents. | ||
/// - [collapseDocumentsChain] this is the chain that collapses the documents | ||
/// if they exceed [defaultTokenMax]. | ||
/// | ||
/// The chain works as follows: | ||
/// - If the number of tokens resulting of formatting the prompt from | ||
/// [combineDocumentsChain] is less than [defaultTokenMax], then | ||
/// [combineDocumentsChain] is called with the documents and the result is | ||
/// returned. | ||
/// - Otherwise, the documents are split into groups of max [defaultTokenMax] | ||
/// tokens and [collapseDocumentsChain] is called for each group. Then, the | ||
/// resulting documents are combined by calling [combineDocumentsChain] and | ||
/// the result is returned. | ||
/// | ||
/// Example: | ||
/// ```dart | ||
/// final finalPrompt = PromptTemplate.fromTemplate( | ||
/// 'Summarize this content: {context}', | ||
/// ); | ||
/// final finalLlmChain = LLMChain(prompt: finalPrompt, llm: llm); | ||
/// final combineDocsChain = StuffDocumentsChain(llmChain: finalLlmChain); | ||
/// | ||
/// final collapsePrompt = PromptTemplate.fromTemplate( | ||
/// 'Collapse this content: {context}', | ||
/// ); | ||
/// final collapseLlmChain = LLMChain(prompt: collapsePrompt, llm: llm); | ||
/// final collapseDocsChain = StuffDocumentsChain(llmChain: collapseLlmChain); | ||
/// | ||
/// final reduceChain = ReduceDocumentsChain( | ||
/// combineDocumentsChain: combineDocsChain, | ||
/// collapseDocumentsChain: collapseDocsChain, | ||
/// ); | ||
/// | ||
/// const docs = [ | ||
/// Document(pageContent: 'Hello world 1!'), | ||
/// Document(pageContent: 'Hello world 2!'), | ||
/// Document(pageContent: 'Hello world 3!'), | ||
/// Document(pageContent: 'Hello world 4!'), | ||
/// ]; | ||
/// final res = await reduceChain.run(docs); | ||
/// ``` | ||
/// {@endtemplate} | ||
class ReduceDocumentsChain extends BaseCombineDocumentsChain { | ||
/// {@macro reduce_documents_chain} | ||
const ReduceDocumentsChain({ | ||
required this.combineDocumentsChain, | ||
this.collapseDocumentsChain, | ||
this.defaultTokenMax = 3000, | ||
}); | ||
|
||
/// Final chain to call to combine documents. | ||
/// This is typically a [StuffDocumentsChain]. | ||
final BaseCombineDocumentsChain combineDocumentsChain; | ||
|
||
/// Chain to use to collapse documents if needed until they can all fit. | ||
/// If null, [combineDocumentsChain] will be used. | ||
/// This is typically a [StuffDocumentsChain]. | ||
final BaseCombineDocumentsChain? collapseDocumentsChain; | ||
|
||
/// The maximum number of tokens to group documents into. For example, if | ||
/// set to 3000 then documents will be grouped into chunks of no greater than | ||
/// 3000 tokens before trying to combine them into a smaller chunk. | ||
/// | ||
/// This is useful to avoid exceeding the context size when combining the | ||
/// documents. | ||
/// | ||
/// It is assumed that each document to combine is less than | ||
/// [defaultTokenMax] tokens. | ||
final int defaultTokenMax; | ||
|
||
@override | ||
String get chainType => 'reduce_documents_chain'; | ||
|
||
@override | ||
Future<int?> promptLength( | ||
final List<Document> docs, { | ||
final InputValues inputs = const {}, | ||
}) async { | ||
// This combine method doesn't depend on the prompt length. | ||
return null; | ||
} | ||
|
||
/// Combine multiple documents. | ||
/// | ||
/// - [docs] the documents to combine. It is assumed that each one is less | ||
/// than [defaultTokenMax] tokens. | ||
/// - [inputs] additional parameters to be passed to LLM calls (like other | ||
/// input variables besides the documents). | ||
/// | ||
/// Returns a tuple of the output and any extra info to return. | ||
@override | ||
Future<(String output, Map<String, dynamic> extraInfo)> combineDocs( | ||
final List<Document> docs, { | ||
final InputValues inputs = const {}, | ||
final int? tokenMax, | ||
}) async { | ||
final resultDocs = await _splitAndCollapseDocs( | ||
docs, | ||
inputs: inputs, | ||
tokenMax: tokenMax, | ||
); | ||
return combineDocumentsChain.combineDocs(resultDocs, inputs: inputs); | ||
} | ||
|
||
/// Splits the documents into smaller chunks that are each less than | ||
/// [tokenMax] tokens. And then collapses them into a single document. | ||
Future<List<Document>> _splitAndCollapseDocs( | ||
final List<Document> docs, { | ||
final int? tokenMax, | ||
final InputValues inputs = const {}, | ||
}) async { | ||
final finalTokenMax = tokenMax ?? defaultTokenMax; | ||
final lengthFunc = combineDocumentsChain.promptLength; | ||
|
||
List<Document> resultDocs = docs; | ||
int? numTokens = await lengthFunc(resultDocs, inputs: inputs); | ||
|
||
while (numTokens != null && numTokens > finalTokenMax) { | ||
final newResultDocList = await _splitDocs( | ||
docs, | ||
inputs, | ||
lengthFunc, | ||
finalTokenMax, | ||
); | ||
resultDocs = []; | ||
for (final docs in newResultDocList) { | ||
final newDoc = await _collapseDocs(docs, inputs); | ||
resultDocs.add(newDoc); | ||
} | ||
numTokens = await lengthFunc(resultDocs, inputs: inputs); | ||
} | ||
|
||
return resultDocs; | ||
} | ||
|
||
/// Split a list of documents into smaller lists of documents that are each | ||
/// less than [tokenMax] tokens. | ||
Future<List<List<Document>>> _splitDocs( | ||
final List<Document> docs, | ||
final InputValues inputs, | ||
final Future<int?> Function( | ||
List<Document> docs, { | ||
InputValues inputs, | ||
}) lengthFunc, | ||
final int tokenMax, | ||
) async { | ||
final List<List<Document>> newResultDocList = []; | ||
List<Document> subResultDocs = []; | ||
|
||
for (final doc in docs) { | ||
subResultDocs.add(doc); | ||
final numTokens = await lengthFunc(subResultDocs, inputs: inputs); | ||
if (numTokens != null && numTokens > tokenMax) { | ||
assert( | ||
subResultDocs.length > 1, | ||
'We should never have a single document that is longer than the tokenMax.', | ||
); | ||
newResultDocList.add( | ||
subResultDocs.sublist(0, subResultDocs.length - 1), | ||
); | ||
subResultDocs = subResultDocs.sublist(subResultDocs.length - 1); | ||
} | ||
} | ||
newResultDocList.add(subResultDocs); | ||
return newResultDocList; | ||
} | ||
|
||
/// Combines multiple documents into one using [collapseDocumentsChain] (or | ||
/// [combineDocumentsChain] if [collapseDocumentsChain] is null). | ||
/// The metadata of the different documents is also combined. | ||
Future<Document> _collapseDocs( | ||
final List<Document> docs, | ||
final InputValues inputs, | ||
) async { | ||
final collapseChain = collapseDocumentsChain ?? combineDocumentsChain; | ||
final result = await collapseChain.run({ | ||
...inputs, | ||
BaseCombineDocumentsChain.defaultInputKey: docs, | ||
}); | ||
final combinedMetadata = {...docs[0].metadata}; | ||
for (var i = 1; i < docs.length; i++) { | ||
docs[i].metadata.forEach((final key, final value) { | ||
if (combinedMetadata.containsKey(key) && value is String) { | ||
combinedMetadata[key] = '${combinedMetadata[key]}, $value'; | ||
} else { | ||
combinedMetadata[key] = value; | ||
} | ||
}); | ||
} | ||
return Document(pageContent: result, metadata: combinedMetadata); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
77 changes: 77 additions & 0 deletions
77
packages/langchain/test/chains/combine_documents/reduce.dart
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import 'package:langchain/src/chains/chains.dart'; | ||
import 'package:langchain/src/documents/documents.dart'; | ||
import 'package:langchain/src/model_io/llms/fake.dart'; | ||
import 'package:langchain/src/model_io/prompts/prompts.dart'; | ||
import 'package:test/test.dart'; | ||
|
||
void main() { | ||
group('ReduceDocumentsChain tests', () { | ||
test('Test reduce', () async { | ||
final llm = FakeListLLM( | ||
responses: [ | ||
// Summarize this content: Hello 1!\n\nHello 2!\n\nHello 3!\n\nHello 4! | ||
'Hello 1234!', | ||
], | ||
); | ||
|
||
final finalPrompt = PromptTemplate.fromTemplate( | ||
'Summarize this content: {context}', | ||
); | ||
final finalLlmChain = LLMChain(prompt: finalPrompt, llm: llm); | ||
final combineDocsChain = StuffDocumentsChain(llmChain: finalLlmChain); | ||
|
||
final reduceChain = ReduceDocumentsChain( | ||
combineDocumentsChain: combineDocsChain, | ||
); | ||
|
||
const docs = [ | ||
Document(pageContent: 'Hello 1!'), | ||
Document(pageContent: 'Hello 2!'), | ||
Document(pageContent: 'Hello 3!'), | ||
Document(pageContent: 'Hello 4!'), | ||
]; | ||
final res = await reduceChain.run(docs); | ||
expect(res, 'Hello 1234!'); | ||
}); | ||
|
||
test('Test reduce and collapse', () async { | ||
final llm = FakeListLLM( | ||
responses: [ | ||
// Collapse this content: Hello 1!\n\nHello 2!\n\nHello 3! | ||
'Hello 123!', | ||
// Collapse this content: Hello 4! | ||
'Hello 4!', | ||
// Summarize this content: Hello 123!\n\nHello 4! | ||
'Hello 1234!', | ||
], | ||
); | ||
|
||
final finalPrompt = PromptTemplate.fromTemplate( | ||
'Summarize this content: {context}', | ||
); | ||
final finalLlmChain = LLMChain(prompt: finalPrompt, llm: llm); | ||
final combineDocsChain = StuffDocumentsChain(llmChain: finalLlmChain); | ||
|
||
final collapsePrompt = PromptTemplate.fromTemplate( | ||
'Collapse this content: {context}', | ||
); | ||
final collapseLlmChain = LLMChain(prompt: collapsePrompt, llm: llm); | ||
final collapseDocsChain = StuffDocumentsChain(llmChain: collapseLlmChain); | ||
|
||
final reduceChain = ReduceDocumentsChain( | ||
combineDocumentsChain: combineDocsChain, | ||
collapseDocumentsChain: collapseDocsChain, | ||
defaultTokenMax: 7, | ||
); | ||
|
||
const docs = [ | ||
Document(pageContent: 'Hello 1!'), | ||
Document(pageContent: 'Hello 2!'), | ||
Document(pageContent: 'Hello 3!'), | ||
Document(pageContent: 'Hello 4!'), | ||
]; | ||
final res = await reduceChain.run(docs); | ||
expect(res, 'Hello 1234!'); | ||
}); | ||
}); | ||
} |
Oops, something went wrong.