diff --git a/packages/langchain/lib/src/documents/vector_stores/memory.dart b/packages/langchain/lib/src/documents/vector_stores/memory.dart index 9a402a2f..26cc4519 100644 --- a/packages/langchain/lib/src/documents/vector_stores/memory.dart +++ b/packages/langchain/lib/src/documents/vector_stores/memory.dart @@ -7,7 +7,6 @@ import '../embeddings/base.dart'; import '../models/models.dart'; import 'base.dart'; -/// {@template memory_vector_store} /// Vector store that stores vectors in memory. /// /// By default, it uses cosine similarity to compare vectors. @@ -15,37 +14,74 @@ import 'base.dart'; /// It iterates over all vectors in the store to find the most similar vectors. /// This is not efficient for large vector stores as it has a time complexity /// of O(vector_dimensionality * num_vectors). -/// {@endtemplate} +/// +/// For more efficient vector stores, see [VertexAIMatchingEngine]. class MemoryVectorStore extends VectorStore { - /// {@macro memory_vector_store} + /// Main constructor for [MemoryVectorStore]. + /// + /// - [embeddings] is the embeddings model to use to embed the documents. + /// - [similarityFunction] is the similarity function to use when comparing + /// vectors. By default, it uses cosine similarity. + /// - [initialMemoryVectors] is an optional list of [MemoryVector] to + /// initialize the vector store with. This is useful when loading a vector + /// store from a database or file. + /// + /// If you want to create and populate a [MemoryVectorStore] from a list of + /// documents or texts, use [MemoryVectorStore.fromDocuments] or + /// [MemoryVectorStore.fromText]. MemoryVectorStore({ required super.embeddings, this.similarityFunction = cosineSimilarity, - }); + final List? initialMemoryVectors, + }) : memoryVectors = [...?initialMemoryVectors]; /// Similarity function to use when comparing vectors. final double Function(List a, List b) similarityFunction; /// Vectors stored in memory. - final List memoryVectors = []; + final List memoryVectors; + + /// Creates a vector store from a list of documents. + /// + /// - [documents] is a list of documents to add to the vector store. + /// - [embeddings] is the embeddings model to use to embed the documents. + static Future fromDocuments({ + required final List documents, + required final Embeddings embeddings, + }) async { + final store = MemoryVectorStore(embeddings: embeddings); + await store.addDocuments(documents: documents); + return store; + } /// Creates a vector store from a list of texts. /// + /// - [ids] is a list of ids to add to the vector store. /// - [texts] is a list of texts to add to the vector store. /// - [metadatas] is a list of metadata to add to the vector store. /// - [embeddings] is the embeddings model to use to embed the texts. static Future fromText({ + final List? ids, required final List texts, - required final List> metadatas, + final List>? metadatas, required final Embeddings embeddings, }) async { + assert( + ids == null || ids.length == texts.length, + 'ids and texts must have the same length', + ); + assert( + metadatas == null || metadatas.length == texts.length, + 'metadatas and texts must have the same length', + ); final vs = MemoryVectorStore(embeddings: embeddings); await vs.addDocuments( documents: texts .mapIndexed( (final i, final text) => Document( + id: ids?[i], pageContent: text, - metadata: i < metadatas.length ? metadatas[i] : const {}, + metadata: metadatas?[i] ?? const {}, ), ) .toList(growable: false), @@ -53,19 +89,6 @@ class MemoryVectorStore extends VectorStore { return vs; } - /// Creates a vector store from a list of documents. - /// - /// - [documents] is a list of documents to add to the vector store. - /// - [embeddings] is the embeddings model to use to embed the documents. - static Future fromDocuments({ - required final List documents, - required final Embeddings embeddings, - }) async { - final store = MemoryVectorStore(embeddings: embeddings); - await store.addDocuments(documents: documents); - return store; - } - @override Future> addVectors({ required final List> vectors, @@ -75,9 +98,8 @@ class MemoryVectorStore extends VectorStore { vectors.mapIndexed((final i, final vector) { final doc = documents[i]; return MemoryVector( - content: doc.pageContent, + document: doc, embedding: vector, - metadata: doc.metadata, ); }), ); @@ -85,8 +107,11 @@ class MemoryVectorStore extends VectorStore { } @override - Future delete({required final List ids}) { - throw UnimplementedError(); + Future delete({required final List ids}) async { + memoryVectors.removeWhere( + (final vector) => ids.contains(vector.document.id), + ); + return true; } @override @@ -109,10 +134,7 @@ class MemoryVectorStore extends VectorStore { return searches .map( (final search) => ( - Document( - pageContent: memoryVectors[search.key].content, - metadata: memoryVectors[search.key].metadata, - ), + memoryVectors[search.key].document, search.value, ), ) @@ -121,20 +143,56 @@ class MemoryVectorStore extends VectorStore { } /// {@template memory_vector} -/// Represents a vector in memory. +/// Represents an entry of [MemoryVectorStore]. /// {@endtemplate} @immutable class MemoryVector { /// {@macro memory_vector} const MemoryVector({ - required this.content, + required this.document, required this.embedding, - required this.metadata, }); - final String content; + /// Document associated with the vector. + final Document document; + + /// Vector embedding. final List embedding; - final Map metadata; + + /// Creates a vector from a map. + factory MemoryVector.fromMap(final Map map) { + return MemoryVector( + document: Document.fromMap(map['document'] as Map), + embedding: map['embedding'] as List, + ); + } + + /// Converts the vector to a map. + Map toMap() { + return { + 'document': document.toMap(), + 'embedding': embedding, + }; + } + + @override + bool operator ==(covariant final MemoryVector other) { + return identical(this, other) || + runtimeType == other.runtimeType && + document == other.document && + const ListEquality().equals(embedding, other.embedding); + } + + @override + int get hashCode => + document.hashCode ^ const ListEquality().hash(embedding); + + @override + String toString() { + return 'MemoryVector{' + 'document: $document, ' + 'embedding: ${embedding.length}}'; + } } /// Measures the cosine of the angle between two vectors in a vector space. diff --git a/packages/langchain/test/documents/vector_stores/memory_test.dart b/packages/langchain/test/documents/vector_stores/memory_test.dart index f7ffa081..955b1a2f 100644 --- a/packages/langchain/test/documents/vector_stores/memory_test.dart +++ b/packages/langchain/test/documents/vector_stores/memory_test.dart @@ -3,29 +3,132 @@ import 'package:test/test.dart'; void main() { group('MemoryVectorStore tests', () { - test('Test MemoryVectorStore search', () async { - final embeddings = _FakeEmbeddings(vectors: [_chaoVector]); - final store = MemoryVectorStore(embeddings: embeddings); - await store.addVectors( - vectors: [ - _helloVector, - _hiVector, - _byeVector, - _whatsThisVector, - ], + test('Test MemoryVectorStore.fromDocuments', () async { + const embeddings = _FakeEmbeddings(); + final store = await MemoryVectorStore.fromDocuments( documents: [ - const Document(pageContent: 'hello'), - const Document(pageContent: 'hi'), - const Document(pageContent: 'bye'), - const Document(pageContent: "what's this"), + const Document(id: '1', pageContent: 'hello'), + const Document(id: '2', pageContent: 'hi'), + const Document(id: '3', pageContent: 'bye'), + const Document(id: '4', pageContent: "what's this"), ], + embeddings: embeddings, + ); + + final results = await store.similaritySearch(query: 'chao', k: 1); + + expect(results.length, 1); + expect(results.first.id, '3'); + expect(results.first.pageContent, 'bye'); + }); + + test('Test MemoryVectorStore.fromText', () async { + const embeddings = _FakeEmbeddings(); + final store = await MemoryVectorStore.fromText( + ids: const ['1', '2', '3', '4'], + texts: const ['hello', 'hi', 'bye', "what's this"], + embeddings: embeddings, ); final results = await store.similaritySearch(query: 'chao', k: 1); expect(results.length, 1); + expect(results.first.id, '3'); expect(results.first.pageContent, 'bye'); }); + + test('Test MemoryVectorStore with initialMemoryVectors', () async { + const embeddings = _FakeEmbeddings(); + final store = MemoryVectorStore( + embeddings: embeddings, + initialMemoryVectors: [ + MemoryVector( + document: const Document(id: '1', pageContent: 'hello'), + embedding: _helloVector, + ), + MemoryVector( + document: const Document(id: '2', pageContent: 'hi'), + embedding: _hiVector, + ), + MemoryVector( + document: const Document(id: '3', pageContent: 'bye'), + embedding: _byeVector, + ), + MemoryVector( + document: const Document(id: '4', pageContent: "what's this"), + embedding: _whatsThisVector, + ), + ], + ); + + final results = await store.similaritySearch(query: 'chao', k: 1); + + expect(results.length, 1); + expect(results.first.id, '3'); + expect(results.first.pageContent, 'bye'); + }); + + test('Test delete entry', () async { + const embeddings = _FakeEmbeddings(); + final store = await MemoryVectorStore.fromDocuments( + documents: [ + const Document(id: '1', pageContent: 'hello'), + const Document(id: '2', pageContent: 'hi'), + const Document(id: '3', pageContent: 'bye'), + const Document(id: '4', pageContent: "what's this"), + ], + embeddings: embeddings, + ); + await store.delete(ids: ['3']); + + final results = await store.similaritySearch(query: 'chao', k: 1); + + expect(results.length, 1); + expect(results.first.id, '2'); + expect(results.first.pageContent, 'hi'); + }); + + test('Test toMap and fromMap', () async { + const embeddings = _FakeEmbeddings(); + final store = MemoryVectorStore( + embeddings: embeddings, + initialMemoryVectors: const [ + MemoryVector( + document: Document(id: '1', pageContent: 'hello'), + embedding: [1, 2, 3], + ), + MemoryVector( + document: Document(id: '2', pageContent: 'hi'), + embedding: [4, 5, 6], + ), + ], + ); + + final map = store.memoryVectors.map((final v) => v.toMap()); + final expectedMap = [ + { + 'document': { + 'id': '1', + 'pageContent': 'hello', + 'metadata': {}, + }, + 'embedding': [1.0, 2.0, 3.0], + }, + { + 'document': { + 'id': '2', + 'pageContent': 'hi', + 'metadata': {}, + }, + 'embedding': [4.0, 5.0, 6.0], + }, + ]; + expect(map, expectedMap); + + final newMap = + expectedMap.map(MemoryVector.fromMap).toList(growable: false); + expect(newMap, store.memoryVectors); + }); }); group('Search algorithms tests', () { @@ -61,24 +164,33 @@ void main() { } class _FakeEmbeddings implements Embeddings { - _FakeEmbeddings({ - this.vectors = const [], - }); - - final List> vectors; + const _FakeEmbeddings(); @override Future> embedQuery( final String query, ) async { - return vectors[0]; + return embedText(query); } @override Future>> embedDocuments( final List documents, ) async { - return vectors; + return [ + for (final document in documents) embedText(document), + ]; + } + + List embedText(final String text) { + return switch (text) { + 'hello' => _helloVector, + 'hi' => _hiVector, + 'bye' => _byeVector, + "what's this" => _whatsThisVector, + 'chao' => _chaoVector, + _ => throw Exception('Unknown text: $text'), + }; } }