In [None]:
import boto3
import re
import os
import numpy as np
import pandas as pd
from sagemaker import get_execution_role
from sagemaker.serializers import CSVSerializer
from sagemaker.deserializers import CSVDeserializer
import roles
import sagemaker as sage
from time import gmtime, strftime

### AWS connection startup

In [None]:
role = f"arn:aws:iam::{roles.account_ID}:role/{roles.SageMakerExecutionRole}"
sess = sage.Session()

---

### Sample data upload to S3

In [None]:
WORK_DIRECTORY = "../src/data"
prefix = "DEMO-DATA"
data_location = sess.upload_data(WORK_DIRECTORY, key_prefix=prefix)

---

### TRAINING

To create the estimator we need:
* an ECR image
* a role
* an instance_type
* and output_path
* a session

In [None]:
account = sess.boto_session.client("sts").get_caller_identity()["Account"]
region = sess.boto_session.region_name
image = "{}.dkr.ecr.{}.amazonaws.com/sagemaker-deploy-terraform:latest".format(account, region)
instance_type = "ml.c4.2xlarge"
output_path = "s3://{}/output".format(sess.default_bucket())

tree = sage.estimator.Estimator(
    image,
    role,
    1,
    instance_type,
    output_path=output_path,
    sagemaker_session=sess,
    container_port=8080
)

Fitting the estimator

In [None]:
tree.fit(data_location)

---

### HOSTING

Deploying process may take some minutes, you can check the status in your aws sagemaker client

In [None]:
predictor = tree.deploy(1, "ml.m4.xlarge", serializer=CSVSerializer(), deserializer=CSVDeserializer())

---

### Sample trial

In [None]:
shape = pd.read_csv("../src/data/iris.csv", header=None)
shape.sample(3)

In [None]:
# drop the label column in the training set
shape.drop(shape.columns[[0]], axis=1, inplace=True)
shape.sample(3)

In [None]:
import itertools

a = [50 * i for i in range(3)]
b = [40 + i for i in range(10)]
indices = [i + j for i, j in itertools.product(a, b)]

test_data = shape.iloc[indices[:-1]]

In [None]:
print(predictor.predict(test_data.to_csv(sep=",", header=False, index=False)))#.decode("utf-8")
#test_data.values

---

### Clean up

In [None]:
#sess.delete_endpoint(predictor.endpoint)