In [1]:
import torch
from transformers_interpret import SequenceClassificationExplainer
from transformers import BertTokenizer, BertForSequenceClassification
from bertviz import head_view
from transformers import BertModel
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


#### Lab 7: Using LRP and Attention Visualization

##### Step 1 Load Bert fine-tuned Emotion Classifier

In [2]:
# Load a pre-trained BERT model (fine-tuned for emotion classification)
##you can use this model or any classifier you choose: "nateraw/bert-base-uncased-emotion"
model_name = "nateraw/bert-base-uncased-emotion"

###instead of using a pipline load the tokenizer and model separateley
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name)

##### Step 2: Get a sample sentence and make a classification

In [3]:
# Make up something releveant to your classifier
text = "Today is okay but I would not say it is the best day ever"

In [4]:
###Tokenize input
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)

##### Perform classification (forward pass)
with torch.no_grad():  # No gradients needed for inference
    output = model(**inputs)

#### Get predicted emotion label
logits = output.logits  # Raw scores for each emotion category
probs = F.softmax(logits, dim=-1)  # Convert to probabilities
predicted_label = torch.argmax(probs, dim=1).item()  # Get index of highest label (probability)

#### Load emotion labels (assuming 6 emotions in `nateraw/bert-base-uncased-emotion`)
emotion_labels = ["sadness", "joy", "love", "anger", "fear", "surprise"]

predicted_emotion = emotion_labels[predicted_label] #store final result
## Print classification result
print(f"Predicted Emotion: {predicted_emotion} (Confidence: {probs.max().item():.4f})")

Predicted Emotion: joy (Confidence: 0.9947)


#### Step 3: Layer-wise Relevance Propagation (LRP) using transformers-interpret


In [None]:
##this is the easy part
explainer = SequenceClassificationExplainer(model, tokenizer)
# Compute and print explanation
word_importances = explainer(text)

In [None]:
word_importances

[('[CLS]', 0.0),
 ('today', -0.34817989902597896),
 ('is', -0.0224294598545257),
 ('okay', 0.3383600736091839),
 ('but', -0.18014577224918712),
 ('i', 0.10099709055763446),
 ('would', 0.01372139886154836),
 ('not', 0.10042423699395508),
 ('say', 0.09939837746565805),
 ('it', 0.13494181605916952),
 ('is', 0.09143401452417635),
 ('the', -0.1292510805280939),
 ('best', 0.7653924829477073),
 ('day', -0.04828484874153694),
 ('ever', 0.2637072060913262),
 ('[SEP]', 0.0)]

In [9]:
explainer.visualize()

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,joy (0.99),joy,1.18,[CLS] today is okay but i would not say it is the best day ever [SEP]
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,joy (0.99),joy,1.18,[CLS] today is okay but i would not say it is the best day ever [SEP]
,,,,


#### Step 4: Attention-Based Explainability using BERTViz
**Hint:** You need to load in the base model for attention, since the fine tuned model doesn't return attention weights. 

In [7]:
#### Load a base BERT model (not classification) to extract attention weights
#### if you used a distilbert or other model you'll need to load that base model
bert_model = BertModel.from_pretrained("bert-base-uncased", output_attentions=True, attn_implementation="eager")

In [8]:
##Tokenize input
inputs = tokenizer(text, return_tensors="pt")

###Get attention weights from the model with forwardpass (don't track gradients)
with torch.no_grad():
    output = bert_model(**inputs)

###Extract attention from the model's output
attention = output.attentions  # Shape: (num_layers, num_heads, seq_len, seq_len)

##Convert token IDs to words for visualization
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])

###Visualize attention using BertViz
head_view(attention, tokens)


<IPython.core.display.Javascript object>