In [1]:
!pip install transformers==4.26.1
!pip install datasets
!pip install captum==0.3.1

Collecting transformers==4.26.1
  Downloading transformers-4.26.1-py3-none-any.whl (6.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.3/6.3 MB[0m [31m56.9 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.11.0 (from transformers==4.26.1)
  Downloading huggingface_hub-0.16.4-py3-none-any.whl (268 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m33.5 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers==4.26.1)
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m121.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.16.4 tokenizers-0.13.3 transformers-4.26.1
Collecting datasets
  Downloading datasets-2.14.5-py3-none-any.whl (519 kB)
[2K     [

In [2]:
# from google.colab import drive
# drive.mount('/content/drive')
# %cd /content/drive/MyDrive/Micropaper/

Mounted at /content/drive
/content/drive/MyDrive/Micropaper


In [1]:
import os, sys, numpy as np, pickle, random
from sklearn.neighbors import kneighbors_graph

In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from tqdm import tqdm
from datasets import load_dataset
from dig import DiscretetizedIntegratedGradients
from attributions import run_dig_explanation
from metrics import eval_log_odds, eval_comprehensiveness, eval_sufficiency
import monotonic_paths
from captum.attr._utils.common import _format_input_baseline, _reshape_and_sum, _validate_input, _format_input

In [3]:
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7fc4a81ab0f0>

In [4]:
tokenizer = AutoTokenizer.from_pretrained('textattack/bert-base-uncased-SST-2')
model= AutoModelForSequenceClassification.from_pretrained('textattack/bert-base-uncased-SST-2')

In [5]:
device= torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
model.to(device)
model.eval()
model.zero_grad()

cpu


In [6]:
def predict(model, inputs_embeds, attention_mask=None):
    return model(inputs_embeds=inputs_embeds, attention_mask=attention_mask)[0]

def nn_forward_func(input_embed, attention_mask=None, position_embed=None, type_embed=None, return_all_logits=False):
    global model
    embeds	= input_embed + position_embed + type_embed
    embeds	= model.bert.embeddings.dropout(model.bert.embeddings.LayerNorm(embeds))
    pred	= predict(model, embeds, attention_mask=attention_mask)
    if return_all_logits:
        return pred
    else:
        return pred.max(1).values

# def load_mappings(dataset, knn_nbrs=500):
#     with open(f'knn/bert_{dataset}_{knn_nbrs}.pkl', 'rb') as f:
#         [word_idx_map, word_features, adj] = pickle.load(f)
#     word_idx_map	= dict(word_idx_map)

    return word_idx_map, word_features, adj

def construct_input_ref_pair(tokenizer, text, ref_token_id, sep_token_id, cls_token_id, device):
	text_ids		= tokenizer.encode(text, add_special_tokens=False, truncation=True,max_length=tokenizer.max_len_single_sentence)
	input_ids		= [cls_token_id] + text_ids + [sep_token_id]	# construct input token ids
	ref_input_ids	= [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]	# construct reference token ids

	return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device)

def construct_input_ref_pos_id_pair(input_ids, device):
	global model
	seq_length			= input_ids.size(1)
	position_ids 		= torch.arange(seq_length, dtype=torch.long,device=device).unsqueeze(0)
	ref_position_ids	= torch.zeros(seq_length, dtype=torch.long,device=device).unsqueeze(0)

	return position_ids, ref_position_ids

def construct_input_ref_token_type_pair(input_ids, device):
	seq_len				= input_ids.size(1)
	token_type_ids		= torch.tensor([[0] * seq_len], dtype=torch.long, device=device)
	ref_token_type_ids	= torch.zeros_like(token_type_ids, dtype=torch.long, device=device)
	return token_type_ids, ref_token_type_ids

def construct_attention_mask(input_ids):
	return torch.ones_like(input_ids)

def get_word_embeddings():
	global model
	return model.bert.embeddings.word_embeddings.weight

def construct_word_embedding(model, input_ids):
	return model.bert.embeddings.word_embeddings(input_ids)

def construct_position_embedding(model, position_ids):
	return model.bert.embeddings.position_embeddings(position_ids)

def construct_type_embedding(model, type_ids):
	return model.bert.embeddings.token_type_embeddings(type_ids)

def construct_sub_embedding(model, input_ids, ref_input_ids, position_ids, ref_position_ids, type_ids, ref_type_ids):
	input_embeddings				= construct_word_embedding(model, input_ids)
	ref_input_embeddings			= construct_word_embedding(model, ref_input_ids)
	input_position_embeddings		= construct_position_embedding(model, position_ids)
	ref_input_position_embeddings	= construct_position_embedding(model, ref_position_ids)
	input_type_embeddings			= construct_type_embedding(model, type_ids)
	ref_input_type_embeddings		= construct_type_embedding(model, ref_type_ids)

	return 	(input_embeddings, ref_input_embeddings), \
			(input_position_embeddings, ref_input_position_embeddings), \
			(input_type_embeddings, ref_input_type_embeddings)

def get_base_token_emb(device):
	global model
	return construct_word_embedding(model, torch.tensor([tokenizer.pad_token_id], device=device))

def get_tokens(text_ids):
	global tokenizer
	return tokenizer.convert_ids_to_tokens(text_ids.squeeze())

def get_inputs(text, device):
	global model, tokenizer
	ref_token_id = tokenizer.pad_token_id
	sep_token_id = tokenizer.sep_token_id
	cls_token_id = tokenizer.cls_token_id

	input_ids, ref_input_ids		= construct_input_ref_pair(tokenizer, text, ref_token_id, sep_token_id, cls_token_id, device)
	position_ids, ref_position_ids	= construct_input_ref_pos_id_pair(input_ids, device)
	type_ids, ref_type_ids			= construct_input_ref_token_type_pair(input_ids, device)
	attention_mask					= construct_attention_mask(input_ids)

	(input_embed, ref_input_embed), (position_embed, ref_position_embed), (type_embed, ref_type_embed) = \
				construct_sub_embedding(model, input_ids, ref_input_ids, position_ids, ref_position_ids, type_ids, ref_type_ids)

	return [input_ids, ref_input_ids, input_embed, ref_input_embed, position_embed, ref_position_embed, type_embed, ref_type_embed, attention_mask]

In [7]:
word_features		= get_word_embeddings().cpu().detach().numpy()
word_idx_map		= tokenizer.vocab
A					= kneighbors_graph(word_features, 500, mode='distance', n_jobs=-1)
auxiliary_data = [word_idx_map, word_features, A]

In [8]:
# Define the Attribution function
attr_func = DiscretetizedIntegratedGradients(nn_forward_func)

In [10]:
dataset= load_dataset('glue', 'sst2')['test']
data= list(zip(dataset['sentence'], dataset['label'], dataset['idx']))

Reusing dataset glue (/Users/ayankundu/.cache/huggingface/datasets/glue/sst2/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4)


In [11]:
all_outputs = []


def calculate_attributions(inputs, device, attr_func, base_token_emb, nn_forward_func, get_tokens):
    # computes the attributions for given input

    # move inputs to main device
    inp = [x.to(device) if x is not None else None for x in inputs]

    # compute attribution
    scaled_features, input_ids, ref_input_ids, input_embed, ref_input_embed, position_embed, ref_position_embed, type_embed, ref_type_embed, attention_mask = inp
    attr, deltaa = run_dig_explanation(attr_func, scaled_features, position_embed, type_embed, attention_mask, 125)

    # compute metrics
    log_odd, pred	= eval_log_odds(nn_forward_func, input_embed, position_embed, type_embed, attention_mask, base_token_emb, attr, topk=20)
    comp			= eval_comprehensiveness(nn_forward_func, input_embed, position_embed, type_embed, attention_mask, base_token_emb, attr, topk=20)
    suff			= eval_sufficiency(nn_forward_func, input_embed, position_embed, type_embed, attention_mask, base_token_emb, attr, topk=20)

    #return log_odd
    return log_odd, comp, suff, attr, deltaa

In [12]:
%%time
# get ref token embedding
base_token_emb = get_base_token_emb(device)

# compute the DIG attributions for all the inputs
print('Starting attribution computation...')
inputs = []
log_odds, comps, suffs, deltas, delta_pcs, count = 0, 0, 0, 0, 0, 0
print_step = 50
for row in tqdm(data):
    inp = get_inputs(row[0], device)
    input_ids, ref_input_ids, input_embed, ref_input_embed, position_embed, ref_position_embed, type_embed, ref_type_embed, attention_mask = inp
    scaled_features 		= monotonic_paths.scale_inputs(input_ids.squeeze().tolist(), ref_input_ids.squeeze().tolist(),\
                                        device, auxiliary_data, method ="DIG", steps=30, nbrs = 50, factor=2, strategy='greedy')
    inputs					= [scaled_features, input_ids, ref_input_ids, input_embed, ref_input_embed, position_embed, ref_position_embed, type_embed, ref_type_embed, attention_mask]
    log_odd, comp, suff, attrib, delta= calculate_attributions(inputs, device, attr_func, base_token_emb, nn_forward_func, get_tokens)
    scaled_features_tpl = _format_input(scaled_features)
    start_point, end_point = _format_input(scaled_features_tpl[0][0].unsqueeze(0)), _format_input(scaled_features_tpl[0][-1].unsqueeze(0))# baselines, inputs (only works for one input, i.e. len(tuple) == 1)
    F_diff = (nn_forward_func(end_point[0],attention_mask,position_embed,type_embed).squeeze() - \
             nn_forward_func(start_point[0],attention_mask,position_embed,type_embed).squeeze())
    delta_pc = delta/F_diff*100
    log_odds	+= log_odd
    comps		+= comp
    suffs 		+= suff
    deltas+= torch.abs(delta).item()
    delta_pcs+= torch.abs(delta_pc).item()
    count		+= 1

    # print the metrics
    if count % print_step == 0:
        print('Log-odds: ', np.round(log_odds / count, 4), 'Comprehensiveness: ', np.round(comps / count, 4),
              'Sufficiency: ', np.round(suffs / count, 4),  'Avg delta: ', np.round(deltas / count, 4),
              'Avg delta pct:', np.round(delta_pcs / count, 4))

print('Log-odds: ', np.round(log_odds / count, 4), 'Comprehensiveness: ', np.round(comps / count, 4),
      'Sufficiency: ', np.round(suffs / count, 4), 'Avg delta: ', np.round(deltas / count, 4),
              'Avg delta pct:', np.round(delta_pcs / count, 4))


  0%|          | 0/1821 [00:00<?, ?it/s]

Starting attribution computation...


  3%|▎         | 50/1821 [35:08<28:03:41, 57.04s/it]

Log-odds:  -0.6249 Comprehensiveness:  0.1815 Sufficiency:  0.4127 Avg delta:  0.4539 Avg delta pct: 43.1137


  5%|▌         | 100/1821 [1:11:32<21:18:55, 44.59s/it]

Log-odds:  -0.6849 Comprehensiveness:  0.2062 Sufficiency:  0.3716 Avg delta:  0.4255 Avg delta pct: 33.249


  8%|▊         | 150/1821 [1:43:12<23:27:10, 50.53s/it]

Log-odds:  -0.8062 Comprehensiveness:  0.2383 Sufficiency:  0.3669 Avg delta:  0.4894 Avg delta pct: 37.728


 11%|█         | 200/1821 [2:16:34<17:48:34, 39.55s/it]

Log-odds:  -0.8277 Comprehensiveness:  0.256 Sufficiency:  0.3782 Avg delta:  0.481 Avg delta pct: 41.4828


 14%|█▎        | 250/1821 [2:53:46<17:45:56, 40.71s/it]

Log-odds:  -0.8333 Comprehensiveness:  0.2709 Sufficiency:  0.3881 Avg delta:  0.4702 Avg delta pct: 37.3392


 16%|█▋        | 300/1821 [3:30:38<15:30:27, 36.70s/it]

Log-odds:  -0.7667 Comprehensiveness:  0.2711 Sufficiency:  0.3948 Avg delta:  0.4907 Avg delta pct: 38.9284


 19%|█▉        | 350/1821 [4:04:46<17:04:09, 41.77s/it]

Log-odds:  -0.7677 Comprehensiveness:  0.2579 Sufficiency:  0.4023 Avg delta:  0.5172 Avg delta pct: 39.1773


 22%|██▏       | 400/1821 [4:40:45<15:36:50, 39.56s/it]

Log-odds:  -0.7512 Comprehensiveness:  0.2606 Sufficiency:  0.3883 Avg delta:  0.511 Avg delta pct: 37.8321


 25%|██▍       | 450/1821 [5:18:05<16:19:28, 42.87s/it]

Log-odds:  -0.7648 Comprehensiveness:  0.2608 Sufficiency:  0.3752 Avg delta:  0.519 Avg delta pct: 39.282


 27%|██▋       | 500/1821 [5:52:22<16:18:20, 44.44s/it]

Log-odds:  -0.7714 Comprehensiveness:  0.2648 Sufficiency:  0.3657 Avg delta:  0.5251 Avg delta pct: 44.2194


 30%|███       | 550/1821 [6:28:32<13:32:21, 38.35s/it]

Log-odds:  -0.793 Comprehensiveness:  0.2735 Sufficiency:  0.3695 Avg delta:  0.5291 Avg delta pct: 43.7595


 33%|███▎      | 600/1821 [7:13:09<21:46:21, 64.19s/it]

Log-odds:  -0.7966 Comprehensiveness:  0.275 Sufficiency:  0.3707 Avg delta:  0.5255 Avg delta pct: 47.9392


 36%|███▌      | 650/1821 [7:50:55<19:17:21, 59.30s/it]

Log-odds:  -0.8053 Comprehensiveness:  0.2793 Sufficiency:  0.3688 Avg delta:  0.519 Avg delta pct: 47.1363


 38%|███▊      | 700/1821 [8:25:43<14:17:35, 45.90s/it]

Log-odds:  -0.7994 Comprehensiveness:  0.2786 Sufficiency:  0.3637 Avg delta:  0.5211 Avg delta pct: 51.9619


 41%|████      | 750/1821 [9:02:00<9:46:20, 32.85s/it] 

Log-odds:  -0.8179 Comprehensiveness:  0.279 Sufficiency:  0.3668 Avg delta:  0.5326 Avg delta pct: 52.7093


 44%|████▍     | 800/1821 [9:37:39<12:15:19, 43.21s/it]

Log-odds:  -0.8179 Comprehensiveness:  0.2786 Sufficiency:  0.3685 Avg delta:  0.5266 Avg delta pct: 51.626


 47%|████▋     | 850/1821 [10:15:22<11:57:06, 44.31s/it]

Log-odds:  -0.8361 Comprehensiveness:  0.2853 Sufficiency:  0.3659 Avg delta:  0.5245 Avg delta pct: 52.7057


 49%|████▉     | 900/1821 [10:53:53<12:01:15, 46.99s/it]

Log-odds:  -0.8462 Comprehensiveness:  0.2865 Sufficiency:  0.3647 Avg delta:  0.5316 Avg delta pct: 53.4362


 52%|█████▏    | 950/1821 [11:28:13<11:17:16, 46.66s/it]

Log-odds:  -0.8231 Comprehensiveness:  0.2826 Sufficiency:  0.3655 Avg delta:  0.5275 Avg delta pct: 54.9588


 55%|█████▍    | 1000/1821 [12:00:49<10:27:25, 45.85s/it]

Log-odds:  -0.8343 Comprehensiveness:  0.2838 Sufficiency:  0.3668 Avg delta:  0.5266 Avg delta pct: 55.9614


 58%|█████▊    | 1050/1821 [12:38:29<8:03:16, 37.61s/it] 

Log-odds:  -0.8536 Comprehensiveness:  0.2833 Sufficiency:  0.3705 Avg delta:  0.5245 Avg delta pct: 55.0351


 60%|██████    | 1100/1821 [13:15:11<6:21:21, 31.74s/it] 

Log-odds:  -0.8346 Comprehensiveness:  0.2822 Sufficiency:  0.3697 Avg delta:  0.5272 Avg delta pct: 54.1869


 63%|██████▎   | 1150/1821 [13:53:15<8:58:18, 48.14s/it] 

Log-odds:  -0.8284 Comprehensiveness:  0.2815 Sufficiency:  0.3668 Avg delta:  0.5323 Avg delta pct: 53.2724


 66%|██████▌   | 1200/1821 [14:31:32<7:32:38, 43.73s/it] 

Log-odds:  -0.8367 Comprehensiveness:  0.2842 Sufficiency:  0.3694 Avg delta:  0.5373 Avg delta pct: 56.4634


 69%|██████▊   | 1250/1821 [15:09:31<9:04:07, 57.18s/it]

Log-odds:  -0.8278 Comprehensiveness:  0.2841 Sufficiency:  0.3713 Avg delta:  0.538 Avg delta pct: 55.9078


 71%|███████▏  | 1300/1821 [15:46:26<6:35:23, 45.53s/it] 

Log-odds:  -0.8361 Comprehensiveness:  0.2883 Sufficiency:  0.3686 Avg delta:  0.5391 Avg delta pct: 55.3998


 74%|███████▍  | 1350/1821 [16:21:09<5:07:58, 39.23s/it]

Log-odds:  -0.84 Comprehensiveness:  0.2915 Sufficiency:  0.3724 Avg delta:  0.538 Avg delta pct: 54.7957


 77%|███████▋  | 1400/1821 [16:54:57<3:04:53, 26.35s/it]

Log-odds:  -0.8406 Comprehensiveness:  0.2926 Sufficiency:  0.3746 Avg delta:  0.537 Avg delta pct: 55.4344


 80%|███████▉  | 1450/1821 [17:32:23<5:19:27, 51.66s/it]

Log-odds:  -0.8437 Comprehensiveness:  0.2915 Sufficiency:  0.3742 Avg delta:  0.5363 Avg delta pct: 57.0886


 82%|████████▏ | 1500/1821 [18:12:48<4:38:15, 52.01s/it]

Log-odds:  -0.8395 Comprehensiveness:  0.2921 Sufficiency:  0.3746 Avg delta:  0.5381 Avg delta pct: 56.9176


 85%|████████▌ | 1550/1821 [18:51:46<3:44:30, 49.71s/it]

Log-odds:  -0.8355 Comprehensiveness:  0.2913 Sufficiency:  0.375 Avg delta:  0.5343 Avg delta pct: 56.1787


 88%|████████▊ | 1600/1821 [19:30:43<2:28:45, 40.39s/it]

Log-odds:  -0.8256 Comprehensiveness:  0.2894 Sufficiency:  0.3753 Avg delta:  0.5329 Avg delta pct: 56.526


 91%|█████████ | 1650/1821 [20:07:56<2:02:47, 43.09s/it]

Log-odds:  -0.82 Comprehensiveness:  0.287 Sufficiency:  0.3746 Avg delta:  0.5324 Avg delta pct: 60.5845


 93%|█████████▎| 1700/1821 [20:49:27<1:45:21, 52.24s/it]

Log-odds:  -0.8163 Comprehensiveness:  0.2864 Sufficiency:  0.3717 Avg delta:  0.5312 Avg delta pct: 60.0439


 96%|█████████▌| 1750/1821 [21:27:08<53:43, 45.41s/it]  

Log-odds:  -0.8181 Comprehensiveness:  0.2855 Sufficiency:  0.3733 Avg delta:  0.5329 Avg delta pct: 60.6395


 99%|█████████▉| 1800/1821 [22:04:02<14:55, 42.64s/it]

Log-odds:  -0.8095 Comprehensiveness:  0.2836 Sufficiency:  0.3741 Avg delta:  0.5337 Avg delta pct: 59.9291


100%|██████████| 1821/1821 [22:17:11<00:00, 44.06s/it]

Log-odds:  -0.8106 Comprehensiveness:  0.2836 Sufficiency:  0.3751 Avg delta:  0.5396 Avg delta pct: 60.667
CPU times: user 21h 55min 3s, sys: 47min 24s, total: 22h 42min 28s
Wall time: 22h 17min 11s



