# Explainability of speech classification model

This notebook demonstrates how to use the `Partition` explainer for a multiclass text classification scenario. Once the SHAP values are computed for a set of sentences we then visualize feature attributions towards individual classes. The text classifcation model we use is BERT fine-tuned on an emotion dataset to classify a sentence among six classes: joy, sadness, anger, fear, love and surprise.

In [None]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import torch
import pickle
import transformers
import datasets
import shap
import seaborn  as sns


# Load shap values

In [None]:
# Please fill this in
PATH_TO_SHAP_BAL = ""
PATH_TO_SHAP_IMBAL = ""
PATH_TO_SHAP_IMBAL_CW = ""

In [None]:
dataset = datasets.load_from_disk("../datasets/amz_reviews")
dataset = dataset["test"]
data = pd.DataFrame(
    {"text": dataset["review_body"], "label": dataset["stars"], "language" : dataset["language"]}
)
data = data.sample(4000, random_state=0)

shap_values_bal = pickle.load(open(PATH_TO_SHAP_BAL,"rb"))
shap_values_imbal = pickle.load(open(PATH_TO_SHAP_IMBAL,"rb"))
shap_values_imbal_cw = pickle.load(open(PATH_TO_SHAP_IMBAL_CW,"rb"))

In [None]:
df_shap_bal = pd.DataFrame(
    {
        "values": shap_values_bal.values,
        "base_values": list(shap_values_bal.base_values),
        "feature_names": shap_values_bal.feature_names,
    }
)

df_shap_imbal = pd.DataFrame(
    {
        "values": shap_values_imbal.values,
        "base_values": list(shap_values_imbal.base_values),
        "feature_names": shap_values_imbal.feature_names,
    }
)

df_shap_imbal_cw = pd.DataFrame(
    {
        "values": shap_values_imbal_cw.values,
        "base_values": list(shap_values_imbal_cw.base_values),
        "feature_names": shap_values_imbal_cw.feature_names,
    }
)

In [None]:
df_shap_bal["overall_shap_values"] = df_shap_bal.apply(lambda x : x["values"].sum(axis=0) +x["base_values"], axis=1)
df_shap_imbal["overall_shap_values"] = df_shap_imbal.apply(lambda x : x["values"].sum(axis=0) +x["base_values"], axis=1)
df_shap_imbal_cw["overall_shap_values"] = df_shap_imbal_cw.apply(lambda x : x["values"].sum(axis=0) +x["base_values"], axis=1)

In [None]:
# THIS IS TO GET PREDICTIONS DISTRIBUTIONS
df_shap_bal_with_language = df_shap_imbal.copy()
df_shap_bal_with_language["language"] = data.reset_index()["language"]
df_shap_bal_with_language["predictions"] = df_shap_bal_with_language["overall_shap_values"].apply(lambda x: np.argmax(x))
LANGUAGE_GROUP = {"de":0, "en": 0, "zh":0, "fr":1 , "es":1, "ja":1}
df_shap_bal_with_language["language_group"] = df_shap_bal_with_language["language"].apply(lambda x: LANGUAGE_GROUP[x])
# round to 1 decimal and multiply by 100
# then format by assing percentage sign and & in between to be latex reads
df_shap_bal_with_language.groupby("language_group")["predictions"].value_counts(normalize=True).apply(lambda x : round(x*100, 1)).apply(lambda x: str(x) + "\%").unstack().apply(lambda x: " & ".join(x), axis=1)


In [None]:
df_shap_all = pd.concat([
    df_shap_bal.add_suffix("_bal"),
    df_shap_imbal.add_suffix("_imbal"),
    df_shap_imbal_cw.add_suffix("_imbal_cw"),
],axis=1)

In [None]:
df_shap_all["contribution_diff_per_token_imbal"] = df_shap_all.apply(
    lambda x: x["values_imbal"] - x["values_bal"]
    ,
    axis=1,
)

df_shap_all["contribution_diff_per_token_imbal_cw"] = df_shap_all.apply(
    lambda x: x["values_imbal_cw"] - x["values_bal"]
    ,
    axis=1,
)

threshold = 0.01

df_shap_contrib_imbal = df_shap_all.apply(
    lambda x: pd.Series(
        dict(
            [
                (
                    "contribution_to_diff_from_neutral",
                    np.ma.array(
                        x["contribution_diff_per_token_imbal"],
                        mask=~(np.abs(x["values_bal"]) <= threshold),
                    ).sum(axis=0),
                ),
                (
                    "contribution_to_diff_from_positive",
                    np.ma.array(
                        x["contribution_diff_per_token_imbal"],
                        mask=~(x["values_bal"] > threshold),
                    ).sum(axis=0),
                ),
                (
                    "contribution_to_diff_from_negative",
                    np.ma.array(
                        x["contribution_diff_per_token_imbal"],
                        mask=~(x["values_bal"] < -threshold),
                    ).sum(axis=0),
                ),
                (
                    "contribution_to_diff_from_base_values",
                    x["base_values_imbal"] - x["base_values_bal"],
                ),
                ("overall_shap_values_bal", x["overall_shap_values_bal"]),
                ("overall_shap_values_imbal", x["overall_shap_values_imbal"]),
                ("overall_shap_values_imbal_cw", x["overall_shap_values_imbal_cw"]),
            ]
        )
    ),
    axis=1,
)
df_shap_contrib_imbal_cw = df_shap_all.apply(
    lambda x: pd.Series(
        dict(
            [
                (
                    "contribution_to_diff_from_neutral",
                    np.ma.array(
                        x["contribution_diff_per_token_imbal_cw"],
                        mask=~(np.abs(x["values_bal"]) <= threshold),
                    ).sum(axis=0),
                ),
                (
                    "contribution_to_diff_from_positive",
                    np.ma.array(
                        x["contribution_diff_per_token_imbal_cw"],
                        mask=~(x["values_bal"] > threshold),
                    ).sum(axis=0),
                ),
                (
                    "contribution_to_diff_from_negative",
                    np.ma.array(
                        x["contribution_diff_per_token_imbal_cw"],
                        mask=~(x["values_bal"] < -threshold),
                    ).sum(axis=0),
                ),
                (
                    "contribution_to_diff_from_base_values",
                    x["base_values_imbal_cw"] - x["base_values_bal"],
                ),
                ("overall_shap_values_bal", x["overall_shap_values_bal"]),
                ("overall_shap_values_imbal", x["overall_shap_values_imbal"]),
                ("overall_shap_values_imbal_cw", x["overall_shap_values_imbal_cw"]),
            ]
        )
    ),
    axis=1,
)

In [None]:
df_shap_contrib_imbal["language"] = data["language"].reset_index(drop=True)
df_shap_contrib_imbal_cw["language"] = data["language"].reset_index(drop=True)

df_shap_contrib_langavg_imbal = (df_shap_contrib_imbal.groupby("language")
    .apply(lambda df: df.apply(lambda x: np.ma.vstack(x).mean(axis=0), axis=0))
    .rename_axis(["language","label"])
    .reset_index())

df_shap_contrib_langavg_imbal_cw = (df_shap_contrib_imbal_cw.groupby("language")
    .apply(lambda df: df.apply(lambda x: np.ma.vstack(x).mean(axis=0), axis=0))
    .rename_axis(["language","label"])
    .reset_index())

In [None]:
sns.barplot(
    data=(
        df_shap_contrib_langavg_imbal.query("label==4")
        .rename(
            columns={
                "contribution_to_diff_from_neutral": "from_neutral",
                "contribution_to_diff_from_positive": "from_positive",
                "contribution_to_diff_from_negative": "from_negative",
                "contribution_to_diff_from_base_values": "from_base_values",
            }
        )
        .melt(
            id_vars=["language", "label"],
            value_vars=[
                "from_neutral",
                "from_positive",
                "from_negative",
                "from_base_values",
            ],
            var_name="contribution_type",
            value_name="contribution",
        )
    ),
    x="contribution",
    y="contribution_type",
    hue="language",
    hue_order=["en", "de", "zh", "fr", "es", "ja"],
)

plt.xlim(-0.16, 0.16)

In [None]:
sns.barplot(
    data=(
        df_shap_contrib_langavg_imbal_cw.query("label==4")
        .rename(
            columns={
                "contribution_to_diff_from_neutral": "from_neutral",
                "contribution_to_diff_from_positive": "from_positive",
                "contribution_to_diff_from_negative": "from_negative",
                "contribution_to_diff_from_base_values": "from_base_values",
            }
        )
        .melt(
            id_vars=["language", "label"],
            value_vars=[
                "from_neutral",
                "from_positive",
                "from_negative",
                "from_base_values",
            ],
            var_name="contribution_type",
            value_name="contribution",
        )
    ),
    x="contribution",
    y="contribution_type",
    hue="language",
    hue_order=["en", "de", "zh", "fr", "es", "ja"],
)

plt.xlim(-0.16, 0.16)