In [1]:
%%time

import os
import boto3
import re
import json
import sagemaker
from sagemaker import get_execution_role

region = boto3.Session().region_name

role = get_execution_role()

bucket = sagemaker.Session().default_bucket()

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/ec2-user/.config/sagemaker/config.yaml
CPU times: user 1.58 s, sys: 237 ms, total: 1.82 s
Wall time: 3.32 s


In [2]:
prefix = "sagemaker/DEMO-xgboost-byo"
bucket_path = "https://s3-{}.amazonaws.com/{}".format(region, bucket)
# customize to your bucket where you have stored the data

In [3]:
%%time
import pickle
import boto3
import gzip

# Get the data from a public S3
buf = (
    boto3.client("s3")
    .get_object(
        Bucket=f"sagemaker-example-files-prod-{region}", Key="datasets/image/MNIST/mnist.pkl.gz"
    )["Body"]
    .read()
)


# decompress the buffer
decomp_buf = gzip.decompress(buf)
train_set, valid_set, test_set = pickle.loads(decomp_buf, encoding="latin1")

CPU times: user 886 ms, sys: 427 ms, total: 1.31 s
Wall time: 2.17 s


In [4]:
train_X = train_set[0]
train_y = train_set[1]

valid_X = valid_set[0]
valid_y = valid_set[1]

test_X = test_set[0]
test_y = test_set[1]

In [5]:
import xgboost as xgb
import sklearn as sk

print("Version of XGboost",xgb.__version__)

bt = xgb.XGBClassifier(
    max_depth=5, learning_rate=0.2, n_estimators=10, objective="multi:softmax"
)  # Setup xgboost model
bt.fit(train_X, train_y, eval_set=[(valid_X, valid_y)], verbose=False)  # Train it to our data

Version of XGboost 1.7.1


In [6]:
model_file_name = "DEMO-local-xgboost-model"
bt.save_model(model_file_name)

In [7]:
!tar czvf model.tar.gz $model_file_name

DEMO-local-xgboost-model


In [8]:
fObj = open("model.tar.gz", "rb")
key = os.path.join(prefix, model_file_name, "model.tar.gz")
boto3.Session().resource("s3").Bucket(bucket).Object(key).upload_fileobj(fObj)

In [9]:
from sagemaker.amazon.amazon_estimator import get_image_uri

container = get_image_uri(boto3.Session().region_name, "xgboost", "1.7-1")

The method get_image_uri has been renamed in sagemaker>=2.
See: https://sagemaker.readthedocs.io/en/stable/v2.html for details.


In [10]:
%%time
from time import gmtime, strftime

model_name = model_file_name + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
model_url = "https://s3-{}.amazonaws.com/{}/{}".format(region, bucket, key)
sm_client = boto3.client("sagemaker")

print(model_url)

primary_container = {
    "Image": container,
    "ModelDataUrl": model_url,
}

create_model_response2 = sm_client.create_model(
    ModelName=model_name, ExecutionRoleArn=role, PrimaryContainer=primary_container
)

print(create_model_response2["ModelArn"])

https://s3-eu-west-2.amazonaws.com/sagemaker-eu-west-2-661082688832/sagemaker/DEMO-xgboost-byo/DEMO-local-xgboost-model/model.tar.gz
arn:aws:sagemaker:eu-west-2:661082688832:model/DEMO-local-xgboost-model2024-04-06-15-47-17
CPU times: user 53.1 ms, sys: 38 µs, total: 53.1 ms
Wall time: 628 ms


In [11]:
from time import gmtime, strftime

endpoint_config_name = "DEMO-XGBoostEndpointConfig-" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
print(endpoint_config_name)
create_endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "InstanceType": "ml.m4.xlarge",
            "InitialInstanceCount": 1,
            "InitialVariantWeight": 1,
            "ModelName": model_name,
            "VariantName": "AllTraffic",
        }
    ],
)

print("Endpoint Config Arn: " + create_endpoint_config_response["EndpointConfigArn"])

DEMO-XGBoostEndpointConfig-2024-04-06-15-47-28
Endpoint Config Arn: arn:aws:sagemaker:eu-west-2:661082688832:endpoint-config/DEMO-XGBoostEndpointConfig-2024-04-06-15-47-28


In [12]:
%%time
import time

endpoint_name = "DEMO-XGBoostEndpoint-" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
print(endpoint_name)
create_endpoint_response = sm_client.create_endpoint(
    EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
)
print(create_endpoint_response["EndpointArn"])

resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
    time.sleep(60)
    resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
    status = resp["EndpointStatus"]
    print("Status: " + status)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

DEMO-XGBoostEndpoint-2024-04-06-15-47-37
arn:aws:sagemaker:eu-west-2:661082688832:endpoint/DEMO-XGBoostEndpoint-2024-04-06-15-47-37
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: InService
Arn: arn:aws:sagemaker:eu-west-2:661082688832:endpoint/DEMO-XGBoostEndpoint-2024-04-06-15-47-37
Status: InService
CPU times: user 98.8 ms, sys: 4.83 ms, total: 104 ms
Wall time: 4min 1s


In [13]:
runtime_client = boto3.client("runtime.sagemaker")

In [14]:
import numpy as np

point_X = test_X[0]
point_X = np.expand_dims(point_X, axis=0)
point_y = test_y[0]
np.savetxt("test_point.csv", point_X, delimiter=",")

In [15]:
%%time
import json


file_name = (
    "test_point.csv"  # customize to your test file, will be 'mnist.single.test' if use data above
)

with open(file_name, "r") as f:
    payload = f.read().strip()

response = runtime_client.invoke_endpoint(
    EndpointName=endpoint_name, ContentType="text/csv", Body=payload
)
result = response["Body"].read().decode("ascii")
print("Predicted Class Probabilities: {}.".format(result))

Predicted Class Probabilities: 7.0
.
CPU times: user 18.3 ms, sys: 0 ns, total: 18.3 ms
Wall time: 128 ms


In [16]:
floatArr = np.array(json.loads(result))
predictedLabel = np.argmax(floatArr)
print("Predicted Class Label: {}.".format(predictedLabel))
print("Actual Class Label: {}.".format(point_y))

Predicted Class Label: 0.
Actual Class Label: 7.


In [17]:
sm_client.delete_endpoint(EndpointName=endpoint_name)


{'ResponseMetadata': {'RequestId': '1e224776-fa7e-44b8-8a40-7330b5a0278e',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amzn-requestid': '1e224776-fa7e-44b8-8a40-7330b5a0278e',
   'content-type': 'application/x-amz-json-1.1',
   'date': 'Sat, 06 Apr 2024 16:03:08 GMT',
   'content-length': '0'},
  'RetryAttempts': 0}}