<a href="https://colab.research.google.com/github/hila-chefer/Transformer-Explainability/blob/main/BERT_explainability.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
# os.chdir(f'./Transformer-Explainability')

In [None]:
from transformers import BertTokenizer
from BERT_explainability.modules.BERT.ExplanationGenerator import Generator
from BERT_explainability.modules.BERT.BertForSequenceClassification import BertForSequenceClassification
from transformers import BertTokenizer
from BERT_explainability.modules.BERT.ExplanationGenerator import Generator
from transformers import AutoTokenizer

from captum.attr import visualization
import torch
import numpy as np
device = "cuda:1" if torch.cuda.is_available() else "cpu"
print(f'running on {device}')

In [None]:
model = BertForSequenceClassification.from_pretrained("textattack/bert-base-uncased-SST-2").to(device)
model.eval()
tokenizer = AutoTokenizer.from_pretrained("textattack/bert-base-uncased-SST-2")
# initialize the explanations generator
explanations = Generator(model)

classifications = ["NEGATIVE", "POSITIVE"]


In [None]:
model

## interpret from 2nd paper
writing the interpretation that is implemented in the followup paper with intent to compare between the two

In [None]:
attn_blocks=list(dict(model.bert.encoder.layer.named_children()).values())
len(attn_blocks)

In [None]:
# attn_blocks[0]
# the following line should be run only after we have done an inference on input sequence. otherwise attn does not exist yet (its None)
attn_blocks[0].attention.self.attn.shape


In [None]:
start_layer = 0

def interpret(input_ids, model, attention_mask, start_layer_text=start_layer):
    device = model.device
    batch_size = input_ids.shape[0]
    logits = model(input_ids=input_ids, attention_mask=attention_mask)[0]
    probs = logits.softmax(dim=-1).detach().cpu().numpy()

    index = np.argmax(probs, axis=-1)
    one_hot = np.zeros((logits.shape[0], logits.shape[1]), dtype=np.float32)
    one_hot[torch.arange(logits.shape[0]), index] = 1
    one_hot = torch.from_numpy(one_hot).requires_grad_(True)
    one_hot = torch.sum(one_hot.cuda(device) * logits)
    # model.zero_grad()

    text_attn_blocks = list(dict(model.bert.encoder.layer.named_children()).values())

    if start_layer_text == -1: 
      # calculate index of last layer 
      start_layer_text = len(text_attn_blocks) - 1

    num_tokens = text_attn_blocks[0].attention.self.attn.shape[-1]
    R_text = torch.eye(num_tokens, num_tokens, dtype=text_attn_blocks[0].attention.self.attn.dtype).to(device)
    R_text = R_text.unsqueeze(0).expand(batch_size, num_tokens, num_tokens)
    for i, blk in enumerate(text_attn_blocks):
        if i < start_layer_text:
          continue
        grad = torch.autograd.grad(one_hot, [blk.attention.self.attn], retain_graph=True)[0].detach()
        cam = blk.attention.self.attn.detach()
        cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
        grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1])
        cam = grad * cam
        cam = cam.reshape(batch_size, -1, cam.shape[-1], cam.shape[-1])
        cam = cam.clamp(min=0).mean(dim=1)
        R_text = R_text + torch.bmm(cam, R_text)
    expl = R_text
    expl[:, 0, 0] = expl[:, 0].min()
    
    
   
    return probs,expl[:,0]

In [None]:
# encode a sentence
text_batch = ["This movie was the best movie I have ever seen! some scenes were ridiculous, but acting was great."]
encoding = tokenizer(text_batch, return_tensors='pt')
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)

# true class is positive - 1
true_class = 1

probs, expl = interpret(input_ids,model,attention_mask,start_layer_text=0)
expl = expl[0]
expl = (expl - expl.min()) / (expl.max() - expl.min())
# expl = expl / expl.sum()
classification = probs.argmax(axis=-1).item()
# get class name
class_name = classifications[classification]
# if the classification is negative, higher explanation scores are more negative
# flip for visualization
if class_name == "NEGATIVE":
  expl *= (-1)

tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten())
print([(tokens[i], expl[i].item()) for i in range(len(tokens))])
vis_data_records = [visualization.VisualizationDataRecord(
                                expl,
                                probs[0][classification],
                                classification,
                                true_class,
                                true_class,
                                1,       
                                tokens,
                                1)]
visualization.visualize_text(vis_data_records)

In [None]:
expl

## original paper 
### Positive sentiment example

In [None]:
expl.shape

In [None]:
input_ids.shape

In [None]:
# encode a sentence
text_batch = ["This movie was the best movie I have ever seen! some scenes were ridiculous, but acting was great."]
encoding = tokenizer(text_batch, return_tensors='pt')
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)

# true class is positive - 1
true_class = 1

# generate an explanation for the input
expl = explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=0)[0]
# normalize scores
expl = (expl - expl.min()) / (expl.max() - expl.min())

# get the model classification
output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1)
classification = output.argmax(dim=-1).item()
# get class name
class_name = classifications[classification]
# if the classification is negative, higher explanation scores are more negative
# flip for visualization
if class_name == "NEGATIVE":
  expl *= (-1)

tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten())
print([(tokens[i], expl[i].item()) for i in range(len(tokens))])
vis_data_records = [visualization.VisualizationDataRecord(
                                expl,
                                output[0][classification],
                                classification,
                                true_class,
                                true_class,
                                1,       
                                tokens,
                                1)]
visualization.visualize_text(vis_data_records)

#Negative sentiment example

In [None]:
# encode a sentence
text_batch = ["I really didn't like this movie. Some of the actors were good, but overall the movie was boring."]
encoding = tokenizer(text_batch, return_tensors='pt')
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)

# generate an explanation for the input
expl = explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=0)[0]
# normalize scores
expl = (expl - expl.min()) / (expl.max() - expl.min())

# get the model classification
output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1)
classification = output.argmax(dim=-1).item()
# get class name
class_name = classifications[classification]
# if the classification is negative, higher explanation scores are more negative
# flip for visualization
if class_name == "NEGATIVE":
  expl *= (-1)

tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten())
print([(tokens[i], expl[i].item()) for i in range(len(tokens))])
vis_data_records = [visualization.VisualizationDataRecord(
                                expl,
                                output[0][classification],
                                classification,
                                1,
                                1,
                                1,       
                                tokens,
                                1)]
visualization.visualize_text(vis_data_records)

# Choosing class for visualization example

In [None]:
# encode a sentence
text_batch = ["I hate that I love you."]
encoding = tokenizer(text_batch, return_tensors='pt')
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)

# true class is positive - 1
true_class = 1

# generate an explanation for the input
target_class = 0
expl = explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=11, index=target_class)[0]
# normalize scores
expl = (expl - expl.min()) / (expl.max() - expl.min())

# get the model classification
output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1)

# get class name
class_name = classifications[target_class]
# if the classification is negative, higher explanation scores are more negative
# flip for visualization
if class_name == "NEGATIVE":
  expl *= (-1)

tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten())
print([(tokens[i], expl[i].item()) for i in range(len(tokens))])
vis_data_records = [visualization.VisualizationDataRecord(
                                expl,
                                output[0][classification],
                                classification,
                                true_class,
                                true_class,
                                1,       
                                tokens,
                                1)]
visualization.visualize_text(vis_data_records)

In [None]:
# encode a sentence
text_batch = ["I hate that I love you."]
encoding = tokenizer(text_batch, return_tensors='pt')
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)

# true class is positive - 1
true_class = 1

# generate an explanation for the input
target_class = 1
expl = explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=11, index=target_class)[0]
# normalize scores
expl = (expl - expl.min()) / (expl.max() - expl.min())

# get the model classification
output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1)

# get class name
class_name = classifications[target_class]
# if the classification is negative, higher explanation scores are more negative
# flip for visualization
if class_name == "NEGATIVE":
  expl *= (-1)

tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten())
print([(tokens[i], expl[i].item()) for i in range(len(tokens))])
vis_data_records = [visualization.VisualizationDataRecord(
                                expl,
                                output[0][classification],
                                classification,
                                true_class,
                                true_class,
                                1,       
                                tokens,
                                1)]
visualization.visualize_text(vis_data_records)