Skip to content

Commit

Permalink
feat: Add support for SummarizeChain (#58)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmigloz committed Jul 22, 2023
1 parent 381a676 commit 9499fc0
Show file tree
Hide file tree
Showing 12 changed files with 411 additions and 41 deletions.
1 change: 1 addition & 0 deletions packages/langchain/lib/src/chains/chains.dart
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ export 'models/models.dart';
export 'question_answering/question_answering.dart';
export 'retrieval_qa.dart';
export 'sequential.dart';
export 'summarization/summarization.dart';
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class MapReduceDocumentsChain extends BaseCombineDocumentsChain {
required this.reduceDocumentsChain,
super.inputKey = defaultInputKey,
super.outputKey = defaultOutputKey,
this.llmChainDocumentPromptVar = defaultLlmChainDocumentPromptVar,
this.mapLlmChainDocumentPromptVar = defaultLlmChainDocumentPromptVar,
this.returnIntermediateSteps = false,
}) {
_initLlmChainDocumentPromptVar();
Expand All @@ -64,7 +64,7 @@ class MapReduceDocumentsChain extends BaseCombineDocumentsChain {

/// The variable name in the [mapLlmChain] where to put the documents in.
/// If only one variable in the [mapLlmChain], this doesn't need to be provided.
String llmChainDocumentPromptVar;
String mapLlmChainDocumentPromptVar;

/// Return the results of the map steps in the output.
final bool returnIntermediateSteps;
Expand All @@ -77,7 +77,7 @@ class MapReduceDocumentsChain extends BaseCombineDocumentsChain {
static const String defaultOutputKey =
BaseCombineDocumentsChain.defaultOutputKey;

/// Default value for [llmChainDocumentPromptVar].
/// Default value for [mapLlmChainDocumentPromptVar].
static const String defaultLlmChainDocumentPromptVar = 'context';

/// Output key for the chain's intermediate steps output.
Expand All @@ -104,15 +104,15 @@ class MapReduceDocumentsChain extends BaseCombineDocumentsChain {
// with this variable name
final llmChainInputVariables = mapLlmChain.prompt.inputVariables;
if (llmChainInputVariables.length == 1) {
llmChainDocumentPromptVar = llmChainInputVariables.first;
} else if (llmChainDocumentPromptVar.isEmpty) {
mapLlmChainDocumentPromptVar = llmChainInputVariables.first;
} else if (mapLlmChainDocumentPromptVar.isEmpty) {
throw ArgumentError(
'llmChainDocumentPromptVar must be provided if there are multiple '
'llmChain input variables',
);
} else if (!llmChainInputVariables.contains(llmChainDocumentPromptVar)) {
} else if (!llmChainInputVariables.contains(mapLlmChainDocumentPromptVar)) {
throw ArgumentError(
'llmChainDocumentPromptVar ($llmChainDocumentPromptVar) was not found '
'llmChainDocumentPromptVar ($mapLlmChainDocumentPromptVar) was not found '
'in llmChain input variables',
);
}
Expand Down Expand Up @@ -143,7 +143,12 @@ class MapReduceDocumentsChain extends BaseCombineDocumentsChain {
}) async {
final mapResults = await mapLlmChain.apply(
docs
.map((final d) => {...inputs, llmChainDocumentPromptVar: d})
.map(
(final d) => {
...inputs,
mapLlmChainDocumentPromptVar: d.pageContent,
},
)
.toList(growable: false),
);

Expand Down
11 changes: 6 additions & 5 deletions packages/langchain/lib/src/chains/combine_documents/reduce.dart
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class ReduceDocumentsChain extends BaseCombineDocumentsChain {
super.outputKey = defaultOutputKey,
required this.combineDocumentsChain,
this.collapseDocumentsChain,
this.tokenMax = 3000,
this.tokenMax = defaultTokenMax,
});

/// Final chain to call to combine documents.
Expand All @@ -80,12 +80,13 @@ class ReduceDocumentsChain extends BaseCombineDocumentsChain {
final int tokenMax;

/// Default [inputKey] value.
static const String defaultInputKey =
BaseCombineDocumentsChain.defaultInputKey;
static const defaultInputKey = BaseCombineDocumentsChain.defaultInputKey;

/// Default [outputKey] value.
static const String defaultOutputKey =
BaseCombineDocumentsChain.defaultOutputKey;
static const defaultOutputKey = BaseCombineDocumentsChain.defaultOutputKey;

/// Default [tokenMax] value.
static const defaultTokenMax = 3000;

@override
String get chainType => 'reduce_documents_chain';
Expand Down
32 changes: 18 additions & 14 deletions packages/langchain/lib/src/chains/combine_documents/stuff.dart
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,10 @@ class StuffDocumentsChain extends BaseCombineDocumentsChain {
required this.llmChain,
super.inputKey = defaultInputKey,
super.outputKey = defaultOutputKey,
this.documentPrompt = const PromptTemplate(
inputVariables: {StuffDocumentsChain.pageContentPromptVar},
template: '{${StuffDocumentsChain.pageContentPromptVar}}',
),
this.documentPrompt = defaultDocumentPrompt,
this.documentSeparator = defaultDocumentSeparator,
this.llmChainStuffedDocumentPromptVar =
defaultLlmChainStuffedDocumentPromptVar,
this.documentSeparator = '\n\n',
}) {
_initLlmChainDocumentPromptVar();
}
Expand All @@ -58,23 +55,30 @@ class StuffDocumentsChain extends BaseCombineDocumentsChain {
/// Prompt to use to format each document.
final BasePromptTemplate documentPrompt;

/// The variable name in the [llmChain] where to put the documents in.
/// If only one variable in the [llmChain], this doesn't need to be provided.
String llmChainStuffedDocumentPromptVar;

/// The string with which to join the formatted documents.
final String documentSeparator;

/// The variable name in the [llmChain.prompt] where to put the documents in.
/// If only one variable in the [llmChain], this doesn't need to be provided.
String llmChainStuffedDocumentPromptVar;

/// Default [inputKey] value.
static const String defaultInputKey =
BaseCombineDocumentsChain.defaultInputKey;
static const defaultInputKey = BaseCombineDocumentsChain.defaultInputKey;

/// Default [outputKey] value.
static const String defaultOutputKey =
BaseCombineDocumentsChain.defaultOutputKey;
static const defaultOutputKey = BaseCombineDocumentsChain.defaultOutputKey;

/// Default [documentPrompt] value.
static const defaultDocumentPrompt = PromptTemplate(
inputVariables: {StuffDocumentsChain.pageContentPromptVar},
template: '{${StuffDocumentsChain.pageContentPromptVar}}',
);

/// Default value for [documentSeparator].
static const defaultDocumentSeparator = '\n\n';

/// Default value for [llmChainStuffedDocumentPromptVar].
static const String defaultLlmChainStuffedDocumentPromptVar = 'context';
static const defaultLlmChainStuffedDocumentPromptVar = 'context';

/// Prompt variable to use for the page content.
static const pageContentPromptVar =
Expand Down
5 changes: 4 additions & 1 deletion packages/langchain/lib/src/chains/llm_chain.dart
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class LLMChain<LLMInput extends Object, LLMOptions extends LanguageModelOptions,
const LLMChain({
required this.prompt,
required this.llm,
this.outputKey = 'text',
this.outputKey = defaultOutputKey,
this.outputParser,
this.returnFinalOnly = true,
this.llmOptions,
Expand Down Expand Up @@ -51,6 +51,9 @@ class LLMChain<LLMInput extends Object, LLMOptions extends LanguageModelOptions,
/// Options to pass to the language model.
final LLMOptions? llmOptions;

/// Default output key.
static const defaultOutputKey = 'text';

/// Output key to use for returning the full generation.
static const fullGenerationOutputKey = 'full_generation';

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
export 'summarize.dart';
229 changes: 229 additions & 0 deletions packages/langchain/lib/src/chains/summarization/summarize.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
import '../../model_io/language_models/language_models.dart';
import '../../model_io/prompts/prompts.dart';
import '../combine_documents/combine_documents.dart';
import '../llm_chain.dart';

const _template = '''
Write a concise summary of the following:
"{context}"
CONCISE SUMMARY:''';

const _promptTemplate = PromptTemplate(
template: _template,
inputVariables: {'context'},
);

/// Chain for summarizing documents.
///
/// There are two methods to summarize documents:
/// - [stuff] uses the [StuffDocumentsChain] to combine all the documents into
/// a single string, then prompts the model to summarize that string. This
/// method is limited by the context length limit of the model.
/// - [mapReduce] uses the [MapReduceDocumentsChain] to summarize each document
/// individually, then combines the results into a single summary.
abstract class SummarizeChain {
/// The [stuff] method uses the [StuffDocumentsChain] to combine all the
/// documents into a single string, then prompts the model to summarize that
/// string. This method is limited by the context length limit of the [llm].
///
/// - [llm] is the language model to use for summarization.
/// - [inputKey] is the input key where the documents to summarize will be
/// placed.
/// - [outputKey] is the output key where the summary will be placed.
/// - [promptTemplate] is the prompt to use to summarize the documents.
/// The default prompt template instructs the model to create a
/// "concise summary".
/// - [documentPrompt] is the prompt to use to format each document before
/// combining them. The default prompt just takes the content of the
/// document.
/// - [stuffedDocumentPromptVar] is the variable used in the [promptTemplate]
/// to indicate where the stuffed document should be placed.
/// - [documentSeparator] is the separator used to join the documents while
/// stuffing them.
///
/// Example:
/// ```dart
/// final loader = TextLoader('path/to/file.txt');
/// final docs = await loader.load();
///
/// const textSplitter = RecursiveCharacterTextSplitter();
/// final docsChunks = textSplitter.splitDocuments(docs);
///
/// final llm = ChatOpenAI(apiKey: openAIKey);
/// final summarizeChain = SummarizeChain.stuff(llm: llm);
///
/// final summary = await summarizeChain.run(docsChunks);
/// ```
static StuffDocumentsChain stuff({
required final BaseLanguageModel llm,
final String inputKey = SummarizeChain.defaultInputKey,
final String outputKey = SummarizeChain.defaultOutputKey,
final BasePromptTemplate promptTemplate = _promptTemplate,
final BasePromptTemplate documentPrompt =
StuffDocumentsChain.defaultDocumentPrompt,
final String stuffedDocumentPromptVar =
StuffDocumentsChain.defaultLlmChainStuffedDocumentPromptVar,
final String documentSeparator =
StuffDocumentsChain.defaultDocumentSeparator,
}) {
final llmChain = LLMChain(
llm: llm,
prompt: promptTemplate,
);

return StuffDocumentsChain(
llmChain: llmChain,
inputKey: inputKey,
outputKey: outputKey,
documentPrompt: documentPrompt,
llmChainStuffedDocumentPromptVar: stuffedDocumentPromptVar,
documentSeparator: documentSeparator,
);
}

/// The [mapReduce] method uses the [MapReduceDocumentsChain] to summarize
/// each document individually, then combines the results into a single
/// summary.
///
/// The [MapReduceDocumentsChain] involves two chains behind the scenes:
/// - [MapReduceDocumentsChain.mapLlmChain] this is the chain that is applied
/// to each document to create a summary.
/// - [MapReduceDocumentsChain.reduceDocumentsChain] this is a
/// [ReduceDocumentsChain] that reduces the summaries of each document into
/// a single summary.
///
/// - [llm] is the language model to use for summarization.
/// - [summaryMaxTokens] is the maximum number of tokens allowed in the final
/// summary. If the final summary exceeds this limit, it will be collapsed
/// using [collapsePrompt].
/// - [inputKey] is the input key where the documents to summarize will be
/// placed.
/// - [outputKey] is the output key where the summary will be placed.
/// - [mapPrompt] is the prompt to use to summarize each document
/// individually.
/// - [mapDocumentPromptVar] is the variable used in the [mapPrompt] to
/// indicate where the document should be placed.
/// - [combinePrompt] is the prompt to use to summarize the summaries of each
/// document.
/// - [combineLlm] is the language model to use to summarize the summaries of
/// each document. By default, [llm] is used.
/// - [combineDocumentPrompt] is the prompt to use to format each individual
/// document before summarizing it. The default prompt just takes the
/// content of the document.
/// - [combineDocumentPromptVar] is the variable used in the [combinePrompt]
/// to indicate where the summaries should be placed.
/// - [combineDocumentSeparator] is the separator used to join the summaries.
/// - [collapsePrompt] is the prompt to use to collapse the final summary if
/// it exceeds the [summaryMaxTokens] limit. By default, [combinePrompt] is used.
/// - [collapseLlm] is the language model to use to collapse the final
/// summary. By default, [combineLlm] is used if it is not null, otherwise
/// [llm] is used.
/// - [collapseDocumentPrompt] is the prompt to use to format the final
/// summary before collapsing it. The default prompt just takes the content
/// of the document.
/// - [collapseDocumentPromptVar] is the variable used in the [collapsePrompt]
/// to indicate where the summary to be collapsed should be placed.
/// - [collapseDocumentSeparator] is the separator used to join the summary
/// to be collapsed.
/// - [returnIntermediateSteps] indicates whether to return the intermediate
/// steps of the summarization process. If true, the intermediate steps
/// will be placed in the [MapReduceDocumentsChain.intermediateStepsOutputKey]
/// output key.
///
/// Example:
/// ```dart
/// final loader = WebBaseLoader(['https://example.com']);
/// final docs = await loader.load();
///
/// const textSplitter = RecursiveCharacterTextSplitter();
/// final docsChunks = textSplitter.splitDocuments(docs);
///
/// final llm = ChatOpenAI(apiKey: openAIKey);
/// final summarizeChain = SummarizeChain.mapReduce(llm: llm);
///
/// final summary = await summarizeChain.run(docsChunks);
/// ```
static MapReduceDocumentsChain mapReduce({
required final BaseLanguageModel llm,
final int summaryMaxTokens = ReduceDocumentsChain.defaultTokenMax,
final String inputKey = SummarizeChain.defaultInputKey,
final String outputKey = SummarizeChain.defaultOutputKey,
final BasePromptTemplate mapPrompt = _promptTemplate,
final String mapDocumentPromptVar =
MapReduceDocumentsChain.defaultLlmChainDocumentPromptVar,
final BasePromptTemplate combinePrompt = _promptTemplate,
final BaseLanguageModel? combineLlm,
final BasePromptTemplate combineDocumentPrompt =
StuffDocumentsChain.defaultDocumentPrompt,
final String combineDocumentPromptVar =
StuffDocumentsChain.defaultLlmChainStuffedDocumentPromptVar,
final String combineDocumentSeparator =
StuffDocumentsChain.defaultDocumentSeparator,
final BasePromptTemplate? collapsePrompt,
final BaseLanguageModel? collapseLlm,
final BasePromptTemplate collapseDocumentPrompt =
StuffDocumentsChain.defaultDocumentPrompt,
final String collapseDocumentPromptVar =
StuffDocumentsChain.defaultLlmChainStuffedDocumentPromptVar,
final String collapseDocumentSeparator =
StuffDocumentsChain.defaultDocumentSeparator,
final bool returnIntermediateSteps = false,
}) {
final finalCombineLlm = combineLlm ?? llm;
final combineLlmChain = LLMChain(
llm: finalCombineLlm,
prompt: combinePrompt,
);

final combineDocumentsChain = StuffDocumentsChain(
llmChain: combineLlmChain,
documentPrompt: combineDocumentPrompt,
llmChainStuffedDocumentPromptVar: combineDocumentPromptVar,
documentSeparator: combineDocumentSeparator,
);

StuffDocumentsChain? collapseDocumentsChain;
if (collapsePrompt != null) {
final finalCollapseLLm = collapseLlm ?? combineLlm ?? llm;
final collapseLlmChain = LLMChain(
llm: finalCollapseLLm,
prompt: collapsePrompt,
);
collapseDocumentsChain = StuffDocumentsChain(
llmChain: collapseLlmChain,
documentPrompt: collapseDocumentPrompt,
llmChainStuffedDocumentPromptVar: collapseDocumentPromptVar,
documentSeparator: collapseDocumentSeparator,
);
}

final reduceDocumentsChain = ReduceDocumentsChain(
combineDocumentsChain: combineDocumentsChain,
collapseDocumentsChain: collapseDocumentsChain,
tokenMax: summaryMaxTokens,
);

final mapLlmChain = LLMChain(llm: llm, prompt: mapPrompt);

return MapReduceDocumentsChain(
inputKey: inputKey,
outputKey: outputKey,
mapLlmChain: mapLlmChain,
reduceDocumentsChain: reduceDocumentsChain,
mapLlmChainDocumentPromptVar: mapDocumentPromptVar,
returnIntermediateSteps: returnIntermediateSteps,
);
}

/// Default input key for the summarization chain where to place the
/// documents to summarize.
static const defaultInputKey = BaseCombineDocumentsChain.defaultInputKey;

/// Default output key for the summarization chain where to place the
/// summary.
static const defaultOutputKey = BaseCombineDocumentsChain.defaultOutputKey;
}

0 comments on commit 9499fc0

Please sign in to comment.