-
Notifications
You must be signed in to change notification settings - Fork 5k
/
claude.py
130 lines (114 loc) · 4.64 KB
/
claude.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import sys
from fastchat.conversation import Conversation
from server.model_workers.base import *
from server.utils import get_httpx_client
import json, httpx
from typing import List, Dict
from configs import logger, log_verbose
import uvicorn
class ClaudeWorker(ApiModelWorker):
def __init__(
self,
*,
controller_addr: str = None,
worker_addr: str = None,
model_names: List[str] = ["claude-api"],
version: str = "2023-06-01",
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 1024)
super().__init__(**kwargs)
self.version = version
def create_claude_messages(self, params: ApiChatParams) -> json:
has_history = any(msg['role'] == 'assistant' for msg in params.messages)
claude_msg = {
"model": params.model_name,
"max_tokens": params.context_len,
"messages": []
}
for msg in params.messages:
role = msg['role']
content = msg['content']
if role == 'system':
continue
# Adjusting for history presence
if has_history and role == 'assistant':
role = "model"
claude_msg["messages"].append({"role": role, "content": content})
return claude_msg
def do_chat(self, params: ApiChatParams) -> Dict:
data = self.create_claude_messages(params)
url = "https://api.anthropic.com/v1/messages"
headers = {
'anthropic-version': '2023-06-01',
'anthropic-beta': 'messages-2023-12-15',
'Content-Type': 'application/json',
'x-api-key': params.api_key,
}
if log_verbose:
logger.info(f'{self.__class__.__name__}:url: {url}')
logger.info(f'{self.__class__.__name__}:headers: {headers}')
logger.info(f'{self.__class__.__name__}:data: {data}')
text = ""
json_string = ""
timeout = httpx.Timeout(60.0)
client = get_httpx_client(timeout=timeout)
client = get_httpx_client()
with client.stream("POST", url, headers=headers, json=data) as response:
for line in response.iter_lines():
line = line.strip()
if not line:
continue
json_string += line
try:
event_data = json.loads(line)
event_type = event_data.get("type")
if event_type == "content_block_delta":
delta_text = event_data.get("delta", {}).get("text", "")
text += delta_text
elif event_type == "message_stop":
# Message is complete, yield the result
yield {
"error_code": 0,
"text": text
}
text = ""
else:
logger.error(f"Failed to get response: {response.text}")
yield {
"error_code": response.status_code,
"text": "Failed to communicate with Claude API."
}
except json.JSONDecodeError as e:
print("Failed to decode JSON:", e)
print("Invalid JSON string:", json_string)
def get_embeddings(self, params):
# Implement embedding retrieval if necessary
print("embedding")
print(params)
def make_conv_template(self, conv_template: List[Dict[str, str]] = None, model_path: str = None) -> Conversation:
if conv_template is None:
conv_template = [
{"role": "user", "content": "Hello there."},
{"role": "assistant", "content": "Hi, I'm Claude. How can I help you?"},
{"role": "user", "content": "Can you explain LLMs in plain English?"}
]
return Conversation(
name=self.model_names[0],
system_message="You are Claude, a helpful, respectful, and honest assistant.",
messages=conv_template,
roles=["user", "assistant"],
sep="\n### ",
stop_str="###",
)
if __name__ == "__main__":
from server.utils import MakeFastAPIOffline
from fastchat.serve.base_model_worker import app
worker = ClaudeWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:21011",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=21011)