Skip to content

Commit

Permalink
feat: Add support for VectorStoreRetrieverMemory (#54)
Browse files Browse the repository at this point in the history
Co-authored-by: David Miguel <me@davidmiguel.com>
  • Loading branch information
dileep9490 and davidmigloz committed Jul 26, 2023
1 parent 3b5c0b2 commit 72cd1b1
Show file tree
Hide file tree
Showing 14 changed files with 210 additions and 4 deletions.
8 changes: 7 additions & 1 deletion packages/langchain/lib/src/memory/base.dart
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@ import 'models/models.dart';

/// {@template base_memory}
/// Base interface for memory in chains.
///
/// Memory refers to state in Chains. Memory can be used to store information
/// about past executions of a Chain and inject that information into the
/// inputs of future executions of the Chain. For example, for conversational
/// Chains Memory can be used to store conversations and automatically add them
/// to future model prompts so that the model has the necessary context to
/// respond coherently to the latest input.
/// {@endtemplate}
abstract interface class BaseMemory {
/// {@macro base_memory}
Expand All @@ -11,7 +18,6 @@ abstract interface class BaseMemory {
Set<String> get memoryKeys;

/// Returns key-value pairs given the [MemoryInputValues].
/// If empty, returns all memories.
Future<MemoryVariables> loadMemoryVariables([
final MemoryInputValues values = const {},
]);
Expand Down
8 changes: 5 additions & 3 deletions packages/langchain/lib/src/memory/chat.dart
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ abstract base class BaseChatMemory implements BaseMemory {
final BaseChatMessageHistory chatHistory;

/// The input key to use for the chat history.
///
/// If null, the input key is inferred from the prompt (the input key hat
/// was filled in by the user (i.e. not a memory key)).
final String? inputKey;

/// The output key to use for the chat history.
Expand All @@ -48,9 +51,8 @@ abstract base class BaseChatMemory implements BaseMemory {
final MemoryInputValues inputValues,
final MemoryOutputValues outputValues,
) {
final promptInputKey = inputKey == null
? getPromptInputKey(inputValues, memoryKeys)
: inputKey!;
final promptInputKey =
inputKey ?? getPromptInputKey(inputValues, memoryKeys);
String outputKey;
if (this.outputKey == null) {
if (outputValues.length != 1) {
Expand Down
1 change: 1 addition & 0 deletions packages/langchain/lib/src/memory/memory.dart
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ export 'chat.dart';
export 'models/models.dart';
export 'simple.dart';
export 'stores/stores.dart';
export 'vector_store.dart';
90 changes: 90 additions & 0 deletions packages/langchain/lib/src/memory/vector_store.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import '../../langchain.dart';
import 'utils.dart';

/// {@template vector_store_retriever_memory}
/// VectorStoreRetriever-backed memory.
/// {@endtemplate}
class VectorStoreRetrieverMemory implements BaseMemory {
/// {@macro vector_store_retriever_memory}
VectorStoreRetrieverMemory({
required this.retriever,
this.memoryKey = defaultMemoryKey,
this.inputKey,
this.excludeInputKeys = const {},
this.returnDocs = false,
});

/// VectorStoreRetriever object to connect to.
final VectorStoreRetriever retriever;

/// Name of the key where the memories are in the map returned by
/// [loadMemoryVariables].
final String memoryKey;

/// The input key to use for the query to the vector store.
///
/// If null, the input key is inferred from the prompt (the input key hat
/// was filled in by the user (i.e. not a memory key)).
final String? inputKey;

/// Input keys to exclude in addition to memory key when constructing the
/// document.
final Set<String> excludeInputKeys;

/// Whether or not to return the result of querying the database directly.
/// If false, the page content of all the documents is returned as a single
/// string.
final bool returnDocs;

/// Default key for [memoryKey].
static const String defaultMemoryKey = 'memory';

@override
Set<String> get memoryKeys => {memoryKey};

@override
Future<MemoryVariables> loadMemoryVariables([
final MemoryInputValues values = const {},
]) async {
final promptInputKey = inputKey ?? getPromptInputKey(values, memoryKeys);
final query = values[promptInputKey];
final docs = await retriever.getRelevantDocuments(query);
return {
memoryKey: returnDocs
? docs
: docs.map((final doc) => doc.pageContent).join('\n'),
};
}

@override
Future<void> saveContext({
required final MemoryInputValues inputValues,
required final MemoryOutputValues outputValues,
}) async {
final docs = _buildDocuments(inputValues, outputValues);
await retriever.addDocuments(docs);
}

/// Builds the documents to save to the vector store from the given
/// [inputValues] and [outputValues].
List<Document> _buildDocuments(
final MemoryInputValues inputValues,
final MemoryOutputValues outputValues,
) {
final excludeKeys = {memoryKey, ...excludeInputKeys};
final filteredInputs = {
for (final entry in inputValues.entries)
if (!excludeKeys.contains(entry.key)) entry.key: entry.value
};
final inputsOutputs = {...filteredInputs, ...outputValues};
final pageContent = inputsOutputs.entries.map((final entry) {
return '${entry.key}: ${entry.value}';
}).join('\n');
return [Document(pageContent: pageContent)];
}

@override
Future<void> clear() async {
// Nothing to clear
}
}
107 changes: 107 additions & 0 deletions packages/langchain/test/memory/vector_store_test.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import 'package:langchain/langchain.dart';
import 'package:test/test.dart';

void main() {
group('VectorStoreRetrieverMemory tests', () {
test('Test vector store memory', () async {
final embeddings = _FakeEmbeddings();
final vectorStore = MemoryVectorStore(embeddings: embeddings);
final memory = VectorStoreRetrieverMemory(
retriever: vectorStore.asRetriever(),
);

final result1 = await memory.loadMemoryVariables({'input': 'foo'});
expect(result1[VectorStoreRetrieverMemory.defaultMemoryKey], '');

await memory.saveContext(
inputValues: {
'foo': 'bar',
},
outputValues: {
'bar': 'foo',
},
);
final result2 = await memory.loadMemoryVariables({'input': 'foo'});
expect(
result2[VectorStoreRetrieverMemory.defaultMemoryKey],
'foo: bar\nbar: foo',
);
});

test('Test returnDocs', () async {
final embeddings = _FakeEmbeddings();
final vectorStore = MemoryVectorStore(embeddings: embeddings);
final memory = VectorStoreRetrieverMemory(
retriever: vectorStore.asRetriever(),
returnDocs: true,
);

await memory.saveContext(
inputValues: {
'foo': 'bar',
},
outputValues: {
'bar': 'foo',
},
);
final result = await memory.loadMemoryVariables({'input': 'foo'});
const expectedDoc = Document(pageContent: 'foo: bar\nbar: foo');
expect(
result[VectorStoreRetrieverMemory.defaultMemoryKey],
[expectedDoc],
);
});

test('Test excludeInputKeys', () async {
final embeddings = _FakeEmbeddings();
final vectorStore = MemoryVectorStore(embeddings: embeddings);
final memory = VectorStoreRetrieverMemory(
retriever: vectorStore.asRetriever(),
excludeInputKeys: {'foo'},
);

final result1 = await memory.loadMemoryVariables({'input': 'foo'});
expect(result1[VectorStoreRetrieverMemory.defaultMemoryKey], '');

await memory.saveContext(
inputValues: {
'foo': 'bar',
},
outputValues: {
'bar': 'foo',
},
);
final result2 = await memory.loadMemoryVariables({'input': 'foo'});
expect(
result2[VectorStoreRetrieverMemory.defaultMemoryKey],
'bar: foo',
);
});
});
}

class _FakeEmbeddings implements Embeddings {
@override
Future<List<List<double>>> embedDocuments(
final List<String> documents,
) async {
return documents.map(_embed).toList(growable: false);
}

@override
Future<List<double>> embedQuery(
final String query,
) async {
return _embed(query);
}

List<double> _embed(final String text) {
return switch (text) {
'foo' => [1.0, 1.0],
'bar' => [-1.0, -1.0],
'foo: bar\nbar: foo' => [1.0, -1.0],
'bar: foo' => [-1.0, 1.0],
_ => throw Exception('Unknown text: $text'),
};
}
}
File renamed without changes.

0 comments on commit 72cd1b1

Please sign in to comment.