This notebook is a part of the [T3-vis](https://arxiv.org/abs/2108.13587) implmentation for visualizing Transformer Neural Networks. Here, using the dataset, model and the functions, we predict over the validation set and grab the attentions for every example. We then aggregate all the attentions together, normalize, then save the file for use in another notebook where we display the aggregate attention to look for patterns.

### Importing Packages

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
# import sys
# sys.path.append('/content/drive/My Drive/{}'.format("cogs402longformer/"))

In [3]:
pip install datasets --quiet

In [4]:
pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [5]:
import os

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

### Import Dataset and Model

Here we are importing the model and the dataset we want to assess. The import is replicating the manner used by the T3-vis implementation, with the removal of a few items such as "idx" and "visualize columns" as they are unnecessary. 

In [6]:
from datasets import load_dataset
from transformers import LongformerForSequenceClassification, AutoTokenizer

Here we import the "papers" model and dataset.

In [7]:
def longformer_finetuned_notes():
    test = torch.load("/content/drive/MyDrive/fakeclinicalnotes/models/full_augmented_lr2e-5_dropout3_10_trained_threshold.pt")
    model = LongformerForSequenceClassification.from_pretrained('allenai/longformer-base-4096', state_dict=test['state_dict'], num_labels = 2)
    return model

def preprocess_function(tokenizer, example, max_length):
    example.update(tokenizer(example['text'], padding='max_length', max_length=max_length, truncation=True))
    return example

def get_notes_dataset(dataset_type):
    max_length = 2048
    dataset = load_dataset("danielhou13/cogs402datafake")[dataset_type]

    tokenizer = AutoTokenizer.from_pretrained('allenai/longformer-base-4096')
    # tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
    dataset = dataset.map(lambda x: preprocess_function(tokenizer, x, max_length), batched=True)
    setattr(dataset, 'input_columns', ['input_ids', 'attention_mask'])
    setattr(dataset, 'target_columns', ['labels'])
    setattr(dataset, 'max_length', max_length)
    setattr(dataset, 'tokenizer', tokenizer)
    return dataset

def notes_train_set():
    return get_notes_dataset('train')

Here we are calling the functions to import the dataset, model and making sure that the dataset is in a pytorch compatible manner.

In [8]:
cogs402_test = notes_train_set()
model = longformer_finetuned_notes()
columns = cogs402_test.input_columns + cogs402_test.target_columns
print(columns)
cogs402_test.set_format(type='torch', columns=columns)
cogs402_test=cogs402_test.remove_columns(['text'])

Using custom data configuration danielhou13--cogs402datafake-f5349e6cf83e41d8
Reusing dataset parquet (/root/.cache/huggingface/datasets/danielhou13___parquet/danielhou13--cogs402datafake-f5349e6cf83e41d8/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/1 [00:00<?, ?it/s]

Loading cached processed dataset at /root/.cache/huggingface/datasets/danielhou13___parquet/danielhou13--cogs402datafake-f5349e6cf83e41d8/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-decb6b766464f86e.arrow
Some weights of the model checkpoint at allenai/longformer-base-4096 were not used when initializing LongformerForSequenceClassification: ['longformer_model.embeddings.token_type_embeddings.weight', 'longformer_model.encoder.layer.3.attention.self.key_global.weight', 'longformer_model.encoder.layer.8.output.dense.bias', 'longformer_model.embeddings.LayerNorm.weight', 'longformer_model.encoder.layer.5.attention.self.query.weight', 'longformer_model.encoder.layer.0.attention.self.query_global.bias', 'longformer_model.encoder.layer.3.attention.self.key.weight', 'longformer_model.encoder.layer.4.attention.self.value_global.bias', 'longformer_model.encoder.layer.9.attention.self.key_global.bias', 'longformer_model.encoder.layer.7.attention.self.query.weight

['input_ids', 'attention_mask', 'labels']


Don't forget to use your GPU if you have one for faster performance.

In [9]:
if torch.cuda.is_available():
    model = model.cuda()
    cuda0 = torch.device('cuda:0')

print(model.device)

cuda:0


### Functions

The following block is the normalization code by T3-vis. It operates by converting the values into colour values which we can use for plotting. For each layer and head in the complete attention matrix, we take the (seq_len, seq_len) matrix, normalize all the values, then scale it so each value is between 0-255. Arrays of the same shape representing colour channels are then made, and the colours we do not want are masked and set to 0. Finally we stack our 4 matrices so that each item in our original array now contains 4 colour values (red, blue, green, alpha). We then convert this into a list, keeping the 4 colour values of each item sequential. 

**The input is an array of shape: (layer, batch, head, seq_len, seq_len)**. 

**The output shape is a (4 x layer x batch x head x seq_len x seq_len) list**

The aggregated attention contains layer x batch x head x seq_len x seq_len values in the attention matrix and there are 4 colour channels: red, blue, green, alpha (controls how opaque the colour is).

In [10]:
def format_attention_image(attention):
    formatted_attn = []
    for layer_idx in range(attention.shape[0]):
        for head_idx in range(attention.shape[1]):
            formatted_entry = {
                'layer': layer_idx,
                'head': head_idx
            }

            # Flatten value of log attention normalize between 255 and 0
            if len(attention[layer_idx, head_idx]) == 0:
                continue
            attn = np.array(attention[layer_idx, head_idx]).flatten()
            attn = (attn - attn.min()) / (attn.max() - attn.min())
            alpha = np.round(attn * 255)
            red = np.ones_like(alpha) * 255
            green = np.zeros_like(alpha) * 255
            blue = np.zeros_like(alpha) * 255

            attn_data = np.dstack([red,green,blue,alpha]).reshape(alpha.shape[0] * 4).astype('uint8')
            formatted_entry['attn'] = attn_data.tolist()
            formatted_attn.append(formatted_entry)
    return formatted_attn

This block of code, **not found in T3-vis**, adapts the longformer model's sliding window attention into the traditional attention format of seq_len x seq_len. More information can be found in this notebook [here](https://colab.research.google.com/drive/1Kxx26NtIlUzioRCHpsR8IbSz_DpRFxEZ).

**Input: Tensor of shape: (layer, batch, head, seq_len, x + sliding_window_size + 1), where x is the number of tokens with global attention.** \\

**Output:
Tensor of shape: (layer, batch, head, seq_len, seq_len)**

By converting the matrix into a seq_len x seq_len matrix, the T3-vis functions will work as intended.

In [11]:
def create_head_matrix(output_attentions, global_attentions):
    new_attention_matrix = torch.zeros((output_attentions.shape[0], 
                                      output_attentions.shape[0]))
    for i in range(output_attentions.shape[0]):
        test_non_zeroes = torch.nonzero(output_attentions[i]).squeeze()
        test2 = output_attentions[i][test_non_zeroes[1:]]
        new_attention_matrix_indices = test_non_zeroes[1:]-257 + i
        new_attention_matrix[i][new_attention_matrix_indices] = test2
        new_attention_matrix[i][0] = output_attentions[i][0]
        new_attention_matrix[0] = global_attentions.squeeze()[:output_attentions.shape[0]]
    return new_attention_matrix


def attentions_all_heads(output_attentions, global_attentions):
    new_matrix = []
    for i in range(output_attentions.shape[0]):
        matrix = create_head_matrix(output_attentions[i], global_attentions[i])
        new_matrix.append(matrix)
    return torch.stack(new_matrix)

def all_batches(output_attentions, global_attentions):
    new_matrix = []
    for i in range(output_attentions.shape[0]):
        matrix = attentions_all_heads(output_attentions[i], global_attentions[i])
        new_matrix.append(matrix)
    return torch.stack(new_matrix)

def all_layers(output_attentions, global_attentions):
    new_matrix = []
    for i in range(output_attentions.shape[0]):
        matrix = all_batches(output_attentions[i], global_attentions[i])
        new_matrix.append(matrix)
    return torch.stack(new_matrix)

This T3-vis function is used to find collect and aggregate the attentions over the entire validation set. It iterates through the dataset, gets the attention for each example, converts each example's attention into the correct shape, then aggregates the attention. Lastly, it sends the aggregated attention matrix to the normalizer function to create a 4 * layer * batch * head * seq_len * seq_len list.

In [12]:
from tqdm import tqdm
def compute_aggregated_attn(model, dataloader, max_input_len):

    n_layers = model.longformer.config.num_hidden_layers
    n_heads = model.longformer.config.num_attention_heads
    # head_size = int(model.longformer.config.hidden_size / n_heads)
    # n_examples = len(dataloader.dataset)

    # importance_scores = np.zeros((n_layers, n_heads))

    device = model.device
    # total_loss = 0.
    attn = np.zeros((n_layers, n_heads, max_input_len, max_input_len))
    print(attn.shape)
    model.eval()

    attn_normalize_count = torch.zeros(max_input_len, device=device)

    for step, inputs in enumerate(tqdm(dataloader, position=0, leave=True)):

        # batch_size_ = inputs['input_ids'].__len__()

        if torch.cuda.is_available():
            for k, v in inputs.items():
                if isinstance(v, torch.Tensor):
                    inputs[k] = v.cuda()
        
        inputs['output_attentions']=True
        
        with torch.no_grad():
            output = model(**inputs)
        
        
        attn_normalize_count += inputs['attention_mask'].sum(dim=0)
        batch_attn = output[-2]
        global_attn = output[-1]
        
        # print(batch_attn[1].shape)
        output_attentions = torch.stack(batch_attn).cpu()

        # print("output_attention.shape", output_attentions.shape)
        global_attentions = torch.stack(global_attn).cpu()
         
        # print(output_attentions.device)
        # print(global_attentions.device)
        
        batch_attn2 = all_layers(output_attentions, global_attentions)
    
        # print(batch_attn2.shape)
        batch_attn = torch.cat([l.sum(dim=0).unsqueeze(0) for l in batch_attn2], dim=0).cpu().numpy()
        
        attn += batch_attn
        
    max_input_len = len(attn_normalize_count.nonzero(as_tuple=False))
    
    attn = attn[:, :, :max_input_len, :max_input_len]
    attn /= attn_normalize_count.cpu().numpy()[:max_input_len]
    print(type(attn))
    formatted_attn = format_attention_image(attn)
    return formatted_attn

The functions operate on a dataloader so we convert our validation set into a dataloader format, using batch_size=1 to minimize our memory usage as longformer uses lots of memory.

In [13]:
dataloader = torch.utils.data.DataLoader(cogs402_test, batch_size=1)

We run the function here, passing in the model, dataloader and how many tokens your want your attention matrix to visualize.

In [None]:
test = compute_aggregated_attn(model, dataloader, cogs402_test.max_length)

(12, 12, 2048, 2048)


100%|██████████| 12/12 [04:59<00:00, 24.99s/it]


<class 'numpy.ndarray'>


In [None]:
print(type(test))

Lastly, we save our new aggregate attention matrix. Remember to change the path to whatever suits your project's needs. The commented-out line of code saves the attention matrix in the current working directory.

Warning: This line of code may take a long time as the ndarray can be very large.

In [None]:
torch.save(test, "/content/drive/MyDrive/fakeclinicalnotes/t3-visapplication/notes/aggregate_attn.pt")
# torch.save(test, "aggregate_attn.pt")