In [94]:
import os
import sys

sys.path.append("../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [95]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint, load_checkpoint
from src.models.evaluate import evaluate_model, get_sparsity, get_similarity
from src.utils.sampling import SamplingDataset
from src.pruning.prune_head import head_importance_prunning

In [96]:
# name = "bert-tiny-yahoo"
name = "bert-4-128-yahoo"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 128
ci_ratio = 0.3
seed = 44

In [97]:
config = Config(name, device)

In [98]:
model = load_model(config=config)

Loading the model.
{'architectures': 'bert',
 'dataset_name': 'YahooAnswersTopics',
 'model_name': 'models/bert-4-128-yahoo',
 'num_labels': 10,
 'tokenizer_name': 'fabriceyhc/bert-base-uncased-yahoo_answers_topics'}
The model models/bert-4-128-yahoo is loaded.


In [99]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.
train.pkl is loaded from cache.
valid.pkl is loaded from cache.
test.pkl is loaded from cache.
The dataset YahooAnswersTopics is loaded
{'config_name': 'yahoo_answers_topics',
 'features': {'first_column': 'question_title', 'second_column': 'topic'},
 'path': 'yahoo_answers_topics'}


In [100]:
from src.utils.load import load_cache
from src.utils.data_class import CustomEmbeddingDataset
from torch.utils.data import DataLoader

generated = load_cache(
    "datasets/generated_dataset/embedding_based/4_128-yahoo", "4_128-yahoo_top1.pkl"
)

4_128-yahoo_top1.pkl is loaded from cache.


In [101]:
print(generated.keys())

dict_keys(['example_label', 'example_list', 'attn_list'])


In [102]:
generated["embeddings"] = generated.pop("example_list")
generated["labels"] = generated.pop("example_label")
generated["attention_mask"] = generated.pop("attn_list")

In [103]:
generated.keys()

dict_keys(['embeddings', 'labels', 'attention_mask'])

In [104]:
generated_data = CustomEmbeddingDataset(generated)
generated_dataloder = DataLoader(
    generated_data,
    batch_size=4,
)

In [105]:
for batch in generated_dataloder:
    print(batch)
    break

{'embeddings': tensor([[[ 0.6627,  0.0542,  0.0560,  ...,  0.3232,  1.5695,  0.2929],
         [ 1.5244, -1.7833,  1.7203,  ...,  0.6870, -1.4302,  0.3291],
         [ 1.5020, -3.0318,  1.9343,  ...,  0.9823, -1.2311, -1.0617],
         ...,
         [ 1.0938,  0.7663,  0.8358,  ...,  1.4675,  0.9290,  1.9437],
         [ 2.1223, -2.8189, -0.8851,  ...,  1.2666,  0.4986,  0.9700],
         [ 0.9699, -0.8722, -0.2652,  ..., -0.6871, -0.8159, -1.4540]],

        [[ 0.6622,  0.0536,  0.0554,  ...,  0.3226,  1.5701,  0.2923],
         [ 1.5238, -1.7827,  1.7198,  ...,  0.6875, -1.4296,  0.3297],
         [ 1.5014, -3.0312,  1.9349,  ...,  0.9828, -1.2305, -1.0611],
         ...,
         [ 1.0938,  0.7663,  0.8358,  ...,  1.4675,  0.9290,  1.9437],
         [ 2.1223, -2.8189, -0.8851,  ...,  1.2666,  0.4986,  0.9700],
         [ 0.9699, -0.8722, -0.2652,  ..., -0.6871, -0.8159, -1.4540]],

        [[ 0.6681,  0.0595,  0.0613,  ...,  0.3285,  1.5642,  0.2982],
         [ 1.5297, -1.7886,  1

In [106]:
import torch
import torch.nn as nn
from scipy.stats import norm
from typing import *
from torch import Tensor
from torch.nn import Module
import torch.nn.functional as F
from functools import partial
from src.utils.sampling import SamplingDataset
from src.pruning.propagate import propagate
from src.utils.helper import Config
import gc

In [107]:
def calc_coefficient(combined_dataloader, dim=0):
    embeddings = combined_dataloader[0]["embeddings"]

    batch_size = embeddings.shape[0] // 2
    concern_inputs, non_concern_inputs = (
        embeddings[:batch_size],
        embeddings[batch_size:],
    )

    calc_norm = lambda tensors, dim: torch.norm(
        tensors.reshape((-1, tensors.shape[-1])), dim=dim
    )

    if dim == 0:
        new_shape = (-1, 1)
    else:
        new_shape = (1, -1)

    concern_norm = calc_norm(concern_inputs, dim=0).reshape(new_shape)
    non_concern_norm = calc_norm(non_concern_inputs, dim=0).reshape(new_shape)

    cosine_similarity = F.cosine_similarity(
        concern_inputs.reshape((-1, concern_inputs.shape[-1])),
        non_concern_inputs.reshape((-1, non_concern_inputs.shape[-1])),
        dim=0,
    ).reshape(new_shape)

    sine_similarity = torch.sign(cosine_similarity) * torch.sqrt(
        1 - cosine_similarity**2
    )
    euclidean_distance = torch.sqrt(concern_norm**2 + non_concern_norm**2)
    coefficient = (
        concern_norm
        + sine_similarity
        * torch.abs(concern_norm + non_concern_norm)
        / euclidean_distance
    )
    return coefficient

In [108]:
def find_layers(
    model: Module,
    layer_types: Optional[List[Type[Module]]] = None,
    include_layers: Optional[List[str]] = None,
    exclude_layers: Optional[List[str]] = None,
    prefix: str = "",
) -> Dict[str, Module]:
    if layer_types is None:
        layer_types = [nn.Linear]
    if include_layers is None:
        include_layers = []
    if exclude_layers is None:
        exclude_layers = []
    layers_dict: Dict[str, Module] = {}

    def recursive_find(module: Module, prefix: str) -> None:
        for name, layer in module.named_children():
            layer_name = f"{prefix}.{name}" if prefix else name
            if any(exclude in layer_name for exclude in exclude_layers):
                continue
            if include_layers and not any(
                include in layer_name for include in include_layers
            ):
                if not any(isinstance(layer, t) for t in layer_types):
                    recursive_find(layer, layer_name)
                continue
            if isinstance(layer, tuple(layer_types)):
                layers_dict[layer_name] = layer
            else:
                recursive_find(layer, layer_name)

    recursive_find(model, prefix)

    return layers_dict


def get_hook(method):
    def hook(module, input, output):
        method(module, input, output)

    return hook


def get_embeddings(model, dataloader):
    embeddings_list = {"embeddings": [], "labels": [], "attention_mask": []}

    for batch in dataloader:
        input_ids = batch["input_ids"]
        labels = batch["labels"]
        attention_mask = batch["attention_mask"]
        with torch.no_grad():
            input_embeddings = model.get_input_embeddings()(input_ids)
        embeddings_list["embeddings"].append(input_embeddings)
        embeddings_list["labels"].append(labels)
        embeddings_list["attention_mask"].append(attention_mask)
        from src.utils.data_class import CustomEmbeddingDataset
    return CustomEmbeddingDataset(embeddings_list)

In [109]:
class Methods:
    def __init__(self, name: str, ratio: float, method="unstructed") -> None:
        self.name = name
        self.ratio = ratio
        self.coefficient = None
        self.method = method
        self.num_mask = 0
        self.pruning_mask = {}

    def ci(self, layer, inputs, outputs):
        current_weight = layer.weight.data.clone()

        importance_score = torch.abs(current_weight) * torch.abs(self.coefficient)

        if self.method == "unstructed":
            sort_res = torch.sort(importance_score, dim=-1, stable=True)
            num_prune = int(current_weight.shape[1] * self.ratio)
            indices_to_prune = sort_res[1][:, :num_prune]
        elif self.method == "structed":
            # importance_vector = torch.mean(importance_score, dim=1)
            # importance_vector =  torch.mean(current_weight, dim=1)
            importance_vector = torch.mean(importance_score, dim=1) / torch.norm(
                current_weight, dim=1
            )
            num_prune = int(importance_vector.shape[0] * self.ratio)
            sort_res = torch.sort(importance_vector)
            indices_to_prune = sort_res[1][:num_prune]
        else:
            raise NotImplementedError(f"{self.method} is not implemented")
        if indices_to_prune.dim() == 2:
            indices_to_prune, _ = torch.sort(indices_to_prune, dim=1)
        else:
            indices_to_prune, _ = torch.sort(indices_to_prune)
        pruning_list = indices_to_prune.tolist()
        self.pruning_mask[self.num_mask] = pruning_list
        self.num_mask += 1

    def apply(self, layer, mask):
        current_weight = layer.weight.data.clone()

        if self.method == "unstructed":
            for row_idx, prune_indices in mask.item():
                current_weight[row_idx, prune_indices] = 0
        elif self.method == "structed":
            pruned_weight = torch.zeros_like(current_weight)
            pruned_weight = current_weight[mask, :]
            layer.weight.data = pruned_weight

        layer.weight.data = current_weight

In [110]:
def prune_concern_identification(
    model: Module,
    config: Config,
    dominant_concern: SamplingDataset,
    non_dominant_concern: SamplingDataset,
    sparsity_ratio: float = 0.6,
    include_layers: Optional[List[str]] = None,
    exclude_layers: Optional[List[str]] = None,
    method="unstructed",
    keep_dim=True,
) -> None:
    layers = find_layers(
        model, include_layers=include_layers, exclude_layers=exclude_layers
    )
    handle_list = []

    method1 = Methods(name="intermediate", ratio=sparsity_ratio, method=method)
    method2 = Methods(name="output", ratio=sparsity_ratio, method=method)

    for name, layer in layers.items():
        if "intermediate" in name:
            handle = layer.register_forward_hook(method1.ci)
        else:
            handle = layer.register_forward_hook(method2.ci)
        handle_list.append(handle)

    first_batch = next(iter(dominant_concern))
    is_embeds = "embeddings" in first_batch
    if not is_embeds:
        pos_embeddings = get_embeddings(model, dominant_concern)
        neg_embeddings = get_embeddings(model, non_dominant_concern)
    else:
        pos_embeddings = dominant_concern
        neg_embeddings = non_dominant_concern
    dominant_batches = list(pos_embeddings)
    non_dominant_batches = list(neg_embeddings)
    combined_batches = {}
    keys = dominant_batches[0].keys()
    for key in keys:
        combined_batches[key] = torch.cat(
            [batch[key] for batch in dominant_batches + non_dominant_batches]
        )

    combined_dataloader = [combined_batches]
    method1.coefficient = calc_coefficient(combined_dataloader, dim=1).to(config.device)
    method2.coefficient = calc_coefficient(combined_dataloader, dim=0).to(config.device)
    propagate(model, combined_dataloader, config)

    for handle in handle_list:
        handle.remove()
    print(method1.pruning_mask[0])
    print(method1.pruning_mask[1])
    print(method1.pruning_mask[2])
    print(method1.pruning_mask[3])
    if method == "unstructed":
        for i in range(method1.num_mask):
            method1.pruning_mask
    else:
        pass

In [111]:
result_list = []

for concern in range(config.num_labels):
    config.init_seed()
    positive_samples = SamplingDataset(
        generated_dataloder,
        config,
        concern,
        num_samples,
        True,
        4,
        resample=False,
    )
    negative_samples = SamplingDataset(
        generated_dataloder,
        config,
        concern,
        num_samples,
        False,
        4,
        resample=False,
    )
    all_samples = SamplingDataset(
        generated_dataloder,
        config,
        200,
        num_samples,
        False,
        4,
        resample=False,
    )

    module = copy.deepcopy(model)
    head_importance_prunning(module, config, all_samples, 0.5)
    module = module.to("cpu")
    prune_concern_identification(
        module,
        config,
        positive_samples,
        negative_samples,
        include_layers=["intermediate", "output"],
        exclude_layers=["attention"],
        sparsity_ratio=0.5,
        keep_dim=True,
        method="structed",
    )

    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, config, test_dataloader, verbose=True)
    result_list.append(result)
    break

Total heads to prune: 8
{(1, 2), (0, 0), (3, 0), (2, 3), (0, 2), (3, 3), (3, 2), (1, 3)}
[0, 2, 3, 4, 5, 8, 10, 12, 13, 15, 16, 18, 20, 21, 23, 24, 29, 33, 35, 36, 40, 42, 44, 50, 51, 52, 54, 55, 59, 60, 63, 65, 67, 73, 75, 76, 85, 87, 89, 90, 92, 93, 95, 96, 97, 102, 103, 104, 106, 108, 112, 113, 115, 116, 119, 122, 123, 124, 126, 135, 136, 139, 140, 141, 143, 146, 147, 151, 153, 157, 159, 161, 164, 165, 166, 169, 171, 173, 174, 175, 176, 180, 181, 185, 187, 190, 191, 194, 196, 199, 200, 202, 205, 207, 208, 210, 212, 215, 216, 217, 218, 219, 223, 224, 225, 226, 228, 229, 231, 233, 237, 238, 240, 242, 243, 249, 250, 251, 252, 254, 257, 258, 259, 260, 262, 263, 264, 266, 270, 271, 275, 276, 277, 280, 282, 284, 285, 287, 290, 291, 292, 293, 300, 301, 302, 303, 304, 305, 306, 308, 309, 310, 311, 312, 318, 319, 320, 321, 322, 323, 324, 326, 327, 328, 329, 330, 331, 333, 337, 338, 342, 344, 345, 347, 348, 349, 350, 352, 356, 360, 363, 364, 365, 368, 373, 374, 375, 377, 380, 381, 383, 387, 3

Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.3300
Precision: 0.6319, Recall: 0.5879, F1-Score: 0.5921
              precision    recall  f1-score   support

           0     0.4183    0.5849    0.4877      2992
           1     0.6675    0.4609    0.5453      2992
           2     0.6015    0.6308    0.6158      3012
           3     0.3468    0.6004    0.4396      2998
           4     0.7768    0.6660    0.7171      2973
           5     0.7932    0.7649    0.7788      3054
           6     0.7569    0.3287    0.4583      3003
           7     0.5551    0.6594    0.6027      3012
           8     0.6974    0.4940    0.5783      2982
           9     0.7060    0.6895    0.6977      2982

    accuracy                         0.5882     30000
   macro avg     0.6319    0.5879    0.5921     30000
weighted avg     0.6320    0.5882    0.5924     30000



In [112]:
from src.utils.helper import report_to_df, append_nth_row

df_list = [report_to_df(df) for df in result_list]
new_df = append_nth_row(df_list)
new_df

Unnamed: 0,class,precision,recall,f1-score,support
0,0,0.4183,0.5849,0.4877,2992
