Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(stores): Initial vectors, ids, and delete in MemoryVectorStore #123

Merged
merged 1 commit into from
Aug 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 91 additions & 33 deletions packages/langchain/lib/src/documents/vector_stores/memory.dart
Original file line number Diff line number Diff line change
Expand Up @@ -7,65 +7,88 @@ import '../embeddings/base.dart';
import '../models/models.dart';
import 'base.dart';

/// {@template memory_vector_store}
/// Vector store that stores vectors in memory.
///
/// By default, it uses cosine similarity to compare vectors.
///
/// It iterates over all vectors in the store to find the most similar vectors.
/// This is not efficient for large vector stores as it has a time complexity
/// of O(vector_dimensionality * num_vectors).
/// {@endtemplate}
///
/// For more efficient vector stores, see [VertexAIMatchingEngine].
class MemoryVectorStore extends VectorStore {
/// {@macro memory_vector_store}
/// Main constructor for [MemoryVectorStore].
///
/// - [embeddings] is the embeddings model to use to embed the documents.
/// - [similarityFunction] is the similarity function to use when comparing
/// vectors. By default, it uses cosine similarity.
/// - [initialMemoryVectors] is an optional list of [MemoryVector] to
/// initialize the vector store with. This is useful when loading a vector
/// store from a database or file.
///
/// If you want to create and populate a [MemoryVectorStore] from a list of
/// documents or texts, use [MemoryVectorStore.fromDocuments] or
/// [MemoryVectorStore.fromText].
MemoryVectorStore({
required super.embeddings,
this.similarityFunction = cosineSimilarity,
});
final List<MemoryVector>? initialMemoryVectors,
}) : memoryVectors = [...?initialMemoryVectors];

/// Similarity function to use when comparing vectors.
final double Function(List<double> a, List<double> b) similarityFunction;

/// Vectors stored in memory.
final List<MemoryVector> memoryVectors = [];
final List<MemoryVector> memoryVectors;

/// Creates a vector store from a list of documents.
///
/// - [documents] is a list of documents to add to the vector store.
/// - [embeddings] is the embeddings model to use to embed the documents.
static Future<MemoryVectorStore> fromDocuments({
required final List<Document> documents,
required final Embeddings embeddings,
}) async {
final store = MemoryVectorStore(embeddings: embeddings);
await store.addDocuments(documents: documents);
return store;
}

/// Creates a vector store from a list of texts.
///
/// - [ids] is a list of ids to add to the vector store.
/// - [texts] is a list of texts to add to the vector store.
/// - [metadatas] is a list of metadata to add to the vector store.
/// - [embeddings] is the embeddings model to use to embed the texts.
static Future<MemoryVectorStore> fromText({
final List<String>? ids,
required final List<String> texts,
required final List<Map<String, dynamic>> metadatas,
final List<Map<String, dynamic>>? metadatas,
required final Embeddings embeddings,
}) async {
assert(
ids == null || ids.length == texts.length,
'ids and texts must have the same length',
);
assert(
metadatas == null || metadatas.length == texts.length,
'metadatas and texts must have the same length',
);
final vs = MemoryVectorStore(embeddings: embeddings);
await vs.addDocuments(
documents: texts
.mapIndexed(
(final i, final text) => Document(
id: ids?[i],
pageContent: text,
metadata: i < metadatas.length ? metadatas[i] : const {},
metadata: metadatas?[i] ?? const {},
),
)
.toList(growable: false),
);
return vs;
}

/// Creates a vector store from a list of documents.
///
/// - [documents] is a list of documents to add to the vector store.
/// - [embeddings] is the embeddings model to use to embed the documents.
static Future<MemoryVectorStore> fromDocuments({
required final List<Document> documents,
required final Embeddings embeddings,
}) async {
final store = MemoryVectorStore(embeddings: embeddings);
await store.addDocuments(documents: documents);
return store;
}

@override
Future<List<String>> addVectors({
required final List<List<double>> vectors,
Expand All @@ -75,18 +98,20 @@ class MemoryVectorStore extends VectorStore {
vectors.mapIndexed((final i, final vector) {
final doc = documents[i];
return MemoryVector(
content: doc.pageContent,
document: doc,
embedding: vector,
metadata: doc.metadata,
);
}),
);
return const [];
}

@override
Future<bool> delete({required final List<String> ids}) {
throw UnimplementedError();
Future<bool> delete({required final List<String> ids}) async {
memoryVectors.removeWhere(
(final vector) => ids.contains(vector.document.id),
);
return true;
}

@override
Expand All @@ -109,10 +134,7 @@ class MemoryVectorStore extends VectorStore {
return searches
.map(
(final search) => (
Document(
pageContent: memoryVectors[search.key].content,
metadata: memoryVectors[search.key].metadata,
),
memoryVectors[search.key].document,
search.value,
),
)
Expand All @@ -121,20 +143,56 @@ class MemoryVectorStore extends VectorStore {
}

/// {@template memory_vector}
/// Represents a vector in memory.
/// Represents an entry of [MemoryVectorStore].
/// {@endtemplate}
@immutable
class MemoryVector {
/// {@macro memory_vector}
const MemoryVector({
required this.content,
required this.document,
required this.embedding,
required this.metadata,
});

final String content;
/// Document associated with the vector.
final Document document;

/// Vector embedding.
final List<double> embedding;
final Map<String, dynamic> metadata;

/// Creates a vector from a map.
factory MemoryVector.fromMap(final Map<String, dynamic> map) {
return MemoryVector(
document: Document.fromMap(map['document'] as Map<String, dynamic>),
embedding: map['embedding'] as List<double>,
);
}

/// Converts the vector to a map.
Map<String, dynamic> toMap() {
return {
'document': document.toMap(),
'embedding': embedding,
};
}

@override
bool operator ==(covariant final MemoryVector other) {
return identical(this, other) ||
runtimeType == other.runtimeType &&
document == other.document &&
const ListEquality<double>().equals(embedding, other.embedding);
}

@override
int get hashCode =>
document.hashCode ^ const ListEquality<double>().hash(embedding);

@override
String toString() {
return 'MemoryVector{'
'document: $document, '
'embedding: ${embedding.length}}';
}
}

/// Measures the cosine of the angle between two vectors in a vector space.
Expand Down
154 changes: 133 additions & 21 deletions packages/langchain/test/documents/vector_stores/memory_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,132 @@ import 'package:test/test.dart';

void main() {
group('MemoryVectorStore tests', () {
test('Test MemoryVectorStore search', () async {
final embeddings = _FakeEmbeddings(vectors: [_chaoVector]);
final store = MemoryVectorStore(embeddings: embeddings);
await store.addVectors(
vectors: [
_helloVector,
_hiVector,
_byeVector,
_whatsThisVector,
],
test('Test MemoryVectorStore.fromDocuments', () async {
const embeddings = _FakeEmbeddings();
final store = await MemoryVectorStore.fromDocuments(
documents: [
const Document(pageContent: 'hello'),
const Document(pageContent: 'hi'),
const Document(pageContent: 'bye'),
const Document(pageContent: "what's this"),
const Document(id: '1', pageContent: 'hello'),
const Document(id: '2', pageContent: 'hi'),
const Document(id: '3', pageContent: 'bye'),
const Document(id: '4', pageContent: "what's this"),
],
embeddings: embeddings,
);

final results = await store.similaritySearch(query: 'chao', k: 1);

expect(results.length, 1);
expect(results.first.id, '3');
expect(results.first.pageContent, 'bye');
});

test('Test MemoryVectorStore.fromText', () async {
const embeddings = _FakeEmbeddings();
final store = await MemoryVectorStore.fromText(
ids: const ['1', '2', '3', '4'],
texts: const ['hello', 'hi', 'bye', "what's this"],
embeddings: embeddings,
);

final results = await store.similaritySearch(query: 'chao', k: 1);

expect(results.length, 1);
expect(results.first.id, '3');
expect(results.first.pageContent, 'bye');
});

test('Test MemoryVectorStore with initialMemoryVectors', () async {
const embeddings = _FakeEmbeddings();
final store = MemoryVectorStore(
embeddings: embeddings,
initialMemoryVectors: [
MemoryVector(
document: const Document(id: '1', pageContent: 'hello'),
embedding: _helloVector,
),
MemoryVector(
document: const Document(id: '2', pageContent: 'hi'),
embedding: _hiVector,
),
MemoryVector(
document: const Document(id: '3', pageContent: 'bye'),
embedding: _byeVector,
),
MemoryVector(
document: const Document(id: '4', pageContent: "what's this"),
embedding: _whatsThisVector,
),
],
);

final results = await store.similaritySearch(query: 'chao', k: 1);

expect(results.length, 1);
expect(results.first.id, '3');
expect(results.first.pageContent, 'bye');
});

test('Test delete entry', () async {
const embeddings = _FakeEmbeddings();
final store = await MemoryVectorStore.fromDocuments(
documents: [
const Document(id: '1', pageContent: 'hello'),
const Document(id: '2', pageContent: 'hi'),
const Document(id: '3', pageContent: 'bye'),
const Document(id: '4', pageContent: "what's this"),
],
embeddings: embeddings,
);
await store.delete(ids: ['3']);

final results = await store.similaritySearch(query: 'chao', k: 1);

expect(results.length, 1);
expect(results.first.id, '2');
expect(results.first.pageContent, 'hi');
});

test('Test toMap and fromMap', () async {
const embeddings = _FakeEmbeddings();
final store = MemoryVectorStore(
embeddings: embeddings,
initialMemoryVectors: const [
MemoryVector(
document: Document(id: '1', pageContent: 'hello'),
embedding: [1, 2, 3],
),
MemoryVector(
document: Document(id: '2', pageContent: 'hi'),
embedding: [4, 5, 6],
),
],
);

final map = store.memoryVectors.map((final v) => v.toMap());
final expectedMap = [
{
'document': {
'id': '1',
'pageContent': 'hello',
'metadata': <String, dynamic>{},
},
'embedding': [1.0, 2.0, 3.0],
},
{
'document': {
'id': '2',
'pageContent': 'hi',
'metadata': <String, dynamic>{},
},
'embedding': [4.0, 5.0, 6.0],
},
];
expect(map, expectedMap);

final newMap =
expectedMap.map(MemoryVector.fromMap).toList(growable: false);
expect(newMap, store.memoryVectors);
});
});

group('Search algorithms tests', () {
Expand Down Expand Up @@ -61,24 +164,33 @@ void main() {
}

class _FakeEmbeddings implements Embeddings {
_FakeEmbeddings({
this.vectors = const [],
});

final List<List<double>> vectors;
const _FakeEmbeddings();

@override
Future<List<double>> embedQuery(
final String query,
) async {
return vectors[0];
return embedText(query);
}

@override
Future<List<List<double>>> embedDocuments(
final List<String> documents,
) async {
return vectors;
return [
for (final document in documents) embedText(document),
];
}

List<double> embedText(final String text) {
return switch (text) {
'hello' => _helloVector,
'hi' => _hiVector,
'bye' => _byeVector,
"what's this" => _whatsThisVector,
'chao' => _chaoVector,
_ => throw Exception('Unknown text: $text'),
};
}
}

Expand Down