# Feature attribution

In [1]:
__author__ = "Christopher Potts"
__version__ = "CS224u, Stanford, Summer 2021"

## Contents

1. [Overview](#Overview)
1. [InputXGradients](#InputXGradients)
1. [Selectivity examples](#Selectivity-examples)
1. [Simple feed-forward classifier example](#Simple-feed-forward-classifier-example)
1. [Bag-of-words classifier for the SST](#Bag-of-words-classifier-for-the-SST)
1. [BERT example](#BERT-example)

## Overview

This notebook is an experimental extension of the CS224u course code. It focuses on the [Integrated Gradients](https://arxiv.org/abs/1703.01365) method for feature attribution, with comparisons to the "inputs $\times$ gradients" method. To run the notebook, first install [the Captum library](https://captum.ai/):

```pip install captum```

This is not currently a required installation (but it will be in future years).

## InputXGradients

For both implementations, the `forward` method of `model` is used. `X` is an (m x n) tensor of attributions. Use `targets=None` for models with scalar outputs, else supply a LongTensor giving a label for each example.

In [2]:
import torch

def grad_x_input(model, X, targets=None):
    """Implementation using PyTorch directly."""
    X.requires_grad = True
    y = model(X)
    y = y if targets is None else y[list(range(len(y))), targets]
    (grads, ) = torch.autograd.grad(y.unbind(), X)
    return grads * X

In [3]:
from captum.attr import InputXGradient

def captum_grad_x_input(model, X, target):
    """Captum-based implementation."""
    X.requires_grad = True
    amod = InputXGradient(model)
    return amod.attribute(X, target=target)

## Selectivity examples

In [4]:
import numpy as np
import torch
import torch.nn as nn
from captum.attr import IntegratedGradients
from captum.attr import InputXGradient

In [5]:
class SelectivityAssessor(nn.Module):
    """Model used by Sundararajan et al, section 2.1 to show that
    input * gradients violates their selectivity axiom.
    """
    def __init__(self):
        super().__init__()
        self.relu = nn.ReLU()

    def forward(self, X):
        return 1.0 - self.relu(1.0 - X)

In [6]:
sel_mod = SelectivityAssessor()

Simple inputs with just one feature:

In [7]:
X_sel = torch.FloatTensor([[0.0], [2.0]])

The outputs for our two examples differ:

In [8]:
sel_mod(X_sel)

tensor([[0.],
        [1.]])

However, `InputXGradient` assigns the same importance to the feature across the two examples, violating selectivity:

In [9]:
captum_grad_x_input(sel_mod, X_sel, target=None)

tensor([[0.],
        [-0.]], grad_fn=<MulBackward0>)

Integrated gradients addresses the problem by averaging gradients across all interpolated representations between the baseline and the actual input:

In [10]:
ig_sel = IntegratedGradients(sel_mod)

In [11]:
sel_baseline = torch.FloatTensor([[0.0]])

In [12]:
ig_sel.attribute(X_sel, sel_baseline)

tensor([[0.],
        [1.]], dtype=torch.float64, grad_fn=<MulBackward0>)

A toy implementation to help bring out what is happening:

In [13]:
def ig_reference_implementation(model, x, base, m=50):
    vals = []
    for k in range(m):
        # Interpolated representation:
        xx = (base + (k/m)) * (x - base)
        # Gradient for the interpolated example:
        xx.requires_grad = True
        y = model(xx)
        (grads, ) = torch.autograd.grad(y.unbind(), xx)
        vals.append(grads)
    return (1 / m) * torch.cat(vals).sum(axis=0) * (x - base)

In [14]:
ig_reference_implementation(sel_mod, torch.FloatTensor([[2.0]]), sel_baseline)

tensor([[1.]])

## Simple feed-forward classifier example

In [15]:
from captum.attr import IntegratedGradients
from sklearn.datasets import make_classification
from sklearn.feature_selection import mutual_info_classif
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
import torch
from torch_shallow_neural_classifier import TorchShallowNeuralClassifier

In [16]:
X_cls, y_cls = make_classification(
    n_samples=5000,
    n_classes=3,
    n_features=5,
    n_informative=3,
    n_redundant=0,
    random_state=42)

The classification problem has two uninformative features:

In [17]:
mutual_info_classif(X_cls, y_cls)

array([0.20138107, 0.02833358, 0.11584416, 0.        , 0.        ])

In [18]:
X_cls_train, X_cls_test, y_cls_train, y_cls_test = train_test_split(X_cls, y_cls)

In [19]:
classifier = TorchShallowNeuralClassifier()

In [20]:
_ = classifier.fit(X_cls_train, y_cls_train)

Stopping after epoch 350. Training loss did not improve more than tol=1e-05. Final error is 1.4553862810134888.

In [21]:
cls_preds = classifier.predict(X_cls_test)

In [22]:
accuracy_score(y_cls_test, cls_preds)

0.844

In [23]:
classifier_ig = IntegratedGradients(classifier.model)

In [24]:
classifier_baseline = torch.zeros(1, X_cls_train.shape[1])

Integrated gradients with respect to the actual labels:

In [25]:
classifier_attrs = classifier_ig.attribute(
    torch.FloatTensor(X_cls_test),
    classifier_baseline,
    target=torch.LongTensor(y_cls_test))

Average attribution is low for the two uninformative features:

In [26]:
classifier_attrs.mean(axis=0)

tensor([ 0.9523,  0.5059,  0.7190, -0.0193, -0.0127], dtype=torch.float64)

## Bag-of-words classifier for the SST

In [27]:
from collections import Counter
from captum.attr import IntegratedGradients
from nltk.corpus import stopwords
from operator import itemgetter
import os
from sklearn.metrics import classification_report
import torch
from torch_shallow_neural_classifier import TorchShallowNeuralClassifier
import sst

In [28]:
SST_HOME = os.path.join("data", "sentiment")

Bag-of-word featurization with stopword removal to make this a little easier to study:

In [29]:
stopwords = set(stopwords.words('english'))

def phi(text):
    return Counter([w for w in text.lower().split() if w not in stopwords])

In [30]:
def fit_mlp(X, y):
    mod = TorchShallowNeuralClassifier(early_stopping=True)
    mod.fit(X, y)
    return mod

In [31]:
experiment = sst.experiment(
    sst.train_reader(SST_HOME),
    phi,
    fit_mlp,
    sst.dev_reader(SST_HOME))

Stopping after epoch 45. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 0.5286356508731842

              precision    recall  f1-score   support

    negative      0.632     0.666     0.648       428
     neutral      0.252     0.144     0.183       229
    positive      0.638     0.745     0.687       444

    accuracy                          0.589      1101
   macro avg      0.507     0.518     0.506      1101
weighted avg      0.555     0.589     0.567      1101



Trained model:

In [32]:
sst_classifier = experiment['model']

Captum needs to have labels as indices rather than strings:

In [33]:
sst_classifier.classes_

['negative', 'neutral', 'positive']

In [34]:
y_sst_test = [sst_classifier.classes_.index(label)
              for label in experiment['assess_datasets'][0]['y']]

sst_preds = [sst_classifier.classes_.index(label)
             for label in experiment['predictions'][0]]

Our featurized test set:

In [35]:
X_sst_test = experiment['assess_datasets'][0]['X']

Feature names to help with analyses:

In [36]:
fnames = experiment['train_dataset']['vectorizer'].get_feature_names()

Integrated gradients:

In [37]:
sst_ig = IntegratedGradients(sst_classifier.model)

All-0s baseline:

In [38]:
sst_baseline = torch.zeros(1, experiment['train_dataset']['X'].shape[1])

Attributions with respect to the model's predictions:

In [39]:
sst_attrs = sst_ig.attribute(
    torch.FloatTensor(X_sst_test),
    sst_baseline,
    target=torch.LongTensor(sst_preds))

Helper functions for error analysis:

In [40]:
def error_analysis(gold=1, predicted=2):
    err_ind = [i for i, (g, p) in enumerate(zip(y_sst_test, sst_preds))
               if g == gold and p == predicted]
    attr_lookup = create_attr_lookup(sst_attrs[err_ind])
    return attr_lookup, err_ind

def create_attr_lookup(attrs):
    mu = attrs.mean(axis=0).detach().numpy()
    return sorted(zip(fnames, mu), key=itemgetter(1), reverse=True)

In [41]:
sst_attrs_lookup, sst_err_ind = error_analysis(gold=1, predicted=2)

In [42]:
sst_attrs_lookup[: 5]

[('.', 0.0810196881304765),
 ('fun', 0.06947951198804361),
 ('film', 0.04929371582902589),
 ('solid', 0.04672621050246706),
 ('kids', 0.03809466066035495)]

Error analysis for a specific example:

In [43]:
ex_ind = sst_err_ind[0]

In [44]:
experiment['assess_datasets'][0]['raw_examples'][ex_ind]

'No one goes unindicted here , which is probably for the best .'

In [45]:
ex_attr_lookup = create_attr_lookup(sst_attrs[ex_ind:ex_ind+1])

In [46]:
[(f, a) for f, a in ex_attr_lookup if a != 0]

[('best', 0.7275746240134193),
 ('probably', 0.3349310239020713),
 ('.', 0.08365320489322038),
 (',', 0.01769884396379488),
 ('one', 0.002754690330329966),
 ('goes', -0.19133889066064283)]

## BERT example

In [47]:
import torch
import torch.nn.functional as F
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from captum.attr import LayerIntegratedGradients
from captum.attr import visualization as viz

In [48]:
hf_weights_name = 'cardiffnlp/twitter-roberta-base-sentiment'

In [49]:
hf_tokenizer = AutoTokenizer.from_pretrained(hf_weights_name)

In [50]:
hf_model = AutoModelForSequenceClassification.from_pretrained(hf_weights_name)

In [51]:
def hf_predict_one_proba(text):
    input_ids = hf_tokenizer.encode(
        text, add_special_tokens=True, return_tensors='pt')
    hf_model.eval()
    with torch.no_grad():
        logits = hf_model(input_ids)[0]
        preds = F.softmax(logits, dim=1)
    hf_model.train()
    return preds.squeeze(0)

In [52]:
def hf_ig_encodings(text):
    pad_id = hf_tokenizer.pad_token_id
    cls_id = hf_tokenizer.cls_token_id
    sep_id = hf_tokenizer.sep_token_id
    input_ids = hf_tokenizer.encode(text, add_special_tokens=False)
    base_ids = [pad_id] * len(input_ids)
    input_ids = [cls_id] + input_ids + [sep_id]
    base_ids = [cls_id] + base_ids + [sep_id]
    return torch.LongTensor([input_ids]), torch.LongTensor([base_ids])

In [53]:
def hf_ig_analyses(text2class):
    data = []
    for text, true_class in text2class.items():
        score_vis = hf_ig_analysis_one(text, true_class)
        data.append(score_vis)
    viz.visualize_text(data)


def hf_ig_analysis_one(text, true_class):
    # Option to look at different layers:
    # layer = model.roberta.encoder.layer[0]
    # layer = model.roberta.embeddings.word_embeddings
    layer = hf_model.roberta.embeddings

    def ig_forward(inputs):
        return hf_model(inputs).logits

    ig = LayerIntegratedGradients(ig_forward, layer)

    input_ids, base_ids = hf_ig_encodings(text)

    attrs, delta = ig.attribute(
        input_ids,
        base_ids,
        target=true_class,
        return_convergence_delta=True)

    # Summarize and z-score normalize the attributions
    # for each representation in `layer`:
    scores = attrs.sum(dim=-1).squeeze(0)
    scores = (scores - scores.mean()) / scores.norm()

    # Intuitive tokens to help with analysis:
    raw_input = hf_tokenizer.convert_ids_to_tokens(input_ids.tolist()[0])
    # RoBERTa-specific clean-up:
    raw_input = [x.strip("Ä ") for x in raw_input]

    # Predictions for comparisons:
    pred_probs = hf_predict_one_proba(text)
    pred_class = pred_probs.argmax()

    score_vis = viz.VisualizationDataRecord(
        word_attributions=scores,
        pred_prob=pred_probs.max(),
        pred_class=pred_class,
        true_class=true_class,
        attr_class=None,
        attr_score=attrs.sum(),
        raw_input=raw_input,
        convergence_score=delta)

    return score_vis

In [54]:
score_vis = hf_ig_analyses({
    "They said it would be great, and they were right.": 2,
    "They said it would be great, and they were wrong.": 0,
    "They were right to say it would be great.": 2,
    "They were wrong to say it would be great.": 0,
    "They said it would be stellar, and they were correct.": 2,
    "They said it would be stellar, and they were incorrect.": 0})

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
2.0,2 (0.82),,2.69,"#s They said it would be great , and they were right . #/s"
,,,,
0.0,0 (0.50),,1.67,"#s They said it would be great , and they were wrong . #/s"
,,,,
2.0,2 (0.76),,1.17,#s They were right to say it would be great . #/s
,,,,
0.0,0 (0.62),,3.81,#s They were wrong to say it would be great . #/s
,,,,
2.0,2 (0.77),,1.6,"#s They said it would be stellar , and they were correct . #/s"
,,,,
