Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: integrate nemotron API #659

Merged
merged 6 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/pytest_package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ jobs:
OPENWEATHERMAP_API_KEY: "${{ secrets.OPENWEATHERMAP_API_KEY }}"
ANTHROPIC_API_KEY: "${{ secrets.ANTHROPIC_API_KEY }}"
COHERE_API_KEY: "${{ secrets.COHERE_API_KEY }}"
NVIDIA_API_BASE_URL: "${{ secrets.NVIDIA_API_BASE_URL }}"
NVIDIA_API_KEY: "${{ secrets.NVIDIA_API_KEY }}"
run: poetry run pytest --fast-test-mode test/

pytest_package_llm_test:
Expand All @@ -47,6 +49,8 @@ jobs:
OPENWEATHERMAP_API_KEY: "${{ secrets.OPENWEATHERMAP_API_KEY }}"
ANTHROPIC_API_KEY: "${{ secrets.ANTHROPIC_API_KEY }}"
COHERE_API_KEY: "${{ secrets.COHERE_API_KEY }}"
NVIDIA_API_BASE_URL: "${{ secrets.NVIDIA_API_BASE_URL }}"
NVIDIA_API_KEY: "${{ secrets.NVIDIA_API_KEY }}"
run: poetry run pytest --llm-test-only test/

pytest_package_very_slow_test:
Expand All @@ -65,4 +69,6 @@ jobs:
OPENWEATHERMAP_API_KEY: "${{ secrets.OPENWEATHERMAP_API_KEY }}"
ANTHROPIC_API_KEY: "${{ secrets.ANTHROPIC_API_KEY }}"
COHERE_API_KEY: "${{ secrets.COHERE_API_KEY }}"
NVIDIA_API_BASE_URL: "${{ secrets.NVIDIA_API_BASE_URL }}"
NVIDIA_API_KEY: "${{ secrets.NVIDIA_API_KEY }}"
run: poetry run pytest --very-slow-test-only test/
10 changes: 9 additions & 1 deletion camel/agents/chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,11 +522,19 @@ def handle_batch_response(
"""
output_messages: List[BaseMessage] = []
for choice in response.choices:
if isinstance(choice.message, list):
# If choice.message is a list, handle accordingly
# It's a check to fit with Nemotron model integration.
content = "".join(
[msg.content for msg in choice.message if msg.content]
)
else:
content = choice.message.content or ""
chat_message = BaseMessage(
role_name=self.role_name,
role_type=self.role_type,
meta_dict=dict(),
content=choice.message.content or "",
content=content,
)
output_messages.append(chat_message)
finish_reasons = [
Expand Down
2 changes: 2 additions & 0 deletions camel/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .base_model import BaseModelBackend
from .litellm_model import LiteLLMModel
from .model_factory import ModelFactory
from .nemotron_model import NemotronModel
from .ollama_model import OllamaModel
from .open_source_model import OpenSourceModel
from .openai_audio_models import OpenAIAudioModels
Expand All @@ -32,5 +33,6 @@
'ModelFactory',
'LiteLLMModel',
'OpenAIAudioModels',
'NemotronModel',
'OllamaModel',
]
71 changes: 71 additions & 0 deletions camel/models/nemotron_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
import os
from typing import List, Optional

from openai import OpenAI

from camel.messages import OpenAIMessage
from camel.types import ChatCompletion, ModelType
from camel.utils import (
BaseTokenCounter,
model_api_key_required,
)


class NemotronModel:
r"""Nemotron model API backend with OpenAI compatibility."""

# NOTE: Nemotron model doesn't support additional model config like OpenAI.

def __init__(
self,
model_type: ModelType,
api_key: Optional[str] = None,
) -> None:
r"""Constructor for Nvidia backend.

Args:
model_type (ModelType): Model for which a backend is created.
api_key (Optional[str]): The API key for authenticating with the
Nvidia service. (default: :obj:`None`)
"""
self.model_type = model_type
url = os.environ.get('NVIDIA_API_BASE_URL', None)
self._api_key = api_key or os.environ.get("NVIDIA_API_KEY")
if not url or not self._api_key:
raise ValueError("The NVIDIA API base url and key should be set.")
self._client = OpenAI(
timeout=60, max_retries=3, base_url=url, api_key=self._api_key
)
self._token_counter: Optional[BaseTokenCounter] = None

@model_api_key_required
def run(
self,
messages: List[OpenAIMessage],
) -> ChatCompletion:
r"""Runs inference of OpenAI chat completion.

Args:
messages (List[OpenAIMessage]): Message list.

Returns:
ChatCompletion.
"""
response = self._client.chat.completions.create(
messages=messages,
model=self.model_type.value,
)
return response
20 changes: 18 additions & 2 deletions camel/types/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,14 @@ class ModelType(Enum):
CLAUDE_2_0 = "claude-2.0"
CLAUDE_INSTANT_1_2 = "claude-instant-1.2"

# 3 models
# Claude3 models
CLAUDE_3_OPUS = "claude-3-opus-20240229"
CLAUDE_3_SONNET = "claude-3-sonnet-20240229"
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"

# Nvidia models
NEMOTRON_4_REWARD = "nvidia/nemotron-4-340b-reward"

@property
def value_for_tiktoken(self) -> str:
return (
Expand Down Expand Up @@ -103,6 +106,17 @@ def is_anthropic(self) -> bool:
ModelType.CLAUDE_3_HAIKU,
}

@property
def is_nvidia(self) -> bool:
r"""Returns whether this type of models is Nvidia-released model.

Returns:
bool: Whether this type of models is nvidia.
"""
return self in {
ModelType.NEMOTRON_4_REWARD,
}

@property
def token_limit(self) -> int:
r"""Returns the maximum token limit for a given model.
Expand Down Expand Up @@ -134,7 +148,7 @@ def token_limit(self) -> int:
return 2048
elif self is ModelType.VICUNA_16K:
return 16384
if self in {ModelType.CLAUDE_2_0, ModelType.CLAUDE_INSTANT_1_2}:
elif self in {ModelType.CLAUDE_2_0, ModelType.CLAUDE_INSTANT_1_2}:
return 100_000
elif self in {
ModelType.CLAUDE_2_1,
Expand All @@ -143,6 +157,8 @@ def token_limit(self) -> int:
ModelType.CLAUDE_3_HAIKU,
}:
return 200_000
elif self is ModelType.NEMOTRON_4_REWARD:
return 4096
else:
raise ValueError("Unknown model type")

Expand Down
4 changes: 4 additions & 0 deletions camel/utils/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def wrapper(self, *args, **kwargs):
if not self._api_key and 'ANTHROPIC_API_KEY' not in os.environ:
raise ValueError('Anthropic API key not found.')
return func(self, *args, **kwargs)
elif self.model_type.is_nvidia:
if not self._api_key and 'NVIDIA_API_KEY' not in os.environ:
raise ValueError('NVIDIA API key not found.')
return func(self, *args, **kwargs)
else:
raise ValueError('Unsupported model type.')

Expand Down
48 changes: 48 additions & 0 deletions examples/models/nemotron_model_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========

from camel.models import NemotronModel
from camel.types import ModelType

nemotro = NemotronModel(model_type=ModelType.NEMOTRON_4_REWARD)

message = [
{"role": "user", "content": "I am going to Paris, what should I see?"},
{
"role": "assistant",
"content": "Ah, Paris, the City of Light! There are so "
"many amazing things to see and do in this beautiful city ...",
},
]

ans = nemotro.run(message)
print(ans)
'''
===============================================================================
ChatCompletion(id='4668ad22-1dec-4df4-ba92-97ffa5fbd16d', choices=[Choice
(finish_reason='length', index=0, logprobs=ChoiceLogprobs(content=
[ChatCompletionTokenLogprob(token='helpfulness', bytes=None, logprob=1.
6171875, top_logprobs=[]), ChatCompletionTokenLogprob(token='correctness',
bytes=None, logprob=1.6484375, top_logprobs=[]), ChatCompletionTokenLogprob
(token='coherence', bytes=None, logprob=3.3125, top_logprobs=[]),
ChatCompletionTokenLogprob(token='complexity', bytes=None, logprob=0.546875,
top_logprobs=[]), ChatCompletionTokenLogprob(token='verbosity', bytes=None,
logprob=0.515625, top_logprobs=[])]), message=[ChatCompletionMessage
(content='helpfulness:1.6171875,correctness:1.6484375,coherence:3.3125,
complexity:0.546875,verbosity:0.515625', role='assistant', function_call=None,
tool_calls=None)])], created=None, model=None, object=None,
system_fingerprint=None, usage=CompletionUsage(completion_tokens=1,
prompt_tokens=78, total_tokens=79))
===============================================================================
'''
4 changes: 2 additions & 2 deletions examples/test/test_ai_society_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import examples.ai_society.role_playing
import examples.function_call.role_playing_with_functions
import examples.open_source_models.role_playing_with_open_source_model
import examples.models.role_playing_with_open_source_model
from camel.models import ModelFactory
from camel.types import ModelPlatformType, ModelType

Expand All @@ -42,6 +42,6 @@ def test_role_playing_with_function_example():

def test_role_playing_with_open_source_model():
with patch('time.sleep', return_value=None):
examples.open_source_models.role_playing_with_open_source_model.main(
examples.models.role_playing_with_open_source_model.main(
chat_turn_limit=2
)
Loading