diff --git a/libs/oci/langchain_oci/chat_models/oci_generative_ai.py b/libs/oci/langchain_oci/chat_models/oci_generative_ai.py index a219511..670076b 100644 --- a/libs/oci/langchain_oci/chat_models/oci_generative_ai.py +++ b/libs/oci/langchain_oci/chat_models/oci_generative_ai.py @@ -545,8 +545,8 @@ def process_stream_tool_calls( return tool_call_chunks -class MetaProvider(Provider): - """Provider implementation for Meta.""" +class GenericProvider(Provider): + """Provider for models using generic API spec.""" stop_sequence_key: str = "stop" @@ -934,6 +934,11 @@ def process_stream_tool_calls( return tool_call_chunks +class MetaProvider(GenericProvider): + """Provider for Meta models. This provider is for backward compatibility.""" + pass + + class ChatOCIGenAI(BaseChatModel, OCIGenAIBase): """ChatOCIGenAI chat model integration. @@ -1018,6 +1023,7 @@ def _provider_map(self) -> Mapping[str, Provider]: return { "cohere": CohereProvider(), "meta": MetaProvider(), + "generic": GenericProvider(), } @property diff --git a/libs/oci/langchain_oci/llms/oci_generative_ai.py b/libs/oci/langchain_oci/llms/oci_generative_ai.py index 8f67d04..240598c 100644 --- a/libs/oci/langchain_oci/llms/oci_generative_ai.py +++ b/libs/oci/langchain_oci/llms/oci_generative_ai.py @@ -40,7 +40,8 @@ def completion_response_to_text(self, response: Any) -> str: return response.data.inference_response.generated_texts[0].text -class MetaProvider(Provider): +class GenericProvider(Provider): + """Provider for models using generic API spec.""" stop_sequence_key: str = "stop" def __init__(self) -> None: @@ -50,6 +51,11 @@ def __init__(self) -> None: def completion_response_to_text(self, response: Any) -> str: return response.data.inference_response.choices[0].text + + +class MetaProvider(GenericProvider): + """Provider for Meta models. This provider is for backward compatibility.""" + pass class OCIAuthType(Enum): @@ -202,14 +208,17 @@ def _identifying_params(self) -> Mapping[str, Any]: def _get_provider(self, provider_map: Mapping[str, Any]) -> Any: if self.provider is not None: provider = self.provider + elif self.model_id is None: + raise ValueError( + "model_id is required to derive the provider, " + "please provide the provider explicitly or specify " + "the model_id to derive the provider." + ) + elif self.model_id.startswith(CUSTOM_ENDPOINT_PREFIX): + raise ValueError("provider is required for custom endpoints.") else: - if self.model_id is None: - raise ValueError( - "model_id is required to derive the provider, " - "please provide the provider explicitly or specify " - "the model_id to derive the provider." - ) - provider = self.model_id.split(".")[0].lower() + + provider = provider_map.get(self.model_id.split(".")[0].lower(), "generic") if provider not in provider_map: raise ValueError( @@ -269,6 +278,7 @@ def _provider_map(self) -> Mapping[str, Any]: return { "cohere": CohereProvider(), "meta": MetaProvider(), + "generic": GenericProvider(), } @property