In [1]:
import os
import sys

sys.path.append("../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint, load_checkpoint
from src.models.evaluate import evaluate_model, get_sparsity, get_similarity
from src.utils.sampling import SamplingDataset


In [3]:
name = "bert-4-128-yahoo"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 128
ratio = 0.5
seed = 44

In [4]:
config = Config(name, device)

In [5]:
model = load_model(config=config)

Loading the model.
{'architectures': 'bert',
 'dataset_name': 'YahooAnswersTopics',
 'model_name': 'models/bert-4-128-yahoo',
 'num_labels': 10,
 'tokenizer_name': 'fabriceyhc/bert-base-uncased-yahoo_answers_topics'}
The model models/bert-4-128-yahoo is loaded.


In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.
train.pkl is loaded from cache.
valid.pkl is loaded from cache.
test.pkl is loaded from cache.
The dataset YahooAnswersTopics is loaded
{'config_name': 'yahoo_answers_topics',
 'features': {'first_column': 'question_title', 'second_column': 'topic'},
 'path': 'yahoo_answers_topics'}


In [7]:
import copy
import torch
import numpy as np
from typing import *
import torch.nn as nn
from functools import partial
from transformers.pytorch_utils import (
    find_pruneable_heads_and_indices,
    prune_linear_layer,
)


def calculate_head_importance(
    model,
    config,
    data,
    normalize_scores_by_layer=True,
):
    device = config.device
    from functools import partial

    gradients = {}
    context_layers = {}

    def save_grad(gradients, layer_idx, grad):
        gradients[f"context_layer_{layer_idx}"] = grad

    def forward_hook(module, input, output, gradients, context_layers, layer_idx):
        context_layers[f"context_layer_{layer_idx}"] = output[0]
        output[0].register_hook(partial(save_grad, gradients, layer_idx))

    def reshape(tensors, shape, num_heads):
        batch_size = shape[0]
        seq_len = shape[1]
        head_dim = shape[2] // num_heads
        tensors = tensors.reshape(batch_size, seq_len, num_heads, head_dim)
        tensors = tensors.permute(0, 2, 1, 3)
        return tensors

    forward_handles = []

    for layer_idx in range(model.bert.config.num_hidden_layers):
        self_att = model.bert.encoder.layer[layer_idx].attention.self
        handle = self_att.register_forward_hook(
            partial(
                forward_hook,
                gradients=gradients,
                context_layers=context_layers,
                layer_idx=layer_idx,
            )
        )
        forward_handles.append(handle)

    """Calculate head importance scores"""
    # Disable dropout
    model.eval()
    # Device
    device = device or next(model.parameters()).device

    # Prepare data loader
    # Head importance tensor
    n_layers = model.bert.config.num_hidden_layers
    n_heads = model.bert.config.num_attention_heads
    head_importance = torch.zeros(n_layers, n_heads).to(device)
    tot_tokens = 0
    first_batch = next(iter(data))
    is_embeds = "embeddings" in first_batch
    for step, batch in enumerate(data):
        if is_embeds:
            embeddings = batch["embeddings"].to(device)
        else:
            input_ids = batch["input_ids"].to(device)
        input_mask = batch["attention_mask"].to(device)
        label_ids = batch["labels"].to(device)
        # Compute gradients
        if is_embeds:
            loss = model(
                inputs_embeds=embeddings, attention_mask=input_mask, labels=label_ids
            ).loss
        else:
            loss = model(input_ids, attention_mask=input_mask, labels=label_ids).loss
        loss.backward()

        for layer_idx in range(model.bert.config.num_hidden_layers):
            ctx = context_layers[f"context_layer_{layer_idx}"]
            grad_ctx = gradients[f"context_layer_{layer_idx}"]
            shape = ctx.shape
            ctx = reshape(ctx, shape, n_heads)
            grad_ctx = reshape(grad_ctx, shape, n_heads)

            # Take the dot
            dot = torch.einsum("bhli,bhli->bhl", [grad_ctx, ctx])
            head_importance[layer_idx] += dot.abs().sum(-1).sum(0).detach()
            del ctx, grad_ctx, dot

        tot_tokens += input_mask.float().detach().sum().data

    head_importance[:-1] /= tot_tokens

    for handle in forward_handles:
        handle.remove()
    return head_importance

def normalize(tensors):
    exponent = 2
    norm_by_layer = torch.pow(
        torch.pow(tensors, exponent).sum(-1), 1 / exponent
    )
    tensors /= norm_by_layer.unsqueeze(-1) + 1e-20
    return tensors

def head_importance_prunning(
    model, config, dominant_concern, sparsity_ratio, method="unstructed", scheduler=None
):
    num_attention_heads = model.config.num_attention_heads
    num_hidden_layers = model.config.num_hidden_layers
    model = model.to(config.device)
    total_heads_to_prune = int(num_attention_heads * num_hidden_layers * sparsity_ratio)

    total_heads_to_prune = max(total_heads_to_prune, num_hidden_layers)
    print(f"Total heads to prune: {total_heads_to_prune}")
    pruned_heads = set()

    if scheduler is not None:
        steps = scheduler.get_steps()
    else:
        steps = [0.25, 0.25, 0.25, 0.25]
        # steps = [1.0]

    for step_ratio in steps:
        heads_to_prune = int(total_heads_to_prune * step_ratio)
        
        head_importance_list = calculate_head_importance(
            model, config, dominant_concern
        )
        head_importance_list = head_importance_list.cpu()
        print(f"head_importance_list\n {head_importance_list}")
        head_importance_list = normalize(head_importance_list)
        
        head_score = head_importance_list
        
        print(f"head_score\n {head_score}")
        if method == "unstructed":
            sorted_indices = torch.argsort(head_score.view(-1))
            prune_list = []
            for idx in sorted_indices:
                layer_index = int(idx // num_attention_heads)
                head_index = int(idx % num_attention_heads)

                if (layer_index, head_index) not in pruned_heads:
                    prune_list.append((layer_index, head_index))
                
                if len(prune_list) >= heads_to_prune:
                    break
            
        elif method == "structed":
            heads_per_layer = heads_to_prune // num_hidden_layers
            prune_list = []
            for layer_idx in range(num_hidden_layers):
                sorted_heads = torch.argsort(head_importance_list[layer_idx])
                prune_list.extend(
                    [
                        (layer_idx, head.item())
                        for head in sorted_heads[:heads_per_layer]
                    ]
                )
        for layer_index, head_index in prune_list:
            if (layer_index, head_index) not in pruned_heads:
                prune_heads(
                    model.bert.encoder.layer[layer_index].attention,
                    [head_index],
                    method=method,
                )
                pruned_heads.add((layer_index, head_index))
        print(sorted(pruned_heads))


def prune_heads(layer, heads, method):
    if len(heads) == 0:
        return
    heads, index = find_pruneable_heads_and_indices(
        heads,
        layer.self.num_attention_heads,
        layer.self.attention_head_size,
        layer.pruned_heads,
    )

    # if method == "unstructed":
    layer.self.query = zero_out_head_weights(
        layer.self.query, heads, layer.self.attention_head_size
    )
    layer.self.key = zero_out_head_weights(
        layer.self.key, heads, layer.self.attention_head_size
    )
    layer.self.value = zero_out_head_weights(
        layer.self.value, heads, layer.self.attention_head_size
    )
    layer.output.dense = zero_out_head_weights(
        layer.output.dense, heads, layer.self.attention_head_size, dim=1
    )
    # elif method == "structed":
    #     layer.self.query = prune_linear_layer(layer.self.query, index)
    #     layer.self.key = prune_linear_layer(layer.self.key, index)
    #     layer.self.value = prune_linear_layer(layer.self.value, index)
    #     layer.output.dense = prune_linear_layer(layer.output.dense, index)

    #     layer.self.num_attention_heads = layer.self.num_attention_heads - len(heads)
    #     layer.self.all_head_size = layer.self.attention_head_size *  layer.self.num_attention_heads
    #     layer.pruned_heads = layer.pruned_heads.union(heads)


def zero_out_head_weights(
    layer: nn.Linear, heads: Set[int], head_size: int, dim: int = 0
) -> nn.Linear:
    """
    Zero out the weights of the specified heads in the linear layer.

    Args:
        layer (`torch.nn.Linear`): The layer to modify.
        heads (`Set[int]`): The indices of heads to zero out.
        head_size (`int`): The size of each head.
        dim (`int`, *optional*, defaults to 0): The dimension on which to zero out the weights.

    Returns:
        `torch.nn.Linear`: The modified layer with weights of specified heads zeroed out.
    """
    for head in heads:
        start_index = head * head_size
        end_index = (head + 1) * head_size
        if dim == 0:
            layer.weight.data[start_index:end_index] = 0
        elif dim == 1:
            layer.weight.data[:, start_index:end_index] = 0

    return layer


In [8]:
from src.models.evaluate import (
    evaluate_model,
    get_sparsity,
    get_similarity,
    get_perplexity,
)

In [9]:
result_list = []

for concern in range(config.num_labels):
  config.init_seed()
  positive_samples = SamplingDataset(
      train_dataloader,
      config,
      concern,
      num_samples,
      True,
      4,
      resample=False,
  )
  negative_samples = SamplingDataset(
      train_dataloader,
      config,
      concern,
      num_samples,
      False,
      4,
      resample=False,
  )
  module = copy.deepcopy(model)

  head_importance_prunning(module, config, positive_samples, ratio)

  print(f"Evaluate the pruned model {concern}")
  result = evaluate_model(module, config, test_dataloader)
  result_list.append(result)
  get_similarity(model, module, valid_dataloader, concern, num_samples, config)
  print("original model's perplexity")
  get_perplexity(model, valid_dataloader, config)
  print("pruned model's perplexity")
  get_perplexity(module, valid_dataloader, config)
  

Total heads to prune: 8
head_importance_list
 tensor([[1.8039e-03, 2.0566e-03, 1.8249e-03, 1.8589e-03],
        [1.5613e-03, 2.5663e-03, 2.0641e-03, 2.2096e-03],
        [1.6074e-03, 2.3024e-03, 1.6433e-03, 1.7042e-03],
        [1.9489e+00, 1.6367e+00, 2.2909e+00, 1.9624e+00]])
head_score
 tensor([[0.4775, 0.5444, 0.4831, 0.4921],
        [0.3663, 0.6021, 0.4843, 0.5184],
        [0.4377, 0.6269, 0.4474, 0.4640],
        [0.4938, 0.4147, 0.5804, 0.4972]])
[(1, 0), (3, 1)]
head_importance_list
 tensor([[1.8696e-03, 1.9767e-03, 1.9059e-03, 1.8763e-03],
        [0.0000e+00, 2.6282e-03, 2.0795e-03, 2.2699e-03],
        [1.5154e-03, 2.1690e-03, 1.7973e-03, 1.6888e-03],
        [2.0914e+00, 0.0000e+00, 2.5977e+00, 2.2976e+00]])
head_score
 tensor([[0.4900, 0.5181, 0.4996, 0.4918],
        [0.0000, 0.6493, 0.5137, 0.5608],
        [0.4190, 0.5997, 0.4969, 0.4669],
        [0.5164, 0.0000, 0.6414, 0.5673]])
[(1, 0), (2, 0), (2, 3), (3, 1)]
head_importance_list
 tensor([[1.8762e-03, 2.0669e-03,

Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.3344
Precision: 0.6432, Recall: 0.5870, F1-Score: 0.5965
              precision    recall  f1-score   support

           0     0.4406    0.5973    0.5071      2992
           1     0.6312    0.5010    0.5586      2992
           2     0.6487    0.6315    0.6400      3012
           3     0.3183    0.6448    0.4262      2998
           4     0.7624    0.7101    0.7353      2973
           5     0.8481    0.6637    0.7447      3054
           6     0.7567    0.3367    0.4660      3003
           7     0.5724    0.6232    0.5967      3012
           8     0.6743    0.5761    0.6213      2982
           9     0.7789    0.5858    0.6687      2982

    accuracy                         0.5871     30000
   macro avg     0.6432    0.5870    0.5965     30000
weighted avg     0.6434    0.5871    0.5966     30000

adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking squa

Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.3359
Precision: 0.6407, Recall: 0.5867, F1-Score: 0.5960
              precision    recall  f1-score   support

           0     0.4324    0.6080    0.5053      2992
           1     0.6342    0.4873    0.5511      2992
           2     0.6404    0.6345    0.6374      3012
           3     0.3243    0.6334    0.4290      2998
           4     0.7681    0.6862    0.7248      2973
           5     0.8774    0.6516    0.7478      3054
           6     0.7505    0.3457    0.4733      3003
           7     0.5737    0.6139    0.5931      3012
           8     0.6577    0.5812    0.6171      2982
           9     0.7486    0.6251    0.6813      2982

    accuracy                         0.5867     30000
   macro avg     0.6407    0.5867    0.5960     30000
weighted avg     0.6410    0.5867    0.5962     30000

adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking squa

Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.3256
Precision: 0.6312, Recall: 0.5844, F1-Score: 0.5907
              precision    recall  f1-score   support

           0     0.4046    0.6213    0.4900      2992
           1     0.6169    0.5582    0.5861      2992
           2     0.6421    0.6451    0.6436      3012
           3     0.3416    0.5807    0.4301      2998
           4     0.6661    0.8032    0.7283      2973
           5     0.8263    0.6372    0.7195      3054
           6     0.7132    0.3593    0.4779      3003
           7     0.6174    0.5664    0.5908      3012
           8     0.6751    0.5540    0.6086      2982
           9     0.8091    0.5188    0.6322      2982

    accuracy                         0.5844     30000
   macro avg     0.6312    0.5844    0.5907     30000
weighted avg     0.6315    0.5844    0.5908     30000

adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking squa

Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.3244
Precision: 0.6409, Recall: 0.5880, F1-Score: 0.5978
              precision    recall  f1-score   support

           0     0.4304    0.6046    0.5028      2992
           1     0.7164    0.4111    0.5224      2992
           2     0.6439    0.6112    0.6272      3012
           3     0.3205    0.6468    0.4286      2998
           4     0.7935    0.6953    0.7411      2973
           5     0.8300    0.7161    0.7689      3054
           6     0.6771    0.3953    0.4992      3003
           7     0.5833    0.6116    0.5971      3012
           8     0.6558    0.5624    0.6055      2982
           9     0.7576    0.6258    0.6854      2982

    accuracy                         0.5882     30000
   macro avg     0.6409    0.5880    0.5978     30000
weighted avg     0.6410    0.5882    0.5980     30000

adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking squa

Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.3003
Precision: 0.6353, Recall: 0.5969, F1-Score: 0.6032
              precision    recall  f1-score   support

           0     0.4197    0.5966    0.4928      2992
           1     0.6928    0.4689    0.5593      2992
           2     0.6642    0.6325    0.6480      3012
           3     0.3421    0.5720    0.4282      2998
           4     0.6790    0.7999    0.7345      2973
           5     0.8462    0.7299    0.7838      3054
           6     0.6645    0.4003    0.4996      3003
           7     0.6261    0.6042    0.6150      3012
           8     0.6295    0.6107    0.6199      2982
           9     0.7886    0.5543    0.6510      2982

    accuracy                         0.5970     30000
   macro avg     0.6353    0.5969    0.6032     30000
weighted avg     0.6356    0.5970    0.6034     30000

adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking squa

Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.3611
Precision: 0.6295, Recall: 0.5738, F1-Score: 0.5813
              precision    recall  f1-score   support

           0     0.3640    0.5876    0.4495      2992
           1     0.7033    0.4278    0.5320      2992
           2     0.7187    0.5259    0.6074      3012
           3     0.3385    0.6024    0.4335      2998
           4     0.7828    0.6872    0.7319      2973
           5     0.7912    0.7705    0.7807      3054
           6     0.6884    0.3826    0.4919      3003
           7     0.5158    0.6793    0.5863      3012
           8     0.6870    0.4195    0.5209      2982
           9     0.7054    0.6553    0.6794      2982

    accuracy                         0.5741     30000
   macro avg     0.6295    0.5738    0.5813     30000
weighted avg     0.6296    0.5741    0.5816     30000

adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking squa

Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.3141
Precision: 0.6451, Recall: 0.5912, F1-Score: 0.5998
              precision    recall  f1-score   support

           0     0.3849    0.6324    0.4785      2992
           1     0.7682    0.3700    0.4994      2992
           2     0.6625    0.6132    0.6369      3012
           3     0.3415    0.5881    0.4321      2998
           4     0.7396    0.7508    0.7451      2973
           5     0.8946    0.6893    0.7786      3054
           6     0.6396    0.4279    0.5128      3003
           7     0.6299    0.6046    0.6170      3012
           8     0.6072    0.6496    0.6277      2982
           9     0.7827    0.5858    0.6701      2982

    accuracy                         0.5912     30000
   macro avg     0.6451    0.5912    0.5998     30000
weighted avg     0.6454    0.5912    0.6000     30000

adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking squa

Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.3033
Precision: 0.6448, Recall: 0.5904, F1-Score: 0.6007
              precision    recall  f1-score   support

           0     0.4109    0.6003    0.4878      2992
           1     0.7222    0.4215    0.5323      2992
           2     0.6790    0.6046    0.6396      3012
           3     0.3190    0.6291    0.4233      2998
           4     0.7577    0.7511    0.7544      2973
           5     0.8448    0.7397    0.7888      3054
           6     0.6393    0.4179    0.5054      3003
           7     0.6103    0.6394    0.6245      3012
           8     0.6568    0.5731    0.6121      2982
           9     0.8083    0.5275    0.6384      2982

    accuracy                         0.5906     30000
   macro avg     0.6448    0.5904    0.6007     30000
weighted avg     0.6450    0.5906    0.6009     30000

adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking squa

Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.3298
Precision: 0.6324, Recall: 0.5890, F1-Score: 0.5963
              precision    recall  f1-score   support

           0     0.3853    0.6404    0.4811      2992
           1     0.6970    0.4666    0.5590      2992
           2     0.6342    0.6408    0.6375      3012
           3     0.3633    0.5720    0.4444      2998
           4     0.7273    0.7588    0.7427      2973
           5     0.8520    0.6012    0.7049      3054
           6     0.6638    0.4089    0.5061      3003
           7     0.6446    0.5292    0.5812      3012
           8     0.6069    0.6636    0.6340      2982
           9     0.7500    0.6087    0.6720      2982

    accuracy                         0.5888     30000
   macro avg     0.6324    0.5890    0.5963     30000
weighted avg     0.6328    0.5888    0.5963     30000

adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking squa

Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.3242
Precision: 0.6320, Recall: 0.5898, F1-Score: 0.5942
              precision    recall  f1-score   support

           0     0.3995    0.6120    0.4834      2992
           1     0.6433    0.4786    0.5489      2992
           2     0.6873    0.5495    0.6107      3012
           3     0.3572    0.5927    0.4458      2998
           4     0.7909    0.6361    0.7051      2973
           5     0.7619    0.7911    0.7762      3054
           6     0.7523    0.3287    0.4575      3003
           7     0.5712    0.6152    0.5924      3012
           8     0.6766    0.5936    0.6324      2982
           9     0.6796    0.7005    0.6899      2982

    accuracy                         0.5900     30000
   macro avg     0.6320    0.5898    0.5942     30000
weighted avg     0.6321    0.5900    0.5944     30000

adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking squa

In [10]:
from src.utils.helper import report_to_df, append_nth_row
df_list = [report_to_df(df) for df in result_list]
new_df = append_nth_row(df_list)
csv_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
new_df.to_csv(f"results/{csv_name}.csv", index=False)
print(csv_name)
new_df

2024-11-04_00-55-15


Unnamed: 0,class,precision,recall,f1-score,support
0,0,0.4406,0.5973,0.5071,2992
1,1,0.6342,0.4873,0.5511,2992
2,2,0.6421,0.6451,0.6436,3012
3,3,0.3205,0.6468,0.4286,2998
4,4,0.679,0.7999,0.7345,2973
5,5,0.7912,0.7705,0.7807,3054
6,6,0.6396,0.4279,0.5128,3003
7,7,0.6103,0.6394,0.6245,3012
8,8,0.6069,0.6636,0.634,2982
9,9,0.6796,0.7005,0.6899,2982
