Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bagatur/oai v1 scratch #12948

Merged
merged 13 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 46 additions & 25 deletions libs/langchain/langchain/chat_models/azure_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from __future__ import annotations

import logging
from typing import Any, Dict, Mapping
from typing import Any, Dict, Union

from langchain.chat_models.openai import ChatOpenAI
from langchain.pydantic_v1 import root_validator
from langchain.chat_models.openai import ChatOpenAI, _is_openai_v1
from langchain.pydantic_v1 import BaseModel, Field, root_validator
from langchain.schema import ChatResult
from langchain.utils import get_from_dict_or_env

Expand Down Expand Up @@ -51,13 +51,13 @@ class AzureChatOpenAI(ChatOpenAI):
in, even if not explicitly saved on this class.
"""

deployment_name: str = ""
deployment_name: str = Field(default="", alias="azure_deployment")
model_version: str = ""
openai_api_type: str = ""
openai_api_base: str = ""
openai_api_version: str = ""
openai_api_key: str = ""
openai_organization: str = ""
openai_api_base: str = Field(default="", alias="azure_endpoint")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this one's a bit awkward, maybe should just create new param

openai_api_version: str = Field(default="", alias="api_version")
openai_api_key: str = Field(default="", alias="api_key")
openai_organization: str = Field(default="", alias="organization")
openai_proxy: str = ""

@root_validator()
Expand Down Expand Up @@ -101,14 +101,27 @@ def validate_environment(cls, values: Dict) -> Dict:
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
try:
if _is_openai_v1():
values["client"] = openai.AzureOpenAI(
azure_endpoint=values["openai_api_base"],
api_key=values["openai_api_key"],
timeout=values["request_timeout"],
max_retries=values["max_retries"],
organization=values["openai_organization"],
api_version=values["openai_api_version"],
azure_deployment=values["deployment_name"],
).chat.completions
values["async_client"] = openai.AsyncAzureOpenAI(
azure_endpoint=values["openai_api_base"],
api_key=values["openai_api_key"],
timeout=values["request_timeout"],
max_retries=values["max_retries"],
organization=values["openai_organization"],
api_version=values["openai_api_version"],
azure_deployment=values["deployment_name"],
).chat.completions
else:
values["client"] = openai.ChatCompletion
except AttributeError:
raise ValueError(
"`openai` has no `ChatCompletion` attribute, this is likely "
"due to an old version of the openai package. Try upgrading it "
"with `pip install --upgrade openai`."
)
if values["n"] < 1:
raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]:
Expand All @@ -118,10 +131,13 @@ def validate_environment(cls, values: Dict) -> Dict:
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
return {
**super()._default_params,
"engine": self.deployment_name,
}
if _is_openai_v1():
return super()._default_params
else:
return {
**super()._default_params,
"engine": self.deployment_name,
}

@property
def _identifying_params(self) -> Dict[str, Any]:
Expand All @@ -131,11 +147,14 @@ def _identifying_params(self) -> Dict[str, Any]:
@property
def _client_params(self) -> Dict[str, Any]:
"""Get the config params used for the openai client."""
return {
**super()._client_params,
"api_type": self.openai_api_type,
"api_version": self.openai_api_version,
}
if _is_openai_v1():
return super()._client_params
else:
return {
**super()._client_params,
"api_type": self.openai_api_type,
"api_version": self.openai_api_version,
}

@property
def _llm_type(self) -> str:
Expand All @@ -148,7 +167,9 @@ def lc_attributes(self) -> Dict[str, Any]:
"openai_api_version": self.openai_api_version,
}

def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult:
if not isinstance(response, dict):
response = response.dict()
for res in response["choices"]:
if res.get("finish_reason", None) == "content_filter":
raise ValueError(
Expand Down
6 changes: 3 additions & 3 deletions libs/langchain/langchain/chat_models/konko.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import _generate_from_stream
from langchain.chat_models.openai import ChatOpenAI, _convert_delta_to_message_chunk
from langchain.chat_models.base import BaseChatModel, _generate_from_stream
from langchain.chat_models.openai import _convert_delta_to_message_chunk
from langchain.pydantic_v1 import Field, root_validator
from langchain.schema import ChatGeneration, ChatResult
from langchain.schema.messages import AIMessageChunk, BaseMessage
Expand All @@ -35,7 +35,7 @@
logger = logging.getLogger(__name__)


class ChatKonko(ChatOpenAI):
class ChatKonko(BaseChatModel):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

afaict doesn't actually use any ChatOpenAI functionality

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably accidental, could you also check AnyScale?

"""`ChatKonko` Chat large language models API.

To use, you should have the ``konko`` python package installed, and the
Expand Down
82 changes: 64 additions & 18 deletions libs/langchain/langchain/chat_models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import logging
import sys
from importlib.metadata import version
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -19,6 +20,8 @@
Union,
)

from packaging.version import Version, parse

from langchain.adapters.openai import convert_dict_to_message, convert_message_to_dict
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
Expand All @@ -44,9 +47,13 @@
)
from langchain.schema.output import ChatGenerationChunk
from langchain.schema.runnable import Runnable
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
from langchain.utils import (
get_from_dict_or_env,
get_pydantic_field_names,
)

if TYPE_CHECKING:
import httpx
import tiktoken


Expand Down Expand Up @@ -91,6 +98,9 @@ async def acompletion_with_retry(
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the async completion call."""
if _is_openai_v1():
return await llm.async_client.create(**kwargs)

retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)

@retry_decorator
Expand All @@ -108,6 +118,11 @@ def _convert_delta_to_message_chunk(
content = _dict.get("content") or ""
if _dict.get("function_call"):
additional_kwargs = {"function_call": dict(_dict["function_call"])}
if (
"name" in additional_kwargs["function_call"]
and additional_kwargs["function_call"]["name"] is None
):
additional_kwargs["function_call"]["name"] = ""
else:
additional_kwargs = {}

Expand All @@ -125,6 +140,11 @@ def _convert_delta_to_message_chunk(
return default_class(content=content)


def _is_openai_v1() -> bool:
_version = parse(version("openai"))
return _version >= Version("1.0.0")


class ChatOpenAI(BaseChatModel):
"""`OpenAI` Chat large language models API.

Expand Down Expand Up @@ -166,6 +186,7 @@ def is_lc_serializable(cls) -> bool:
return True

client: Any = None #: :meta private:
async_client: Any = None #: :meta private:
model_name: str = Field(default="gpt-3.5-turbo", alias="model")
"""Model name to use."""
temperature: float = 0.7
Expand All @@ -175,16 +196,18 @@ def is_lc_serializable(cls) -> bool:
# When updating this to use a SecretStr
# Check for classes that derive from this class (as some of them
# may assume openai_api_key is a str)
openai_api_key: Optional[str] = None
openai_api_key: Optional[str] = Field(default=None, alias="api_key")
"""Base URL path for API requests,
leave blank if not using a proxy or service emulator."""
openai_api_base: Optional[str] = None
openai_organization: Optional[str] = None
openai_api_base: Optional[str] = Field(default=None, alias="base_url")
openai_organization: Optional[str] = Field(default=None, alias="organization")
# to support explicit proxy for OpenAI
openai_proxy: Optional[str] = None
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
request_timeout: Union[float, Tuple[float, float], httpx.Timeout, None] = Field(
default=None, alias="timeout"
)
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
max_retries: int = 6
max_retries: int = 2
"""Maximum number of retries to make when generating."""
streaming: bool = False
"""Whether to stream the results or not."""
Expand Down Expand Up @@ -266,14 +289,24 @@ def validate_environment(cls, values: Dict) -> Dict:
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
try:

if _is_openai_v1():
values["client"] = openai.OpenAI(
api_key=values["openai_api_key"],
timeout=values["request_timeout"],
max_retries=values["max_retries"],
organization=values["openai_organization"],
base_url=values["openai_api_base"] or None,
).chat.completions
values["async_client"] = openai.AsyncOpenAI(
api_key=values["openai_api_key"],
timeout=values["request_timeout"],
max_retries=values["max_retries"],
organization=values["openai_organization"],
base_url=values["openai_api_base"] or None,
).chat.completions
else:
values["client"] = openai.ChatCompletion
except AttributeError:
raise ValueError(
"`openai` has no `ChatCompletion` attribute, this is likely "
"due to an old version of the openai package. Try upgrading it "
"with `pip install --upgrade openai`."
)
if values["n"] < 1:
raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]:
Expand All @@ -285,7 +318,6 @@ def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
return {
"model": self.model_name,
"request_timeout": self.request_timeout,
"max_tokens": self.max_tokens,
"stream": self.streaming,
"n": self.n,
Expand All @@ -297,6 +329,9 @@ def completion_with_retry(
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
) -> Any:
"""Use tenacity to retry the completion call."""
if _is_openai_v1():
return self.client.create(**kwargs)

retry_decorator = _create_retry_decorator(self, run_manager=run_manager)

@retry_decorator
Expand Down Expand Up @@ -333,6 +368,8 @@ def _stream(
for chunk in self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
):
if not isinstance(chunk, dict):
chunk = chunk.dict()
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
Expand Down Expand Up @@ -381,8 +418,10 @@ def _create_message_dicts(
message_dicts = [convert_message_to_dict(m) for m in messages]
return message_dicts, params

def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult:
generations = []
if not isinstance(response, dict):
response = response.dict()
for res in response["choices"]:
message = convert_dict_to_message(res["message"])
gen = ChatGeneration(
Expand All @@ -408,6 +447,8 @@ async def _astream(
async for chunk in await acompletion_with_retry(
self, messages=message_dicts, run_manager=run_manager, **params
):
if not isinstance(chunk, dict):
chunk = chunk.dict()
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
Expand Down Expand Up @@ -455,11 +496,16 @@ def _identifying_params(self) -> Dict[str, Any]:
def _client_params(self) -> Dict[str, Any]:
"""Get the parameters used for the openai client."""
openai_creds: Dict[str, Any] = {
"api_key": self.openai_api_key,
"api_base": self.openai_api_base,
"organization": self.openai_organization,
"model": self.model_name,
}
if not _is_openai_v1():
openai_creds.update(
{
"api_key": self.openai_api_key,
"api_base": self.openai_api_base,
"organization": self.openai_organization,
}
)
if self.openai_proxy:
import openai

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
import os
from typing import Any, Mapping, cast
from unittest import mock

import pytest
Expand Down Expand Up @@ -48,9 +47,8 @@ def test_model_name_set_on_chat_result_when_present_in_response(
"""
# convert sample_response_text to instance of Mapping[str, Any]
sample_response = json.loads(sample_response_text)
mock_response = cast(Mapping[str, Any], sample_response)
mock_chat = AzureChatOpenAI()
chat_result = mock_chat._create_chat_result(mock_response)
chat_result = mock_chat._create_chat_result(sample_response)
assert (
chat_result.llm_output is not None
and chat_result.llm_output["model_name"] == model_name
Expand Down
Loading