diff --git a/libs/oci/langchain_oci/__init__.py b/libs/oci/langchain_oci/__init__.py index 8a532e1..eb6df32 100644 --- a/libs/oci/langchain_oci/__init__.py +++ b/libs/oci/langchain_oci/__init__.py @@ -1,14 +1,16 @@ # Copyright (c) 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ -from langchain_oci.chat_models.oci_generative_ai import ChatOCIGenAI from langchain_oci.chat_models.oci_data_science import ( ChatOCIModelDeployment, ChatOCIModelDeploymentTGI, - ChatOCIModelDeploymentVLLM + ChatOCIModelDeploymentVLLM, +) +from langchain_oci.chat_models.oci_generative_ai import ChatOCIGenAI +from langchain_oci.embeddings.oci_data_science_model_deployment_endpoint import ( + OCIModelDeploymentEndpointEmbeddings, ) from langchain_oci.embeddings.oci_generative_ai import OCIGenAIEmbeddings -from langchain_oci.embeddings.oci_data_science_model_deployment_endpoint import OCIModelDeploymentEndpointEmbeddings from langchain_oci.llms.oci_data_science_model_deployment_endpoint import ( BaseOCIModelDeployment, OCIModelDeploymentLLM, diff --git a/libs/oci/langchain_oci/chat_models/__init__.py b/libs/oci/langchain_oci/chat_models/__init__.py index 56d3c4b..c714b3c 100644 --- a/libs/oci/langchain_oci/chat_models/__init__.py +++ b/libs/oci/langchain_oci/chat_models/__init__.py @@ -4,8 +4,13 @@ from langchain_oci.chat_models.oci_data_science import ( ChatOCIModelDeployment, ChatOCIModelDeploymentTGI, - ChatOCIModelDeploymentVLLM + ChatOCIModelDeploymentVLLM, ) from langchain_oci.chat_models.oci_generative_ai import ChatOCIGenAI -__all__ = ["ChatOCIGenAI", "ChatOCIModelDeployment", "ChatOCIModelDeploymentTGI", "ChatOCIModelDeploymentVLLM"] +__all__ = [ + "ChatOCIGenAI", + "ChatOCIModelDeployment", + "ChatOCIModelDeploymentTGI", + "ChatOCIModelDeploymentVLLM", +] diff --git a/libs/oci/langchain_oci/chat_models/oci_generative_ai.py b/libs/oci/langchain_oci/chat_models/oci_generative_ai.py index 935b001..1709b05 100644 --- a/libs/oci/langchain_oci/chat_models/oci_generative_ai.py +++ b/libs/oci/langchain_oci/chat_models/oci_generative_ai.py @@ -362,6 +362,19 @@ def messages_to_oci_params( message=msg_content, tool_calls=tool_calls ) ) + elif isinstance(msg, ToolMessage): + oci_chat_history.append( + self.oci_chat_message[self.get_role(msg)]( + tool_results=[ + self.oci_tool_result( + call=self.oci_tool_call( + name=msg.name, parameters={} + ), + outputs=[{"output": msg.content}], + ) + ], + ) + ) # Process current turn messages in reverse order until a HumanMessage current_turn = [] diff --git a/libs/oci/langchain_oci/embeddings/__init__.py b/libs/oci/langchain_oci/embeddings/__init__.py index 6867961..a23dbea 100644 --- a/libs/oci/langchain_oci/embeddings/__init__.py +++ b/libs/oci/langchain_oci/embeddings/__init__.py @@ -1,7 +1,9 @@ # Copyright (c) 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ -from langchain_oci.embeddings.oci_data_science_model_deployment_endpoint import OCIModelDeploymentEndpointEmbeddings +from langchain_oci.embeddings.oci_data_science_model_deployment_endpoint import ( + OCIModelDeploymentEndpointEmbeddings, +) from langchain_oci.embeddings.oci_generative_ai import OCIGenAIEmbeddings __all__ = ["OCIModelDeploymentEndpointEmbeddings", "OCIGenAIEmbeddings"] diff --git a/libs/oci/langchain_oci/embeddings/oci_data_science_model_deployment_endpoint.py b/libs/oci/langchain_oci/embeddings/oci_data_science_model_deployment_endpoint.py index d203f1f..8f931f2 100644 --- a/libs/oci/langchain_oci/embeddings/oci_data_science_model_deployment_endpoint.py +++ b/libs/oci/langchain_oci/embeddings/oci_data_science_model_deployment_endpoint.py @@ -1,13 +1,13 @@ # Copyright (c) 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +from typing import Any, Callable, Dict, List, Mapping, Optional + +import requests from langchain_core.embeddings import Embeddings from langchain_core.language_models.llms import create_base_retry_decorator from langchain_core.utils import get_from_dict_or_env from pydantic import BaseModel, Field, model_validator -import requests -from typing import Any, Callable, Dict, List, Mapping, Optional - DEFAULT_HEADER = { "Content-Type": "application/json", @@ -39,7 +39,7 @@ class OCIModelDeploymentEndpointEmbeddings(BaseModel, Embeddings): embeddings = OCIModelDeploymentEndpointEmbeddings( endpoint="https://modeldeployment.us-ashburn-1.oci.customer-oci.com//predict", ) - """ # noqa: E501 + """ # noqa: E501 auth: dict = Field(default_factory=dict, exclude=True) """ADS auth dictionary for OCI authentication: diff --git a/libs/oci/langchain_oci/embeddings/oci_generative_ai.py b/libs/oci/langchain_oci/embeddings/oci_generative_ai.py index 8906cf9..edfb285 100644 --- a/libs/oci/langchain_oci/embeddings/oci_generative_ai.py +++ b/libs/oci/langchain_oci/embeddings/oci_generative_ai.py @@ -143,13 +143,13 @@ def make_security_token_signer(oci_config): # type: ignore[no-untyped-def] oci_config=client_kwargs["config"] ) elif values["auth_type"] == OCIAuthType(3).name: - client_kwargs[ - "signer" - ] = oci.auth.signers.InstancePrincipalsSecurityTokenSigner() + client_kwargs["signer"] = ( + oci.auth.signers.InstancePrincipalsSecurityTokenSigner() + ) elif values["auth_type"] == OCIAuthType(4).name: - client_kwargs[ - "signer" - ] = oci.auth.signers.get_resource_principals_signer() + client_kwargs["signer"] = ( + oci.auth.signers.get_resource_principals_signer() + ) else: raise ValueError("Please provide valid value to auth_type") diff --git a/libs/oci/langchain_oci/llms/oci_data_science_model_deployment_endpoint.py b/libs/oci/langchain_oci/llms/oci_data_science_model_deployment_endpoint.py index 5a605f8..9089be1 100644 --- a/libs/oci/langchain_oci/llms/oci_data_science_model_deployment_endpoint.py +++ b/libs/oci/langchain_oci/llms/oci_data_science_model_deployment_endpoint.py @@ -3,10 +3,10 @@ """LLM for OCI data science model deployment endpoint.""" -from contextlib import asynccontextmanager import json import logging import traceback +from contextlib import asynccontextmanager from typing import ( Any, AsyncGenerator, @@ -793,6 +793,7 @@ class OCIModelDeploymentTGI(OCIModelDeploymentLLM): ) """ + max_tokens: int = 256 """Denotes the number of tokens to predict per generation.""" @@ -943,6 +944,7 @@ class OCIModelDeploymentVLLM(OCIModelDeploymentLLM): ) """ + max_tokens: int = 256 """Denotes the number of tokens to predict per generation.""" diff --git a/libs/oci/langchain_oci/llms/oci_generative_ai.py b/libs/oci/langchain_oci/llms/oci_generative_ai.py index ac8ae98..8f67d04 100644 --- a/libs/oci/langchain_oci/llms/oci_generative_ai.py +++ b/libs/oci/langchain_oci/llms/oci_generative_ai.py @@ -22,12 +22,10 @@ class Provider(ABC): @property @abstractmethod - def stop_sequence_key(self) -> str: - ... + def stop_sequence_key(self) -> str: ... @abstractmethod - def completion_response_to_text(self, response: Any) -> str: - ... + def completion_response_to_text(self, response: Any) -> str: ... class CohereProvider(Provider): @@ -159,13 +157,13 @@ def make_security_token_signer(oci_config): # type: ignore[no-untyped-def] oci_config=client_kwargs["config"] ) elif values["auth_type"] == OCIAuthType(3).name: - client_kwargs[ - "signer" - ] = oci.auth.signers.InstancePrincipalsSecurityTokenSigner() + client_kwargs["signer"] = ( + oci.auth.signers.InstancePrincipalsSecurityTokenSigner() + ) elif values["auth_type"] == OCIAuthType(4).name: - client_kwargs[ - "signer" - ] = oci.auth.signers.get_resource_principals_signer() + client_kwargs["signer"] = ( + oci.auth.signers.get_resource_principals_signer() + ) else: raise ValueError( "Please provide valid value to auth_type, " diff --git a/libs/oci/tests/unit_tests/embeddings/test_oci_model_deployment_endpoint.py b/libs/oci/tests/unit_tests/embeddings/test_oci_model_deployment_endpoint.py index b74598f..582dbcc 100644 --- a/libs/oci/tests/unit_tests/embeddings/test_oci_model_deployment_endpoint.py +++ b/libs/oci/tests/unit_tests/embeddings/test_oci_model_deployment_endpoint.py @@ -1,8 +1,9 @@ """Test OCI Data Science Model Deployment Endpoint.""" -import responses import pytest +import responses from pytest_mock import MockerFixture + from langchain_oci.embeddings import OCIModelDeploymentEndpointEmbeddings @@ -17,11 +18,7 @@ def test_embedding_call(mocker: MockerFixture) -> None: responses.POST, endpoint, json={ - "data": [ - { - "embedding": expected_output - } - ], + "data": [{"embedding": expected_output}], }, status=200, )