Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(retrievers)!: Move all retriever config options to RetrieverOptions #248

Merged
merged 1 commit into from
Nov 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/expression_language/cookbook/retrieval.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,9 @@ In this example, we will add a memory to the chain and return the source documen

```dart
final retriever = vectorStore.asRetriever(
searchType: const VectorStoreSimilaritySearch(k: 1),
defaultOptions: const VectorStoreRetrieverOptions(
searchType: VectorStoreSimilaritySearch(k: 1),
),
);
final model = ChatOpenAI(apiKey: openaiApiKey);
final stringOutputParser = const StringOutputParser();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,9 @@ Future<void> _conversationalRetrievalChainMemoryAndDocs() async {
final vectorStore = Chroma(embeddings: embeddings);

final retriever = vectorStore.asRetriever(
searchType: const VectorStoreSimilaritySearch(k: 1),
defaultOptions: const VectorStoreRetrieverOptions(
searchType: VectorStoreSimilaritySearch(k: 1),
),
);
final model = ChatOpenAI(apiKey: openaiApiKey);
const stringOutputParser = StringOutputParser();
Expand Down
6 changes: 3 additions & 3 deletions packages/langchain/lib/src/chains/retrieval_qa.dart
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class RetrievalQAChain extends BaseChain {
});

/// Retriever to use.
final BaseRetriever retriever;
final Retriever retriever;

/// Chain to use to combine the documents.
final BaseCombineDocumentsChain combineDocumentsChain;
Expand Down Expand Up @@ -111,7 +111,7 @@ class RetrievalQAChain extends BaseChain {
String get chainType => 'retrieval_qa';

/// Creates a [RetrievalQAChain] from a [BaseLanguageModel] and a
/// [BaseRetriever].
/// [Retriever].
///
/// By default, it uses a prompt template optimized for question answering
/// that includes the retrieved documents and the question.
Expand All @@ -134,7 +134,7 @@ class RetrievalQAChain extends BaseChain {
/// [prompt]. Use 'context' and 'question' as the variable names.
factory RetrievalQAChain.fromLlm({
required final BaseLanguageModel llm,
required final BaseRetriever retriever,
required final Retriever retriever,
final PromptTemplate? prompt,
}) {
return RetrievalQAChain(
Expand Down
18 changes: 12 additions & 6 deletions packages/langchain/lib/src/documents/retrievers/base.dart
Original file line number Diff line number Diff line change
@@ -1,27 +1,33 @@
import '../../core/core.dart';
import '../models/models.dart';
import 'models/models.dart';

/// {@template base_retriever}
/// Base Index class. All indexes should extend this class.
/// {@endtemplate}
abstract class BaseRetriever
extends Runnable<String, BaseLangChainOptions, List<Document>> {
abstract class Retriever<Options extends RetrieverOptions>
extends Runnable<String, Options, List<Document>> {
/// {@macro base_retriever}
const BaseRetriever();
const Retriever();

/// Get the most relevant documents for a given query.
///
/// - [input] - The query to search for.
/// - [options] - Retrieval options.
@override
Future<List<Document>> invoke(
final String input, {
final BaseLangChainOptions? options,
final Options? options,
}) {
return getRelevantDocuments(input);
return getRelevantDocuments(input, options: options);
}

/// Get the most relevant documents for a given query.
///
/// - [query] - The query to search for.
Future<List<Document>> getRelevantDocuments(final String query);
/// - [options] - Retrieval options.
Future<List<Document>> getRelevantDocuments(
final String query, {
final Options? options,
});
}
8 changes: 6 additions & 2 deletions packages/langchain/lib/src/documents/retrievers/fake.dart
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
import '../models/models.dart';
import 'base.dart';
import 'models/models.dart';

/// {@template fake_retriever}
/// A retriever that returns a fixed list of documents.
/// This class is meant for testing purposes only.
/// {@endtemplate}
class FakeRetriever extends BaseRetriever {
class FakeRetriever extends Retriever<RetrieverOptions> {
/// {@macro fake_retriever}
const FakeRetriever(this.docs);

/// The documents to return.
final List<Document> docs;

@override
Future<List<Document>> getRelevantDocuments(final String query) {
Future<List<Document>> getRelevantDocuments(
final String query, {
final RetrieverOptions? options,
}) {
return Future.value(docs);
}
}
29 changes: 29 additions & 0 deletions packages/langchain/lib/src/documents/retrievers/models/models.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import 'package:meta/meta.dart';

import '../../../core/base.dart';
import '../../vector_stores/models/models.dart';
import '../retrievers.dart';

/// {@template retriever_options}
/// Base class for [Retriever] options.
/// {@endtemplate}
@immutable
class RetrieverOptions extends BaseLangChainOptions {
/// {@macro retriever_options}
const RetrieverOptions();
}

/// {@template vector_store_retriever_options}
/// Options for [VectorStoreRetriever].
/// {@endtemplate}
class VectorStoreRetrieverOptions extends RetrieverOptions {
/// {@macro vector_store_retriever_options}
const VectorStoreRetrieverOptions({
this.searchType = const VectorStoreSimilaritySearch(),
});

/// The type of search to perform, either:
/// - [VectorStoreSearchType.similarity] (default)
/// - [VectorStoreSearchType.mmr]
final VectorStoreSearchType searchType;
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
export 'base.dart';
export 'fake.dart';
export 'models/models.dart';
export 'vector_store.dart';
20 changes: 14 additions & 6 deletions packages/langchain/lib/src/documents/retrievers/vector_store.dart
Original file line number Diff line number Diff line change
@@ -1,25 +1,33 @@
import '../models/models.dart';
import '../vector_stores/vector_stores.dart';
import 'base.dart';
import 'models/models.dart';

/// {@template vector_store_retriever}
/// A retriever that uses a vector store to retrieve documents.
/// {@endtemplate}
class VectorStoreRetriever<V extends VectorStore> extends BaseRetriever {
class VectorStoreRetriever<V extends VectorStore>
extends Retriever<VectorStoreRetrieverOptions> {
/// {@macro vector_store_retriever}
const VectorStoreRetriever({
required this.vectorStore,
this.searchType = const VectorStoreSimilaritySearch(),
this.defaultOptions = const VectorStoreRetrieverOptions(),
});

/// The vector store to retrieve documents from.
final V vectorStore;

/// The type of search to perform.
final VectorStoreSearchType searchType;
/// Default options for this retriever.
final VectorStoreRetrieverOptions defaultOptions;

@override
Future<List<Document>> getRelevantDocuments(final String query) {
return vectorStore.search(query: query, searchType: searchType);
Future<List<Document>> getRelevantDocuments(
final String query, {
final VectorStoreRetrieverOptions? options,
}) {
return vectorStore.search(
query: query,
searchType: options?.searchType ?? defaultOptions.searchType,
);
}
}
11 changes: 5 additions & 6 deletions packages/langchain/lib/src/documents/vector_stores/base.dart
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// ignore_for_file: avoid_unused_constructor_parameters
import '../embeddings/base.dart';
import '../models/models.dart';
import '../retrievers/models/models.dart';
import '../retrievers/vector_store.dart';
import 'models/models.dart';

Expand Down Expand Up @@ -170,16 +171,14 @@ abstract class VectorStore {

/// Returns a [VectorStoreRetriever] that uses this vector store.
///
/// - [searchType] is the type of search to perform, either
/// [VectorStoreSearchType.similarity] (default) or
/// [VectorStoreSearchType.mmr].
/// - [defaultOptions] are the default options for the retriever.
VectorStoreRetriever asRetriever({
final VectorStoreSearchType searchType =
const VectorStoreSimilaritySearch(),
final VectorStoreRetrieverOptions defaultOptions =
const VectorStoreRetrieverOptions(),
}) {
return VectorStoreRetriever(
vectorStore: this,
searchType: searchType,
defaultOptions: defaultOptions,
);
}
}