In [1]:
import os
import sys

sys.path.append("../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint
from src.utils.load import save_checkpoint
from src.models.evaluate import evaluate_model, get_sparsity, get_similarity
from src.utils.sampling import SamplingDataset

In [3]:
name = "bert-4-128-yahoo"
# name = "YahooAnswersTopics"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 128
head_pruning_ratio = 0.5
seed = 44
include_layers = ["attention", "intermediate", "output"]
        exclude_layers=exclude_layers,


In [None]:
config = Config(name, device)
num_labels = config.config["num_labels"]
model = load_model(config)

In [None]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

In [6]:
import numpy as np
import torch
import copy
from functools import partial
import torch.nn as nn
from transformers.pytorch_utils import find_pruneable_heads_and_indices
from typing import *

import time


def calculate_prune_head(arr, i, pruned_heads):
    flattened_with_indices = [
        (value, index)
        for index, value in np.ndenumerate(arr)
        if index not in pruned_heads
    ]

    sorted_by_value = sorted(flattened_with_indices, key=lambda x: x[0])
    bottom_indices = sorted_by_value[:i]

    bottom_indices_only = [index for _, index in bottom_indices]

    return bottom_indices_only


def prune_head(model, prune_list):
    for layer_index, head_index in prune_list:
        prune_heads(model.bert.encoder.layer[layer_index].attention, ([head_index]))
    return model


def preprocess_prunehead(arr, num_layer):
    layer_max = lambda arr: np.argmax(arr, axis=1)

    max_layer = layer_max(arr)
    for layer in range(num_layer):
        head = max_layer[layer]
        arr[layer][head] = 100
    return arr


def prune_heads(layer, heads):
    if len(heads) == 0:
        return
    heads, index = find_pruneable_heads_and_indices(
        heads,
        layer.self.num_attention_heads,
        layer.self.attention_head_size,
        layer.pruned_heads,
    )

    # Zero out weights in linear layers instead of pruning
    layer.self.query = zero_out_head_weights(
        layer.self.query, heads, layer.self.attention_head_size
    )
    layer.self.key = zero_out_head_weights(
        layer.self.key, heads, layer.self.attention_head_size
    )
    layer.self.value = zero_out_head_weights(
        layer.self.value, heads, layer.self.attention_head_size
    )
    layer.output.dense = zero_out_head_weights(
        layer.output.dense, heads, layer.self.attention_head_size, dim=1
    )


def zero_out_head_weights(
    layer: nn.Linear, heads: Set[int], head_size: int, dim: int = 0
) -> nn.Linear:
    """
    Zero out the weights of the specified heads in the linear layer.

    Args:
        layer (`torch.nn.Linear`): The layer to modify.
        heads (`Set[int]`): The indices of heads to zero out.
        head_size (`int`): The size of each head.
        dim (`int`, *optional*, defaults to 0): The dimension on which to zero out the weights.

    Returns:
        `torch.nn.Linear`: The modified layer with weights of specified heads zeroed out.
    """
    for head in heads:
        start_index = head * head_size
        end_index = (head + 1) * head_size
        if dim == 0:
            layer.weight.data[start_index:end_index] = 0
        elif dim == 1:
            layer.weight.data[:, start_index:end_index] = 0

    return layer

In [7]:
import numpy as np


def calculate_head_importance(
    model,
    config,
    data,
    normalize_scores_by_layer=True,
):
    device = config.device
    from functools import partial

    gradients = {}
    context_layers = {}

    def save_grad(gradients, layer_idx, grad):
        gradients[f"context_layer_{layer_idx}"] = grad

    def forward_hook(module, input, output, gradients, context_layers, layer_idx):
        context_layers[f"context_layer_{layer_idx}"] = output[0]
        output[0].register_hook(partial(save_grad, gradients, layer_idx))

    def reshape(tensors, shape, num_heads):
        batch_size = shape[0]
        seq_len = shape[1]
        head_dim = shape[2] // num_heads
        tensors = tensors.reshape(batch_size, seq_len, num_heads, head_dim)
        tensors = tensors.permute(0, 2, 1, 3)
        return tensors

    forward_handles = []

    for layer_idx in range(model.bert.config.num_hidden_layers):
        self_att = model.bert.encoder.layer[layer_idx].attention.self
        handle = self_att.register_forward_hook(
            partial(
                forward_hook,
                gradients=gradients,
                context_layers=context_layers,
                layer_idx=layer_idx,
            )
        )
        forward_handles.append(handle)

    # Disable dropout
    model.eval()
    device = device or next(model.parameters()).device

    n_layers = model.bert.config.num_hidden_layers
    n_heads = model.bert.config.num_attention_heads
    head_dim = model.bert.config.hidden_size // n_heads
    num_classes = model.num_labels  # Adjust based on your model

    head_importance = {
        label: torch.zeros(n_layers, n_heads).to(device) for label in range(num_classes)
    }
    tot_tokens = {label: 0 for label in range(num_classes)}

    for step, batch in enumerate(data):
        input_ids = batch["input_ids"].to(device)
        input_mask = batch["attention_mask"].to(device)
        label_ids = batch["labels"].to(device)
        unique_labels = label_ids.unique()

        for label in unique_labels:
            mask = label_ids == label
            input_ids_label = input_ids[mask]
            input_mask_label = input_mask[mask]
            label_ids_label = label_ids[mask]

            if input_ids_label.size(0) == 0:
                continue

            # Zero gradients
            model.zero_grad()
            # Compute loss and backward pass
            loss = model(
                input_ids_label, attention_mask=input_mask_label, labels=label_ids_label
            ).loss
            loss.backward()

            for layer_idx in range(n_layers):
                ctx = context_layers[f"context_layer_{layer_idx}"]
                grad_ctx = gradients[f"context_layer_{layer_idx}"]
                shape = ctx.shape
                ctx = reshape(ctx, shape, n_heads)
                grad_ctx = reshape(grad_ctx, shape, n_heads)

                # Compute dot product and accumulate head importance per class
                dot = torch.einsum("bhli,bhli->bhl", [grad_ctx, ctx])
                head_importance[label.item()][layer_idx] += (
                    dot.abs().sum(-1).sum(0).detach()
                )
                del ctx, grad_ctx, dot

            tot_tokens[label.item()] += input_mask_label.float().detach().sum().item()

    for label in range(num_classes):
        # Adjust the value weight norm addition
        for layer_idx in range(n_layers):
            for head in range(n_heads):
                start_idx = head * head_dim
                end_idx = (head + 1) * head_dim

                value_weight_norm = torch.norm(
                    model.bert.encoder.layer[layer_idx].attention.self.value.weight[
                        :, start_idx:end_idx
                    ]
                )
                head_importance[label][layer_idx][head] += value_weight_norm.detach()

        # Normalize head importance per class
        head_importance[label][:-1] /= (
            tot_tokens[label] + 1e-20
        )  # Avoid division by zero

        if normalize_scores_by_layer:
            exponent = 2
            norm_by_layer = torch.pow(
                torch.pow(head_importance[label], exponent).sum(-1), 1 / exponent
            )
            head_importance[label] /= norm_by_layer.unsqueeze(-1) + 1e-20

    for handle in forward_handles:
        handle.remove()

    return head_importance

In [8]:
positive_samples = SamplingDataset(
    train_dataloader,
    0,
    num_samples,
    num_labels,
    True,
    4,
    device=device,
    resample=False,
)
all_samples = SamplingDataset(
    train_dataloader,
    200,
    num_samples,
    num_labels,
    False,
    4,
    device=device,
    resample=False,
)

In [9]:
def head_importance_prunning(
    model, config, dominant_concern, concern, sparsity_ratio, gradually=True
):
    num_attention_heads = model.config.num_attention_heads
    num_hidden_layers = model.config.num_hidden_layers
    total_heads_to_prune = int(num_attention_heads * num_hidden_layers * sparsity_ratio)

    if total_heads_to_prune >= 4 and total_heads_to_prune % 4 != 0:
        total_heads_to_prune -= 4 - (total_heads_to_prune % 4)

    if gradually:
        num_steps = max(1, total_heads_to_prune // 4)
    else:
        num_steps = 1

    heads_per_step = int(total_heads_to_prune // num_steps)
    print(f"Total heads to prune: {total_heads_to_prune}")

    pruned_heads = set()

    for step in range(num_steps):
        if step == num_steps - 1:
            current_heads_to_prune = total_heads_to_prune - (step * heads_per_step)
        else:
            current_heads_to_prune = heads_per_step

        head_importance_list = calculate_head_importance(
            model,
            config,
            dominant_concern,
            normalize_scores_by_layer=True,
        )
        head_importance_list = head_importance_list[concern]
        print(f"head importance list\n {head_importance_list}")
        head_importance_list = head_importance_list.cpu()

        # preprocess_prunehead(head_importance_list, num_hidden_layers)

        prune_list = calculate_prune_head(
            head_importance_list, current_heads_to_prune, pruned_heads
        )
        pruned_heads.update(prune_list)

        prune_head(model, prune_list)
    print(pruned_heads)

In [None]:
module = copy.deepcopy(model)
head_importance_prunning(module, config, all_samples, 0, head_pruning_ratio)
result = evaluate_model(module, config, test_dataloader, verbose=True)
get_similarity(
    model,
    module,
    valid_dataloader,
    0,
    num_samples,
    num_labels,
    device=device,
)

In [None]:
for concern in range(num_labels):
    train = copy.deepcopy(train_dataloader)
    positive_samples = SamplingDataset(
        train,
        concern,
        num_samples,
        num_labels,
        True,
        4,
        device=device,
        resample=False,
    )

    module = copy.deepcopy(model)
    head_importance_prunning(
        module, config, positive_samples, concern, head_pruning_ratio
    )
    print(f"Evaluate the pruned model {concern}")

    result = evaluate_model(module, config, test_dataloader, verbose=True)
    get_similarity(
        model,
        module,
        valid_dataloader,
        0,
        num_samples,
        num_labels,
        device=device,
    )