# AIPI 590 - XAI | Assignment #5
### Hongxuan Li

[![Open In Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1ba_9sGCfP0BW0fcbe3pRLCdxzsckDT6P?usp=sharing)

#### References

- Model BERT: https://arxiv.org/abs/1810.04805
- Model Tool Huggingface Pre-trained Model: https://arxiv.org/abs/1810.04805
- Interpretation Tool SHAP: https://github.com/shap/shap

In [46]:
# import os

# # Remove Colab default sample_data
# !rm -r ./sample_data

# # Clone GitHub files to colab workspace
# repo_name = "/content/AIPI590-XAI" # Change to your repo name
# git_path = 'https://github.com/h0ngxuanli/AIPI590-XAI.git' #Change to your path
# !git clone "{git_path}"


# # Install dependencies from requirements.txt file
# !pip install -r "{os.path.join(repo_name,'assignment5/requirements.txt')}" #Add if using requirements.txt

# # Change working directory to location of notebook
# notebook_dir = 'assignment4/'
# path_to_notebook = os.path.join(repo_name,notebook_dir)
# %cd "{path_to_notebook}"
# %ls


# Dependencies

In [47]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import pipeline
import shap
import numpy as np

# Dataset

Test sentiment sentences with various difficulties. Some words could be very tricky in the context to test whetehr BERT to capature the real meaning of sentence.

In [48]:
test_examples = [
    "The software update fixed some issues but introduced new bugs.",
    "Despite the rain, we had a great time at the outdoor festival.",
    "The customer service was helpful, but I still haven't resolved my issue.",
    "I can't recommend this enough.",
    "The performance was out of this world.",
    "This solution is a drop in the bucket."
]

# Load BERT for sentiment analysis

Choose BERT for sentiment analysis, and examine whether SHAP could capture esstential sentiment words

In [49]:
# Load pre-trained BERT and tokenizer
model_name = "google-bert/bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.eval()

# Move to cuda
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

# Inference Pipeline

In [50]:
def inference(texts):

    processed_texts = []
    for text in texts:
      processed_texts.append(text)

    # tokenize input
    inputs = tokenizer(
        processed_texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=512
    ).to(device)

    # get BERT output
    with torch.no_grad():
        outputs = model(**inputs)

    # get prediction based on output logits
    probs = torch.nn.functional.softmax(outputs.logits, dim=-1)

    return probs.cpu().numpy()

# Initialize SHAP

In [51]:
# initialize masker to mask words
masker = shap.maskers.Text(tokenizer)

# create SHAP explainer
explainer = shap.Explainer(inference, masker)

# get SHAP explanation
def analyze_sentiment_with_shap(text):

    # Get sentiment prediction
    prediction = inference([text])[0]

    # Get binary prediction results
    sentiment = "Positive" if prediction[1] > prediction[0] else "Negative"
    sentiment_idx = 1 if prediction[1] > prediction[0] else 0
    confidence = float(max(prediction))

    # Compute SHAP values
    shap_values = explainer([text])
    shap_values = shap_values[:, :, sentiment_idx]

    return sentiment, confidence, shap_values, sentiment_idx

# Visulize Explanations

### Explanation of SHAP results on test examples
**The explaination is based on my run, and it could change for different runs, which also shows the how sensitive SHAP is to BERT outputs**
- "The software update fixed some issues but introduced new bugs."
  - 'but' has highest importance, which is correct. But 'bugs' may be should have positive effect towards making the sentence Negative.
- "Despite the rain, we had a great time at the outdoor festival.",
  - 'we' has the highest importance which is not relevant to the sentiment
- "The customer service was helpful, but I still haven't resolved my issue."
  - BERT's prediction is positive. As a result, SHAP assigned high values to "haven" and "resovled", which is a wrong analysis based on wrong prediction of BERT.
- "I can't recommend this enough."
  - SHAP fail to capture the interaction between "can't" and "enough", where only "can't" is assigned a high importance score
- "The performance was out of this world."
  - SHAP fail to capture the interaction between "out of" and "world", where only "world" is assigned a high importance score
- "This solution is a drop in the bucket."
  - Misclassification of BERT leads SHAP fail to capture mearningful words in the sentence.

### Discussion

**Strength**

1. Model-agnostic, I can even apply SHAP to different layers of BERT to understand how importance evolves through the network
2. SHAP is able to provide exact importane for each token, which provides evidence to know how the BERT make wrong predictions
3. SHAP is able to help find corner cases that BERT fail to handle with, which facilitates identifying the bias, unfairness, privacy problem in the model.
4. SHAP is able to capture the interaction between tokens by calculating each token's contribution in the context of other tokens.

**Limitations**

1. SHAP struggles to fully capture the contextual nature of BERT, where the same word can have different representations based on its context. For instance, SHAP failed to capture the interaction between "can't" and "enough".
2. The large feature space in longer sentences may lead to sparse and less reliable SHAP values, reducing the credibility of explanations.
3. SHAP's interpratation depends on the the stability and correctness of BERT's output. BERT's wrong prediction could lead to incorrect analysis of SHAP.
4. BERT's use of subword tokenization can lead to unintuitive SHAP results, where importance is assigned to partial words rather than complete words.

**Improvements**
1. Incorporate BERT's attention weights into SHAP calculations for more context-aware attributions.
2. Average SHAP values across multiple occurrences of a word in different contexts to capture semantic significance of a word.
3. Apply SHAP to fixed-size windows of text instead of entire long sentences to enable long-context sentence analysis


In [52]:
for example in test_examples:
    sentiment, confidence, shap_values, sentiment_idx = analyze_sentiment_with_shap(example)
    print(f"\nText: {example}")
    print(f"Sentiment: {sentiment}")
    print(f"Confidence: {confidence:.4f}")
    # shap.plots.text(shap_values[0])


    # Display the most influential words
    # Access the SHAP values for the positive class (index 1)
    shap_values_array = shap_values[0]


    # Get the tokens
    tokens = shap_values[0].data

    # Pair tokens with their corresponding SHAP values
    token_importance = list(zip(tokens, shap_values_array))

    # Sort tokens by absolute SHAP value (importance)
    token_importance.sort(key=lambda x: np.abs(x[1]), reverse=True)

    print("\nMost influential words (absolute SHAP value):")
    for token, importance in token_importance[:5]:  # Top 5 most influential words
        print(f"{token}: {importance:.4f}")


    shap.plots.text(shap_values[0])



  0%|          | 0/156 [00:00<?, ?it/s]

PartitionExplainer explainer: 2it [00:14, 14.69s/it]               


Text: The software update fixed some issues but introduced new bugs.
Sentiment: Negative
Confidence: 0.6948

Most influential words (absolute SHAP value):





  0%|          | 0/240 [00:00<?, ?it/s]

PartitionExplainer explainer: 2it [00:38, 38.51s/it]               


Text: Despite the rain, we had a great time at the outdoor festival.
Sentiment: Negative
Confidence: 0.7357

Most influential words (absolute SHAP value):





  0%|          | 0/306 [00:00<?, ?it/s]

PartitionExplainer explainer: 2it [00:34, 34.32s/it]               


Text: The customer service was helpful, but I still haven't resolved my issue.
Sentiment: Negative
Confidence: 0.7439

Most influential words (absolute SHAP value):





  0%|          | 0/90 [00:00<?, ?it/s]


Text: I can't recommend this enough.
Sentiment: Negative
Confidence: 0.7085

Most influential words (absolute SHAP value):


  0%|          | 0/90 [00:00<?, ?it/s]


Text: The performance was out of this world.
Sentiment: Negative
Confidence: 0.6773

Most influential words (absolute SHAP value):


  0%|          | 0/110 [00:00<?, ?it/s]

PartitionExplainer explainer: 2it [00:10, 10.06s/it]               


Text: This solution is a drop in the bucket.
Sentiment: Negative
Confidence: 0.6998

Most influential words (absolute SHAP value):



