/
06_deploy_model.py
271 lines (216 loc) · 9.45 KB
/
06_deploy_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
# Databricks notebook source
# MAGIC %md
# MAGIC
# MAGIC #Create a model serving endpoint with Python
# MAGIC Now we have a fine-tuned model registered in Unity Catalog, our final step is to deploy this model behind a Model Serving endpoint. This notebook covers wrapping the REST API queries for model serving endpoint creation, updating endpoint configuration based on model version, and endpoint deletion with Python for your Python model serving workflows.
# COMMAND ----------
import mlflow
mlflow.set_registry_uri("databricks-uc")
client = mlflow.tracking.MlflowClient()
catalog = "rlaif"
log_schema = "inference_log" # A schema within the catalog where the inferece log is going to be stored
model_name = "rlaif.model.llama3-8b-vegetarian"
model_serving_endpoint_name = "llama3-8b-vegetarian"
# COMMAND ----------
# MAGIC %md
# MAGIC ## Get token and model version
# MAGIC
# MAGIC The following section demonstrates how to provide both a token for the API, which can be obtained from the notebook and how to get the latest model version you plan to serve and deploy.
# COMMAND ----------
token = (
dbutils.notebook.entry_point.getDbutils()
.notebook()
.getContext()
.apiToken()
.getOrElse(None)
)
# With the token, you can create our authorization header for our subsequent REST calls
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
# Next you need an endpoint at which to execute your request which you can get from the notebook's tags collection
java_tags = dbutils.notebook.entry_point.getDbutils().notebook().getContext().tags()
# This object comes from the Java CM - Convert the Java Map opject to a Python dictionary
tags = sc._jvm.scala.collection.JavaConversions.mapAsJavaMap(java_tags)
# Lastly, extract the Databricks instance (domain name) from the dictionary
instance = tags["browserHostName"]
champion_version = client.get_model_version_by_alias(model_name, "champion")
model_version = champion_version.version
# COMMAND ----------
# MAGIC %md
# MAGIC ## Set up configurations
# MAGIC
# MAGIC Depending on the latency and throughput requirements of your use case, you want to choose the right `workload_type` and `workload_size`. **Note that if you're using Azure Databricks, use `GPU_LARGE` for `workload_type`**. The `auto_capture_config` block specifies where to write the inference logs: i.e. requests and responses from the endpoint with a timestamp.
# COMMAND ----------
import requests
my_json = {
"name": model_serving_endpoint_name,
"config": {
"served_models": [
{
"model_name": model_name,
"model_version": model_version,
"workload_type": "GPU_LARGE",
"workload_size": "Small",
"scale_to_zero_enabled": "false",
"environment_vars": {
"HF_TOKEN": "{{secrets/rlaif/hf_token}}"
}
}
],
"auto_capture_config": {
"catalog_name": catalog,
"schema_name": log_schema,
"table_name_prefix": model_serving_endpoint_name,
},
},
}
# Make sure to the schema for the inference table exists
_ = spark.sql(
f"CREATE SCHEMA IF NOT EXISTS {catalog}.{log_schema}"
)
# Make sure to drop the inference table of it exists
_ = spark.sql(
f"DROP TABLE IF EXISTS {catalog}.{log_schema}.`{model_serving_endpoint_name}_payload`"
)
# COMMAND ----------
# MAGIC %md
# MAGIC The following defines Python functions that:
# MAGIC - create a model serving endpoint
# MAGIC - update a model serving endpoint configuration with the latest model version
# MAGIC - delete a model serving endpoint
# COMMAND ----------
def func_create_endpoint(model_serving_endpoint_name):
# get endpoint status
endpoint_url = f"https://{instance}/api/2.0/serving-endpoints"
url = f"{endpoint_url}/{model_serving_endpoint_name}"
r = requests.get(url, headers=headers)
if "RESOURCE_DOES_NOT_EXIST" in r.text:
print(
"Creating this new endpoint: ",
f"https://{instance}/serving-endpoints/{model_serving_endpoint_name}/invocations",
)
re = requests.post(endpoint_url, headers=headers, json=my_json)
else:
new_model_version = (my_json["config"])["served_models"][0]["model_version"]
print(
"This endpoint existed previously! We are updating it to a new config with new model version: ",
new_model_version,
)
# update config
url = f"{endpoint_url}/{model_serving_endpoint_name}/config"
re = requests.put(url, headers=headers, json=my_json["config"])
# wait till new config file in place
import time, json
# get endpoint status
url = f"https://{instance}/api/2.0/serving-endpoints/{model_serving_endpoint_name}"
retry = True
total_wait = 0
while retry:
r = requests.get(url, headers=headers)
assert (
r.status_code == 200
), f"Expected an HTTP 200 response when accessing endpoint info, received {r.status_code}"
endpoint = json.loads(r.text)
if "pending_config" in endpoint.keys():
seconds = 10
print("New config still pending")
if total_wait < 6000:
# if less the 10 mins waiting, keep waiting
print(f"Wait for {seconds} seconds")
print(f"Total waiting time so far: {total_wait} seconds")
time.sleep(10)
total_wait += seconds
else:
print(f"Stopping, waited for {total_wait} seconds")
retry = False
else:
print("New config in place now!")
retry = False
assert (
re.status_code == 200
), f"Expected an HTTP 200 response, received {re.status_code}"
def func_delete_model_serving_endpoint(model_serving_endpoint_name):
endpoint_url = f"https://{instance}/api/2.0/serving-endpoints"
url = f"{endpoint_url}/{model_serving_endpoint_name}"
response = requests.delete(url, headers=headers)
if response.status_code != 200:
raise Exception(
f"Request failed with status {response.status_code}, {response.text}"
)
else:
print(model_serving_endpoint_name, "endpoint is deleted!")
return response.json()
# COMMAND ----------
func_create_endpoint(model_serving_endpoint_name)
# COMMAND ----------
# MAGIC %md
# MAGIC ## Wait for end point to be ready
# MAGIC
# MAGIC The `wait_for_endpoint()` function defined in the following command gets and returns the serving endpoint status.
# COMMAND ----------
import time, mlflow
def wait_for_endpoint():
endpoint_url = f"https://{instance}/api/2.0/serving-endpoints"
while True:
url = f"{endpoint_url}/{model_serving_endpoint_name}"
response = requests.get(url, headers=headers)
assert (
response.status_code == 200
), f"Expected an HTTP 200 response, received {response.status_code}\n{response.text}"
status = response.json().get("state", {}).get("ready", {})
# print("status",status)
if status == "READY":
print(status)
print("-" * 80)
return
else:
print(f"Endpoint not ready ({status}), waiting 5 miutes")
time.sleep(300) # Wait 300 seconds
api_url = mlflow.utils.databricks_utils.get_webapp_url()
wait_for_endpoint()
# COMMAND ----------
# MAGIC %md
# MAGIC ## Score the model
# MAGIC
# MAGIC The following command defines the `generate_response()` function and sends a scoring request under the `payload_json` variable.
# COMMAND ----------
def prompt_generate(text):
return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are an AI assistant that specializes in cuisine. Your task is to generate a text related to food preferences, recipes, or ingredients based on the question provided below. Generate 1 text and do not generate more than 1 text. Be concise and use no more than 100 words.<|eot_id|><|start_header_id|>user<|end_header_id|>
Question: {text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""
# COMMAND ----------
import os
import requests
import pandas as pd
import json
import matplotlib.pyplot as plt
# Replace URL with the end point invocation url you get from Model Seriving page.
endpoint_url = (
f"https://{instance}/serving-endpoints/{model_serving_endpoint_name}/invocations"
)
token = (
dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
)
def generate_response(text, url=endpoint_url, databricks_token=token):
headers = {
"Authorization": f"Bearer {databricks_token}",
"Content-Type": "application/json",
}
body = {
"dataframe_records": [{"prompt": prompt_generate(text)}],
"params": {"max_tokens": 250},
}
data = json.dumps(body)
response = requests.request(method="POST", headers=headers, url=url, data=data)
if response.status_code != 200:
raise Exception(
f"Request failed with status {response.status_code}, {response.text}"
)
return response.json()
# COMMAND ----------
text = "What are some protein sources that can be used in dishes?"
print(generate_response(text)["predictions"][0]["candidates"][0]["text"])
# COMMAND ----------
# MAGIC %md
# MAGIC ## Delete the endpoint
# COMMAND ----------
func_delete_model_serving_endpoint(model_serving_endpoint_name)