In [6]:
import torch
import pandas as pd
from transformers import BertTokenizer, BertForSequenceClassification
import torch.nn.functional as F

# Attention rollout function
def attention_rollout(attentions):
    rollout = torch.eye(attentions[0].size(-1)).to(attentions[0].device)
    for attention in attentions:
        attention_heads_fused = attention.mean(dim=1)
        attention_heads_fused += torch.eye(attention_heads_fused.size(-1)).to(attention_heads_fused.device)
        attention_heads_fused /= attention_heads_fused.sum(dim=-1, keepdim=True)
        rollout = torch.matmul(rollout, attention_heads_fused)
    return rollout

# Load dataset
df = pd.read_csv("sentiment_analysis.csv")
text = df["text"][3]

# Load model and tokenizer
model_name = "nlptown/bert-base-multilingual-uncased-sentiment"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name, output_attentions=True)
model.eval()

# Tokenize and run model
inputs = tokenizer(text, return_tensors="pt", truncation=True)
with torch.no_grad():
    outputs = model(**inputs)
    attentions = outputs.attentions
    logits = outputs.logits
    probs = F.softmax(logits, dim=-1)

# Simplify prediction to binary (positive/negative)
predicted_class = torch.argmax(probs, dim=1).item()
binary_sentiment = "POSITIVE" if predicted_class >= 3 else "NEGATIVE"

# Get tokens and attention rollout
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
rollout = attention_rollout(attentions)[0]
cls_attention = rollout[0]  # Attention from [CLS] token

# Print results
print("\nToken Attention Weights:")
for token, weight in zip(tokens, cls_attention):
    print(f"Token: {token}, Attention Weight: {weight.item():.4f}")

print(f"\nSentiment Prediction: {binary_sentiment}")



Token Attention Weights:
Token: [CLS], Attention Weight: 0.0260
Token: we, Attention Weight: 0.0307
Token: attend, Attention Weight: 0.0275
Token: in, Attention Weight: 0.0185
Token: the, Attention Weight: 0.0153
Token: class, Attention Weight: 0.0276
Token: just, Attention Weight: 0.0282
Token: for, Attention Weight: 0.0207
Token: listening, Attention Weight: 0.0207
Token: teachers, Attention Weight: 0.0281
Token: reading, Attention Weight: 0.0234
Token: on, Attention Weight: 0.0182
Token: slide, Attention Weight: 0.0261
Token: ., Attention Weight: 0.0298
Token: just, Attention Weight: 0.0430
Token: non, Attention Weight: 0.0355
Token: ##sen, Attention Weight: 0.0224
Token: ##ce, Attention Weight: 0.0246
Token: [SEP], Attention Weight: 0.5339

Sentiment Prediction: NEGATIVE
