Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: integrate the llama3 (8B, 70B) served by Groq #531

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
85eaced
feat: add llama3 provided by groq
Appointat Apr 24, 2024
4fef335
feat: add llama3 provided by groq
Appointat Apr 24, 2024
2d78038
Merge remote-tracking branch 'origin' into feature/groq-llama3
Appointat Apr 24, 2024
2590f5c
refactor code to use api_key_required decorator consistently
Appointat Apr 24, 2024
85a55f2
update poetry toml
Appointat Apr 24, 2024
4f307ad
Merge remote-tracking branch 'origin' into feature/groq-llama3
Appointat Apr 27, 2024
f2e57f0
feat: add model configs, fix decorator of step function
Appointat Apr 28, 2024
cff5e04
Merge branch 'master' into feature/groq-llama3
Appointat Apr 28, 2024
6996309
add GROQ_API_KEY to yml
Wendong-Fan Apr 29, 2024
136d0b7
Merge branch 'master' into feature/groq-llama3
Appointat May 5, 2024
5fb90a7
update poetry
Appointat May 5, 2024
fec77a6
remove type ignore
Appointat May 5, 2024
7b0f0b0
reformated by ruff
Appointat May 5, 2024
793a688
use AutoTokenizer as tokennizer
Appointat May 5, 2024
eca647a
fix: update token_counter import in test_groq_llama3_model.py
Appointat May 5, 2024
431f0f0
Merge branch 'master' into feature/groq-llama3
Appointat May 6, 2024
936d6e4
fix: remove GroqModel from camel/models/__init__.py
Appointat May 13, 2024
0dd1a30
fix: remove GroqModel from camel/models/__init__.py
Appointat Jun 2, 2024
a8f96ab
refactor: improve error handling in GroqModel check_model_config method
Appointat Jun 2, 2024
4c8b070
refactor: convert response from OpenAI format to GroqModel
Appointat Jun 2, 2024
6e59352
Merge branch 'master' into feature/groq-llama3
Appointat Jun 2, 2024
71d79ce
refactor: Add API key parameter to GroqModel constructor
Appointat Jun 2, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build_package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,6 @@ jobs:
SEARCH_ENGINE_ID: "${{ secrets.SEARCH_ENGINE_ID }}"
OPENWEATHERMAP_API_KEY: "${{ secrets.OPENWEATHERMAP_API_KEY }}"
ANTHROPIC_API_KEY: "${{ secrets.ANTHROPIC_API_KEY }}"
GROQ_API_KEY: "${{ secrets.GROQ_API_KEY }}"
COHERE_API_KEY: "${{ secrets.COHERE_API_KEY }}"
run: pytest --fast-test-mode ./test
3 changes: 3 additions & 0 deletions .github/workflows/pytest_package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ jobs:
SEARCH_ENGINE_ID: "${{ secrets.SEARCH_ENGINE_ID }}"
OPENWEATHERMAP_API_KEY: "${{ secrets.OPENWEATHERMAP_API_KEY }}"
ANTHROPIC_API_KEY: "${{ secrets.ANTHROPIC_API_KEY }}"
GROQ_API_KEY: "${{ secrets.GROQ_API_KEY }}"
COHERE_API_KEY: "${{ secrets.COHERE_API_KEY }}"
run: poetry run pytest --fast-test-mode test/

Expand All @@ -46,6 +47,7 @@ jobs:
SEARCH_ENGINE_ID: "${{ secrets.SEARCH_ENGINE_ID }}"
OPENWEATHERMAP_API_KEY: "${{ secrets.OPENWEATHERMAP_API_KEY }}"
ANTHROPIC_API_KEY: "${{ secrets.ANTHROPIC_API_KEY }}"
GROQ_API_KEY: "${{ secrets.GROQ_API_KEY }}"
COHERE_API_KEY: "${{ secrets.COHERE_API_KEY }}"
run: poetry run pytest --llm-test-only test/

Expand All @@ -64,5 +66,6 @@ jobs:
SEARCH_ENGINE_ID: "${{ secrets.SEARCH_ENGINE_ID }}"
OPENWEATHERMAP_API_KEY: "${{ secrets.OPENWEATHERMAP_API_KEY }}"
ANTHROPIC_API_KEY: "${{ secrets.ANTHROPIC_API_KEY }}"
GROQ_API_KEY: "${{ secrets.GROQ_API_KEY }}"
COHERE_API_KEY: "${{ secrets.COHERE_API_KEY }}"
run: poetry run pytest --very-slow-test-only test/
3 changes: 3 additions & 0 deletions camel/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from .anthropic_config import ANTHROPIC_API_PARAMS, AnthropicConfig
from .base_config import BaseConfig
from .groq_llama3_config import GROQ_LLAMA3_API_PARAMS, GroqLLAMA3Config
from .openai_config import (
OPENAI_API_PARAMS,
ChatGPTConfig,
Expand All @@ -25,5 +26,7 @@
'OPENAI_API_PARAMS',
'AnthropicConfig',
'ANTHROPIC_API_PARAMS',
'GROQ_LLAMA3_API_PARAMS',
'GroqLLAMA3Config',
'OpenSourceConfig',
]
54 changes: 54 additions & 0 deletions camel/configs/groq_llama3_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from __future__ import annotations

from dataclasses import asdict, dataclass

from anthropic._types import Literal

from camel.configs.base_config import BaseConfig


@dataclass(frozen=True)
class GroqLLAMA3Config(BaseConfig):
r"""Defines the parameters for generating chat completions using the
Anthropic API. And Camel does not support stream mode for GroqLLAMA3.

See: https://console.groq.com/docs/text-chat
Args:
max_tokens (int, optional): The maximum number of tokens to generate
before stopping. Note that Anthropic models may stop before
reaching this maximum. This parameter only specifies the absolute
maximum number of tokens to generate.
(default: :obj:`256`)
temperature (float, optional): Amount of randomness injected into the
response. Defaults to 1. Ranges from 0 to 1. Use temp closer to 0
for analytical / multiple choice, and closer to 1 for creative
and generative tasks.
(default: :obj:`1`)
stream (bool, optional): Whether to incrementally stream the response
using server-sent events. Camel does not support stream mode for
Groq Llama3.
(default: :obj:`False`)
"""

max_tokens: int = 4096 # since the Llama3 usually has a context
# window of 8192 tokens, the default is set to 4096
temperature: float = 1 # Camel does not suggest modifying the `top_p`
# Camel does not support stream mode for Groq Llama3, the default value of
# `stream` is False
stream: Literal[False] = False


GROQ_LLAMA3_API_PARAMS = {param for param in asdict(GroqLLAMA3Config()).keys()}
2 changes: 2 additions & 0 deletions camel/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from .anthropic_model import AnthropicModel
from .base_model import BaseModelBackend
from .groq_model import GroqModel
from .model_factory import ModelFactory
from .open_source_model import OpenSourceModel
from .openai_audio_models import OpenAIAudioModels
Expand All @@ -23,6 +24,7 @@
'BaseModelBackend',
'OpenAIModel',
'AnthropicModel',
'GroqModel',
'StubModel',
'OpenSourceModel',
'ModelFactory',
Expand Down
7 changes: 6 additions & 1 deletion camel/models/anthropic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
from camel.configs import ANTHROPIC_API_PARAMS
from camel.models.base_model import BaseModelBackend
from camel.types import ChatCompletion, ModelType
from camel.utils import AnthropicTokenCounter, BaseTokenCounter
from camel.utils import (
AnthropicTokenCounter,
BaseTokenCounter,
api_key_required,
)


class AnthropicModel(BaseModelBackend):
Expand Down Expand Up @@ -90,6 +94,7 @@ def count_tokens_from_prompt(self, prompt: str) -> int:
"""
return self.client.count_tokens(prompt)

@api_key_required
def run(
self,
messages,
Expand Down
149 changes: 149 additions & 0 deletions camel/models/groq_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
import os
from typing import Any, Dict, List, Optional

from groq import Groq

from camel.configs import GROQ_LLAMA3_API_PARAMS
from camel.messages import OpenAIMessage
from camel.models import BaseModelBackend
from camel.types import (
ChatCompletion,
ChatCompletionMessage,
Choice,
CompletionUsage,
ModelType,
)
from camel.utils import (
BaseTokenCounter,
GroqLlama3TokenCounter,
api_key_required,
)


class GroqModel(BaseModelBackend):
r"""LLM API served by Groq in a unified BaseModelBackend interface."""

def __init__(
self,
model_type: ModelType,
model_config_dict: Dict[str, Any],
api_key: Optional[str] = None,
) -> None:
r"""Constructor for Groq backend.

Args:
model_type (ModelType): Model for which a backend is created.
model_config_dict (Dict[str, Any]): A dictionary that will
be fed into groq.ChatCompletion.create().
api_key (Optional[str]): The API key for authenticating with the
Anthropic service. (default: :obj:`None`).
"""
super().__init__(model_type, model_config_dict)
url = os.environ.get('GROQ_API_BASE_URL', None)
api_key = os.environ.get('GROQ_API_KEY')
self._client = Groq(api_key=api_key, base_url=url)
self._token_counter: Optional[BaseTokenCounter] = None

def _convert_response_from_anthropic_to_openai(self, response):
# openai ^1.0.0 format, reference openai/types/chat/chat_completion.py
obj = ChatCompletion.construct(
id=response.id,
choices=[
Choice.construct(
finish_reason=response.choices[0].finish_reason,
index=response.choices[0].index,
logprobs=response.choices[0].logprobs,
message=ChatCompletionMessage.construct(
content=response.choices[0].message.content,
role=response.choices[0].message.role,
function_call=None, # It does not provide function call
tool_calls=response.choices[0].message.tool_calls,
),
)
],
created=response.created,
model=response.model,
object="chat.completion",
system_fingerprint=response.system_fingerprint,
usage=CompletionUsage.construct(
completion_tokens=response.usage.completion_tokens,
prompt_tokens=response.usage.prompt_tokens,
total_tokens=response.usage.total_tokens,
),
)
return obj

@property
def token_counter(self) -> BaseTokenCounter:
r"""Initialize the token counter for the model backend. But Groq API
does not provide any token counter.

Returns:
BaseTokenCounter: The token counter following the model's
tokenization style.
"""
if not self._token_counter:
self._token_counter = GroqLlama3TokenCounter(self.model_type)
return self._token_counter

@api_key_required
def run(
self,
messages: List[OpenAIMessage],
) -> ChatCompletion:
r"""Runs inference of OpenAI chat completion.

Args:
messages (List[OpenAIMessage]): Message list with the chat history
in OpenAI API format.

Returns:
ChatCompletion: Response in the OpenAI API format (non-stream\
mode).
"""
_response = self._client.chat.completions.create(
messages=messages, # type: ignore[arg-type]
model=self.model_type.value,
**self.model_config_dict,
)

# Format response to OpenAI format
response = self._convert_response_from_anthropic_to_openai(_response)

return response

def check_model_config(self):
r"""Check whether the model configuration contains any unexpected
arguments to Groq API. But Groq API does not have any additional
arguments to check.

Raises:
ValueError: If the model configuration dictionary contains any
unexpected arguments to Groq API.
"""
for param in self.model_config_dict:
if param not in GROQ_LLAMA3_API_PARAMS:
raise ValueError(
f"Unexpected argument `{param}` is "
"input into Groq Llama3 model backend."
)

@property
def stream(self) -> bool:
r"""Returns whether the model supports streaming. But Groq API does
not support streaming.
"""
return False
3 changes: 3 additions & 0 deletions camel/models/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from camel.models.anthropic_model import AnthropicModel
from camel.models.base_model import BaseModelBackend
from camel.models.groq_model import GroqModel
from camel.models.open_source_model import OpenSourceModel
from camel.models.openai_model import OpenAIModel
from camel.models.stub_model import StubModel
Expand Down Expand Up @@ -58,6 +59,8 @@ def create(
model_class = OpenSourceModel
elif model_type.is_anthropic:
model_class = AnthropicModel
elif model_type.is_groq:
model_class = GroqModel
else:
raise ValueError(f"Unknown model type `{model_type}` is input")

Expand Down
15 changes: 15 additions & 0 deletions camel/types/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ class ModelType(Enum):
GPT_4_TURBO = "gpt-4-turbo"
GPT_4O = "gpt-4o"

GROQ_LLAMA_3_8_B = "llama3-8b-8192"
GROQ_LLAMA_3_70_B = "llama3-70b-8192"

STUB = "stub"

LLAMA_2 = "llama-2"
Expand Down Expand Up @@ -62,6 +65,14 @@ def is_openai(self) -> bool:
ModelType.GPT_4O,
}

@property
def is_groq(self) -> bool:
r"""Returns whether this type of models is served by Groq."""
return self in {
ModelType.GROQ_LLAMA_3_8_B,
ModelType.GROQ_LLAMA_3_70_B,
}

@property
def is_open_source(self) -> bool:
r"""Returns whether this type of models is open-source."""
Expand Down Expand Up @@ -103,6 +114,10 @@ def token_limit(self) -> int:
return 128000
elif self is ModelType.GPT_4O:
return 128000
elif self is ModelType.GROQ_LLAMA_3_8_B:
return 8192
elif self is ModelType.GROQ_LLAMA_3_70_B:
return 8192
elif self is ModelType.STUB:
return 4096
elif self is ModelType.LLAMA_2:
Expand Down
2 changes: 2 additions & 0 deletions camel/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .token_counting import (
AnthropicTokenCounter,
BaseTokenCounter,
GroqLlama3TokenCounter,
OpenAITokenCounter,
OpenSourceTokenCounter,
get_model_encoding,
Expand All @@ -42,6 +43,7 @@
'get_task_list',
'check_server_running',
'AnthropicTokenCounter',
'GroqLlama3TokenCounter',
'get_system_information',
'to_pascal',
'PYDANTIC_V2',
Expand Down
4 changes: 4 additions & 0 deletions camel/utils/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ def wrapper(self, *args, **kwargs):
if 'ANTHROPIC_API_KEY' not in os.environ:
raise ValueError('Anthropic API key not found.')
return func(self, *args, **kwargs)
elif self.model_type.is_groq:
if "GROQ_API_KEY" not in os.environ:
raise ValueError('Groq API key not found.')
return func(self, *args, **kwargs)
else:
raise ValueError('Unsupported model type.')

Expand Down
Loading
Loading