From e3b0cc20c42bd51f6425d7044b2c33ac1bd88beb Mon Sep 17 00:00:00 2001 From: Shubham Date: Thu, 5 Oct 2023 06:56:41 +0530 Subject: [PATCH 1/7] arcee utilities --- libs/langchain/langchain/utilities/arcee.py | 146 ++++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 libs/langchain/langchain/utilities/arcee.py diff --git a/libs/langchain/langchain/utilities/arcee.py b/libs/langchain/langchain/utilities/arcee.py new file mode 100644 index 00000000000000..76d18fd02febd6 --- /dev/null +++ b/libs/langchain/langchain/utilities/arcee.py @@ -0,0 +1,146 @@ +from enum import Enum + +from langchain.pydantic_v1 import BaseModel, root_validator +from typing import Any, Dict, TYPE_CHECKING, List, Mapping, Union, Optional, Literal +import requests + + +class ArceeRoute(str, Enum): + generate = "models/generate" + retrieve = "models/retrieve" + model_training_status = "models/status/{model_id}" + + +class DALMFilterType(str, Enum): + fuzzy_search = "fuzzy_search" + strict_search = "strict_search" + + +class DALMFilter(BaseModel): + """Filters available for a dalm retrieval and generation + + Arguments: + field_name: The field to filter on. Can be 'document' or 'name' to filter on your document's raw text or title + Any other field will be presumed to be a metadata field you included when uploading your context data + filter_type: Currently 'fuzzy_search' and 'strict_search' are supported. More to come soon! + 'fuzzy_search' means a fuzzy search on the provided field will be performed. The exact strict doesn't + need to exist in the document for this to find a match. Very useful for scanning a document for some + keyword terms + 'strict_search' means that the exact string must appear in the provided field. This is NOT an exact eq + filter. ie a document with content "the happy dog crossed the street" will match on a strict_search of "dog" + but won't match on "the dog". Python equivalent of `return search_string in full_string` + value: The actual value to search for in the context data/metadata + """ + + field_name: str + filter_type: DALMFilterType + value: str + _is_metadata: bool = False + + @root_validator + def set_meta(self) -> "DALMFilter": + """document and name are reserved arcee keys. Anything else is metadata""" + self._is_metadata = self.field_name not in ["document", "name"] + return self + + +class ArceeWrapper(BaseModel): + """Wrapper for Arcee APIs""" + + arcee_api_url: str = "https://api.arcee.ai" + """Arcee API URL""" + + arcee_api_version: str = "v2" + """Arcee API Version""" + + arcee_app_url: str = "https://app.arcee.ai" + """Arcee App URL""" + + model_id: str + """Arcee Model ID""" + + model_kwargs: Mapping[str, Any] = None + """Keyword arguments to pass to the model.""" + + @root_validator() + def validate_model_kwargs(cls, values: Dict) -> Dict: + """Validate that model kwargs are valid.""" + + print("validating kwargs: ", values.get("model_kwargs")) + + if values.get("model_kwargs") is not None: + kw = values.get("model_kwargs") + + # validate size + if kw.get("size") is not None: + if not kw.get("size") >= 0: + raise ValueError("`size` must be positive") + + # validate filters + if kw.get("filters") is not None: + if not isinstance(kw.get("filters"), List): + raise ValueError("`filters` must be a list") + for f in kw.get("filters"): + DALMFilter.validate(f) + + @classmethod + def _make_request_headers(cls, headers: Optional[Dict] = None) -> Dict: + """Make the request headers""" + headers = headers or {} + internal_headers = { + "X-Token": f"{cls.arcee_api_key}", + "Content-Type": "application/json", + } + headers.update(**internal_headers) + return headers + + @classmethod + def _make_request_url(cls, route: ArceeRoute) -> str: + """Make the request url""" + return f"{cls.arcee_api_url}/{cls.arcee_api_version}/{route}" + + def _make_request_body_for_models( + self, prompt: str, **kwargs: Mapping[str, Any] + ) -> Mapping[str, Any]: + """Build the kwargs for the Post request, used by sync + + Args: + prompt (str): prompt used in query + kwargs (dict): model kwargs in payload + + Returns: + Dict[str, Union[str,dict]]: _description_ + """ + _model_kwargs = self.model_kwargs or {} + _params = {**_model_kwargs, **kwargs} + + # validate filters + filters = [DALMFilter.validate(f) for f in _params.get("filters", [])] + + return dict( + model_id=self.model_id, + query=prompt, + size=_params.get("size", 3), + filters=filters, + id=self.model_id, + ) + + @classmethod + def make_request( + cls, + request: Literal["post", "get"], + route: ArceeRoute, + body: Optional[dict[str, Any]] = None, + params: Optional[dict[str, Any]] = None, + headers: Optional[dict[str, Any]] = None, + ) -> dict[str, str]: + """Makes the request""" + headers = cls._make_request_headers(headers) + url = cls.make_request_url(route) + + req_type = getattr(requests, request) + + response = req_type(url, json=body, params=params, headers=headers) + if response.status_code not in (200, 201): + raise Exception(f"Failed to make request. Response: {response.text}") + return response.json() From fa195ac0c0e09fba311605dea33f35ab615ef33c Mon Sep 17 00:00:00 2001 From: Shubham Kushwaha Date: Thu, 5 Oct 2023 23:05:36 +0530 Subject: [PATCH 2/7] arcee llm --- libs/langchain/langchain/llms/__init__.py | 10 ++ libs/langchain/langchain/llms/arcee.py | 146 +++++++++++++++++ .../langchain/langchain/utilities/__init__.py | 1 + libs/langchain/langchain/utilities/arcee.py | 152 ++++++++---------- 4 files changed, 225 insertions(+), 84 deletions(-) create mode 100644 libs/langchain/langchain/llms/arcee.py diff --git a/libs/langchain/langchain/llms/__init__.py b/libs/langchain/langchain/llms/__init__.py index de577c7064c20d..b4d74f470ac1ee 100644 --- a/libs/langchain/langchain/llms/__init__.py +++ b/libs/langchain/langchain/llms/__init__.py @@ -52,6 +52,12 @@ def _import_anyscale() -> Any: return Anyscale +def _import_arcee() -> Any: + from langchain.llms.arcee import Arcee + + return Arcee + + def _import_aviary() -> Any: from langchain.llms.aviary import Aviary @@ -479,6 +485,8 @@ def __getattr__(name: str) -> Any: return _import_anthropic() elif name == "Anyscale": return _import_anyscale() + elif name == "Arcee": + return _import_arcee() elif name == "Aviary": return _import_aviary() elif name == "AzureMLOnlineEndpoint": @@ -633,6 +641,7 @@ def __getattr__(name: str) -> Any: "AmazonAPIGateway", "Anthropic", "Anyscale", + "Arcee", "Aviary", "AzureMLOnlineEndpoint", "AzureOpenAI", @@ -713,6 +722,7 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]: "amazon_bedrock": _import_bedrock, "anthropic": _import_anthropic, "anyscale": _import_anyscale, + "arcee": _import_arcee, "aviary": _import_aviary, "azure": _import_azure_openai, "azureml_endpoint": _import_azureml_endpoint, diff --git a/libs/langchain/langchain/llms/arcee.py b/libs/langchain/langchain/llms/arcee.py new file mode 100644 index 00000000000000..a51ff974a14be9 --- /dev/null +++ b/libs/langchain/langchain/llms/arcee.py @@ -0,0 +1,146 @@ +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.utilities.arcee import ArceeClient, ArceeRoute, DALMFilter +from langchain.llms.base import LLM +from langchain.pydantic_v1 import Extra, root_validator +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from langchain.utils import get_from_dict_or_env + + +class Arcee(LLM): + """Arcee's Domain Adapted Language Models (DALMs). + + To use, set the ``ARCEE_API_KEY`` environment variable with your Arcee API key, + or pass ``arcee_api_key`` as a named parameter. + + Example: + .. code-block:: python + + from langchain.llms import Arcee + + arcee = Arcee( + model="DPT-PubMed-7b", + arcee_api_key="DUMMY-KEY" + ) + + response = arcee("Can?") + """ + + client: ArceeClient = None #: :meta private: + """Arcee client.""" + + arcee_api_key: str = "" + """Arcee API Key""" + + model: str + """Arcee DALM name""" + + arcee_api_url: str = "https://api.arcee.ai" + """Arcee API URL""" + + arcee_api_version: str = "v2" + """Arcee API Version""" + + arcee_app_url: str = "https://app.arcee.ai" + """Arcee App URL""" + + model_id: str = "" + """Arcee Model ID""" + + model_kwargs: Optional[Dict] = None + """Keyword arguments to pass to the model.""" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "arcee" + + @root_validator() + def validate_environments(cls, values: Dict) -> Dict: + """Validate Arcee environment variables.""" + + # validate env vars + values["arcee_api_key"] = get_from_dict_or_env( + values, + "arcee_api_key", + "ARCEE_API_KEY", + ) + + values["arcee_api_url"] = get_from_dict_or_env( + values, + "arcee_api_url", + "ARCEE_API_URL", + ) + + values["arcee_app_url"] = get_from_dict_or_env( + values, + "arcee_app_url", + "ARCEE_APP_URL", + ) + + values["arcee_api_version"] = get_from_dict_or_env( + values, + "arcee_api_version", + "ARCEE_API_VERSION", + ) + + # validate model kwargs + if values.get("model_kwargs") is not None: + kw = values.get("model_kwargs") + + # validate size + if kw.get("size") is not None: + if not kw.get("size") >= 0: + raise ValueError("`size` must be positive") + + # validate filters + if kw.get("filters") is not None: + if not isinstance(kw.get("filters"), List): + raise ValueError("`filters` must be a list") + for f in kw.get("filters"): + DALMFilter(**f) + + values["client"] = ArceeClient( + arcee_api_key=values.get("arcee_api_key"), + arcee_api_url=values.get("arcee_api_url"), + arcee_api_version=values.get("arcee_api_version"), + model_kwargs=values.get("model_kwargs"), + model_name=values.get("model"), + ) + + # validate model training status + values.get("client").validate_model_training_status() + + return values + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Generate text from Arcee DALM. + + Args: + prompt: Prompt to generate text from. + size: The max number of context results to retrieve. Defaults to 3. (Can be less if filters are provided). + filters: Filters to apply to the context results. + """ + + try: + response = self.client.make_request( + method="post", + route=ArceeRoute.generate.value.format(id_or_name=self.model_id), + body=self.client.make_request_body_for_models( + prompt=prompt, + **kwargs, + ), + ) + return response["text"] + except Exception as e: + raise ValueError(f"Error while generating text: {e}") from e diff --git a/libs/langchain/langchain/utilities/__init__.py b/libs/langchain/langchain/utilities/__init__.py index e964baa90111cf..3402aeb5a35adc 100644 --- a/libs/langchain/langchain/utilities/__init__.py +++ b/libs/langchain/langchain/utilities/__init__.py @@ -5,6 +5,7 @@ """ from langchain.utilities.alpha_vantage import AlphaVantageAPIWrapper from langchain.utilities.apify import ApifyWrapper +from langchain.utilities.arcee import ArceeClient, ArceeRoute, DALMFilter from langchain.utilities.arxiv import ArxivAPIWrapper from langchain.utilities.awslambda import LambdaWrapper from langchain.utilities.bash import BashProcess diff --git a/libs/langchain/langchain/utilities/arcee.py b/libs/langchain/langchain/utilities/arcee.py index 76d18fd02febd6..2430ee1bbdcdbd 100644 --- a/libs/langchain/langchain/utilities/arcee.py +++ b/libs/langchain/langchain/utilities/arcee.py @@ -8,7 +8,7 @@ class ArceeRoute(str, Enum): generate = "models/generate" retrieve = "models/retrieve" - model_training_status = "models/status/{model_id}" + model_training_status = "models/status/{id_or_name}" class DALMFilterType(str, Enum): @@ -37,86 +37,90 @@ class DALMFilter(BaseModel): value: str _is_metadata: bool = False - @root_validator - def set_meta(self) -> "DALMFilter": - """document and name are reserved arcee keys. Anything else is metadata""" - self._is_metadata = self.field_name not in ["document", "name"] - return self - - -class ArceeWrapper(BaseModel): - """Wrapper for Arcee APIs""" - - arcee_api_url: str = "https://api.arcee.ai" - """Arcee API URL""" - - arcee_api_version: str = "v2" - """Arcee API Version""" - - arcee_app_url: str = "https://app.arcee.ai" - """Arcee App URL""" - - model_id: str - """Arcee Model ID""" - - model_kwargs: Mapping[str, Any] = None - """Keyword arguments to pass to the model.""" - @root_validator() - def validate_model_kwargs(cls, values: Dict) -> Dict: - """Validate that model kwargs are valid.""" - - print("validating kwargs: ", values.get("model_kwargs")) + def set_meta(cls, values: Dict) -> Dict: + """document and name are reserved arcee keys. Anything else is metadata""" + values["_is_meta"] = values.get("field_name") not in ["document", "name"] + return values + + +class ArceeClient: + def __init__( + self, + arcee_api_key: str, + arcee_api_url: str, + arcee_api_version: str, + model_kwargs: Dict[str, Any], + model_name: str, + ): + self.arcee_api_key = arcee_api_key + self.model_kwargs = model_kwargs + self.arcee_api_url = arcee_api_url + self.arcee_api_version = arcee_api_version + + try: + route = ArceeRoute.model_training_status.value.format(id_or_name=model_name) + # response = self.make_request("get", route) + response = {"status": "training_complete", "model_id": "123"} + self.model_id = response.get("model_id") + self.model_training_status = response.get("status") + except Exception as e: + raise ValueError( + f"Error while validating model training status for '{model_name}': {e}" + ) from e + + def validate_model_training_status(self): + if self.model_training_status != "training_complete": + raise Exception( + f"Model {self.model_id} is not ready. Please wait for training to complete." + ) - if values.get("model_kwargs") is not None: - kw = values.get("model_kwargs") + def make_request( + self, + method: Literal["post", "get"], + route: ArceeRoute, + body: Optional[dict] = None, + params: Optional[dict] = None, + headers: Optional[dict] = None, + ) -> dict: + """Make a request to the Arcee API + Args: + method: The HTTP method to use + route: The route to call + body: The body of the request + params: The query params of the request + headers: The headers of the request + """ + headers = self._make_request_headers(headers=headers) + url = self._make_request_url(route=route) - # validate size - if kw.get("size") is not None: - if not kw.get("size") >= 0: - raise ValueError("`size` must be positive") + req_type = getattr(requests, method) - # validate filters - if kw.get("filters") is not None: - if not isinstance(kw.get("filters"), List): - raise ValueError("`filters` must be a list") - for f in kw.get("filters"): - DALMFilter.validate(f) + response = req_type(url, json=body, params=params, headers=headers) + if response.status_code not in (200, 201): + raise Exception(f"Failed to make request. Response: {response.text}") + return response.json() - @classmethod - def _make_request_headers(cls, headers: Optional[Dict] = None) -> Dict: - """Make the request headers""" + def _make_request_headers(self, headers: Optional[Dict] = None) -> Dict: headers = headers or {} internal_headers = { - "X-Token": f"{cls.arcee_api_key}", + "X-Token": self.arcee_api_key, "Content-Type": "application/json", } - headers.update(**internal_headers) + headers.update(internal_headers) return headers - @classmethod - def _make_request_url(cls, route: ArceeRoute) -> str: - """Make the request url""" - return f"{cls.arcee_api_url}/{cls.arcee_api_version}/{route}" + def _make_request_url(self, route: ArceeRoute) -> str: + return f"{self.arcee_api_url}/{self.arcee_api_version}/{route}" - def _make_request_body_for_models( + def make_request_body_for_models( self, prompt: str, **kwargs: Mapping[str, Any] ) -> Mapping[str, Any]: - """Build the kwargs for the Post request, used by sync - - Args: - prompt (str): prompt used in query - kwargs (dict): model kwargs in payload - - Returns: - Dict[str, Union[str,dict]]: _description_ - """ + """Make the request body for generate/retrieve models endpoint""" _model_kwargs = self.model_kwargs or {} _params = {**_model_kwargs, **kwargs} - # validate filters - filters = [DALMFilter.validate(f) for f in _params.get("filters", [])] - + filters = [DALMFilter(**f) for f in _params.get("filters", [])] return dict( model_id=self.model_id, query=prompt, @@ -124,23 +128,3 @@ def _make_request_body_for_models( filters=filters, id=self.model_id, ) - - @classmethod - def make_request( - cls, - request: Literal["post", "get"], - route: ArceeRoute, - body: Optional[dict[str, Any]] = None, - params: Optional[dict[str, Any]] = None, - headers: Optional[dict[str, Any]] = None, - ) -> dict[str, str]: - """Makes the request""" - headers = cls._make_request_headers(headers) - url = cls.make_request_url(route) - - req_type = getattr(requests, request) - - response = req_type(url, json=body, params=params, headers=headers) - if response.status_code not in (200, 201): - raise Exception(f"Failed to make request. Response: {response.text}") - return response.json() From 8fe29cddb4f3852fc4981d1d46eeb676f9a2efc7 Mon Sep 17 00:00:00 2001 From: Shubham Kushwaha Date: Fri, 6 Oct 2023 04:25:33 +0530 Subject: [PATCH 3/7] arcee retriever --- libs/langchain/langchain/llms/arcee.py | 47 ++++--- .../langchain/retrievers/__init__.py | 2 + libs/langchain/langchain/retrievers/arcee.py | 117 ++++++++++++++++++ .../langchain/langchain/utilities/__init__.py | 3 +- libs/langchain/langchain/utilities/arcee.py | 60 ++++++++- 5 files changed, 197 insertions(+), 32 deletions(-) create mode 100644 libs/langchain/langchain/retrievers/arcee.py diff --git a/libs/langchain/langchain/llms/arcee.py b/libs/langchain/langchain/llms/arcee.py index a51ff974a14be9..deb3885839c227 100644 --- a/libs/langchain/langchain/llms/arcee.py +++ b/libs/langchain/langchain/llms/arcee.py @@ -1,5 +1,5 @@ from langchain.callbacks.manager import CallbackManagerForLLMRun -from langchain.utilities.arcee import ArceeClient, ArceeRoute, DALMFilter +from langchain.utilities.arcee import ArceeWrapper, DALMFilter from langchain.llms.base import LLM from langchain.pydantic_v1 import Extra, root_validator from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union @@ -25,8 +25,8 @@ class Arcee(LLM): response = arcee("Can?") """ - client: ArceeClient = None #: :meta private: - """Arcee client.""" + _client: ArceeWrapper = None #: :meta private: + """Arcee _client.""" arcee_api_key: str = "" """Arcee API Key""" @@ -53,12 +53,28 @@ class Config: """Configuration for this pydantic object.""" extra = Extra.forbid + underscore_attrs_are_private = True @property def _llm_type(self) -> str: """Return type of llm.""" return "arcee" + def __init__(self, **data: Any) -> None: + """Initializes private fields.""" + + super().__init__(**data) + + self._client = ArceeWrapper( + arcee_api_key=self.arcee_api_key, + arcee_api_url=self.arcee_api_url, + arcee_api_version=self.arcee_api_version, + model_kwargs=self.model_kwargs, + model_name=self.model, + ) + + self._client.validate_model_training_status() + @root_validator() def validate_environments(cls, values: Dict) -> Dict: """Validate Arcee environment variables.""" @@ -104,17 +120,6 @@ def validate_environments(cls, values: Dict) -> Dict: for f in kw.get("filters"): DALMFilter(**f) - values["client"] = ArceeClient( - arcee_api_key=values.get("arcee_api_key"), - arcee_api_url=values.get("arcee_api_url"), - arcee_api_version=values.get("arcee_api_version"), - model_kwargs=values.get("model_kwargs"), - model_name=values.get("model"), - ) - - # validate model training status - values.get("client").validate_model_training_status() - return values def _call( @@ -129,18 +134,10 @@ def _call( Args: prompt: Prompt to generate text from. size: The max number of context results to retrieve. Defaults to 3. (Can be less if filters are provided). - filters: Filters to apply to the context results. + filters: Filters to apply to the context dataset. """ try: - response = self.client.make_request( - method="post", - route=ArceeRoute.generate.value.format(id_or_name=self.model_id), - body=self.client.make_request_body_for_models( - prompt=prompt, - **kwargs, - ), - ) - return response["text"] + return self._client.generate(prompt=prompt, **kwargs) except Exception as e: - raise ValueError(f"Error while generating text: {e}") from e + raise Exception(f"Failed to generate text: {e}") from e diff --git a/libs/langchain/langchain/retrievers/__init__.py b/libs/langchain/langchain/retrievers/__init__.py index ba50fdfd57b6bc..b93a7e652aba23 100644 --- a/libs/langchain/langchain/retrievers/__init__.py +++ b/libs/langchain/langchain/retrievers/__init__.py @@ -19,6 +19,7 @@ """ from langchain.retrievers.arxiv import ArxivRetriever +from langchain.retrievers.arcee import ArceeRetriever from langchain.retrievers.azure_cognitive_search import AzureCognitiveSearchRetriever from langchain.retrievers.bm25 import BM25Retriever from langchain.retrievers.chaindesk import ChaindeskRetriever @@ -66,6 +67,7 @@ __all__ = [ "AmazonKendraRetriever", + "ArceeRetriever", "ArxivRetriever", "AzureCognitiveSearchRetriever", "ChatGPTPluginRetriever", diff --git a/libs/langchain/langchain/retrievers/arcee.py b/libs/langchain/langchain/retrievers/arcee.py new file mode 100644 index 00000000000000..b5df98cea3dfb9 --- /dev/null +++ b/libs/langchain/langchain/retrievers/arcee.py @@ -0,0 +1,117 @@ +from typing import Any, Dict, Iterable, List, Optional + +from langchain.pydantic_v1 import Extra, root_validator + +from langchain.callbacks.manager import CallbackManagerForRetrieverRun +from langchain.docstore.document import Document +from langchain.schema import BaseRetriever +from langchain.schema.retriever import BaseRetriever +from langchain.utilities.arcee import ArceeWrapper, ArceeRoute, DALMFilter +from langchain.utils import get_from_dict_or_env + + +class ArceeRetriever(BaseRetriever): + _client: ArceeWrapper = None #: :meta private: + """Arcee client.""" + + arcee_api_key: str = "" + """Arcee API Key""" + + model: str + """Arcee DALM name""" + + arcee_api_url: str = "https://api.arcee.ai" + """Arcee API URL""" + + arcee_api_version: str = "v2" + """Arcee API Version""" + + arcee_app_url: str = "https://app.arcee.ai" + """Arcee App URL""" + + model_kwargs: Optional[Dict] = None + """Keyword arguments to pass to the model.""" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + underscore_attrs_are_private = True + + def __init__(self, **data: Any) -> None: + """Initializes private fields.""" + + super().__init__(**data) + + self._client = ArceeWrapper( + arcee_api_key=self.arcee_api_key, + arcee_api_url=self.arcee_api_url, + arcee_api_version=self.arcee_api_version, + model_kwargs=self.model_kwargs, + model_name=self.model, + ) + + self._client.validate_model_training_status() + + @root_validator() + def validate_environments(cls, values: Dict) -> Dict: + """Validate Arcee environment variables.""" + + # validate env vars + values["arcee_api_key"] = get_from_dict_or_env( + values, + "arcee_api_key", + "ARCEE_API_KEY", + ) + + values["arcee_api_url"] = get_from_dict_or_env( + values, + "arcee_api_url", + "ARCEE_API_URL", + ) + + values["arcee_app_url"] = get_from_dict_or_env( + values, + "arcee_app_url", + "ARCEE_APP_URL", + ) + + values["arcee_api_version"] = get_from_dict_or_env( + values, + "arcee_api_version", + "ARCEE_API_VERSION", + ) + + # validate model kwargs + if values.get("model_kwargs") is not None: + kw = values.get("model_kwargs") + + # validate size + if kw.get("size") is not None: + if not kw.get("size") >= 0: + raise ValueError("`size` must not be negative.") + + # validate filters + if kw.get("filters") is not None: + if not isinstance(kw.get("filters"), List): + raise ValueError("`filters` must be a list.") + for f in kw.get("filters"): + DALMFilter(**f) + + return values + + def _get_relevant_documents( + self, query: str, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any + ) -> List[Document]: + """Retrieve {size} contexts with your retriever for a given query + + Args: + qeury: Query to submit to the model + size: The max number of context results to retrieve. Defaults to 3. (Can be less if filters are provided). + filters: Filters to apply to the context dataset. + """ + + try: + return self._client.retrieve(query=query, **kwargs) + except Exception as e: + raise ValueError(f"Error while retrieving documents: {e}") from e diff --git a/libs/langchain/langchain/utilities/__init__.py b/libs/langchain/langchain/utilities/__init__.py index 3402aeb5a35adc..a091c8d7f47720 100644 --- a/libs/langchain/langchain/utilities/__init__.py +++ b/libs/langchain/langchain/utilities/__init__.py @@ -5,7 +5,7 @@ """ from langchain.utilities.alpha_vantage import AlphaVantageAPIWrapper from langchain.utilities.apify import ApifyWrapper -from langchain.utilities.arcee import ArceeClient, ArceeRoute, DALMFilter +from langchain.utilities.arcee import ArceeWrapper from langchain.utilities.arxiv import ArxivAPIWrapper from langchain.utilities.awslambda import LambdaWrapper from langchain.utilities.bash import BashProcess @@ -42,6 +42,7 @@ __all__ = [ "AlphaVantageAPIWrapper", "ApifyWrapper", + "ArceeWrapper", "ArxivAPIWrapper", "BashProcess", "BibtexparserWrapper", diff --git a/libs/langchain/langchain/utilities/arcee.py b/libs/langchain/langchain/utilities/arcee.py index 2430ee1bbdcdbd..4f90243f80a119 100644 --- a/libs/langchain/langchain/utilities/arcee.py +++ b/libs/langchain/langchain/utilities/arcee.py @@ -1,6 +1,7 @@ from enum import Enum from langchain.pydantic_v1 import BaseModel, root_validator +from langchain.schema.retriever import Document from typing import Any, Dict, TYPE_CHECKING, List, Mapping, Union, Optional, Literal import requests @@ -44,7 +45,7 @@ def set_meta(cls, values: Dict) -> Dict: return values -class ArceeClient: +class ArceeWrapper: def __init__( self, arcee_api_key: str, @@ -60,8 +61,8 @@ def __init__( try: route = ArceeRoute.model_training_status.value.format(id_or_name=model_name) - # response = self.make_request("get", route) - response = {"status": "training_complete", "model_id": "123"} + response = self._make_request("get", route) + # response = {"status": "training_complete", "model_id": "123"} # TODO: remove after testing self.model_id = response.get("model_id") self.model_training_status = response.get("status") except Exception as e: @@ -75,7 +76,7 @@ def validate_model_training_status(self): f"Model {self.model_id} is not ready. Please wait for training to complete." ) - def make_request( + def _make_request( self, method: Literal["post", "get"], route: ArceeRoute, @@ -113,14 +114,14 @@ def _make_request_headers(self, headers: Optional[Dict] = None) -> Dict: def _make_request_url(self, route: ArceeRoute) -> str: return f"{self.arcee_api_url}/{self.arcee_api_version}/{route}" - def make_request_body_for_models( + def _make_request_body_for_models( self, prompt: str, **kwargs: Mapping[str, Any] ) -> Mapping[str, Any]: """Make the request body for generate/retrieve models endpoint""" _model_kwargs = self.model_kwargs or {} _params = {**_model_kwargs, **kwargs} - filters = [DALMFilter(**f) for f in _params.get("filters", [])] + filters = [DALMFilter(**f) for f in _params.get("filters", [])] # TODO: Get this validated return dict( model_id=self.model_id, query=prompt, @@ -128,3 +129,50 @@ def make_request_body_for_models( filters=filters, id=self.model_id, ) + + def generate( + self, + prompt: str, + **kwargs: Any, + ) -> str: + """Generate text from Arcee DALM. + + Args: + prompt: Prompt to generate text from. + size: The max number of context results to retrieve. Defaults to 3. (Can be less if filters are provided). + filters: Filters to apply to the context dataset. + """ + + + response = self._make_request( + method="post", + route=ArceeRoute.generate.value.format(id_or_name=self.model_id), + body=self._make_request_body_for_models( + prompt=prompt, + **kwargs, + ), + ) + return response["text"] # TODO: confirm this transformation + + def retrieve( + self, + query: str, + **kwargs: Any, + ) -> List[Document]: + """Retrieve {size} contexts with your retriever for a given query + + Args: + qeury: Query to submit to the model + size: The max number of context results to retrieve. Defaults to 3. (Can be less if filters are provided). + filters: Filters to apply to the context dataset. + """ + + response = self._make_request( + method="post", + route=ArceeRoute.retrieve.value.format(id_or_name=self.model_id), + body=self._make_request_body_for_models( + prompt=query, + **kwargs, + ), + ) + return [Document(**doc) for doc in response["documents"]] # TODO: confirm this transformation \ No newline at end of file From f1c27d0ee7d452be9d7b8c819185a33ba76b4e6b Mon Sep 17 00:00:00 2001 From: Shubham Kushwaha Date: Tue, 10 Oct 2023 00:15:00 +0530 Subject: [PATCH 4/7] arcee llm docs --- .../docs/integrations/llms/arcee.ipynb | 152 ++++++++++++++++++ 1 file changed, 152 insertions(+) create mode 100644 docs/docs_skeleton/docs/integrations/llms/arcee.ipynb diff --git a/docs/docs_skeleton/docs/integrations/llms/arcee.ipynb b/docs/docs_skeleton/docs/integrations/llms/arcee.ipynb new file mode 100644 index 00000000000000..2fbdf87a36c279 --- /dev/null +++ b/docs/docs_skeleton/docs/integrations/llms/arcee.ipynb @@ -0,0 +1,152 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Arcee\n", + "This notebook demonstrates how to use the `Arcee` class for generating text using Arcee's Domain Adapted Language Models (DALMs)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "### Setup\n", + "\n", + "Before using Arcee, make sure the Arcee API key is set as `ARCEE_API_KEY` environment variable. You can also pass the api key as a named parameter." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.llms import Arcee\n", + "\n", + "# Create an instance of the Arcee class\n", + "arcee = Arcee(\n", + " model=\"DPT-PubMed-7b\",\n", + " # arcee_api_key=\"ARCEE-API-KEY\" # if not already set in the environment\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Generating Text\n", + "\n", + "You can generate text from Arcee by providing a prompt. Here's an example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Generate text\n", + "prompt = \"Can you explain the theory of relativity?\"\n", + "response = arcee(prompt)\n", + "print(response)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Additional parameters\n", + "\n", + "Arcee allows you to apply `filters` and set retrieved document sizes to aid text generation. Filters help narrow down the results. Here's how to use these parameters:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define filters\n", + "filters = [\n", + " {\n", + " \"field_name\": \"document\",\n", + " \"filter_type\": \"fuzzy_search\",\n", + " \"value\": \"Einstein\"\n", + " },\n", + " {\n", + " \"field_name\": \"year\",\n", + " \"filter_type\": \"strict_search\",\n", + " \"value\": \"1905\"\n", + " }\n", + "]\n", + "\n", + "response = arcee(prompt, size=5, filters=filters)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Additional Configuration\n", + "\n", + "You can also configure Arcee's parameters such as `arcee_api_url`, `arcee_app_url`, and `model_kwargs` as needed.\n", + "Setting the `model_kwargs` at the object initialization, uses the parameters as default for all the subsequent calls to the generate response." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "arcee = Arcee(\n", + " model=\"DALM-Patent\",\n", + " # arcee_api_key=\"ARCEE-API-KEY\", # if not already set in the environment\n", + " arcee_api_url=\"https://custom-api.arcee.ai\",\n", + " arcee_app_url=\"https://custom-app.arcee.ai\",\n", + " model_kwargs={\n", + " \"size\": 5,\n", + " \"filters\": [\n", + " {\n", + " \"field_name\": \"document\",\n", + " \"filter_type\": \"fuzzy_search\",\n", + " \"value\": \"Einstein\"\n", + " }\n", + " ]\n", + " }\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 21e056c86664c5971ea22ed6f404b398e562e538 Mon Sep 17 00:00:00 2001 From: Shubham Kushwaha Date: Tue, 10 Oct 2023 01:09:00 +0530 Subject: [PATCH 5/7] reference to arcee utils --- libs/langchain/langchain/utilities/arcee.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/libs/langchain/langchain/utilities/arcee.py b/libs/langchain/langchain/utilities/arcee.py index 4f90243f80a119..fe866ee7c523f7 100644 --- a/libs/langchain/langchain/utilities/arcee.py +++ b/libs/langchain/langchain/utilities/arcee.py @@ -1,3 +1,6 @@ +# This module contains utility classes and functions for interacting with Arcee API. +# For more information and updates, refer to the Arcee utils page: [https://github.com/arcee-ai/arcee-python/blob/main/arcee/dalm.py] + from enum import Enum from langchain.pydantic_v1 import BaseModel, root_validator @@ -62,7 +65,6 @@ def __init__( try: route = ArceeRoute.model_training_status.value.format(id_or_name=model_name) response = self._make_request("get", route) - # response = {"status": "training_complete", "model_id": "123"} # TODO: remove after testing self.model_id = response.get("model_id") self.model_training_status = response.get("status") except Exception as e: @@ -121,7 +123,7 @@ def _make_request_body_for_models( _model_kwargs = self.model_kwargs or {} _params = {**_model_kwargs, **kwargs} - filters = [DALMFilter(**f) for f in _params.get("filters", [])] # TODO: Get this validated + filters = [DALMFilter(**f) for f in _params.get("filters", [])] return dict( model_id=self.model_id, query=prompt, @@ -152,7 +154,7 @@ def generate( **kwargs, ), ) - return response["text"] # TODO: confirm this transformation + return response["text"] def retrieve( self, @@ -175,4 +177,4 @@ def retrieve( **kwargs, ), ) - return [Document(**doc) for doc in response["documents"]] # TODO: confirm this transformation \ No newline at end of file + return [Document(**doc) for doc in response["documents"]] \ No newline at end of file From c06452dabcbff65531eb1322678c5e3646236aa1 Mon Sep 17 00:00:00 2001 From: Shubham Kushwaha Date: Tue, 10 Oct 2023 01:10:40 +0530 Subject: [PATCH 6/7] arcee docs: llm + retriever --- .../docs/integrations/llms/arcee.ipynb | 88 +++++------ .../docs/integrations/retrievers/arcee.ipynb | 141 ++++++++++++++++++ 2 files changed, 182 insertions(+), 47 deletions(-) create mode 100644 docs/docs_skeleton/docs/integrations/retrievers/arcee.ipynb diff --git a/docs/docs_skeleton/docs/integrations/llms/arcee.ipynb b/docs/docs_skeleton/docs/integrations/llms/arcee.ipynb index 2fbdf87a36c279..0f0fa3461cbace 100644 --- a/docs/docs_skeleton/docs/integrations/llms/arcee.ipynb +++ b/docs/docs_skeleton/docs/integrations/llms/arcee.ipynb @@ -12,7 +12,6 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\n", "### Setup\n", "\n", "Before using Arcee, make sure the Arcee API key is set as `ARCEE_API_KEY` environment variable. You can also pass the api key as a named parameter." @@ -28,11 +27,45 @@ "\n", "# Create an instance of the Arcee class\n", "arcee = Arcee(\n", - " model=\"DPT-PubMed-7b\",\n", + " model=\"DALM-PubMed\",\n", " # arcee_api_key=\"ARCEE-API-KEY\" # if not already set in the environment\n", ")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Additional Configuration\n", + "\n", + "You can also configure Arcee's parameters such as `arcee_api_url`, `arcee_app_url`, and `model_kwargs` as needed.\n", + "Setting the `model_kwargs` at the object initialization uses the parameters as default for all the subsequent calls to the generate response." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "arcee = Arcee(\n", + " model=\"DALM-Patent\",\n", + " # arcee_api_key=\"ARCEE-API-KEY\", # if not already set in the environment\n", + " arcee_api_url=\"https://custom-api.arcee.ai\", # default is https://api.arcee.ai\n", + " arcee_app_url=\"https://custom-app.arcee.ai\", # default is https://app.arcee.ai\n", + " model_kwargs={\n", + " \"size\": 5,\n", + " \"filters\": [\n", + " {\n", + " \"field_name\": \"document\",\n", + " \"filter_type\": \"fuzzy_search\",\n", + " \"value\": \"Einstein\"\n", + " }\n", + " ]\n", + " }\n", + ")" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -49,9 +82,8 @@ "outputs": [], "source": [ "# Generate text\n", - "prompt = \"Can you explain the theory of relativity?\"\n", - "response = arcee(prompt)\n", - "print(response)" + "prompt = \"Can AI-driven music therapy contribute to the rehabilitation of patients with disorders of consciousness?\"\n", + "response = arcee(prompt)" ] }, { @@ -60,7 +92,9 @@ "source": [ "### Additional parameters\n", "\n", - "Arcee allows you to apply `filters` and set retrieved document sizes to aid text generation. Filters help narrow down the results. Here's how to use these parameters:\n" + "Arcee allows you to apply `filters` and set the `size` (in terms of count) of retrieved document(s) to aid text generation. Filters help narrow down the results. Here's how to use these parameters:\n", + "\n", + "\n" ] }, { @@ -83,49 +117,9 @@ " }\n", "]\n", "\n", + "# Generate text with filters and size params\n", "response = arcee(prompt, size=5, filters=filters)" ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Additional Configuration\n", - "\n", - "You can also configure Arcee's parameters such as `arcee_api_url`, `arcee_app_url`, and `model_kwargs` as needed.\n", - "Setting the `model_kwargs` at the object initialization, uses the parameters as default for all the subsequent calls to the generate response." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "arcee = Arcee(\n", - " model=\"DALM-Patent\",\n", - " # arcee_api_key=\"ARCEE-API-KEY\", # if not already set in the environment\n", - " arcee_api_url=\"https://custom-api.arcee.ai\",\n", - " arcee_app_url=\"https://custom-app.arcee.ai\",\n", - " model_kwargs={\n", - " \"size\": 5,\n", - " \"filters\": [\n", - " {\n", - " \"field_name\": \"document\",\n", - " \"filter_type\": \"fuzzy_search\",\n", - " \"value\": \"Einstein\"\n", - " }\n", - " ]\n", - " }\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/docs/docs_skeleton/docs/integrations/retrievers/arcee.ipynb b/docs/docs_skeleton/docs/integrations/retrievers/arcee.ipynb new file mode 100644 index 00000000000000..3cf1b62db3a0c3 --- /dev/null +++ b/docs/docs_skeleton/docs/integrations/retrievers/arcee.ipynb @@ -0,0 +1,141 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Arcee Retriever\n", + "This notebook demonstrates how to use the `ArceeRetriever` class to retrieve relevant document(s) for Arcee's Domain Adapted Language Models (DALMs)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Setup\n", + "\n", + "Before using `ArceeRetriever`, make sure the Arcee API key is set as `ARCEE_API_KEY` environment variable. You can also pass the api key as a named parameter." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.retrievers import ArceeRetriever\n", + "\n", + "retriever = ArceeRetriever(\n", + " model=\"DALM-PubMed\",\n", + " # arcee_api_key=\"ARCEE-API-KEY\" # if not already set in the environment\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Additional Configuration\n", + "\n", + "You can also configure `ArceeRetriever`'s parameters such as `arcee_api_url`, `arcee_app_url`, and `model_kwargs` as needed.\n", + "Setting the `model_kwargs` at the object initialization uses the filters and size as default for all the subsequent retrievals." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "retriever = ArceeRetriever(\n", + " model=\"DALM-PubMed\",\n", + " # arcee_api_key=\"ARCEE-API-KEY\", # if not already set in the environment\n", + " arcee_api_url=\"https://custom-api.arcee.ai\", # default is https://api.arcee.ai\n", + " arcee_app_url=\"https://custom-app.arcee.ai\", # default is https://app.arcee.ai\n", + " model_kwargs={\n", + " \"size\": 5,\n", + " \"filters\": [\n", + " {\n", + " \"field_name\": \"document\",\n", + " \"filter_type\": \"fuzzy_search\",\n", + " \"value\": \"Einstein\"\n", + " }\n", + " ]\n", + " }\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Retrieving documents\n", + "You can retrieve relevant documents from uploaded contexts by providing a query. Here's an example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "query = \"Can AI-driven music therapy contribute to the rehabilitation of patients with disorders of consciousness?\"\n", + "documents = retriever.get_relevant_documents(query=query)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Additional parameters\n", + "\n", + "Arcee allows you to apply `filters` and set the `size` (in terms of count) of retrieved document(s). Filters help narrow down the results. Here's how to use these parameters:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define filters\n", + "filters = [\n", + " {\n", + " \"field_name\": \"document\",\n", + " \"filter_type\": \"fuzzy_search\",\n", + " \"value\": \"Music\"\n", + " },\n", + " {\n", + " \"field_name\": \"year\",\n", + " \"filter_type\": \"strict_search\",\n", + " \"value\": \"1905\"\n", + " }\n", + "]\n", + "\n", + "# Retrieve documents with filters and size params\n", + "documents = retriever.get_relevant_documents(query=query, size=5, filters=filters)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 32b8ff6249627dc0126c95d9a5a262c8290925d1 Mon Sep 17 00:00:00 2001 From: Shubham Kushwaha Date: Tue, 10 Oct 2023 02:25:57 +0530 Subject: [PATCH 7/7] test, lint, format, spell --- libs/langchain/langchain/llms/arcee.py | 26 ++++--- .../langchain/retrievers/__init__.py | 2 +- libs/langchain/langchain/retrievers/arcee.py | 41 ++++++++--- libs/langchain/langchain/utilities/arcee.py | 73 +++++++++++-------- 4 files changed, 87 insertions(+), 55 deletions(-) diff --git a/libs/langchain/langchain/llms/arcee.py b/libs/langchain/langchain/llms/arcee.py index deb3885839c227..469b5e250fd5dd 100644 --- a/libs/langchain/langchain/llms/arcee.py +++ b/libs/langchain/langchain/llms/arcee.py @@ -1,8 +1,9 @@ +from typing import Any, Dict, List, Optional + from langchain.callbacks.manager import CallbackManagerForLLMRun -from langchain.utilities.arcee import ArceeWrapper, DALMFilter from langchain.llms.base import LLM from langchain.pydantic_v1 import Extra, root_validator -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from langchain.utilities.arcee import ArceeWrapper, DALMFilter from langchain.utils import get_from_dict_or_env @@ -18,14 +19,14 @@ class Arcee(LLM): from langchain.llms import Arcee arcee = Arcee( - model="DPT-PubMed-7b", - arcee_api_key="DUMMY-KEY" + model="DALM-PubMed", + arcee_api_key="ARCEE-API-KEY" ) - response = arcee("Can?") + response = arcee("AI-driven music therapy") """ - _client: ArceeWrapper = None #: :meta private: + _client: Optional[ArceeWrapper] = None #: :meta private: """Arcee _client.""" arcee_api_key: str = "" @@ -46,7 +47,7 @@ class Arcee(LLM): model_id: str = "" """Arcee Model ID""" - model_kwargs: Optional[Dict] = None + model_kwargs: Optional[Dict[str, Any]] = None """Keyword arguments to pass to the model.""" class Config: @@ -64,7 +65,7 @@ def __init__(self, **data: Any) -> None: """Initializes private fields.""" super().__init__(**data) - + self._client = None self._client = ArceeWrapper( arcee_api_key=self.arcee_api_key, arcee_api_url=self.arcee_api_url, @@ -105,8 +106,8 @@ def validate_environments(cls, values: Dict) -> Dict: ) # validate model kwargs - if values.get("model_kwargs") is not None: - kw = values.get("model_kwargs") + if values["model_kwargs"]: + kw = values["model_kwargs"] # validate size if kw.get("size") is not None: @@ -133,11 +134,14 @@ def _call( Args: prompt: Prompt to generate text from. - size: The max number of context results to retrieve. Defaults to 3. (Can be less if filters are provided). + size: The max number of context results to retrieve. + Defaults to 3. (Can be less if filters are provided). filters: Filters to apply to the context dataset. """ try: + if not self._client: + raise ValueError("Client is not initialized.") return self._client.generate(prompt=prompt, **kwargs) except Exception as e: raise Exception(f"Failed to generate text: {e}") from e diff --git a/libs/langchain/langchain/retrievers/__init__.py b/libs/langchain/langchain/retrievers/__init__.py index b93a7e652aba23..9abdb145fd27bc 100644 --- a/libs/langchain/langchain/retrievers/__init__.py +++ b/libs/langchain/langchain/retrievers/__init__.py @@ -18,8 +18,8 @@ CallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun """ -from langchain.retrievers.arxiv import ArxivRetriever from langchain.retrievers.arcee import ArceeRetriever +from langchain.retrievers.arxiv import ArxivRetriever from langchain.retrievers.azure_cognitive_search import AzureCognitiveSearchRetriever from langchain.retrievers.bm25 import BM25Retriever from langchain.retrievers.chaindesk import ChaindeskRetriever diff --git a/libs/langchain/langchain/retrievers/arcee.py b/libs/langchain/langchain/retrievers/arcee.py index b5df98cea3dfb9..6bfba7eef9a6e3 100644 --- a/libs/langchain/langchain/retrievers/arcee.py +++ b/libs/langchain/langchain/retrievers/arcee.py @@ -1,17 +1,33 @@ -from typing import Any, Dict, Iterable, List, Optional - -from langchain.pydantic_v1 import Extra, root_validator +from typing import Any, Dict, List, Optional from langchain.callbacks.manager import CallbackManagerForRetrieverRun from langchain.docstore.document import Document +from langchain.pydantic_v1 import Extra, root_validator from langchain.schema import BaseRetriever -from langchain.schema.retriever import BaseRetriever -from langchain.utilities.arcee import ArceeWrapper, ArceeRoute, DALMFilter +from langchain.utilities.arcee import ArceeWrapper, DALMFilter from langchain.utils import get_from_dict_or_env class ArceeRetriever(BaseRetriever): - _client: ArceeWrapper = None #: :meta private: + """Document retriever for Arcee's Domain Adapted Language Models (DALMs). + + To use, set the ``ARCEE_API_KEY`` environment variable with your Arcee API key, + or pass ``arcee_api_key`` as a named parameter. + + Example: + .. code-block:: python + + from langchain.retrievers import ArceeRetriever + + retriever = ArceeRetriever( + model="DALM-PubMed", + arcee_api_key="ARCEE-API-KEY" + ) + + documents = retriever.get_relevant_documents("AI-driven music therapy") + """ + + _client: Optional[ArceeWrapper] = None #: :meta private: """Arcee client.""" arcee_api_key: str = "" @@ -29,7 +45,7 @@ class ArceeRetriever(BaseRetriever): arcee_app_url: str = "https://app.arcee.ai" """Arcee App URL""" - model_kwargs: Optional[Dict] = None + model_kwargs: Optional[Dict[str, Any]] = None """Keyword arguments to pass to the model.""" class Config: @@ -83,8 +99,8 @@ def validate_environments(cls, values: Dict) -> Dict: ) # validate model kwargs - if values.get("model_kwargs") is not None: - kw = values.get("model_kwargs") + if values["model_kwargs"]: + kw = values["model_kwargs"] # validate size if kw.get("size") is not None: @@ -106,12 +122,15 @@ def _get_relevant_documents( """Retrieve {size} contexts with your retriever for a given query Args: - qeury: Query to submit to the model - size: The max number of context results to retrieve. Defaults to 3. (Can be less if filters are provided). + query: Query to submit to the model + size: The max number of context results to retrieve. + Defaults to 3. (Can be less if filters are provided). filters: Filters to apply to the context dataset. """ try: + if not self._client: + raise ValueError("Client is not initialized.") return self._client.retrieve(query=query, **kwargs) except Exception as e: raise ValueError(f"Error while retrieving documents: {e}") from e diff --git a/libs/langchain/langchain/utilities/arcee.py b/libs/langchain/langchain/utilities/arcee.py index fe866ee7c523f7..f73cdcad9c8f46 100644 --- a/libs/langchain/langchain/utilities/arcee.py +++ b/libs/langchain/langchain/utilities/arcee.py @@ -1,12 +1,14 @@ # This module contains utility classes and functions for interacting with Arcee API. -# For more information and updates, refer to the Arcee utils page: [https://github.com/arcee-ai/arcee-python/blob/main/arcee/dalm.py] +# For more information and updates, refer to the Arcee utils page: +# [https://github.com/arcee-ai/arcee-python/blob/main/arcee/dalm.py] from enum import Enum +from typing import Any, Dict, List, Literal, Mapping, Optional, Union + +import requests from langchain.pydantic_v1 import BaseModel, root_validator from langchain.schema.retriever import Document -from typing import Any, Dict, TYPE_CHECKING, List, Mapping, Union, Optional, Literal -import requests class ArceeRoute(str, Enum): @@ -24,15 +26,20 @@ class DALMFilter(BaseModel): """Filters available for a dalm retrieval and generation Arguments: - field_name: The field to filter on. Can be 'document' or 'name' to filter on your document's raw text or title - Any other field will be presumed to be a metadata field you included when uploading your context data - filter_type: Currently 'fuzzy_search' and 'strict_search' are supported. More to come soon! - 'fuzzy_search' means a fuzzy search on the provided field will be performed. The exact strict doesn't - need to exist in the document for this to find a match. Very useful for scanning a document for some - keyword terms - 'strict_search' means that the exact string must appear in the provided field. This is NOT an exact eq - filter. ie a document with content "the happy dog crossed the street" will match on a strict_search of "dog" - but won't match on "the dog". Python equivalent of `return search_string in full_string` + field_name: The field to filter on. Can be 'document' or 'name' to filter + on your document's raw text or title. Any other field will be presumed + to be a metadata field you included when uploading your context data + filter_type: Currently 'fuzzy_search' and 'strict_search' are supported. + 'fuzzy_search' means a fuzzy search on the provided field is performed. + The exact strict doesn't need to exist in the document + for this to find a match. + Very useful for scanning a document for some keyword terms. + 'strict_search' means that the exact string must appear + in the provided field. + This is NOT an exact eq filter. ie a document with content + "the happy dog crossed the street" will match on a strict_search of + "dog" but won't match on "the dog". + Python equivalent of `return search_string in full_string`. value: The actual value to search for in the context data/metadata """ @@ -54,7 +61,7 @@ def __init__( arcee_api_key: str, arcee_api_url: str, arcee_api_version: str, - model_kwargs: Dict[str, Any], + model_kwargs: Optional[Dict[str, Any]], model_name: str, ): self.arcee_api_key = arcee_api_key @@ -72,17 +79,18 @@ def __init__( f"Error while validating model training status for '{model_name}': {e}" ) from e - def validate_model_training_status(self): + def validate_model_training_status(self) -> None: if self.model_training_status != "training_complete": raise Exception( - f"Model {self.model_id} is not ready. Please wait for training to complete." + f"Model {self.model_id} is not ready. " + "Please wait for training to complete." ) def _make_request( self, method: Literal["post", "get"], - route: ArceeRoute, - body: Optional[dict] = None, + route: Union[ArceeRoute, str], + body: Optional[Mapping[str, Any]] = None, params: Optional[dict] = None, headers: Optional[dict] = None, ) -> dict: @@ -113,7 +121,7 @@ def _make_request_headers(self, headers: Optional[Dict] = None) -> Dict: headers.update(internal_headers) return headers - def _make_request_url(self, route: ArceeRoute) -> str: + def _make_request_url(self, route: Union[ArceeRoute, str]) -> str: return f"{self.arcee_api_url}/{self.arcee_api_version}/{route}" def _make_request_body_for_models( @@ -133,22 +141,22 @@ def _make_request_body_for_models( ) def generate( - self, - prompt: str, - **kwargs: Any, + self, + prompt: str, + **kwargs: Any, ) -> str: """Generate text from Arcee DALM. Args: prompt: Prompt to generate text from. - size: The max number of context results to retrieve. Defaults to 3. (Can be less if filters are provided). + size: The max number of context results to retrieve. Defaults to 3. + (Can be less if filters are provided). filters: Filters to apply to the context dataset. """ - response = self._make_request( method="post", - route=ArceeRoute.generate.value.format(id_or_name=self.model_id), + route=ArceeRoute.generate, body=self._make_request_body_for_models( prompt=prompt, **kwargs, @@ -157,24 +165,25 @@ def generate( return response["text"] def retrieve( - self, - query: str, - **kwargs: Any, + self, + query: str, + **kwargs: Any, ) -> List[Document]: """Retrieve {size} contexts with your retriever for a given query - + Args: - qeury: Query to submit to the model - size: The max number of context results to retrieve. Defaults to 3. (Can be less if filters are provided). + query: Query to submit to the model + size: The max number of context results to retrieve. Defaults to 3. + (Can be less if filters are provided). filters: Filters to apply to the context dataset. """ response = self._make_request( method="post", - route=ArceeRoute.retrieve.value.format(id_or_name=self.model_id), + route=ArceeRoute.retrieve, body=self._make_request_body_for_models( prompt=query, **kwargs, ), ) - return [Document(**doc) for doc in response["documents"]] \ No newline at end of file + return [Document(**doc) for doc in response["documents"]]