In [1]:
import os, sys, numpy as np, pickle, random
from sklearn.neighbors import kneighbors_graph

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from tqdm import tqdm
from datasets import load_dataset
from dig import DiscretetizedIntegratedGradients
from attributions import run_dig_explanation
from metrics import eval_log_odds, eval_comprehensiveness, eval_sufficiency
import monotonic_paths
from captum.attr._utils.common import _reshape_and_sum, _validate_input

In [3]:
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7fe1da7bd0f0>

In [4]:
tokenizer = AutoTokenizer.from_pretrained('textattack/bert-base-uncased-SST-2')
model= AutoModelForSequenceClassification.from_pretrained('textattack/bert-base-uncased-SST-2')

In [5]:
device= torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
model.to(device)
model.eval()
model.zero_grad()

cpu


In [None]:
def predict(model, inputs_embeds, attention_mask=None):
    return model(inputs_embeds=inputs_embeds, attention_mask=attention_mask)[0]

def nn_forward_func(model, input_embed, attention_mask=None, position_embed=None, type_embed=None, return_all_logits=False):
    embeds	= input_embed + position_embed + type_embed
    embeds	= model.bert.embeddings.dropout(model.bert.embeddings.LayerNorm(embeds))
    pred	= predict(model, embeds, attention_mask=attention_mask)
    if return_all_logits:
        return pred
    else:
        return pred.max(1).values

# def load_mappings(dataset, knn_nbrs=500):
#     with open(f'knn/bert_{dataset}_{knn_nbrs}.pkl', 'rb') as f:
#         [word_idx_map, word_features, adj] = pickle.load(f)
#     word_idx_map	= dict(word_idx_map)

    return word_idx_map, word_features, adj

def construct_input_ref_pair(tokenizer, text, ref_token_id, sep_token_id, cls_token_id, device):
	text_ids		= tokenizer.encode(text, add_special_tokens=False, truncation=True,max_length=tokenizer.max_len_single_sentence)
	input_ids		= [cls_token_id] + text_ids + [sep_token_id]	# construct input token ids
	ref_input_ids	= [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]	# construct reference token ids

	return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device)

def construct_input_ref_pos_id_pair(input_ids, device):
	seq_length			= input_ids.size(1)
	position_ids 		= torch.arange(seq_length, dtype=torch.long,device=device).unsqueeze(0)
	ref_position_ids	= torch.zeros(seq_length, dtype=torch.long,device=device).unsqueeze(0)

	return position_ids, ref_position_ids

def construct_input_ref_token_type_pair(input_ids, device):
	seq_len				= input_ids.size(1)
	token_type_ids		= torch.tensor([[0] * seq_len], dtype=torch.long, device=device)
	ref_token_type_ids	= torch.zeros_like(token_type_ids, dtype=torch.long, device=device)
	return token_type_ids, ref_token_type_ids

def construct_attention_mask(input_ids):
	return torch.ones_like(input_ids)

def get_word_embeddings(model):
	return model.bert.embeddings.word_embeddings.weight

def construct_word_embedding(model, input_ids):
	return model.bert.embeddings.word_embeddings(input_ids)

def construct_position_embedding(model, position_ids):
	return model.bert.embeddings.position_embeddings(position_ids)

def construct_type_embedding(model, type_ids):
	return model.bert.embeddings.token_type_embeddings(type_ids)

def construct_sub_embedding(model, input_ids, ref_input_ids, position_ids, ref_position_ids, type_ids, ref_type_ids):
	input_embeddings				= construct_word_embedding(model, input_ids)
	ref_input_embeddings			= construct_word_embedding(model, ref_input_ids)
	input_position_embeddings		= construct_position_embedding(model, position_ids)
	ref_input_position_embeddings	= construct_position_embedding(model, ref_position_ids)
	input_type_embeddings			= construct_type_embedding(model, type_ids)
	ref_input_type_embeddings		= construct_type_embedding(model, ref_type_ids)

	return 	(input_embeddings, ref_input_embeddings), \
			(input_position_embeddings, ref_input_position_embeddings), \
			(input_type_embeddings, ref_input_type_embeddings)

def get_base_token_emb(model, tokenizer, device):
	return construct_word_embedding(model, torch.tensor([tokenizer.pad_token_id], device=device))

def get_tokens(tokenizer, text_ids):
	return tokenizer.convert_ids_to_tokens(text_ids.squeeze())

def get_inputs(model, tokenizer, text, device):
	ref_token_id = tokenizer.mask_token_id
	sep_token_id = tokenizer.sep_token_id
	cls_token_id = tokenizer.cls_token_id

	input_ids, ref_input_ids		= construct_input_ref_pair(tokenizer, text, ref_token_id, sep_token_id, cls_token_id, device)
	position_ids, ref_position_ids	= construct_input_ref_pos_id_pair(input_ids, device)
	type_ids, ref_type_ids			= construct_input_ref_token_type_pair(input_ids, device)
	attention_mask					= construct_attention_mask(input_ids)

	(input_embed, ref_input_embed), (position_embed, ref_position_embed), (type_embed, ref_type_embed) = \
				construct_sub_embedding(model, input_ids, ref_input_ids, position_ids, ref_position_ids, type_ids, ref_type_ids)

	return [input_ids, ref_input_ids, input_embed, ref_input_embed, position_embed, ref_position_embed, type_embed, ref_type_embed, attention_mask]

In [9]:
word_features		= get_word_embeddings().cpu().detach().numpy()
word_idx_map		= tokenizer.vocab
A					= kneighbors_graph(word_features, 500, mode='distance', n_jobs=-1)
auxiliary_data = [word_idx_map, word_features, A]

In [10]:
# Define the Attribution function
attr_func = DiscretetizedIntegratedGradients(nn_forward_func)

In [11]:
dataset= load_dataset('glue', 'sst2')['test']
data= list(zip(dataset['sentence'], dataset['label'], dataset['idx']))

Reusing dataset glue (/Users/ayankundu/.cache/huggingface/datasets/glue/sst2/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4)


In [12]:
all_outputs = []


def calculate_attributions(inputs, device, attr_func, base_token_emb, nn_forward_func, get_tokens):
    # computes the attributions for given input

    # move inputs to main device
    inp = [x.to(device) if x is not None else None for x in inputs]

    # compute attribution
    scaled_features, input_ids, ref_input_ids, input_embed, ref_input_embed, position_embed, ref_position_embed, type_embed, ref_type_embed, attention_mask = inp
    attr, deltaa = run_dig_explanation(attr_func, scaled_features, position_embed, type_embed, attention_mask, 125)

    # compute metrics
    log_odd, pred	= eval_log_odds(nn_forward_func, input_embed, position_embed, type_embed, attention_mask, base_token_emb, attr, topk=20)
    comp			= eval_comprehensiveness(nn_forward_func, input_embed, position_embed, type_embed, attention_mask, base_token_emb, attr, topk=20)
    suff			= eval_sufficiency(nn_forward_func, input_embed, position_embed, type_embed, attention_mask, base_token_emb, attr, topk=20)

    #return log_odd
    return log_odd, comp, suff, attr, deltaa

In [13]:
%%time
# get ref token embedding
base_token_emb = get_base_token_emb(device)

# compute the DIG attributions for all the inputs
print('Starting attribution computation...')
inputs = []
log_odds, comps, suffs, deltas, delta_pcs, count = 0, 0, 0, 0, 0, 0
print_step = 50

delta_pcs_list = []
for row in tqdm(data):
    inp = get_inputs(row[0], device)
    input_ids, ref_input_ids, input_embed, ref_input_embed, position_embed, ref_position_embed, type_embed, ref_type_embed, attention_mask = inp
    scaled_features 		= monotonic_paths.scale_inputs(input_ids.squeeze().tolist(), ref_input_ids.squeeze().tolist(),\
                                        device, auxiliary_data, method ="UIG", steps=30, nbrs = 50, factor=2, strategy='greedy')
    inputs					= [scaled_features, input_ids, ref_input_ids, input_embed, ref_input_embed, position_embed, ref_position_embed, type_embed, ref_type_embed, attention_mask]
    log_odd, comp, suff, attrib, delta= calculate_attributions(inputs, device, attr_func, base_token_emb, nn_forward_func, get_tokens)
    scaled_features_tpl = _format_input(scaled_features)
    start_point, end_point = _format_input(scaled_features_tpl[0][0].unsqueeze(0)), _format_input(scaled_features_tpl[0][-1].unsqueeze(0))# baselines, inputs (only works for one input, i.e. len(tuple) == 1)
    F_diff = (nn_forward_func(end_point[0],attention_mask,position_embed,type_embed).squeeze() - \
             nn_forward_func(start_point[0],attention_mask,position_embed,type_embed).squeeze())
    delta_pc = delta/F_diff*100
    log_odds	+= log_odd
    comps		+= comp
    suffs 		+= suff
    deltas+= torch.abs(delta).item()
    delta_pcs+= torch.abs(delta_pc).item()
    delta_pcs_list.append(torch.abs(delta_pc).item())
    count		+= 1

    # print the metrics
    if count % print_step == 0:
        print('Log-odds: ', np.round(log_odds / count, 4), 'Comprehensiveness: ', np.round(comps / count, 4),
              'Sufficiency: ', np.round(suffs / count, 4),  'Avg delta: ', np.round(deltas / count, 4),
              'Avg delta pct:', np.round(delta_pcs / count, 4))

print('Log-odds: ', np.round(log_odds / count, 4), 'Comprehensiveness: ', np.round(comps / count, 4),
      'Sufficiency: ', np.round(suffs / count, 4), 'Avg delta: ', np.round(deltas / count, 4),
              'Avg delta pct:', np.round(delta_pcs / count, 4))


  0%|          | 0/1821 [00:00<?, ?it/s]

Starting attribution computation...


  3%|▎         | 50/1821 [36:12<27:48:18, 56.52s/it]

Log-odds:  -0.7727 Comprehensiveness:  0.2292 Sufficiency:  0.2803 Avg delta:  0.1997 Avg delta pct: 164.4344


  5%|▌         | 100/1821 [1:14:30<23:08:31, 48.41s/it]

Log-odds:  -0.6978 Comprehensiveness:  0.1948 Sufficiency:  0.3013 Avg delta:  0.2543 Avg delta pct: 97.1936


  8%|▊         | 150/1821 [1:47:30<23:36:16, 50.85s/it]

Log-odds:  -0.7002 Comprehensiveness:  0.1959 Sufficiency:  0.3445 Avg delta:  0.2659 Avg delta pct: 77.2461


 11%|█         | 200/1821 [2:21:24<18:01:57, 40.05s/it]

Log-odds:  -0.7861 Comprehensiveness:  0.2314 Sufficiency:  0.3584 Avg delta:  0.2434 Avg delta pct: 72.8588


 14%|█▎        | 250/1821 [2:59:08<18:07:21, 41.53s/it]

Log-odds:  -0.8338 Comprehensiveness:  0.2439 Sufficiency:  0.3539 Avg delta:  0.2436 Avg delta pct: 63.8958


 16%|█▋        | 300/1821 [3:36:19<15:42:52, 37.19s/it]

Log-odds:  -0.8433 Comprehensiveness:  0.2583 Sufficiency:  0.3524 Avg delta:  0.2544 Avg delta pct: 101.1932


 19%|█▉        | 350/1821 [4:10:57<17:18:13, 42.35s/it]

Log-odds:  -0.8898 Comprehensiveness:  0.2601 Sufficiency:  0.3533 Avg delta:  0.2654 Avg delta pct: 91.4928


 22%|██▏       | 400/1821 [4:47:23<15:55:25, 40.34s/it]

Log-odds:  -0.9037 Comprehensiveness:  0.2642 Sufficiency:  0.3506 Avg delta:  0.2699 Avg delta pct: 83.6418


 25%|██▍       | 450/1821 [5:25:07<16:33:19, 43.47s/it]

Log-odds:  -0.8922 Comprehensiveness:  0.2701 Sufficiency:  0.3501 Avg delta:  0.2729 Avg delta pct: 83.0168


 27%|██▋       | 500/1821 [5:59:53<16:33:28, 45.12s/it]

Log-odds:  -0.8866 Comprehensiveness:  0.2677 Sufficiency:  0.341 Avg delta:  0.2816 Avg delta pct: 86.6758


 30%|███       | 550/1821 [6:36:31<13:44:32, 38.92s/it]

Log-odds:  -0.9296 Comprehensiveness:  0.2698 Sufficiency:  0.3381 Avg delta:  0.2797 Avg delta pct: 82.4906


 33%|███▎      | 600/1821 [7:21:23<21:49:13, 64.34s/it]

Log-odds:  -0.9326 Comprehensiveness:  0.2695 Sufficiency:  0.3402 Avg delta:  0.2781 Avg delta pct: 78.7615


 36%|███▌      | 650/1821 [7:59:35<19:23:29, 59.62s/it]

Log-odds:  -0.971 Comprehensiveness:  0.2743 Sufficiency:  0.3411 Avg delta:  0.2765 Avg delta pct: 75.7822


 38%|███▊      | 700/1821 [8:34:51<14:26:03, 46.35s/it]

Log-odds:  -0.9881 Comprehensiveness:  0.273 Sufficiency:  0.3342 Avg delta:  0.2825 Avg delta pct: 73.8866


 41%|████      | 750/1821 [9:11:35<9:59:10, 33.57s/it] 

Log-odds:  -1.0072 Comprehensiveness:  0.2743 Sufficiency:  0.3311 Avg delta:  0.2915 Avg delta pct: 83.9861


 44%|████▍     | 800/1821 [9:47:35<12:27:55, 43.95s/it]

Log-odds:  -1.002 Comprehensiveness:  0.2731 Sufficiency:  0.3314 Avg delta:  0.2881 Avg delta pct: 80.3386


 47%|████▋     | 850/1821 [10:27:14<13:12:46, 48.99s/it]

Log-odds:  -0.9968 Comprehensiveness:  0.2726 Sufficiency:  0.3365 Avg delta:  0.2917 Avg delta pct: 129.6338


 49%|████▉     | 900/1821 [11:06:42<12:05:42, 47.28s/it]

Log-odds:  -1.0206 Comprehensiveness:  0.2767 Sufficiency:  0.3342 Avg delta:  0.2945 Avg delta pct: 126.9899


 52%|█████▏    | 950/1821 [11:41:33<11:48:46, 48.82s/it]

Log-odds:  -1.0103 Comprehensiveness:  0.2726 Sufficiency:  0.3327 Avg delta:  0.2942 Avg delta pct: 127.3313


 55%|█████▍    | 1000/1821 [12:18:08<10:58:41, 48.14s/it]

Log-odds:  -1.0101 Comprehensiveness:  0.2752 Sufficiency:  0.3326 Avg delta:  0.3001 Avg delta pct: 125.182


 58%|█████▊    | 1050/1821 [12:56:12<8:09:31, 38.10s/it] 

Log-odds:  -1.0391 Comprehensiveness:  0.2789 Sufficiency:  0.3347 Avg delta:  0.2991 Avg delta pct: 120.8039


 60%|██████    | 1100/1821 [13:33:20<6:32:59, 32.70s/it] 

Log-odds:  -1.0128 Comprehensiveness:  0.2746 Sufficiency:  0.3352 Avg delta:  0.2994 Avg delta pct: 116.7374


 63%|██████▎   | 1150/1821 [14:10:36<8:38:00, 46.32s/it] 

Log-odds:  -1.0157 Comprehensiveness:  0.2752 Sufficiency:  0.3343 Avg delta:  0.2978 Avg delta pct: 112.6728


 66%|██████▌   | 1200/1821 [14:49:18<8:17:08, 48.03s/it] 

Log-odds:  -1.035 Comprehensiveness:  0.2788 Sufficiency:  0.336 Avg delta:  0.297 Avg delta pct: 126.0367


 69%|██████▊   | 1250/1821 [15:29:16<9:20:56, 58.94s/it] 

Log-odds:  -1.0326 Comprehensiveness:  0.2777 Sufficiency:  0.3395 Avg delta:  0.305 Avg delta pct: 123.1208


 71%|███████▏  | 1300/1821 [16:07:05<6:41:44, 46.27s/it] 

Log-odds:  -1.0381 Comprehensiveness:  0.281 Sufficiency:  0.3384 Avg delta:  0.309 Avg delta pct: 119.8531


 74%|███████▍  | 1350/1821 [16:42:48<5:15:41, 40.22s/it]

Log-odds:  -1.0247 Comprehensiveness:  0.2798 Sufficiency:  0.3427 Avg delta:  0.3054 Avg delta pct: 116.5024


 77%|███████▋  | 1400/1821 [17:17:41<3:12:34, 27.44s/it]

Log-odds:  -1.0327 Comprehensiveness:  0.2823 Sufficiency:  0.3406 Avg delta:  0.3071 Avg delta pct: 118.6266


 80%|███████▉  | 1450/1821 [17:55:46<5:26:23, 52.79s/it]

Log-odds:  -1.0275 Comprehensiveness:  0.2806 Sufficiency:  0.342 Avg delta:  0.3096 Avg delta pct: 119.5681


 82%|████████▏ | 1500/1821 [18:37:00<4:41:43, 52.66s/it]

Log-odds:  -1.0396 Comprehensiveness:  0.2817 Sufficiency:  0.3434 Avg delta:  0.3092 Avg delta pct: 117.7329


 85%|████████▌ | 1550/1821 [19:16:57<3:48:00, 50.48s/it]

Log-odds:  -1.0237 Comprehensiveness:  0.279 Sufficiency:  0.3463 Avg delta:  0.3071 Avg delta pct: 115.1804


 88%|████████▊ | 1600/1821 [19:56:21<2:27:25, 40.03s/it]

Log-odds:  -1.0157 Comprehensiveness:  0.279 Sufficiency:  0.3464 Avg delta:  0.3051 Avg delta pct: 114.3882


 91%|█████████ | 1650/1821 [20:31:42<1:57:39, 41.28s/it]

Log-odds:  -1.0177 Comprehensiveness:  0.2781 Sufficiency:  0.35 Avg delta:  0.3094 Avg delta pct: 112.8573


 93%|█████████▎| 1700/1821 [21:12:14<1:42:07, 50.64s/it]

Log-odds:  -1.013 Comprehensiveness:  0.276 Sufficiency:  0.3497 Avg delta:  0.3101 Avg delta pct: 114.8142


 96%|█████████▌| 1750/1821 [21:51:19<55:38, 47.01s/it]  

Log-odds:  -1.006 Comprehensiveness:  0.2748 Sufficiency:  0.3478 Avg delta:  0.3109 Avg delta pct: 115.0807


 99%|█████████▉| 1800/1821 [22:29:28<15:10, 43.36s/it]  

Log-odds:  -1.004 Comprehensiveness:  0.2747 Sufficiency:  0.347 Avg delta:  0.312 Avg delta pct: 115.3253


100%|██████████| 1821/1821 [22:42:36<00:00, 44.90s/it]

Log-odds:  -1.0049 Comprehensiveness:  0.2751 Sufficiency:  0.3479 Avg delta:  0.3135 Avg delta pct: 114.5416
CPU times: user 22h 16min 4s, sys: 50min 40s, total: 23h 6min 45s
Wall time: 22h 42min 36s





In [14]:
len(delta_pcs_list)

1821

In [16]:
import statistics

In [17]:
statistics.median(delta_pcs_list)

11.941275596618652