In [19]:
import os
import sys

sys.path.append("../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [20]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint, load_checkpoint
from src.models.evaluate import evaluate_model, get_sparsity, get_similarity
from src.utils.sampling import SamplingDataset


In [21]:
name = "bert-4-128-yahoo"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 128
ratio = 0.5
seed = 44

In [22]:
config = Config(name, device)

In [23]:
model = load_model(config=config)

Loading the model.
{'architectures': 'bert',
 'dataset_name': 'YahooAnswersTopics',
 'model_name': 'models/bert-4-128-yahoo',
 'num_labels': 10,
 'tokenizer_name': 'fabriceyhc/bert-base-uncased-yahoo_answers_topics'}
The model models/bert-4-128-yahoo is loaded.


In [24]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.
train.pkl is loaded from cache.
valid.pkl is loaded from cache.
test.pkl is loaded from cache.
The dataset YahooAnswersTopics is loaded
{'config_name': 'yahoo_answers_topics',
 'features': {'first_column': 'question_title', 'second_column': 'topic'},
 'path': 'yahoo_answers_topics'}


In [25]:
import copy
import torch
import numpy as np
from typing import *
import torch.nn as nn
from functools import partial
from transformers.pytorch_utils import (
    find_pruneable_heads_and_indices,
    prune_linear_layer,
)


def calculate_head_importance(
    model,
    config,
    data,
    normalize_scores_by_layer=True,
):
    device = config.device
    from functools import partial

    gradients = {}
    context_layers = {}

    def save_grad(gradients, layer_idx, grad):
        gradients[f"context_layer_{layer_idx}"] = grad

    def forward_hook(module, input, output, gradients, context_layers, layer_idx):
        context_layers[f"context_layer_{layer_idx}"] = output[0]
        output[0].register_hook(partial(save_grad, gradients, layer_idx))

    def reshape(tensors, shape, num_heads):
        batch_size = shape[0]
        seq_len = shape[1]
        head_dim = shape[2] // num_heads
        tensors = tensors.reshape(batch_size, seq_len, num_heads, head_dim)
        tensors = tensors.permute(0, 2, 1, 3)
        return tensors

    forward_handles = []

    for layer_idx in range(model.bert.config.num_hidden_layers):
        self_att = model.bert.encoder.layer[layer_idx].attention.self
        handle = self_att.register_forward_hook(
            partial(
                forward_hook,
                gradients=gradients,
                context_layers=context_layers,
                layer_idx=layer_idx,
            )
        )
        forward_handles.append(handle)

    """Calculate head importance scores"""
    # Disable dropout
    model.eval()
    # Device
    device = device or next(model.parameters()).device

    # Prepare data loader
    # Head importance tensor
    n_layers = model.bert.config.num_hidden_layers
    n_heads = model.bert.config.num_attention_heads
    head_importance = torch.zeros(n_layers, n_heads).to(device)
    tot_tokens = 0
    first_batch = next(iter(data))
    is_embeds = "embeddings" in first_batch
    for step, batch in enumerate(data):
        if is_embeds:
            embeddings = batch["embeddings"].to(device)
        else:
            input_ids = batch["input_ids"].to(device)
        input_mask = batch["attention_mask"].to(device)
        label_ids = batch["labels"].to(device)
        # Compute gradients
        if is_embeds:
            loss = model(
                inputs_embeds=embeddings, attention_mask=input_mask, labels=label_ids
            ).loss
        else:
            loss = model(input_ids, attention_mask=input_mask, labels=label_ids).loss
        loss.backward()

        for layer_idx in range(model.bert.config.num_hidden_layers):
            ctx = context_layers[f"context_layer_{layer_idx}"]
            grad_ctx = gradients[f"context_layer_{layer_idx}"]
            shape = ctx.shape
            ctx = reshape(ctx, shape, n_heads)
            grad_ctx = reshape(grad_ctx, shape, n_heads)

            # Take the dot
            dot = torch.einsum("bhli,bhli->bhl", [grad_ctx, ctx])
            head_importance[layer_idx] += dot.abs().sum(-1).sum(0).detach()
            del ctx, grad_ctx, dot

        tot_tokens += input_mask.float().detach().sum().data

    head_importance[:-1] /= tot_tokens

    for handle in forward_handles:
        handle.remove()
    return head_importance

def normalize(tensors):
    exponent = 2
    norm_by_layer = torch.pow(
        torch.pow(tensors, exponent).sum(-1), 1 / exponent
    )
    tensors /= norm_by_layer.unsqueeze(-1) + 1e-20
    return tensors

def head_importance_prunning(
    model, prune_list
):
        pruned_heads = set()

        for layer_index, head_index in prune_list:
            if (layer_index, head_index) not in pruned_heads:
                prune_heads(
                    model.bert.encoder.layer[layer_index].attention,
                    [head_index],
                    method=None,
                )
                pruned_heads.add((layer_index, head_index))
        print(sorted(pruned_heads))


def prune_heads(layer, heads, method):
    if len(heads) == 0:
        return
    heads, index = find_pruneable_heads_and_indices(
        heads,
        layer.self.num_attention_heads,
        layer.self.attention_head_size,
        layer.pruned_heads,
    )

    # if method == "unstructed":
    layer.self.query = zero_out_head_weights(
        layer.self.query, heads, layer.self.attention_head_size
    )
    layer.self.key = zero_out_head_weights(
        layer.self.key, heads, layer.self.attention_head_size
    )
    layer.self.value = zero_out_head_weights(
        layer.self.value, heads, layer.self.attention_head_size
    )
    layer.output.dense = zero_out_head_weights(
        layer.output.dense, heads, layer.self.attention_head_size, dim=1
    )
    # elif method == "structed":
    #     layer.self.query = prune_linear_layer(layer.self.query, index)
    #     layer.self.key = prune_linear_layer(layer.self.key, index)
    #     layer.self.value = prune_linear_layer(layer.self.value, index)
    #     layer.output.dense = prune_linear_layer(layer.output.dense, index)

    #     layer.self.num_attention_heads = layer.self.num_attention_heads - len(heads)
    #     layer.self.all_head_size = layer.self.attention_head_size *  layer.self.num_attention_heads
    #     layer.pruned_heads = layer.pruned_heads.union(heads)


def zero_out_head_weights(
    layer: nn.Linear, heads: Set[int], head_size: int, dim: int = 0
) -> nn.Linear:
    """
    Zero out the weights of the specified heads in the linear layer.

    Args:
        layer (`torch.nn.Linear`): The layer to modify.
        heads (`Set[int]`): The indices of heads to zero out.
        head_size (`int`): The size of each head.
        dim (`int`, *optional*, defaults to 0): The dimension on which to zero out the weights.

    Returns:
        `torch.nn.Linear`: The modified layer with weights of specified heads zeroed out.
    """
    for head in heads:
        start_index = head * head_size
        end_index = (head + 1) * head_size
        if dim == 0:
            layer.weight.data[start_index:end_index] = 0
        elif dim == 1:
            layer.weight.data[:, start_index:end_index] = 0

    return layer


In [26]:
from src.models.evaluate import (
    evaluate_model,
    get_sparsity,
    get_similarity,
    get_perplexity,
)

In [27]:
prune_list = [(0, 0), (0, 1), (0,2), (0,3)]

module = copy.deepcopy(model)

head_importance_prunning(module, prune_list)

result = evaluate_model(module, config, test_dataloader)
for concern in range(config.num_labels):
    get_similarity(model, module, valid_dataloader, concern, num_samples, config)
print("original model's perplexity")
get_perplexity(model, valid_dataloader, config)
print("pruned model's perplexity")
get_perplexity(module, valid_dataloader, config)

[(0, 0), (0, 1), (0, 2), (0, 3)]


Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.3199
Precision: 0.6381, Recall: 0.5909, F1-Score: 0.5984
              precision    recall  f1-score   support

           0     0.4615    0.5508    0.5022      2992
           1     0.6804    0.3920    0.4975      2992
           2     0.7023    0.5780    0.6341      3012
           3     0.3284    0.6464    0.4356      2998
           4     0.7947    0.6757    0.7304      2973
           5     0.8383    0.7420    0.7872      3054
           6     0.6815    0.3876    0.4942      3003
           7     0.5336    0.6700    0.5941      3012
           8     0.6162    0.6338    0.6249      2982
           9     0.7438    0.6328    0.6838      2982

    accuracy                         0.5911     30000
   macro avg     0.6381    0.5909    0.5984     30000
weighted avg     0.6383    0.5911    0.5986     30000

adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking squa

3.567145586013794

In [28]:
prune_list = [(1, 0), (1, 1), (1,2), (1,3)]

module = copy.deepcopy(model)

head_importance_prunning(module, prune_list)

result = evaluate_model(module, config, test_dataloader)
for concern in range(config.num_labels):
    get_similarity(model, module, valid_dataloader, concern, num_samples, config)
print("original model's perplexity")
get_perplexity(model, valid_dataloader, config)
print("pruned model's perplexity")
get_perplexity(module, valid_dataloader, config)

[(1, 0), (1, 1), (1, 2), (1, 3)]


Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.2621
Precision: 0.6428, Recall: 0.6038, F1-Score: 0.6115
              precision    recall  f1-score   support

           0     0.4305    0.5943    0.4993      2992
           1     0.6753    0.4783    0.5600      2992
           2     0.6894    0.5916    0.6368      3012
           3     0.3475    0.6181    0.4448      2998
           4     0.7432    0.7592    0.7511      2973
           5     0.8539    0.7403    0.7931      3054
           6     0.6627    0.4063    0.5037      3003
           7     0.6209    0.6301    0.6255      3012
           8     0.6287    0.6251    0.6269      2982
           9     0.7764    0.5949    0.6736      2982

    accuracy                         0.6039     30000
   macro avg     0.6428    0.6038    0.6115     30000
weighted avg     0.6431    0.6039    0.6117     30000

adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking squa

3.337691307067871

In [29]:
prune_list = [(2, 0), (2, 1), (2,2), (2,3)]

module = copy.deepcopy(model)

head_importance_prunning(module, prune_list)

result = evaluate_model(module, config, test_dataloader)
for concern in range(config.num_labels):
    get_similarity(model, module, valid_dataloader, concern, num_samples, config)
print("original model's perplexity")
get_perplexity(model, valid_dataloader, config)
print("pruned model's perplexity")
get_perplexity(module, valid_dataloader, config)

[(2, 0), (2, 1), (2, 2), (2, 3)]


Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.2475
Precision: 0.6431, Recall: 0.6069, F1-Score: 0.6149
              precision    recall  f1-score   support

           0     0.5066    0.5157    0.5111      2992
           1     0.6570    0.5000    0.5678      2992
           2     0.6923    0.6125    0.6500      3012
           3     0.3238    0.6448    0.4311      2998
           4     0.7521    0.7346    0.7432      2973
           5     0.8246    0.7590    0.7905      3054
           6     0.6528    0.4126    0.5056      3003
           7     0.6578    0.5744    0.6133      3012
           8     0.6226    0.6616    0.6415      2982
           9     0.7415    0.6543    0.6952      2982

    accuracy                         0.6071     30000
   macro avg     0.6431    0.6069    0.6149     30000
weighted avg     0.6434    0.6071    0.6151     30000

adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking squa

3.3074309825897217

In [30]:
prune_list = [(3, 0), (3, 1), (3,2), (3,3)]

module = copy.deepcopy(model)

head_importance_prunning(module, prune_list)

result = evaluate_model(module, config, test_dataloader)
for concern in range(config.num_labels):
    get_similarity(model, module, valid_dataloader, concern, num_samples, config)
print("original model's perplexity")
get_perplexity(model, valid_dataloader, config)
print("pruned model's perplexity")
get_perplexity(module, valid_dataloader, config)

[(3, 0), (3, 1), (3, 2), (3, 3)]


Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.2442
Precision: 0.6263, Recall: 0.6145, F1-Score: 0.6145
              precision    recall  f1-score   support

           0     0.4883    0.5150    0.5013      2992
           1     0.6200    0.5709    0.5944      2992
           2     0.6390    0.6335    0.6362      3012
           3     0.3763    0.5143    0.4346      2998
           4     0.7014    0.7656    0.7321      2973
           5     0.7875    0.7839    0.7857      3054
           6     0.6865    0.3813    0.4903      3003
           7     0.6610    0.6142    0.6367      3012
           8     0.6293    0.6536    0.6412      2982
           9     0.6736    0.7129    0.6927      2982

    accuracy                         0.6146     30000
   macro avg     0.6263    0.6145    0.6145     30000
weighted avg     0.6265    0.6146    0.6147     30000

adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking squa

3.2736523151397705