# Demo of XAI for sentiment analysis

This notebook consists of calculation of attribution using Integrated Gradients method for siebert model. It was inspired by and based on the official tutorial: [https://captum.ai/tutorials/Bert_SQUAD_Interpret](https://captum.ai/tutorials/Bert_SQUAD_Interpret)

In [None]:
from transformers import pipeline
from captum.attr import LayerIntegratedGradients, visualization
import torch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
sentiment_analysis = pipeline("sentiment-analysis",model="siebert/sentiment-roberta-large-english")
print(sentiment_analysis("I love this!"))

In [None]:
model = sentiment_analysis.model
model

In [None]:
tokenizer = sentiment_analysis.tokenizer
tokenizer

In [None]:
def predict(inputs, position_ids=None, attention_mask=None):
    output = model(inputs, position_ids=position_ids, attention_mask=attention_mask, )
    return output.logits

In [None]:
ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence

In [None]:
def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):
    text_ids = tokenizer.encode(text, add_special_tokens=False)

    # construct input token ids
    input_ids = [cls_token_id] + text_ids  + [sep_token_id]

    # construct reference token ids
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)

def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)

def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
    seq_len = input_ids.size(1)
    token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)
    ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)# * -1
    return token_type_ids, ref_token_type_ids

def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)

    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids

We can input text and attribution label with respect to which the attribution will be calculated.
Labels:
* 0 - negative
* 1 - positive

So, for instance, the text below has clearly negative sentiment, but we will calculate the attribution with respect to the positive label. We will obtain info which tokens contribute "anti" positive label.

In [None]:
text = "Today is a terrible day and i cant stop crying"
attribution_label = torch.tensor([[1]])

input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id)
position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
attention_mask = construct_attention_mask(input_ids)

indices = input_ids[0].detach().tolist()
all_tokens = tokenizer.convert_ids_to_tokens(indices)

In [None]:
scores = predict(input_ids, attention_mask=attention_mask, position_ids=position_ids)

print('Text: ', text)
print('Tokens', all_tokens)
print('Predicted Sentiment: ', scores)

In [None]:
sentiment_analysis(text)

In [None]:
torch.softmax(scores, 1)

In [None]:
lig = LayerIntegratedGradients(predict, model.roberta.embeddings)

attributions, delta = lig.attribute(inputs=input_ids,
                                    target=attribution_label,
                                    baselines=ref_input_ids,
                                    additional_forward_args=(position_ids,attention_mask),
                                    return_convergence_delta=True)

In [None]:
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

attributions_sum = summarize_attributions(attributions)

In [None]:
vis = visualization.VisualizationDataRecord(
    word_attributions=attributions_sum,
    pred_prob=torch.max(torch.softmax(scores[0], dim=0)),
    pred_class=torch.argmax(scores[0]),
    true_class=torch.argmax(scores[0]),
    attr_class=str(attribution_label),
    attr_score=attributions_sum.sum(),
    raw_input_ids=all_tokens,
    convergence_score=delta)

print('\033[1m', 'Visualizations', '\033[0m')
visualization.visualize_text([vis])

If we change the text, the attribution also changes.


In [None]:
text = "Today is a beautiful day and i cant stop smiling"
attribution_label = torch.tensor([[1]])

input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id)
position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
attention_mask = construct_attention_mask(input_ids)

indices = input_ids[0].detach().tolist()
all_tokens = tokenizer.convert_ids_to_tokens(indices)
scores = predict(input_ids, attention_mask=attention_mask, position_ids=position_ids)

print('Text: ', text)
print('Tokens', all_tokens)
print('Predicted Sentiment: ', scores, " HuggingFace model:", sentiment_analysis(text))

attributions, delta = lig.attribute(inputs=input_ids,
                                    target=attribution_label,
                                    baselines=ref_input_ids,
                                    additional_forward_args=(position_ids, attention_mask),
                                    return_convergence_delta=True)

attributions_sum = summarize_attributions(attributions)
vis = visualization.VisualizationDataRecord(
    word_attributions=attributions_sum,
    pred_prob=torch.max(torch.softmax(scores[0], dim=0)),
    pred_class=torch.argmax(scores[0]),
    true_class=torch.argmax(scores[0]),
    attr_class=str(attribution_label),
    attr_score=attributions_sum.sum(),
    raw_input_ids=all_tokens,
    convergence_score=delta)

print('\033[1m', 'Visualizations', '\033[0m')
visualization.visualize_text([vis])