Skip to content

Commit

Permalink
feat: add_zhipuai_model (#600)
Browse files Browse the repository at this point in the history
Co-authored-by: Wendong-Fan <133094783+Wendong-Fan@users.noreply.github.com>
Co-authored-by: Wendong <w3ndong.fan@gmail.com>
  • Loading branch information
3 people committed Jun 16, 2024
1 parent 26100a9 commit 53b9308
Show file tree
Hide file tree
Showing 7 changed files with 330 additions and 0 deletions.
2 changes: 2 additions & 0 deletions camel/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
from .openai_audio_models import OpenAIAudioModels
from .openai_model import OpenAIModel
from .stub_model import StubModel
from .zhipuai_model import ZhipuAIModel

__all__ = [
'BaseModelBackend',
'OpenAIModel',
'AnthropicModel',
'StubModel',
'ZhipuAIModel',
'OpenSourceModel',
'ModelFactory',
'LiteLLMModel',
Expand Down
3 changes: 3 additions & 0 deletions camel/models/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from camel.models.open_source_model import OpenSourceModel
from camel.models.openai_model import OpenAIModel
from camel.models.stub_model import StubModel
from camel.models.zhipuai_model import ZhipuAIModel
from camel.types import ModelType


Expand Down Expand Up @@ -58,6 +59,8 @@ def create(
model_class = OpenSourceModel
elif model_type.is_anthropic:
model_class = AnthropicModel
elif model_type.is_zhipuai:
model_class = ZhipuAIModel
else:
raise ValueError(f"Unknown model type `{model_type}` is input")

Expand Down
125 changes: 125 additions & 0 deletions camel/models/zhipuai_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========

import os
from typing import Any, Dict, List, Optional, Union

from openai import OpenAI, Stream

from camel.configs import OPENAI_API_PARAMS
from camel.messages import OpenAIMessage
from camel.models import BaseModelBackend
from camel.types import ChatCompletion, ChatCompletionChunk, ModelType
from camel.utils import (
BaseTokenCounter,
OpenAITokenCounter,
model_api_key_required,
)


class ZhipuAIModel(BaseModelBackend):
r"""ZhipuAI API in a unified BaseModelBackend interface."""

def __init__(
self,
model_type: ModelType,
model_config_dict: Dict[str, Any],
api_key: Optional[str] = None,
url: Optional[str] = None,
) -> None:
r"""Constructor for ZhipuAI backend.
Args:
model_type (ModelType): Model for which a backend is created,
such as GLM_* series.
model_config_dict (Dict[str, Any]): A dictionary that will
be fed into openai.ChatCompletion.create().
api_key (Optional[str]): The API key for authenticating with the
ZhipuAI service. (default: :obj:`None`)
"""
super().__init__(model_type, model_config_dict)
self._url = url or os.environ.get("ZHIPUAI_API_BASE_URL")
self._api_key = api_key or os.environ.get("ZHIPUAI_API_KEY")
self._client = OpenAI(
timeout=60,
max_retries=3,
api_key=self._api_key,
base_url=self._url,
)
self._token_counter: Optional[BaseTokenCounter] = None

@model_api_key_required
def run(
self,
messages: List[OpenAIMessage],
) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
r"""Runs inference of OpenAI chat completion.
Args:
messages (List[OpenAIMessage]): Message list with the chat history
in OpenAI API format.
Returns:
Union[ChatCompletion, Stream[ChatCompletionChunk]]:
`ChatCompletion` in the non-stream mode, or
`Stream[ChatCompletionChunk]` in the stream mode.
"""
# Use OpenAI cilent as interface call ZhipuAI
# Reference: https://open.bigmodel.cn/dev/api#openai_sdk
response = self._client.chat.completions.create(
messages=messages,
model=self.model_type.value,
**self.model_config_dict,
)
return response

@property
def token_counter(self) -> BaseTokenCounter:
r"""Initialize the token counter for the model backend.
Returns:
OpenAITokenCounter: The token counter following the model's
tokenization style.
"""

if not self._token_counter:
# It's a temporary setting for token counter.
self._token_counter = OpenAITokenCounter(ModelType.GPT_3_5_TURBO)
return self._token_counter

def check_model_config(self):
r"""Check whether the model configuration contains any
unexpected arguments to OpenAI API.
Raises:
ValueError: If the model configuration dictionary contains any
unexpected arguments to OpenAI API.
"""
for param in self.model_config_dict:
if param not in OPENAI_API_PARAMS:
raise ValueError(
f"Unexpected argument `{param}` is "
"input into OpenAI model backend."
)
pass

@property
def stream(self) -> bool:
r"""Returns whether the model is in stream mode, which sends partial
results each time.
Returns:
bool: Whether the model is in stream mode.
"""
return self.model_config_dict.get('stream', False)
18 changes: 18 additions & 0 deletions camel/types/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ class ModelType(Enum):
GPT_4_32K = "gpt-4-32k"
GPT_4_TURBO = "gpt-4-turbo"
GPT_4O = "gpt-4o"
GLM_4 = "glm-4"
GLM_4V = 'glm-4v'
GLM_3_TURBO = "glm-3-turbo"

STUB = "stub"

Expand Down Expand Up @@ -62,6 +65,15 @@ def is_openai(self) -> bool:
ModelType.GPT_4O,
}

@property
def is_zhipuai(self) -> bool:
r"""Returns whether this type of models is an ZhipuAI model."""
return self in {
ModelType.GLM_3_TURBO,
ModelType.GLM_4,
ModelType.GLM_4V,
}

@property
def is_open_source(self) -> bool:
r"""Returns whether this type of models is open-source."""
Expand Down Expand Up @@ -103,6 +115,12 @@ def token_limit(self) -> int:
return 128000
elif self is ModelType.GPT_4O:
return 128000
elif self == ModelType.GLM_4:
return 8192
elif self == ModelType.GLM_3_TURBO:
return 8192
elif self == ModelType.GLM_4V:
return 1024
elif self is ModelType.STUB:
return 4096
elif self is ModelType.LLAMA_2:
Expand Down
4 changes: 4 additions & 0 deletions camel/utils/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ def wrapper(self, *args, **kwargs):
if not self._api_key and 'OPENAI_API_KEY' not in os.environ:
raise ValueError('OpenAI API key not found.')
return func(self, *args, **kwargs)
elif self.model_type.is_zhipuai:
if 'ZHIPUAI_API_KEY' not in os.environ:
raise ValueError('ZhiPuAI API key not found.')
return func(self, *args, **kwargs)
elif self.model_type.is_anthropic:
if not self._api_key and 'ANTHROPIC_API_KEY' not in os.environ:
raise ValueError('Anthropic API key not found.')
Expand Down
117 changes: 117 additions & 0 deletions examples/zhipuai_models/zhipuai_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========

from camel.agents import ChatAgent
from camel.configs import ChatGPTConfig
from camel.messages import BaseMessage
from camel.types import ModelType

# Define system message
sys_msg = BaseMessage.make_assistant_message(
role_name="Assistant",
content="You are a helpful assistant.",
)

# Set model config
model_config = ChatGPTConfig(
temperature=0.2, top_p=0.9
) # temperature=,top_p here can not be 1 or 0.

# Set agent
camel_agent = ChatAgent(
sys_msg,
model_config=model_config,
model_type=ModelType.GLM_4,
)
camel_agent.reset()

user_msg = BaseMessage.make_user_message(
role_name="User",
content="I want to practice my legs today."
"Help me make a fitness and diet plan",
)

# Get response information
response = camel_agent.step(user_msg)
print(response.msgs[0].content)
'''
===============================================================================
Certainly! Focusing on leg workouts can help improve strength, endurance, and
overall lower-body fitness. Here's a sample fitness
and diet plan for leg training:
**Fitness Plan:**
1. **Warm-Up:**
- 5-10 minutes of light cardio (jogging, cycling, or jumping jacks)
- Leg swings (forward and backward)
- Hip circles
2. **Strength Training:**
- Squats: 3 sets of 8-12 reps
- Deadlifts: 3 sets of 8-12 reps
- Lunges: 3 sets of 10-12 reps per leg
- Leg press: 3 sets of 10-12 reps
- Calf raises: 3 sets of 15-20 reps
3. **Cardio:**
- Hill sprints: 5-8 reps of 30-second sprints
- Cycling or stationary biking: 20-30 minutes at moderate intensity
4. **Cool Down:**
- Stretching (focus on the legs, hip flexors, and hamstrings)
- Foam rolling (optional)
**Diet Plan:**
1. **Breakfast:**
- Greek yogurt with mixed berries and a tablespoon of chia seeds
- Whole-grain toast with avocado
2. **Snack:**
- A banana with a tablespoon of natural peanut butter
3. **Lunch:**
- Grilled chicken breast with quinoa and steamed vegetables
- A side of mixed greens with a light vinaigrette
4. **Snack:**
- A serving of mixed nuts and dried fruits
5. **Dinner:**
- Baked salmon with sweet potato and roasted asparagus
- A side of lentil soup or a bean salad
6. **Post-Workout Snack:**
- A protein shake or a serving of cottage cheese with fruit
7. **Hydration:**
- Drink plenty of water throughout the day to
stay hydrated, especially after workouts.
**Tips:**
- Ensure you get enough rest and recovery, as leg workouts
can be demanding on the body.
- Listen to your body and adjust the weights and reps
according to your fitness level.
- Make sure to include a variety of nutrients in your diet to
support muscle recovery and overall health.
- Consult a fitness professional or trainer if you need personalized
guidance or have any pre-existing health conditions.
Remember, consistency is key to seeing results, so stick to
your plan and modify it as needed to suit your goals and progress.
===============================================================================
'''
61 changes: 61 additions & 0 deletions test/models/test_zhipuai_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
import re

import pytest

from camel.configs import ChatGPTConfig, OpenSourceConfig
from camel.models import ZhipuAIModel
from camel.types import ModelType
from camel.utils import OpenAITokenCounter


@pytest.mark.model_backend
@pytest.mark.parametrize(
"model_type",
[
ModelType.GLM_3_TURBO,
ModelType.GLM_4,
ModelType.GLM_4V,
],
)
def test_zhipuai_model(model_type):
model_config_dict = ChatGPTConfig().__dict__
model = ZhipuAIModel(model_type, model_config_dict)
assert model.model_type == model_type
assert model.model_config_dict == model_config_dict
assert isinstance(model.token_counter, OpenAITokenCounter)
assert isinstance(model.model_type.value_for_tiktoken, str)
assert isinstance(model.model_type.token_limit, int)


@pytest.mark.model_backend
def test_zhipuai_model_unexpected_argument():
model_type = ModelType.GLM_4V
model_config = OpenSourceConfig(
model_path="vicuna-7b-v1.5",
server_url="http://localhost:8000/v1",
)
model_config_dict = model_config.__dict__

with pytest.raises(
ValueError,
match=re.escape(
(
"Unexpected argument `model_path` is "
"input into OpenAI model backend."
)
),
):
_ = ZhipuAIModel(model_type, model_config_dict)

0 comments on commit 53b9308

Please sign in to comment.