# AutoGluon Tabular with SageMaker

[AutoGluon](https://github.com/awslabs/autogluon) automates machine learning tasks enabling you to easily achieve strong predictive performance in your applications. With just a few lines of code, you can train and deploy high-accuracy deep learning models on tabular, image, and text data.
This notebook shows how to use AutoGluon-Tabular with Amazon SageMaker by creating custom containers.

## Prerequisites

If using a SageMaker hosted notebook, select kernel `conda_mxnet_p36`.

In [1]:
import subprocess

# Make sure docker compose is set up properly for local mode
subprocess.run("./setup.sh", shell=True)

CompletedProcess(args='./setup.sh', returncode=126)

In [2]:
# For Studio
subprocess.run("apt-get update -y", shell=True)
subprocess.run("apt install unzip", shell=True)

CompletedProcess(args='apt install unzip', returncode=127)

In [3]:
import os
import sys
import boto3
import sagemaker
from time import sleep
from collections import Counter
import numpy as np
import pandas as pd
from sagemaker import get_execution_role, local, Model, utils, s3
from sagemaker.estimator import Estimator
from sagemaker.predictor import Predictor
from sagemaker.serializers import CSVSerializer
from sagemaker.deserializers import StringDeserializer
from sklearn.metrics import accuracy_score, classification_report
from IPython.core.display import display, HTML
from IPython.core.interactiveshell import InteractiveShell

# Print settings
InteractiveShell.ast_node_interactivity = "all"
pd.set_option("display.max_columns", 500)
pd.set_option("display.max_rows", 10)

# Account/s3 setup
session = sagemaker.Session()
local_session = local.LocalSession()
bucket = session.default_bucket()
prefix = "sagemaker/autogluon-tabular"
region = session.boto_region_name
role = get_execution_role()
client = session.boto_session.client(
    "sts", region_name=region, endpoint_url=utils.sts_regional_endpoint(region)
)
account = client.get_caller_identity()["Account"]

registry_uri_training = sagemaker.image_uris.retrieve(
    "mxnet",
    region,
    version="1.7.0",
    py_version="py3",
    instance_type="ml.m5.2xlarge",
    image_scope="training",
)
registry_uri_inference = sagemaker.image_uris.retrieve(
    "mxnet",
    region,
    version="1.7.0",
    py_version="py3",
    instance_type="ml.m5.2xlarge",
    image_scope="inference",
)
ecr_uri_prefix = account + "." + ".".join(registry_uri_training.split("/")[0].split(".")[1:])

### Build docker images

Build the training/inference image and push to ECR

In [9]:
training_algorithm_name = "autogluon-sagemaker-training"
inference_algorithm_name = "autogluon-sagemaker-inference"

First, you may want to remove existing docker images to make a room to build autogluon containers.

In [None]:
subprocess.run("docker system prune -af", shell=True)

In [None]:
subprocess.run(
    f"/bin/bash ./container-training/build_push_training.sh {account} {region} {training_algorithm_name} {ecr_uri_prefix} {registry_uri_training.split('/')[0].split('.')[0]} {registry_uri_training}",
    shell=True,
)
subprocess.run("docker system prune -af", shell=True)

In [None]:
subprocess.run(
    f"/bin/bash ./container-inference/build_push_inference.sh {account} {region} {inference_algorithm_name} {ecr_uri_prefix} {registry_uri_training.split('/')[0].split('.')[0]} {registry_uri_inference}",
    shell=True,
)
subprocess.run("docker system prune -af", shell=True)

### Alternative way of building docker images using sm-docker

The new Amazon SageMaker Studio Image Build convenience package allows data scientists and developers to easily build custom container images from your Studio notebooks via a new CLI. 
Newly built Docker images are tagged and pushed to Amazon ECR. 

To use the CLI, you need to ensure the Amazon SageMaker execution role used by your Studio notebook environment (or another AWS Identity and Access Management (IAM) role, if you prefer) has the required permissions to interact with the resources used by the CLI, including access to CodeBuild and Amazon ECR. Your role should have a trust policy with CodeBuild. 

You also need to make sure the appropriate permissions are included in your role to run the build in CodeBuild, create a repository in Amazon ECR, and push images to that repository. 

See also: https://aws.amazon.com/blogs/machine-learning/using-the-amazon-sagemaker-studio-image-build-cli-to-build-container-images-from-your-studio-notebooks/

In [None]:
# subprocess.run("pip install sagemaker-studio-image-build", shell=True)

In [None]:
"""
training_repo_name = training_algorithm_name + ':latest'
training_repo_name  

!sm-docker build . --repository {training_repo_name} \
--file ./container-training/Dockerfile.training --build-arg REGISTRY_URI={registry_uri_training}

inference_repo_name = inference_algorithm_name + ':latest'
inference_repo_name  

!sm-docker build . --repository {inference_repo_name} \
--file ./container-inference/Dockerfile.inference --build-arg REGISTRY_URI={registry_uri_inference}
"""

### Get the data

In this example we'll use the direct-marketing dataset to build a binary classification model that predicts whether customers will accept or decline a marketing offer.  
First we'll download the data and split it into train and test sets. AutoGluon does not require a separate validation set (it uses bagged k-fold cross-validation).

In [4]:
# Download and unzip the data
# subprocess.run(
#     f"aws s3 cp --region {region} s3://sagemaker-sample-data-{region}/autopilot/direct_marketing/bank-additional.zip .",
#     shell=True,
# )
# subprocess.run("unzip -qq -o bank-additional.zip", shell=True)
# subprocess.run("rm bank-additional.zip", shell=True)

local_data_path = "./normalized.csv"
data = pd.read_csv(local_data_path)

# Split train/test data
train = data.sample(frac=0.8, random_state=42)
test = data.drop(train.index)

# Split test X/y
label = "genre"
y_test = test[label]
X_test = test.drop(columns=[label])

##### Check the data

In [5]:
train.head(3)
train.shape

test.head(3)
test.shape

X_test.head(3)
X_test.shape

Unnamed: 0,genre,meanTempogram,stdTempogram,varTempogram,meanMFCC_1,stdMFCC_1,varMFCC_1,meanMFCC_2,stdMFCC_2,varMFCC_2,meanMFCC_3,stdMFCC_3,varMFCC_3,meanMFCC_4,stdMFCC_4,varMFCC_4,meanMFCC_5,stdMFCC_5,varMFCC_5,meanMFCC_6,stdMFCC_6,varMFCC_6,meanMFCC_7,stdMFCC_7,varMFCC_7,meanMFCC_8,stdMFCC_8,varMFCC_8,meanMFCC_9,stdMFCC_9,varMFCC_9,meanMFCC_10,stdMFCC_10,varMFCC_10,meanMFCC_11,stdMFCC_11,varMFCC_11,meanMFCC_12,stdMFCC_12,varMFCC_12,meanMFCC_13,stdMFCC_13,varMFCC_13,meanSpectralCentroid,stdSpectralCentroid,varSpectralCentroid,meanZeroCrossingRate,stdZeroCrossingRate,varZeroCrossingRate,meanChromaFrequencies,stdChromaFrequencies,varChromaFrequencies,meanSpectralRollOff,stdSpectralRollOff,varSpectralRollOff,meanSpectralBandwidth,stdSpectralBandwidth,varSpectralBandwidth,meanSpectralContrast,stdSpectralContrast,varSpectralContrast,meanSpectralFlatness,stdSpectralFlatness,varSpectralFlatness
521,jazz,0.314589,0.254149,0.160138,0.608754,0.227671,0.077467,0.673336,0.259353,0.117975,0.559693,0.31183,0.139814,0.396798,0.340712,0.189345,0.584283,0.093185,0.032761,0.492108,0.194535,0.09047,0.714469,0.096024,0.045933,0.355306,0.184465,0.092547,0.591702,0.174561,0.094353,0.358224,0.178639,0.087415,0.57326,0.095917,0.036938,0.437105,0.184182,0.094273,0.645136,0.283869,0.160472,0.221815,0.201252,0.056115,0.12143,0.087019,0.013563,0.229677,0.786529,0.749919,0.261275,0.380175,0.161852,0.378606,0.392879,0.207209,0.428007,0.177257,0.095407,0.010826,0.037977,0.001465
737,metal,0.427414,0.367745,0.252432,0.791073,0.377333,0.176631,0.51264,0.14285,0.052731,0.500528,0.448799,0.250501,0.876582,0.247469,0.121977,0.573348,0.143785,0.055754,0.791817,0.139737,0.0599,0.505028,0.206719,0.112089,0.667056,0.219528,0.114842,0.416856,0.182787,0.099636,0.67352,0.103689,0.045907,0.443367,0.146669,0.061546,0.66492,0.153438,0.075713,0.46894,0.200424,0.103149,0.365758,0.296179,0.107968,0.255122,0.232964,0.067747,0.650728,0.514647,0.460181,0.445007,0.355386,0.143136,0.464974,0.209664,0.080678,0.535574,0.772773,0.674221,0.016682,0.042346,0.001818
740,metal,0.395358,0.326509,0.217449,0.676874,0.186592,0.056942,0.356315,0.326525,0.164674,0.567737,0.338477,0.158992,0.510937,0.309423,0.165432,0.63961,0.276527,0.133473,0.552426,0.350408,0.199235,0.768443,0.307101,0.184308,0.474861,0.393186,0.247407,0.697506,0.235522,0.135296,0.576919,0.320299,0.184944,0.633876,0.141666,0.058964,0.605591,0.285196,0.163214,0.649899,0.238405,0.128192,0.6199,0.50457,0.27887,0.469943,0.395739,0.174642,0.599233,0.643202,0.593161,0.683152,0.598445,0.375799,0.712569,0.586955,0.398239,0.432068,0.425076,0.287915,0.105506,0.192903,0.037308


(800, 64)

Unnamed: 0,genre,meanTempogram,stdTempogram,varTempogram,meanMFCC_1,stdMFCC_1,varMFCC_1,meanMFCC_2,stdMFCC_2,varMFCC_2,meanMFCC_3,stdMFCC_3,varMFCC_3,meanMFCC_4,stdMFCC_4,varMFCC_4,meanMFCC_5,stdMFCC_5,varMFCC_5,meanMFCC_6,stdMFCC_6,varMFCC_6,meanMFCC_7,stdMFCC_7,varMFCC_7,meanMFCC_8,stdMFCC_8,varMFCC_8,meanMFCC_9,stdMFCC_9,varMFCC_9,meanMFCC_10,stdMFCC_10,varMFCC_10,meanMFCC_11,stdMFCC_11,varMFCC_11,meanMFCC_12,stdMFCC_12,varMFCC_12,meanMFCC_13,stdMFCC_13,varMFCC_13,meanSpectralCentroid,stdSpectralCentroid,varSpectralCentroid,meanZeroCrossingRate,stdZeroCrossingRate,varZeroCrossingRate,meanChromaFrequencies,stdChromaFrequencies,varChromaFrequencies,meanSpectralRollOff,stdSpectralRollOff,varSpectralRollOff,meanSpectralBandwidth,stdSpectralBandwidth,varSpectralBandwidth,meanSpectralContrast,stdSpectralContrast,varSpectralContrast,meanSpectralFlatness,stdSpectralFlatness,varSpectralFlatness
1,reggae,0.251674,0.209781,0.127566,0.601865,0.411215,0.204393,0.617099,0.448078,0.266062,0.569494,0.381634,0.192466,0.598732,0.686925,0.542005,0.506231,0.455667,0.278308,0.73221,0.429512,0.266777,0.431441,0.67615,0.54979,0.491202,0.540061,0.388292,0.501946,0.713977,0.600301,0.697637,0.478588,0.323442,0.503482,0.286948,0.147788,0.555267,0.374413,0.23426,0.523616,0.472941,0.321634,0.290066,0.373626,0.162326,0.188749,0.24154,0.072156,0.400401,0.826589,0.795333,0.365662,0.541328,0.311284,0.407825,0.488073,0.293583,0.652288,0.562604,0.424493,0.009669,0.029061,0.000862
4,reggae,0.174107,0.163321,0.095551,0.752637,0.387832,0.185024,0.348909,0.63572,0.465276,0.597469,0.309881,0.138456,0.435811,0.509114,0.340705,0.672473,0.298856,0.149022,0.478061,0.528161,0.362653,0.632067,0.531175,0.38747,0.407135,0.526068,0.373733,0.64379,0.534757,0.396267,0.362759,0.391483,0.243373,0.681846,0.222939,0.105115,0.524689,0.286852,0.164446,0.781792,0.406446,0.260007,0.616986,0.766661,0.605143,0.416471,0.532536,0.302367,0.585827,0.662927,0.614202,0.645751,0.775026,0.613481,0.686209,0.584828,0.395828,0.315782,0.124489,0.063318,0.215885,0.452068,0.20452
13,reggae,0.190306,0.162489,0.094997,0.818058,0.454506,0.242719,0.510609,0.609768,0.434636,0.484726,0.371896,0.184652,0.459279,0.481339,0.313109,0.752121,0.376308,0.208483,0.467799,0.605514,0.446873,0.773822,0.520902,0.376888,0.479969,0.372597,0.229764,0.620313,0.491667,0.352543,0.513828,0.403445,0.25381,0.848489,0.381108,0.220689,0.61273,0.477182,0.327903,0.772189,0.301585,0.17373,0.390945,0.514024,0.288482,0.241661,0.28183,0.09469,0.414717,0.831555,0.801012,0.463441,0.685964,0.48638,0.523306,0.765962,0.626421,0.320781,0.135522,0.069769,0.061493,0.168321,0.028419


(200, 64)

Unnamed: 0,meanTempogram,stdTempogram,varTempogram,meanMFCC_1,stdMFCC_1,varMFCC_1,meanMFCC_2,stdMFCC_2,varMFCC_2,meanMFCC_3,stdMFCC_3,varMFCC_3,meanMFCC_4,stdMFCC_4,varMFCC_4,meanMFCC_5,stdMFCC_5,varMFCC_5,meanMFCC_6,stdMFCC_6,varMFCC_6,meanMFCC_7,stdMFCC_7,varMFCC_7,meanMFCC_8,stdMFCC_8,varMFCC_8,meanMFCC_9,stdMFCC_9,varMFCC_9,meanMFCC_10,stdMFCC_10,varMFCC_10,meanMFCC_11,stdMFCC_11,varMFCC_11,meanMFCC_12,stdMFCC_12,varMFCC_12,meanMFCC_13,stdMFCC_13,varMFCC_13,meanSpectralCentroid,stdSpectralCentroid,varSpectralCentroid,meanZeroCrossingRate,stdZeroCrossingRate,varZeroCrossingRate,meanChromaFrequencies,stdChromaFrequencies,varChromaFrequencies,meanSpectralRollOff,stdSpectralRollOff,varSpectralRollOff,meanSpectralBandwidth,stdSpectralBandwidth,varSpectralBandwidth,meanSpectralContrast,stdSpectralContrast,varSpectralContrast,meanSpectralFlatness,stdSpectralFlatness,varSpectralFlatness
1,0.251674,0.209781,0.127566,0.601865,0.411215,0.204393,0.617099,0.448078,0.266062,0.569494,0.381634,0.192466,0.598732,0.686925,0.542005,0.506231,0.455667,0.278308,0.73221,0.429512,0.266777,0.431441,0.67615,0.54979,0.491202,0.540061,0.388292,0.501946,0.713977,0.600301,0.697637,0.478588,0.323442,0.503482,0.286948,0.147788,0.555267,0.374413,0.23426,0.523616,0.472941,0.321634,0.290066,0.373626,0.162326,0.188749,0.24154,0.072156,0.400401,0.826589,0.795333,0.365662,0.541328,0.311284,0.407825,0.488073,0.293583,0.652288,0.562604,0.424493,0.009669,0.029061,0.000862
4,0.174107,0.163321,0.095551,0.752637,0.387832,0.185024,0.348909,0.63572,0.465276,0.597469,0.309881,0.138456,0.435811,0.509114,0.340705,0.672473,0.298856,0.149022,0.478061,0.528161,0.362653,0.632067,0.531175,0.38747,0.407135,0.526068,0.373733,0.64379,0.534757,0.396267,0.362759,0.391483,0.243373,0.681846,0.222939,0.105115,0.524689,0.286852,0.164446,0.781792,0.406446,0.260007,0.616986,0.766661,0.605143,0.416471,0.532536,0.302367,0.585827,0.662927,0.614202,0.645751,0.775026,0.613481,0.686209,0.584828,0.395828,0.315782,0.124489,0.063318,0.215885,0.452068,0.20452
13,0.190306,0.162489,0.094997,0.818058,0.454506,0.242719,0.510609,0.609768,0.434636,0.484726,0.371896,0.184652,0.459279,0.481339,0.313109,0.752121,0.376308,0.208483,0.467799,0.605514,0.446873,0.773822,0.520902,0.376888,0.479969,0.372597,0.229764,0.620313,0.491667,0.352543,0.513828,0.403445,0.25381,0.848489,0.381108,0.220689,0.61273,0.477182,0.327903,0.772189,0.301585,0.17373,0.390945,0.514024,0.288482,0.241661,0.28183,0.09469,0.414717,0.831555,0.801012,0.463441,0.685964,0.48638,0.523306,0.765962,0.626421,0.320781,0.135522,0.069769,0.061493,0.168321,0.028419


(200, 63)

Upload the data to s3

In [6]:
train_file = "train.csv"
train.to_csv(train_file, index=False)
train_s3_path = session.upload_data(train_file, key_prefix="{}/data".format(prefix))

test_file = "test.csv"
test.to_csv(test_file, index=False)
test_s3_path = session.upload_data(test_file, key_prefix="{}/data".format(prefix))

X_test_file = "X_test.csv"
X_test.to_csv(X_test_file, index=False)
X_test_s3_path = session.upload_data(X_test_file, key_prefix="{}/data".format(prefix))

## Hyperparameter Selection

The minimum required settings for training is just a target label, `init_args['label']`.

Additional optional hyperparameters can be passed to the `autogluon.tabular.TabularPredictor.fit` function via `fit_args`.

Below shows a more in depth example of AutoGluon-Tabular hyperparameters from the example [Predicting Columns in a Table - In Depth](https://auto.gluon.ai/stable/tutorials/tabular_prediction/tabular-indepth.html). Please see [fit parameters](https://auto.gluon.ai/stable/_modules/autogluon/tabular/predictor/predictor.html#TabularPredictor) for further information. Note that in order for hyperparameter ranges to work in SageMaker, values passed to the `fit_args['hyperparameters']` must be represented as strings.

```python
nn_options = {
    'num_epochs': "10",
    'learning_rate': "ag.space.Real(1e-4, 1e-2, default=5e-4, log=True)",
    'activation': "ag.space.Categorical('relu', 'softrelu', 'tanh')",
    'layers': "ag.space.Categorical([100],[1000],[200,100],[300,200,100])",
    'dropout_prob': "ag.space.Real(0.0, 0.5, default=0.1)"
}

gbm_options = {
    'num_boost_round': "100",
    'num_leaves': "ag.space.Int(lower=26, upper=66, default=36)"
}

model_hps = {'NN': nn_options, 'GBM': gbm_options} 

init_args = {
  'eval_metric' : 'roc_auc'  
  'label': 'y'
}

fit_args = {
  'presets': ['best_quality', 'optimize_for_deployment'],
  'time_limits': 60*10,
  'hyperparameters': model_hps,
  'hyperparameter_tune': True,
  'search_strategy': 'skopt'
}


hyperparameters = {
  'fit_args': fit_args,
  'feature_importance': True
}
```
**Note:** Your hyperparameter choices may affect the size of the model package, which could result in additional time taken to upload your model and complete training. Including `'optimize_for_deployment'` in the list of `fit_args['presets']` is recommended to greatly reduce upload times.

<br>

In [10]:
# Define required label and optional additional parameters
init_args = {"label": "genre"}

# Define additional parameters
fit_args = {
    # Adding 'best_quality' to presets list will result in better performance (but longer runtime)
    "presets": ["best_quality"],  
    "time_limit": 60*60,

}

# Pass fit_args to SageMaker estimator hyperparameters # "feature_importance": True
hyperparameters = {"init_args": init_args, "fit_args": fit_args}

tags = [{"Key": "AlgorithmName", "Value": "AutoGluon-Tabular"}]

## Train

For local training set `train_instance_type` to `local` .   
For non-local training the recommended instance type is `ml.m5.2xlarge`.   

**Note:** Depending on how many underlying models are trained, `train_volume_size` may need to be increased so that they all fit on disk.

In [None]:
%%time

#instance_type = "ml.m5.2xlarge"
instance_type = 'local'

ecr_image = f"{ecr_uri_prefix}/{training_algorithm_name}:latest"

estimator = Estimator(
    image_uri=ecr_image,
    role=role,
    instance_count=1,
    instance_type=instance_type,
    hyperparameters=hyperparameters,
    volume_size=1000,
    tags=tags,
)

# Set inputs. Test data is optional, but requires a label column.
inputs = {"training": train_s3_path, "testing": test_s3_path}

estimator.fit(inputs)

Creating 32ym1asrgc-algo-1-82cy1 ... 
Creating 32ym1asrgc-algo-1-82cy1 ... done
Attaching to 32ym1asrgc-algo-1-82cy1
[36m32ym1asrgc-algo-1-82cy1 |[0m 2021-12-01 20:00:31,278 sagemaker-training-toolkit INFO     Imported framework sagemaker_mxnet_container.training
[36m32ym1asrgc-algo-1-82cy1 |[0m 2021-12-01 20:00:31,281 sagemaker-training-toolkit INFO     No GPUs detected (normal if no gpus installed)
[36m32ym1asrgc-algo-1-82cy1 |[0m 2021-12-01 20:00:31,281 sagemaker-training-toolkit INFO     Failed to parse hyperparameter init_args value {'label': 'genre'} to Json.
[36m32ym1asrgc-algo-1-82cy1 |[0m Returning the value itself
[36m32ym1asrgc-algo-1-82cy1 |[0m 2021-12-01 20:00:31,281 sagemaker-training-toolkit INFO     Failed to parse hyperparameter fit_args value {'presets': ['best_quality'], 'time_limit': 3600} to Json.
[36m32ym1asrgc-algo-1-82cy1 |[0m Returning the value itself
[36m32ym1asrgc-algo-1-82cy1 |[0m 2021-12-01 20:00:31,293 sagemaker_mxnet_container.training INFO

### Review the performance of the trained model

In [None]:
from utils.ag_utils import launch_viewer

launch_viewer(is_debug=False)

### Create Model

In [None]:
# Create predictor object
class AutoGluonTabularPredictor(Predictor):
    def __init__(self, *args, **kwargs):
        super().__init__(
            *args, serializer=CSVSerializer(), deserializer=StringDeserializer(), **kwargs
        )

In [None]:
ecr_image = f"{ecr_uri_prefix}/{inference_algorithm_name}:latest"

if instance_type == "local":
    model = estimator.create_model(image_uri=ecr_image, role=role)
else:
    # model_uri = os.path.join(estimator.output_path, estimator._current_job_name, "output", "model.tar.gz")
    model_uri = estimator.model_data
    model = Model(
        ecr_image,
        model_data=model_uri,
        role=role,
        sagemaker_session=session,
        predictor_cls=AutoGluonTabularPredictor,
    )

### Batch Transform

For local mode, either `s3://<bucket>/<prefix>/output/` or `file:///<absolute_local_path>` can be used as outputs.

By including the label column in the test data, you can also evaluate prediction performance (In this case, passing `test_s3_path` instead of `X_test_s3_path`).

In [None]:
output_path = f"s3://{bucket}/{prefix}/output/"
# output_path = f'file://{os.getcwd()}'

transformer = model.transformer(
    instance_count=1,
    instance_type=instance_type,
    strategy="MultiRecord",
    max_payload=6,
    max_concurrent_transforms=1,
    output_path=output_path,
)

transformer.transform(test_s3_path, content_type="text/csv", split_type="Line")
transformer.wait()

### Endpoint

##### Deploy remote or local endpoint

In [None]:
instance_type = "ml.m5.2xlarge"
# instance_type = 'local'

predictor = model.deploy(initial_instance_count=1, instance_type=instance_type)

##### Attach to endpoint (or reattach if kernel was restarted)

In [None]:
# Select standard or local session based on instance_type
if instance_type == "local":
    sess = local_session
else:
    sess = session

# Attach to endpoint
predictor = AutoGluonTabularPredictor(predictor.endpoint_name, sagemaker_session=sess)

##### Predict on unlabeled test data

In [None]:
print(predictor.endpoint_name)

!aws sagemaker list-endpoints --region us-east-1

results = predictor.predict(X_test.to_csv(index=False)).splitlines()



In [None]:
genres = ['jazz', 'metal', 'country', 'classical', 'blues', 'hiphop', 'reggae', 'disco', 'pop', 'rock']

#split the result string
prob_results = [result.split(",") for result in results ]

#cast each item from a string to a float for probability
for i in range(len(prob_results)):
    predicton = prob_results[i]
    for k in  range(len(predicton)):
        prob_results[i][k] = float(prob_results[i][k])

#print(prob_results)
for result in prob_results:
    print(result)
    max_index = result.index(max(result))
    print(max_index)
    print(genres[max_index])


##### Predict on data that includes label column  
Prediction performance metrics will be printed to endpoint logs.

In [None]:
results = predictor.predict(test.to_csv(index=False)).splitlines()

# Check output
threshold = 0.5
y_results = np.array(["yes" if float(i.split(",")[1]) > threshold else "no" for i in results])

print(Counter(y_results))

##### Check that classification performance metrics match evaluation printed to endpoint logs as expected

In [None]:
threshold = 0.5
y_results = np.array(["yes" if float(i.split(",")[1]) > threshold else "no" for i in results])

print("accuracy: {}".format(accuracy_score(y_true=y_test, y_pred=y_results)))
print(classification_report(y_true=y_test, y_pred=y_results, digits=6))

##### Clean up endpoint

In [None]:
predictor.delete_endpoint()

## Explainability with Amazon SageMaker Clarify

There are growing business needs and legislative regulations that require explainations of why a model made a certain decision. SHAP (SHapley Additive exPlanations) is an approach to explain the output of machine learning models. SHAP values represent a feature's contribution to a change in the model output. SageMaker Clarify uses SHAP to explain the contribution that each input feature makes to the final decision.

##### Set parameters for SHAP calculation

In [None]:
seed = 0
num_rows = 500

# Write a csv file used by SageMaker Clarify
test_explainavility_file = "test_explainavility.csv"
train.head(num_rows).to_csv(test_explainavility_file, index=False, header=False)
test_explainavility_s3_path = session.upload_data(
    test_explainavility_file, key_prefix="{}/data".format(prefix)
)

##### Specify computing resources

In [None]:
from sagemaker import clarify

model_name = estimator.latest_training_job.job_name
container_def = model.prepare_container_def()
session.create_model(model_name, role, container_def)

clarify_processor = clarify.SageMakerClarifyProcessor(
    role=role, instance_count=1, instance_type="ml.c4.xlarge", sagemaker_session=session
)
model_config = clarify.ModelConfig(
    model_name=model_name, instance_type="ml.c5.xlarge", instance_count=1, accept_type="text/csv"
)

##### Run a SageMaker Clarify job

In [None]:
shap_config = clarify.SHAPConfig(
    baseline=X_test.sample(15, random_state=seed).values.tolist(),
    num_samples=100,
    agg_method="mean_abs",
)

explainability_output_path = "s3://{}/{}/{}/clarify-explainability".format(
    bucket, prefix, model_name
)
explainability_data_config = clarify.DataConfig(
    s3_data_input_path=test_explainavility_s3_path,
    s3_output_path=explainability_output_path,
    label="y",
    headers=train.columns.to_list(),
    dataset_type="text/csv",
)

predictions_config = clarify.ModelPredictedLabelConfig(probability_threshold=0.5)

clarify_processor.run_explainability(
    data_config=explainability_data_config,
    model_config=model_config,
    explainability_config=shap_config,
)

##### View the Explainability Report

You can view the explainability report in Studio under the experiments tab. If you're not a Studio user yet, as with the Bias Report, you can access this report at the following S3 bucket.

In [None]:
subprocess.run(f"aws s3 cp {explainability_output_path} . --recursive", shell=True)

Global explanatory methods allow understanding the model and its feature contributions in aggregate over multiple datapoints. Here we show an aggregate bar plot that plots the mean absolute SHAP value for each feature.

In [None]:
subprocess.run(f"{sys.executable} -m pip install shap", shell=True)

##### Compute global shap values out of out.csv

In [None]:
shap_values_ = pd.read_csv("explanations_shap/out.csv")
shap_values_.abs().mean().to_dict()

In [None]:
num_features = len(train.head(num_rows).drop(["y"], axis=1).columns)

In [None]:
import shap

shap_values = [shap_values_.to_numpy()[:, :num_features], shap_values_.to_numpy()[:, num_features:]]
shap.summary_plot(
    shap_values,
    plot_type="bar",
    feature_names=train.head(num_rows).drop(["y"], axis=1).columns.tolist(),
)

The detailed summary plot below can provide more context over the above bar chart. It tells which features are most important and, in addition, their range of effects over the dataset. The color allows us to match how changes in the value of a feature effect the change in prediction. The 'red' indicates higher value of the feature and 'blue' indicates lower (normalized over the features).

In [None]:
shap.summary_plot(
    shap_values_[shap_values_.columns[20:]].to_numpy(), train.head(num_rows).drop(["y"], axis=1)
)