Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes integration test #779

Merged
merged 1 commit into from
May 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 12 additions & 17 deletions tests/integration/llm/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import numpy as np
from datetime import datetime
from io import BytesIO
import hashlib

logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser(description="Build the LLM configs")
Expand Down Expand Up @@ -53,11 +52,10 @@
args = parser.parse_args()


def compute_model_name_hash(model_name):
# This mirrors the Utils.hash implementation from DJL Core
m = hashlib.sha256()
m.update(model_name)
return m.hexdigest()[:40]
def get_model_name():
endpoint = f"http://127.0.0.1:8080/models"
res = requests.get(endpoint).json()
return res["models"][0]["modelName"]


ds_raw_model_spec = {
Expand Down Expand Up @@ -118,21 +116,18 @@ def compute_model_name_hash(model_name):
"batch_size": [1, 4],
"seq_length": [16, 32],
"worker": 1,
"model_name": compute_model_name_hash(b"nomic-ai/gpt4all-j"),
},
"no-code/databricks/dolly-v2-7b": {
"max_memory_per_gpu": [10.0, 12.0],
"batch_size": [1, 4],
"seq_length": [16, 32],
"worker": 2,
"model_name": compute_model_name_hash(b"databricks/dolly-v2-7b"),
},
"no-code/google/flan-t5-xl": {
"max_memory_per_gpu": [7.0, 7.0],
"batch_size": [1, 4],
"seq_length": [16, 32],
"worker": 2,
"model_name": compute_model_name_hash(b"google/flan-t5-xl")
}
}

Expand Down Expand Up @@ -266,7 +261,8 @@ def compute_model_name_hash(model_name):
}


def check_worker_number(desired, model_name="test"):
def check_worker_number(desired):
model_name = get_model_name()
endpoint = f"http://127.0.0.1:8080/models/{model_name}"
res = requests.get(endpoint).json()
if desired == len(res[0]["models"][0]["workerGroups"]):
Expand All @@ -278,9 +274,9 @@ def check_worker_number(desired, model_name="test"):
f"Worker number does not meet requirements! {res}")


def send_json(data, model_name="test"):
def send_json(data):
headers = {'content-type': 'application/json'}
endpoint = f"http://127.0.0.1:8080/predictions/{model_name}"
endpoint = f"http://127.0.0.1:8080/invocations"
resp = requests.post(endpoint, headers=headers, json=data)

if resp.status_code >= 300:
Expand All @@ -289,12 +285,12 @@ def send_json(data, model_name="test"):
return resp


def send_image_json(img_url, data, model_name="test"):
def send_image_json(img_url, data):
multipart_form_data = {
'data': BytesIO(requests.get(img_url, stream=True).content),
'json': (None, json.dumps(data), 'application/json')
}
endpoint = f"http://127.0.0.1:8080/predictions/{model_name}"
endpoint = f"http://127.0.0.1:8080/invocations"
resp = requests.post(endpoint, files=multipart_form_data)

if resp.status_code >= 300:
Expand Down Expand Up @@ -459,8 +455,7 @@ def test_handler(model, model_spec):
)
spec = model_spec[args.model]
if "worker" in spec:
check_worker_number(spec["worker"],
model_name=spec.get("model_name", "test"))
check_worker_number(spec["worker"])
for i, batch_size in enumerate(spec["batch_size"]):
for seq_length in spec["seq_length"]:
if "t5" in model:
Expand All @@ -470,7 +465,7 @@ def test_handler(model, model_spec):
params = {"max_new_tokens": seq_length}
req["parameters"] = params
logging.info(f"req {req}")
res = send_json(req, model_name=spec.get("model_name", "test"))
res = send_json(req)
if spec.get("stream_output", False):
logging.info(f"res: {res.content}")
result = res.content.decode().split("\n")[:-1]
Expand Down
4 changes: 1 addition & 3 deletions tests/integration/llm/sagemaker-endpoint-tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,7 @@
}
}

ENGINE_TO_METRIC_CONFIG_ENGINE = {
"Python" : "Accelerate"
}
ENGINE_TO_METRIC_CONFIG_ENGINE = {"Python": "Accelerate"}


def get_sagemaker_session(default_bucket=DEFAULT_BUCKET,
Expand Down
Loading