# Emotion classification multiclass example

This notebook demonstrates how to use the `Partition` explainer for a multiclass text classification scenario. Once the SHAP values are computed for a set of sentences we then visualize feature attributions towards individual classes. The text classifcation model we use is BERT fine-tuned on an emotion dataset to classify a sentence among six classes: joy, sadness, anger, fear, love and surprise.

In [68]:
pip install datasets

Looking in indexes: http://mirrors.aliyun.com/pypi/simple
Note: you may need to restart the kernel to use updated packages.


In [69]:
pip install transformers

Looking in indexes: http://mirrors.aliyun.com/pypi/simple
Note: you may need to restart the kernel to use updated packages.


In [70]:
pip install shap

Looking in indexes: http://mirrors.aliyun.com/pypi/simple
Note: you may need to restart the kernel to use updated packages.


In [71]:
import pandas as pd


data = pd.read_csv("./data/train.jsonl")


print(data.head())

                   {"text":"i didnt feel humiliated"  label:0}
0  {"text":"i can go from feeling so hopeless to ...  label:0}
1  {"text":"im grabbing a minute to post i feel g...  label:3}
2  {"text":"i am ever feeling nostalgic about the...  label:2}
3                     {"text":"i am feeling grouchy"  label:3}
4  {"text":"ive been feeling a little burdened la...  label:0}


In [72]:
import pandas as pd


file_path = "./data/train.jsonl"


data = pd.read_json(file_path, lines=True)


print(data.head())

                                                text  label
0                            i didnt feel humiliated      0
1  i can go from feeling so hopeless to so damned...      0
2   im grabbing a minute to post i feel greedy wrong      3
3  i am ever feeling nostalgic about the fireplac...      2
4                               i am feeling grouchy      3


In [73]:
import pandas as pd


data = pd.read_json('./data/train.jsonl', lines=True)


data = pd.DataFrame({"text": data["text"], "emotion": data["label"]})


print(data.head())

                                                text  emotion
0                            i didnt feel humiliated        0
1  i can go from feeling so hopeless to so damned...        0
2   im grabbing a minute to post i feel greedy wrong        3
3  i am ever feeling nostalgic about the fireplac...        2
4                               i am feeling grouchy        3


In [74]:
import transformers


model_path = "./model/"


tokenizer = transformers.AutoTokenizer.from_pretrained(model_path, use_fast=True)


model = transformers.AutoModelForSequenceClassification.from_pretrained(model_path).cuda()


pred = transformers.pipeline(
    "text-classification",
    model=model,
    tokenizer=tokenizer,
    device=0,
    return_all_scores=True,
)

`return_all_scores` is now deprecated,  if want a similar functionality use `top_k=None` instead of `return_all_scores=True` or `top_k=1` instead of `return_all_scores=False`.


### Build a transformers pipline

Note that we have set `return_all_scores=True` for the pipeline so we can observe the model's behavior for all classes, not just the top output.

### Create an explainer for the pipeline

A transformers pipeline object can be passed directly to `shap.Explainer`, which will then wrap the pipeline model as a `shap.models.TransformersPipeline` model and the pipeline tokenizer as a `shap.maskers.Text` masker.

In [75]:
import shap
explainer = shap.Explainer(pred)

### Compute SHAP values

Explainers have the same method signature as the models they are explaining, so we just pass a list of strings for which to explain the classifications.

In [76]:
shap_values = explainer(data["text"][:3])

  0%|          | 0/498 [00:00<?, ?it/s]

### Visualize the impact on all the output classes

In the plots below, when you hover your mouse over an output class you get the explanation for that output class. When you click an output class name then that class remains the focus of the explanation visualization until you click another class.

The base value is what the model outputs when the entire input text is masked, while $f_{output class}(inputs)$ is the output of the model for the full original input. The SHAP values explain in an addive way how the impact of unmasking each word changes the model output from the base value (where the entire input is masked) to the final prediction value.

In [77]:
shap.plots.text(shap_values)

### Visualize the impact on a single class

Since `Explanation` objects are sliceable we can slice out just a single output class to visualize the model output towards that class.

In [80]:
shap.plots.text(shap_values[:, :, "anger"])

ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (3,) + inhomogeneous part.

In [79]:
shap.plots.text(shap_values[0])

### Plotting the top words impacting a specific class

In addition to slicing, `Explanation` objects also support a set of reducing methods. Here we use the `.mean(0)` to take the average impact of all words towards the "joy" class. Note that here we are also averaging over three examples, to get a better summary you would want to use a larger portion of the dataset.

In [63]:
shap.plots.bar(shap_values[:, :, "joy"].mean(0))

ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (3,) + inhomogeneous part.

In [64]:
# we can sort the bar chart in decending order
shap.plots.bar(shap_values[:, :, "joy"].mean(0), order=shap.Explanation.argsort)

ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (3,) + inhomogeneous part.

In [65]:
# ...or acending order
shap.plots.bar(shap_values[:, :, "joy"].mean(0), order=shap.Explanation.argsort.flip)

ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (3,) + inhomogeneous part.

### Explain the log odds instead of the probabilities

In the examples above we explained the direct output of the pipline object, which are class probabilities. Sometimes it makes more sense to work in a log odds space where it is natural to add and subtract effects (addition and subtraction correspond to the addition or subtraction of bits of evidence information). To work with logits we can use a parameter of the `shap.models.TransformersPipeline` object:

In [66]:
logit_explainer = shap.Explainer(
    shap.models.TransformersPipeline(pred, rescale_to_logits=True)
)

logit_shap_values = logit_explainer(data["text"][:3])
shap.plots.text(logit_shap_values)

  0%|          | 0/498 [00:00<?, ?it/s]

<hr>
Have an idea for more helpful examples? Pull requests that add to this documentation notebook are encouraged! 