In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import sys
sys.path.append('/content/drive/My Drive/{}'.format("cogs402longformer/"))

In [3]:
pip install transformers --quiet

In [4]:
pip install captum --quiet

In [5]:
pip install datasets --quiet

In [6]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import os
print(os.getcwd())
path_parent = os.path.dirname(os.getcwd())
os.chdir(path_parent)
print(os.getcwd())

In [7]:
from transformers import BertTokenizer, BertForSequenceClassification, BertConfig

from captum.attr import visualization as viz
from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer

import torch

In [8]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [9]:
from transformers import LongformerForSequenceClassification, LongformerTokenizer, LongformerConfig
# replace <PATH-TO-SAVED-MODEL> with the real path of the saved model
model_path = '/content/drive/MyDrive/cogs402longformer/models/longformer-finetuned_papers/checkpoint-2356'

# load model
model = LongformerForSequenceClassification.from_pretrained(model_path, num_labels = 2)
model.to(device)
model.eval()
model.zero_grad()

# load tokenizer
tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")

In [10]:
ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence

In [11]:
def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):

    text_ids = tokenizer.encode(text, truncation = True, add_special_tokens=False, max_length = 2048)
    # construct input token ids
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)

def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)

    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids

def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)

In [12]:
from datasets import load_dataset
cogs402_ds = load_dataset("danielhou13/cogs402dataset")["test"]

Using custom data configuration danielhou13--cogs402dataset-cc784554b797f843
Reusing dataset parquet (/root/.cache/huggingface/datasets/danielhou13___parquet/danielhou13--cogs402dataset-cc784554b797f843/0.0.0/0b6d5799bb726b24ad7fc7be720c170d8e497f575d02d47537de9a5bac074901)


  0%|          | 0/2 [00:00<?, ?it/s]

In [13]:
testval = 923
text = cogs402_ds['text'][testval]
label = cogs402_ds['labels'][testval]
print(label)

0


In [14]:
input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id)
position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
attention_mask = construct_attention_mask(input_ids)

indices = input_ids[0].detach().tolist()
all_tokens = tokenizer.convert_ids_to_tokens(indices)

In [15]:
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

In [16]:
print(all_tokens)

['<s>', 'lp', 'opt', ':', 'ĠA', 'ĠRule', 'ĠOptim', 'ization', 'ĠTool', 'Ġfor', 'ĠAnswer', 'ĠSet', 'ĠProgramming', 'Ġ', 'Ġar', 'X', 'iv', ':', '16', '08', '.', '05', '675', 'v', '2', 'Ġ[', 'cs', '.', 'LO', ']', 'Ġ23', 'ĠAug', 'Ġ2016', 'Ġ', 'ĠManuel', 'ĠB', 'ich', 'ler', ',', 'ĠMichael', 'ĠMor', 'ak', ',', 'Ġand', 'ĠStefan', 'ĠWol', 'tr', 'an', 'ĠT', 'U', 'ĠW', 'ien', ',', 'ĠVienna', ',', 'ĠAustria', 'Ġ{', 's', 'urn', 'ame', '}', '@', 'db', 'ai', '.', 'tu', 'w', 'ien', '.', 'ac', '.', 'at', 'Ġ', 'ĠAbstract', '.', 'ĠState', '-', 'of', '-', 'the', '-', 'art', 'Ġanswer', 'Ġset', 'Ġprogramming', 'Ġ(', 'AS', 'P', ')', 'Ġsol', 'vers', 'Ġrely', 'Ġon', 'Ġa', 'Ġprogram', 'Ġcalled', 'Ġa', 'Ġgrou', 'nder', 'Ġto', 'Ġconvert', 'Ġnon', '-', 'ground', 'Ġprograms', 'Ġcontaining', 'Ġvariables', 'Ġinto', 'Ġvariable', '-', 'free', ',', 'Ġpropos', 'itional', 'Ġprograms', '.', 'ĠThe', 'Ġsize', 'Ġof', 'Ġthis', 'Ġgrounding', 'Ġdepends', 'Ġheavily', 'Ġon', 'Ġthe', 'Ġsize', 'Ġof', 'Ġthe', 'Ġnon', '-', 'ground', 

In [17]:
#set 1 if we are dealing with a positive class, and 0 if dealing with negative class
def custom_forward2(inputs_emb, global_attention_mask) :
    preds = model(inputs_embeds=inputs_emb, global_attention_mask=global_attention_mask)
    return torch.softmax(preds.logits, dim = 1)[:, 0] # for negative attribution, 
    #return torch.softmax(preds, dim = 1)[:, 1] #<- for positive attribution

In [18]:
def construct_whole_longformer_embeddings(input_ids, ref_input_ids, \
                                          token_type_ids=None, ref_token_type_ids=None, \
                                          position_ids=None, ref_position_ids=None):
    input_embeddings = model.longformer.embeddings(input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
    ref_input_embeddings = model.longformer.embeddings(ref_input_ids, token_type_ids=ref_token_type_ids, position_ids=ref_position_ids)
    
    print(input_embeddings)
    return input_embeddings, ref_input_embeddings

In [19]:
layer_attrs = []

# The token that we would like to examine separately.
token_to_explain = 334 # the index of the token that we would like to examine more thoroughly
layer_attrs_dist = []

input_embeddings, ref_input_embeddings = construct_whole_longformer_embeddings(input_ids, ref_input_ids, \
                                         position_ids=position_ids, ref_position_ids=ref_position_ids)

print(input_embeddings.shape)
globalattention= torch.zeros_like(input_ids)
globalattention[:, 0] = 1
print(globalattention)

tensor([[[ 0.2746, -0.0095,  0.3482,  ..., -0.0345, -0.0225, -0.1789],
         [-0.4242, -0.0695,  0.5289,  ..., -0.4890, -0.5733, -0.4481],
         [-0.0438,  0.2288,  0.6584,  ...,  0.3309, -0.0427,  0.2786],
         ...,
         [-0.1743, -0.3578,  0.0495,  ..., -0.1654,  0.2040, -0.7195],
         [-0.1707, -0.1062,  0.1560,  ...,  0.2212,  0.1290,  0.1111],
         [ 0.0133,  0.4446, -0.2146,  ...,  0.1282,  0.2887, -0.0445]]],
       device='cuda:0', grad_fn=<NativeLayerNormBackward0>)
torch.Size([1, 2050, 768])
tensor([[1, 0, 0,  ..., 0, 0, 0]], device='cuda:0')


In [20]:
for i in range(model.config.num_hidden_layers):
    lc = LayerConductance(custom_forward2, model.longformer.encoder.layer[i])
    layer_attributions = lc.attribute(inputs=input_embeddings, 
                                      baselines=ref_input_embeddings, 
                                      additional_forward_args=globalattention, 
                                      n_steps=5, 
                                      internal_batch_size=2)
    layer_attrs.append(summarize_attributions(layer_attributions).cpu().detach().tolist())

    # storing attributions of the token id that we would like to examine in more detail in token_to_explain
    layer_attrs_dist.append(layer_attributions[0,token_to_explain,:].cpu().detach().tolist())


RuntimeError: ignored

In [None]:
import numpy as np

fig, ax = plt.subplots(figsize=(15,5))
xticklabels=all_tokens
yticklabels=list(range(1,13))
ax = sns.heatmap(np.array(layer_attrs), xticklabels=xticklabels, yticklabels=yticklabels, linewidth=0.2)
plt.xlabel('Tokens')
plt.ylabel('Layers')
plt.show()