diff --git a/docs/docs_skeleton/docs/integrations/llms/arcee.ipynb b/docs/docs_skeleton/docs/integrations/llms/arcee.ipynb new file mode 100644 index 00000000000000..0f0fa3461cbace --- /dev/null +++ b/docs/docs_skeleton/docs/integrations/llms/arcee.ipynb @@ -0,0 +1,146 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Arcee\n", + "This notebook demonstrates how to use the `Arcee` class for generating text using Arcee's Domain Adapted Language Models (DALMs)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Setup\n", + "\n", + "Before using Arcee, make sure the Arcee API key is set as `ARCEE_API_KEY` environment variable. You can also pass the api key as a named parameter." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.llms import Arcee\n", + "\n", + "# Create an instance of the Arcee class\n", + "arcee = Arcee(\n", + " model=\"DALM-PubMed\",\n", + " # arcee_api_key=\"ARCEE-API-KEY\" # if not already set in the environment\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Additional Configuration\n", + "\n", + "You can also configure Arcee's parameters such as `arcee_api_url`, `arcee_app_url`, and `model_kwargs` as needed.\n", + "Setting the `model_kwargs` at the object initialization uses the parameters as default for all the subsequent calls to the generate response." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "arcee = Arcee(\n", + " model=\"DALM-Patent\",\n", + " # arcee_api_key=\"ARCEE-API-KEY\", # if not already set in the environment\n", + " arcee_api_url=\"https://custom-api.arcee.ai\", # default is https://api.arcee.ai\n", + " arcee_app_url=\"https://custom-app.arcee.ai\", # default is https://app.arcee.ai\n", + " model_kwargs={\n", + " \"size\": 5,\n", + " \"filters\": [\n", + " {\n", + " \"field_name\": \"document\",\n", + " \"filter_type\": \"fuzzy_search\",\n", + " \"value\": \"Einstein\"\n", + " }\n", + " ]\n", + " }\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Generating Text\n", + "\n", + "You can generate text from Arcee by providing a prompt. Here's an example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Generate text\n", + "prompt = \"Can AI-driven music therapy contribute to the rehabilitation of patients with disorders of consciousness?\"\n", + "response = arcee(prompt)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Additional parameters\n", + "\n", + "Arcee allows you to apply `filters` and set the `size` (in terms of count) of retrieved document(s) to aid text generation. Filters help narrow down the results. Here's how to use these parameters:\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define filters\n", + "filters = [\n", + " {\n", + " \"field_name\": \"document\",\n", + " \"filter_type\": \"fuzzy_search\",\n", + " \"value\": \"Einstein\"\n", + " },\n", + " {\n", + " \"field_name\": \"year\",\n", + " \"filter_type\": \"strict_search\",\n", + " \"value\": \"1905\"\n", + " }\n", + "]\n", + "\n", + "# Generate text with filters and size params\n", + "response = arcee(prompt, size=5, filters=filters)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/docs_skeleton/docs/integrations/retrievers/arcee.ipynb b/docs/docs_skeleton/docs/integrations/retrievers/arcee.ipynb new file mode 100644 index 00000000000000..3cf1b62db3a0c3 --- /dev/null +++ b/docs/docs_skeleton/docs/integrations/retrievers/arcee.ipynb @@ -0,0 +1,141 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Arcee Retriever\n", + "This notebook demonstrates how to use the `ArceeRetriever` class to retrieve relevant document(s) for Arcee's Domain Adapted Language Models (DALMs)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Setup\n", + "\n", + "Before using `ArceeRetriever`, make sure the Arcee API key is set as `ARCEE_API_KEY` environment variable. You can also pass the api key as a named parameter." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.retrievers import ArceeRetriever\n", + "\n", + "retriever = ArceeRetriever(\n", + " model=\"DALM-PubMed\",\n", + " # arcee_api_key=\"ARCEE-API-KEY\" # if not already set in the environment\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Additional Configuration\n", + "\n", + "You can also configure `ArceeRetriever`'s parameters such as `arcee_api_url`, `arcee_app_url`, and `model_kwargs` as needed.\n", + "Setting the `model_kwargs` at the object initialization uses the filters and size as default for all the subsequent retrievals." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "retriever = ArceeRetriever(\n", + " model=\"DALM-PubMed\",\n", + " # arcee_api_key=\"ARCEE-API-KEY\", # if not already set in the environment\n", + " arcee_api_url=\"https://custom-api.arcee.ai\", # default is https://api.arcee.ai\n", + " arcee_app_url=\"https://custom-app.arcee.ai\", # default is https://app.arcee.ai\n", + " model_kwargs={\n", + " \"size\": 5,\n", + " \"filters\": [\n", + " {\n", + " \"field_name\": \"document\",\n", + " \"filter_type\": \"fuzzy_search\",\n", + " \"value\": \"Einstein\"\n", + " }\n", + " ]\n", + " }\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Retrieving documents\n", + "You can retrieve relevant documents from uploaded contexts by providing a query. Here's an example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "query = \"Can AI-driven music therapy contribute to the rehabilitation of patients with disorders of consciousness?\"\n", + "documents = retriever.get_relevant_documents(query=query)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Additional parameters\n", + "\n", + "Arcee allows you to apply `filters` and set the `size` (in terms of count) of retrieved document(s). Filters help narrow down the results. Here's how to use these parameters:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define filters\n", + "filters = [\n", + " {\n", + " \"field_name\": \"document\",\n", + " \"filter_type\": \"fuzzy_search\",\n", + " \"value\": \"Music\"\n", + " },\n", + " {\n", + " \"field_name\": \"year\",\n", + " \"filter_type\": \"strict_search\",\n", + " \"value\": \"1905\"\n", + " }\n", + "]\n", + "\n", + "# Retrieve documents with filters and size params\n", + "documents = retriever.get_relevant_documents(query=query, size=5, filters=filters)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/libs/langchain/langchain/llms/__init__.py b/libs/langchain/langchain/llms/__init__.py index de577c7064c20d..b4d74f470ac1ee 100644 --- a/libs/langchain/langchain/llms/__init__.py +++ b/libs/langchain/langchain/llms/__init__.py @@ -52,6 +52,12 @@ def _import_anyscale() -> Any: return Anyscale +def _import_arcee() -> Any: + from langchain.llms.arcee import Arcee + + return Arcee + + def _import_aviary() -> Any: from langchain.llms.aviary import Aviary @@ -479,6 +485,8 @@ def __getattr__(name: str) -> Any: return _import_anthropic() elif name == "Anyscale": return _import_anyscale() + elif name == "Arcee": + return _import_arcee() elif name == "Aviary": return _import_aviary() elif name == "AzureMLOnlineEndpoint": @@ -633,6 +641,7 @@ def __getattr__(name: str) -> Any: "AmazonAPIGateway", "Anthropic", "Anyscale", + "Arcee", "Aviary", "AzureMLOnlineEndpoint", "AzureOpenAI", @@ -713,6 +722,7 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]: "amazon_bedrock": _import_bedrock, "anthropic": _import_anthropic, "anyscale": _import_anyscale, + "arcee": _import_arcee, "aviary": _import_aviary, "azure": _import_azure_openai, "azureml_endpoint": _import_azureml_endpoint, diff --git a/libs/langchain/langchain/llms/arcee.py b/libs/langchain/langchain/llms/arcee.py new file mode 100644 index 00000000000000..469b5e250fd5dd --- /dev/null +++ b/libs/langchain/langchain/llms/arcee.py @@ -0,0 +1,147 @@ +from typing import Any, Dict, List, Optional + +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.llms.base import LLM +from langchain.pydantic_v1 import Extra, root_validator +from langchain.utilities.arcee import ArceeWrapper, DALMFilter +from langchain.utils import get_from_dict_or_env + + +class Arcee(LLM): + """Arcee's Domain Adapted Language Models (DALMs). + + To use, set the ``ARCEE_API_KEY`` environment variable with your Arcee API key, + or pass ``arcee_api_key`` as a named parameter. + + Example: + .. code-block:: python + + from langchain.llms import Arcee + + arcee = Arcee( + model="DALM-PubMed", + arcee_api_key="ARCEE-API-KEY" + ) + + response = arcee("AI-driven music therapy") + """ + + _client: Optional[ArceeWrapper] = None #: :meta private: + """Arcee _client.""" + + arcee_api_key: str = "" + """Arcee API Key""" + + model: str + """Arcee DALM name""" + + arcee_api_url: str = "https://api.arcee.ai" + """Arcee API URL""" + + arcee_api_version: str = "v2" + """Arcee API Version""" + + arcee_app_url: str = "https://app.arcee.ai" + """Arcee App URL""" + + model_id: str = "" + """Arcee Model ID""" + + model_kwargs: Optional[Dict[str, Any]] = None + """Keyword arguments to pass to the model.""" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + underscore_attrs_are_private = True + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "arcee" + + def __init__(self, **data: Any) -> None: + """Initializes private fields.""" + + super().__init__(**data) + self._client = None + self._client = ArceeWrapper( + arcee_api_key=self.arcee_api_key, + arcee_api_url=self.arcee_api_url, + arcee_api_version=self.arcee_api_version, + model_kwargs=self.model_kwargs, + model_name=self.model, + ) + + self._client.validate_model_training_status() + + @root_validator() + def validate_environments(cls, values: Dict) -> Dict: + """Validate Arcee environment variables.""" + + # validate env vars + values["arcee_api_key"] = get_from_dict_or_env( + values, + "arcee_api_key", + "ARCEE_API_KEY", + ) + + values["arcee_api_url"] = get_from_dict_or_env( + values, + "arcee_api_url", + "ARCEE_API_URL", + ) + + values["arcee_app_url"] = get_from_dict_or_env( + values, + "arcee_app_url", + "ARCEE_APP_URL", + ) + + values["arcee_api_version"] = get_from_dict_or_env( + values, + "arcee_api_version", + "ARCEE_API_VERSION", + ) + + # validate model kwargs + if values["model_kwargs"]: + kw = values["model_kwargs"] + + # validate size + if kw.get("size") is not None: + if not kw.get("size") >= 0: + raise ValueError("`size` must be positive") + + # validate filters + if kw.get("filters") is not None: + if not isinstance(kw.get("filters"), List): + raise ValueError("`filters` must be a list") + for f in kw.get("filters"): + DALMFilter(**f) + + return values + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Generate text from Arcee DALM. + + Args: + prompt: Prompt to generate text from. + size: The max number of context results to retrieve. + Defaults to 3. (Can be less if filters are provided). + filters: Filters to apply to the context dataset. + """ + + try: + if not self._client: + raise ValueError("Client is not initialized.") + return self._client.generate(prompt=prompt, **kwargs) + except Exception as e: + raise Exception(f"Failed to generate text: {e}") from e diff --git a/libs/langchain/langchain/retrievers/__init__.py b/libs/langchain/langchain/retrievers/__init__.py index ba50fdfd57b6bc..9abdb145fd27bc 100644 --- a/libs/langchain/langchain/retrievers/__init__.py +++ b/libs/langchain/langchain/retrievers/__init__.py @@ -18,6 +18,7 @@ CallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun """ +from langchain.retrievers.arcee import ArceeRetriever from langchain.retrievers.arxiv import ArxivRetriever from langchain.retrievers.azure_cognitive_search import AzureCognitiveSearchRetriever from langchain.retrievers.bm25 import BM25Retriever @@ -66,6 +67,7 @@ __all__ = [ "AmazonKendraRetriever", + "ArceeRetriever", "ArxivRetriever", "AzureCognitiveSearchRetriever", "ChatGPTPluginRetriever", diff --git a/libs/langchain/langchain/retrievers/arcee.py b/libs/langchain/langchain/retrievers/arcee.py new file mode 100644 index 00000000000000..6bfba7eef9a6e3 --- /dev/null +++ b/libs/langchain/langchain/retrievers/arcee.py @@ -0,0 +1,136 @@ +from typing import Any, Dict, List, Optional + +from langchain.callbacks.manager import CallbackManagerForRetrieverRun +from langchain.docstore.document import Document +from langchain.pydantic_v1 import Extra, root_validator +from langchain.schema import BaseRetriever +from langchain.utilities.arcee import ArceeWrapper, DALMFilter +from langchain.utils import get_from_dict_or_env + + +class ArceeRetriever(BaseRetriever): + """Document retriever for Arcee's Domain Adapted Language Models (DALMs). + + To use, set the ``ARCEE_API_KEY`` environment variable with your Arcee API key, + or pass ``arcee_api_key`` as a named parameter. + + Example: + .. code-block:: python + + from langchain.retrievers import ArceeRetriever + + retriever = ArceeRetriever( + model="DALM-PubMed", + arcee_api_key="ARCEE-API-KEY" + ) + + documents = retriever.get_relevant_documents("AI-driven music therapy") + """ + + _client: Optional[ArceeWrapper] = None #: :meta private: + """Arcee client.""" + + arcee_api_key: str = "" + """Arcee API Key""" + + model: str + """Arcee DALM name""" + + arcee_api_url: str = "https://api.arcee.ai" + """Arcee API URL""" + + arcee_api_version: str = "v2" + """Arcee API Version""" + + arcee_app_url: str = "https://app.arcee.ai" + """Arcee App URL""" + + model_kwargs: Optional[Dict[str, Any]] = None + """Keyword arguments to pass to the model.""" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + underscore_attrs_are_private = True + + def __init__(self, **data: Any) -> None: + """Initializes private fields.""" + + super().__init__(**data) + + self._client = ArceeWrapper( + arcee_api_key=self.arcee_api_key, + arcee_api_url=self.arcee_api_url, + arcee_api_version=self.arcee_api_version, + model_kwargs=self.model_kwargs, + model_name=self.model, + ) + + self._client.validate_model_training_status() + + @root_validator() + def validate_environments(cls, values: Dict) -> Dict: + """Validate Arcee environment variables.""" + + # validate env vars + values["arcee_api_key"] = get_from_dict_or_env( + values, + "arcee_api_key", + "ARCEE_API_KEY", + ) + + values["arcee_api_url"] = get_from_dict_or_env( + values, + "arcee_api_url", + "ARCEE_API_URL", + ) + + values["arcee_app_url"] = get_from_dict_or_env( + values, + "arcee_app_url", + "ARCEE_APP_URL", + ) + + values["arcee_api_version"] = get_from_dict_or_env( + values, + "arcee_api_version", + "ARCEE_API_VERSION", + ) + + # validate model kwargs + if values["model_kwargs"]: + kw = values["model_kwargs"] + + # validate size + if kw.get("size") is not None: + if not kw.get("size") >= 0: + raise ValueError("`size` must not be negative.") + + # validate filters + if kw.get("filters") is not None: + if not isinstance(kw.get("filters"), List): + raise ValueError("`filters` must be a list.") + for f in kw.get("filters"): + DALMFilter(**f) + + return values + + def _get_relevant_documents( + self, query: str, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any + ) -> List[Document]: + """Retrieve {size} contexts with your retriever for a given query + + Args: + query: Query to submit to the model + size: The max number of context results to retrieve. + Defaults to 3. (Can be less if filters are provided). + filters: Filters to apply to the context dataset. + """ + + try: + if not self._client: + raise ValueError("Client is not initialized.") + return self._client.retrieve(query=query, **kwargs) + except Exception as e: + raise ValueError(f"Error while retrieving documents: {e}") from e diff --git a/libs/langchain/langchain/utilities/__init__.py b/libs/langchain/langchain/utilities/__init__.py index e964baa90111cf..a091c8d7f47720 100644 --- a/libs/langchain/langchain/utilities/__init__.py +++ b/libs/langchain/langchain/utilities/__init__.py @@ -5,6 +5,7 @@ """ from langchain.utilities.alpha_vantage import AlphaVantageAPIWrapper from langchain.utilities.apify import ApifyWrapper +from langchain.utilities.arcee import ArceeWrapper from langchain.utilities.arxiv import ArxivAPIWrapper from langchain.utilities.awslambda import LambdaWrapper from langchain.utilities.bash import BashProcess @@ -41,6 +42,7 @@ __all__ = [ "AlphaVantageAPIWrapper", "ApifyWrapper", + "ArceeWrapper", "ArxivAPIWrapper", "BashProcess", "BibtexparserWrapper", diff --git a/libs/langchain/langchain/utilities/arcee.py b/libs/langchain/langchain/utilities/arcee.py new file mode 100644 index 00000000000000..f73cdcad9c8f46 --- /dev/null +++ b/libs/langchain/langchain/utilities/arcee.py @@ -0,0 +1,189 @@ +# This module contains utility classes and functions for interacting with Arcee API. +# For more information and updates, refer to the Arcee utils page: +# [https://github.com/arcee-ai/arcee-python/blob/main/arcee/dalm.py] + +from enum import Enum +from typing import Any, Dict, List, Literal, Mapping, Optional, Union + +import requests + +from langchain.pydantic_v1 import BaseModel, root_validator +from langchain.schema.retriever import Document + + +class ArceeRoute(str, Enum): + generate = "models/generate" + retrieve = "models/retrieve" + model_training_status = "models/status/{id_or_name}" + + +class DALMFilterType(str, Enum): + fuzzy_search = "fuzzy_search" + strict_search = "strict_search" + + +class DALMFilter(BaseModel): + """Filters available for a dalm retrieval and generation + + Arguments: + field_name: The field to filter on. Can be 'document' or 'name' to filter + on your document's raw text or title. Any other field will be presumed + to be a metadata field you included when uploading your context data + filter_type: Currently 'fuzzy_search' and 'strict_search' are supported. + 'fuzzy_search' means a fuzzy search on the provided field is performed. + The exact strict doesn't need to exist in the document + for this to find a match. + Very useful for scanning a document for some keyword terms. + 'strict_search' means that the exact string must appear + in the provided field. + This is NOT an exact eq filter. ie a document with content + "the happy dog crossed the street" will match on a strict_search of + "dog" but won't match on "the dog". + Python equivalent of `return search_string in full_string`. + value: The actual value to search for in the context data/metadata + """ + + field_name: str + filter_type: DALMFilterType + value: str + _is_metadata: bool = False + + @root_validator() + def set_meta(cls, values: Dict) -> Dict: + """document and name are reserved arcee keys. Anything else is metadata""" + values["_is_meta"] = values.get("field_name") not in ["document", "name"] + return values + + +class ArceeWrapper: + def __init__( + self, + arcee_api_key: str, + arcee_api_url: str, + arcee_api_version: str, + model_kwargs: Optional[Dict[str, Any]], + model_name: str, + ): + self.arcee_api_key = arcee_api_key + self.model_kwargs = model_kwargs + self.arcee_api_url = arcee_api_url + self.arcee_api_version = arcee_api_version + + try: + route = ArceeRoute.model_training_status.value.format(id_or_name=model_name) + response = self._make_request("get", route) + self.model_id = response.get("model_id") + self.model_training_status = response.get("status") + except Exception as e: + raise ValueError( + f"Error while validating model training status for '{model_name}': {e}" + ) from e + + def validate_model_training_status(self) -> None: + if self.model_training_status != "training_complete": + raise Exception( + f"Model {self.model_id} is not ready. " + "Please wait for training to complete." + ) + + def _make_request( + self, + method: Literal["post", "get"], + route: Union[ArceeRoute, str], + body: Optional[Mapping[str, Any]] = None, + params: Optional[dict] = None, + headers: Optional[dict] = None, + ) -> dict: + """Make a request to the Arcee API + Args: + method: The HTTP method to use + route: The route to call + body: The body of the request + params: The query params of the request + headers: The headers of the request + """ + headers = self._make_request_headers(headers=headers) + url = self._make_request_url(route=route) + + req_type = getattr(requests, method) + + response = req_type(url, json=body, params=params, headers=headers) + if response.status_code not in (200, 201): + raise Exception(f"Failed to make request. Response: {response.text}") + return response.json() + + def _make_request_headers(self, headers: Optional[Dict] = None) -> Dict: + headers = headers or {} + internal_headers = { + "X-Token": self.arcee_api_key, + "Content-Type": "application/json", + } + headers.update(internal_headers) + return headers + + def _make_request_url(self, route: Union[ArceeRoute, str]) -> str: + return f"{self.arcee_api_url}/{self.arcee_api_version}/{route}" + + def _make_request_body_for_models( + self, prompt: str, **kwargs: Mapping[str, Any] + ) -> Mapping[str, Any]: + """Make the request body for generate/retrieve models endpoint""" + _model_kwargs = self.model_kwargs or {} + _params = {**_model_kwargs, **kwargs} + + filters = [DALMFilter(**f) for f in _params.get("filters", [])] + return dict( + model_id=self.model_id, + query=prompt, + size=_params.get("size", 3), + filters=filters, + id=self.model_id, + ) + + def generate( + self, + prompt: str, + **kwargs: Any, + ) -> str: + """Generate text from Arcee DALM. + + Args: + prompt: Prompt to generate text from. + size: The max number of context results to retrieve. Defaults to 3. + (Can be less if filters are provided). + filters: Filters to apply to the context dataset. + """ + + response = self._make_request( + method="post", + route=ArceeRoute.generate, + body=self._make_request_body_for_models( + prompt=prompt, + **kwargs, + ), + ) + return response["text"] + + def retrieve( + self, + query: str, + **kwargs: Any, + ) -> List[Document]: + """Retrieve {size} contexts with your retriever for a given query + + Args: + query: Query to submit to the model + size: The max number of context results to retrieve. Defaults to 3. + (Can be less if filters are provided). + filters: Filters to apply to the context dataset. + """ + + response = self._make_request( + method="post", + route=ArceeRoute.retrieve, + body=self._make_request_body_for_models( + prompt=query, + **kwargs, + ), + ) + return [Document(**doc) for doc in response["documents"]]