slice by rows and columns 

In [1]:
import os
import sys

sys.path.append("../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint, load_checkpoint
from src.models.evaluate import evaluate_model, get_sparsity, get_similarity
from src.utils.sampling import SamplingDataset
from src.pruning.prune_head import head_importance_prunning

In [3]:
name = "bert-4-128-yahoo"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 128
ci_ratio = 0.3
seed = 44

In [4]:
config = Config(name, device)

In [5]:
config.model_summary()

{'architectures': 'bert',
 'dataset_name': 'YahooAnswersTopics',
 'model_name': 'models/bert-4-128-yahoo',
 'num_labels': 10,
 'tokenizer_name': 'fabriceyhc/bert-base-uncased-yahoo_answers_topics'}


In [6]:
model = load_model(config=config)

Loading the model.
{'architectures': 'bert',
 'dataset_name': 'YahooAnswersTopics',
 'model_name': 'models/bert-4-128-yahoo',
 'num_labels': 10,
 'tokenizer_name': 'fabriceyhc/bert-base-uncased-yahoo_answers_topics'}
The model models/bert-4-128-yahoo is loaded.


In [7]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.
train.pkl is loaded from cache.
valid.pkl is loaded from cache.
test.pkl is loaded from cache.
The dataset YahooAnswersTopics is loaded
{'config_name': 'yahoo_answers_topics',
 'features': {'first_column': 'question_title', 'second_column': 'topic'},
 'path': 'yahoo_answers_topics'}


In [8]:
def get_hook(method):
    def hook(module, input, output):
        method(module, input, output)

    return hook

In [9]:
def find_layers(
    model,
    layer_types=None,
    include_layers=None,
    exclude_layers=None,
    prefix: str = "",
):
    if layer_types is None:
        layer_types = [torch.nn.Linear]
    if include_layers is None:
        include_layers = []
    if exclude_layers is None:
        exclude_layers = []
    layers_dict = {}

    def recursive_find(module, prefix: str) -> None:
        for name, layer in module.named_children():
            layer_name = f"{prefix}.{name}" if prefix else name
            if any(exclude in layer_name for exclude in exclude_layers):
                continue
            if include_layers and not any(
                include in layer_name for include in include_layers
            ):
                if not any(isinstance(layer, t) for t in layer_types):
                    recursive_find(layer, layer_name)
                continue
            if isinstance(layer, tuple(layer_types)):
                layers_dict[layer_name] = layer
            else:
                recursive_find(layer, layer_name)

    recursive_find(model, prefix)

    return layers_dict

In [10]:
def prune_concern_identification(
    model,
    model_config: Config,
    dominant_concern: SamplingDataset,
    non_dominant_concern: SamplingDataset,
    sparsity_ratio: float = 0.6,
    include_layers=None,
    exclude_layers=None,
    compress=False,
) -> None:
    layers = find_layers(
        model, include_layers=include_layers, exclude_layers=exclude_layers
    )
    handle_list = []

    method1 = Methods(sparsity_ratio, axis=0, compress=compress)
    method2 = Methods(sparsity_ratio, compress=compress)
    for name, layer in layers.items():
        if "intermediate" in name:
            handle = layer.register_forward_hook(method1.ci)
        else:
            handle = layer.register_forward_hook(method2.ci)
        handle_list.append(handle)

    dominant_batches = list(dominant_concern)
    non_dominant_batches = list(non_dominant_concern)

    if len(dominant_batches) != len(non_dominant_batches):
        raise ValueError(
            "Batch sizes of dominant_concern and non_dominant_concern does not match."
        )

    combined_batches = {}
    keys = dominant_batches[0].keys()

    for key in keys:
        combined_batches[key] = torch.cat(
            [batch[key] for batch in dominant_batches + non_dominant_batches]
        )

    combined_dataloader = [combined_batches]
    from src.pruning.propagate import propagate

    propagate(model, combined_dataloader, model_config)

    for handle in handle_list:
        handle.remove()

In [11]:
class Methods:
    def __init__(self, ratio: float, axis=1, compress=False) -> None:
        self.ratio = ratio
        self.axis = axis
        self.compress = compress

    def ci(self, layer, inputs, outputs):
        current_weight = layer.weight.data
        current_bias = layer.bias.data if layer.bias is not None else None
        X = inputs[0]

        batch_size = X.shape[0] // 2

        concern_inputs, non_concern_inputs = (
            X[:batch_size],
            X[batch_size:],
        )  # (batch_size, seq_dim, input_dim)

        calc_norm = lambda tensors, dim: torch.norm(
            tensors.reshape((-1, tensors.shape[-1])), dim=dim
        )

        concern_norm = calc_norm(concern_inputs, dim=0).reshape((1, -1))
        non_concern_norm = calc_norm(non_concern_inputs, dim=0).reshape((1, -1))

        cosine_similarity = torch.nn.functional.cosine_similarity(
            concern_inputs.reshape((-1, concern_inputs.shape[-1])),
            non_concern_inputs.reshape((-1, non_concern_inputs.shape[-1])),
            dim=0,
        ).reshape(1, -1)

        sine_similarity = torch.sqrt(1 - cosine_similarity**2)
        distance = torch.sqrt(concern_norm**2 + non_concern_norm**2)
        coefficient = (
            concern_norm
            + sine_similarity * torch.abs(concern_norm + non_concern_norm) / distance
        )

        importance_score = torch.abs(current_weight) * torch.abs(coefficient)

        if self.axis == 1:
            importance_vector = torch.mean(importance_score, axis=0)
        else:
            importance_vector = torch.mean(importance_score, axis=1)

        sort_res = torch.sort(importance_vector, stable=True)
        num_prune = int(importance_vector.numel() * self.ratio)
        indices_to_prune = sort_res[1][:num_prune]

        mask = torch.ones(current_weight.shape[self.axis], dtype=bool)
        mask[indices_to_prune] = False
        if self.axis == 1:
            pruned_weight = current_weight[:, mask]

        else:
            pruned_weight = current_weight[mask, :]
            pruned_bias = current_bias[mask]
            layer.bias.data = pruned_bias

        layer.weight.data = pruned_weight

In [12]:
for concern in range(config.num_labels):
    train = copy.deepcopy(train_dataloader)
    valid = copy.deepcopy(valid_dataloader)
    positive_samples = SamplingDataset(
        train,
        config,
        concern,
        num_samples,
        True,
        4,
        resample=False,
    )
    negative_samples = SamplingDataset(
        train,
        config,
        concern,
        num_samples,
        False,
        4,
        resample=False,
    )
    all_samples = SamplingDataset(
        train,
        config,
        200,
        num_samples,
        False,
        4,
        resample=False,
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        config,
        positive_samples,
        negative_samples,
        include_layers=["intermediate", "output"],
        exclude_layers=["attention"],
        sparsity_ratio=0.5,
        compress=True,
    )

    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, config, test_dataloader, verbose=True)

    break

Evaluate the pruned model 0


Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.2424
Precision: 0.6479, Recall: 0.6055, F1-Score: 0.6113
              precision    recall  f1-score   support

           0     0.5415    0.4843    0.5113      2992
           1     0.7079    0.4723    0.5666      2992
           2     0.7242    0.5850    0.6472      3012
           3     0.3243    0.6481    0.4323      2998
           4     0.7303    0.7514    0.7407      2973
           5     0.8624    0.7344    0.7933      3054
           6     0.6908    0.3660    0.4785      3003
           7     0.6221    0.6232    0.6227      3012
           8     0.5645    0.7364    0.6391      2982
           9     0.7105    0.6543    0.6812      2982

    accuracy                         0.6056     30000
   macro avg     0.6479    0.6055    0.6113     30000
weighted avg     0.6482    0.6056    0.6115     30000



In [13]:
result = evaluate_model(model, config, test_dataloader)

Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.2240
Precision: 0.6478, Recall: 0.6149, F1-Score: 0.6195
              precision    recall  f1-score   support

           0     0.5321    0.4843    0.5071      2992
           1     0.7005    0.4723    0.5642      2992
           2     0.6957    0.6119    0.6511      3012
           3     0.3443    0.6421    0.4482      2998
           4     0.7254    0.7783    0.7509      2973
           5     0.8403    0.7600    0.7981      3054
           6     0.6719    0.4106    0.5097      3003
           7     0.6185    0.6384    0.6283      3012
           8     0.5854    0.7146    0.6436      2982
           9     0.7637    0.6362    0.6941      2982

    accuracy                         0.6150     30000
   macro avg     0.6478    0.6149    0.6195     30000
weighted avg     0.6481    0.6150    0.6198     30000



In [14]:
def get_layer0_inputs(model, batch):
    with torch.no_grad():
        first_layer = list(model.children())[0]

        embeddings = first_layer(batch)

        return embeddings

In [15]:
for batch in train_dataloader:
    inp_batch, args_batch, kwargs_batch = get_layer0_inputs(model_adapter, batch)
    inps.append(inp_batch)
    args.append(args_batch)
    kwargs.append(kwargs_batch)
    if apply_mask:
        ignore_masks.append(batch["attention_mask"])

NameError: name 'model_adapter' is not defined

In [14]:
H = None
for idx, batch in enumerate(train_dataloader):
    batch = batch.double().to("cuda")
    H_batch = torch.sum(batch.mT @ batch, dim=0)
    H = H_batch if H is None else H + H_batch

AttributeError: 'dict' object has no attribute 'double'