In [1]:
import os
import sys

sys.path.append("../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint, load_checkpoint
from src.models.evaluate import evaluate_model, get_sparsity, get_similarity
from src.utils.sampling import SamplingDataset
from src.pruning.prune_head import head_importance_prunning

In [3]:
# name = "bert-tiny-yahoo"
name = "bert-4-128-yahoo"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ci_ratio = 0.5
seed = 44

In [4]:
config = Config(name, device)

In [5]:
model = load_model(config=config)

Loading the model.
{'architectures': 'bert',
 'dataset_name': 'YahooAnswersTopics',
 'model_name': 'models/bert-4-128-yahoo',
 'num_labels': 10,
 'tokenizer_name': 'fabriceyhc/bert-base-uncased-yahoo_answers_topics'}
The model models/bert-4-128-yahoo is loaded.


In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.
train.pkl is loaded from cache.
valid.pkl is loaded from cache.
test.pkl is loaded from cache.
The dataset YahooAnswersTopics is loaded
{'config_name': 'yahoo_answers_topics',
 'features': {'first_column': 'question_title', 'second_column': 'topic'},
 'path': 'yahoo_answers_topics'}


In [7]:
positive_samples = SamplingDataset(
    train_dataloader,
    config,
    0,
    num_samples,
    True,
    4,
    resample=False,
)

In [8]:
negative_samples = SamplingDataset(
    train_dataloader,
    config,
    0,
    num_samples,
    False,
    4,
    resample=False,
)

In [9]:
import torch
import torch.nn as nn
from scipy.stats import norm
from typing import *
from torch import Tensor
from torch.nn import Module
import torch.nn.functional as F
from functools import partial
from src.utils.sampling import SamplingDataset
from src.pruning.propagate import propagate
from src.utils.helper import Config
import gc
from src.pruning.prune import prune_concern_identification

In [10]:
result_list = []

for concern in range(config.num_labels):
    config.init_seed()
    positive_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        True,
        4,
        resample=False,
    )
    negative_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        False,
        4,
        resample=False,
    )
    all_samples = SamplingDataset(
        train_dataloader,
        config,
        200,
        num_samples,
        False,
        4,
        resample=False,
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        config,
        positive_samples,
        negative_samples,
        include_layers=["intermediate", "output"],
        exclude_layers=["attention"],
        sparsity_ratio=ci_ratio,
        keep_dim=True,
        method="structed",
    )

    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, config, test_dataloader)
    result_list.append(result)
    break

Evaluate the pruned model 0


Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.2143
Precision: 0.6435, Recall: 0.6156, F1-Score: 0.6189
              precision    recall  f1-score   support

           0     0.5123    0.5080    0.5102      2992
           1     0.7087    0.4813    0.5732      2992
           2     0.7166    0.5920    0.6484      3012
           3     0.3570    0.6151    0.4518      2998
           4     0.7222    0.7679    0.7444      2973
           5     0.8380    0.7639    0.7992      3054
           6     0.6700    0.4016    0.5022      3003
           7     0.6073    0.6398    0.6231      3012
           8     0.5873    0.7207    0.6472      2982
           9     0.7160    0.6653    0.6897      2982

    accuracy                         0.6156     30000
   macro avg     0.6435    0.6156    0.6189     30000
weighted avg     0.6439    0.6156    0.6191     30000



In [11]:
from src.utils.helper import report_to_df, append_nth_row

df_list = [report_to_df(df) for df in result_list]
new_df = append_nth_row(df_list)
new_df

Unnamed: 0,class,precision,recall,f1-score,support
0,0,0.5123,0.508,0.5102,2992
