# bert 분류모델을 captum으로 XAI 하기

### wygo, 240502

- [ref1](https://colab.research.google.com/drive/1pgAbzUF2SzF0BdFtGpJbZPWUOhFxT2NZ#scrollTo=6grV8dFnj9xO)
- [ref2](https://captum.ai/tutorials/Bert_SQUAD_Interpret)

In [None]:
# captum for bert classificatrion
# https://colab.research.google.com/drive/1pgAbzUF2SzF0BdFtGpJbZPWUOhFxT2NZ#scrollTo=6grV8dFnj9xO
# # https://captum.ai/tutorials/Bert_SQUAD_Interpret
from transformers import BertTokenizer, BertForSequenceClassification, BertConfig
import captum
from captum.attr import visualization as viz
from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer
import torch
import matplotlib.pyplot as plt

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# function
def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):

    text_ids = tokenizer.encode(text, add_special_tokens=False)
    # construct input token ids
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)

def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
    seq_len = input_ids.size(1)
    token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)
    ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)# * -1
    return token_type_ids, ref_token_type_ids

def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)

    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids
    
def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)

def predict(inputs):
    #print('model(inputs): ', model(inputs))
    return model(inputs)[0]

def custom_forward(inputs):
    preds = predict(inputs)
    return torch.softmax(preds, dim = 1)[:, 0] # for negative attribution, torch.softmax(preds, dim = 1)[:, 1] <- for positive attribution

def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

In [None]:
## model
# Get model and config files from https://huggingface.co/lvwerra/bert-imdb
# !wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/config.json
# !wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/pytorch_model.bin
# !wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/special_tokens_map.json
# !wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/tokenizer_config.json
# !wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/training_args.bin
# !wget -P ./model https://s3.amazonaws.com/models.huggingface.co/bert/lvwerra/bert-imdb/vocab.txt

# load model
model = BertForSequenceClassification.from_pretrained('./model')
model.to(device)
model.eval()
model.zero_grad()

# load tokenizer
tokenizer = BertTokenizer.from_pretrained('./model')
# ref_token_id/sep_token_id/cls_token_id: 0, 102, 101
ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence

# model

lig = LayerIntegratedGradients(custom_forward, model.bert.embeddings)

#saved_act = None
def save_act(module, inp, out):
    #global saved_act
    #saved_act = out
    return saved_act

hook = model.bert.embeddings.register_forward_hook(save_act)
hook.remove()

In [None]:
# One can test a couple of examples and check that the sentiment classifier is behaving
text = "The first movie is great but the second is horrible and bad" #"The movie was one of those amazing movies"#"The movie was one of those amazing movies you can not forget"
# text = "The movie was one of those crappy movies you can't forget."

In [None]:
# run captum
input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id)
token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
attention_mask = construct_attention_mask(input_ids)

indices = input_ids[0].detach().tolist()
all_tokens = tokenizer.convert_ids_to_tokens(indices)


attributions, delta = lig.attribute(inputs=input_ids,
                                    baselines=ref_input_ids,
                                    n_steps=700,
                                    internal_batch_size=5,
                                    return_convergence_delta=True)

attributions_sum = summarize_attributions(attributions)

pred = predict(input_ids)
score = torch.softmax(pred, dim = 1)  # batch가 가능하겠어

batch_idx = 0
index_predict = int(torch.argmax(score[batch_idx]).cpu().numpy())  # '0' or '1'
print(f'Sentiment: {index_predict}')

predict_index_probability = score.cpu().detach().numpy().squeeze()[index_predict]
print(f'Probability {index_predict}: {predict_index_probability*100:.3f}')

index_true = 0

# storing couple samples in an array for visualization purposes
score_vis = viz.VisualizationDataRecord(attributions_sum,
                                        predict_index_probability,  # predict probability
                                        index_predict,  # index_predict, [9.9928e-01, 7.1609e-04] 중에 높은 확률의index를 추출, 0이 추출된다 
                                        index_true,  # index_true
                                        text,
                                        attributions_sum.sum(),       
                                        all_tokens,
                                        delta)

viz.visualize_text([score_vis]);