<a href="https://colab.research.google.com/github/francescopatane96/ESM2_experiments/blob/main/xAI_BERT_text_classification_PT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

https://towardsdatascience.com/interpreting-the-prediction-of-bert-model-for-text-classification-5ab09f8ef074

In [None]:
pip install transformers

In [32]:
from transformers import BertTokenizer

# Instantiate tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

text = 'The movie is horrible. i hate it'

# Tokenize input text
text_ids = tokenizer.encode(text, add_special_tokens=True)

# Print the tokens
print(tokenizer.convert_ids_to_tokens(text_ids))
# Output: ['[CLS]', 'The', 'movie', 'is', 'superb', '[SEP]']

# Print the ids of the tokens
print(text_ids)

['[CLS]', 'The', 'movie', 'is', 'horrible', '.', 'i', 'hate', 'it', '[SEP]']
[101, 1109, 2523, 1110, 9210, 119, 178, 4819, 1122, 102]


In [33]:
from transformers import BertModel
import torch

# Instantiate BERT model
model = BertModel.from_pretrained('bert-base-cased')

embeddings = model.embeddings(torch.tensor([text_ids]))
print(embeddings.size())

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


torch.Size([1, 10, 768])


In [34]:
from torch import nn

# Class of model architecture
class BertClassifier(nn.Module):

    def __init__(self, dropout=0.5):

        super(BertClassifier, self).__init__()

        self.bert = BertModel.from_pretrained('bert-base-cased')
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(768, 2)
        self.relu = nn.ReLU()

    def forward(self, input_id, mask = None):

        _, pooled_output = self.bert(input_ids= input_id, attention_mask=mask,return_dict=False)
        dropout_output = self.dropout(pooled_output)
        linear_output = self.linear(dropout_output)
        final_layer = self.relu(linear_output)

        return final_layer

In [None]:
pip install wget

In [14]:
import wget

url = 'https://github.com/marcellusruben/bert_captum/raw/main/bert_model.pt'

wget.download(url)

'bert_model.pt'

In [35]:
model = BertClassifier()
model.load_state_dict(torch.load('bert_model.pt', map_location=torch.device('cpu')))
model.eval()

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BertClassifier(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_af

In [36]:
# Define model output
def model_output(inputs):
  return model(inputs)[0]

# Define model input
model_input = model.bert.embeddings

In [None]:
pip install captum

In [43]:
from captum.attr import LayerIntegratedGradients

lig = LayerIntegratedGradients(model_output, model_input)

In [45]:
def construct_input_and_baseline(text):

    max_length = 510
    baseline_token_id = tokenizer.pad_token_id 
    sep_token_id = tokenizer.sep_token_id 
    cls_token_id = tokenizer.cls_token_id 

    text_ids = tokenizer.encode(text, max_length=max_length, truncation=True, add_special_tokens=False)
   
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    token_list = tokenizer.convert_ids_to_tokens(input_ids)

    baseline_input_ids = [cls_token_id] + [baseline_token_id] * len(text_ids) + [sep_token_id]
    return torch.tensor([input_ids], device='cpu'), torch.tensor([baseline_input_ids], device='cpu'), token_list

text = 'This movie is horrible. i hate it'
input_ids, baseline_input_ids, all_tokens = construct_input_and_baseline(text)

print(f'original text: {input_ids}')
print(f'baseline text: {baseline_input_ids}')

original text: tensor([[ 101, 1188, 2523, 1110, 9210,  119,  178, 4819, 1122,  102]])
baseline text: tensor([[101,   0,   0,   0,   0,   0,   0,   0,   0, 102]])


In [46]:
attributions, delta = lig.attribute(inputs= input_ids,
                                    baselines= baseline_input_ids,
                                    return_convergence_delta=True,
                                    internal_batch_size=1
                                    )
print(attributions.size())

torch.Size([1, 10, 768])


In [47]:
def summarize_attributions(attributions):

    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    
    return attributions

attributions_sum = summarize_attributions(attributions)
print(attributions_sum.size())

torch.Size([10])


In [48]:
from captum.attr import visualization as viz

score_vis = viz.VisualizationDataRecord(
                        word_attributions = attributions_sum,
                        pred_prob = torch.max(model(input_ids)[0]),
                        pred_class = torch.argmax(model(input_ids)[0]).numpy(),
                        true_class = 1,
                        attr_class = text,
                        attr_score = attributions_sum.sum(),       
                        raw_input_ids = all_tokens,
                        convergence_score = delta)

viz.visualize_text([score_vis])

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,0 (4.57),This movie is horrible. i hate it,1.23,[CLS] This movie is horrible . i hate it [SEP]
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,0 (4.57),This movie is horrible. i hate it,1.23,[CLS] This movie is horrible . i hate it [SEP]
,,,,


In [49]:
def interpret_text(text, true_class):

    input_ids, baseline_input_ids, all_tokens = construct_input_and_baseline(text)
    attributions, delta = lig.attribute(inputs= input_ids,
                                    baselines= baseline_input_ids,
                                    return_convergence_delta=True,
                                    internal_batch_size=1
                                    )
    attributions_sum = summarize_attributions(attributions)

    score_vis = viz.VisualizationDataRecord(
                        word_attributions = attributions_sum,
                        pred_prob = torch.max(model(input_ids)[0]),
                        pred_class = torch.argmax(model(input_ids)[0]).numpy(),
                        true_class = true_class,
                        attr_class = text,
                        attr_score = attributions_sum.sum(),       
                        raw_input_ids = all_tokens,
                        convergence_score = delta)

    viz.visualize_text([score_vis])

In [54]:
text = "bugs are smart animals"
true_class = 1
interpret_text(text, true_class)

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (4.04),bugs are smart animals,1.19,[CLS] bugs are smart animals [SEP]
,,,,
