Load fine-tuned model

In [1]:
import os
import sys
#correct path
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [12]:
from transformers import BertTokenizer, BertForSequenceClassification
import numpy as np
from captum.attr import visualization as viz
import thesis_eliott.utilities as ut

Explore data and ground truth explanations

In [39]:
# choose which instance to inspect
INDEX = 15
#DATASET = 'movies'
DATASET = 'tweets'

Read data

In [40]:
if DATASET == 'movies':
    # read movies data
    K = 0.176
    model_path = '../models/bert-base-uncased-finetuned-movies'
    model = BertForSequenceClassification.from_pretrained(model_path, output_attentions=True)
    tokenizer = BertTokenizer.from_pretrained(model_path)
    from data.movies.utils import load_documents, load_datasets, annotations_from_jsonl, Annotation
    data_root = os.path.join(module_path, 'data', 'movies')
    documents = load_documents(data_root)
    test = annotations_from_jsonl(os.path.join(data_root, 'train.jsonl'))

    # extract one example
    annotation = test[INDEX]
    review = documents[annotation.annotation_id]
    input_text_list = [word for sentence in review for word in sentence]
    evidences = annotation.all_evidences()
    label = annotation.classification

    #tokenize input
    input_text_str = " ".join(input_text_list)
    tokenized_input = tokenizer.encode_plus(input_text_list, return_tensors='pt', truncation=True, is_split_into_words=True) #truncated to 512 tokens
    input_ids = tokenized_input['input_ids']
    token_type_ids = tokenized_input['token_type_ids']
    input_id_list = input_ids[0].tolist()
    tokens = tokenizer.convert_ids_to_tokens(input_id_list)

    #extract indices in input list for evidences
    indices_tmp_or = []
    words_tmp = []
    for ev in evidences:
        words_tmp.append(ev.text)
        idx_range = list(range(ev.start_token, ev.end_token))
        indices_tmp_or.append(idx_range)
    evidence_indices = [idx for indices in indices_tmp_or for idx in indices]
    #make binary list from evidence indices original
    ev_bin_or = ut.indices_to_binary(input_text_list, evidence_indices)

    #tokenize explanations and extract indices in input
    indices_tmp = []
    tokens_tmp = []
    for ev in evidences:
        if ev.text != "":
            tokenized_expl = tokenizer.encode_plus(ev.text, return_tensors='pt', truncation=True, add_special_tokens=False)
            input_ids_expl = tokenized_expl['input_ids']
            input_id_list_expl = input_ids_expl[0].tolist()
            tokens_expl = tokenizer.convert_ids_to_tokens(input_id_list_expl)
            expl_indices = ut.find_indices(tokens_expl, tokens)
            if expl_indices != None: #explanations on content above 512 tokens
                indices_tmp.append(expl_indices)
                tokens_tmp.append(tokens_expl)
    explanation_indices = [idx for indices in indices_tmp for idx in indices]
    explanation_tokens = [t for tkns in tokens_tmp for t in tkns]
    #make binary list from evidence indices
    ev_bin = ut.indices_to_binary(tokens, explanation_indices)
elif DATASET == 'tweets':
    # import fine-tuned model and tokenizer
    model_path = '../models/bert-base-uncased-finetuned-tweets'
    # read tweets data
    K = 0.599
    model = BertForSequenceClassification.from_pretrained(model_path, output_attentions=True)
    tokenizer = BertTokenizer.from_pretrained(model_path)
    from datasets import load_dataset
    #csv_file = 'test_2k_for_explain.csv'
    csv_file = 'train.csv'
    if csv_file == 'test_2k_for_explain.csv':
        label_name = 'label'
    elif csv_file == 'train.csv':
        label_name = 'sentiment'
    dataset_raw = load_dataset("/workspace/data/tweet-sentiment-extraction", data_files={'train': csv_file})
    train = dataset_raw['train']

    # extract one example
    instance = train[INDEX]
    input_text_list = instance['text'].split()
    label = instance[label_name]
    evidences = instance['selected_text'].split()
    print(evidences)

    # tokenize input
    tokenized_input = tokenizer.encode_plus(input_text_list, return_tensors='pt', truncation=True, is_split_into_words=True) #truncated to 512 tokens
    input_ids = tokenized_input['input_ids']
    token_type_ids = tokenized_input['token_type_ids']
    input_id_list = input_ids[0].tolist()
    tokens = tokenizer.convert_ids_to_tokens(input_id_list)

    #tokenize explanations and extract indices in input
    tokenized_expl = tokenizer.encode_plus(evidences, return_tensors='pt', truncation=True, is_split_into_words=True, add_special_tokens=False) #truncated to 512 tokens
    input_ids_expl = tokenized_expl['input_ids']
    input_id_list_expl = input_ids_expl[0].tolist()
    tokens_expl = tokenizer.convert_ids_to_tokens(input_id_list_expl)
    evidence_indices = ut.find_indices(tokens_expl, tokens)
    ev_bin = ut.indices_to_binary(tokens, evidence_indices)
expl_len = int(len(tokens)*K) #0.599 for twitter, 0.176 for movies

Using custom data configuration tweet-sentiment-extraction-4e1d6c5cd620c166
Reusing dataset csv (/home/eliott.remmer/.cache/huggingface/datasets/csv/tweet-sentiment-extraction-4e1d6c5cd620c166/0.0.0/6b9057d9e23d9d8a2f05b985917a0da84d70c5dae3d22ddd8a3f22fb01c69d9e)


  0%|          | 0/1 [00:00<?, ?it/s]

['Uh', 'oh,', 'I', 'am', 'sunburned']


Extract random explanation

In [41]:
#random explanation indices
np.random.seed(0)
input_index_range = list(range(len(tokens)))
random_expl_indices = np.random.choice(input_index_range, expl_len)
ev_bin_random = ut.indices_to_binary(tokens, random_expl_indices)

#random explanation tokens
random_expl = np.random.choice(tokens, expl_len)

#print("Ground truth explanation original")
#viz_rec_gt_or = [viz.VisualizationDataRecord(ev_bin_or, 1, 1, label, 1, 1, input_text_list, None)]
#heatmap = viz.visualize_text(viz_rec_gt_or)

print("Ground truth explanation tokenized")
viz_rec_gt = [viz.VisualizationDataRecord(ev_bin, 1, 1, label, 1, 1, tokens, None)]
heatmap = viz.visualize_text(viz_rec_gt)

print("Random explanation")
viz_rec_random = [viz.VisualizationDataRecord(ev_bin_random, 1, 1, label, 1, 1, tokens, None)]
heatmap = viz.visualize_text(viz_rec_random)

Ground truth explanation tokenized


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
negative,1 (1.00),1.0,1.0,"[CLS] uh oh , i am sun ##burn ##ed [SEP]"
,,,,


Random explanation


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
negative,1 (1.00),1.0,1.0,"[CLS] uh oh , i am sun ##burn ##ed [SEP]"
,,,,


Extract attention-based explantions 

In [26]:
# run input through model to get attention weights
attention = model(input_ids, token_type_ids=token_type_ids)[-1]

#get cls attention
cls_attn = ut.get_cls_attention(attention)

#concatenate tokens into words
input_words = ut.tokens2words(tokens, tokens)

#concatenate attention vector, remove weights to ##-tokens
cls_attn_words = ut.tokens2words(tokens, cls_attn)

print("\nAttention explanation tokens")
scaling_factor = 1
viz_rec_attn = [viz.VisualizationDataRecord(cls_attn*scaling_factor, 1, 1, label, 1, 1, tokens, None)]
heatmap = viz.visualize_text(viz_rec_attn)


Attention explanation tokens


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
NEG,1 (1.00),1.0,1.0,"[CLS] ah , and 1999 was going along so well , too . "" she ' s all that "" has the dubious distinction of being the worst movie i ' ve seen so far this year . and quite frankly , i doubt i ' ll see anything equally bad . ( at least , i * hope * i do n ' t see anything equally bad ) . "" she ' s all that "" tells the story of the most popular guy in school ( played by freddie pri ##nz ##e jr . ) who accepts a bet to transform the geek ##iest girl in school ( rachel leigh cook ) into the most popular . that , right there , is problem # 1 . how many times have we seen this storyline ? as cook comments near the end of the film , "" it ' s kind of like "" pretty woman "" , except without the prostitution "" . of course , had the filmmakers attempted to try something new with this material , the well - worn storyline would have been a device to prop ##ell the movie forward . as it is , though , "" she ' s all that "" relies * completely * on the lame and over ##used formula to push it ahead . there ' s not one original or interesting character in the film , either , and if that was n ' t bad enough , there ' s not one good performance featured . the star of the movie , rachel leigh cook , is simply horrible . i usually do n ' t like to get so personal , but in this case , i think it needs to be said . cook wears the same expression throughout the flick and looks to be having as miserable a time as i was . i was never convinced that she was a "" ne ##rd "" , and her transformation was un ##con ##vin ##cing and unnecessary . the movie seems to be saying it ' s better to be popular than to be who you are . as for freddie pri ##nz ##e jr . , an actor i or ##dina ##rily enjoy , he too is quite bad here . he coasts through the film on so - called charm , and never establishes a real character . ki ##ere ##n cu ##lk ##in is here , too , as the brother of cook . and for some ind ##is ##cer ##nable reason , he ' s got hearing aids . no explanation is given and they ' re never brought up . were we supposed to feel * sorry * for him just because he wore hearing aids ? i do n ' t think so . that single element of the film was one of the most offensive things i ' ve seen in a movie in a long time . "" she ' s all that "" sucks [SEP]"
,,,,


Extract SHAP explanations

In [20]:
import torch
import shap
import functools

def predict_fn(input_ids, attention_mask=None, batch_size=32, label=None,
            output_logits=False, repeat_input_ids=False, device='cpu'):
    """
    Wrapper function for a Huggingface Transformers model into the format that KernelSHAP expects,
    i.e. where inputs and outputs are numpy arrays.
    """

    model.to(device)
    input_ids = torch.tensor(input_ids)
    attention_mask = torch.ones_like(input_ids) if attention_mask is None else torch.tensor(attention_mask)

    if repeat_input_ids:
        assert input_ids.shape[0] == 1
        input_ids = input_ids.repeat(attention_mask.shape[0], 1)
 
    ds = torch.utils.data.TensorDataset(input_ids.long(), attention_mask.long())
    dl = torch.utils.data.DataLoader(ds, batch_size=batch_size)
    probas = []
    logits = []
    with torch.no_grad():
        for batch in dl:
            out = model(batch[0].to(device), attention_mask=batch[1].to(device))
            logits.append(out.logits.detach().cpu())
            probas.append(torch.nn.functional.softmax(out.logits.detach().cpu(),
                                                      dim=1).detach())
    logits = torch.cat(logits, dim=0).numpy()
    probas = torch.cat(probas, dim=0).numpy()

    if label is not None:
        probas = probas[:, label]
        logits = logits[:, label]

    return (probas, logits) if output_logits else probas

#make prediction on input
input_ids_np = input_ids.detach().numpy()
pred = predict_fn(input_ids.detach().numpy())
pred_label = pred.argmax()
pred_p = pred[0, pred_label]

#create baseline
baseline = input_ids_np.copy()
ref_token = tokenizer.mask_token_id # Could also consider [UNK] or [PAD] tokens
baseline[:, 1:-1] = ref_token # Keep [CLS] and [SEP] tokens fixed in baseline

#define explainer
predict_fn_label = functools.partial(predict_fn, label=pred_label) #creates a copy of predict_fn which always sends the predicted label
explainer = shap.KernelExplainer(predict_fn_label, baseline)

#get shap values (~1 minute with nsamples = 500)
nsamples = 500
phi = explainer.shap_values(input_ids_np, nsamples=nsamples)[0]

#concatenate phi vector, remove weights to ##-tokens
#phi_words = ut.tokens2words(tokens, phi)

#print("\nSHAP explanation words")
#viz_rec_shap_words = [viz.VisualizationDataRecord(phi_words*scaling_factor, 1, 1, label, 1, 1, input_words, None)]
#heatmap = viz.visualize_text(viz_rec_shap_words)

#print("\nSHAP explanation tokens")
#viz_rec_shap = [viz.VisualizationDataRecord(phi*scaling_factor, 1, 1, label, 1, 1, tokens, None)]
#heatmap = viz.visualize_text(viz_rec_shap)

  0%|          | 0/1 [00:00<?, ?it/s]

The default of 'normalize' will be set to False in version 1.2 and deprecated in version 1.4.
If you wish to scale the data, use Pipeline with a StandardScaler in a preprocessing stage. To reproduce the previous behavior:

from sklearn.pipeline import make_pipeline

model = make_pipeline(StandardScaler(with_mean=False), LassoLarsIC())

If you wish to pass a sample_weight parameter, you need to pass it as a fit parameter to each step of the pipeline as follows:

kwargs = {s[0] + '__sample_weight': sample_weight for s in model.steps}
model.fit(X, y, **kwargs)

Set parameter alpha to: original_alpha * np.sqrt(n_samples). 


In [22]:
scaling_factor = 10
print("\nSHAP explanation tokens")
viz_rec_shap = [viz.VisualizationDataRecord(phi*scaling_factor, 1, 1, label, 1, 1, tokens, None)]
heatmap = viz.visualize_text(viz_rec_shap)
print(label)
print(pred)
print(pred_label)


SHAP explanation tokens


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
POS,1 (1.00),1.0,1.0,"[CLS] carry on mat ##ron is the last great carry - on film in my opinion . made in 1972 , it still features most of the regulars of this genre . sid james plays the head of a gang of crook ##s intent on stealing contra ##ceptive pills from the local maternity hospital and selling them off to make a profit , kenneth williams is sir bernard cutting , head of the hospital but also a h ##yp ##och ##ond ##ria ##c , hat ##tie jacques re - takes her role as mat ##ron , and charles ha ##wt ##rey is the psychiatrist dr . good ##e . the jokes come in fast and so do the laughs , with humorous antics between the mat ##ron and sir bernard . this time , williams is after mat ##ron ( jacques ) when he needs to prove himself that he is a man after visiting dr . good ##e . the doctor ( ha ##wt ##rey ) is sworn to secrecy : - "" i assure you , that anything you say to me today will go in one ear and straight out of the other ! "" bernard thinks he is having a sex change and needs to prove himself . there are great cameo ##s from joan sims , an expect ##ant mother who is many weeks over ##due and is eating constantly in every scene , and from kenneth connor as the expect ##ant father who still thinks he ' s at work ( at the railway station ) . sid james ' gang including cyril his son ( kenneth cope ) try to find out where the pills are by getting cyril to dress up as a nurse and live in at the nurses home . he has to share a room with nurse ball ( barbara windsor ) and she soon sees through him ! the film ends with the attempted robbery of the hospital ' s pills with panic ensuing . there are good performances by hat ##tie jacques as the mat ##ron , however her character seems a little more subdued and quieter than her previous ' mat ##ron ' s ' . williams is , as usual , on top form , but sid james is n ' t given a very good part in this movie and i would forgive anyone to forget that he was in the movie at all . the same does not go for charles ha ##wt ##rey , for although he only first appears in the movie after 30 minutes gone and has scarce screen time , he seems to steal every scene he is in . by not making use of sid james and barbara windsor ' s talent to full effect , the film seems to fl ##ound ##er , but it certainly makes up for it with it ' s good storyline and it ' s other appealing characters . this film is genuinely funny and i could watch it again and again and [SEP]"
,,,,


POS
[[0.00865631 0.99134374]]
1


Compare and evaluate

IOU

In [None]:
# 0.176 average fraction of explanations in movies dataset
#k = 0.176
explanation_indices = evidence_indices
ks = np.linspace(0, 1, 100)
ious_tok_attn = []
ious_ind_attn = []
ious_tok_shap = []
ious_ind_shap = []
for k in ks:
    #print("k:", k)
    #indices, tokens
    attn_expl_ind = ut.get_top_k(cls_attn, tokens, k=k, output_indices=True, omit_scores=True, positive_only=True)
    shap_expl_ind = ut.get_top_k(phi, tokens, k=k, output_indices=True, omit_scores=True, positive_only=True)

    #random_iou = ut.calculate_iou(set(explanation_indices), set(random_expl_indices))
    attn_iou_ind = ut.calculate_iou(set(explanation_indices), set(attn_expl_ind))
    shap_iou_ind = ut.calculate_iou(set(explanation_indices), set(shap_expl_ind))
    
    #print("IOU indices for tokens")
    #print("Attention:", attn_iou_ind)
    #print("SHAP:", shap_iou_ind)
    ious_ind_attn.append(attn_iou_ind)
    ious_ind_shap.append(shap_iou_ind)

    #tokens
    attn_expl_tok = ut.get_top_k(cls_attn, tokens, k=k, output_indices=False, omit_scores=True, positive_only=True)
    shap_expl_tok = ut.get_top_k(phi, tokens, k=k, output_indices=False, omit_scores=True, positive_only=True)

    #random_iou_tok = ut.calculate_iou(set(explanation_tokens), set(random_expl))
    attn_iou_tok = ut.calculate_iou(set(explanation_tokens), set(attn_expl_tok))
    shap_iou_tok = ut.calculate_iou(set(explanation_tokens), set(shap_expl_tok))

    #print("\nIOU tokens")
    #print("Attention:", attn_iou_tok)
    #print("SHAP:", shap_iou_tok)
    #print("\n")
    ious_tok_attn.append(attn_iou_tok)
    ious_tok_shap.append(shap_iou_tok)


In [None]:
import matplotlib.pyplot as plt
plt.title('Attention IOU, token-set')
plt.xlabel('fraction of tokens included')
plt.axvline(x=len(explanation_indices)/len(tokens), linestyle='--', label='ground truth expl fraction')
plt.ylabel('IOU')
plt.plot(ks, ious_tok_attn, label='attention')
plt.plot(ks, ious_tok_shap, label='shap')
plt.legend()
plt.show()

plt.title('Attention IOU, indices of tokens')
plt.xlabel('fraction of indices included')
plt.ylabel('IOU')
plt.axvline(x=len(explanation_indices)/len(tokens), linestyle='--', label='ground truth expl fraction')
plt.plot(ks, ious_ind_attn, label='attention')
plt.plot(ks, ious_ind_shap, label='shap')
plt.legend()
plt.show()

print("input len:", len(tokens))
print("ground truth len:", len(explanation_indices))
print("attn len:", len(attn_expl_tok))
print("shap len:", len(shap_expl_tok))
print("gt ratio:", len(explanation_indices)/len(tokens))


PR AUC

In [None]:
from sklearn.metrics import precision_recall_curve, auc, roc_curve, roc_auc_score

print("attention")
precision, recall, thresholds = precision_recall_curve(ev_bin[1:-1], cls_attn[1:-1])
pr_auc = auc(recall, precision)
fpr, tpr, thresholds2 = roc_curve(ev_bin[1:-1], cls_attn[1:-1])
print("pr auc:", pr_auc)
plt.xlabel('recall')
plt.ylabel('precision')
plt.title('pr curve')
plt.xlim(0, 1.01)
plt.ylim(0, 1.01)
plt.plot(recall, precision)
plt.show()

roc_auc = roc_auc_score(ev_bin, cls_attn)
print("roc auc:", roc_auc)
plt.xlabel('false positive rate')
plt.ylabel('true positive rate')
plt.title('roc curve')
plt.xlim(-0.01, 1.01)
plt.ylim(-0.01, 1.01)
plt.plot(fpr, tpr)
plt.show()

print("shap")
precision, recall, thresholds = precision_recall_curve(ev_bin[1:-1], phi[1:-1])
pr_auc = auc(recall, precision)
fpr, tpr, thresholds2 = roc_curve(ev_bin[1:-1], phi[1:-1])
print("pr auc:", pr_auc)
plt.xlabel('recall')
plt.ylabel('precision')
plt.title('pr curve')
plt.xlim(0, 1.01)
plt.ylim(0, 1.01)
plt.plot(recall, precision)
plt.show()

roc_auc = roc_auc_score(ev_bin, cls_attn)
print("roc auc:", roc_auc)
plt.xlabel('false positive rate')
plt.ylabel('true positive rate')
plt.title('roc curve')
plt.xlim(-0.01, 1.01)
plt.ylim(-0.01, 1.01)
plt.plot(fpr, tpr)
plt.show()

Compare explanations before threshold

In [179]:
print("GROUND TRUTH")
heatmap = viz.visualize_text(viz_rec_gt)
print("\nATTENTION")
heatmap = viz.visualize_text(viz_rec_attn)
print("\nSHAP")
heatmap = viz.visualize_text(viz_rec_shap)
print("\nRANDOM")
heatmap = viz.visualize_text(viz_rec_random)

GROUND TRUTH


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
negative,1 (1.00),1.0,1.0,[CLS] can ` t up ##load a picture . i already hate twitter [SEP]
,,,,



ATTENTION


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
negative,1 (1.00),1.0,1.0,[CLS] can ` t up ##load a picture . i already hate twitter [SEP]
,,,,



SHAP


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
negative,1 (1.00),1.0,1.0,[CLS] can ` t up ##load a picture . i already hate twitter [SEP]
,,,,



RANDOM


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
negative,1 (1.00),1.0,1.0,[CLS] can ` t up ##load a picture . i already hate twitter [SEP]
,,,,


Compare explanations after threshold

In [176]:
print("\nAttention explanation tokens binary")
attn_expl_ind = ut.get_top_k(cls_attn, tokens, k=K, output_indices=True, omit_scores=True, positive_only=True)
attn_bin = ut.indices_to_binary(tokens, attn_expl_ind)
viz_rec_attn_bin = [viz.VisualizationDataRecord(attn_bin, 1, 1, label, 1, 1, tokens, None)]
heatmap = viz.visualize_text(viz_rec_attn_bin)

print("\nSHAP explanation tokens binary")
shap_expl_ind = ut.get_top_k(phi, tokens, k=1, output_indices=True, omit_scores=True, positive_only=True)
shap_bin = ut.indices_to_binary(tokens, shap_expl_ind)
viz_rec_shap_bin = [viz.VisualizationDataRecord(shap_bin, 1, 1, label, 1, 1, tokens, None)]
heatmap = viz.visualize_text(viz_rec_shap_bin)
print(pred)
print(pred_label)


Attention explanation tokens binary


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
negative,1 (1.00),1.0,1.0,[CLS] can ` t up ##load a picture . i already hate twitter [SEP]
,,,,



SHAP explanation tokens binary


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
negative,1 (1.00),1.0,1.0,[CLS] can ` t up ##load a picture . i already hate twitter [SEP]
,,,,


[[0.9857985  0.01144176 0.00275973]]
0
