Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for ZhipuAI's GLM LLMs #900

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions memgpt/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ class MemGPTCredentials:
azure_deployment: str = None
azure_embedding_deployment: str = None

# zhipuai config
zhipuai_api_key: str = None

# custom llm API config
openllm_auth_type: str = None
openllm_key: str = None
Expand Down Expand Up @@ -66,6 +69,8 @@ def load(cls) -> "MemGPTCredentials":
"azure_version": get_field(config, "azure", "version"),
"azure_deployment": get_field(config, "azure", "deployment"),
"azure_embedding_deployment": get_field(config, "azure", "embedding_deployment"),
# zhipuai
"zhipuai_api_key": get_field(config, "zhipuai", "api_key"),
# open llm
"openllm_auth_type": get_field(config, "openllm", "auth_type"),
"openllm_key": get_field(config, "openllm", "key"),
Expand Down
85 changes: 84 additions & 1 deletion memgpt/llm_api_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from memgpt.credentials import MemGPTCredentials
from memgpt.local_llm.chat_completion_proxy import get_chat_completion
from memgpt.constants import CLI_WARNING_PREFIX
from memgpt.models.chat_completion_response import ChatCompletionResponse
from memgpt.models.chat_completion_response import ChatCompletionResponse, Choice, Message, ToolCall, FunctionCall

from memgpt.data_types import AgentState

Expand Down Expand Up @@ -247,6 +247,70 @@ def openai_embeddings_request(url, api_key, data):
raise e


def _ensure_zhipu_message_format(messages):
zhipuai_messages = []
for message in messages:
role = message["role"]
if role == "system" or role == "user":
zhipuai_messages.append({"role": role, "content": message["content"]})
elif role == "assistant":
if "tool_calls" in message:
tool_calls = message["tool_calls"]
msg = {"role": role, "tool_calls": []}
for i, tool_call in enumerate(tool_calls):
tool_call_type = tool_call["type"]
assert tool_call_type == "function", f"unknown tool_call_type {tool_call_type}"
msg["tool_calls"].append(
{
"id": tool_call["id"],
"index": i,
"type": tool_call_type,
"function": {"name": tool_call["function"]["name"], "args": tool_call["function"]["arguments"]},
}
)
zhipuai_messages.append(msg)
else:
zhipuai_messages.append({"role": role, "content": message["content"]})
elif role == "tool":
zhipuai_messages.append({"role": role, "content": message["content"], "tool_call_id": message["tool_call_id"]})
else:
raise ValueError(f"unknown role {role}")
return zhipuai_messages


def zhipuai_chat_completions_request(api_key, data):
data["messages"] = _ensure_zhipu_message_format(data["messages"])
try:
import zhipuai
except ImportError:
raise ImportError("zhipuai is not installed, please run `pip install zhipuai`")
client = zhipuai.ZhipuAI(api_key=api_key)
response = client.chat.completions.create(**data)
return ChatCompletionResponse(
id=response.id,
choices=[
Choice(
finish_reason=c.finish_reason,
index=c.index,
message=Message(
content=c.message.content,
role=c.message.role,
tool_calls=None
if c.message.tool_calls is None
else [
ToolCall(
id=ct.id,
function=FunctionCall(name=ct.function.name, arguments=ct.function.arguments),
)
for ct in c.message.tool_calls
],
),
)
for c in response.choices
],
)


def azure_openai_chat_completions_request(resource_name, deployment_id, api_version, api_key, data):
"""https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions"""
from memgpt.utils import printd
Expand Down Expand Up @@ -455,6 +519,25 @@ def create(
api_key=credentials.azure_key,
data=data,
)
elif agent_state.llm_config.model_endpoint_type == "zhipuai":
assert function_call == "auto", "zhipuai does not support specifying function_call for now"
tools = [
{
"type": "function",
"function": f,
}
for f in functions
]
data = dict(
model="glm-4",
tools=tools,
messages=messages,
tool_choice=function_call,
)
return zhipuai_chat_completions_request(
api_key=credentials.zhipuai_api_key,
data=data,
)

# local model
else:
Expand Down
Loading