In [1]:
import torch
import pandas as pd
from tqdm import tqdm

from transformers import AutoTokenizer, AutoModelForSequenceClassification

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from captum.attr import visualization
from captum.attr import IntegratedGradients, DeepLiftShap, DeepLift

In [11]:
from utils.wrappers import CLFWrapper
from utils.utils import *
from utils.attr_utils import *
from utils.vis_utils import *

In [4]:
# Initialize Model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device("cuda:0" if torch.cuda.is_available() \
# 					else ("mps" if torch.backends.mps.is_available() else "cpu"))

model_path = "klue/roberta-base"

model = AutoModelForSequenceClassification.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)

# Load Model
model.to(device)
model.eval()
model.zero_grad()

Some weights of the model checkpoint at klue/roberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.layer_norm.bias', 'lm_head.decoder.bias', 'lm_head.decoder.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at klue/roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.out_proj.bias', 'classifie

In [5]:
# Sample Data
samples = [
    "This is sample sentence 1",
    "This is sample sentence 2"
]

labels = [0,1]

In [6]:


# DeepLift
def interpret_sentence_deeplift(explainer, model_wrapper, tokenizer, sentence, label=1):
    # Build Input & Reference IDs
    # print(sentence)
    input_ids, ref_input_ids = build_input_ref_pair(sentence, tokenizer)
    input_ids = input_ids.to(device)
    ref_input_ids = ref_input_ids.to(device)
    # print(input_ids.shape, ref_input_ids.shape)
    
    # Calc Embeddings
    input_embedding = model_wrapper.get_embeddings(input_ids)
    ref_embedding = model_wrapper.get_embeddings(ref_input_ids)
    
    # Get Model Predictions
    # print(input_embedding)
    # print(input_embedding.size())
    pred, prob = model_wrapper.predict(inputs_embeds = input_embedding)#.item()
    # print(prob)
    pred = pred[0].item()
    prob = prob.item()
    # pred_ind = round(pred)
    
    # Compute Attributions & Delta
    attributions, delta = explainer.attribute(input_embedding, baselines = ref_embedding, return_convergence_delta=True)
    print('pred: ', pred, '(', '%.2f' % prob, ')', ', delta: ', abs(delta))
    
    # Aggregate
    attributions = aggregate_attributions(attributions)
    
    # Make Visualiztion Data Record
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0].detach().cpu().numpy().tolist())    
    record = make_viz_record(attributions, tokens, prob, pred, label, delta)
    return attributions, record
    

In [7]:
# Make Wrapper
wrapper = CLFWrapper(model, pretrained_model = "roberta", embedding_layer = "embeddings")

# Make Explainer
explainer = DeepLift(wrapper)

In [8]:
vis_data_records_ig = []

for sample, label in tqdm(zip(samples,labels)):
    attributions, record = interpret_sentence_deeplift(explainer, wrapper, tokenizer, sample, label=label)
    vis_data_records_ig.append(record)
    

               activations. The hooks and attributes will be removed
            after the attribution is finished
1it [00:00,  3.33it/s]

pred:  1 ( 0.54 ) , delta:  tensor([0.0135])


2it [00:00,  3.59it/s]

pred:  1 ( 0.54 ) , delta:  tensor([0.0053])





In [9]:
# Visualize - Display
visualization_html = visualization.visualize_text(vis_data_records_ig)

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,1 (0.54),label,-1.33,[CLS] Th ##is is s ##amp ##le se ##nt ##ence 1 [SEP]
,,,,
1.0,1 (0.54),label,0.58,[CLS] Th ##is is s ##amp ##le se ##nt ##ence 2 [SEP]
,,,,


In [12]:
# Only get HTML
visualization_html = visualize_text_without_display(vis_data_records_ig)

# Write File
with open("visualization.html", "w", encoding = "utf-8") as f:
    f.write(visualization_html.data)