In [1]:
import os
import sys

sys.path.append("../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint, load_checkpoint
from src.models.evaluate import evaluate_model, get_sparsity, get_similarity
from src.utils.sampling import SamplingDataset
from src.pruning.prune_head import head_importance_prunning

In [3]:
# name = "bert-tiny-yahoo"
name = "bert-4-128-yahoo"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 128
ci_ratio = 0.3
seed = 44

In [4]:
config = Config(name, device)

In [5]:
model = load_model(config=config)

Loading the model.
{'architectures': 'bert',
 'dataset_name': 'YahooAnswersTopics',
 'model_name': 'models/bert-4-128-yahoo',
 'num_labels': 10,
 'tokenizer_name': 'fabriceyhc/bert-base-uncased-yahoo_answers_topics'}
The model models/bert-4-128-yahoo is loaded.


In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.
train.pkl is loaded from cache.
valid.pkl is loaded from cache.
test.pkl is loaded from cache.
The dataset YahooAnswersTopics is loaded
{'config_name': 'yahoo_answers_topics',
 'features': {'first_column': 'question_title', 'second_column': 'topic'},
 'path': 'yahoo_answers_topics'}


In [7]:
positive_samples = SamplingDataset(
    train_dataloader,
    config,
    0,
    num_samples,
    True,
    4,
    resample=False,
)

In [8]:
negative_samples = SamplingDataset(
    train_dataloader,
    config,
    0,
    num_samples,
    False,
    4,
    resample=False,
)

In [9]:
import torch
import torch.nn as nn
from scipy.stats import norm
from typing import *
from torch import Tensor
from torch.nn import Module
import torch.nn.functional as F
from functools import partial
from src.utils.sampling import SamplingDataset
from src.pruning.propagate import propagate
from src.utils.helper import Config
import gc

In [10]:
def calc_coefficient(combined_dataloader, dim=0):
    embeddings = combined_dataloader[0]["embeddings"]

    batch_size = embeddings.shape[0] // 2
    concern_inputs, non_concern_inputs = (
        embeddings[:batch_size],
        embeddings[batch_size:],
    )

    calc_norm = lambda tensors, dim: torch.norm(
        tensors.reshape((-1, tensors.shape[-1])), dim=dim
    )

    if dim == 0:
        new_shape = (-1, 1)
    else:
        new_shape = (1, -1)

    concern_norm = calc_norm(concern_inputs, dim=0).reshape(new_shape)
    non_concern_norm = calc_norm(non_concern_inputs, dim=0).reshape(new_shape)

    cosine_similarity = F.cosine_similarity(
        concern_inputs.reshape((-1, concern_inputs.shape[-1])),
        non_concern_inputs.reshape((-1, non_concern_inputs.shape[-1])),
        dim=0,
    ).reshape(new_shape)

    sine_similarity = torch.sign(cosine_similarity) * torch.sqrt(
        1 - cosine_similarity**2
    )
    euclidean_distance = torch.sqrt(concern_norm**2 + non_concern_norm**2)
    coefficient = (
        concern_norm
        + sine_similarity
        * torch.abs(concern_norm + non_concern_norm)
        / euclidean_distance
    )
    return coefficient

In [11]:
def find_layers(
    model: Module,
    layer_types: Optional[List[Type[Module]]] = None,
    include_layers: Optional[List[str]] = None,
    exclude_layers: Optional[List[str]] = None,
    prefix: str = "",
) -> Dict[str, Module]:
    if layer_types is None:
        layer_types = [nn.Linear]
    if include_layers is None:
        include_layers = []
    if exclude_layers is None:
        exclude_layers = []
    layers_dict: Dict[str, Module] = {}

    def recursive_find(module: Module, prefix: str) -> None:
        for name, layer in module.named_children():
            layer_name = f"{prefix}.{name}" if prefix else name
            if any(exclude in layer_name for exclude in exclude_layers):
                continue
            if include_layers and not any(
                include in layer_name for include in include_layers
            ):
                if not any(isinstance(layer, t) for t in layer_types):
                    recursive_find(layer, layer_name)
                continue
            if isinstance(layer, tuple(layer_types)):
                layers_dict[layer_name] = layer
            else:
                recursive_find(layer, layer_name)

    recursive_find(model, prefix)

    return layers_dict


def get_hook(method):
    def hook(module, input, output):
        method(module, input, output)

    return hook


def get_embeddings(model, dataloader):
    embeddings_list = {"embeddings": [], "labels": [], "attention_mask": []}

    for batch in dataloader:
        input_ids = batch["input_ids"]
        labels = batch["labels"]
        attention_mask = batch["attention_mask"]
        with torch.no_grad():
            input_embeddings = model.get_input_embeddings()(input_ids)
        embeddings_list["embeddings"].append(input_embeddings)
        embeddings_list["labels"].append(labels)
        embeddings_list["attention_mask"].append(attention_mask)
        from src.utils.data_class import CustomEmbeddingDataset
    return CustomEmbeddingDataset(embeddings_list)

In [12]:
class Pruner:
    def __init__(self, layers, ratio: float, method="unstructed") -> None:
        self.layers = layers
        self.ratio = ratio
        self.coefficient = None
        self.method = method
        self.pruning_mask = {}
        self.pruning_indices = {}

    def ci(self, layer, inputs, outputs):
        current_weight = layer.weight.data.clone()

        importance_score = torch.abs(current_weight) * torch.abs(self.coefficient)

        indices_vector = None
        if self.method == "unstructed":
            sort_res = torch.sort(importance_score, dim=-1, stable=True)
            num_prune = int(current_weight.shape[1] * self.ratio)
            indices_matrix = sort_res[1][:, :num_prune]
            W_mask = (torch.zeros_like(importance_score) == 1).scatter_(
                1, indices_matrix, True
            )
        elif self.method == "structed":
            importance_vector = torch.norm(importance_score, dim=1)
            num_prune = int(importance_vector.shape[0] * self.ratio)
            sort_res = torch.sort(importance_vector)
            indices_vector = sort_res[1][:num_prune]
            W_mask = (torch.zeros_like(importance_vector) == 1).scatter_(
                0, indices_vector, True
            )
        else:
            raise NotImplementedError(f"{self.method} is not implemented")

        if self.method == "unstructed":
            sorted_indices_matrix = torch.sort(indices_matrix, dim=1)[0]
            indices = sorted_indices_matrix

        elif self.method == "structed":
            sorted_indices_vector = torch.sort(indices_vector)[0]
            indices = sorted_indices_vector
        else:
            raise NotImplementedError(f"The method {self.method} is not implemented")

        layer_id = id(layer)
        layer_name = [key for key, val in self.layers.items() if id(val) == layer_id][0]
        self.pruning_mask[layer_name] = W_mask
        self.pruning_indices[layer_name] = indices

    @staticmethod
    def apply(layer, method, axis, mask, keepdim):
        current_weight = layer.weight.data.clone()
        current_weight = current_weight * mask

        if not keepdim:
            if method == "structed":
                if axis == 0:
                    zero_rows = (current_weight == 0).all(dim=1)
                    current_weight = current_weight[~zero_rows]

                    if layer.bias is not None:
                        current_bias = layer.bias.data.clone()
                        layer.bias.data = current_bias[~zero_rows]
                elif axis == 1:
                    zero_cols = (current_weight == 0).all(dim=0)
                    current_weight = current_weight[:, ~zero_cols]
        layer.in_features = current_weight.shape[1]
        layer.out_features = current_weight.shape[0]
        layer.weight.data = current_weight

In [13]:
def prune_concern_identification(
    model: Module,
    config: Config,
    dominant_concern: SamplingDataset,
    non_dominant_concern: SamplingDataset,
    sparsity_ratio: float = 0.6,
    include_layers: Optional[List[str]] = None,
    exclude_layers: Optional[List[str]] = None,
    method="unstructed",
    keep_dim=True,
) -> None:
    layers = find_layers(
        model, include_layers=include_layers, exclude_layers=exclude_layers
    )
    handle_list = []
    pruner = Pruner(layers, ratio=sparsity_ratio, method=method)

    for name, layer in layers.items():
        if method == "structed":
            if "intermediate" in name:
                handle = layer.register_forward_hook(pruner.ci)
                handle_list.append(handle)
        else:
            handle = layer.register_forward_hook(pruner.ci)
            handle_list.append(handle)

    pos_embeddings = get_embeddings(model, dominant_concern)
    neg_embeddings = get_embeddings(model, non_dominant_concern)
    dominant_batches = list(pos_embeddings)
    non_dominant_batches = list(neg_embeddings)
    combined_batches = {}
    keys = dominant_batches[0].keys()
    for key in keys:
        combined_batches[key] = torch.cat(
            [batch[key] for batch in dominant_batches + non_dominant_batches]
        )

    combined_dataloader = [combined_batches]
    pruner.coefficient = calc_coefficient(combined_dataloader, dim=1).to(config.device)
    propagate(model, combined_dataloader, config)

    for handle in handle_list:
        handle.remove()

    intermediate_mask = None
    for name, layer in layers.items():
        print(name)
        if method == "structed":
            if "intermediate" in name:
                current_mask = pruner.pruning_mask[name].to("cpu")
                intermediate_mask = current_mask
                current_mask = current_mask.unsqueeze(dim=1).expand(
                    -1, layer.weight.shape[1]
                )
                Pruner.apply(
                    layer,
                    method="structed",
                    axis=0,
                    mask=current_mask,
                    keepdim=keep_dim,
                )
            elif "output" in name:
                current_mask = intermediate_mask.unsqueeze(dim=0).expand(
                    layer.weight.shape[0], -1
                )
                Pruner.apply(
                    layer,
                    method="structed",
                    axis=1,
                    mask=current_mask,
                    keepdim=keep_dim,
                )
        elif method == "unstructed":
            current_mask = pruner.pruning_mask[name][0].to("cpu")
            Pruner.apply(
                layer, method="unstructed", axis=0, mask=current_mask, keepdim=keep_dim
            )

In [14]:
result_list = []

for concern in range(config.num_labels):
    config.init_seed()
    positive_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        True,
        4,
        resample=False,
    )
    negative_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        False,
        4,
        resample=False,
    )
    all_samples = SamplingDataset(
        train_dataloader,
        config,
        200,
        num_samples,
        False,
        4,
        resample=False,
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        config,
        positive_samples,
        negative_samples,
        include_layers=["intermediate", "output"],
        exclude_layers=["attention"],
        sparsity_ratio=0.5,
        keep_dim=True,
        method="structed",
    )

    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, config, test_dataloader, verbose=True)
    result_list.append(result)
    break

bert.encoder.layer.0.intermediate.dense


IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

In [None]:
model

In [None]:
module

In [None]:
from src.utils.helper import report_to_df, append_nth_row

df_list = [report_to_df(df) for df in result_list]
new_df = append_nth_row(df_list)
new_df