Skip to content

Commit

Permalink
feat: Add support for ReduceDocumentsChain (#70)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmigloz committed Jul 11, 2023
1 parent e22f22c commit 34cf10b
Show file tree
Hide file tree
Showing 7 changed files with 336 additions and 13 deletions.
11 changes: 6 additions & 5 deletions packages/langchain/lib/src/chains/base.dart
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ abstract class BaseChain {
return Future.wait(inputs.map(call));
}

/// Convenience method for executing chain when there's a single output.
/// Convenience method for executing chain when there's a single string
/// output.
///
/// The main difference between this method and [call] is that this method
/// can only be used for chains that return a single output. If a Chain has
Expand All @@ -146,7 +147,7 @@ abstract class BaseChain {
/// Eg: `chain.run('Hello world!')`
/// - A map of key->values, if the chain has multiple input keys.
/// Eg: `chain.run({'foo': 'Hello', 'bar': 'world!'})`
Future<dynamic> run(final dynamic input) async {
Future<String> run(final dynamic input) async {
final outputKey = runOutputKey;
final returnValues = await call(input, returnOnlyOutputs: true);
return returnValues[outputKey];
Expand All @@ -161,12 +162,12 @@ abstract class BaseChain {
final inputKeysSet = inputMap.keys.toSet();
final inputKeysSetLength = inputKeysSet.length;

if (inputKeysSetLength != inputKeys.length) {
if (inputKeysSetLength < inputKeys.length) {
return false;
}

final inputKeysSetDiff = inputKeysSet.difference(inputKeys);
if (inputKeysSetDiff.isNotEmpty) {
final inputKeysDiff = inputKeys.difference(inputKeysSet);
if (inputKeysDiff.isNotEmpty) {
return false;
}

Expand Down
21 changes: 18 additions & 3 deletions packages/langchain/lib/src/chains/combine_documents/base.dart
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,30 @@ abstract class BaseCombineDocumentsChain extends BaseChain {
};
}

// TODO add promptLength method to base chain the prompt length given the documents passed in
/// Returns the prompt length (number of tokens) given the documents passed
/// in.
///
/// This can be used by a caller to determine whether passing in a list of
/// documents would exceed a certain prompt length. This useful when trying
/// to ensure that the size of a prompt remains below a certain context limit.
///
/// - [docs] is the list of documents to combine.
/// - [inputs] is a map of other inputs to use in the combination.
///
/// Returns null if the combine method doesn't depend on the prompt length.
/// Otherwise, the length of the prompt in tokens.
Future<int?> promptLength(
final List<Document> docs, {
final InputValues inputs = const {},
});

/// Combines the given [docs] into a single string.
///
/// - [docs] is the list of documents to combine.
/// - [inputs] is a map of other inputs to use in the combination.
///
/// Returns a tuple of the output string and any extra info to return.
Future<(dynamic output, Map<String, dynamic> extraInfo)> combineDocs(
/// Returns a tuple of the output and any extra info to return.
Future<(String output, Map<String, dynamic> extraInfo)> combineDocs(
final List<Document> docs, {
final InputValues inputs = const {},
});
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
export 'base.dart';
export 'reduce.dart';
export 'stuff.dart';
201 changes: 201 additions & 0 deletions packages/langchain/lib/src/chains/combine_documents/reduce.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
import '../../documents/models/models.dart';
import '../../model_io/prompts/models/models.dart';
import 'base.dart';
import 'stuff.dart';

/// {@template reduce_documents_chain}
/// Chain that combines documents by recursively reducing them if needed.
///
/// This involves two chains:
/// - [combineDocumentsChain] this is the chain that combines the documents.
/// - [collapseDocumentsChain] this is the chain that collapses the documents
/// if they exceed [defaultTokenMax].
///
/// The chain works as follows:
/// - If the number of tokens resulting of formatting the prompt from
/// [combineDocumentsChain] is less than [defaultTokenMax], then
/// [combineDocumentsChain] is called with the documents and the result is
/// returned.
/// - Otherwise, the documents are split into groups of max [defaultTokenMax]
/// tokens and [collapseDocumentsChain] is called for each group. Then, the
/// resulting documents are combined by calling [combineDocumentsChain] and
/// the result is returned.
///
/// Example:
/// ```dart
/// final finalPrompt = PromptTemplate.fromTemplate(
/// 'Summarize this content: {context}',
/// );
/// final finalLlmChain = LLMChain(prompt: finalPrompt, llm: llm);
/// final combineDocsChain = StuffDocumentsChain(llmChain: finalLlmChain);
///
/// final collapsePrompt = PromptTemplate.fromTemplate(
/// 'Collapse this content: {context}',
/// );
/// final collapseLlmChain = LLMChain(prompt: collapsePrompt, llm: llm);
/// final collapseDocsChain = StuffDocumentsChain(llmChain: collapseLlmChain);
///
/// final reduceChain = ReduceDocumentsChain(
/// combineDocumentsChain: combineDocsChain,
/// collapseDocumentsChain: collapseDocsChain,
/// );
///
/// const docs = [
/// Document(pageContent: 'Hello world 1!'),
/// Document(pageContent: 'Hello world 2!'),
/// Document(pageContent: 'Hello world 3!'),
/// Document(pageContent: 'Hello world 4!'),
/// ];
/// final res = await reduceChain.run(docs);
/// ```
/// {@endtemplate}
class ReduceDocumentsChain extends BaseCombineDocumentsChain {
/// {@macro reduce_documents_chain}
const ReduceDocumentsChain({
required this.combineDocumentsChain,
this.collapseDocumentsChain,
this.defaultTokenMax = 3000,
});

/// Final chain to call to combine documents.
/// This is typically a [StuffDocumentsChain].
final BaseCombineDocumentsChain combineDocumentsChain;

/// Chain to use to collapse documents if needed until they can all fit.
/// If null, [combineDocumentsChain] will be used.
/// This is typically a [StuffDocumentsChain].
final BaseCombineDocumentsChain? collapseDocumentsChain;

/// The maximum number of tokens to group documents into. For example, if
/// set to 3000 then documents will be grouped into chunks of no greater than
/// 3000 tokens before trying to combine them into a smaller chunk.
///
/// This is useful to avoid exceeding the context size when combining the
/// documents.
///
/// It is assumed that each document to combine is less than
/// [defaultTokenMax] tokens.
final int defaultTokenMax;

@override
String get chainType => 'reduce_documents_chain';

@override
Future<int?> promptLength(
final List<Document> docs, {
final InputValues inputs = const {},
}) async {
// This combine method doesn't depend on the prompt length.
return null;
}

/// Combine multiple documents.
///
/// - [docs] the documents to combine. It is assumed that each one is less
/// than [defaultTokenMax] tokens.
/// - [inputs] additional parameters to be passed to LLM calls (like other
/// input variables besides the documents).
///
/// Returns a tuple of the output and any extra info to return.
@override
Future<(String output, Map<String, dynamic> extraInfo)> combineDocs(
final List<Document> docs, {
final InputValues inputs = const {},
final int? tokenMax,
}) async {
final resultDocs = await _splitAndCollapseDocs(
docs,
inputs: inputs,
tokenMax: tokenMax,
);
return combineDocumentsChain.combineDocs(resultDocs, inputs: inputs);
}

/// Splits the documents into smaller chunks that are each less than
/// [tokenMax] tokens. And then collapses them into a single document.
Future<List<Document>> _splitAndCollapseDocs(
final List<Document> docs, {
final int? tokenMax,
final InputValues inputs = const {},
}) async {
final finalTokenMax = tokenMax ?? defaultTokenMax;
final lengthFunc = combineDocumentsChain.promptLength;

List<Document> resultDocs = docs;
int? numTokens = await lengthFunc(resultDocs, inputs: inputs);

while (numTokens != null && numTokens > finalTokenMax) {
final newResultDocList = await _splitDocs(
docs,
inputs,
lengthFunc,
finalTokenMax,
);
resultDocs = [];
for (final docs in newResultDocList) {
final newDoc = await _collapseDocs(docs, inputs);
resultDocs.add(newDoc);
}
numTokens = await lengthFunc(resultDocs, inputs: inputs);
}

return resultDocs;
}

/// Split a list of documents into smaller lists of documents that are each
/// less than [tokenMax] tokens.
Future<List<List<Document>>> _splitDocs(
final List<Document> docs,
final InputValues inputs,
final Future<int?> Function(
List<Document> docs, {
InputValues inputs,
}) lengthFunc,
final int tokenMax,
) async {
final List<List<Document>> newResultDocList = [];
List<Document> subResultDocs = [];

for (final doc in docs) {
subResultDocs.add(doc);
final numTokens = await lengthFunc(subResultDocs, inputs: inputs);
if (numTokens != null && numTokens > tokenMax) {
assert(
subResultDocs.length > 1,
'We should never have a single document that is longer than the tokenMax.',
);
newResultDocList.add(
subResultDocs.sublist(0, subResultDocs.length - 1),
);
subResultDocs = subResultDocs.sublist(subResultDocs.length - 1);
}
}
newResultDocList.add(subResultDocs);
return newResultDocList;
}

/// Combines multiple documents into one using [collapseDocumentsChain] (or
/// [combineDocumentsChain] if [collapseDocumentsChain] is null).
/// The metadata of the different documents is also combined.
Future<Document> _collapseDocs(
final List<Document> docs,
final InputValues inputs,
) async {
final collapseChain = collapseDocumentsChain ?? combineDocumentsChain;
final result = await collapseChain.run({
...inputs,
BaseCombineDocumentsChain.defaultInputKey: docs,
});
final combinedMetadata = {...docs[0].metadata};
for (var i = 1; i < docs.length; i++) {
docs[i].metadata.forEach((final key, final value) {
if (combinedMetadata.containsKey(key) && value is String) {
combinedMetadata[key] = '${combinedMetadata[key]}, $value';
} else {
combinedMetadata[key] = value;
}
});
}
return Document(pageContent: result, metadata: combinedMetadata);
}
}
18 changes: 14 additions & 4 deletions packages/langchain/lib/src/chains/combine_documents/stuff.dart
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ import 'base.dart';
/// This chain takes a list of documents and first combines them into a single
/// string. It does this by formatting each document into a string with the
/// [documentPrompt] and then joining them together with [documentSeparator].
/// It then adds that new string to the inputs with the variable name set by
/// [llmChainStuffedDocumentInputKey]. Those inputs are then passed to the
/// [llmChain].
/// It then adds that new resulting string to the inputs with the variable
/// name set by [llmChainStuffedDocumentInputKey]. Those inputs are then
/// passed to the [llmChain].
///
/// The content of each document is formatted using [documentPrompt].
/// By default, it just takes the content of the document.
Expand Down Expand Up @@ -110,14 +110,24 @@ class StuffDocumentsChain extends BaseCombineDocumentsChain {
}
}

@override
Future<int?> promptLength(
final List<Document> docs, {
final InputValues inputs = const {},
}) {
final llmInputs = _getInputs(docs, inputs);
final prompt = llmChain.prompt.formatPrompt(llmInputs);
return llmChain.llm.countTokens(prompt);
}

/// Stuff all documents into one prompt and pass to LLM.
///
/// - [docs] the documents to combine.
/// - [inputs] the inputs to pass to the [llmChain].
///
/// Returns a tuple of the output string and any extra info to return.
@override
Future<(dynamic output, Map<String, dynamic> extraInfo)> combineDocs(
Future<(String output, Map<String, dynamic> extraInfo)> combineDocs(
final List<Document> docs, {
final InputValues inputs = const {},
}) async {
Expand Down
77 changes: 77 additions & 0 deletions packages/langchain/test/chains/combine_documents/reduce.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import 'package:langchain/src/chains/chains.dart';
import 'package:langchain/src/documents/documents.dart';
import 'package:langchain/src/model_io/llms/fake.dart';
import 'package:langchain/src/model_io/prompts/prompts.dart';
import 'package:test/test.dart';

void main() {
group('ReduceDocumentsChain tests', () {
test('Test reduce', () async {
final llm = FakeListLLM(
responses: [
// Summarize this content: Hello 1!\n\nHello 2!\n\nHello 3!\n\nHello 4!
'Hello 1234!',
],
);

final finalPrompt = PromptTemplate.fromTemplate(
'Summarize this content: {context}',
);
final finalLlmChain = LLMChain(prompt: finalPrompt, llm: llm);
final combineDocsChain = StuffDocumentsChain(llmChain: finalLlmChain);

final reduceChain = ReduceDocumentsChain(
combineDocumentsChain: combineDocsChain,
);

const docs = [
Document(pageContent: 'Hello 1!'),
Document(pageContent: 'Hello 2!'),
Document(pageContent: 'Hello 3!'),
Document(pageContent: 'Hello 4!'),
];
final res = await reduceChain.run(docs);
expect(res, 'Hello 1234!');
});

test('Test reduce and collapse', () async {
final llm = FakeListLLM(
responses: [
// Collapse this content: Hello 1!\n\nHello 2!\n\nHello 3!
'Hello 123!',
// Collapse this content: Hello 4!
'Hello 4!',
// Summarize this content: Hello 123!\n\nHello 4!
'Hello 1234!',
],
);

final finalPrompt = PromptTemplate.fromTemplate(
'Summarize this content: {context}',
);
final finalLlmChain = LLMChain(prompt: finalPrompt, llm: llm);
final combineDocsChain = StuffDocumentsChain(llmChain: finalLlmChain);

final collapsePrompt = PromptTemplate.fromTemplate(
'Collapse this content: {context}',
);
final collapseLlmChain = LLMChain(prompt: collapsePrompt, llm: llm);
final collapseDocsChain = StuffDocumentsChain(llmChain: collapseLlmChain);

final reduceChain = ReduceDocumentsChain(
combineDocumentsChain: combineDocsChain,
collapseDocumentsChain: collapseDocsChain,
defaultTokenMax: 7,
);

const docs = [
Document(pageContent: 'Hello 1!'),
Document(pageContent: 'Hello 2!'),
Document(pageContent: 'Hello 3!'),
Document(pageContent: 'Hello 4!'),
];
final res = await reduceChain.run(docs);
expect(res, 'Hello 1234!');
});
});
}

0 comments on commit 34cf10b

Please sign in to comment.