Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: integrate the llama3 (8B, 70B), Mistral.AI, Gemma (7B, 9B) served by Groq #531

Merged
merged 45 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
85eaced
feat: add llama3 provided by groq
Appointat Apr 24, 2024
4fef335
feat: add llama3 provided by groq
Appointat Apr 24, 2024
2d78038
Merge remote-tracking branch 'origin' into feature/groq-llama3
Appointat Apr 24, 2024
2590f5c
refactor code to use api_key_required decorator consistently
Appointat Apr 24, 2024
85a55f2
update poetry toml
Appointat Apr 24, 2024
4f307ad
Merge remote-tracking branch 'origin' into feature/groq-llama3
Appointat Apr 27, 2024
f2e57f0
feat: add model configs, fix decorator of step function
Appointat Apr 28, 2024
cff5e04
Merge branch 'master' into feature/groq-llama3
Appointat Apr 28, 2024
6996309
add GROQ_API_KEY to yml
Wendong-Fan Apr 29, 2024
136d0b7
Merge branch 'master' into feature/groq-llama3
Appointat May 5, 2024
5fb90a7
update poetry
Appointat May 5, 2024
fec77a6
remove type ignore
Appointat May 5, 2024
7b0f0b0
reformated by ruff
Appointat May 5, 2024
793a688
use AutoTokenizer as tokennizer
Appointat May 5, 2024
eca647a
fix: update token_counter import in test_groq_llama3_model.py
Appointat May 5, 2024
431f0f0
Merge branch 'master' into feature/groq-llama3
Appointat May 6, 2024
936d6e4
fix: remove GroqModel from camel/models/__init__.py
Appointat May 13, 2024
0dd1a30
fix: remove GroqModel from camel/models/__init__.py
Appointat Jun 2, 2024
a8f96ab
refactor: improve error handling in GroqModel check_model_config method
Appointat Jun 2, 2024
4c8b070
refactor: convert response from OpenAI format to GroqModel
Appointat Jun 2, 2024
6e59352
Merge branch 'master' into feature/groq-llama3
Appointat Jun 2, 2024
71d79ce
refactor: Add API key parameter to GroqModel constructor
Appointat Jun 2, 2024
6240aa2
feat: Update pyproject.toml to include outlines dependency
Appointat Jul 3, 2024
52e455a
feat: Update pyproject.toml to include outlines dependency
Appointat Jul 5, 2024
8c4d086
feat: Add groq-llama3 as a model platform type
Appointat Jul 5, 2024
39dbf03
Update version to 0.1.5.6
Appointat Jul 8, 2024
2a85e23
Update dependencies to latest versions
Appointat Jul 8, 2024
f699a37
feat: Update GroqLlama3TokenCounter placeholder in test_groq_llama3_m…
Appointat Jul 8, 2024
6d84987
chore: Remove unused role playing code
Appointat Jul 8, 2024
7d686fb
chore: Update GroqLlama3TokenCounter placeholder in token_counting.py
Appointat Jul 8, 2024
e748c66
Merge branch 'master' into feature/groq-llama3
Wendong-Fan Jul 14, 2024
1baf7d2
chore: Add stop parameter to GroqLLAMA3Config
Appointat Jul 15, 2024
9a88dcf
refactor: Update GroqLlama3TokenCounter placeholder in token_counting.py
Appointat Jul 15, 2024
f88bbb3
refactor: Improve printing of AI User and AI Assistant messages
Appointat Jul 15, 2024
23fd16b
refactor: Update token_counting.py to use OpenAI API format for messages
Appointat Jul 15, 2024
98e4012
refactor: Rename GroqLLAMA3Config to GroqConfig and update related im…
Appointat Jul 15, 2024
a57126d
refactor: Update GroqLlama3TokenCounter placeholder in token_counting.py
Appointat Jul 15, 2024
9d7948b
refactor: Add support for GroqMixtral8x7b and GroqGemma7bIt models
Appointat Jul 15, 2024
30ef7e7
refactor: Update token_counting.py to use OpenAI API format for messages
Appointat Jul 15, 2024
1d7a862
refactor: Update GroqModel to support multiple Groq models and improv…
Appointat Jul 15, 2024
21132ac
refactor: Update GroqModel to use OpenAI API for multiple models and …
Appointat Jul 23, 2024
0fff2f1
add HF Token
Wendong-Fan Jul 23, 2024
5b0b6ce
Merge branch 'master' into feature/groq-llama3
Wendong-Fan Jul 23, 2024
407b44e
update WD
Wendong-Fan Jul 23, 2024
2340fdf
docstring fix
Wendong-Fan Jul 23, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build_package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,5 @@ jobs:
SEARCH_ENGINE_ID: "${{ secrets.SEARCH_ENGINE_ID }}"
OPENWEATHERMAP_API_KEY: "${{ secrets.OPENWEATHERMAP_API_KEY }}"
ANTHROPIC_API_KEY: "${{ secrets.ANTHROPIC_API_KEY }}"
GROQ_API_KEY: "${{ secrets.GROQ_API_KEY }}"
run: pytest --fast-test-mode ./test
3 changes: 3 additions & 0 deletions .github/workflows/pytest_package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ jobs:
SEARCH_ENGINE_ID: "${{ secrets.SEARCH_ENGINE_ID }}"
OPENWEATHERMAP_API_KEY: "${{ secrets.OPENWEATHERMAP_API_KEY }}"
ANTHROPIC_API_KEY: "${{ secrets.ANTHROPIC_API_KEY }}"
GROQ_API_KEY: "${{ secrets.GROQ_API_KEY }}"
run: poetry run pytest --fast-test-mode test/

pytest_package_llm_test:
Expand All @@ -45,6 +46,7 @@ jobs:
SEARCH_ENGINE_ID: "${{ secrets.SEARCH_ENGINE_ID }}"
OPENWEATHERMAP_API_KEY: "${{ secrets.OPENWEATHERMAP_API_KEY }}"
ANTHROPIC_API_KEY: "${{ secrets.ANTHROPIC_API_KEY }}"
GROQ_API_KEY: "${{ secrets.GROQ_API_KEY }}"
run: poetry run pytest --llm-test-only test/

pytest_package_very_slow_test:
Expand All @@ -62,4 +64,5 @@ jobs:
SEARCH_ENGINE_ID: "${{ secrets.SEARCH_ENGINE_ID }}"
OPENWEATHERMAP_API_KEY: "${{ secrets.OPENWEATHERMAP_API_KEY }}"
ANTHROPIC_API_KEY: "${{ secrets.ANTHROPIC_API_KEY }}"
GROQ_API_KEY: "${{ secrets.GROQ_API_KEY }}"
run: poetry run pytest --very-slow-test-only test/
9 changes: 7 additions & 2 deletions camel/agents/chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple

from camel.agents import BaseAgent
from camel.configs import ChatGPTConfig, ChatGPTVisionConfig
from camel.configs import ChatGPTConfig, ChatGPTVisionConfig, GroqLLAMA3Config
from camel.memories import (
AgentMemory,
ChatHistoryMemory,
Expand Down Expand Up @@ -143,7 +143,12 @@ def __init__(
raise ValueError("Please don't use `ChatGPTVisionConfig` as "
"the `model_config` when `model_type` "
"is not `GPT_4_TURBO_VISION`")
self.model_config = model_config or ChatGPTConfig()
if self.model_type.is_groq:
# Since the configuration of Groq models is different from
# OpenAI models, we need to use `if`.
self.model_config = model_config or GroqLLAMA3Config()
else:
self.model_config = model_config or ChatGPTConfig()

self.model_backend: BaseModelBackend = ModelFactory.create(
self.model_type, self.model_config.__dict__)
Expand Down
54 changes: 48 additions & 6 deletions camel/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,16 @@

from abc import ABC
from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Literal,
Optional,
Sequence,
Union,
)

from anthropic._types import NOT_GIVEN, NotGiven

Expand Down Expand Up @@ -222,10 +231,10 @@ class AnthropicConfig(BaseConfig):

See: https://docs.anthropic.com/claude/reference/complete_post
Args:
max_tokens_to_sample (int, optional): The maximum number of tokens to
generate before stopping. Note that Anthropic models may stop
before reaching this maximum. This parameter only specifies the
absolute maximum number of tokens to generate.
max_tokens (int, optional): The maximum number of tokens to generate
before stopping. Note that Anthropic models may stop before
reaching this maximum. This parameter only specifies the absolute
maximum number of tokens to generate.
(default: :obj:`256`)
stop_sequences (List[str], optional): Sequences that will cause the
model to stop generating completion text. Anthropic models stop
Expand Down Expand Up @@ -253,7 +262,6 @@ class AnthropicConfig(BaseConfig):
stream (bool, optional): Whether to incrementally stream the response
using server-sent events.
(default: :obj:`False`)

"""
max_tokens: int = 256
stop_sequences: Union[List[str], NotGiven] = NOT_GIVEN
Expand All @@ -265,3 +273,37 @@ class AnthropicConfig(BaseConfig):


ANTHROPIC_API_PARAMS = {param for param in asdict(AnthropicConfig()).keys()}


@dataclass(frozen=True)
class GroqLLAMA3Config(BaseConfig):
r"""Defines the parameters for generating chat completions using the
Anthropic API. And Camel does not support stream mode for GroqLLAMA3.

See: https://console.groq.com/docs/text-chat
Args:
max_tokens (int, optional): The maximum number of tokens to generate
before stopping. Note that Anthropic models may stop before
reaching this maximum. This parameter only specifies the absolute
maximum number of tokens to generate.
(default: :obj:`256`)
temperature (float, optional): Amount of randomness injected into the
response. Defaults to 1. Ranges from 0 to 1. Use temp closer to 0
for analytical / multiple choice, and closer to 1 for creative
and generative tasks.
(default: :obj:`1`)
stream (bool, optional): Whether to incrementally stream the response
using server-sent events. Camel does not support stream mode for
Groq Llama3.
(default: :obj:`False`)
"""

max_tokens: int = 4096 # since the Llama3 usually has a context
# window of 8192 tokens, the default is set to 4096
temperature: float = 1 # Camel does not suggest modifying the `top_p`
# Camel does not support stream mode for Groq Llama3, the default value of
# `stream` is False
stream: Literal[False] = False


GROQ_LLAMA3_API_PARAMS = {param for param in asdict(GroqLLAMA3Config()).keys()}
2 changes: 2 additions & 0 deletions camel/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from .base_model import BaseModelBackend
from .groq_model import GroqModel
from .openai_model import OpenAIModel
from .stub_model import StubModel
from .open_source_model import OpenSourceModel
Expand All @@ -20,6 +21,7 @@

__all__ = [
'BaseModelBackend',
'GroqModel',
'OpenAIModel',
'AnthropicModel',
'StubModel',
Expand Down
7 changes: 6 additions & 1 deletion camel/models/anthropic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
from camel.configs import ANTHROPIC_API_PARAMS
from camel.models import BaseModelBackend
from camel.types import ChatCompletion, ModelType
from camel.utils import AnthropicTokenCounter, BaseTokenCounter
from camel.utils import (
AnthropicTokenCounter,
BaseTokenCounter,
api_key_required,
)


class AnthropicModel(BaseModelBackend):
Expand Down Expand Up @@ -67,6 +71,7 @@ def count_tokens_from_prompt(self, prompt: str) -> int:
"""
return self.client.count_tokens(prompt)

@api_key_required
def run(
self,
messages,
Expand Down
122 changes: 122 additions & 0 deletions camel/models/groq_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
import os
from typing import Any, Dict, List, Optional

from groq import Groq

from camel.configs import GROQ_LLAMA3_API_PARAMS
from camel.messages import OpenAIMessage
from camel.models import BaseModelBackend
from camel.types import (
ChatCompletion,
ChatCompletionMessage,
Choice,
CompletionUsage,
ModelType,
)
from camel.utils import BaseTokenCounter, OpenAITokenCounter, api_key_required


class GroqModel(BaseModelBackend):
r"""LLM API served by Groq in a unified BaseModelBackend interface."""

def __init__(self, model_type: ModelType,
model_config_dict: Dict[str, Any]) -> None:
r"""Constructor for Groq backend.

Args:
model_type (ModelType): Model for which a backend is created.
model_config_dict (Dict[str, Any]): A dictionary that will
be fed into groq.ChatCompletion.create().
"""
super().__init__(model_type, model_config_dict)
url = os.environ.get('GROQ_API_BASE_URL', None)
api_key = os.environ.get('GROQ_API_KEY')
self._client = Groq(api_key=api_key, base_url=url)
self._token_counter: Optional[BaseTokenCounter] = None

@property
def token_counter(self) -> BaseTokenCounter:
r"""Initialize the token counter for the model backend. But Groq API
does not provide any token counter.

Returns:
BaseTokenCounter: The token counter following the model's
tokenization style.
"""
if not self._token_counter:
# Groq API does not provide any token counter, so we use the
# OpenAI token counter as a placeholder.
self._token_counter = OpenAITokenCounter(ModelType.GPT_3_5_TURBO)
Appointat marked this conversation as resolved.
Show resolved Hide resolved
return self._token_counter

@api_key_required
def run(
self,
messages: List[OpenAIMessage],
) -> ChatCompletion: # type: ignore
r"""Runs inference of OpenAI chat completion.

Args:
messages (List[OpenAIMessage]): Message list with the chat history
in OpenAI API format.

Returns:
ChatCompletion: `ChatCompletion` in the non-stream mode, while the
stream mode is not supported by Groq.
"""
# Since the data structure defined in the Groq client is slightly
# different from the one defined in the CAMEL, we need to convert the
# data structure. In addition, the tyep ignore is used to avoid the
# meaningless type error detected by the mypy.
response = self._client.chat.completions.create(
messages=messages, # type: ignore
model=self.model_type.value,
**self.model_config_dict,
)

_choices: List[Choice] = [] # type: ignore
for choice in response.choices:
choice.message = ChatCompletionMessage(
role=choice.message.role, # type: ignore
content=choice.message.content,
tool_calls=choice.message.tool_calls) # type: ignore
_choice = Choice(**choice.__dict__) # type: ignore
_choices.append(_choice)
response.choices = _choices # type: ignore

response.usage = CompletionUsage(
**response.usage.__dict__) # type: ignore
return ChatCompletion(**response.__dict__)

def check_model_config(self):
r"""Check whether the model configuration contains any
unexpected arguments to Groq API. But Groq API does not have any
additional arguments to check.
Raises:
Appointat marked this conversation as resolved.
Show resolved Hide resolved
ValueError: If the model configuration dictionary contains any
unexpected arguments to OpenAI API.
"""
for param in self.model_config_dict:
if param not in GROQ_LLAMA3_API_PARAMS:
raise ValueError(f"Unexpected argument `{param}` is "
"input into Groq Llama3 model backend.")

@property
def stream(self) -> bool:
r"""Returns whether the model supports streaming. But Groq API does
not support streaming.
"""
return False
3 changes: 3 additions & 0 deletions camel/models/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from camel.models import (
AnthropicModel,
BaseModelBackend,
GroqModel,
OpenAIModel,
OpenSourceModel,
StubModel,
Expand Down Expand Up @@ -49,6 +50,8 @@ def create(model_type: ModelType,
model_class: Any
if model_type.is_openai:
model_class = OpenAIModel
elif model_type.is_groq:
model_class = GroqModel
elif model_type == ModelType.STUB:
model_class = StubModel
elif model_type.is_open_source:
Expand Down
15 changes: 15 additions & 0 deletions camel/types/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ class ModelType(Enum):
GPT_4_TURBO = "gpt-4-turbo"
GPT_4_TURBO_VISION = "gpt-4-turbo"

GROQ_LLAMA_3_8_B = "llama3-8b-8192"
GROQ_LLAMA_3_70_B = "llama3-70b-8192"

STUB = "stub"

LLAMA_2 = "llama-2"
Expand Down Expand Up @@ -62,6 +65,14 @@ def is_openai(self) -> bool:
ModelType.GPT_4_TURBO_VISION,
}

@property
def is_groq(self) -> bool:
r"""Returns whether this type of models is served by Groq."""
return self in {
ModelType.GROQ_LLAMA_3_8_B,
ModelType.GROQ_LLAMA_3_70_B,
}

@property
def is_open_source(self) -> bool:
r"""Returns whether this type of models is open-source."""
Expand Down Expand Up @@ -103,6 +114,10 @@ def token_limit(self) -> int:
return 128000
elif self is ModelType.GPT_4_TURBO_VISION:
return 128000
elif self is ModelType.GROQ_LLAMA_3_8_B:
return 8192
elif self is ModelType.GROQ_LLAMA_3_70_B:
return 8192
elif self is ModelType.STUB:
return 4096
elif self is ModelType.LLAMA_2:
Expand Down
8 changes: 6 additions & 2 deletions camel/utils/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def get_lazy_imported_types_module():


def api_key_required(func: F) -> F:
r"""Decorator that checks if the OpenAI API key is available in the
r"""Decorator that checks if the LLM API key is available in the
environment variables.

Args:
Expand All @@ -57,7 +57,7 @@ def api_key_required(func: F) -> F:
callable: The decorated function.

Raises:
ValueError: If the OpenAI API key is not found in the environment
ValueError: If the LLM API key is not found in the environment
variables.
"""

Expand All @@ -71,6 +71,10 @@ def wrapper(self, *args, **kwargs):
if 'ANTHROPIC_API_KEY' not in os.environ:
raise ValueError('Anthropic API key not found.')
return func(self, *args, **kwargs)
elif self.model_type.is_groq:
if "GROQ_API_KEY" not in os.environ:
raise ValueError('Groq API key not found.')
return func(self, *args, **kwargs)
else:
raise ValueError('Unsupported model type.')

Expand Down
Loading
Loading