In [2]:
import boto3
import re
import os
import numpy as np
import pandas as pd
from sagemaker import get_execution_role
from sagemaker.serializers import CSVSerializer
from sagemaker.deserializers import CSVDeserializer
import roles
import sagemaker as sage
from time import gmtime, strftime
import json



In [3]:
#Set to True if you want to train before deploy
TRAIN = False

### AWS connection startup

In [4]:
role = f"arn:aws:iam::{roles.account_ID}:role/{roles.SageMakerExecutionRole}"
sess = sage.Session()
account = sess.boto_session.client("sts").get_caller_identity()["Account"]
region = sess.boto_session.region_name
image = "{}.dkr.ecr.{}.amazonaws.com/sagemaker-deploy-terraform:latest".format(account, region)
instance_type = "ml.c4.2xlarge"
output_path = "s3://{}/output".format(sess.default_bucket())

---

### Sample data upload to S3

In [5]:
WORK_DIRECTORY = "../src/data"
prefix = "DEMO-DATA"
data_location = sess.upload_data(WORK_DIRECTORY, key_prefix=prefix)

---

### TRAINING / ARTIFACT DOWNLOAD

In [5]:
if TRAIN:
    sagemaker_model = sage.estimator.Estimator(
    image,
    role,
    1,
    instance_type,
    output_path=output_path,
    sagemaker_session=sess,
    container_port=8080
)
    sagemaker_model.fit(data_location)
else:
    #If we do not train, we need to load the model from S3
    model_artifact = f's3://sagemaker-{region}-{account}/output/{roles.artifact}/output/model.tar.gz'
    sagemaker_model = sage.Model(
        model_data=model_artifact,
        role=role,
        image_uri=image
    )

---

### HOSTING

Deploying process may take some minutes, you can check the status in your aws sagemaker client

In [9]:
predictor = sagemaker_model.deploy(1, "ml.m4.xlarge", serializer=CSVSerializer(), deserializer=CSVDeserializer())

----!

---

### Sample trial

In [18]:
data = pd.read_csv("../src/data/iris.csv", header=None)
test_data = data.iloc[50:70,1:] #drop label and sample

In [19]:
if not TRAIN:
    sagemaker_client = boto3.client('sagemaker')
    response = sagemaker_client.list_endpoints()
    endpoints = response['Endpoints']
    for endpoint in endpoints:
        print(f"Available endpoints:\n> {endpoint['EndpointName']}")

Available endpoints:
> sagemaker-deploy-terraform-2023-09-08-14-55-34-332


In [20]:
if not TRAIN:    
    endpoint_name = 'sagemaker-deploy-terraform-2023-09-08-14-55-34-332'
    predictor = sage.predictor.Predictor(endpoint_name)

In [27]:
content_type = 'text/csv'

In [28]:
print(predictor.predict(
    test_data.to_csv(sep=",", header=False, index=False), 
    initial_args={'ContentType':content_type}).decode("utf-8"))

versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor



---

# Invocation with Boto3

In [29]:
sagemaker_runtime = boto3.client('sagemaker-runtime')

In [30]:
response = sagemaker_runtime.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType=content_type,
    Body=test_data.to_csv(sep=",", header=False, index=False)
)

In [31]:
if response["ResponseMetadata"]["HTTPStatusCode"] == 200:
    response_body = response["Body"].read().decode("utf-8")
    print(response_body)
else:
    raise Exception('An error has occured')

versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor



---

### Clean up

In [32]:
sess.delete_endpoint(predictor.endpoint)

The endpoint attribute has been renamed in sagemaker>=2.
See: https://sagemaker.readthedocs.io/en/stable/v2.html for details.
