In [1]:
import os
import sys

sys.path.append("../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint
from src.models.evaluate import (
    evaluate_model,
    get_sparsity,
    get_similarity,
    get_perplexity,
)
from src.utils.sampling import SamplingDataset
from src.utils.helper import report_to_df, append_nth_row

In [3]:
name = "bert-4-128-yahoo"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ratio = 0.5
seed = 44
include_layers = ["intermediate", "output"]
exclude_layers = ["attention"]

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-11-03 16:03:47


In [5]:
config = Config(name, device)
num_labels = config.config["num_labels"]
model = load_model(config)

Loading the model.
{'architectures': 'bert',
 'dataset_name': 'YahooAnswersTopics',
 'model_name': 'models/bert-4-128-yahoo',
 'num_labels': 10,
 'tokenizer_name': 'fabriceyhc/bert-base-uncased-yahoo_answers_topics'}
The model models/bert-4-128-yahoo is loaded.


In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.
train.pkl is loaded from cache.
valid.pkl is loaded from cache.
test.pkl is loaded from cache.
The dataset YahooAnswersTopics is loaded
{'config_name': 'yahoo_answers_topics',
 'features': {'first_column': 'question_title', 'second_column': 'topic'},
 'path': 'yahoo_answers_topics'}


In [7]:
import torch
import matplotlib.pyplot as plt
import seaborn as sns

def see(data):
  data = data.cpu().numpy().flatten()
  
  # # Heatmap
  # plt.figure(figsize=(10, 1))
  # sns.heatmap([data], cmap="viridis", cbar=True, annot=False, xticklabels=False, yticklabels=False)
  # plt.title("Heatmap of Tensor Values")
  # plt.show() 
  
  
  # Bar plot
  plt.figure(figsize=(15, 5))
  plt.bar(range(len(data)), data)
  plt.title("Bar Plot of Tensor Values")
  plt.xlabel("Index")
  plt.ylabel("Value")
  plt.show()



In [8]:
# print("Evaluate the original model")
# result = evaluate_model(model, config, test_dataloader)

In [9]:
import torch
import torch.nn as nn
from scipy.stats import norm
from typing import *
from torch import Tensor
from torch.nn import Module
import torch.nn.functional as F
from functools import partial
from src.utils.sampling import SamplingDataset
from src.pruning.propagate import propagate
from src.utils.helper import Config
import gc


class Pruner:
    def __init__(self, layers, ratio: float, method="unstructed") -> None:
        self.ratio = ratio
        self.method = method
        self.layers = layers
        self.pruning_mask = {}
        self.pruning_indices = {}
        
    def ci(self, layer, inputs, outputs):
        current_weight = layer.weight.data
        X = inputs[0]
        batch_size = X.shape[0] // 2

        concern_inputs, non_concern_inputs = (
            X[:batch_size],
            X[batch_size:],
        )  # (batch_size, seq_dim, input_dim)

        calc_norm = lambda tensors, dim: torch.norm(
            tensors.reshape((-1, tensors.shape[-1])), dim=dim
        )

        new_shape = (1, -1)
        concern_norm = calc_norm(concern_inputs, dim=0).reshape(new_shape)
        non_concern_norm = calc_norm(non_concern_inputs, dim=0).reshape(new_shape)
        print(layer)
        # see(concern_norm)
        # see(non_concern_norm)
        cosine_similarity = F.cosine_similarity(
            concern_inputs.reshape((-1, concern_inputs.shape[-1])),
            non_concern_inputs.reshape((-1, non_concern_inputs.shape[-1])),
            dim=0,
        ).reshape(1, -1)

        sine_similarity = torch.sqrt(1 - cosine_similarity**2)
        alpha = (cosine_similarity * sine_similarity) / (cosine_similarity - sine_similarity) 
        coefficient = (
            concern_norm +  alpha * non_concern_norm
        )
        # see(coefficient)
        importance_score = torch.abs(current_weight) * torch.abs(coefficient)

        indices_vector = None
        if self.method == "unstructed":
            sort_res = torch.sort(importance_score, dim=-1, stable=True)
            num_prune = int(current_weight.shape[1] * self.ratio)
            indices_matrix = sort_res[1][:, :num_prune]
            W_mask = (torch.ones_like(importance_score) == 1).scatter_(
                1, indices_matrix, False
            )
        elif self.method == "structed":
            importance_vector = torch.norm(importance_score, dim=1)
            num_prune = int(importance_vector.shape[0] * self.ratio)
            sort_res = torch.sort(importance_vector)
            indices_vector = sort_res[1][:num_prune]
            W_mask = (torch.ones_like(importance_vector) == 1).scatter_(
                0, indices_vector, False
            )
        else:
            raise NotImplementedError(f"{self.method} is not implemented")

        if self.method == "unstructed":
            sorted_indices_matrix = torch.sort(indices_matrix, dim=1)[0]
            indices = sorted_indices_matrix

        elif self.method == "structed":
            sorted_indices_vector = torch.sort(indices_vector)[0]
            indices = sorted_indices_vector
        else:
            raise NotImplementedError(f"The method {self.method} is not implemented")

        layer_id = id(layer)
        layer_name = [key for key, val in self.layers.items() if id(val) == layer_id][0]
        self.pruning_mask[layer_name] = W_mask
        self.pruning_indices[layer_name] = indices

    @staticmethod
    def apply(layer, method, axis, mask, keepdim):
        current_weight = layer.weight.data.clone()
        current_weight = current_weight * mask
        if not keepdim:
            if method == "structed":
                if axis == 0:
                    zero_rows = (current_weight == 0).all(dim=1)
                    current_weight = current_weight[~zero_rows]

                    if layer.bias is not None:
                        current_bias = layer.bias.data.clone()
                        layer.bias.data = current_bias[~zero_rows]
                elif axis == 1:
                    zero_cols = (current_weight == 0).all(dim=0)
                    current_weight = current_weight[:, ~zero_cols]
        layer.in_features = current_weight.shape[1]
        layer.out_features = current_weight.shape[0]
        layer.weight.data = current_weight


def find_layers(
    model: Module,
    layer_types: Optional[List[Type[Module]]] = None,
    include_layers: Optional[List[str]] = None,
    exclude_layers: Optional[List[str]] = None,
    prefix: str = "",
) -> Dict[str, Module]:
    if layer_types is None:
        layer_types = [nn.Linear]
    if include_layers is None:
        include_layers = []
    if exclude_layers is None:
        exclude_layers = []
    layers_dict: Dict[str, Module] = {}

    def recursive_find(module: Module, prefix: str) -> None:
        for name, layer in module.named_children():
            layer_name = f"{prefix}.{name}" if prefix else name
            if any(exclude in layer_name for exclude in exclude_layers):
                continue
            if include_layers and not any(
                include in layer_name for include in include_layers
            ):
                if not any(isinstance(layer, t) for t in layer_types):
                    recursive_find(layer, layer_name)
                continue
            if isinstance(layer, tuple(layer_types)):
                layers_dict[layer_name] = layer
            else:
                recursive_find(layer, layer_name)

    recursive_find(model, prefix)

    return layers_dict


def get_hook(method):
    def hook(module, input, output):
        method(module, input, output)

    return hook

def prune_concern_identification(
    model: Module,
    config: Config,
    dominant_concern: SamplingDataset,
    non_dominant_concern: SamplingDataset,
    sparsity_ratio: float = 0.6,
    include_layers: Optional[List[str]] = None,
    exclude_layers: Optional[List[str]] = None,
    method: str = "unstructed",
    keep_dim=True,
) -> None:
    layers = find_layers(
        model, include_layers=include_layers, exclude_layers=exclude_layers
    )
    handle_list = []
    pruner = Pruner(layers, ratio=sparsity_ratio, method=method)

    for name, layer in layers.items():
        if method == "structed":
            if "intermediate" in name:
                handle = layer.register_forward_hook(pruner.ci)
                handle_list.append(handle)
        else:
            handle = layer.register_forward_hook(pruner.ci)
            handle_list.append(handle)

    dominant_batches = list(dominant_concern)
    non_dominant_batches = list(non_dominant_concern)

    if len(dominant_batches) != len(non_dominant_batches):
        raise ValueError(
            "Batch sizes of dominant_concern and non_dominant_concern does not match."
        )

    combined_batches = {}
    keys = dominant_batches[0].keys()

    for key in keys:
        combined_batches[key] = torch.cat(
            [batch[key] for batch in dominant_batches + non_dominant_batches]
        )

    combined_dataloader = [combined_batches]
    propagate(model, combined_dataloader, config)
    for handle in handle_list:
        handle.remove()

    intermediate_mask = None
    for name, layer in layers.items():
        if method == "structed":
            if "intermediate" in name:
                current_mask = pruner.pruning_mask[name].to("cpu")
                intermediate_mask = current_mask
                current_mask = current_mask.unsqueeze(dim=1).expand(
                    -1, layer.weight.shape[1]
                )
                Pruner.apply(
                    layer,
                    method="structed",
                    axis=0,
                    mask=current_mask,
                    keepdim=keep_dim,
                )
            elif "output" in name:
                current_mask = intermediate_mask.unsqueeze(dim=0).expand(
                    layer.weight.shape[0], -1
                )
                Pruner.apply(
                    layer,
                    method="structed",
                    axis=1,
                    mask=current_mask,
                    keepdim=keep_dim,
                )
        elif method == "unstructed":
            current_mask = pruner.pruning_mask[name].to("cpu")
            Pruner.apply(
                layer, method="unstructed", axis=0, mask=current_mask, keepdim=keep_dim
            )



In [10]:
result_list = []

for concern in range(config.num_labels):
    config.init_seed()
    positive_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        True,
        4,
        resample=False,
    )
    negative_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        False,
        4,
        resample=False,
    )
    all_samples = SamplingDataset(
        train_dataloader,
        config,
        200,
        num_samples,
        False,
        4,
        resample=False,
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        config,
        positive_samples,
        negative_samples,
        include_layers=include_layers,
        exclude_layers=exclude_layers,
        sparsity_ratio=ratio,
        keep_dim=True,
        method="unstructed",
    )
    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, config, test_dataloader, verbose=False)
    result_list.append(result)
    get_sparsity(module)

    get_similarity(model, module, valid_dataloader, concern, num_samples, config)
    print("original model's perplexity")
    get_perplexity(model, valid_dataloader, config)
    print("pruned model's perplexity")
    get_perplexity(module, valid_dataloader, config)

Linear(in_features=128, out_features=1024, bias=True)
Linear(in_features=1024, out_features=128, bias=True)
Linear(in_features=128, out_features=1024, bias=True)
Linear(in_features=1024, out_features=128, bias=True)
Linear(in_features=128, out_features=1024, bias=True)
Linear(in_features=1024, out_features=128, bias=True)
Linear(in_features=128, out_features=1024, bias=True)
Linear(in_features=1024, out_features=128, bias=True)
Evaluate the pruned model 0


Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.2153
Precision: 0.6478, Recall: 0.6176, F1-Score: 0.6219
              precision    recall  f1-score   support

           0     0.5296    0.5057    0.5174      2992
           1     0.6928    0.5013    0.5817      2992
           2     0.7021    0.6102    0.6529      3012
           3     0.3513    0.6254    0.4499      2998
           4     0.7244    0.7780    0.7502      2973
           5     0.8412    0.7616    0.7995      3054
           6     0.6857    0.3989    0.5044      3003
           7     0.6096    0.6424    0.6256      3012
           8     0.5886    0.7086    0.6430      2982
           9     0.7522    0.6442    0.6940      2982

    accuracy                         0.6177     30000
   macro avg     0.6478    0.6176    0.6219     30000
weighted avg     0.6481    0.6177    0.6221     30000

0.39267273726798974
{'bert.encoder.layer.0.attention.self.query.weight': 0.0, 'bert.encoder.layer.0.attention.self.query.bias': 0.0, 'bert.encoder.layer.0.attention.self.key.we

Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.2185
Precision: 0.6491, Recall: 0.6177, F1-Score: 0.6219
              precision    recall  f1-score   support

           0     0.5402    0.4896    0.5137      2992
           1     0.6934    0.5094    0.5873      2992
           2     0.7051    0.6112    0.6548      3012
           3     0.3487    0.6291    0.4487      2998
           4     0.7209    0.7777    0.7482      2973
           5     0.8431    0.7616    0.8003      3054
           6     0.6894    0.3969    0.5038      3003
           7     0.6233    0.6301    0.6267      3012
           8     0.5739    0.7277    0.6417      2982
           9     0.7529    0.6439    0.6941      2982

    accuracy                         0.6178     30000
   macro avg     0.6491    0.6177    0.6219     30000
weighted avg     0.6494    0.6178    0.6221     30000

0.39267273726798974
{'bert.encoder.layer.0.attention.self.query.weight': 0.0, 'bert.encoder.layer.0.attention.self.query.bias': 0.0, 'bert.encoder.layer.0.attention.self.key.we

Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.2143
Precision: 0.6467, Recall: 0.6192, F1-Score: 0.6223
              precision    recall  f1-score   support

           0     0.5327    0.5013    0.5165      2992
           1     0.6946    0.5100    0.5882      2992
           2     0.6913    0.6208    0.6542      3012
           3     0.3589    0.6167    0.4537      2998
           4     0.7241    0.7716    0.7471      2973
           5     0.8298    0.7662    0.7967      3054
           6     0.6982    0.3913    0.5015      3003
           7     0.6187    0.6361    0.6273      3012
           8     0.5790    0.7227    0.6429      2982
           9     0.7402    0.6553    0.6951      2982

    accuracy                         0.6193     30000
   macro avg     0.6467    0.6192    0.6223     30000
weighted avg     0.6470    0.6193    0.6225     30000

0.39267273726798974
{'bert.encoder.layer.0.attention.self.query.weight': 0.0, 'bert.encoder.layer.0.attention.self.query.bias': 0.0, 'bert.encoder.layer.0.attention.self.key.we

Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.2188
Precision: 0.6481, Recall: 0.6149, F1-Score: 0.6201
              precision    recall  f1-score   support

           0     0.5387    0.4863    0.5112      2992
           1     0.6845    0.5053    0.5814      2992
           2     0.6991    0.6162    0.6550      3012
           3     0.3399    0.6354    0.4429      2998
           4     0.7218    0.7767    0.7482      2973
           5     0.8429    0.7574    0.7979      3054
           6     0.6817    0.3986    0.5030      3003
           7     0.6236    0.6298    0.6267      3012
           8     0.5864    0.7109    0.6427      2982
           9     0.7627    0.6328    0.6917      2982

    accuracy                         0.6150     30000
   macro avg     0.6481    0.6149    0.6201     30000
weighted avg     0.6484    0.6150    0.6203     30000

0.39267273726798974
{'bert.encoder.layer.0.attention.self.query.weight': 0.0, 'bert.encoder.layer.0.attention.self.query.bias': 0.0, 'bert.encoder.layer.0.attention.self.key.we

Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.2139
Precision: 0.6488, Recall: 0.6191, F1-Score: 0.6226
              precision    recall  f1-score   support

           0     0.5244    0.5064    0.5152      2992
           1     0.6974    0.5100    0.5892      2992
           2     0.7018    0.6142    0.6551      3012
           3     0.3548    0.6244    0.4525      2998
           4     0.7161    0.7830    0.7481      2973
           5     0.8340    0.7665    0.7988      3054
           6     0.6973    0.3896    0.4999      3003
           7     0.6370    0.6252    0.6310      3012
           8     0.5769    0.7223    0.6415      2982
           9     0.7485    0.6489    0.6952      2982

    accuracy                         0.6191     30000
   macro avg     0.6488    0.6191    0.6226     30000
weighted avg     0.6491    0.6191    0.6228     30000

0.39267273726798974
{'bert.encoder.layer.0.attention.self.query.weight': 0.0, 'bert.encoder.layer.0.attention.self.query.bias': 0.0, 'bert.encoder.layer.0.attention.self.key.we

Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.2100
Precision: 0.6482, Recall: 0.6181, F1-Score: 0.6217
              precision    recall  f1-score   support

           0     0.5293    0.4943    0.5112      2992
           1     0.7020    0.5040    0.5868      2992
           2     0.6991    0.6202    0.6573      3012
           3     0.3523    0.6234    0.4502      2998
           4     0.7151    0.7827    0.7474      2973
           5     0.8279    0.7705    0.7982      3054
           6     0.6881    0.3923    0.4997      3003
           7     0.6380    0.6232    0.6305      3012
           8     0.5708    0.7317    0.6413      2982
           9     0.7596    0.6388    0.6940      2982

    accuracy                         0.6182     30000
   macro avg     0.6482    0.6181    0.6217     30000
weighted avg     0.6485    0.6182    0.6219     30000

0.39267273726798974
{'bert.encoder.layer.0.attention.self.query.weight': 0.0, 'bert.encoder.layer.0.attention.self.query.bias': 0.0, 'bert.encoder.layer.0.attention.self.key.we

Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.2114
Precision: 0.6477, Recall: 0.6187, F1-Score: 0.6226
              precision    recall  f1-score   support

           0     0.5243    0.5047    0.5143      2992
           1     0.6999    0.5090    0.5894      2992
           2     0.7051    0.6096    0.6538      3012
           3     0.3559    0.6187    0.4519      2998
           4     0.7232    0.7733    0.7474      2973
           5     0.8394    0.7613    0.7984      3054
           6     0.6851    0.4006    0.5056      3003
           7     0.6220    0.6365    0.6291      3012
           8     0.5769    0.7247    0.6424      2982
           9     0.7453    0.6486    0.6936      2982

    accuracy                         0.6188     30000
   macro avg     0.6477    0.6187    0.6226     30000
weighted avg     0.6480    0.6188    0.6228     30000

0.39267273726798974
{'bert.encoder.layer.0.attention.self.query.weight': 0.0, 'bert.encoder.layer.0.attention.self.query.bias': 0.0, 'bert.encoder.layer.0.attention.self.key.we

Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.2187
Precision: 0.6481, Recall: 0.6171, F1-Score: 0.6211
              precision    recall  f1-score   support

           0     0.5400    0.4936    0.5158      2992
           1     0.6902    0.5191    0.5925      2992
           2     0.7057    0.6059    0.6520      3012
           3     0.3502    0.6254    0.4490      2998
           4     0.7191    0.7716    0.7444      2973
           5     0.8340    0.7633    0.7971      3054
           6     0.6983    0.3893    0.4999      3003
           7     0.6062    0.6408    0.6230      3012
           8     0.5794    0.7203    0.6422      2982
           9     0.7576    0.6415    0.6948      2982

    accuracy                         0.6172     30000
   macro avg     0.6481    0.6171    0.6211     30000
weighted avg     0.6484    0.6172    0.6213     30000

0.39267273726798974
{'bert.encoder.layer.0.attention.self.query.weight': 0.0, 'bert.encoder.layer.0.attention.self.query.bias': 0.0, 'bert.encoder.layer.0.attention.self.key.we

Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.2144
Precision: 0.6483, Recall: 0.6195, F1-Score: 0.6230
              precision    recall  f1-score   support

           0     0.5337    0.4977    0.5150      2992
           1     0.6964    0.5137    0.5913      2992
           2     0.6966    0.6175    0.6547      3012
           3     0.3555    0.6181    0.4513      2998
           4     0.7217    0.7773    0.7485      2973
           5     0.8384    0.7613    0.7980      3054
           6     0.6936    0.3943    0.5028      3003
           7     0.6277    0.6308    0.6292      3012
           8     0.5725    0.7284    0.6411      2982
           9     0.7470    0.6556    0.6983      2982

    accuracy                         0.6195     30000
   macro avg     0.6483    0.6195    0.6230     30000
weighted avg     0.6486    0.6195    0.6232     30000

0.39267273726798974
{'bert.encoder.layer.0.attention.self.query.weight': 0.0, 'bert.encoder.layer.0.attention.self.query.bias': 0.0, 'bert.encoder.layer.0.attention.self.key.we

Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.2127
Precision: 0.6481, Recall: 0.6166, F1-Score: 0.6211
              precision    recall  f1-score   support

           0     0.5323    0.4983    0.5148      2992
           1     0.6912    0.4960    0.5775      2992
           2     0.7038    0.6152    0.6565      3012
           3     0.3458    0.6284    0.4461      2998
           4     0.7211    0.7777    0.7483      2973
           5     0.8430    0.7593    0.7990      3054
           6     0.6815    0.3983    0.5027      3003
           7     0.6196    0.6345    0.6270      3012
           8     0.5865    0.7183    0.6458      2982
           9     0.7565    0.6398    0.6933      2982

    accuracy                         0.6167     30000
   macro avg     0.6481    0.6166    0.6211     30000
weighted avg     0.6484    0.6167    0.6213     30000

0.39267273726798974
{'bert.encoder.layer.0.attention.self.query.weight': 0.0, 'bert.encoder.layer.0.attention.self.query.bias': 0.0, 'bert.encoder.layer.0.attention.self.key.we

In [11]:
df_list = [report_to_df(df) for df in result_list]
new_df = append_nth_row(df_list)
csv_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
new_df.to_csv(f"results/{csv_name}.csv", index=False)
print(csv_name)
new_df

2024-11-03_16-12-53


Unnamed: 0,class,precision,recall,f1-score,support
0,0,0.5296,0.5057,0.5174,2992
1,1,0.6934,0.5094,0.5873,2992
2,2,0.6913,0.6208,0.6542,3012
3,3,0.3399,0.6354,0.4429,2998
4,4,0.7161,0.783,0.7481,2973
5,5,0.8279,0.7705,0.7982,3054
6,6,0.6851,0.4006,0.5056,3003
7,7,0.6062,0.6408,0.623,3012
8,8,0.5725,0.7284,0.6411,2982
9,9,0.7565,0.6398,0.6933,2982


In [12]:
print("Evaluate the original model")
result = evaluate_model(model, config, test_dataloader)

Evaluate the original model


Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.2240
Precision: 0.6478, Recall: 0.6149, F1-Score: 0.6195
              precision    recall  f1-score   support

           0     0.5321    0.4843    0.5071      2992
           1     0.7005    0.4723    0.5642      2992
           2     0.6957    0.6119    0.6511      3012
           3     0.3443    0.6421    0.4482      2998
           4     0.7254    0.7783    0.7509      2973
           5     0.8403    0.7600    0.7981      3054
           6     0.6719    0.4106    0.5097      3003
           7     0.6185    0.6384    0.6283      3012
           8     0.5854    0.7146    0.6436      2982
           9     0.7637    0.6362    0.6941      2982

    accuracy                         0.6150     30000
   macro avg     0.6478    0.6149    0.6195     30000
weighted avg     0.6481    0.6150    0.6198     30000

