# Model Explainability - Amazon SageMaker Clarify

**SageMaker Studio Kernel**: Data Science

In this exercise you will do:
 - Create an Amazon SageMaker Clarify job for evaluating feature importance for your model
 - Visualize results

***

## Part 1/4 - Setup
Here we'll import some libraries and define some variables.

### Import required modules

In [None]:
import boto3
import csv
from io import StringIO
import json
import numpy as np
import logging
import pandas as pd
import sagemaker.session

In [None]:
s3_client = boto3.client("s3")
s3_resource = boto3.resource("s3")
sagemaker_client = boto3.client("sagemaker")

region = boto3.session.Session().region_name
role_name = "mlops-sagemaker-execution-role"
role = "arn:aws:iam::{}:role/{}".format(boto3.client('sts').get_caller_identity().get('Account'), role_name)

kms_account_id = boto3.client('sts').get_caller_identity().get('Account')

In [None]:
logging.basicConfig(level=logging.INFO)
LOGGER = logging.getLogger(__name__)

***

## Part 2/4 - Create Model Predictor
During this step, we are creating a model predictor for a previously created SageMaker Endpoint

In [None]:
bucket_artifacts = ""
bucket_inference = ""

clarify_path = "data/clarify"

explainability_output_path = 's3://{}/data/monitoring/explainability'.format(bucket_artifacts)

inference_instance_count = 1
inference_instance_type = "ml.m5.xlarge"

kms_alias = "ml-kms"

model_package_group = "ml-end-to-end-group"

monitoring_output_path = "data/monitoring/captured"

processing_output_files_path = "data/output"

train_data = "s3://{}/{}/train/train.csv".format(bucket_artifacts, processing_output_files_path)
test_data = "s3://{}/{}/test/test.csv".format(bucket_artifacts, processing_output_files_path)

In [None]:
kms_key = "arn:aws:kms:{}:{}:alias/{}".format(region, kms_account_id, kms_alias)

In [None]:
boto_session = boto3.Session(region_name=region)

sagemaker_client = boto_session.client("sagemaker")
runtime_client = boto_session.client("sagemaker-runtime")

sagemaker_session = sagemaker.session.Session(
    boto_session=boto_session,
    sagemaker_client=sagemaker_client,
    sagemaker_runtime_client=runtime_client,
    default_bucket=bucket_inference
)

* Negative - 0
* Neutral - 1
* Positive - 2

In [None]:
from sagemaker.deserializers import CSVDeserializer
from sagemaker.serializers import CSVSerializer
from sagemaker.tensorflow.model import TensorFlowPredictor

predictor = TensorFlowPredictor(
    endpoint_name=model_package_group + "-dev",
    model_name="saved_model",
    model_version=1,
    accept_type="text/csv",
    serializer=CSVSerializer(),
    deserializer=CSVDeserializer()
)

In [None]:
inputs = ["ti imploro di guardare questo documentario. molto spaventoso e informativo. uno dei motivi esatti che sto eliminando fb entir"]

result = predictor.predict(inputs)

LOGGER.info("{}".format(result))

## Part 3/4 - Explainability
Here we are creating Amazon SageMaker Clarify jobs for model explainability

### Prepare data for Amazon SageMaker Clarify

In [None]:
df_test = pd.read_csv(
                test_data,
                sep=',',
                quotechar='"',
                quoting=csv.QUOTE_ALL,
                escapechar='\\',
                encoding='utf-8',
                error_bad_lines=False
            )
        
df_test = df_test.dropna()

In [None]:
df_test["len"] = df_test["text"].apply(lambda ele: len(ele))

In [None]:
num_examples = 300

df_test_clarify = pd.DataFrame(
    df_test.sample(n=num_examples),
    columns=["text"],
)

In [None]:
csv_buffer = StringIO()

df_test_clarify.to_csv(csv_buffer, header=True, index=False)

s3_resource.Object(bucket_inference, "{}/validation.csv".format(clarify_path)).put(Body=csv_buffer.getvalue())

### Create Amazon SageMaker Clarify Job

In [None]:
from sagemaker.clarify import (
    DataConfig,
    ModelConfig,
    SageMakerClarifyProcessor,
    SHAPConfig,
    TextConfig
)

To obtain feature importance for parts of an input text, create a TextConfig specifying the granularity of the parts of the text and the language. Clarify then breaks the text down into tokens, sentences, or paragraphs depending on your choice of granularity

In [None]:
text_config = TextConfig(
    language="english", 
    granularity="sentence" 
)

In [None]:
shap_config = SHAPConfig(
    baseline=[["<UNK>"]],
    num_samples=1000,
    agg_method="mean_abs",
    save_local_shap_values=True,
    text_config=text_config
)

In [None]:
explainability_data_config = DataConfig(
    s3_data_input_path="s3://{}/{}/validation.csv".format(bucket_inference, clarify_path),
    s3_output_path=explainability_output_path,
    headers=["text"],
    dataset_type="text/csv"
)

In [None]:
model_config = ModelConfig(
    model_name=predictor._get_model_names()[0],
    instance_type=inference_instance_type,  
    instance_count=inference_instance_count,
    accept_type="text/csv",
    content_type="text/csv"
)

Run an Amazon SageMaker Clarify Processing Job

In [None]:
clarify_processor = SageMakerClarifyProcessor(
    role=role, 
    instance_count=inference_instance_count, 
    instance_type=inference_instance_type, 
    sagemaker_session=sagemaker_session
)

In [None]:
clarify_processor.run_explainability(
    data_config=explainability_data_config,
    model_config=model_config,
    explainability_config=shap_config)

***

## Part 4/4 - Visualization

### Visualize local explanations

We use Captum to visualize the feature importances computed by Clarify. First, lets load the local explanations. Local text explanations can be found in the analysis results folder in a file named out.jsonl in the explanations_shap directory.

In [None]:
local_feature_attributions_file = "out.jsonl"
analysis_results = []
analysis_result = sagemaker.s3.S3Downloader.download(
    explainability_output_path + "/explanations_shap/" + local_feature_attributions_file,
    local_path="./../data",
)

shap_out = []
file = sagemaker.s3.S3Downloader.read_file(
    explainability_output_path + "/explanations_shap/" + local_feature_attributions_file
)
for line in file.split("\n"):
    if line:
        shap_out.append(json.loads(line))

The local explanations file is a JSON Lines file, that contains the explanation of one instance per row. Let's examine the output format of the explanations.

In [None]:
print(json.dumps(shap_out[0], indent=2))

At the highest level of this JSON Line, there are two keys: explanations, join_source_value (Not present here as we have not included a joinsource column in the input dataset). explanations contains a list of attributions for each feature in the dataset. In this case, we have a single element, because the input dataset also had a single feature. It also contains details like feature_name, data_type of the features (indicating whether Clarify inferred the column as numerical, categorical or text). Each token attribution also contains a description field that contains the token itself, and the starting index of the token in original input. This allows you to reconstruct the original sentence from the output as well.
In the following block, we create a list of attributions and a list of tokens for use in visualizations.

In [None]:
attributions_dataset = [
    np.array([attr["attribution"][0] for attr in expl["explanations"][0]["attributions"]])
    for expl in shap_out
]
tokens_dataset = [
    np.array(
        [attr["description"]["partial_text"] for attr in expl["explanations"][0]["attributions"]]
    )
    for expl in shap_out
]

We obtain predictions as well so that they can be displayed alongside the feature attributions.

In [None]:
preds = []

for t in df_test_clarify.values:
    preds.append(predictor.predict([t]))
    print(".", end="", flush=True)

In [None]:
! pip install captum

In [None]:
from captum.attr import visualization

In [None]:
%matplotlib inline

# This method is a wrapper around the captum that helps produce visualizations for local explanations. It will
# visualize the attributions for the tokens with red or green colors for negative and positive attributions.
def visualization_record(
    attributions,  # list of attributions for the tokens
    text,  # list of tokens
    pred,  # the prediction value obtained from the endpoint
    delta,
    true_label,  # the true label from the dataset
    normalize=True,  # normalizes the attributions so that the max absolute value is 1. Yields stronger colors.
    max_frac_to_show=0.05,  # what fraction of tokens to highlight, set to 1 for all.
    match_to_pred=False,  # whether to limit highlights to red for negative predictions and green for positive ones.
    # By enabling `match_to_pred` you show what tokens contribute to a high/low prediction not those that oppose it.
):
    if normalize:
        attributions = attributions / max(max(attributions), max(-attributions))
    if max_frac_to_show is not None and max_frac_to_show < 1:
        num_show = int(max_frac_to_show * attributions.shape[0])
        sal = attributions
        if pred < 0.5:
            sal = -sal
        if not match_to_pred:
            sal = np.abs(sal)
        top_idxs = np.argsort(-sal)[:num_show]
        mask = np.zeros_like(attributions)
        mask[top_idxs] = 1
        attributions = attributions * mask
    return visualization.VisualizationDataRecord(
        attributions,
        pred,
        int(pred > 0.5),
        true_label,
        attributions.sum() > 0,
        attributions.sum(),
        text,
        delta,
    )


In [None]:
# You can customize the following display settings
normalize = True
max_frac_to_show = 1
match_to_pred = False
labels = df_test["Sentiment"][:num_examples]
vis = []
for attr, token, pred, label in zip(attributions_dataset, tokens_dataset, preds, labels):
    vis.append(
        visualization_record(
            attr, token, float(pred[0][1]), 0.0, label, normalize, max_frac_to_show, match_to_pred
        )
    )


Now that we compiled the record we are finally ready to render the visualization.

We see a row per review in the selected dataset. For each row we have the prediction, the label, and the highlighted text. Additionally, we show the total sum of attributions (as attribution score) and its label (as attribution label), which indicates whether it is greater than zero.

In [None]:
_ = visualization.visualize_text(vis)

In [None]:
! rm -rf ./../data