Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat add Amazon Bedrock support #1231

Merged
merged 35 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
fd4c0fd
implement framework
usamimeri Apr 25, 2024
4f14ee7
implement base provider
usamimeri Apr 25, 2024
ec7df8a
support mistral
usamimeri Apr 25, 2024
4d1fb20
add generate_kwargs
usamimeri Apr 25, 2024
a741488
add stream
usamimeri Apr 25, 2024
9775a2b
implement meta
usamimeri Apr 25, 2024
187e9ef
support anthropic
usamimeri Apr 25, 2024
6355c5c
implement all model
usamimeri Apr 25, 2024
a6058ca
change provider to private
usamimeri Apr 25, 2024
f45a379
add titan
usamimeri Apr 25, 2024
aded1dc
add some type hint
usamimeri Apr 25, 2024
6452adf
update provider package
usamimeri Apr 25, 2024
784c626
update max tokens and support max_tokens field
usamimeri Apr 25, 2024
0cdca1b
fix llama3 chat template bug
usamimeri Apr 25, 2024
e9723f4
add claude chat template
usamimeri Apr 25, 2024
6561c7a
add docs
usamimeri Apr 25, 2024
cafe666
lazy installation
usamimeri Apr 25, 2024
4c394a1
add test
usamimeri Apr 26, 2024
8fafa2e
stream test
usamimeri Apr 26, 2024
83c8ccb
update
usamimeri Apr 26, 2024
b0ed292
update test
usamimeri Apr 26, 2024
4c77d6c
Merge branch 'geekan:main' into main
usamimeri Apr 26, 2024
6776e5c
Update amazon_bedrock_api.py
usamimeri Apr 26, 2024
3411e7d
Merge branch 'geekan:main' into main
usamimeri Apr 27, 2024
a05d257
fix claude streaming bug
usamimeri Apr 27, 2024
98cb452
remove opus since unavailable now and fix bug
usamimeri Apr 27, 2024
79251cd
resolve problems
usamimeri Apr 28, 2024
cb8e14d
add token counts
usamimeri Apr 28, 2024
aa8e123
delete cache
usamimeri Apr 28, 2024
986e3c8
fix
usamimeri Apr 29, 2024
3f108ab
rename bedrock class and add more tests
usamimeri Apr 29, 2024
f14a1f6
add pre-commit
usamimeri Apr 29, 2024
0006b62
resolve problem and add cost manager
usamimeri Apr 29, 2024
f7b29ed
change log for non-streaming model
usamimeri Apr 29, 2024
0916399
Merge branch 'geekan:main' into main
usamimeri May 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions config/puppeteer-config.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
{
"executablePath": "/usr/bin/chromium",
"args": [
"--no-sandbox"
]
}
"executablePath": "/usr/bin/chromium",
"args": ["--no-sandbox"]
}
7 changes: 6 additions & 1 deletion metagpt/configs/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class LLMType(Enum):
MISTRAL = "mistral"
YI = "yi" # lingyiwanwu
OPENROUTER = "openrouter"
BEDROCK = "bedrock"

def __missing__(self, key):
return self.OPENAI
Expand Down Expand Up @@ -74,10 +75,14 @@ class LLMConfig(YamlModel):
best_of: Optional[int] = None
n: Optional[int] = None
stream: bool = False
logprobs: Optional[bool] = None # https://cookbook.openai.com/examples/using_logprobs
# https://cookbook.openai.com/examples/using_logprobs
logprobs: Optional[bool] = None
top_logprobs: Optional[int] = None
timeout: int = 600

# For Amazon Bedrock
region_name: str = None

# For Network
proxy: Optional[str] = None

Expand Down
2 changes: 2 additions & 0 deletions metagpt/provider/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from metagpt.provider.qianfan_api import QianFanLLM
from metagpt.provider.dashscope_api import DashScopeLLM
from metagpt.provider.anthropic_api import AnthropicLLM
from metagpt.provider.bedrock_api import BedrockLLM

__all__ = [
"GeminiLLM",
Expand All @@ -30,4 +31,5 @@
"QianFanLLM",
"DashScopeLLM",
"AnthropicLLM",
"BedrockLLM",
]
Empty file.
28 changes: 28 additions & 0 deletions metagpt/provider/bedrock/base_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import json
from abc import ABC, abstractmethod


class BaseBedrockProvider(ABC):
usamimeri marked this conversation as resolved.
Show resolved Hide resolved
# to handle different generation kwargs
max_tokens_field_name = "max_tokens"

@abstractmethod
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
...

def get_request_body(self, messages: list[dict], const_kwargs, *args, **kwargs) -> str:
body = json.dumps({"prompt": self.messages_to_prompt(messages), **const_kwargs})
return body

def get_choice_text(self, response_body: dict) -> str:
completions = self._get_completion_from_dict(response_body)
return completions

def get_choice_text_from_stream(self, event) -> str:
rsp_dict = json.loads(event["chunk"]["bytes"])
completions = self._get_completion_from_dict(rsp_dict)
return completions

def messages_to_prompt(self, messages: list[dict]) -> str:
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
return "\n".join([f"{i['role']}: {i['content']}" for i in messages])
121 changes: 121 additions & 0 deletions metagpt/provider/bedrock/bedrock_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import json
from typing import Literal

from metagpt.provider.bedrock.base_provider import BaseBedrockProvider
from metagpt.provider.bedrock.utils import (
messages_to_prompt_llama2,
messages_to_prompt_llama3,
)


class MistralProvider(BaseBedrockProvider):
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html

def messages_to_prompt(self, messages: list[dict]):
return messages_to_prompt_llama2(messages)

def _get_completion_from_dict(self, rsp_dict: dict) -> str:
return rsp_dict["outputs"][0]["text"]


class AnthropicProvider(BaseBedrockProvider):
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html

def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs):
body = json.dumps({"messages": messages, "anthropic_version": "bedrock-2023-05-31", **generate_kwargs})
return body

def _get_completion_from_dict(self, rsp_dict: dict) -> str:
return rsp_dict["content"][0]["text"]

def get_choice_text_from_stream(self, event) -> str:
# https://docs.anthropic.com/claude/reference/messages-streaming
rsp_dict = json.loads(event["chunk"]["bytes"])
if rsp_dict["type"] == "content_block_delta":
completions = rsp_dict["delta"]["text"]
return completions
else:
return ""


class CohereProvider(BaseBedrockProvider):
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command.html

def _get_completion_from_dict(self, rsp_dict: dict) -> str:
return rsp_dict["generations"][0]["text"]

def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs):
body = json.dumps(
{"prompt": self.messages_to_prompt(messages), "stream": kwargs.get("stream", False), **generate_kwargs}
)
return body

def get_choice_text_from_stream(self, event) -> str:
rsp_dict = json.loads(event["chunk"]["bytes"])
completions = rsp_dict.get("text", "")
return completions


class MetaProvider(BaseBedrockProvider):
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html

max_tokens_field_name = "max_gen_len"

def __init__(self, llama_version: Literal["llama2", "llama3"]) -> None:
self.llama_version = llama_version

def messages_to_prompt(self, messages: list[dict]):
if self.llama_version == "llama2":
return messages_to_prompt_llama2(messages)
else:
return messages_to_prompt_llama3(messages)

def _get_completion_from_dict(self, rsp_dict: dict) -> str:
return rsp_dict["generation"]


class Ai21Provider(BaseBedrockProvider):
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-jurassic2.html

max_tokens_field_name = "maxTokens"

def _get_completion_from_dict(self, rsp_dict: dict) -> str:
return rsp_dict["completions"][0]["data"]["text"]


class AmazonProvider(BaseBedrockProvider):
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html

max_tokens_field_name = "maxTokenCount"

def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs):
body = json.dumps({"inputText": self.messages_to_prompt(messages), "textGenerationConfig": generate_kwargs})
return body

def _get_completion_from_dict(self, rsp_dict: dict) -> str:
return rsp_dict["results"][0]["outputText"]

def get_choice_text_from_stream(self, event) -> str:
rsp_dict = json.loads(event["chunk"]["bytes"])
completions = rsp_dict["outputText"]
return completions


PROVIDERS = {
"mistral": MistralProvider,
"meta": MetaProvider,
"ai21": Ai21Provider,
"cohere": CohereProvider,
"anthropic": AnthropicProvider,
"amazon": AmazonProvider,
}


def get_provider(model_id: str):
provider, model_name = model_id.split(".")[0:2] # meta、mistral……
if provider not in PROVIDERS:
raise KeyError(f"{provider} is not supported!")
if provider == "meta":
# distinguish llama2 and llama3
return PROVIDERS[provider](model_name[:6])
return PROVIDERS[provider]()
112 changes: 112 additions & 0 deletions metagpt/provider/bedrock/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from metagpt.logs import logger

# max_tokens for each model
NOT_SUUPORT_STREAM_MODELS = {
"ai21.j2-grande-instruct": 8000,
"ai21.j2-jumbo-instruct": 8000,
"ai21.j2-mid": 8000,
"ai21.j2-mid-v1": 8000,
"ai21.j2-ultra": 8000,
"ai21.j2-ultra-v1": 8000,
}

SUPPORT_STREAM_MODELS = {
"amazon.titan-tg1-large": 8000,
"amazon.titan-text-express-v1": 8000,
"amazon.titan-text-express-v1:0:8k": 8000,
"amazon.titan-text-lite-v1:0:4k": 4000,
"amazon.titan-text-lite-v1": 4000,
"anthropic.claude-instant-v1": 100000,
"anthropic.claude-instant-v1:2:100k": 100000,
"anthropic.claude-v1": 100000,
"anthropic.claude-v2": 100000,
"anthropic.claude-v2:1": 200000,
"anthropic.claude-v2:0:18k": 18000,
"anthropic.claude-v2:1:200k": 200000,
"anthropic.claude-3-sonnet-20240229-v1:0": 200000,
"anthropic.claude-3-sonnet-20240229-v1:0:28k": 28000,
"anthropic.claude-3-sonnet-20240229-v1:0:200k": 200000,
"anthropic.claude-3-haiku-20240307-v1:0": 200000,
"anthropic.claude-3-haiku-20240307-v1:0:48k": 48000,
"anthropic.claude-3-haiku-20240307-v1:0:200k": 200000,
# currently (2024-4-29) only available at US West (Oregon) AWS Region.
"anthropic.claude-3-opus-20240229-v1:0": 200000,
"cohere.command-text-v14": 4000,
"cohere.command-text-v14:7:4k": 4000,
"cohere.command-light-text-v14": 4000,
"cohere.command-light-text-v14:7:4k": 4000,
"meta.llama2-13b-chat-v1:0:4k": 4000,
"meta.llama2-13b-chat-v1": 2000,
"meta.llama2-70b-v1": 4000,
"meta.llama2-70b-v1:0:4k": 4000,
"meta.llama2-70b-chat-v1": 4000,
"meta.llama2-70b-chat-v1:0:4k": 4000,
"meta.llama3-8b-instruct-v1:0": 2000,
"meta.llama3-70b-instruct-v1:0": 2000,
"mistral.mistral-7b-instruct-v0:2": 32000,
"mistral.mixtral-8x7b-instruct-v0:1": 32000,
"mistral.mistral-large-2402-v1:0": 32000,
}

# TODO:use a more general function for constructing chat templates.


def messages_to_prompt_llama2(messages: list[dict]) -> str:
BOS = ("<s>",)
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"

prompt = f"{BOS}"
for message in messages:
role = message.get("role", "")
content = message.get("content", "")
if role == "system":
prompt += f"{B_SYS} {content} {E_SYS}"
elif role == "user":
prompt += f"{B_INST} {content} {E_INST}"
elif role == "assistant":
prompt += f"{content}"
else:
logger.warning(f"Unknown role name {role} when formatting messages")
prompt += f"{content}"

return prompt


def messages_to_prompt_llama3(messages: list[dict]) -> str:
BOS = "<|begin_of_text|>"
GENERAL_TEMPLATE = "<|start_header_id|>{role}<|end_header_id|>\n\n{content}<|eot_id|>"

prompt = f"{BOS}"
for message in messages:
role = message.get("role", "")
content = message.get("content", "")
prompt += GENERAL_TEMPLATE.format(role=role, content=content)

if role != "assistant":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

role will have a NameError if messages is empty

prompt += "<|start_header_id|>assistant<|end_header_id|>"

return prompt


def messages_to_prompt_claude2(messages: list[dict]) -> str:
GENERAL_TEMPLATE = "\n\n{role}: {content}"
prompt = ""
for message in messages:
role = message.get("role", "")
content = message.get("content", "")
prompt += GENERAL_TEMPLATE.format(role=role, content=content)

if role != "assistant":
prompt += "\n\nAssistant:"

return prompt


def get_max_tokens(model_id: str) -> int:
try:
max_tokens = (NOT_SUUPORT_STREAM_MODELS | SUPPORT_STREAM_MODELS)[model_id]
except KeyError:
logger.warning(f"Couldn't find model:{model_id} , max tokens has been set to 2048")
max_tokens = 2048
return max_tokens
Loading
Loading