sqeeze by rows and columns 

In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint, load_checkpoint
from src.models.evaluate import evaluate_model, get_sparsity, get_similarity
from src.utils.sampling import SamplingDataset
from src.pruning.prune_head import head_importance_prunning

In [3]:
# name = "bert-tiny-yahoo"
name = "bert-4-128-yahoo"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 128
ci_ratio = 0.3
seed = 44

In [4]:
config = Config(name, device)

In [5]:
config.model_summary()

{'architectures': 'bert',
 'dataset_name': 'YahooAnswersTopics',
 'model_name': 'models/bert-4-128-yahoo',
 'num_labels': 10,
 'tokenizer_name': 'fabriceyhc/bert-base-uncased-yahoo_answers_topics'}


In [6]:
model = load_model(config=config)

Loading the model.
{'architectures': 'bert',
 'dataset_name': 'YahooAnswersTopics',
 'model_name': 'models/bert-4-128-yahoo',
 'num_labels': 10,
 'tokenizer_name': 'fabriceyhc/bert-base-uncased-yahoo_answers_topics'}
The model models/bert-4-128-yahoo is loaded.


In [7]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.
train.pkl is loaded from cache.
valid.pkl is loaded from cache.
test.pkl is loaded from cache.
The dataset YahooAnswersTopics is loaded
{'config_name': 'yahoo_answers_topics',
 'features': {'first_column': 'question_title', 'second_column': 'topic'},
 'path': 'yahoo_answers_topics'}


In [8]:
def get_hook(method):
    def hook(module, input, output):
        method(module, input, output)

    return hook


def prune_concern_identification(
    model,
    model_config: Config,
    dominant_concern: SamplingDataset,
    non_dominant_concern: SamplingDataset,
    sparsity_ratio: float = 0.6,
    include_layers=None,
    exclude_layers=None,
    compress=False,
) -> None:
    layers = find_layers(
        model, include_layers=include_layers, exclude_layers=exclude_layers
    )
    handle_list = []

    method1 = Methods(sparsity_ratio, axis=1, compress=compress)
    method2 = Methods(sparsity_ratio, axis=1, compress=compress)
    for name, layer in layers.items():
        if "intermediate" in name:
            handle = layer.register_forward_hook(method1.ci)
        else:
            handle = layer.register_forward_hook(method2.ci)
        handle_list.append(handle)

    dominant_batches = list(dominant_concern)
    non_dominant_batches = list(non_dominant_concern)

    if len(dominant_batches) != len(non_dominant_batches):
        raise ValueError(
            "Batch sizes of dominant_concern and non_dominant_concern does not match."
        )

    combined_batches = {}
    keys = dominant_batches[0].keys()

    for key in keys:
        combined_batches[key] = torch.cat(
            [batch[key] for batch in dominant_batches + non_dominant_batches]
        )

    combined_dataloader = [combined_batches]
    from src.pruning.propagate import propagate

    propagate(model, combined_dataloader, model_config)

    for handle in handle_list:
        handle.remove()

In [9]:
def find_layers(
    model,
    layer_types=None,
    include_layers=None,
    exclude_layers=None,
    prefix: str = "",
):
    if layer_types is None:
        layer_types = [torch.nn.Linear]
    if include_layers is None:
        include_layers = []
    if exclude_layers is None:
        exclude_layers = []
    layers_dict = {}

    def recursive_find(module, prefix: str) -> None:
        for name, layer in module.named_children():
            layer_name = f"{prefix}.{name}" if prefix else name
            if any(exclude in layer_name for exclude in exclude_layers):
                continue
            if include_layers and not any(
                include in layer_name for include in include_layers
            ):
                if not any(isinstance(layer, t) for t in layer_types):
                    recursive_find(layer, layer_name)
                continue
            if isinstance(layer, tuple(layer_types)):
                layers_dict[layer_name] = layer
            else:
                recursive_find(layer, layer_name)

    recursive_find(model, prefix)

    return layers_dict

In [10]:
class Methods:
    def __init__(self, ratio: float, axis=1, compress=False) -> None:
        self.ratio = ratio
        self.axis = axis
        self.compress = compress

    def ci(self, layer, inputs, outputs):
        current_weight = layer.weight.data
        current_bias = layer.bias.data if layer.bias is not None else None
        X = inputs[0]

        batch_size = X.shape[0] // 2

        concern_inputs, non_concern_inputs = (
            X[:batch_size],
            X[batch_size:],
        )  # (batch_size, seq_dim, input_dim)

        calc_norm = lambda tensors, dim: torch.norm(
            tensors.reshape((-1, tensors.shape[-1])), dim=dim
        )

        concern_norm = calc_norm(concern_inputs, dim=0).reshape((1, -1))
        non_concern_norm = calc_norm(non_concern_inputs, dim=0).reshape((1, -1))

        cosine_similarity = torch.nn.functional.cosine_similarity(
            concern_inputs.reshape((-1, concern_inputs.shape[-1])),
            non_concern_inputs.reshape((-1, non_concern_inputs.shape[-1])),
            dim=0,
        ).reshape(1, -1)

        sine_similarity = torch.sqrt(1 - cosine_similarity**2)
        distance = torch.sqrt(concern_norm**2 + non_concern_norm**2)
        coefficient = (
            concern_norm
            + sine_similarity * torch.abs(concern_norm + non_concern_norm) / distance
        )

        importance_score = torch.abs(current_weight) * torch.abs(coefficient)
        W_mask = torch.zeros_like(importance_score) == 1
        sort_res = torch.sort(importance_score, dim=self.axis, stable=True)
        num_prune = int(importance_score.shape[self.axis] * self.ratio)

        if self.axis == 0:
            indices_to_prune = sort_res[1][:num_prune, :]
        else:
            indices_to_prune = sort_res[1][:, :num_prune]

        W_mask.scatter_(self.axis, indices_to_prune, True)
        current_weight[W_mask] = 0

In [11]:
result_list = []

for concern in range(config.num_labels):
    config.init_seed()
    positive_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        True,
        4,
        resample=False,
    )
    negative_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        False,
        4,
        resample=False,
    )
    all_samples = SamplingDataset(
        train_dataloader,
        config,
        200,
        num_samples,
        False,
        4,
        resample=False,
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        config,
        positive_samples,
        negative_samples,
        include_layers=["intermediate", "output"],
        exclude_layers=["attention"],
        sparsity_ratio=0.5,
        compress=True,
    )

    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, config, test_dataloader, verbose=True)
    result_list.append(result)

Evaluate the pruned model 0


Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Evaluate the pruned model 1


Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Evaluate the pruned model 2


Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Evaluate the pruned model 3


Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Evaluate the pruned model 4


Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Evaluate the pruned model 5


Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Evaluate the pruned model 6


Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Evaluate the pruned model 7


Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Evaluate the pruned model 8


Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Evaluate the pruned model 9


Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

In [12]:
from src.utils.helper import report_to_df, append_nth_row

df_list = [report_to_df(df) for df in result_list]
new_df = append_nth_row(df_list)

In [13]:
new_df

Unnamed: 0,class,precision,recall,f1-score,support
0,0,0.5351,0.4843,0.5084,2992
1,1,0.6986,0.4779,0.5676,2992
2,2,0.6986,0.6172,0.6554,3012
3,3,0.3444,0.6424,0.4484,2998
4,4,0.7206,0.779,0.7487,2973
5,5,0.8406,0.7616,0.7992,3054
6,6,0.6763,0.4056,0.5071,3003
7,7,0.6216,0.6375,0.6294,3012
8,8,0.5822,0.7193,0.6436,2982
9,9,0.7601,0.6385,0.694,2982


In [14]:
result = evaluate_model(model, config, test_dataloader)

Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.2240
Precision: 0.6478, Recall: 0.6149, F1-Score: 0.6195
              precision    recall  f1-score   support

           0     0.5321    0.4843    0.5071      2992
           1     0.7005    0.4723    0.5642      2992
           2     0.6957    0.6119    0.6511      3012
           3     0.3443    0.6421    0.4482      2998
           4     0.7254    0.7783    0.7509      2973
           5     0.8403    0.7600    0.7981      3054
           6     0.6719    0.4106    0.5097      3003
           7     0.6185    0.6384    0.6283      3012
           8     0.5854    0.7146    0.6436      2982
           9     0.7637    0.6362    0.6941      2982

    accuracy                         0.6150     30000
   macro avg     0.6478    0.6149    0.6195     30000
weighted avg     0.6481    0.6150    0.6198     30000

