In [None]:
!pip install transformers
!pip install captum

In [58]:
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer
from captum.attr import LayerIntegratedGradients, visualization as viz
import torch

def visualize_sentiment(text: str):
    """
    Visualizes the sentiment of the given text using a pre-trained DistilBERT model.

    Args:
        text (str): The text to visualize.
    """

    # Pre-trained model and tokenizer
    model_path = 'distilbert-base-uncased-finetuned-sst-2-english'
    model = DistilBertForSequenceClassification.from_pretrained(model_path)
    tokenizer = DistilBertTokenizer.from_pretrained(model_path)
    model.eval()

    # Function to create input tensors and baseline for the given text
    def construct_input_and_baseline(input_text: str):
        """Constructs input and baseline tensors for the given text."""
        max_length = 768
        baseline_token_id = tokenizer.pad_token_id
        sep_token_id = tokenizer.sep_token_id
        cls_token_id = tokenizer.cls_token_id

        text_ids = tokenizer.encode(input_text, max_length=max_length, truncation=True, add_special_tokens=False)
        input_ids = [cls_token_id] + text_ids + [sep_token_id]
        baseline_input_ids = [cls_token_id] + [baseline_token_id] * len(text_ids) + [sep_token_id]
        token_list = tokenizer.convert_ids_to_tokens(input_ids)

        return torch.tensor([input_ids], device='cpu'), torch.tensor([baseline_input_ids], device='cpu'), token_list

    # Constructing input and baseline
    input_ids, baseline_input_ids, all_tokens = construct_input_and_baseline(text)

    # Defining a function for the model's output
    def model_output(inputs):
        return model(inputs)[0]

    # Layer Integrated Gradients
    lig = LayerIntegratedGradients(model_output, model.distilbert.embeddings)

    # Target classes
    target_classes = [0, 1]
    attributions = {}
    delta = {}

    # Calculating attributions for both classes
    for target_class in target_classes:
        attributions[target_class], delta[target_class] = lig.attribute(
            inputs=input_ids,
            baselines=baseline_input_ids,
            target=target_class,
            return_convergence_delta=True,
            internal_batch_size=1)

    # Summarizing attributions
    neg_attributions = attributions[0].sum(dim=-1).squeeze(0) / torch.norm(attributions[0])
    pos_attributions = attributions[1].sum(dim=-1).squeeze(0) / torch.norm(attributions[1])

    # Predicting the class
    pred_prob, pred_class = torch.max(model(input_ids)[0]), int(torch.argmax(model(input_ids)[0]))

    # Selecting the attributions based on the predicted class
    summarized_attr = pos_attributions if pred_class == 1 else neg_attributions

    # Visualization data
    score_vis = viz.VisualizationDataRecord(
                        word_attributions=summarized_attr,
                        pred_prob=pred_prob,
                        pred_class=pred_class,
                        true_class=None,
                        attr_class=text,
                        attr_score=summarized_attr.sum(),
                        raw_input_ids=all_tokens,
                        convergence_score=delta[pred_class])

    # Visualizing the result
    viz.visualize_text([score_vis])


In [60]:
text = "The movie was not bad as mentioned by critics. It was in fact awesome; I enjoyed the whole time"
visualize_sentiment(text)

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
,1 (4.65),The movie was not bad as mentioned by critics. It was in fact awesome; I enjoyed the whole time,12.92,[CLS] the movie was not bad as mentioned by critics . it was in fact awesome ; i enjoyed the whole time [SEP]
,,,,
