In [1]:
# !pip install transformers==4.26.1
# !pip install datasets
# !pip install captum==0.3.1

In [2]:
# from google.colab import drive
# drive.mount('/content/drive')
# %cd /content/drive/MyDrive/Micropaper/

In [53]:
import os, sys, numpy as np, pickle, random
from sklearn.neighbors import kneighbors_graph

In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, BertModel
from tqdm import tqdm
from datasets import load_dataset
from dig import DiscretetizedIntegratedGradients
from attributions import run_dig_explanation
from metrics import eval_log_odds, eval_comprehensiveness, eval_sufficiency
import monotonic_paths
from captum.attr._utils.common import _format_input_baseline, _reshape_and_sum, _validate_input, _format_input

In [18]:
layer = 11

In [19]:
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7ff5d2c5b110>

In [20]:
tokenizer = AutoTokenizer.from_pretrained('textattack/bert-base-uncased-SST-2')
model= AutoModelForSequenceClassification.from_pretrained('textattack/bert-base-uncased-SST-2',output_hidden_states = True)
model1 = BertModel.from_pretrained('textattack/bert-base-uncased-SST-2',output_hidden_states = True)

Some weights of the model checkpoint at textattack/bert-base-uncased-SST-2 were not used when initializing BertModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [21]:
device= torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
model.to(device)
# model1.to(device)
model.eval()
model.zero_grad()


cpu


In [22]:

def construct_input_ref_pair(tokenizer, text, ref_token_id, sep_token_id, cls_token_id, device):
	text_ids		= tokenizer.encode(text, add_special_tokens=False, truncation=True,max_length=tokenizer.max_len_single_sentence)
	input_ids		= [cls_token_id] + text_ids + [sep_token_id]	# construct input token ids
	ref_input_ids	= [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]	# construct reference token ids

	return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device)

def construct_input_ref_pos_id_pair(input_ids, device):
	global model
	seq_length			= input_ids.size(1)
	position_ids 		= torch.arange(seq_length, dtype=torch.long,device=device).unsqueeze(0)
	ref_position_ids	= torch.zeros(seq_length, dtype=torch.long,device=device).unsqueeze(0)

	return position_ids, ref_position_ids

def construct_input_ref_token_type_pair(input_ids, device):
	seq_len				= input_ids.size(1)
	token_type_ids		= torch.tensor([[0] * seq_len], dtype=torch.long, device=device)
	ref_token_type_ids	= torch.zeros_like(token_type_ids, dtype=torch.long, device=device)
	return token_type_ids, ref_token_type_ids

def construct_attention_mask(input_ids):
	return torch.ones_like(input_ids)

def get_word_embeddings():
	global model
	return model.bert.embeddings.word_embeddings.weight

def construct_word_embedding(model, input_ids):
	return model.bert.embeddings.word_embeddings(input_ids)

def construct_position_embedding(model, position_ids):
	return model.bert.embeddings.position_embeddings(position_ids)

def construct_type_embedding(model, type_ids):
	return model.bert.embeddings.token_type_embeddings(type_ids)

def construct_sub_embedding(model, input_ids, ref_input_ids, position_ids, ref_position_ids, type_ids, ref_type_ids):
	input_embeddings				= construct_word_embedding(model, input_ids)
	ref_input_embeddings			= construct_word_embedding(model, ref_input_ids)
	input_position_embeddings		= construct_position_embedding(model, position_ids)
	ref_input_position_embeddings	= construct_position_embedding(model, ref_position_ids)
	input_type_embeddings			= construct_type_embedding(model, type_ids)
	ref_input_type_embeddings		= construct_type_embedding(model, ref_type_ids)

	return 	(input_embeddings, ref_input_embeddings), \
			(input_position_embeddings, ref_input_position_embeddings), \
			(input_type_embeddings, ref_input_type_embeddings)

def get_base_token_emb(device):
	global model
	return construct_word_embedding(model, torch.tensor([tokenizer.pad_token_id], device=device))

def get_tokens(text_ids):
	global tokenizer
	return tokenizer.convert_ids_to_tokens(text_ids.squeeze())

def get_inputs(text, device):
	global model, tokenizer
	ref_token_id = tokenizer.mask_token_id
	sep_token_id = tokenizer.sep_token_id
	cls_token_id = tokenizer.cls_token_id

	input_ids, ref_input_ids		= construct_input_ref_pair(tokenizer, text, ref_token_id, sep_token_id, cls_token_id, device)
	position_ids, ref_position_ids	= construct_input_ref_pos_id_pair(input_ids, device)
	type_ids, ref_type_ids			= construct_input_ref_token_type_pair(input_ids, device)
	attention_mask					= construct_attention_mask(input_ids)

	(input_embed, ref_input_embed), (position_embed, ref_position_embed), (type_embed, ref_type_embed) = \
				construct_sub_embedding(model, input_ids, ref_input_ids, position_ids, ref_position_ids, type_ids, ref_type_ids)

	return [input_ids, ref_input_ids, input_embed, ref_input_embed, position_embed, ref_position_embed, type_embed, ref_type_embed, attention_mask]

In [23]:

def predict(model, inputs_embeds, attention_mask=None):
    return model.bert(inputs_embeds=inputs_embeds, attention_mask=attention_mask)

def nn_forward_func(input_embed, layer , attention_mask=None, position_embed=None, type_embed=None, return_all_logits=False):
    global model
    global model1
    embeds	= input_embed + position_embed + type_embed
    embeds	= model.bert.embeddings.dropout(model.bert.embeddings.LayerNorm(embeds))
    pred	= predict(model, embeds, attention_mask=attention_mask)
    pred1 =  model.classifier(model.dropout(model1.pooler(pred.hidden_states[layer])))
    if return_all_logits:
        return pred1
    else:
        return pred1.max(1).values

# for row in tqdm([[sentence]]):
#     inp = get_inputs(row[0], device)
#     input_ids, ref_input_ids, input_embed, ref_input_embed, position_embed, ref_position_embed, type_embed, ref_type_embed, attention_mask = inp
#     f = nn_forward_func(input_embed[0],12, attention_mask,position_embed,type_embed)


In [24]:
# model.classifier(model.dropout(model1.pooler(f.hidden_states[1])))

In [25]:
# f

In [26]:
word_features		= get_word_embeddings().cpu().detach().numpy()
word_idx_map		= tokenizer.vocab
A					= kneighbors_graph(word_features, 500, mode='distance', n_jobs=-1)
auxiliary_data = [word_idx_map, word_features, A]

In [27]:
# Define the Attribution function
attr_func = DiscretetizedIntegratedGradients(nn_forward_func)

In [28]:
attr_func.attribute

<bound method DiscretetizedIntegratedGradients.attribute of <dig.DiscretetizedIntegratedGradients object at 0x7ff5d2881d90>>

In [29]:
dataset= load_dataset('glue', 'sst2')['test']
data= list(zip(dataset['sentence'], dataset['label'], dataset['idx']))

Reusing dataset glue (/Users/ayankundu/.cache/huggingface/datasets/glue/sst2/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4)


In [30]:
all_outputs = []


def calculate_attributions(inputs, device, attr_func, base_token_emb, nn_forward_func, get_tokens , layer):
    # computes the attributions for given input

    # move inputs to main device
    inp = [x.to(device) if x is not None else None for x in inputs]

    # compute attribution
    scaled_features, input_ids, ref_input_ids, input_embed, ref_input_embed, position_embed, ref_position_embed, type_embed, ref_type_embed, attention_mask = inp
    attr, deltaa = run_dig_explanation(attr_func, layer, scaled_features, position_embed, type_embed, attention_mask, 63)

    # compute metrics
    log_odd, pred	= eval_log_odds(nn_forward_func, layer, input_embed, position_embed, type_embed, attention_mask, base_token_emb, attr, topk=20)
    comp			= eval_comprehensiveness(nn_forward_func, layer, input_embed, position_embed, type_embed, attention_mask, base_token_emb, attr, topk=20)
    suff			= eval_sufficiency(nn_forward_func, layer, input_embed, position_embed, type_embed, attention_mask, base_token_emb, attr, topk=20)

    #return log_odd
    return log_odd, comp, suff, attr, deltaa

In [58]:
%%time

sentence = 'first good, then bothersome'

# get ref token embedding
base_token_emb = get_base_token_emb(device)

# compute the DIG attributions for all the inputs
print('Starting attribution computation...')
inputs = []
# log_odds, comps, suffs, deltas, delta_pcs, count = 0, 0, 0, 0, 0, 0
print_step = 50

delta_pcs_list = []
for row in tqdm([[sentence]]):
    for j in list(range(1,13)):
        log_odds, comps, suffs, deltas, delta_pcs, count = 0, 0, 0, 0, 0, 0
        print(f"*************** Layer number : {j} ****************")
        print()
        
        inp1 = get_inputs(sentence, device)
        input_ids, ref_input_ids, input_embed, ref_input_embed, position_embed, ref_position_embed, type_embed, ref_type_embed, attention_mask = inp1
        prediction = nn_forward_func(input_embed, j ,attention_mask,position_embed,type_embed,return_all_logits=True).squeeze()
        print(f"Prediction: {prediction}")
        print()
        
        inp = get_inputs(row[0], device)
        input_ids, ref_input_ids, input_embed, ref_input_embed, position_embed, ref_position_embed, type_embed, ref_type_embed, attention_mask = inp
        scaled_features 		= monotonic_paths.scale_inputs(input_ids.squeeze().tolist(), ref_input_ids.squeeze().tolist(),\
                                            device, auxiliary_data, method ="UIG", steps=30, nbrs = 50, factor=1, strategy='greedy')
        inputs					= [scaled_features, input_ids, ref_input_ids, input_embed, ref_input_embed, position_embed, ref_position_embed, type_embed, ref_type_embed, attention_mask]
        log_odd, comp, suff, attrib, delta= calculate_attributions(inputs, device, attr_func, base_token_emb, nn_forward_func, get_tokens, j)
        print(f"Attributions scores: {attrib}")
        print()
        
        scaled_features_tpl = _format_input(scaled_features)
        start_point, end_point = _format_input(scaled_features_tpl[0][0].unsqueeze(0)), _format_input(scaled_features_tpl[0][-1].unsqueeze(0))# baselines, inputs (only works for one input, i.e. len(tuple) == 1)
        F_diff = (nn_forward_func(end_point[0],j, attention_mask,position_embed,type_embed).squeeze() - \
                 nn_forward_func(start_point[0], j, attention_mask,position_embed,type_embed).squeeze())
        
        delta_pc = delta/F_diff*100
        log_odds	+= log_odd
        comps		+= comp
        suffs 		+= suff
        deltas+= torch.abs(delta).item()
        delta_pcs+= torch.abs(delta_pc).item()
        delta_pcs_list.append(torch.abs(delta_pc).item())
        count		+= 1
        
        
        # print the metrics
        print('Log-odds: ', np.round(log_odds / count, 4), 'Comprehensiveness: ', np.round(comps / count, 4),
              'Sufficiency: ', np.round(suffs / count, 4),  'Avg delta: ', np.round(deltas / count, 4),
              'Avg delta pct:', np.round(delta_pcs / count, 4))

#     print('Log-odds: ', np.round(log_odds / count, 4), 'Comprehensiveness: ', np.round(comps / count, 4),
#           'Sufficiency: ', np.round(suffs / count, 4), 'Avg delta: ', np.round(deltas / count, 4),
#                   'Avg delta pct:', np.round(delta_pcs / count, 4))


  0%|          | 0/1 [00:00<?, ?it/s]

Starting attribution computation...
*************** Layer number : 1 ****************

Prediction: tensor([-0.3420,  0.2318], grad_fn=<SqueezeBackward0>)

Attributions scores: tensor([ 0.0000, -0.4475, -0.5731, -0.4156, -0.5305, -0.0909,  0.0946,  0.0000],
       grad_fn=<DivBackward0>)

Log-odds:  0.0033 Comprehensiveness:  -0.0001 Sufficiency:  0.096 Avg delta:  0.004 Avg delta pct: 7.9942
*************** Layer number : 2 ****************

Prediction: tensor([-0.2641,  0.1929], grad_fn=<SqueezeBackward0>)

Attributions scores: tensor([ 0.0000, -0.1137, -0.3334,  0.6424,  0.2180, -0.6411, -0.0690,  0.0000],
       grad_fn=<DivBackward0>)

Log-odds:  0.0117 Comprehensiveness:  -0.0001 Sufficiency:  0.0032 Avg delta:  0.0013 Avg delta pct: 11.1332
*************** Layer number : 3 ****************

Prediction: tensor([-0.1913,  0.2224], grad_fn=<SqueezeBackward0>)

Attributions scores: tensor([ 0.0000,  0.2661,  0.1528,  0.8164,  0.4067, -0.0886, -0.2570,  0.0000],
       grad_fn=<DivBac

100%|██████████| 1/1 [01:47<00:00, 107.76s/it]

Attributions scores: tensor([ 0.0000,  0.1018, -0.2101, -0.1143,  0.3223,  0.9101, -0.0155,  0.0000],
       grad_fn=<DivBackward0>)

Log-odds:  -3.4912 Comprehensiveness:  0.9707 Sufficiency:  0.5965 Avg delta:  0.58 Avg delta pct: 23.0217
CPU times: user 1min 45s, sys: 2.69 s, total: 1min 47s
Wall time: 1min 47s



