Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions libs/oci/langchain_oci/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
# Copyright (c) 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

from langchain_oci.chat_models.oci_generative_ai import ChatOCIGenAI
from langchain_oci.chat_models.oci_data_science import (
ChatOCIModelDeployment,
ChatOCIModelDeploymentTGI,
ChatOCIModelDeploymentVLLM
ChatOCIModelDeploymentVLLM,
)
from langchain_oci.chat_models.oci_generative_ai import ChatOCIGenAI
from langchain_oci.embeddings.oci_data_science_model_deployment_endpoint import (
OCIModelDeploymentEndpointEmbeddings,
)
from langchain_oci.embeddings.oci_generative_ai import OCIGenAIEmbeddings
from langchain_oci.embeddings.oci_data_science_model_deployment_endpoint import OCIModelDeploymentEndpointEmbeddings
from langchain_oci.llms.oci_data_science_model_deployment_endpoint import (
BaseOCIModelDeployment,
OCIModelDeploymentLLM,
Expand Down
9 changes: 7 additions & 2 deletions libs/oci/langchain_oci/chat_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,13 @@
from langchain_oci.chat_models.oci_data_science import (
ChatOCIModelDeployment,
ChatOCIModelDeploymentTGI,
ChatOCIModelDeploymentVLLM
ChatOCIModelDeploymentVLLM,
)
from langchain_oci.chat_models.oci_generative_ai import ChatOCIGenAI

__all__ = ["ChatOCIGenAI", "ChatOCIModelDeployment", "ChatOCIModelDeploymentTGI", "ChatOCIModelDeploymentVLLM"]
__all__ = [
"ChatOCIGenAI",
"ChatOCIModelDeployment",
"ChatOCIModelDeploymentTGI",
"ChatOCIModelDeploymentVLLM",
]
13 changes: 13 additions & 0 deletions libs/oci/langchain_oci/chat_models/oci_generative_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,19 @@ def messages_to_oci_params(
message=msg_content, tool_calls=tool_calls
)
)
elif isinstance(msg, ToolMessage):
oci_chat_history.append(
self.oci_chat_message[self.get_role(msg)](
tool_results=[
self.oci_tool_result(
call=self.oci_tool_call(
name=msg.name, parameters={}
),
outputs=[{"output": msg.content}],
)
],
)
)

# Process current turn messages in reverse order until a HumanMessage
current_turn = []
Expand Down
4 changes: 3 additions & 1 deletion libs/oci/langchain_oci/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Copyright (c) 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

from langchain_oci.embeddings.oci_data_science_model_deployment_endpoint import OCIModelDeploymentEndpointEmbeddings
from langchain_oci.embeddings.oci_data_science_model_deployment_endpoint import (
OCIModelDeploymentEndpointEmbeddings,
)
from langchain_oci.embeddings.oci_generative_ai import OCIGenAIEmbeddings

__all__ = ["OCIModelDeploymentEndpointEmbeddings", "OCIGenAIEmbeddings"]
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Copyright (c) 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

from typing import Any, Callable, Dict, List, Mapping, Optional

import requests
from langchain_core.embeddings import Embeddings
from langchain_core.language_models.llms import create_base_retry_decorator
from langchain_core.utils import get_from_dict_or_env
from pydantic import BaseModel, Field, model_validator
import requests
from typing import Any, Callable, Dict, List, Mapping, Optional


DEFAULT_HEADER = {
"Content-Type": "application/json",
Expand Down Expand Up @@ -39,7 +39,7 @@ class OCIModelDeploymentEndpointEmbeddings(BaseModel, Embeddings):
embeddings = OCIModelDeploymentEndpointEmbeddings(
endpoint="https://modeldeployment.us-ashburn-1.oci.customer-oci.com/<md_ocid>/predict",
)
""" # noqa: E501
""" # noqa: E501

auth: dict = Field(default_factory=dict, exclude=True)
"""ADS auth dictionary for OCI authentication:
Expand Down
12 changes: 6 additions & 6 deletions libs/oci/langchain_oci/embeddings/oci_generative_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,13 @@ def make_security_token_signer(oci_config): # type: ignore[no-untyped-def]
oci_config=client_kwargs["config"]
)
elif values["auth_type"] == OCIAuthType(3).name:
client_kwargs[
"signer"
] = oci.auth.signers.InstancePrincipalsSecurityTokenSigner()
client_kwargs["signer"] = (
oci.auth.signers.InstancePrincipalsSecurityTokenSigner()
)
elif values["auth_type"] == OCIAuthType(4).name:
client_kwargs[
"signer"
] = oci.auth.signers.get_resource_principals_signer()
client_kwargs["signer"] = (
oci.auth.signers.get_resource_principals_signer()
)
else:
raise ValueError("Please provide valid value to auth_type")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

"""LLM for OCI data science model deployment endpoint."""

from contextlib import asynccontextmanager
import json
import logging
import traceback
from contextlib import asynccontextmanager
from typing import (
Any,
AsyncGenerator,
Expand Down Expand Up @@ -793,6 +793,7 @@ class OCIModelDeploymentTGI(OCIModelDeploymentLLM):
)

"""

max_tokens: int = 256
"""Denotes the number of tokens to predict per generation."""

Expand Down Expand Up @@ -943,6 +944,7 @@ class OCIModelDeploymentVLLM(OCIModelDeploymentLLM):
)

"""

max_tokens: int = 256
"""Denotes the number of tokens to predict per generation."""

Expand Down
18 changes: 8 additions & 10 deletions libs/oci/langchain_oci/llms/oci_generative_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@
class Provider(ABC):
@property
@abstractmethod
def stop_sequence_key(self) -> str:
...
def stop_sequence_key(self) -> str: ...

@abstractmethod
def completion_response_to_text(self, response: Any) -> str:
...
def completion_response_to_text(self, response: Any) -> str: ...


class CohereProvider(Provider):
Expand Down Expand Up @@ -159,13 +157,13 @@ def make_security_token_signer(oci_config): # type: ignore[no-untyped-def]
oci_config=client_kwargs["config"]
)
elif values["auth_type"] == OCIAuthType(3).name:
client_kwargs[
"signer"
] = oci.auth.signers.InstancePrincipalsSecurityTokenSigner()
client_kwargs["signer"] = (
oci.auth.signers.InstancePrincipalsSecurityTokenSigner()
)
elif values["auth_type"] == OCIAuthType(4).name:
client_kwargs[
"signer"
] = oci.auth.signers.get_resource_principals_signer()
client_kwargs["signer"] = (
oci.auth.signers.get_resource_principals_signer()
)
else:
raise ValueError(
"Please provide valid value to auth_type, "
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Test OCI Data Science Model Deployment Endpoint."""

import responses
import pytest
import responses
from pytest_mock import MockerFixture

from langchain_oci.embeddings import OCIModelDeploymentEndpointEmbeddings


Expand All @@ -17,11 +18,7 @@ def test_embedding_call(mocker: MockerFixture) -> None:
responses.POST,
endpoint,
json={
"data": [
{
"embedding": expected_output
}
],
"data": [{"embedding": expected_output}],
},
status=200,
)
Expand Down