Skip to content

Commit

Permalink
feat: Add support for litellm library (#596)
Browse files Browse the repository at this point in the history
Co-authored-by: Miles Bennett <yuezhao@nwu.edu.cn>
Co-authored-by: Wendong <w3ndong.fan@gmail.com>
Co-authored-by: Wendong-Fan <133094783+Wendong-Fan@users.noreply.github.com>
  • Loading branch information
4 people committed Jun 14, 2024
1 parent 849896c commit 8951cd3
Show file tree
Hide file tree
Showing 10 changed files with 516 additions and 26 deletions.
3 changes: 3 additions & 0 deletions camel/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from .anthropic_config import ANTHROPIC_API_PARAMS, AnthropicConfig
from .base_config import BaseConfig
from .litellm_config import LITELLM_API_PARAMS, LiteLLMConfig
from .openai_config import (
OPENAI_API_PARAMS,
ChatGPTConfig,
Expand All @@ -26,4 +27,6 @@
'AnthropicConfig',
'ANTHROPIC_API_PARAMS',
'OpenSourceConfig',
'LiteLLMConfig',
'LITELLM_API_PARAMS',
]
113 changes: 113 additions & 0 deletions camel/configs/litellm_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from __future__ import annotations

from dataclasses import asdict, dataclass, field
from typing import List, Optional, Union

from camel.configs.base_config import BaseConfig


@dataclass(frozen=True)
class LiteLLMConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
LiteLLM API.
Args:
model (str): The name of the language model to use for text completion.
messages (List): A list of message objects representing the
conversation context. (default: [])
timeout (Optional[Union[float, str]], optional): Request timeout.
(default: None)
temperature (Optional[float], optional): Temperature parameter for
controlling randomness. (default: None)
top_p (Optional[float], optional): Top-p parameter for nucleus
sampling. (default: None)
n (Optional[int], optional): Number of completions to generate.
(default: None)
stream (Optional[bool], optional): Whether to return a streaming
response. (default: None)
stream_options (Optional[dict], optional): Options for the streaming
response. (default: None)
stop (Optional[Union[str, List[str]]], optional): Sequences where the
API will stop generating further tokens. (default: None)
max_tokens (Optional[int], optional): Maximum number of tokens to
generate. (default: None)
presence_penalty (Optional[float], optional): Penalize new tokens
based on their existence in the text so far. (default: None)
frequency_penalty (Optional[float], optional): Penalize new tokens
based on their frequency in the text so far. (default: None)
logit_bias (Optional[dict], optional): Modify the probability of
specific tokens appearing in the completion. (default: None)
user (Optional[str], optional): A unique identifier representing the
end-user. (default: None)
response_format (Optional[dict], optional): Response format
parameters. (default: None)
seed (Optional[int], optional): Random seed. (default: None)
tools (Optional[List], optional): List of tools. (default: None)
tool_choice (Optional[Union[str, dict]], optional): Tool choice
parameters. (default: None)
logprobs (Optional[bool], optional): Whether to return log
probabilities of the output tokens. (default: None)
top_logprobs (Optional[int], optional): Number of most likely tokens
to return at each token position. (default: None)
deployment_id (Optional[str], optional): Deployment ID. (default: None)
extra_headers (Optional[dict], optional): Additional headers for the
request. (default: None)
base_url (Optional[str], optional): Base URL for the API. (default:
None)
api_version (Optional[str], optional): API version. (default: None)
api_key (Optional[str], optional): API key. (default: None)
model_list (Optional[list], optional): List of API base, version,
keys. (default: None)
mock_response (Optional[str], optional): Mock completion response for
testing or debugging. (default: None)
custom_llm_provider (Optional[str], optional): Non-OpenAI LLM
provider. (default: None)
max_retries (Optional[int], optional): Maximum number of retries.
(default: None)
"""

model: str = "gpt-3.5-turbo"
messages: List = field(default_factory=list)
timeout: Optional[Union[float, str]] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
n: Optional[int] = None
stream: Optional[bool] = None
stream_options: Optional[dict] = None
stop: Optional[Union[str, List[str]]] = None
max_tokens: Optional[int] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
logit_bias: Optional[dict] = field(default_factory=dict)
user: Optional[str] = None
response_format: Optional[dict] = None
seed: Optional[int] = None
tools: Optional[List] = field(default_factory=list)
tool_choice: Optional[Union[str, dict]] = None
logprobs: Optional[bool] = None
top_logprobs: Optional[int] = None
deployment_id: Optional[str] = None
extra_headers: Optional[dict] = field(default_factory=dict)
base_url: Optional[str] = None
api_version: Optional[str] = None
api_key: Optional[str] = None
model_list: Optional[list] = field(default_factory=list)
mock_response: Optional[str] = None
custom_llm_provider: Optional[str] = None
max_retries: Optional[int] = None


LITELLM_API_PARAMS = {param for param in asdict(LiteLLMConfig()).keys()}
2 changes: 2 additions & 0 deletions camel/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from .anthropic_model import AnthropicModel
from .base_model import BaseModelBackend
from .litellm_model import LiteLLMModel
from .model_factory import ModelFactory
from .open_source_model import OpenSourceModel
from .openai_audio_models import OpenAIAudioModels
Expand All @@ -26,5 +27,6 @@
'StubModel',
'OpenSourceModel',
'ModelFactory',
'LiteLLMModel',
'OpenAIAudioModels',
]
112 changes: 112 additions & 0 deletions camel/models/litellm_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from camel.configs import LITELLM_API_PARAMS
from camel.messages import OpenAIMessage
from camel.utils import LiteLLMTokenCounter

if TYPE_CHECKING:
from litellm.utils import CustomStreamWrapper, ModelResponse


class LiteLLMModel:
r"""Constructor for LiteLLM backend with OpenAI compatibility."""

# NOTE: Currently "stream": True is not supported with LiteLLM due to the
# limitation of the current camel design.

def __init__(
self, model_type: str, model_config_dict: Dict[str, Any]
) -> None:
r"""Constructor for LiteLLM backend.
Args:
model_type (str): Model for which a backend is created,
such as GPT-3.5-turbo, Claude-2, etc.
model_config_dict (Dict[str, Any]): A dictionary of parameters for
the model configuration.
"""
self.model_type = model_type
self.model_config_dict = model_config_dict
self._client = None
self._token_counter: Optional[LiteLLMTokenCounter] = None
self.check_model_config()

@property
def client(self):
if self._client is None:
from litellm import completion

self._client = completion
return self._client

@property
def token_counter(self) -> LiteLLMTokenCounter:
r"""Initialize the token counter for the model backend.
Returns:
LiteLLMTokenCounter: The token counter following the model's
tokenization style.
"""
if not self._token_counter:
self._token_counter = LiteLLMTokenCounter(self.model_type)
return self._token_counter

def run(
self,
messages: List[OpenAIMessage],
) -> Union['ModelResponse', 'CustomStreamWrapper']:
r"""Runs inference of LiteLLM chat completion.
Args:
messages (List[OpenAIMessage]): Message list with the chat history
in OpenAI format.
Returns:
Union[ModelResponse, CustomStreamWrapper]:
`ModelResponse` in the non-stream mode, or
`CustomStreamWrapper` in the stream mode.
"""
response = self.client(
model=self.model_type,
messages=messages,
**self.model_config_dict,
)
return response

def check_model_config(self):
r"""Check whether the model configuration contains any unexpected
arguments to LiteLLM API.
Raises:
ValueError: If the model configuration dictionary contains any
unexpected arguments.
"""
for param in self.model_config_dict:
if param not in LITELLM_API_PARAMS:
raise ValueError(
f"Unexpected argument `{param}` is "
"input into LiteLLM model backend."
)

@property
def stream(self) -> bool:
r"""Returns whether the model is in stream mode, which sends partial
results each time.
Returns:
bool: Whether the model is in stream mode.
"""
return self.model_config_dict.get('stream', False)
2 changes: 2 additions & 0 deletions camel/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .token_counting import (
AnthropicTokenCounter,
BaseTokenCounter,
LiteLLMTokenCounter,
OpenAITokenCounter,
OpenSourceTokenCounter,
get_model_encoding,
Expand All @@ -53,6 +54,7 @@
'BaseTokenCounter',
'OpenAITokenCounter',
'OpenSourceTokenCounter',
'LiteLLMTokenCounter',
'Constants',
'text_extract_from_web',
'create_chunks',
Expand Down
52 changes: 52 additions & 0 deletions camel/utils/token_counting.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,58 @@ def count_tokens_from_messages(self, messages: List[OpenAIMessage]) -> int:
return self.client.count_tokens(prompt)


class LiteLLMTokenCounter:
def __init__(self, model_type: str):
r"""Constructor for the token counter for LiteLLM models.
Args:
model_type (str): Model type for which tokens will be counted.
"""
self.model_type = model_type
self._token_counter = None
self._completion_cost = None

@property
def token_counter(self):
if self._token_counter is None:
from litellm import token_counter

self._token_counter = token_counter
return self._token_counter

@property
def completion_cost(self):
if self._completion_cost is None:
from litellm import completion_cost

self._completion_cost = completion_cost
return self._completion_cost

def count_tokens_from_messages(self, messages: List[OpenAIMessage]) -> int:
r"""Count number of tokens in the provided message list using
the tokenizer specific to this type of model.
Args:
messages (List[OpenAIMessage]): Message list with the chat history
in LiteLLM API format.
Returns:
int: Number of tokens in the messages.
"""
return self.token_counter(model=self.model_type, messages=messages)

def calculate_cost_from_response(self, response: dict) -> float:
r"""Calculate the cost of the given completion response.
Args:
response (dict): The completion response from LiteLLM.
Returns:
float: The cost of the completion call in USD.
"""
return self.completion_cost(completion_response=response)


def count_tokens_from_image(
image: Image.Image, detail: OpenAIVisionDetailType
) -> int:
Expand Down
Loading

0 comments on commit 8951cd3

Please sign in to comment.