In [1]:
import os
import sys

sys.path.append("../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint, load_checkpoint
from src.models.evaluate import evaluate_model, get_sparsity, get_similarity
from src.utils.sampling import SamplingDataset
from src.pruning.prune_head import head_importance_prunning
from src.pruning.prune import prune_concern_identification

In [3]:
# name = "bert-tiny-yahoo"
name = "bert-4-128-yahoo"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ci_ratio = 0.3
seed = 44

In [4]:
config = Config(name, device)

In [5]:
model = load_model(config=config)

Loading the model.
{'architectures': 'bert',
 'dataset_name': 'YahooAnswersTopics',
 'model_name': 'models/bert-4-128-yahoo',
 'num_labels': 10,
 'tokenizer_name': 'fabriceyhc/bert-base-uncased-yahoo_answers_topics'}
The model models/bert-4-128-yahoo is loaded.


In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.
train.pkl is loaded from cache.
valid.pkl is loaded from cache.
test.pkl is loaded from cache.
The dataset YahooAnswersTopics is loaded
{'config_name': 'yahoo_answers_topics',
 'features': {'first_column': 'question_title', 'second_column': 'topic'},
 'path': 'yahoo_answers_topics'}


In [7]:
from src.utils.load import load_cache
from src.utils.data_class import CustomEmbeddingDataset
from torch.utils.data import DataLoader

generated = load_cache(
    # "datasets/generated_dataset/embedding_based/4_128-yahoo", "4_128-yahoo_top1.pkl",
    # "datasets/generated_dataset/embedding_based/4_128-yahoo", "4_128-yahoo_top2_bottom2.pkl"
    # "datasets/generated_dataset/embedding_based/4_128-yahoo", "4_128-yahoo_top4.pkl"
    # "datasets/generated_dataset/embedding_based/4_128-yahoo", "4_128-yahoo_bottom1.pkl"
    "datasets/generated_dataset/embedding_based/4_128-yahoo",
    "4_128-yahoo_bottom4.pkl",
)

4_128-yahoo_bottom4.pkl is loaded from cache.


In [8]:
print(generated.keys())

dict_keys(['example_label', 'example_list', 'attn_list'])


In [9]:
generated["embeddings"] = generated.pop("example_list")
generated["labels"] = generated.pop("example_label")
generated["attention_mask"] = generated.pop("attn_list")

In [10]:
generated.keys()

dict_keys(['embeddings', 'labels', 'attention_mask'])

In [11]:
generated_data = CustomEmbeddingDataset(generated)
generated_dataloder = DataLoader(
    generated_data,
    batch_size=4,
)

In [12]:
import torch
import torch.nn as nn
from scipy.stats import norm
from typing import *
from torch import Tensor
from torch.nn import Module
import torch.nn.functional as F
from functools import partial
from src.utils.sampling import SamplingDataset
from src.pruning.propagate import propagate
from src.utils.helper import Config
import gc

In [13]:
result_list = []

for concern in range(config.num_labels):
    config.init_seed()
    positive_samples = SamplingDataset(
        generated_dataloder,
        config,
        concern,
        num_samples,
        True,
        4,
        resample=False,
    )
    negative_samples = SamplingDataset(
        generated_dataloder,
        config,
        concern,
        num_samples,
        False,
        4,
        resample=False,
    )
    all_samples = SamplingDataset(
        generated_dataloder,
        config,
        200,
        num_samples,
        False,
        4,
        resample=False,
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        config,
        positive_samples,
        negative_samples,
        include_layers=["intermediate", "output"],
        exclude_layers=["attention"],
        sparsity_ratio=0.5,
        keep_dim=True,
        method="structed",
    )

    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, config, test_dataloader)
    result_list.append(result)
    break

Evaluate the pruned model 0


Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.2282
Precision: 0.6537, Recall: 0.6094, F1-Score: 0.6156
              precision    recall  f1-score   support

           0     0.5701    0.4646    0.5120      2992
           1     0.7071    0.4639    0.5602      2992
           2     0.7191    0.5900    0.6482      3012
           3     0.3200    0.6608    0.4312      2998
           4     0.7192    0.7804    0.7485      2973
           5     0.8529    0.7502    0.7983      3054
           6     0.7044    0.3879    0.5003      3003
           7     0.6103    0.6375    0.6236      3012
           8     0.5814    0.7210    0.6437      2982
           9     0.7527    0.6378    0.6905      2982

    accuracy                         0.6095     30000
   macro avg     0.6537    0.6094    0.6156     30000
weighted avg     0.6540    0.6095    0.6159     30000



In [14]:
from src.utils.helper import report_to_df, append_nth_row

df_list = [report_to_df(df) for df in result_list]
new_df = append_nth_row(df_list)
new_df

Unnamed: 0,class,precision,recall,f1-score,support
0,0,0.5701,0.4646,0.512,2992
