In [None]:
import boto3
import sagemaker
import pandas as pd

sess   = sagemaker.Session()
bucket = sess.default_bucket()
role = sagemaker.get_execution_role()
region = boto3.Session().region_name

sm = boto3.Session().client(service_name='sagemaker', region_name=region)

In [None]:
!pip install --upgrade sagemaker==1.56.1

In [None]:
!pip list

# Setup Input Data

In [None]:
%store -r scikit_processing_job_name

In [None]:
print(scikit_processing_job_name)

In [None]:
# scikit_processing_job_s3_output_prefix = 'data'
print('Previous Scikit Processing Job Name: {}'.format(scikit_processing_job_name))

In [None]:
prefix_test = '{}/output/bert-test'.format(scikit_processing_job_name)

test_s3_uri = 's3://{}/{}'.format(bucket, prefix_test)

In [None]:
print(test_s3_uri)

!aws s3 ls $test_s3_uri/

# Setup Batch Transform Model

In [None]:
%store -r training_job_name

In [None]:
print(training_job_name)

In [None]:
!saved_model_cli show --all --dir ./tensorflow/saved_model/0/

In [None]:
from sagemaker.tensorflow.serving import Model

# If you change SAGEMAKER_TFS_DEFAULT_MODEL_NAME to something other than 'saved_model', you may see the dreaded ping error in the logs error
batch_env = {
  'SAGEMAKER_TFS_DEFAULT_MODEL_NAME': 'saved_model', # <== change this when using multi-model,
                                                     #     but watch out for the dreaded ping/ error 
                                                     #     if the model name doesn't exist
  'SAGEMAKER_TFS_ENABLE_BATCHING': 'true',
  'SAGEMAKER_TFS_BATCH_TIMEOUT_MICROS': '50000',
  'SAGEMAKER_TFS_MAX_BATCH_SIZE': '16'
}

batch_model = Model(entry_point='inference.py', # <== This cannot change. See https://github.com/aws/sagemaker-python-sdk/issues/1451
                    source_dir='src_batch_tfrecord',
                    model_data='s3://{}/{}/output/model.tar.gz'.format(bucket, training_job_name),
                    role=role,
                    framework_version='2.1.0',
                    env=batch_env)

In [None]:
batch_predictor = batch_model.transformer(instance_count=1, 
                                          strategy='MultiRecord', 
                                          instance_type='ml.c5.9xlarge',
                                          assemble_with='Line',
                                          max_payload=100, # This is in Megabytes (not number of records)
                                          env=batch_env)

# Start Batch Predictions

In [None]:
batch_predictor.transform(data=test_s3_uri, 
                          split_type='TFRecord',
                          compression_type=None,
                          experiment_config=None,
                          wait=False)

In [None]:
from IPython.core.display import display, HTML

display(HTML('<b>Review <a href="https://console.aws.amazon.com/sagemaker/home?region={}#/transform-jobs/{}?region={}&tab=Monitor">Batch Prediction Job</a></b>'.format(region, batch_predictor.latest_transform_job.job_name, region)))


In [None]:
from IPython.core.display import display, HTML

display(HTML('<b>Review <a href="https://console.aws.amazon.com/cloudwatch/home?region={}#logStream:group=/aws/sagemaker/TransformJobs;prefix={};streamFilter=typeLogStreamPrefix">CloudWatch Logs</a></b>'.format(region, batch_predictor.latest_transform_job.job_name)))


In [None]:
from IPython.core.display import display, HTML

display(HTML('<b>Review <a href="https://console.aws.amazon.com/s3/buckets/{}/{}/?region={}">Batch Prediction S3 Output</a></b>'.format(bucket, batch_predictor.latest_transform_job.job_name, region)))


In [None]:
print('Waiting for batch prediction job: ' + batch_predictor.latest_transform_job.job_name)

batch_predictor.wait(logs=False)

# Check Output Data

After the transform job has completed, download the output data from S3.

For each file "f" in the input data, we have a corresponding file "f.out" containing the predicted labels from each input row. 

We can compare the predicted labels to the true labels saved earlier.


In [None]:
# Download the output data from S3 to local filesystem
batch_prediction_output_s3_uri = batch_predictor.output_path

# !mkdir -p ./batch_prediction_output

In [None]:
%%bash 

aws s3 cp --recursive $batch_prediction_output_s3_uri/ ./batch_prediction_output

ls ./batch_prediction_output