Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Anthropic support #288

Merged
merged 26 commits into from
Apr 21, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
a1b9b28
add Anthropic support #283
ocss884 Sep 14, 2023
b242cc0
correct lint and format
ocss884 Sep 14, 2023
201a4e2
refresh poetry.lock
ocss884 Sep 15, 2023
de6d6ee
Use union instead of "|" for py3.8
ocss884 Sep 15, 2023
c3265f8
Create pull.yml
ocss884 Oct 6, 2023
e735d9a
Merge branch 'camel-ai:master' into master
ocss884 Nov 10, 2023
c0e55ea
Merge branch 'camel-ai:master' into master
ocss884 Nov 18, 2023
35b3673
Merge branch 'camel-ai:master' into master
ocss884 Dec 5, 2023
808b2cf
update to "master"
ocss884 Dec 5, 2023
ff6f94c
apply #288, migrate to openai v1.0.0
ocss884 Dec 5, 2023
e597c05
update lock file and correct code format
ocss884 Dec 27, 2023
4fde44a
Merge branch 'master' into anthropic_support
ocss884 Dec 27, 2023
1f85339
bug fix in commons.py
ocss884 Jan 2, 2024
72bfdef
update test file and remove mistakenly tracked pull.yml
ocss884 Jan 3, 2024
faec388
Merge remote-tracking branch 'upstream/master'
ocss884 Mar 17, 2024
fcf4511
Merge branch 'master' into anthropic_support
ocss884 Mar 17, 2024
fe368e1
update lock file
ocss884 Mar 17, 2024
9dd1822
limit openai version<1.14.0
ocss884 Mar 17, 2024
a48a57e
add docstring and improve api_key checking logic
ocss884 Mar 24, 2024
2f5d17f
Merge branch 'master' of https://github.com/camel-ai/camel
ocss884 Apr 15, 2024
156402e
Merge branch 'master' into anthropic_support
ocss884 Apr 15, 2024
f202c7c
add Claude3 support & migrate to Message endpoint
ocss884 Apr 17, 2024
85bbb18
small fix
ocss884 Apr 17, 2024
1138868
Merge branch 'master' into anthropic_support
ocss884 Apr 17, 2024
425a531
Merge branch 'master' into anthropic_support
Wendong-Fan Apr 17, 2024
cdc6744
Merge branch 'master' into anthropic_support
Wendong-Fan Apr 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions camel/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, List, Optional, Sequence, Union

from anthropic._types import NOT_GIVEN, NotGiven

from camel.functions import OpenAIFunction


Expand Down Expand Up @@ -159,3 +161,55 @@ class OpenSourceConfig(BaseConfig):
param
for param in asdict(FunctionCallingConfig()).keys()
}


@dataclass(frozen=True)
class AnthropicConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
Anthropic API.

See: https://docs.anthropic.com/claude/reference/complete_post
Args:
max_tokens_to_sample (int, optional): The maximum number of tokens to
generate before stopping. Note that Anthropic models may stop
before reaching this maximum. This parameter only specifies the
absolute maximum number of tokens to generate.
(default: :obj:`256`)
stop_sequences (List[str], optional): Sequences that will cause the
model to stop generating completion text. Anthropic models stop
on "\n\nHuman:", and may include additional built-in stop sequences
in the future. By providing the stop_sequences parameter, you may
include additional strings that will cause the model to stop
generating.
temperature (float, optional): Amount of randomness injected into the
response. Defaults to 1. Ranges from 0 to 1. Use temp closer to 0
for analytical / multiple choice, and closer to 1 for creative
and generative tasks.
(default: :obj:`1`)
top_p (float, optional): Use nucleus sampling. In nucleus sampling, we
compute the cumulative distribution over all the options for each
subsequent token in decreasing probability order and cut it off
once it reaches a particular probability specified by `top_p`.
You should either alter `temperature` or `top_p`,
but not both.
(default: :obj:`0.7`)
top_k (int, optional): Only sample from the top K options for each
subsequent token. Used to remove "long tail" low probability
responses.
(default: :obj:`5`)
metadata: An object describing metadata about the request.
stream (bool, optional): Whether to incrementally stream the response
using server-sent events.
(default: :obj:`False`)

"""
max_tokens_to_sample: int = 256
stop_sequences: Union[List[str], NotGiven] = NOT_GIVEN
temperature: float = 1
top_p: float = 0.7
top_k: int = 5
metadata: NotGiven = NOT_GIVEN
stream: bool = False


ANTHROPIC_API_PARAMS = {param for param in asdict(AnthropicConfig()).keys()}
4 changes: 2 additions & 2 deletions camel/embeddings/openai_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from camel.embeddings import BaseEmbedding
from camel.types import EmbeddingModelType
from camel.utils import openai_api_key_required
from camel.utils import api_key_required


class OpenAIEmbedding(BaseEmbedding[str]):
Expand All @@ -41,7 +41,7 @@ def __init__(
self.output_dim = model_type.output_dim
self.client = OpenAI()

@openai_api_key_required
@api_key_required
def embed_list(
self,
objs: List[str],
Expand Down
2 changes: 2 additions & 0 deletions camel/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
from .openai_model import OpenAIModel
from .stub_model import StubModel
from .open_source_model import OpenSourceModel
from .anthropic_model import AnthropicModel
from .model_factory import ModelFactory

__all__ = [
'BaseModelBackend',
'OpenAIModel',
'AnthropicModel',
'StubModel',
'OpenSourceModel',
'ModelFactory',
Expand Down
115 changes: 115 additions & 0 deletions camel/models/anthropic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
import os
from typing import Any, Dict, List, Optional, Union

from anthropic import AI_PROMPT, HUMAN_PROMPT, Anthropic
from anthropic.types import Completion
from openai import Stream

from camel.configs import ANTHROPIC_API_PARAMS
from camel.messages import OpenAIMessage
from camel.models import BaseModelBackend
from camel.types import ChatCompletion, ChatCompletionChunk, ModelType
from camel.utils import AnthropicTokenCounter, BaseTokenCounter
from camel.utils.token_counting import messages_to_prompt


class AnthropicModel(BaseModelBackend):
r"""Anthropic API in a unified BaseModelBackend interface."""

def __init__(self, model_type: ModelType,
model_config_dict: Dict[str, Any]) -> None:

super().__init__(model_type, model_config_dict)

self.client = Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
self._token_counter: Optional[BaseTokenCounter] = None

def _convert_openai_messages_to_anthropic_prompt(
self, messages: List[OpenAIMessage]):
return messages_to_prompt(messages, self.model_type)

def _convert_response_from_anthropic_to_openai(self, response: Completion):
# openai ^1.0.0 format, reference openai/types/chat/chat_completion.py
obj = ChatCompletion.construct(
id=None, choices=[
dict(
index=0, message={
"role": "assistant",
"content": response.completion
}, finish_reason=response.stop_reason)
], created=None, model=response.model, object="chat.completion")
return obj

@property
def token_counter(self) -> BaseTokenCounter:
if not self._token_counter:
self._token_counter = AnthropicTokenCounter(self.model_type)
return self._token_counter

def count_tokens_from_messages(self, messages: List[OpenAIMessage]):
prompt = self._convert_openai_messages_to_anthropic_prompt(messages)
return self.count_tokens_from_prompt(prompt)

def count_tokens_from_prompt(self, prompt):
return self.client.count_tokens(prompt)
Wendong-Fan marked this conversation as resolved.
Show resolved Hide resolved

def run(
self,
messages: List[OpenAIMessage],
) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
r"""Run inference of Anthropic chat completion.

Args:
messages (List[Dict]): Message list with the chat history
in OpenAI API format.

Returns:
Dict[str, Any]: Response in the OpenAI API format.
"""

prompt = self._convert_openai_messages_to_anthropic_prompt(messages)
response = self.client.completions.create(
model=self.model_type.value,
prompt=f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}",
**self.model_config_dict)

# format response to openai format
response = self._convert_response_from_anthropic_to_openai(response)

return response

def check_model_config(self):
r"""Check whether the model configuration is valid for anthropic
model backends.

Raises:
ValueError: If the model configuration dictionary contains any
unexpected arguments to OpenAI API, or it does not contain
:obj:`model_path` or :obj:`server_url`.
"""
for param in self.model_config_dict:
if param not in ANTHROPIC_API_PARAMS:
raise ValueError(f"Unexpected argument `{param}` is "
"input into Anthropic model backend.")

@property
def stream(self) -> bool:
r"""Returns whether the model is in stream mode,
which sends partial results each time.
Returns:
bool: Whether the model is in stream mode.
"""
return self.model_config_dict.get("stream", False)
3 changes: 3 additions & 0 deletions camel/models/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Any, Dict

from camel.models import (
AnthropicModel,
BaseModelBackend,
OpenAIModel,
OpenSourceModel,
Expand Down Expand Up @@ -52,6 +53,8 @@ def create(model_type: ModelType,
model_class = StubModel
elif model_type.is_open_source:
model_class = OpenSourceModel
elif model_type.is_anthropic:
model_class = AnthropicModel
else:
raise ValueError(f"Unknown model type `{model_type}` is input")

Expand Down
8 changes: 2 additions & 6 deletions camel/models/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,7 @@
from camel.messages import OpenAIMessage
from camel.models import BaseModelBackend
from camel.types import ChatCompletion, ChatCompletionChunk, ModelType
from camel.utils import (
BaseTokenCounter,
OpenAITokenCounter,
openai_api_key_required,
)
from camel.utils import BaseTokenCounter, OpenAITokenCounter, api_key_required


class OpenAIModel(BaseModelBackend):
Expand Down Expand Up @@ -57,7 +53,7 @@ def token_counter(self) -> BaseTokenCounter:
self._token_counter = OpenAITokenCounter(self.model_type)
return self._token_counter

@openai_api_key_required
@api_key_required
def run(
self,
messages: List[OpenAIMessage],
Expand Down
17 changes: 17 additions & 0 deletions camel/types/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ class ModelType(Enum):
VICUNA = "vicuna"
VICUNA_16K = "vicuna-16k"

CLAUDE_2 = "claude-2"
CLAUDE_INSTANT = "claude-instant-1"

@property
def value_for_tiktoken(self) -> str:
return self.value if self is not ModelType.STUB else "gpt-3.5-turbo"
Expand All @@ -62,6 +65,18 @@ def is_open_source(self) -> bool:
ModelType.VICUNA_16K,
}

@property
def is_anthropic(self) -> bool:
r"""Returns whether this type of models is Anthropic-released model.

Returns:
bool: Whether this type of models is anthropic.
"""
if self.name in {"CLAUDE_2", "CLAUDE_INSTANT"}:
return True
else:
return False

@property
def token_limit(self) -> int:
r"""Returns the maximum token limit for a given model.
Expand Down Expand Up @@ -89,6 +104,8 @@ def token_limit(self) -> int:
return 2048
elif self is ModelType.VICUNA_16K:
return 16384
elif self in {ModelType.CLAUDE_2, ModelType.CLAUDE_INSTANT}:
return 100_000
else:
raise ValueError("Unknown model type")

Expand Down
14 changes: 6 additions & 8 deletions camel/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from .commons import (
openai_api_key_required,
api_key_required,
print_text_animated,
get_prompt_template_key_words,
get_first_int,
Expand All @@ -23,21 +23,19 @@
to_pascal,
PYDANTIC_V2,
)
from .token_counting import (
get_model_encoding,
BaseTokenCounter,
OpenAITokenCounter,
OpenSourceTokenCounter,
)
from .token_counting import (get_model_encoding, BaseTokenCounter,
OpenAITokenCounter, OpenSourceTokenCounter,
AnthropicTokenCounter)

__all__ = [
'openai_api_key_required',
'api_key_required',
'print_text_animated',
'get_prompt_template_key_words',
'get_first_int',
'download_tasks',
'get_task_list',
'check_server_running',
'AnthropicTokenCounter',
'get_system_information',
'to_pascal',
'PYDANTIC_V2',
Expand Down
13 changes: 10 additions & 3 deletions camel/utils/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
F = TypeVar('F', bound=Callable[..., Any])


def openai_api_key_required(func: F) -> F:
def api_key_required(func: F) -> F:
r"""Decorator that checks if the OpenAI API key is available in the
environment variables.

Expand All @@ -46,8 +46,15 @@ def openai_api_key_required(func: F) -> F:

@wraps(func)
def wrapper(self, *args, **kwargs):
if 'OPENAI_API_KEY' in os.environ:
return func(self, *args, **kwargs)
print(self)
if self.model_type.is_openai:
if 'OPENAI_API_KEY' in os.environ:
return func(self, *args, **kwargs)
elif self.model_type.is_anthropic:
if 'ANTHROPIC_API_KEY' in os.environ:
return func(self, *args, **kwargs)
else:
raise ValueError('Anthropic API key not found.')
else:
raise ValueError('OpenAI API key not found.')
Wendong-Fan marked this conversation as resolved.
Show resolved Hide resolved

Expand Down
Loading
Loading