<a href="https://colab.research.google.com/github/danielhou13/cogs402longformer/blob/main/src/T3-vis/T3_vis_head_importance.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook is a part of the [T3-vis](https://arxiv.org/abs/2108.13587) implmentation for visualizing Transformer Neural Networks. Here, using the dataset, model and the functions, we calculate the importance of each head in each layer, allowing us to scale the attention output of a transformer model (in this case longformer) by their head and layer. By doing so, there is more attention on the tokens for the head and layer with the most importance.

### Import Packages

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import os
import torch
import numpy as np
from tqdm import tqdm
import pdb

In [None]:
pip install datasets --quiet

[K     |████████████████████████████████| 362 kB 7.4 MB/s 
[K     |████████████████████████████████| 212 kB 73.6 MB/s 
[K     |████████████████████████████████| 101 kB 12.2 MB/s 
[K     |████████████████████████████████| 140 kB 94.5 MB/s 
[K     |████████████████████████████████| 1.1 MB 88.0 MB/s 
[K     |████████████████████████████████| 596 kB 50.4 MB/s 
[K     |████████████████████████████████| 127 kB 91.5 MB/s 
[K     |████████████████████████████████| 144 kB 65.8 MB/s 
[K     |████████████████████████████████| 94 kB 4.2 MB/s 
[K     |████████████████████████████████| 271 kB 77.3 MB/s 
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.[0m
[?25h

In [None]:
pip install transformers --quiet

[K     |████████████████████████████████| 4.4 MB 8.5 MB/s 
[K     |████████████████████████████████| 6.6 MB 50.3 MB/s 
[?25h

### Import dataset

Here we are importing the model and the dataset we want to assess. The import is replicating the manner used by the T3-vis implementation, with the removal of a few items such as "idx" and "visualize columns" as they are unnecessary. 

In [None]:
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer

def longformer_finetuned_papers():
    model = AutoModelForSequenceClassification.from_pretrained('danielhou13/longformer-finetuned_papers_v2', num_labels = 2)
    return model

def preprocess_function(tokenizer, example, max_length):
    example.update(tokenizer(example['text'], padding='max_length', max_length=max_length, truncation=True))
    return example

def get_papers_dataset(dataset_type):
    max_length = 2048
    dataset = load_dataset("danielhou13/cogs402dataset")[dataset_type]
    new_col = list(np.arange(0, len(dataset)))

    tokenizer = AutoTokenizer.from_pretrained('allenai/longformer-base-4096')
    # tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
    dataset = dataset.map(lambda x: preprocess_function(tokenizer, x, max_length), batched=True)
    setattr(dataset, 'input_columns', ['input_ids', 'attention_mask'])
    setattr(dataset, 'target_columns', ['labels'])
    setattr(dataset, 'max_length', max_length)
    setattr(dataset, 'tokenizer', tokenizer)
    return dataset

def papers_train_set():
    return get_papers_dataset('train')

def papers_test_set():
    return get_papers_dataset('test')

### T3-vis functions

These functions are a copy of the T3-vis functions but substituting bert for longformer. 

We iterate over the entire dataset and compute the importance of each head for every layer.

**Output:
Array of shape: (Layer, Head)**

Each item in the array will be (after normalization) a value from 0.0 - 1.0 indicating how important that particular head is (1.0 being most important and 0.0 being least).


In [None]:
def normalize(matrix, axis=None):
    normalized = (matrix - matrix.min(axis=axis)) /\
                 (matrix.max(axis=axis) - matrix.min(axis=axis))
    return normalized

def find_pruneable_heads_and_indices(heads, n_heads, head_size, already_pruned_heads):
    """
    List, int, int, set -> Tuple[set, "torch.LongTensor"]
    """

    mask = torch.ones(n_heads, head_size)
    heads = set(heads) - already_pruned_heads  # Convert to set and remove already pruned heads
    for head in heads:
        # Compute how many pruned heads are before the head and move the index accordingly
        head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
        mask[head] = 0

    mask = mask.view(-1).contiguous().eq(1)
    index: torch.LongTensor = torch.arange(len(mask))[~mask].long()
    return heads, index

def get_taylor_importance(model):
    n_layers = model.num_hidden_layers
    n_heads = model.num_attention_heads
    head_size = int(model.hidden_size / n_heads)
    importance_scores = np.zeros((n_layers, n_heads))

    for i in range(n_layers):
        attention = model.longformer.encoder.layer[i].attention
        num_attention_heads = attention.self.num_heads

        pruned_heads = attention.pruned_heads
        leftover_heads = set(list(range(n_heads))) - pruned_heads

        for head_idx in leftover_heads:
            heads, index = find_pruneable_heads_and_indices([head_idx], num_attention_heads, head_size, pruned_heads)
            index = index.to(model.device)

            query_b_grad = (attention.self.query.bias.grad[index] *\
                            attention.self.query.bias[index]) ** 2
            query_W_grad = (attention.self.query.weight.grad.index_select(0, index) *\
                            attention.self.query.weight.index_select(0, index)) ** 2

            key_b_grad = (attention.self.key.bias.grad[index] *\
                          attention.self.key.bias[index]) ** 2
            key_W_grad = (attention.self.key.weight.grad.index_select(0, index) *\
                          attention.self.key.weight.index_select(0, index)) ** 2

            value_b_grad = (attention.self.value.bias.grad[index] *\
                            attention.self.value.bias[index]) ** 2
            value_W_grad = (attention.self.value.weight.grad.index_select(0, index) *\
                            attention.self.value.weight.index_select(0, index)) ** 2

            output_W_grad = (attention.output.dense.weight.grad.index_select(1, index) *
                             attention.output.dense.weight.index_select(1, index)) ** 2
            abs_grad_magnitude = query_b_grad.sum() + query_W_grad.sum() + key_b_grad.sum() + \
                key_W_grad.sum() + value_b_grad.sum() + value_W_grad.sum() + output_W_grad.sum()

                
            score = abs_grad_magnitude.item()
            importance_scores[i, head_idx] += score
    return importance_scores


def compute_importance(model, dataloader, measure='taylor'):

    assert measure in ['taylor', 'oracle', 'sensitivity']

    max_input_len = model.config.max_position_embeddings
    n_layers = model.config.num_hidden_layers
    n_heads = model.config.num_attention_heads
    head_size = int(model.config.hidden_size / n_heads)

    importance_scores = np.zeros((n_layers, n_heads))

    device = model.device
    total_loss = 0.

    if measure == 'sensitivity':
        head_mask = torch.ones(n_layers, n_heads).to(device)
        head_mask.requires_grad_(requires_grad=True)
    else:
        head_mask = None

    for step, inputs in enumerate(tqdm(dataloader)):
        batch_size_ = inputs['input_ids'].__len__()

        if torch.cuda.is_available():
            for k, v in inputs.items():
                if isinstance(v, torch.Tensor):
                    inputs[k] = v.cuda()


        output = model(**inputs)
        loss = output['loss']
        loss.backward()

        if measure == 'sensitivity':
            importance_scores += head_mask.grad.abs().detach().cpu().numpy()
        elif measure == 'taylor':
            importance_scores = get_taylor_importance(model)

    return importance_scores

Here we are calling the functions to import the dataset, model and making sure that the dataset is in a pytorch compatible manner.

In [None]:
cogs402_test = papers_train_set()
model = longformer_finetuned_papers()
columns = cogs402_test.input_columns + cogs402_test.target_columns
print(columns)
cogs402_test.set_format(type='torch', columns=columns+['idx'])
cogs402_test=cogs402_test.remove_columns(['text'])
print(cogs402_test)

Downloading:   0%|          | 0.00/739 [00:00<?, ?B/s]

Using custom data configuration danielhou13--cogs402dataset-144b958ac1a53abb


Downloading and preparing dataset None/None (download: 157.87 MiB, generated: 311.56 MiB, post-processed: Unknown size, total: 469.43 MiB) to /root/.cache/huggingface/datasets/danielhou13___parquet/danielhou13--cogs402dataset-144b958ac1a53abb/0.0.0/7328ef7ee03eaf3f86ae40594d46a1cec86161704e02dd19f232d81eee72ade8...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/132M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

0 tables [00:00, ? tables/s]

0 tables [00:00, ? tables/s]

Dataset parquet downloaded and prepared to /root/.cache/huggingface/datasets/danielhou13___parquet/danielhou13--cogs402dataset-144b958ac1a53abb/0.0.0/7328ef7ee03eaf3f86ae40594d46a1cec86161704e02dd19f232d81eee72ade8. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

Downloading:   0%|          | 0.00/694 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]



  0%|          | 0/5 [00:00<?, ?ba/s]

Downloading:   0%|          | 0.00/0.99k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/567M [00:00<?, ?B/s]

['input_ids', 'attention_mask', 'labels']
Dataset({
    features: ['labels', 'idx', 'input_ids', 'attention_mask'],
    num_rows: 4280
})


Don't forget to send the model to GPU if you have one.

In [None]:
if torch.cuda.is_available():
    model = model.cuda()

print(model.device)

cuda:0


The functions operate using the dataloader so we convert our validation set (as per the T3-vis implementation) to a dataloader. We keep batch size to 1 to minimize the memory required as longformer models are memory intensive.

In [None]:
val_dataloader = torch.utils.data.DataLoader(cogs402_test, batch_size=1)

Finally, we run the function, normalize the resulting matrix, and save the results for future use.

In [None]:
importance = compute_importance(model, val_dataloader)
importance = normalize(importance)

Lastly, we save our importance matrix. Remember to change the path to whatever suits your project's needs. The commented-out line of code saves the importance matrix in the current working directory.


In [None]:
torch.save(importance, "/content/drive/MyDrive/cogs402longformer/t3-visapplication/resources/papers/pretrained/head_importance.pt")
# torch.save(importance, "/head_importance.pt")