In [237]:
import os
import sys

sys.path.append("../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [238]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint, load_checkpoint
from src.models.evaluate import evaluate_model, get_sparsity, get_similarity
from src.utils.sampling import SamplingDataset
from src.pruning.prune_head import head_importance_prunning

In [239]:
# name = "bert-4-128-yahoo"
name = "OSDG"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 128
ci_ratio = 0.3
seed = 44

In [240]:
config = Config(name, device)

In [241]:
model = load_model(config=config)

Loading the model.
{'architectures': 'bert',
 'dataset_name': 'OSDG',
 'model_name': 'sadickam/sdg-classification-bert',
 'num_labels': 16,
 'tokenizer_name': 'sadickam/sdg-classification-bert'}
The model sadickam/sdg-classification-bert is loaded.


In [242]:
model

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [243]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset OSDG.
train.pkl is loaded from cache.
valid.pkl is loaded from cache.
test.pkl is loaded from cache.
The dataset OSDG is loaded
{'config_name': '2024-01-01',
 'features': {'first_column': 'text', 'second_column': 'labels'},
 'path': 'albertmartinez/OSDG'}


In [244]:
positive_samples = SamplingDataset(
    train_dataloader,
    config,
    0,
    num_samples,
    True,
    4,
    resample=False,
)

In [245]:
negative_samples = SamplingDataset(
    train_dataloader,
    config,
    0,
    num_samples,
    False,
    4,
    resample=False,
)

In [246]:
import torch
import torch.nn as nn
from scipy.stats import norm
from typing import *
from torch import Tensor
from torch.nn import Module
import torch.nn.functional as F
from functools import partial
from src.utils.sampling import SamplingDataset
from src.pruning.propagate import propagate
from src.utils.helper import Config
import gc


class Methods:
    def __init__(
        self, ratio: float, keep_dim, axis: int = 0, method="unstructed"
    ) -> None:
        self.ratio = ratio
        self.axis = axis
        self.coefficient = None
        self.keep_dim = keep_dim
        self.method = method

    def ci(self, layer, inputs, outputs):
        current_weight = layer.weight.data.clone()
        current_bias = layer.bias.data.clone() if layer.bias is not None else None

        X = inputs[0]
        batch_size = X.shape[0] // 2

        concern_inputs, non_concern_inputs = (
            X[:batch_size],
            X[batch_size:],
        )  # (batch_size, seq_dim, input_dim)

        calc_norm = lambda tensors: torch.norm(
            tensors.reshape(-1, tensors.shape[-1]), dim=0
        )

        concern_norm = calc_norm(concern_inputs).reshape(1, -1)
        non_concern_norm = calc_norm(non_concern_inputs).reshape(1, -1)

        cosine_similarity = F.cosine_similarity(
            concern_inputs.reshape(-1, concern_inputs.shape[-1]),
            non_concern_inputs.reshape(-1, non_concern_inputs.shape[-1]),
            dim=0,
        ).reshape(1, -1)

        sine_similarity = torch.sqrt(1 - cosine_similarity**2 + 1e-8)
        distance = torch.sqrt(concern_norm**2 + non_concern_norm**2 + 1e-8)
        coefficient = (
            concern_norm
            + sine_similarity * torch.abs(concern_norm + non_concern_norm) / distance
        )

        concern_mean = concern_inputs.mean(dim=(0, 1))
        concern_var = concern_inputs.var(dim=(0, 1)) + 1e-8

        non_concern_mean = non_concern_inputs.mean(dim=(0, 1))
        non_concern_var = non_concern_inputs.var(dim=(0, 1)) + 1e-8

        fisher_score = (concern_mean - non_concern_mean) ** 2 / (
            concern_var + non_concern_var
        )
        fisher_score = fisher_score / (fisher_score.max() + 1e-8)
        fisher_score = fisher_score.reshape(1, -1)
        coefficient = coefficient * (1 + fisher_score)

        importance_score = torch.abs(current_weight) * torch.abs(coefficient)

        # cosine_similarity = F.cosine_similarity(
        #     concern_inputs.reshape((-1, concern_inputs.shape[-1])),
        #     non_concern_inputs.reshape((-1, non_concern_inputs.shape[-1])),
        #     dim=0,
        # ).reshape(1, -1)

        # sine_similarity = torch.sqrt(1 - cosine_similarity**2)
        # distance = torch.sqrt(concern_norm**2 + non_concern_norm**2)
        # coefficient = (
        #     concern_norm
        #     + sine_similarity * torch.abs(concern_norm + non_concern_norm) / distance
        # )
        num_prune = int(current_weight.shape[self.axis] * self.ratio)
        # importance_score = torch.abs(current_weight)
        # importance_score = torch.abs(current_weight) * torch.abs(coefficient)
        # importance_score = torch.abs(current_weight) * torch.abs(torch.tensor(Vt1[0]/Vt2[-1], dtype=current_weight.dtype).to(current_weight.device).reshape(1, -1))

        # print(current_weight.shape)
        # print(coefficient.shape)
        # print(self.coefficient.shape)

        if self.method == "unstructed":
            W_mask = torch.zeros_like(importance_score, dtype=bool)

            sort_res = torch.sort(importance_score, dim=self.axis, stable=True)
            if self.axis == 0:
                indices_to_prune = sort_res[1][:num_prune, :]
            else:
                indices_to_prune = sort_res[1][:, :num_prune]

            W_mask.scatter_(self.axis, indices_to_prune, True)
            current_weight[W_mask] = 0
        # elif self.method == "structed":
        #     if self.axis == 0:
        #         # importance_vector = torch.mean(importance_score, dim=1)
        #         # importance_vector =  torch.mean(current_weight, dim=1)
        #         importance_vector = torch.mean(importance_score, dim=1)

        #     else:
        #         # importance_vector = torch.mean(importance_score, dim=0)
        #         # importance_vector =  torch.mean(current_weight, dim=0)
        #         importance_vector = torch.mean(importance_score, dim=0)
        #     sort_res = torch.sort(importance_vector)
        #     indices_to_prune = sort_res[1][:num_prune]

        #     if self.axis == 0:
        #         current_weight[indices_to_prune, :] = 0
        #     else:
        #         current_weight[:, indices_to_prune] = 0
        # else:
        #     raise NotImplementedError(f"{self.method} is not implemented")

        # pruned_list = indices_to_prune.tolist()
        # pruned_list = sorted(pruned_list)
        # print(f"{self.axis}: {pruned_list}")
        # if not self.keep_dim:
        #     if self.axis == 0:
        #         slice_mask = ~torch.any(current_weight, dim=1)
        #         current_weight = current_weight[slice_mask, :].clone()
        #         if current_bias is not None:
        #             current_bias = current_bias[slice_mask].clone()
        #     else:
        #         slice_mask = ~torch.any(current_weight, dim=0)
        #         current_weight = current_weight[:, slice_mask].clone()

        layer.weight.data = current_weight
        layer.bias.data = current_bias
        layer.in_features = current_weight.shape[1]
        layer.out_features = current_weight.shape[0]


def find_layers(
    model: Module,
    layer_types: Optional[List[Type[Module]]] = None,
    include_layers: Optional[List[str]] = None,
    exclude_layers: Optional[List[str]] = None,
    prefix: str = "",
) -> Dict[str, Module]:
    if layer_types is None:
        layer_types = [nn.Linear]
    if include_layers is None:
        include_layers = []
    if exclude_layers is None:
        exclude_layers = []
    layers_dict: Dict[str, Module] = {}

    def recursive_find(module: Module, prefix: str) -> None:
        for name, layer in module.named_children():
            layer_name = f"{prefix}.{name}" if prefix else name
            if any(exclude in layer_name for exclude in exclude_layers):
                continue
            if include_layers and not any(
                include in layer_name for include in include_layers
            ):
                if not any(isinstance(layer, t) for t in layer_types):
                    recursive_find(layer, layer_name)
                continue
            if isinstance(layer, tuple(layer_types)):
                layers_dict[layer_name] = layer
            else:
                recursive_find(layer, layer_name)

    recursive_find(model, prefix)

    return layers_dict


def get_hook(method):
    def hook(module, input, output):
        method(module, input, output)

    return hook


def get_embeddings(model, dataloader):
    embeddings_list = {"embeddings": [], "labels": [], "attention_mask": []}

    for batch in dataloader:
        input_ids = batch["input_ids"]
        labels = batch["labels"]
        attention_mask = batch["attention_mask"]
        with torch.no_grad():
            input_embeddings = model.get_input_embeddings()(input_ids)
        embeddings_list["embeddings"].append(input_embeddings)
        embeddings_list["labels"].append(labels)
        embeddings_list["attention_mask"].append(attention_mask)
        from src.utils.data_class import CustomEmbeddingDataset
    return CustomEmbeddingDataset(embeddings_list)


def calc_svd(tensor1, tensor2):
    flatten_tensor1 = tensor1.reshape(-1, tensor1.shape[-1]).detach().cpu()
    flatten_tensor2 = tensor2.reshape(-1, tensor2.shape[-1]).detach().cpu()
    product = flatten_tensor1.T @ flatten_tensor2
    import numpy as np

    U, sv, Vt = np.linalg.svd(product)
    normalized_U = U / np.linalg.norm(U, axis=0)
    normalized_sv = sv / np.linalg.norm(sv)
    normalized_Vt = Vt / np.linalg.norm(Vt, axis=1)
    return normalized_U, normalized_sv, normalized_Vt


def prune_concern_identification(
    model: Module,
    config: Config,
    dominant_concern: SamplingDataset,
    non_dominant_concern: SamplingDataset,
    sparsity_ratio: float = 0.6,
    include_layers: Optional[List[str]] = None,
    exclude_layers: Optional[List[str]] = None,
    method="unstructed",
    keep_dim=True,
) -> None:
    layers = find_layers(
        model, include_layers=include_layers, exclude_layers=exclude_layers
    )
    handle_list = []

    method1 = Methods(sparsity_ratio, axis=0, method=method, keep_dim=keep_dim)
    method2 = Methods(sparsity_ratio, axis=1, method=method, keep_dim=keep_dim)

    for name, layer in layers.items():
        if "intermediate" in name:
            handle = layer.register_forward_hook(method1.ci)
        else:
            handle = layer.register_forward_hook(method2.ci)
        handle_list.append(handle)

    pos_embeddings = get_embeddings(model, dominant_concern)
    neg_embeddings = get_embeddings(model, non_dominant_concern)
    dominant_batches = list(pos_embeddings)
    non_dominant_batches = list(neg_embeddings)
    combined_batches = {}
    keys = dominant_batches[0].keys()
    for key in keys:
        combined_batches[key] = torch.cat(
            [batch[key] for batch in dominant_batches + non_dominant_batches]
        )

    combined_dataloader = [combined_batches]
    method1.coefficient = calc_coefficient(combined_dataloader, dim=1)
    method2.coefficient = calc_coefficient(combined_dataloader, dim=0)
    propagate(model, combined_dataloader, config)

    for handle in handle_list:
        handle.remove()


def calc_coefficient(combined_dataloader, dim=0):
    embeddings = combined_dataloader[0]["embeddings"]

    batch_size = embeddings.shape[0] // 2
    concern_inputs, non_concern_inputs = (
        embeddings[:batch_size],
        embeddings[batch_size:],
    )

    return concern_inputs, non_concern_inputs
    # calc_norm = lambda tensors, dim: torch.norm(
    #     tensors.reshape((-1, tensors.shape[-1])), dim=dim
    # )

    # if dim == 0:
    #     new_shape = (-1, 1)
    # else:
    #     new_shape = (1, -1)

    # concern_norm = calc_norm(concern_inputs, dim=0).reshape(new_shape)
    # non_concern_norm = calc_norm(non_concern_inputs, dim=0).reshape(new_shape)

    # cosine_similarity = F.cosine_similarity(
    #     concern_inputs.reshape((-1, concern_inputs.shape[-1])),
    #     non_concern_inputs.reshape((-1, non_concern_inputs.shape[-1])),
    #     dim=0,
    # ).reshape(new_shape)

    # sine_similarity = torch.sign(cosine_similarity) * torch.sqrt(
    #     1 - cosine_similarity**2
    # )
    # euclidean_distance = torch.sqrt(concern_norm**2 + non_concern_norm**2)
    # coefficient = (
    #     concern_norm
    #     + sine_similarity
    #     * torch.abs(concern_norm + non_concern_norm)
    #     / euclidean_distance
    # )
    # return coefficient

In [247]:
result_list = []

for concern in range(config.num_labels):
    config.init_seed()
    positive_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        True,
        4,
        resample=False,
    )
    negative_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        False,
        4,
        resample=False,
    )
    all_samples = SamplingDataset(
        train_dataloader,
        config,
        200,
        num_samples,
        False,
        4,
        resample=False,
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        config,
        positive_samples,
        negative_samples,
        include_layers=["intermediate", "output"],
        exclude_layers=["attention"],
        sparsity_ratio=0.5,
        keep_dim=True,
        method="unstructed",
    )

    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, config, test_dataloader)
    result_list.append(result)
    break

Evaluate the pruned model 0


Evaluating the model:   0%|          | 0/800 [00:00<?, ?it/s]

Loss: 0.9374
Precision: 0.7739, Recall: 0.7771, F1-Score: 0.7712
              precision    recall  f1-score   support

           0     0.7504    0.6562    0.7001       797
           1     0.8424    0.6968    0.7627       775
           2     0.8686    0.8730    0.8708       795
           3     0.8763    0.8108    0.8423      1110
           4     0.8526    0.8032    0.8271      1260
           5     0.8948    0.6848    0.7759       882
           6     0.8480    0.8011    0.8239       940
           7     0.4845    0.5603    0.5196       473
           8     0.6835    0.8365    0.7523       746
           9     0.5727    0.7373    0.6447       689
          10     0.7101    0.7896    0.7477       670
          11     0.6096    0.8109    0.6960       312
          12     0.7114    0.8045    0.7551       665
          13     0.8682    0.8185    0.8426       314
          14     0.8512    0.7791    0.8135       756
          15     0.9583    0.9720    0.9651      1607

    accuracy   

In [248]:
from src.utils.helper import report_to_df, append_nth_row

df_list = [report_to_df(df) for df in result_list]
new_df = append_nth_row(df_list)
new_df

Unnamed: 0,class,precision,recall,f1-score,support
0,0,0.7504,0.6562,0.7001,797
