<a href="https://colab.research.google.com/github/danielhou13/cogs402longformer/blob/main/src/CaptumLongformerSequenceClassificationaggregate.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook finds the aggregate attributions for both the postive and negative class over the entire dataset.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Import dependencies

In [None]:
import sys
sys.path.append('/content/drive/My Drive/{}'.format("cogs402longformer/"))

In [None]:
pip install transformers --quiet

[K     |████████████████████████████████| 4.4 MB 4.3 MB/s 
[K     |████████████████████████████████| 6.6 MB 72.6 MB/s 
[K     |████████████████████████████████| 596 kB 70.7 MB/s 
[K     |████████████████████████████████| 101 kB 5.2 MB/s 
[?25h

In [None]:
pip install captum --quiet

[?25l[K     |▎                               | 10 kB 36.6 MB/s eta 0:00:01[K     |▌                               | 20 kB 7.5 MB/s eta 0:00:01[K     |▊                               | 30 kB 6.7 MB/s eta 0:00:01[K     |█                               | 40 kB 3.6 MB/s eta 0:00:01[K     |█▏                              | 51 kB 3.6 MB/s eta 0:00:01[K     |█▍                              | 61 kB 4.2 MB/s eta 0:00:01[K     |█▋                              | 71 kB 4.4 MB/s eta 0:00:01[K     |█▉                              | 81 kB 4.6 MB/s eta 0:00:01[K     |██                              | 92 kB 5.1 MB/s eta 0:00:01[K     |██▎                             | 102 kB 4.3 MB/s eta 0:00:01[K     |██▌                             | 112 kB 4.3 MB/s eta 0:00:01[K     |██▊                             | 122 kB 4.3 MB/s eta 0:00:01[K     |███                             | 133 kB 4.3 MB/s eta 0:00:01[K     |███▏                            | 143 kB 4.3 MB/s eta 0:00:01[K    

In [None]:
pip install datasets --quiet

[K     |████████████████████████████████| 362 kB 4.3 MB/s 
[K     |████████████████████████████████| 140 kB 83.6 MB/s 
[K     |████████████████████████████████| 1.1 MB 74.1 MB/s 
[K     |████████████████████████████████| 212 kB 92.1 MB/s 
[K     |████████████████████████████████| 127 kB 99.4 MB/s 
[K     |████████████████████████████████| 144 kB 82.7 MB/s 
[K     |████████████████████████████████| 271 kB 57.6 MB/s 
[K     |████████████████████████████████| 94 kB 2.6 MB/s 
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.[0m
[?25h

In [None]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [None]:
from captum.attr import visualization as viz
from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer

import torch
import pandas as pd

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Import model

In [None]:
from transformers import LongformerForSequenceClassification, LongformerTokenizer, LongformerConfig
# replace <PATH-TO-SAVED-MODEL> with the real path of the saved model
model_path = 'danielhou13/longformer-finetuned_papers_v2'
#model_path = 'danielhou13/longformer-finetuned-news-cogs402'

# load model
model = LongformerForSequenceClassification.from_pretrained(model_path, num_labels = 2)
model.to(device)
model.eval()
model.zero_grad()

# load tokenizer
tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")

Downloading:   0%|          | 0.00/0.99k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/567M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/694 [00:00<?, ?B/s]

Create functions that give us the input ids and the position ids for the text we want to examine

In [None]:
def predict(inputs, position_ids=None, attention_mask=None):
    output = model(inputs,
                   position_ids=position_ids,
                   attention_mask=attention_mask)
    return output.logits

In [None]:
ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence

In [None]:
max_length = 2046
def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):

    text_ids = tokenizer.encode(text, truncation = True, add_special_tokens=False, max_length = max_length)
    # construct input token ids
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)

def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)

    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids
    
def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)

Import dataset and take one example from it for testing purposes

Here we import the papers dataset

In [None]:
from datasets import load_dataset
import numpy as np
cogs402_ds = load_dataset("danielhou13/cogs402dataset")["test"]

Downloading:   0%|          | 0.00/739 [00:00<?, ?B/s]

Using custom data configuration danielhou13--cogs402dataset-144b958ac1a53abb


Downloading and preparing dataset None/None (download: 157.87 MiB, generated: 311.56 MiB, post-processed: Unknown size, total: 469.43 MiB) to /root/.cache/huggingface/datasets/danielhou13___parquet/danielhou13--cogs402dataset-144b958ac1a53abb/0.0.0/7328ef7ee03eaf3f86ae40594d46a1cec86161704e02dd19f232d81eee72ade8...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/132M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

0 tables [00:00, ? tables/s]

0 tables [00:00, ? tables/s]

Dataset parquet downloaded and prepared to /root/.cache/huggingface/datasets/danielhou13___parquet/danielhou13--cogs402dataset-144b958ac1a53abb/0.0.0/7328ef7ee03eaf3f86ae40594d46a1cec86161704e02dd19f232d81eee72ade8. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
# cogs402_ds2 = load_dataset('hyperpartisan_news_detection', 'bypublisher')['validation']
# val_size = 5000
# val_indices = np.random.randint(0, len(cogs402_ds2), val_size)
# val_ds = cogs402_ds2.select(val_indices)
# labels2 = map(int, val_ds['hyperpartisan'])
# labels2 = list(labels2)
# val_ds = val_ds.add_column("labels", labels2)

In [None]:
# testval = 923
# text = cogs402_ds2['text'][testval]
# label = cogs402_ds2['labels'][testval]
# print(label)

In [None]:
#set 1 if we are dealing with a positive class, and 0 if dealing with negative class
def custom_forward(inputs, position_ids=None, attention_mask=None):
    preds = predict(inputs,
                   position_ids=position_ids,
                   attention_mask=attention_mask
                   )
    return torch.softmax(preds, dim = 1)

Perform Layer Integrated Gradients using the longformer's embeddings. This can easily be adjusted to use longformer word embeddings and longformer position embeddings. Note that the longformer does not use token type embeddings.

In [None]:
lig = LayerIntegratedGradients(custom_forward, model.longformer.embeddings)
# lig2 = LayerIntegratedGradients(custom_forward, \
#                                 [model.longformer.embeddings.word_embeddings, \
#                                  model.longformer.embeddings.position_embeddings])

In [None]:
input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id)
position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
attention_mask = construct_attention_mask(input_ids)

indices = input_ids[0].detach().tolist()
all_tokens = tokenizer.convert_ids_to_tokens(indices)

In [None]:
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.linalg.norm(attributions)
    return attributions

We can find the aggregate total for each token over the entire dataset in order to find which words have the highest and lowest attribution when the model is predicting positive and when the model is predicting negative. 

In [None]:
from tqdm import tqdm
aggregate_attrib_zero = []
aggregate_attrib_ones = []
aggregation_function = {'attribution': 'sum'}

for i in tqdm(range(len(cogs402_ds))):
  text = cogs402_ds[i]['text']
  label = cogs402_ds[i]['labels']
  input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id)
  position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
  attention_mask = construct_attention_mask(input_ids)

  indices = input_ids[0].detach().tolist()
  all_tokens = tokenizer.convert_ids_to_tokens(indices)

  attributions = lig.attribute(inputs=input_ids,
                                    baselines=ref_input_ids,
                                    additional_forward_args=(position_ids, attention_mask),
                                    target=label,
                                    n_steps=20,
                                    internal_batch_size = 2)
  
  attributions_sum = summarize_attributions(attributions)
  
  d = {"tokens":all_tokens, "attribution":attributions_sum[:len(all_tokens)].cpu()}  
  df_attrib = pd.DataFrame(d)
  df_attrib = df_attrib.groupby(df_attrib['tokens']).aggregate(aggregation_function)

  if label == 0:
    aggregate_attrib_zero.append(df_attrib)
  else:
    aggregate_attrib_ones.append(df_attrib)

 12%|█▏        | 132/1070 [23:09<2:44:53, 10.55s/it]

Here we have the implementation for the multi-embedding version

In [None]:
# from tqdm import tqdm
# aggregate_attrib_zero = []
# aggregate_attrib_ones = []
# aggregate_pos_zero = []
# aggregate_pos_ones = []

# aggregation_function = {'attribution': 'sum'}

# for i in tqdm(range(len(cogs402_ds)), position = 0, leave = True):
#   text = cogs402_ds[i]['text']
#   label = cogs402_ds[i]['labels']
#   input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id)
#   position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
#   attention_mask = construct_attention_mask(input_ids)

#   indices = input_ids[0].detach().tolist()
#   all_tokens = tokenizer.convert_ids_to_tokens(indices)

#   attributions2 = lig2.attribute(inputs=(input_ids, position_ids),
#                                baselines=(ref_input_ids, ref_position_ids),
#                                target=label,
#                                additional_forward_args=(attention_mask),
#                                n_steps=15,
#                                internal_batch_size = 2)
#   attributions_word = summarize_attributions(attributions2[0])
#   attributions_position = summarize_attributions(attributions2[1])

#   d = {"tokens":all_tokens, "attribution":attributions_word[:len(all_tokens)].cpu()}  
#   d2 = {"tokens":all_tokens, "attribution":attributions_position[:len(all_tokens)].cpu()}  
  
#   df_attrib = pd.DataFrame(d)
#   df_attrib2 = pd.DataFrame(d2)

#   df_attrib = df_attrib.groupby(df_attrib['tokens']).aggregate(aggregation_function)
#   df_attrib2 = df_attrib2.groupby(df_attrib2['tokens']).aggregate(aggregation_function)

#   if label == 0:
#     aggregate_attrib_zero.append(df_attrib)
#     aggregate_pos_zero.append(df_attrib2)
#   else:
#     aggregate_attrib_ones.append(df_attrib)
#     aggregate_pos_ones.append(df_attrib2)

In [None]:
def combinedataframe(listframes, aggregation_func):
  df_attrib = pd.concat(listframes)
  df_attrib = df_attrib.reset_index(level=0)
  df_attrib = df_attrib.groupby(df_attrib['tokens']).aggregate(aggregation_func)
  df_attrib['attribution'] = df_attrib['attribution'].div(len(listframes))
  highest_attrib_tokens_all = df_attrib.sort_values(by=['attribution'], ascending=False)
  return highest_attrib_tokens_all

In [None]:
#longformer embeddings/word embeddings if multi-embedding
df_attrib_zero = combinedataframe(aggregate_attrib_zero, aggregation_function)
df_attrib_ones = combinedataframe(aggregate_attrib_ones, aggregation_function)

#position embeddings for multi-embedding
# df_pos_zero = combinedataframe(aggregate_pos_zero, aggregation_function)
# df_pos_ones = combinedataframe(aggregate_pos_ones, aggregation_function)

Here we get the attributions for the negative class. Here we are only showing the top 10 higest attributions, in other words, the tokens that have the most influence in the model predicting negative.

In [None]:
df_attrib_zero[:10]

In [None]:
# df_pos_zero[:10]

Here we get the attributions for the positive class. We are once again only showing the top 10 attributions, the tokens that have the most influence in the model predicting positive.

In [None]:
df_attrib_ones[:10]

In [None]:
# df_pos_ones[:10]

Note: if you wish to find the aggregate attributions irrespective of the example's class, you can combine the dataframes and use the aggregation function.

Save the pandas dataframe into a csv to access it in the future without having to run through the entire dataset. Change the file name from papers to the dataset used.

In [None]:
df_attrib_zero.to_csv('/content/drive/MyDrive/cogs402longformer/results/longformer_emb_zero_papers.csv')  
df_attrib_ones.to_csv('/content/drive/MyDrive/cogs402longformer/results/longformer_emb_ones_papers.csv')  

# df_attrib_zero.to_csv('/content/drive/MyDrive/cogs402longformer/results/word_emb_attrib_zero_papers.csv')  
# df_attrib_ones.to_csv('/content/drive/MyDrive/cogs402longformer/results/word_emb_attrib_ones_papers.csv')  
# df_pos_zero.to_csv('/content/drive/MyDrive/cogs402longformer/results/pos_emb_attrib_zero_papers.csv')  
# df_pos_ones.to_csv('/content/drive/MyDrive/cogs402longformer/results/pos_emb_attrib_ones_papers.csv')  