<a href="https://colab.research.google.com/github/danielhou13/cogs402longformer/blob/main/src/ConvertSlidingAttentionMatrix.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
# import sys
# sys.path.append('/content/drive/My Drive/{}'.format("cogs402longformer/"))

In [3]:
pip install transformers --quiet

In [4]:
pip install captum --quiet

In [5]:
pip install datasets --quiet

In [6]:
import os

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [8]:
from datasets import load_dataset
cogs402_ds = load_dataset("danielhou13/cogs402dataset")["test"]

Using custom data configuration danielhou13--cogs402dataset-cc784554b797f843
Reusing dataset parquet (/root/.cache/huggingface/datasets/danielhou13___parquet/danielhou13--cogs402dataset-cc784554b797f843/0.0.0/7328ef7ee03eaf3f86ae40594d46a1cec86161704e02dd19f232d81eee72ade8)


  0%|          | 0/2 [00:00<?, ?it/s]

In [9]:
from transformers import LongformerForSequenceClassification, LongformerTokenizer, LongformerConfig
# replace <PATH-TO-SAVED-MODEL> with the real path of the saved model
model_path = 'danielhou13/longformer-finetuned_papers'

# load model
model = LongformerForSequenceClassification.from_pretrained(model_path, num_labels = 2, output_attentions=True)
model.to(device)
model.eval()
model.zero_grad()

# load tokenizer
tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")

In [10]:
def predict(inputs, position_ids=None, attention_mask=None):
    output = model(inputs,
                   position_ids=position_ids,
                   attention_mask=attention_mask)
    return output.logits, output.attentions, output.global_attentions

In [11]:
ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence

In [12]:
def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):

    text_ids = tokenizer.encode(text, truncation = True, add_special_tokens=False, max_length = 128)
    # construct input token ids
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)

def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)

    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids

def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)

def construct_whole_longformer_embeddings(input_ids, ref_input_ids, \
                                          token_type_ids=None, ref_token_type_ids=None, \
                                          position_ids=None, ref_position_ids=None):
    input_embeddings = model.longformer.embeddings(input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
    ref_input_embeddings = model.longformer.embeddings(ref_input_ids, token_type_ids=ref_token_type_ids, position_ids=ref_position_ids)
    
    print(input_embeddings.shape)
    return input_embeddings, ref_input_embeddings

In [13]:
if torch.__version__ >= '1.7.0':
    norm_fn = torch.linalg.norm
else:
    norm_fn = torch.norm

In [14]:
testval = 923
text = cogs402_ds['text'][testval]
label = cogs402_ds['labels'][testval]
print(label)

0


In [15]:
input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id)

position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
attention_mask = construct_attention_mask(input_ids)

indices = input_ids[0].detach().tolist()
all_tokens = tokenizer.convert_ids_to_tokens(indices)

In [16]:
print(input_ids.squeeze().shape)

torch.Size([130])


In [17]:
score, attention, global_attention = predict(input_ids, position_ids, attention_mask)

In [18]:
label=torch.tensor(label)
print(label)

tensor(0)


In [None]:
# print(attention[1])
# print(attention[1].sum(dim=0))
# print("non_zero", torch.nonzero(attention[1] - attention[1].sum(dim=0)))
# print(attention[1].sum(dim=0).unsqueeze(0).shape)

# batch_attn = torch.cat([l.sum(dim=0).unsqueeze(0) for l in attention], dim=0).detach().cpu().numpy()
# print(batch_attn.shape)

In [26]:
# shape -> layer x batch x head x seq_len x attention_window
output_attentions_all = torch.stack(attention).cpu()

global_attention_all = torch.stack(global_attention).cpu()

In [27]:
print(output_attentions_all.device)
print(global_attention_all.device)

cpu
cpu


In [33]:
test = output_attentions_all[11][0][11]
global_test = global_attention_all[11][0][11]

print(len(all_tokens))
print(test.shape[0])

print(global_test.squeeze().shape)

def create_head_matrix(output_attentions, global_attentions):
  new_attention_matrix = torch.zeros((output_attentions.shape[0], 
                                      output_attentions.shape[0]))
  for i in range(output_attentions.shape[0]):
    test_non_zeroes = torch.nonzero(output_attentions[i]).squeeze()
    test2 = output_attentions[i][test_non_zeroes[1:]]
    new_attention_matrix_indices = test_non_zeroes[1:]-257 + i
    new_attention_matrix[i][new_attention_matrix_indices] = test2
    new_attention_matrix[i][0] = output_attentions[i][0]
    new_attention_matrix[0] = global_attentions.squeeze()[:output_attentions.shape[0]]
  return new_attention_matrix


def attentions_all_heads(output_attentions, global_attentions):
    new_matrix = []
    for i in range(output_attentions.shape[0]):
      matrix = create_head_matrix(output_attentions[i], global_attentions[i])
      new_matrix.append(matrix)
    return torch.stack(new_matrix)

def all_batches(output_attentions, global_attentions):
    new_matrix = []
    for i in range(output_attentions.shape[0]):
      matrix = attentions_all_heads(output_attentions[i], global_attentions[i])
      new_matrix.append(matrix)
    return torch.stack(new_matrix)

def all_layers(output_attentions, global_attentions):
    new_matrix = []
    for i in range(output_attentions.shape[0]):
      matrix = all_batches(output_attentions[i], global_attentions[i])
      new_matrix.append(matrix)
    return torch.stack(new_matrix)

130
130
torch.Size([512])


In [35]:
new_matrix = create_head_matrix(test, global_test)
# new_matrix2 = attentions_all_heads(output_attentions_all[11][0], 
#                                    global_attention_all[11][0])
# new_matrix3 = all_batches(output_attentions_all[11], 
#                           global_attention_all[11])
new_matrix4 = all_layers(output_attentions_all, 
                         global_attention_all)
print(new_matrix)
# print(new_matrix2.shape)
# print(new_matrix3.shape)
print(new_matrix4.shape)

tensor([[0.0087, 0.0061, 0.0082,  ..., 0.0070, 0.0053, 0.0058],
        [0.0069, 0.0098, 0.0086,  ..., 0.0084, 0.0057, 0.0072],
        [0.0087, 0.0065, 0.0093,  ..., 0.0077, 0.0053, 0.0060],
        ...,
        [0.0086, 0.0071, 0.0091,  ..., 0.0108, 0.0075, 0.0089],
        [0.0095, 0.0071, 0.0098,  ..., 0.0122, 0.0085, 0.0101],
        [0.0085, 0.0070, 0.0096,  ..., 0.0107, 0.0074, 0.0095]],
       grad_fn=<CopySlices>)
torch.Size([12, 1, 12, 130, 130])


In [36]:
print(new_matrix.shape)

torch.Size([130, 130])


In [42]:
print(new_matrix4[1].shape)
batch_attn = torch.cat([l.sum(dim=0).unsqueeze(0) for l in new_matrix4], dim=0).detach().cpu().numpy()

torch.Size([1, 12, 130, 130])


AttributeError: ignored