-
Notifications
You must be signed in to change notification settings - Fork 809
/
chat_app.py
207 lines (176 loc) · 6.71 KB
/
chat_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
from datetime import datetime
import time
import requests
import sys
import json
from azure.identity import AzureCliCredential
import logging
from azure.ai.ml import MLClient
from sseclient import SSEClient
class ColoredFormatter(logging.Formatter):
# Color code dictionary
color_codes = {
'debug': '\033[0;32m', # Green
'info': '\033[0;36m', # Cyan
'warning': '\033[0;33m', # Yellow
'error': '\033[0;31m', # Red
'critical': '\033[0;35m', # Magenta
}
def format(self, record):
# Get the original message
message = super().format(record)
# Add color codes
message = f"{self.color_codes.get(record.levelname.lower(), '')}{message}\033[0m"
return message
logger = logging.getLogger(__name__)
handler = logging.StreamHandler()
handler.setFormatter(ColoredFormatter())
logger.setLevel(logging.INFO)
logger.addHandler(handler)
def apply_delta(base: dict, delta: dict):
for k, v in delta.items():
if k in base:
base[k] += v
else:
base[k] = v
def score(url, api_key, body, stream=True, on_event=None):
headers = {
"Content-Type": "application/json",
"Authorization": ("Bearer " + api_key),
# The azureml-model-deployment header will force the request to go to a specific deployment.
# Remove this header to have the request observe the endpoint traffic rules
"azureml-model-deployment": "blue",
"Accept": "text/event-stream, application/json" if stream else "application/json"
}
logger.info("Sending HTTP request...")
logger.debug("POST %s", url)
for name, value in headers.items():
if name == "Authorization":
value = "[REDACTED]"
logger.debug(f">>> {name}: {value}")
logger.debug(json.dumps(body, indent=4, ensure_ascii=False))
logger.debug("")
time1 = datetime.now()
response = None
try:
response = requests.post(url, json=body, headers=headers, stream=stream)
response.raise_for_status()
finally:
time2 = datetime.now()
if response is not None:
logger.info(
"Got response: %d %s (elapsed %s)",
response.status_code,
response.reason,
time2 - time1,
)
for name, value in response.headers.items():
logger.debug(f"<<< {name}: {value}")
time1 = datetime.now()
try:
content_type = response.headers.get('Content-Type')
if "text/event-stream" in content_type:
output = {}
client = SSEClient(response)
for event in client.events():
if on_event:
on_event(event)
dct = json.loads(event.data)
apply_delta(output, dct)
return output, True
else:
return response.json(), False
finally:
time2 = datetime.now()
logger.info("\nResponse reading elapsed: %s", time2 - time1)
class ChatApp:
def __init__(self, ml_client, endpoint_name, chat_input_name, chat_output_name, stream=True, debug=False):
self._chat_input_name = chat_input_name
self._chat_output_name = chat_output_name
self._chat_history = []
self._stream = stream
if debug:
logger.setLevel(logging.DEBUG)
logger.info("Getting endpoint info...")
endpoint = ml_client.online_endpoints.get(endpoint_name)
keys = ml_client.online_endpoints.get_keys(endpoint_name)
self._endpoint_url = endpoint.scoring_uri
self._endpoint_key = keys.primary_key if endpoint.auth_mode == "key" else keys.access_token
logger.info(f"Done.")
logger.debug(f"Target endpoint: {endpoint.id}")
@property
def url(self):
return self._endpoint_url
@property
def api_key(self):
return self._endpoint_key
def get_payload(self, chat_input, chat_history=[]):
return {
self._chat_input_name: chat_input,
"chat_history": chat_history,
}
def chat_once(self, chat_input):
def on_event(event):
dct = json.loads(event.data)
answer_delta = dct.get(self._chat_output_name)
if answer_delta:
print(answer_delta, end='')
# We need to flush the output
# otherwise the text does not appear on the console
# unless a new line comes.
sys.stdout.flush()
# Sleep for 20ms for better animation effects
time.sleep(0.02)
try:
payload = self.get_payload(chat_input=chat_input, chat_history=self._chat_history)
output, stream = score(self.url, self.api_key, payload, stream=self._stream, on_event=on_event)
# We don't use self._stream here since the result may not always be the same as self._stream specified.
if stream:
# Print a new line at the end of the content to make sure
# the next logger line will always starts from a new line.
pass
# print("\n")
else:
print(output.get(self._chat_output_name, "<empty>"))
self._chat_history.append({
"inputs": {
self._chat_input_name: chat_input,
},
"outputs": output,
})
logger.info("Length of chat history: %s", len(self._chat_history))
except requests.HTTPError as e:
logger.error(e.response.text)
def chat(self):
while True:
try:
question = input("Chat with Wikipedia:> ")
if question in ("exit", "bye"):
print("Bye.")
break
self.chat_once(question)
except KeyboardInterrupt:
# When pressed Ctrl_C, exit
print("\nBye.")
break
except Exception as e:
logger.exception("An error occurred: %s", e)
# Do not raise the errors out so that we can continue the chat
if __name__ == "__main__":
ml_client = MLClient(
credential=AzureCliCredential(),
# Replace with your subscription ID, resource group name, and workspace name
subscription_id="<your_sub_id>",
resource_group_name="<your_resource_group_name>",
workspace_name="<your_workspace_name>",
)
chat_app = ChatApp(
ml_client=ml_client,
# TODO: Replace with your online endpoint name
endpoint_name="chat-with-wikipedia-stream",
chat_input_name="question",
chat_output_name="answer",
stream=False,
debug=True,
)
chat_app.chat()