# Sentiment with Transformers

**Interpreting the Prediction of ProsusAI/finbert model for Text Classification**

https://towardsdatascience.com/interpreting-the-prediction-of-bert-model-for-text-classification-5ab09f8ef074

In [1]:
from transformers import BertTokenizer

# Initialize the tokenizer for BERT models
# Note: FinBERT is a pre-trained NLP model to analyze sentiment of financial text
tokenizer = BertTokenizer.from_pretrained('ProsusAI/finbert')

# Set labels
labels = {0: 'Positive', 1: 'Negative', 2: 'Neutral'}

# Create Sentiment classifier

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertForSequenceClassification

class SentimentModel(nn.Module):
   def __init__(self):
        super(SentimentModel, self).__init__()
        # Initialize the model for sequence classification
        self.backbone = BertForSequenceClassification.from_pretrained('ProsusAI/finbert')

   def forward(self, input_id, mask = None):
      output = self.backbone(input_ids=input_id, attention_mask=mask,return_dict=False)
      output = F.softmax(output[0], dim=-1)

      return (output)


model = SentimentModel()
model.to( 'cpu' );

# Sentiment analysis

In [3]:
from captum.attr import LayerIntegratedGradients

# Define model output
def model_output(inputs):
  return model(inputs)[0]

# Define model input
model_input = model.backbone.bert.embeddings


lig = LayerIntegratedGradients(model_output, model_input)

In [4]:
from captum.attr import visualization as viz


def construct_input_and_baseline(text):

    max_length = 512
    baseline_token_id = tokenizer.pad_token_id 
    sep_token_id = tokenizer.sep_token_id 
    cls_token_id = tokenizer.cls_token_id 

    text_ids = tokenizer.encode(text, max_length=max_length, truncation=True, add_special_tokens=False)
   
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    token_list = tokenizer.convert_ids_to_tokens(input_ids)

    baseline_input_ids = [cls_token_id] + [baseline_token_id] * len(text_ids) + [sep_token_id]
    return torch.tensor([input_ids], device='cpu'), torch.tensor([baseline_input_ids], device='cpu'), token_list




def summarize_attributions(attributions):

    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    
    return attributions




def interpret_text(text, true_class=None):

    input_ids, baseline_input_ids, all_tokens = construct_input_and_baseline(text)
    attributions, delta = lig.attribute(inputs=input_ids,
                                        baselines=baseline_input_ids,
                                        return_convergence_delta=True,
                                        internal_batch_size=1)
    attributions_sum = summarize_attributions(attributions)

    score_vis = viz.VisualizationDataRecord(
                        word_attributions = attributions_sum,
                        pred_prob = torch.max(model(input_ids)[0]),
                        pred_class = labels[torch.argmax(model(input_ids)[0]).item()],
                        true_class = true_class,
                        attr_class = text,
                        attr_score = attributions_sum.sum(),       
                        raw_input_ids = all_tokens,
                        convergence_score = delta)

    viz.visualize_text([score_vis])

# Examples

In [5]:
text = "The stock market moved down today and most shares showed losses"
interpret_text(text, true_class = 'Negative')

  inputs, allow_unused)


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
Negative,Negative (0.96),The stock market moved down today and most shares showed losses,1.04,[CLS] the stock market moved down today and most shares showed losses [SEP]
,,,,


In [6]:
text = "The stock market moved up today and most shares showed gains"
interpret_text(text, true_class = 'Positive')

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
Positive,Positive (0.89),The stock market moved up today and most shares showed gains,2.41,[CLS] the stock market moved up today and most shares showed gains [SEP]
,,,,
