In [1]:
import shap
import seaborn as sns
import plotly.graph_objects as go

import pandas as pd
import transformers
import torch
import numpy as np


In [17]:
tokenizer = transformers.AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest")
model = transformers.AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest")


Some weights of the model checkpoint at cardiffnlp/twitter-roberta-base-sentiment-latest were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [18]:
# text = ["I respect the visuals, I respect the performances but... nah."]
text = ["I was not disappointed with the result."]
encoded_input = tokenizer(text, return_tensors='pt')['input_ids']
words = [tokenizer.convert_tokens_to_string([x]).strip() for x in 
    tokenizer.convert_ids_to_tokens(encoded_input[0])]

with torch.no_grad():
    output = model(encoded_input).logits
    print(output.shape)
    scores = torch.softmax(output, 1)
scores


torch.Size([1, 3])


tensor([[0.2047, 0.5419, 0.2535]])

In [None]:
# https://shap.readthedocs.io/en/latest/example_notebooks/text_examples/sentiment_analysis/Emotion%20classification%20multiclass%20example.html

pred = transformers.pipeline(
    "text-classification",
    model=model,
    tokenizer=tokenizer,
    device=0,
    return_all_scores=True,
)

explainer = shap.Explainer(pred)


In [13]:
explainer.__class__


shap.explainers._partition.PartitionExplainer

In [20]:
shap_values = explainer(text)


In [21]:
shap.plots.text(shap_values)


In [109]:
links = []
labels = ["negative", "neutral", "positive"]
colors = {
    "negative": "rgba(0, 202, 255, 0.5)",
    "neutral": "rgba(112, 255, 145, 0.5)",
    "positive": "rgba(255, 0, 250, 0.3)"
}
# logits = np.clip(shap_values.values.squeeze(), a_min=1e-10, a_max=None)
logits = shap_values.values.squeeze()

for i, a in enumerate(zip(words, logits.tolist())):
    word, logit = a
    for j, b in enumerate(zip(labels, logit)):
        label, v = b
        if word == "<s>" or word == "</s>":
            v = 1e-5
        links.append(
            {
                'source': i,
                'target': len(words) + j,
                'value': v,
                'color': colors[label]
            }
        )


In [None]:
words


['<s>',
 'I',
 'respect',
 'the',
 'visuals',
 ',',
 'I',
 'respect',
 'the',
 'performances',
 'but',
 '...',
 'n',
 'ah',
 '.',
 '</s>']

In [108]:
logits


array([[ 0.        ,  0.        ,  0.        ],
       [-0.15133477,  0.0132775 ,  0.13805727],
       [ 0.06220398, -0.01353531, -0.04866867],
       [ 0.00905484, -0.01270499,  0.00365016],
       [ 0.04431918, -0.0186793 , -0.02563988],
       [-0.02031578, -0.02319141,  0.04350719],
       [-0.06339632, -0.01112055,  0.07451687],
       [ 0.06425602,  0.02058847, -0.08484448],
       [-0.01260841, -0.01619044,  0.02879884],
       [ 0.00299603, -0.01694284,  0.01394681],
       [ 0.0768812 ,  0.10103078, -0.17791197],
       [ 0.04006428,  0.00111538, -0.04117968],
       [ 0.03348455,  0.05938965, -0.0928742 ],
       [-0.0062831 ,  0.0081635 , -0.00188039],
       [ 0.04356562,  0.04611192, -0.08967755],
       [ 0.        ,  0.        ,  0.        ]])

In [110]:
df = pd.DataFrame(links)

fig = go.Figure(
    go.Sankey(
        # arrangement = "snap", 
        node={"label": words + labels,
              'x': [0.1] * len(words) + [0.5] * len(labels),
              "y": list(np.linspace(0.01,0.99,len(words))) + list(np.linspace(0.01,0.99,len(labels))),
              'color': "grey",
              'pad': 500
              },
            #   "y": [],
        link={
            "source": df["source"].tolist(),
            "target": df["target"].tolist(),
            "value": df["value"].tolist(),
            "color": df["color"].tolist(),
        },
    )
)

fig.show()
