Skip to content

Commit

Permalink
feat: integrate nemotron API (#659)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wendong-Fan committed Jun 20, 2024
1 parent d9e6f8d commit 896c4c4
Show file tree
Hide file tree
Showing 13 changed files with 160 additions and 5 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/pytest_package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ jobs:
OPENWEATHERMAP_API_KEY: "${{ secrets.OPENWEATHERMAP_API_KEY }}"
ANTHROPIC_API_KEY: "${{ secrets.ANTHROPIC_API_KEY }}"
COHERE_API_KEY: "${{ secrets.COHERE_API_KEY }}"
NVIDIA_API_BASE_URL: "${{ secrets.NVIDIA_API_BASE_URL }}"
NVIDIA_API_KEY: "${{ secrets.NVIDIA_API_KEY }}"
run: poetry run pytest --fast-test-mode test/

pytest_package_llm_test:
Expand All @@ -47,6 +49,8 @@ jobs:
OPENWEATHERMAP_API_KEY: "${{ secrets.OPENWEATHERMAP_API_KEY }}"
ANTHROPIC_API_KEY: "${{ secrets.ANTHROPIC_API_KEY }}"
COHERE_API_KEY: "${{ secrets.COHERE_API_KEY }}"
NVIDIA_API_BASE_URL: "${{ secrets.NVIDIA_API_BASE_URL }}"
NVIDIA_API_KEY: "${{ secrets.NVIDIA_API_KEY }}"
run: poetry run pytest --llm-test-only test/

pytest_package_very_slow_test:
Expand All @@ -65,4 +69,6 @@ jobs:
OPENWEATHERMAP_API_KEY: "${{ secrets.OPENWEATHERMAP_API_KEY }}"
ANTHROPIC_API_KEY: "${{ secrets.ANTHROPIC_API_KEY }}"
COHERE_API_KEY: "${{ secrets.COHERE_API_KEY }}"
NVIDIA_API_BASE_URL: "${{ secrets.NVIDIA_API_BASE_URL }}"
NVIDIA_API_KEY: "${{ secrets.NVIDIA_API_KEY }}"
run: poetry run pytest --very-slow-test-only test/
10 changes: 9 additions & 1 deletion camel/agents/chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,11 +522,19 @@ def handle_batch_response(
"""
output_messages: List[BaseMessage] = []
for choice in response.choices:
if isinstance(choice.message, list):
# If choice.message is a list, handle accordingly
# It's a check to fit with Nemotron model integration.
content = "".join(
[msg.content for msg in choice.message if msg.content]
)
else:
content = choice.message.content or ""
chat_message = BaseMessage(
role_name=self.role_name,
role_type=self.role_type,
meta_dict=dict(),
content=choice.message.content or "",
content=content,
)
output_messages.append(chat_message)
finish_reasons = [
Expand Down
2 changes: 2 additions & 0 deletions camel/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .base_model import BaseModelBackend
from .litellm_model import LiteLLMModel
from .model_factory import ModelFactory
from .nemotron_model import NemotronModel
from .ollama_model import OllamaModel
from .open_source_model import OpenSourceModel
from .openai_audio_models import OpenAIAudioModels
Expand All @@ -32,5 +33,6 @@
'ModelFactory',
'LiteLLMModel',
'OpenAIAudioModels',
'NemotronModel',
'OllamaModel',
]
71 changes: 71 additions & 0 deletions camel/models/nemotron_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
import os
from typing import List, Optional

from openai import OpenAI

from camel.messages import OpenAIMessage
from camel.types import ChatCompletion, ModelType
from camel.utils import (
BaseTokenCounter,
model_api_key_required,
)


class NemotronModel:
r"""Nemotron model API backend with OpenAI compatibility."""

# NOTE: Nemotron model doesn't support additional model config like OpenAI.

def __init__(
self,
model_type: ModelType,
api_key: Optional[str] = None,
) -> None:
r"""Constructor for Nvidia backend.
Args:
model_type (ModelType): Model for which a backend is created.
api_key (Optional[str]): The API key for authenticating with the
Nvidia service. (default: :obj:`None`)
"""
self.model_type = model_type
url = os.environ.get('NVIDIA_API_BASE_URL', None)
self._api_key = api_key or os.environ.get("NVIDIA_API_KEY")
if not url or not self._api_key:
raise ValueError("The NVIDIA API base url and key should be set.")
self._client = OpenAI(
timeout=60, max_retries=3, base_url=url, api_key=self._api_key
)
self._token_counter: Optional[BaseTokenCounter] = None

@model_api_key_required
def run(
self,
messages: List[OpenAIMessage],
) -> ChatCompletion:
r"""Runs inference of OpenAI chat completion.
Args:
messages (List[OpenAIMessage]): Message list.
Returns:
ChatCompletion.
"""
response = self._client.chat.completions.create(
messages=messages,
model=self.model_type.value,
)
return response
20 changes: 18 additions & 2 deletions camel/types/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,14 @@ class ModelType(Enum):
CLAUDE_2_0 = "claude-2.0"
CLAUDE_INSTANT_1_2 = "claude-instant-1.2"

# 3 models
# Claude3 models
CLAUDE_3_OPUS = "claude-3-opus-20240229"
CLAUDE_3_SONNET = "claude-3-sonnet-20240229"
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"

# Nvidia models
NEMOTRON_4_REWARD = "nvidia/nemotron-4-340b-reward"

@property
def value_for_tiktoken(self) -> str:
return (
Expand Down Expand Up @@ -103,6 +106,17 @@ def is_anthropic(self) -> bool:
ModelType.CLAUDE_3_HAIKU,
}

@property
def is_nvidia(self) -> bool:
r"""Returns whether this type of models is Nvidia-released model.
Returns:
bool: Whether this type of models is nvidia.
"""
return self in {
ModelType.NEMOTRON_4_REWARD,
}

@property
def token_limit(self) -> int:
r"""Returns the maximum token limit for a given model.
Expand Down Expand Up @@ -134,7 +148,7 @@ def token_limit(self) -> int:
return 2048
elif self is ModelType.VICUNA_16K:
return 16384
if self in {ModelType.CLAUDE_2_0, ModelType.CLAUDE_INSTANT_1_2}:
elif self in {ModelType.CLAUDE_2_0, ModelType.CLAUDE_INSTANT_1_2}:
return 100_000
elif self in {
ModelType.CLAUDE_2_1,
Expand All @@ -143,6 +157,8 @@ def token_limit(self) -> int:
ModelType.CLAUDE_3_HAIKU,
}:
return 200_000
elif self is ModelType.NEMOTRON_4_REWARD:
return 4096
else:
raise ValueError("Unknown model type")

Expand Down
4 changes: 4 additions & 0 deletions camel/utils/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def wrapper(self, *args, **kwargs):
if not self._api_key and 'ANTHROPIC_API_KEY' not in os.environ:
raise ValueError('Anthropic API key not found.')
return func(self, *args, **kwargs)
elif self.model_type.is_nvidia:
if not self._api_key and 'NVIDIA_API_KEY' not in os.environ:
raise ValueError('NVIDIA API key not found.')
return func(self, *args, **kwargs)
else:
raise ValueError('Unsupported model type.')

Expand Down
File renamed without changes.
48 changes: 48 additions & 0 deletions examples/models/nemotron_model_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========

from camel.models import NemotronModel
from camel.types import ModelType

nemotro = NemotronModel(model_type=ModelType.NEMOTRON_4_REWARD)

message = [
{"role": "user", "content": "I am going to Paris, what should I see?"},
{
"role": "assistant",
"content": "Ah, Paris, the City of Light! There are so "
"many amazing things to see and do in this beautiful city ...",
},
]

ans = nemotro.run(message)
print(ans)
'''
===============================================================================
ChatCompletion(id='4668ad22-1dec-4df4-ba92-97ffa5fbd16d', choices=[Choice
(finish_reason='length', index=0, logprobs=ChoiceLogprobs(content=
[ChatCompletionTokenLogprob(token='helpfulness', bytes=None, logprob=1.
6171875, top_logprobs=[]), ChatCompletionTokenLogprob(token='correctness',
bytes=None, logprob=1.6484375, top_logprobs=[]), ChatCompletionTokenLogprob
(token='coherence', bytes=None, logprob=3.3125, top_logprobs=[]),
ChatCompletionTokenLogprob(token='complexity', bytes=None, logprob=0.546875,
top_logprobs=[]), ChatCompletionTokenLogprob(token='verbosity', bytes=None,
logprob=0.515625, top_logprobs=[])]), message=[ChatCompletionMessage
(content='helpfulness:1.6171875,correctness:1.6484375,coherence:3.3125,
complexity:0.546875,verbosity:0.515625', role='assistant', function_call=None,
tool_calls=None)])], created=None, model=None, object=None,
system_fingerprint=None, usage=CompletionUsage(completion_tokens=1,
prompt_tokens=78, total_tokens=79))
===============================================================================
'''
File renamed without changes.
File renamed without changes.
4 changes: 2 additions & 2 deletions examples/test/test_ai_society_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import examples.ai_society.role_playing
import examples.function_call.role_playing_with_functions
import examples.open_source_models.role_playing_with_open_source_model
import examples.models.role_playing_with_open_source_model
from camel.models import ModelFactory
from camel.types import ModelPlatformType, ModelType

Expand All @@ -42,6 +42,6 @@ def test_role_playing_with_function_example():

def test_role_playing_with_open_source_model():
with patch('time.sleep', return_value=None):
examples.open_source_models.role_playing_with_open_source_model.main(
examples.models.role_playing_with_open_source_model.main(
chat_turn_limit=2
)

0 comments on commit 896c4c4

Please sign in to comment.