<a href="https://colab.research.google.com/github/danielhou13/cogs402longformer/blob/main/src/Token_head_importance.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
# import sys
# sys.path.append('/content/drive/My Drive/{}'.format("cogs402longformer/"))

In [3]:
pip install datasets --quiet

In [4]:
pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


Import Dataset and Model

In [5]:
import os

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

In [6]:
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('allenai/longformer-base-4096')

def longformer_finetuned_papers():
    model = AutoModelForSequenceClassification.from_pretrained('danielhou13/longformer-finetuned_papers', num_labels = 2, output_attentions = True)
    return model

def preprocess_function(tokenizer, example, max_length):
    example.update(tokenizer(example['text'], padding='max_length', max_length=max_length, truncation=True))
    return example

def get_papers_dataset(dataset_type):
    max_length = 2048
    dataset = load_dataset("danielhou13/cogs402dataset")[dataset_type]

    # tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
    dataset = dataset.map(lambda x: preprocess_function(tokenizer, x, max_length), batched=True)
    setattr(dataset, 'input_columns', ['input_ids', 'attention_mask'])
    setattr(dataset, 'target_columns', ['labels'])
    setattr(dataset, 'max_length', max_length)
    setattr(dataset, 'tokenizer', tokenizer)
    return dataset

def papers_test_set():
    return get_papers_dataset('test')

In [7]:
cogs402_test = papers_test_set()
model = longformer_finetuned_papers()
columns = cogs402_test.input_columns + cogs402_test.target_columns
print(columns)
cogs402_test.set_format(type='torch', columns=columns)
cogs402_test=cogs402_test.remove_columns(['text'])

Using custom data configuration danielhou13--cogs402dataset-cc784554b797f843
Reusing dataset parquet (/root/.cache/huggingface/datasets/danielhou13___parquet/danielhou13--cogs402dataset-cc784554b797f843/0.0.0/7328ef7ee03eaf3f86ae40594d46a1cec86161704e02dd19f232d81eee72ade8)


  0%|          | 0/2 [00:00<?, ?it/s]



  0%|          | 0/2 [00:00<?, ?ba/s]

['input_ids', 'attention_mask', 'labels']


In [8]:
if torch.cuda.is_available():
    model = model.cuda()

print(model.device)

cuda:0


Take example for evaluation

In [9]:
testexam = cogs402_test[923]

In [10]:
# print(test['labels'][923])

In [11]:
output = model(testexam["input_ids"].unsqueeze(0).cuda(), attention_mask=testexam['attention_mask'].unsqueeze(0).cuda(), labels=testexam['labels'].cuda())
batch_attn = output[-2]
output_attentions = torch.stack(batch_attn).cpu()
global_attention = output[-1]
output_global_attentions = torch.stack(global_attention).cpu()
print("output_attention.shape", output_attentions.shape)
print("gl_output_attention.shape", output_global_attentions.shape)

output_attention.shape torch.Size([12, 1, 12, 2048, 514])
gl_output_attention.shape torch.Size([12, 1, 12, 2048, 1])


In [12]:
# print(os.getcwd())
# yes = torch.load("resources/longformer_test2/epoch_3/aggregate_attn.pt")

Convert sliding attention matrix to correct seq_len x seq_len matrix

In [13]:
def create_head_matrix(output_attentions, global_attentions):
    new_attention_matrix = torch.zeros((output_attentions.shape[0], 
                                      output_attentions.shape[0]))
    for i in range(output_attentions.shape[0]):
        test_non_zeroes = torch.nonzero(output_attentions[i]).squeeze()
        test2 = output_attentions[i][test_non_zeroes[1:]]
        new_attention_matrix_indices = test_non_zeroes[1:]-257 + i
        new_attention_matrix[i][new_attention_matrix_indices] = test2
        new_attention_matrix[i][0] = output_attentions[i][0]
        new_attention_matrix[0] = global_attentions.squeeze()[:output_attentions.shape[0]]
    return new_attention_matrix


def attentions_all_heads(output_attentions, global_attentions):
    new_matrix = []
    for i in range(output_attentions.shape[0]):
        matrix = create_head_matrix(output_attentions[i], global_attentions[i])
        new_matrix.append(matrix)
    return torch.stack(new_matrix)

def all_batches(output_attentions, global_attentions):
    new_matrix = []
    for i in range(output_attentions.shape[0]):
        matrix = attentions_all_heads(output_attentions[i], global_attentions[i])
        new_matrix.append(matrix)
    return torch.stack(new_matrix)

def all_layers(output_attentions, global_attentions):
    new_matrix = []
    for i in range(output_attentions.shape[0]):
        matrix = all_batches(output_attentions[i], global_attentions[i])
        new_matrix.append(matrix)
    return torch.stack(new_matrix)

In [14]:
converted_mat = all_layers(output_attentions, output_global_attentions).detach().cpu().numpy()
print(converted_mat.shape)

(12, 1, 12, 2048, 2048)


Sum over all the tokens (column-wise)

In [15]:
attention_sum = converted_mat.sum(axis=3)
print(attention_sum.shape)

(12, 1, 12, 2048)


Load head importance model and scale the attentions by head importance

In [16]:
head_importance = torch.load("/content/drive/MyDrive/cogs402longformer/t3-visapplication/resources/pretrained/head_importance.pt")

In [17]:
print(head_importance[2])

[0.00404374 0.02163236 0.00531127 0.01472499 0.03868961 0.00646301
 0.01300004 0.00334545 1.         0.00796655 0.1583804  0.01080081]


In [18]:
def scale_by_importance(attention_matrix, head_importance):
  new_matrix = np.zeros_like(attention_matrix)
  for i in range(attention_matrix.shape[0]):
    head_importance_layer = head_importance[i]
    for j in range(attention_matrix.shape[1]):
      new_matrix[i][j] = attention_matrix[i][j] * np.expand_dims(head_importance_layer, axis=1)
  return new_matrix

In [19]:
attention_matrix_importance = scale_by_importance(attention_sum, head_importance)
print(attention_matrix_importance.shape)

(12, 1, 12, 2048)


In [20]:
list = np.array(range(1,10))
print(list)
print(list[-4:])
print(list[:-6:-1])

[1 2 3 4 5 6 7 8 9]
[6 7 8 9]
[9 8 7 6 5]


Get top k attended words for each head, for each example in batch, for each layer

In [21]:
def find_top_attention(scores_mat, axis, k):
  indices = scores_mat.argsort(axis=axis)[:, :, :, :-(k+1):-1]
  vals = np.take_along_axis(scores_mat, indices, axis=axis)
  return indices, vals

In [22]:
all_tokens = tokenizer.convert_ids_to_tokens(testexam["input_ids"])

We want the position (index) of the token, the attention value, and the actual token itself.

In [23]:
indexes, values = find_top_attention(attention_matrix_importance, 3, 10)

In [24]:
def get_tokens(index_matrix):
  highest_tokens = []
  for i in range(indexes.shape[0]):
    row_tokens = []
    for j in range(indexes.shape[1]):
      batch_tokens = []
      for k in range(indexes.shape[2]):
        tokens = [all_tokens[idx] for idx in indexes[i][j][k]]
        batch_tokens.append(tokens)
      row_tokens.append(batch_tokens)
    highest_tokens.append(row_tokens)
  return np.array(highest_tokens)

highest_tokens = get_tokens(indexes)
print(highest_tokens.shape)

(12, 1, 12, 10)


In [25]:
print(indexes[:][:][0], values[:][:][0], highest_tokens[:][:][0])

[[[   0  512 1536 1024   61 1079 2029  558   56  414]
  [   0 1536 1024  195 1311  634  263 1261 1252  276]
  [   0 1024  512  848  839 1838  263  282 1785 1789]
  [   0  735 1024 1852 1536  382 1088 1162  363  420]
  [   0 1145 1480 1074  263 1364 1322  195  252  848]
  [   0  512 1536 1024 2040 1942 1577 1508  414 1362]
  [   0  512 1024 1536 1871   61 1162 1713  138  576]
  [ 663 1706  587  756  382  913    0  854  593  692]
  [   0 1677 1701  195  263  565  382   53   55 1074]
  [1222  570 1721  640 1959  299 1360 1213   24 1413]
  [   0 1024  512  570 1269 1790 1785  382  778  845]
  [   0 1222  569 1316 1696 1026 1219 1633  570 1293]]] [[[ 1.0553687   0.4273339   0.18387423  0.16710557  0.13229102
    0.12186833  0.10956845  0.1094526   0.10861638  0.10686851]
  [ 0.98397076  0.406099    0.34594324  0.30012038  0.29888272
    0.28868958  0.28849742  0.27655417  0.27404934  0.2696984 ]
  [ 0.9272709   0.20314224  0.18950082  0.16722995  0.16298965
    0.15621807  0.15494066  0.154