# Run Model Explanation with SageMaker Clarify (Post-Training)

## Using SageMake Processing Jobs

In [None]:
import boto3
import sagemaker
import pandas as pd
import numpy as np

sess   = sagemaker.Session()
bucket = sess.default_bucket()
role = sagemaker.get_execution_role()
region = boto3.Session().region_name

sm = boto3.Session().client(service_name='sagemaker', region_name=region)

In [None]:
%store -r training_job_name

In [None]:
try:
    training_job_name
    print('[OK]')
except NameError:
    print('+++++++++++++++++++++++++++++++')
    print('[ERROR] Please run the notebooks in the previous TRAIN section before you continue.')
    print('+++++++++++++++++++++++++++++++')

In [None]:
print(training_job_name)

In [None]:
# # updated inference.py with correct json parsing
# training_job_name='tensorflow-training-2021-01-27-23-37-00-180'

# Get Data

In [None]:
import pandas as pd

data = pd.read_json('./data-clarify/amazon_reviews_us_giftcards_software_videogames_balanced.jsonl', lines=True)
data.head()

In [None]:
data.shape

### Data inspection
Plotting histograms for the distribution of the different features is a good way to visualize the data. Let's plot a few of the features that can be considered _sensitive_.  
Let's take a look specifically at the Sex feature of a census respondent. In the first plot we see that there are fewer Female respondents as a whole but especially in the positive outcomes, where they form ~$\frac{1}{7}$th of respondents.

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format='retina'

In [None]:
data['star_rating'].value_counts().sort_values().plot(kind='bar', title='Count Reviews per Category', rot=0)

In [None]:
data['product_category'].value_counts().sort_values().plot(kind='bar', title='Count Reviews per Category', rot=0)

# Create Model

In [None]:
import sagemaker

inference_image_uri = sagemaker.image_uris.retrieve(
    framework="tensorflow",
    region=region,
    version="2.3.1",
    py_version="py37",
    instance_type='ml.m5.4xlarge',
    image_scope="inference"
)
print(inference_image_uri)

In [None]:
model_name = sess.create_model_from_job(
    training_job_name=training_job_name,
    image_uri=inference_image_uri
)

In [None]:
print(model_name)

# Explaining Predictions with Amazon SageMaker Clarify

There are expanding business needs and legislative regulations that require explainations of _why_ a model mades the decision it did. SageMaker Clarify uses SHAP to explain the contribution that each input feature makes to the final decision.

In [None]:
from sagemaker import clarify

clarify_processor = clarify.SageMakerClarifyProcessor(role=role,
                                                      instance_count=1,
                                                      instance_type='ml.c5.2xlarge',
                                                      sagemaker_session=sess)

# patch image in Thundera devo account. It is the initial launch image plus the JSONLines bug fix
clarify_processor.image_uri = "678264136642.dkr.ecr.us-east-1.amazonaws.com/sagemaker-clarify-processing:1.0_jsonlines_patch"

## Writing DataConfig and ModelConfig
A `DataConfig` object communicates some basic information about data I/O to Clarify. We specify where to find the input dataset, where to store the output, the target column (`label`), the header names, and the dataset type.

Similarly, the `ModelConfig` object communicates information about your trained model and `ModelPredictedLabelConfig` provides information on the format of your predictions.  

**Note**: To avoid additional traffic to your production models, SageMaker Clarify sets up and tears down a dedicated endpoint when processing. `ModelConfig` specifies your preferred instance type and instance count used to run your model on during Clarify's processing.

In [None]:
post_train_dataset = data[['star_rating', 'product_category', 'review_body']]
post_train_dataset.shape

In [None]:
post_train_dataset.head()

# Select n samples across each category

## _Note: We need to have >1 product_categories (facet values) in our training data._

In [None]:
# TODO:  CHANGE THIS BACK TO .sample(10)
# TODO:  REMOVE THE INDEX (0, 1, etc)
# TODO:  REMOVE THE DUPLICATE product_category IN THIS GROUPBY
# TODO:  WHY ARE WE USING groupby?
bias_data = post_train_dataset.groupby('product_category', group_keys=False).apply(lambda s: s.sample(5))
bias_data.reset_index(drop=True, inplace=True)
bias_data.shape

In [None]:
bias_data

# Convert To Categorical dtype

## _Note: Otherwise Clarify converts (int) and (str) to continuous threshold values instead of categorial values._

In [None]:
bias_data['product_category'] = bias_data['product_category'].astype('category')

In [None]:
bias_data['product_category'].dtype

In [None]:
bias_data['star_rating'] = bias_data['star_rating'].astype('category')

In [None]:
bias_data['star_rating'].dtype

In [None]:
bias_data

# Create Explainability Data Set Without Label Column

In [None]:
explainability_data = bias_data.drop(['star_rating'], axis=1)
explainability_data.shape

In [None]:
explainability_data.head()

# Convert to `jsonlines` Format and Upload To S3

In [None]:
path = './data-clarify/test_explainability.jsonl'

In [None]:
# path = './data-clarify/post_train_data_explainability.jsonl'
# data.to_json(path, orient="records", lines=True)

In [None]:
post_train_dataset_explainability_s3_uri = sess.upload_data(bucket=bucket, key_prefix=training_job_name, path=path)
post_train_dataset_explainability_s3_uri

# Configure Clarify

In [None]:
from sagemaker import clarify

model_config = clarify.ModelConfig(model_name=model_name,
                                   instance_type='ml.m5.4xlarge',
                                   instance_count=1,
                                   content_type='application/jsonlines',
                                   accept_type='application/jsonlines',
                                   content_template='{"features":$features}')

In [None]:
explainability_output_path = 's3://{}/clarify-explainability'.format(bucket)

explainability_data_config = clarify.DataConfig(s3_data_input_path=post_train_dataset_explainability_s3_uri,
                                s3_output_path=explainability_output_path,
#                                label='star_rating',
                                headers=['review_body', 'product_category'],
                                features='features',
                                dataset_type='application/jsonlines')

In [None]:
shap_config = clarify.SHAPConfig(baseline=post_train_dataset_explainability_s3_uri, # [explainability_data.iloc[0].values.tolist()],
                                 num_samples=5,
                                 agg_method='mean_abs')

## _Note: `label` is set to the JSON key for the model prediction results_

In [None]:
#predictions_config = clarify.ModelPredictedLabelConfig(label='predicted_label')

## Run Clarify

In [None]:
from sagemaker import clarify

clarify_processor = clarify.SageMakerClarifyProcessor(role=role,
                                                      instance_count=1,
                                                      instance_type='ml.c5.2xlarge',
                                                      sagemaker_session=sess)

In [None]:
clarify_processor.run_explainability(model_config=model_config,
                                     model_scores='predicted_label',
                                     data_config=explainability_data_config,                                     
                                     explainability_config=shap_config,
                                     wait=False,
                                     logs=False)

In [None]:
run_explainability_job_name = clarify_processor.latest_job.job_name
run_explainability_job_name

In [None]:
from IPython.core.display import display, HTML

display(HTML('<b>Review <a target="blank" href="https://console.aws.amazon.com/sagemaker/home?region={}#/processing-jobs/{}">Processing Job</a></b>'.format(region, run_explainability_job_name)))


In [None]:
from IPython.core.display import display, HTML

display(HTML('<b>Review <a target="blank" href="https://console.aws.amazon.com/cloudwatch/home?region={}#logStream:group=/aws/sagemaker/ProcessingJobs;prefix={};streamFilter=typeLogStreamPrefix">CloudWatch Logs</a> After About 5 Minutes</b>'.format(region, run_explainability_job_name)))


In [None]:
from IPython.core.display import display, HTML

display(HTML('<b>Review <a target="blank" href="https://s3.console.aws.amazon.com/s3/buckets/{}/{}/?region={}&tab=overview">S3 Output Data</a> After The Processing Job Has Completed</b>'.format(bucket, run_explainability_job_name, region)))


In [None]:
running_processor = sagemaker.processing.ProcessingJob.from_processing_name(processing_job_name=run_explainability_job_name,
                                                                            sagemaker_session=sess)

processing_job_description = running_processor.describe()

print(processing_job_description)

In [None]:
running_processor.wait(logs=False)

#### Viewing the Explainability Report
As with the bias report, you can view the explainability report in Studio under the experiments tab


<img src="img/explainability_detail.gif">

The Model Insights tab contains direct links to the report and model insights.

If you're not a Studio user yet, as with the Bias Report, you can access this report at the following S3 bucket.

# Download Report From S3

In [None]:
!aws s3 ls $explainability_output_path/

In [None]:
!aws s3 cp --recursive $explainability_output_path ./explainability_report/

In [None]:
from IPython.core.display import display, HTML

display(HTML('<b>Review <a target="blank" href="./explainability_report/report.html">Explainability Report</a></b>'))


# Release Resources

In [None]:
%%html

<p><b>Shutting down your kernel for this notebook to release resources.</b></p>
<button class="sm-command-button" data-commandlinker-command="kernelmenu:shutdown" style="display:none;">Shutdown Kernel</button>
        
<script>
try {
    els = document.getElementsByClassName("sm-command-button");
    els[0].click();
}
catch(err) {
    // NoOp
}    
</script>

In [None]:
%%javascript

try {
    Jupyter.notebook.save_checkpoint();
    Jupyter.notebook.session.delete();
}
catch(err) {
    // NoOp
}