diff --git a/packages/langchain/lib/src/chains/chains.dart b/packages/langchain/lib/src/chains/chains.dart index 0040f804..367cc22d 100644 --- a/packages/langchain/lib/src/chains/chains.dart +++ b/packages/langchain/lib/src/chains/chains.dart @@ -6,3 +6,4 @@ export 'models/models.dart'; export 'question_answering/question_answering.dart'; export 'retrieval_qa.dart'; export 'sequential.dart'; +export 'summarization/summarization.dart'; diff --git a/packages/langchain/lib/src/chains/combine_documents/map_reduce.dart b/packages/langchain/lib/src/chains/combine_documents/map_reduce.dart index e88f2b67..279eac83 100644 --- a/packages/langchain/lib/src/chains/combine_documents/map_reduce.dart +++ b/packages/langchain/lib/src/chains/combine_documents/map_reduce.dart @@ -49,7 +49,7 @@ class MapReduceDocumentsChain extends BaseCombineDocumentsChain { required this.reduceDocumentsChain, super.inputKey = defaultInputKey, super.outputKey = defaultOutputKey, - this.llmChainDocumentPromptVar = defaultLlmChainDocumentPromptVar, + this.mapLlmChainDocumentPromptVar = defaultLlmChainDocumentPromptVar, this.returnIntermediateSteps = false, }) { _initLlmChainDocumentPromptVar(); @@ -64,7 +64,7 @@ class MapReduceDocumentsChain extends BaseCombineDocumentsChain { /// The variable name in the [mapLlmChain] where to put the documents in. /// If only one variable in the [mapLlmChain], this doesn't need to be provided. - String llmChainDocumentPromptVar; + String mapLlmChainDocumentPromptVar; /// Return the results of the map steps in the output. final bool returnIntermediateSteps; @@ -77,7 +77,7 @@ class MapReduceDocumentsChain extends BaseCombineDocumentsChain { static const String defaultOutputKey = BaseCombineDocumentsChain.defaultOutputKey; - /// Default value for [llmChainDocumentPromptVar]. + /// Default value for [mapLlmChainDocumentPromptVar]. static const String defaultLlmChainDocumentPromptVar = 'context'; /// Output key for the chain's intermediate steps output. @@ -104,15 +104,15 @@ class MapReduceDocumentsChain extends BaseCombineDocumentsChain { // with this variable name final llmChainInputVariables = mapLlmChain.prompt.inputVariables; if (llmChainInputVariables.length == 1) { - llmChainDocumentPromptVar = llmChainInputVariables.first; - } else if (llmChainDocumentPromptVar.isEmpty) { + mapLlmChainDocumentPromptVar = llmChainInputVariables.first; + } else if (mapLlmChainDocumentPromptVar.isEmpty) { throw ArgumentError( 'llmChainDocumentPromptVar must be provided if there are multiple ' 'llmChain input variables', ); - } else if (!llmChainInputVariables.contains(llmChainDocumentPromptVar)) { + } else if (!llmChainInputVariables.contains(mapLlmChainDocumentPromptVar)) { throw ArgumentError( - 'llmChainDocumentPromptVar ($llmChainDocumentPromptVar) was not found ' + 'llmChainDocumentPromptVar ($mapLlmChainDocumentPromptVar) was not found ' 'in llmChain input variables', ); } @@ -143,7 +143,12 @@ class MapReduceDocumentsChain extends BaseCombineDocumentsChain { }) async { final mapResults = await mapLlmChain.apply( docs - .map((final d) => {...inputs, llmChainDocumentPromptVar: d}) + .map( + (final d) => { + ...inputs, + mapLlmChainDocumentPromptVar: d.pageContent, + }, + ) .toList(growable: false), ); diff --git a/packages/langchain/lib/src/chains/combine_documents/reduce.dart b/packages/langchain/lib/src/chains/combine_documents/reduce.dart index 554cabea..95c68861 100644 --- a/packages/langchain/lib/src/chains/combine_documents/reduce.dart +++ b/packages/langchain/lib/src/chains/combine_documents/reduce.dart @@ -56,7 +56,7 @@ class ReduceDocumentsChain extends BaseCombineDocumentsChain { super.outputKey = defaultOutputKey, required this.combineDocumentsChain, this.collapseDocumentsChain, - this.tokenMax = 3000, + this.tokenMax = defaultTokenMax, }); /// Final chain to call to combine documents. @@ -80,12 +80,13 @@ class ReduceDocumentsChain extends BaseCombineDocumentsChain { final int tokenMax; /// Default [inputKey] value. - static const String defaultInputKey = - BaseCombineDocumentsChain.defaultInputKey; + static const defaultInputKey = BaseCombineDocumentsChain.defaultInputKey; /// Default [outputKey] value. - static const String defaultOutputKey = - BaseCombineDocumentsChain.defaultOutputKey; + static const defaultOutputKey = BaseCombineDocumentsChain.defaultOutputKey; + + /// Default [tokenMax] value. + static const defaultTokenMax = 3000; @override String get chainType => 'reduce_documents_chain'; diff --git a/packages/langchain/lib/src/chains/combine_documents/stuff.dart b/packages/langchain/lib/src/chains/combine_documents/stuff.dart index 613c1209..89a946c1 100644 --- a/packages/langchain/lib/src/chains/combine_documents/stuff.dart +++ b/packages/langchain/lib/src/chains/combine_documents/stuff.dart @@ -41,13 +41,10 @@ class StuffDocumentsChain extends BaseCombineDocumentsChain { required this.llmChain, super.inputKey = defaultInputKey, super.outputKey = defaultOutputKey, - this.documentPrompt = const PromptTemplate( - inputVariables: {StuffDocumentsChain.pageContentPromptVar}, - template: '{${StuffDocumentsChain.pageContentPromptVar}}', - ), + this.documentPrompt = defaultDocumentPrompt, + this.documentSeparator = defaultDocumentSeparator, this.llmChainStuffedDocumentPromptVar = defaultLlmChainStuffedDocumentPromptVar, - this.documentSeparator = '\n\n', }) { _initLlmChainDocumentPromptVar(); } @@ -58,23 +55,30 @@ class StuffDocumentsChain extends BaseCombineDocumentsChain { /// Prompt to use to format each document. final BasePromptTemplate documentPrompt; - /// The variable name in the [llmChain] where to put the documents in. - /// If only one variable in the [llmChain], this doesn't need to be provided. - String llmChainStuffedDocumentPromptVar; - /// The string with which to join the formatted documents. final String documentSeparator; + /// The variable name in the [llmChain.prompt] where to put the documents in. + /// If only one variable in the [llmChain], this doesn't need to be provided. + String llmChainStuffedDocumentPromptVar; + /// Default [inputKey] value. - static const String defaultInputKey = - BaseCombineDocumentsChain.defaultInputKey; + static const defaultInputKey = BaseCombineDocumentsChain.defaultInputKey; /// Default [outputKey] value. - static const String defaultOutputKey = - BaseCombineDocumentsChain.defaultOutputKey; + static const defaultOutputKey = BaseCombineDocumentsChain.defaultOutputKey; + + /// Default [documentPrompt] value. + static const defaultDocumentPrompt = PromptTemplate( + inputVariables: {StuffDocumentsChain.pageContentPromptVar}, + template: '{${StuffDocumentsChain.pageContentPromptVar}}', + ); + + /// Default value for [documentSeparator]. + static const defaultDocumentSeparator = '\n\n'; /// Default value for [llmChainStuffedDocumentPromptVar]. - static const String defaultLlmChainStuffedDocumentPromptVar = 'context'; + static const defaultLlmChainStuffedDocumentPromptVar = 'context'; /// Prompt variable to use for the page content. static const pageContentPromptVar = diff --git a/packages/langchain/lib/src/chains/llm_chain.dart b/packages/langchain/lib/src/chains/llm_chain.dart index dacaf11e..fe246f7e 100644 --- a/packages/langchain/lib/src/chains/llm_chain.dart +++ b/packages/langchain/lib/src/chains/llm_chain.dart @@ -21,7 +21,7 @@ class LLMChain separators; diff --git a/packages/langchain/lib/src/documents/transformers/text_splitters/text_splitter.dart b/packages/langchain/lib/src/documents/transformers/text_splitters/text_splitter.dart index 71e2e654..6c1b4fca 100644 --- a/packages/langchain/lib/src/documents/transformers/text_splitters/text_splitter.dart +++ b/packages/langchain/lib/src/documents/transformers/text_splitters/text_splitter.dart @@ -4,10 +4,6 @@ import 'package:meta/meta.dart'; import '../../models/models.dart'; import '../base.dart'; -/// Default length function for [TextSplitter]. -/// Measures the length of the given chunk by counting its characters. -int _defaultLengthFunction(final String chunk) => chunk.characters.length; - /// {@template text_splitter} /// Interface for splitting text into chunks. /// {@endtemplate} @@ -15,7 +11,7 @@ abstract class TextSplitter implements BaseDocumentTransformer { const TextSplitter({ this.chunkSize = 4000, this.chunkOverlap = 200, - this.lengthFunction = _defaultLengthFunction, + this.lengthFunction = defaultLengthFunction, this.keepSeparator = false, this.addStartIndex = false, }) : assert(chunkOverlap <= chunkSize); @@ -35,6 +31,11 @@ abstract class TextSplitter implements BaseDocumentTransformer { /// If `true`, includes chunk's `start_index` in metadata. final bool addStartIndex; + /// Default length function for [TextSplitter]. + /// Measures the length of the given chunk by counting its characters. + static int defaultLengthFunction(final String chunk) => + chunk.characters.length; + /// Split text into multiple components. List splitText(final String text); diff --git a/packages/langchain/lib/src/model_io/llms/fake.dart b/packages/langchain/lib/src/model_io/llms/fake.dart index a7b1b776..f92f0c80 100644 --- a/packages/langchain/lib/src/model_io/llms/fake.dart +++ b/packages/langchain/lib/src/model_io/llms/fake.dart @@ -12,8 +12,10 @@ class FakeListLLM extends SimpleLLM { required this.responses, }); + /// Responses to return in order when called. final List responses; - int i = 0; + + int _i = 0; @override String get modelType => 'fake-list'; @@ -23,7 +25,7 @@ class FakeListLLM extends SimpleLLM { final String prompt, { final LLMOptions? options, }) { - return Future.value(responses[i++ % responses.length]); + return Future.value(responses[_i++ % responses.length]); } @override @@ -64,3 +66,43 @@ class FakeEchoLLM extends SimpleLLM { .toList(growable: false); } } + +/// {@template fake_handler_llm} +/// Fake LLM for testing. +/// It returns the string returned by the [handler] function. +/// {@endtemplate} +class FakeHandlerLLM extends SimpleLLM { + /// {@macro fake_handler_llm} + FakeHandlerLLM({ + required this.handler, + }); + + /// Function called to generate the response. + final String Function( + String prompt, + LLMOptions? options, + int callCount, + ) handler; + + int _callCount = 0; + + @override + String get modelType => 'fake-handler'; + + @override + Future callInternal( + final String prompt, { + final LLMOptions? options, + }) { + return Future.value(handler(prompt, options, ++_callCount)); + } + + @override + Future> tokenize(final PromptValue promptValue) async { + return promptValue + .toString() + .split(' ') + .map((final word) => word.hashCode) + .toList(growable: false); + } +} diff --git a/packages/langchain/test/chains/combine_documents/stuff.dart b/packages/langchain/test/chains/combine_documents/stuff.dart index c3b1cd03..70566840 100644 --- a/packages/langchain/test/chains/combine_documents/stuff.dart +++ b/packages/langchain/test/chains/combine_documents/stuff.dart @@ -24,9 +24,9 @@ void main() { 'input_documents': docs, }); expect(res['foo'], foo); - expect(res['input_documents'], docs); + expect(res[StuffDocumentsChain.defaultInputKey], docs); expect( - res['output_text'], + res[StuffDocumentsChain.defaultOutputKey], 'Print Hello world!. Context: Hello 1!\n\nHello 2!', ); }); diff --git a/packages/langchain/test/chains/summarization/summarize.dart b/packages/langchain/test/chains/summarization/summarize.dart new file mode 100644 index 00000000..7e2ff1cb --- /dev/null +++ b/packages/langchain/test/chains/summarization/summarize.dart @@ -0,0 +1,83 @@ +import 'package:langchain/langchain.dart'; +import 'package:test/test.dart'; + +void main() { + group('SummarizeChain tests', () { + test('Test SummarizeChain.stuff', () async { + final llm = FakeHandlerLLM( + handler: (final prompt, final options, final callCount) { + switch (callCount) { + case 1: + expect( + prompt, + 'Write a concise summary of the following:\n\n\n"Hello 1!\n\nHello 2!"\n\n\nCONCISE SUMMARY:', + ); + return 'Hello 12!'; + default: + throw TestFailure('Unexpected call count: $callCount'); + } + }, + ); + + final stuffSummarizeChain = SummarizeChain.stuff(llm: llm); + + const docs = [ + Document(pageContent: 'Hello 1!'), + Document(pageContent: 'Hello 2!'), + ]; + final res = await stuffSummarizeChain.call({ + SummarizeChain.defaultInputKey: docs, + }); + + expect(res[SummarizeChain.defaultOutputKey], 'Hello 12!'); + }); + + test('Test SummarizeChain.mapReduce', () async { + final llm = FakeHandlerLLM( + handler: (final prompt, final options, final callCount) { + switch (callCount) { + case 1: + expect( + prompt, + 'Write a concise summary of the following:\n\n\n"Hello 1!"\n\n\nCONCISE SUMMARY:', + ); + return '1'; + case 2: + expect( + prompt, + 'Write a concise summary of the following:\n\n\n"Hello 2!"\n\n\nCONCISE SUMMARY:', + ); + return '2'; + case 3: + expect( + prompt, + 'Write a concise summary of the following:\n\n\n"Hello 3!"\n\n\nCONCISE SUMMARY:', + ); + return '3'; + case 4: + expect( + prompt, + 'Write a concise summary of the following:\n\n\n"1\n\n2\n\n3"\n\n\nCONCISE SUMMARY:', + ); + return 'Hello 123!'; + default: + throw TestFailure('Unexpected call count: $callCount'); + } + }, + ); + + final mapReduceSummarizeChain = SummarizeChain.mapReduce(llm: llm); + + const docs = [ + Document(pageContent: 'Hello 1!'), + Document(pageContent: 'Hello 2!'), + Document(pageContent: 'Hello 3!'), + ]; + final res = await mapReduceSummarizeChain.call({ + SummarizeChain.defaultInputKey: docs, + }); + + expect(res[SummarizeChain.defaultOutputKey], 'Hello 123!'); + }); + }); +}