# AIPI 590 - XAI | Assignment #7
### Hongxuan Li

[![Open In Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1cExtaeDCO1JKHtGKxtxTHCa-TMN1cXVC?usp=sharing)

## References

- Bert: https://arxiv.org/abs/1810.04805
- Captum: https://github.com/pytorch/captum
- Plotly: https://plotly.com/

## Hypothesis
**Null Hypothesis (H0)**: There is no significant difference in BERT's attribution scores between descriptive words (adjective) and their associated target nouns compared to other contextual tokens

**Alternative Hypothesis (H1)**: BERT shows significantly different attribution scores for descriptive word-target noun pairs compared to other contextual tokens

## Approach
1. **Model and Dataset**:
   - Model: BERT-base-uncased with binary classification head
   - Dataset: descriptions containing descriptive word-noun pairs (e.g., "vast ocean", "majestic mountain")

2. **Method**: 
   - Used Integrated Gradients for attribution analysis
   - Computed attribution scores for different token types:
     * Target nouns (e.g., "ocean", "mountain")
     * Descriptive words (e.g., "vast", "majestic")
     * Other contextual tokens

## Findings
- `CLS` and `SEP` tokens consistently show high absolute attribution scores due to their architectural importance
- Other contextual tokens show the highest average attribution scores (~0.306)
- Target nouns (~0.182) and their descriptive words (~0.218) show lower attribution patterns

This suggests that BERT relies more heavily on broad contextual information rather than specific word relationships. The higher attribution scores for other contextual tokens, compared to descriptive words and their target nouns, reject our hypothesis about BERT's processing of semantic relationships.

# Dependencies

In [67]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from captum.attr import IntegratedGradients
import numpy as np
import plotly.graph_objects as go
import plotly.express as px

# Model

In [68]:
# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Model setup
model_name = "bert-base-uncased"
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.to(device)
model.eval()

Using device: cpu


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

# Data

In [69]:
texts = {
    "ocean": "The vast ocean stretches beyond the horizon.",
    "mountain": "The majestic mountain towers over the valley.",
    "forest": "The peaceful forest surrounds the lake.",
    "desert": "The beautiful desert extends into infinity.",
    "river": "The mighty river flows through the canyon.",
    "valley": "The serene valley lies between tall peaks."
}

# Integrated Gradients Analysis

## get token attributions

In [70]:

attributions_dict, tokens_dict, target_token_indices_dict = {}, {}, {}

for target_token, text in texts.items():

    """Generate token attributions using Integrated Gradients"""
    model.get_input_embeddings().requires_grad = True
    
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)
    
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    
    target_token_lower = target_token.lower()
    target_token_indices = [
        i for i, token in enumerate(tokens) 
        if target_token_lower in token.lower()
    ]
    
    if not target_token_indices:
        raise ValueError(f"Target token '{target_token}' not found in tokenized text.")
    
    def forward_func(inputs_embeds):
        outputs = model(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
        return outputs.logits[:, 1]
    
    embeddings = model.get_input_embeddings()
    input_embeds = embeddings(input_ids)
    baseline = torch.zeros_like(input_embeds)
    
    ig = IntegratedGradients(forward_func)
    attributions = ig.attribute(
        input_embeds,
        baselines=baseline,
        n_steps=50
    )

    attributions_dict[target_token] = attributions
    tokens_dict[target_token] = tokens
    target_token_indices_dict[target_token] = target_token_indices


## aggregate attribution scores for individual parts

In [71]:
results = {}
for target_token in list(attributions_dict.keys()):
    tokens = tokens_dict[target_token]
    attributions = attributions_dict[target_token]
    target_indices = target_token_indices_dict[target_token]

    scores = attributions.sum(dim=-1).squeeze(0).detach().cpu().numpy()
    scores_normalized = scores / (np.abs(scores).max() + 1e-10)

    # Define descriptive words (adjectives)
    descriptive_words = ['vast', 'majestic', 'beautiful', 'tall', 'deep', 'wide', 
                        'grand', 'mighty', 'peaceful', 'serene']

    descriptive_indices = [i for i, token in enumerate(tokens) 
                            if token.lower() in descriptive_words]

    other_indices = list(set(range(len(tokens))) - 
                        set(target_indices) - 
                        set(descriptive_indices))

    results[target_token] = {
        'target_scores': np.mean(np.abs(scores_normalized[target_indices])),
        'descriptive_scores': np.mean(np.abs(scores_normalized[descriptive_indices])) if descriptive_indices else 0,
        'other_scores': np.mean(np.abs(scores_normalized[other_indices])) if other_indices else 0
    }

## visualize token attrbutions

In [72]:
for target_token in list(attributions_dict.keys()):

    title = f"Token Attribution Scores for '{target_token}'"

    tokens = tokens_dict[target_token]
    attributions = attributions_dict[target_token]
    target_indices = target_token_indices_dict[target_token]
    
    scores = attributions.sum(dim=-1).squeeze(0).detach().cpu().numpy()
    scores_normalized = scores / (np.abs(scores).max() + 1e-10)

    colors = ['lightgray' if i not in target_token_indices else 'red' 
                for i in range(len(tokens))]

    hover_text = [
        f"Token: {token}<br>Score: {score:.4f}" 
        for token, score in zip(tokens, scores_normalized)
    ]

    fig = go.Figure()

    fig.add_trace(go.Bar(
        x=list(range(len(tokens))),
        y=scores_normalized,
        text=tokens,
        hovertext=hover_text,
        hoverinfo='text',
        marker_color=colors,
        name='Attribution Score'
    ))

    fig.update_layout(
        title=title,
        xaxis_title="Token Position",
        yaxis_title="Attribution Score (Normalized)",
        showlegend=False,
        xaxis=dict(
            tickmode='array',
            ticktext=tokens,
            tickvals=list(range(len(tokens))),
            tickangle=45
        ),
        hoverlabel=dict(
            bgcolor="white",
            font_size=16,
            font_family="Rockwell"
        )
    )
    fig.show()

## visualize average attribution scores

In [73]:
"""Create summary visualization of semantic patterns"""
categories = ['Target Tokens', 'Descriptive Words', 'Other Tokens']
means = [
    np.mean([np.abs(r['target_scores']) for r in results.values()]),
    np.mean([np.abs(r['descriptive_scores']) for r in results.values()]),
    np.mean([np.abs(r['other_scores']) for r in results.values()])
]

fig = go.Figure()

fig.add_trace(go.Bar(
    x=categories,
    y=means,
    marker_color=['#FF6B6B', '#4ECDC4', '#45B7D1'],
    text=[f'{v:.3f}' for v in means],
    textposition='auto',
))

fig.update_layout(
    title='Average Attribution Scores by Token Type',
    xaxis_title='Token Category',
    yaxis_title='Average Attribution Score',
    showlegend=False,
    template='plotly_white',
    height=500,
)
fig.show()