# 09 - Calculating and visualising the SHAP values

In this notebook we:


  * demonstate how to use the `explain` module
  * visualise both global and local shapley values
  * rank the scored dataset from riskiest to least risky record


This module uses the [SHAP library](https://shap.readthedocs.io/en/latest/)

Pycorrectmatch moves from real record values to the copula data representation, where 1 is the most common entry in the column, 2 is the second and so on until the number of unique values for each feature.

Our baseline record could be the least unique record, i.e. it has the most common values of each feature. In the case of the copula representation, this is a vector of 1s, size 1 by n the number of features. The individual uniqueness of this vector would be a number close to zero, as this represents the least unique record.

This is useful to for explainability. We can use this baseline as something that is effectively the same as not having information. The individual shapley values per record then would add up approx to their individual privacy risk score:

`model.predict(X.iloc[[0]])= shap_result.base_values[0] + sum(shap_result.values[0])`


where `shap_result.base_values[0] is the baseline and we assume ~0`

In [None]:
import json
import os
from collections import defaultdict

import seaborn as sns

In [None]:
from privacy_fingerprint.common.config import (
    load_experiment_config,
    load_experiment_config_from_file,
    load_global_config_from_file,
)

# Example config files are available in the config directory.
# They will need to be modified with the path to the Julia executable

load_global_config_from_file("../configs/global_config.yaml")
load_experiment_config_from_file("../configs/experiment_config.yaml")

In [None]:
experiment_config = load_experiment_config()
experiment_config.scoring.encoding_scheme = "rarest"

In [None]:
import privacy_fingerprint.extract.aws_comprehend as aws
from privacy_fingerprint.explain import PrivacyRiskExplainer
from privacy_fingerprint.score import PrivacyRiskScorer, encode, preprocess

In [None]:
# The dataset will be loaded from the directory created in notebook 2.
output_dir = "../experiments/02_generate_dataset/"

with open(os.path.join(output_dir, "ner_dataset.json")) as fp:
    ner_records = json.load(fp)

In [None]:
# The format of the NER records must be standardised to enable scoring
common_ner_results = aws.prepare_common_records(
    aws.DEFAULT_IDENTIFIERS, ner_records
)

In [None]:
pcm_dataset = preprocess(common_ner_results)

In [None]:
# we keep a limited number of columns for the purposes of the example since
# shapley values take a while to be calculated
cols_to_keep = ["gender", "ethnicity", "disease", "treatment", "prescriptions"]

In [None]:
def simplify_ethnicity(text):
    text = text.lower()
    if text == "":
        return "unknown"
    mentions = defaultdict(int)
    for ethnicity, label in [
        ("white", "white"),
        ("black", "black"),
        ("african", "black"),
        ("asian", "asian"),
        ("indian", "asian"),
        ("pakistani", "asian"),
        ("chinese", "asian"),
    ]:
        if ethnicity in text:
            mentions[label] += 1
    if len(mentions) > 1:
        return "mixed"
    elif len(mentions) == 1:
        return list(mentions.keys())[0]
    else:
        return "unknown"

In [None]:
transformations = {
    "gender": lambda x: x.lower()
    if x.lower() in ["female", "male"]
    else "unknown",
    "ethnicity": simplify_ethnicity,
}

Pycorrectmatch required the dataset to be encoded, as we have seen in the other notebooks.

In [None]:
encoded_dataset, lookup = encode(
    pcm_dataset[cols_to_keep].transform(
        {i: transformations.get(i, lambda x: x) for i in cols_to_keep}
    )
)

Create the privacy risk scorer to transform the dataset to the values of the copula

In [None]:
scorer = PrivacyRiskScorer()
pop_uniqueness = scorer.calculate_population_uniqueness(encoded_dataset)
print("Population uniqueness: ", pop_uniqueness)
# Here we fit the model, this has to happen first before calculating scores or transforming
scorer.fit(encoded_dataset)
# This is the transformed dataset from the real record values to the marginal values
transformed_dataset = scorer.map_records_to_copula(encoded_dataset)
N_FEATURES = encoded_dataset.shape[1]
print(N_FEATURES)
# Calculating individual privacy risk scores
pcm_scored_dataset = scorer.predict(encoded_dataset)

Create the explainer to pass the transformed dataframe

In [None]:
# SHAP takes a while to run - a progress bar appears when running SHAP
explainer = PrivacyRiskExplainer(scorer.predict_transformed, N_FEATURES)
# Calculating shapley values using the transformed_dataset
local_shapley_df, global_shap, exp_obj = explainer.explain(transformed_dataset)

# Visualise global and local shapley values

The SHAP library has plotting functions that can visualise the shap results.

In [None]:
# Plot the mean shap values - global explanation
explainer.plot_global_explanation(exp_obj)

In [None]:
# Plot the local shap values for a particular record
explainer.plot_local_explanation(exp_obj, 985)

# Rank records by overall privacy risk

The sum of the individual shapley values should be equal to the individual privacy risk score. 

In [None]:
# this is the original record dataset sorted by individual risk score
sorted_pcm_df = pcm_dataset[cols_to_keep].loc[
    pcm_scored_dataset.sort_values(ascending=False).index
]

The cell below is equivalent to sorting it by descending order of the shapley sums per row


 `s = local_shapley_df.sum(axis=1)`

In [None]:
ranked_local_shapley_df = local_shapley_df.loc[
    pcm_scored_dataset.sort_values(ascending=False).index
]

ranked_local_shapley_df_w_score = ranked_local_shapley_df.copy(deep=True)
ranked_local_shapley_df_w_score["score"] = pcm_scored_dataset.sort_values(
    ascending=False
)

The following heatmap visualises the individual shap values on the ranked dataframe- ranked from riskiest to least risky record.

In [None]:
g = sns.heatmap(
    ranked_local_shapley_df,
    cmap=sns.light_palette("r", as_cmap=True),
    annot=False,
)
g.set_xticklabels(g.get_xticklabels(), rotation=45, fontsize=8)

In [None]:
# Plot the local shap values for the riskiest record
explainer.plot_local_explanation(exp_obj, ranked_local_shapley_df.iloc[0].name)