# Making batch predictions using a TensorFlow model with Amazon SageMaker
This notebook shows how to make **batch predictions with TensorFlow on SageMaker**. Many customers have machine learning workloads that require a large number of predictions to be made reliably on a repeatable schedule. As compared to SageMaker's managed hosting service, inference compute capacity for batch predictions is spun up on demand and taken down upon completion of the batch. For large batch workloads, this represents significant cost savings over an always-on endpoint. Data scientists can stay focused on creating the best models, since [SageMaker batch](https://docs.aws.amazon.com/sagemaker/latest/dg/batch-transform.html) uses the same trained model easily across hosted endpoints and batch.

In [8]:
import boto3
import sagemaker
import tensorflow
from sagemaker.tensorflow.serving import Model
from sagemaker.tensorflow import TensorFlow
from time import gmtime, strftime

To help with evaluating the batch prediction results, enter the list of class labels that your classifier was trained on.

In [9]:
class_name_list = ['013.Bobolink', '017.Cardinal']

This notebook assumes you have already trained your TensorFlow model, which results in model artifacts being available in S3. Update the `training_job_name` variable to refer to your specific training job so that the notebook has a full s3 URI to the model artifacts. These same model artifacts were used for deployment in a SageMaker hosted endpoint in the previous lab. In this lab, we demonstrate batch predictions with the same trained model.

In [10]:
sess = sagemaker.Session()
bucket = sess.default_bucket()
prefix = 'DEMO-TF-image-classification-birds'
training_job_name = 'mpr-tf-ic-2019-09-08-04-24-51-045'

model_artifacts = 's3://{}/{}/output/model.tar.gz'.format(bucket, training_job_name)
print(model_artifacts)

s3://sagemaker-us-east-1-355151823911/mpr-tf-ic-2019-09-08-04-24-51-045/output/model.tar.gz


Here we instantiate a Model object pointing to the trained model artifacts and referring to the TensorFlow Serving image that will be used to drive inference on that model.

In [11]:
client = boto3.client('sagemaker')

model_name = 'mpr-tf-ic-gpu-{}'.format(strftime("%d-%H-%M-%S", gmtime()))

tf_serving_model = Model(model_data=model_artifacts,
                         role=sagemaker.get_execution_role(),
                         image='520713654638.dkr.ecr.us-east-1.amazonaws.com/sagemaker-tensorflow-serving:1.12-gpu',
                         framework_version='1.12', # 1.13.1-gpu not found; 1.12 works even if model trained in 1.13.1
                         sagemaker_session=sess)

batch_instance_type = 'ml.p3.2xlarge'
tf_serving_container = tf_serving_model.prepare_container_def(batch_instance_type)
model_params = {
    'ModelName': model_name,
    'Containers': [
        tf_serving_container
    ],
    'ExecutionRoleArn': sagemaker.get_execution_role()
}

client.create_model(**model_params)

{'ModelArn': 'arn:aws:sagemaker:us-east-1:355151823911:model/mpr-tf-ic-gpu-08-04-36-18',
 'ResponseMetadata': {'RequestId': '9c4dbcd5-42e3-4487-a379-3cb96e836203',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amzn-requestid': '9c4dbcd5-42e3-4487-a379-3cb96e836203',
   'content-type': 'application/x-amz-json-1.1',
   'content-length': '87',
   'date': 'Sun, 08 Sep 2019 04:36:17 GMT'},
  'RetryAttempts': 0}}

SageMaker batch transformations require input to be specified in s3, and you need to provide an s3 output path where SageMaker will save the resulting predictions.

In [12]:
input_data_path = 's3://sagemaker-us-east-1-355151823911/DEMO-TF-image-classification-birds/train/'
output_data_path = 's3://{}/{}/{}'.format(bucket, prefix, 'batch-predictions')
print(output_data_path)

s3://sagemaker-us-east-1-355151823911/DEMO-TF-image-classification-birds/batch-predictions


Before we run the batch transformation, we first remove prior batch prediction results. In production, you would likely instead tag the folder with a timestamp and retain the results from each run of the batch.

In [13]:
if input('Are you sure you want to remove the old batch predictions?') == 'yes':
    !aws s3 rm --quiet --recursive $output_data_path

Are you sure you want to remove the old batch predictions?no


Likewise, to interpret the results, we copy them down to our local folder. If we have done this before, we first remove the old results.

In [14]:
if input('Are you sure you want to remove the prior local batch predictions from ./batch_predictions') == 'yes':
    !rm -rf ./batch_predictions/*

Are you sure you want to remove the prior local batch predictions from ./batch_predictionsno


Here we kick off the batch prediction job using the SageMaker Transformer object.

In [15]:
batch_instance_count = 2
concurrency = 100

transformer = sagemaker.transformer.Transformer(
    model_name = model_name,
    instance_count = batch_instance_count,
    instance_type = batch_instance_type,
    max_concurrent_transforms = concurrency,
    output_path = output_data_path,
    base_transform_job_name='tf-birds-image-transform')

transformer.transform(data = input_data_path, content_type = 'application/x-image')
transformer.wait()

................................................!


To facilitate evaluation of the output, we download the results to our local folder.

In [16]:
!aws s3 cp --quiet --recursive $output_data_path ./batch_predictions

In [17]:
import json
import re
import os
import glob
import numpy as np

total = 0
correct = 0

predicted = []
actual = []

for entry in glob.glob('batch_predictions/*/*'):
    try:
        actual_label = entry.split('/')[1]
        actual_index = class_name_list.index(actual_label)
        with open(entry, 'r') as f:
            jstr = json.load(f)
            results = [float('%.3f'%(item)) for sublist in jstr['predictions'] for item in sublist]
            class_index = np.argmax(np.array(results))
            predicted_label = class_name_list[class_index]
            predicted.append(class_index)
            actual.append(actual_index)
            is_correct = (predicted_label == actual_label) or False
            if is_correct:
                correct += 1
            total += 1
    except Exception as e:
        print(e)
        continue

In [20]:
print('Out of {} total images, accurate predictions were returned for {}'.format(total, correct))
accuracy = correct / total
print('Accuracy is {:.1%}'.format(accuracy))

Out of 70 total images, accurate predictions were returned for 70
Accuracy is 100.0%
