Skip to content

Commit

Permalink
community[patch]: Modify LLMs/Anyscale work with OpenAI API v1 (langc…
Browse files Browse the repository at this point in the history
…hain-ai#14206)

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
- **Description:** 
1. Modify LLMs/Anyscale to work with OAI v1
2. Get rid of openai_ prefixed variables in Chat_model/ChatAnyscale
3. Modify `anyscale_api_base` to `anyscale_base_url` to follow OAI name
convention (reverted)

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
  • Loading branch information
2 people authored and Hayden Wolff committed Feb 27, 2024
1 parent 29dd346 commit 8987b3d
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 114 deletions.
30 changes: 12 additions & 18 deletions libs/community/langchain_community/chat_models/anyscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

logger = logging.getLogger(__name__)


DEFAULT_API_BASE = "https://api.endpoints.anyscale.com/v1"
DEFAULT_MODEL = "meta-llama/Llama-2-7b-chat-hf"

Expand Down Expand Up @@ -60,7 +59,7 @@ def lc_secrets(self) -> Dict[str, str]:
def is_lc_serializable(cls) -> bool:
return False

anyscale_api_key: SecretStr
anyscale_api_key: SecretStr = Field(default=None)
"""AnyScale Endpoints API keys."""
model_name: str = Field(default=DEFAULT_MODEL, alias="model")
"""Model name to use."""
Expand Down Expand Up @@ -102,22 +101,17 @@ def get_available_models(

return {model["id"] for model in models_response.json()["data"]}

@root_validator(pre=True)
def validate_environment_override(cls, values: dict) -> dict:
@root_validator()
def validate_environment(cls, values: dict) -> dict:
"""Validate that api key and python package exists in environment."""
values["openai_api_key"] = get_from_dict_or_env(
values,
"anyscale_api_key",
"ANYSCALE_API_KEY",
)
values["anyscale_api_key"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"anyscale_api_key",
"ANYSCALE_API_KEY",
)
)
values["openai_api_base"] = get_from_dict_or_env(
values["anyscale_api_base"] = get_from_dict_or_env(
values,
"anyscale_api_base",
"ANYSCALE_API_BASE",
Expand All @@ -140,8 +134,8 @@ def validate_environment_override(cls, values: dict) -> dict:
try:
if is_openai_v1():
client_params = {
"api_key": values["openai_api_key"],
"base_url": values["openai_api_base"],
"api_key": values["anyscale_api_key"].get_secret_value(),
"base_url": values["anyscale_api_base"],
# To do: future support
# "organization": values["openai_organization"],
# "timeout": values["request_timeout"],
Expand All @@ -152,6 +146,8 @@ def validate_environment_override(cls, values: dict) -> dict:
}
values["client"] = openai.OpenAI(**client_params).chat.completions
else:
values["openai_api_base"] = values["anyscale_api_base"]
values["openai_api_key"] = values["anyscale_api_key"].get_secret_value()
values["client"] = openai.ChatCompletion
except AttributeError as exc:
raise ValueError(
Expand All @@ -164,10 +160,9 @@ def validate_environment_override(cls, values: dict) -> dict:
values["model_name"] = DEFAULT_MODEL

model_name = values["model_name"]

available_models = cls.get_available_models(
values["openai_api_key"],
values["openai_api_base"],
values["anyscale_api_key"].get_secret_value(),
values["anyscale_api_base"],
)

if model_name not in available_models:
Expand Down Expand Up @@ -197,9 +192,8 @@ def _get_encoding_model(self) -> tuple[str, tiktoken.Encoding]:

def get_num_tokens_from_messages(self, messages: list[BaseMessage]) -> int:
"""Calculate num tokens with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
Official documentation: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
"""
if sys.version_info[1] <= 7:
return super().get_num_tokens_from_messages(messages)
model, encoding = self._get_encoding_model()
Expand Down
Loading

0 comments on commit 8987b3d

Please sign in to comment.