<a href="https://colab.research.google.com/github/danielhou13/cogs402longformer/blob/main/src/T3-vis/T3_vis_aggregate_attn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook is a part of the [T3-vis](https://arxiv.org/abs/2108.13587) implmentation for visualizing Transformer Neural Networks. Here, using the dataset, model and the functions, we predict over the validation set and grab the attentions for every example. We then aggregate all the attentions together, normalize, then save the file for use in another notebook where we display the aggregate attention to look for patterns.

### Importing Packages

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# import sys
# sys.path.append('/content/drive/My Drive/{}'.format("cogs402longformer/"))

In [None]:
pip install datasets --quiet

In [None]:
pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import os

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

### Import Dataset and Model

Here we are importing the model and the dataset we want to assess. The import is replicating the manner used by the T3-vis implementation, with the removal of a few items such as "idx" and "visualize columns" as they are unnecessary. 

In [None]:
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer

Here we import the "papers" model and dataset.

In [None]:
# def longformer_finetuned_papers():
#     model = AutoModelForSequenceClassification.from_pretrained('danielhou13/longformer-finetuned_papers', num_labels = 2, output_attentions = True)
#     return model

# # def bert_test():
# #     model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased', num_labels = 2)
# #     setattr(model, 'num_hidden_layers', model.config.num_hidden_layers)
# #     setattr(model, 'num_attention_heads', model.config.num_attention_heads)
# #     setattr(model, 'hidden_size', model.config.hidden_size)
# #     return model

# def preprocess_function(tokenizer, example, max_length):
#     example.update(tokenizer(example['text'], padding='max_length', max_length=max_length, truncation=True))
#     return example

# def get_papers_dataset(dataset_type):
#     max_length = 2048
#     dataset = load_dataset("danielhou13/cogs402dataset")[dataset_type]

#     tokenizer = AutoTokenizer.from_pretrained('allenai/longformer-base-4096')
#     # tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
#     dataset = dataset.map(lambda x: preprocess_function(tokenizer, x, max_length), batched=True)
#     setattr(dataset, 'input_columns', ['input_ids', 'attention_mask'])
#     setattr(dataset, 'target_columns', ['labels'])
#     setattr(dataset, 'max_length', max_length)
#     setattr(dataset, 'tokenizer', tokenizer)
#     return dataset

# def papers_test_set():
#     return get_papers_dataset('test')

Here we import the "hyperpartisan news" dataset and model.

In [None]:
def preprocess_function(tokenizer, example, max_length):
    example.update(tokenizer(example['text'], padding='max_length', max_length=max_length, truncation=True))
    return example

def longformer_finetuned_news():
    model = AutoModelForSequenceClassification.from_pretrained('danielhou13/longformer-finetuned-news-cogs402', num_labels = 2)
    return model

def get_news_dataset(dataset_type):
    max_length = 2048
    dataset = load_dataset("danielhou13/cogs402dataset2")[dataset_type]

    tokenizer = AutoTokenizer.from_pretrained('allenai/longformer-base-4096')
    dataset = dataset.map(lambda x: preprocess_function(tokenizer, x, max_length), batched=True)

    labels = map(int, dataset['hyperpartisan'])
    print(type(dataset['hyperpartisan']))
    labels = list(labels)
    dataset = dataset.add_column("labels", labels)

    dataset = dataset.remove_columns(['text', 'title', 'hyperpartisan', 'url', 'published_at', 'bias'])
    print(dataset)
    setattr(dataset, 'input_columns', ['input_ids', 'attention_mask'])
    setattr(dataset, 'target_columns', ['labels'])
    setattr(dataset, 'max_length', max_length)
    setattr(dataset, 'tokenizer', tokenizer)
    return dataset

def news_train_set():
    return get_news_dataset('train')

def news_test_set():
    return get_news_dataset('validation')

Here we are calling the functions to import the dataset, model and making sure that the dataset is in a pytorch compatible manner.

In [None]:
# cogs402_test = papers_test_set()
# model = longformer_finetuned_papers()
# columns = cogs402_test.input_columns + cogs402_test.target_columns
# print(columns)
# cogs402_test.set_format(type='torch', columns=columns)
# cogs402_test=cogs402_test.remove_columns(['text'])

In [None]:
cogs402_test = news_test_set()
model = longformer_finetuned_news()
columns = cogs402_test.input_columns + cogs402_test.target_columns
print(columns)
cogs402_test.set_format(type='torch', columns=columns)

Downloading:   0%|          | 0.00/2.26k [00:00<?, ?B/s]

Using custom data configuration danielhou13--cogs402dataset2-6edc4c363493b501


Downloading and preparing dataset hyperpartisan_news_detection/bypublisher (download: 38.81 MiB, generated: 66.89 MiB, post-processed: Unknown size, total: 105.70 MiB) to /root/.cache/huggingface/datasets/danielhou13___parquet/danielhou13--cogs402dataset2-6edc4c363493b501/0.0.0/7328ef7ee03eaf3f86ae40594d46a1cec86161704e02dd19f232d81eee72ade8...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/6.91M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/33.8M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

0 tables [00:00, ? tables/s]

0 tables [00:00, ? tables/s]

Dataset parquet downloaded and prepared to /root/.cache/huggingface/datasets/danielhou13___parquet/danielhou13--cogs402dataset2-6edc4c363493b501/0.0.0/7328ef7ee03eaf3f86ae40594d46a1cec86161704e02dd19f232d81eee72ade8. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]



  0%|          | 0/3 [00:00<?, ?ba/s]

<class 'list'>
Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 2500
})
['input_ids', 'attention_mask', 'labels']


Don't forget to use your GPU if you have one for faster performance.

In [None]:
if torch.cuda.is_available():
    model = model.cuda()
    cuda0 = torch.device('cuda:0')

print(model.device)

cuda:0


Sanity Check

In [None]:
test = cogs402_test

In [None]:
# print(test['labels'][923])

In [None]:
output = model(test["input_ids"][923].unsqueeze(0).cuda(), attention_mask=test['attention_mask'][923].unsqueeze(0).cuda(), labels=test['labels'][923].cuda(), output_attentions=True)
batch_attn = output[-2]
output_attentions = torch.stack(batch_attn).cpu()
print("output_attention.shape", output_attentions.shape)

output_attention.shape torch.Size([12, 1, 12, 2048, 514])


In [None]:
# print(os.getcwd())
# yes = torch.load("resources/longformer_test2/epoch_3/aggregate_attn.pt")

### Functions

The following block is the normalization code by T3-vis. It operates by converting the values into colour values which we can use for plotting. 

**The input is an array of shape: (layer, batch, head, seq_len, seq_len)**. 

**The output shape is a (4 x layer x batch x head x seq_len x seq_len) list**

The aggregated attention contains layer x batch x head x seq_len x seq_len values in the attention matrix and there are 4 colour channels: red, blue, green, alpha (controls how opaque the colour is).

In [None]:
def format_attention_image(attention):
    formatted_attn = []
    for layer_idx in range(attention.shape[0]):
        for head_idx in range(attention.shape[1]):
            formatted_entry = {
                'layer': layer_idx,
                'head': head_idx
            }

            # Flatten value of log attention normalize between 255 and 0
            if len(attention[layer_idx, head_idx]) == 0:
                continue
            attn = np.array(attention[layer_idx, head_idx]).flatten()
            attn = (attn - attn.min()) / (attn.max() - attn.min())
            alpha = np.round(attn * 255)
            red = np.ones_like(alpha) * 255
            green = np.zeros_like(alpha) * 255
            blue = np.zeros_like(alpha) * 255

            attn_data = np.dstack([red,green,blue,alpha]).reshape(alpha.shape[0] * 4).astype('uint8')
            formatted_entry['attn'] = attn_data.tolist()
            formatted_attn.append(formatted_entry)
    return formatted_attn

This block of code, not found in T3-vis, adapts the longformer model's sliding window attention into the traditional attention format of seq_len x seq_len. 

**Input: Tensor of shape: (Layer, Batch, Head, seq_len, x + sliding_window_size + 1) \\
(where x is how many tokens you set to have global attentions)**

**Output:
Tensor of shape: (layer, batch, head, seq_len, seq_len)**

That way, the T3-vis functions will work as intended.



In [None]:
def create_head_matrix(output_attentions, global_attentions):
    new_attention_matrix = torch.zeros((output_attentions.shape[0], 
                                      output_attentions.shape[0]))
    for i in range(output_attentions.shape[0]):
        test_non_zeroes = torch.nonzero(output_attentions[i]).squeeze()
        test2 = output_attentions[i][test_non_zeroes[1:]]
        new_attention_matrix_indices = test_non_zeroes[1:]-257 + i
        new_attention_matrix[i][new_attention_matrix_indices] = test2
        new_attention_matrix[i][0] = output_attentions[i][0]
        new_attention_matrix[0] = global_attentions.squeeze()[:output_attentions.shape[0]]
    return new_attention_matrix


def attentions_all_heads(output_attentions, global_attentions):
    new_matrix = []
    for i in range(output_attentions.shape[0]):
        matrix = create_head_matrix(output_attentions[i], global_attentions[i])
        new_matrix.append(matrix)
    return torch.stack(new_matrix)

def all_batches(output_attentions, global_attentions):
    new_matrix = []
    for i in range(output_attentions.shape[0]):
        matrix = attentions_all_heads(output_attentions[i], global_attentions[i])
        new_matrix.append(matrix)
    return torch.stack(new_matrix)

def all_layers(output_attentions, global_attentions):
    new_matrix = []
    for i in range(output_attentions.shape[0]):
        matrix = all_batches(output_attentions[i], global_attentions[i])
        new_matrix.append(matrix)
    return torch.stack(new_matrix)

This T3-vis function is used to find collect and aggregate the attentions over the entire validation set. It iterates through the dataset, gets the attention for each example, converts each example's attention into the correct shape, then aggregates the attention. Lastly, it sends the aggregated attention matrix to the normalizer function to create a 4 * layer * batch * head * seq_len * seq_len list.

In [None]:
from tqdm import tqdm
def compute_aggregated_attn(model, dataloader, max_input_len):

    n_layers = model.longformer.config.num_hidden_layers
    n_heads = model.longformer.config.num_attention_heads
    # head_size = int(model.longformer.config.hidden_size / n_heads)
    # n_examples = len(dataloader.dataset)

    # importance_scores = np.zeros((n_layers, n_heads))

    device = model.device
    # total_loss = 0.
    attn = np.zeros((n_layers, n_heads, max_input_len, max_input_len))
    print(attn.shape)
    model.eval()

    attn_normalize_count = torch.zeros(max_input_len, device=device)

    for step, inputs in enumerate(tqdm(dataloader, position=0, leave=True)):

        # batch_size_ = inputs['input_ids'].__len__()

        if torch.cuda.is_available():
            for k, v in inputs.items():
                if isinstance(v, torch.Tensor):
                    inputs[k] = v.cuda()
        
        inputs['output_attentions']=True
        
        with torch.no_grad():
            output = model(**inputs)
        
        
        attn_normalize_count += inputs['attention_mask'].sum(dim=0)
        batch_attn = output[-2]
        global_attn = output[-1]
        
        # print(batch_attn[1].shape)
        output_attentions = torch.stack(batch_attn).cpu()

        # print("output_attention.shape", output_attentions.shape)
        global_attentions = torch.stack(global_attn).cpu()
         
        # print(output_attentions.device)
        # print(global_attentions.device)
        
        batch_attn2 = all_layers(output_attentions, global_attentions)
    
        # print(batch_attn2.shape)
        batch_attn = torch.cat([l.sum(dim=0).unsqueeze(0) for l in batch_attn2], dim=0).cpu().numpy()
        
        attn += batch_attn
        
    max_input_len = len(attn_normalize_count.nonzero(as_tuple=False))
    
    attn = attn[:, :, :max_input_len, :max_input_len]
    attn /= attn_normalize_count.cpu().numpy()[:max_input_len]
    print(type(attn))
    formatted_attn = format_attention_image(attn)
    return formatted_attn

The functions operate on a dataloader so we convert our dataset into a dataloader format, using batch_size=1 to minimize our memory usage as longformer uses lots of memory.

In [None]:
dataloader = torch.utils.data.DataLoader(cogs402_test, batch_size=1)

We run the function here, passing in the model, dataloader and how many tokens your want your attention matrix to visualize.

In [None]:
test = compute_aggregated_attn(model, dataloader, cogs402_test.max_length)

(12, 12, 2048, 2048)


100%|██████████| 2500/2500 [14:32:45<00:00, 20.95s/it]


<class 'numpy.ndarray'>


In [None]:
print(type(test))

Lastly, we save our new aggregate attention matrix. Remember to change the path to whatever suits your project's needs. The commented-out line of code saves the attention matrix in the current working directory.

Warning: This line of code may take a long time as the ndarray can be very large.

In [None]:
torch.save(test, "/content/drive/MyDrive/cogs402longformer/t3-visapplication/resources/news/aggregate_attn.pt")
# torch.save(test, "aggregate_attn.pt")