In [10]:
import os
import sys

sys.path.append("../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [11]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print

from src.utils.load import load_model, load_data, save_checkpoint
from src.models.evaluate import evaluate_model, get_sparsity, get_similarity
from src.utils.sampling import SamplingDataset
from src.pruning.prune_head import head_importance_prunning

In [12]:
name = "bert-4-128-yahoo"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 128
head_ratio = 0.1
seed = 44

In [13]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-10-12 17:54:11


In [14]:
config = Config(name, device)
num_labels = config.config["num_labels"]
model = load_model(config)

Loading the model.
{'architectures': 'bert',
 'dataset_name': 'YahooAnswersTopics',
 'model_name': 'models/bert-4-128-yahoo',
 'num_labels': 10,
 'tokenizer_name': 'fabriceyhc/bert-base-uncased-yahoo_answers_topics'}
The model models/bert-4-128-yahoo is loaded.


In [15]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.
train.pkl is loaded from cache.
valid.pkl is loaded from cache.
test.pkl is loaded from cache.
The dataset YahooAnswersTopics is loaded
{'config_name': 'yahoo_answers_topics',
 'features': {'first_column': 'question_title', 'second_column': 'topic'},
 'path': 'yahoo_answers_topics'}


In [16]:
all_samples = SamplingDataset(
    train_dataloader,
    config,
    200,
    num_samples,
    False,
    4,
    resample=False,
)


module = copy.deepcopy(model)
head_importance_prunning(module, config, all_samples, 0.3)
# save_checkpoint(module, "Modules/", f"head_prune_{name}_{head_pruning_ratio}p.pt")

Total heads to prune: 4
{(0, 2), (0, 3), (2, 2), (3, 0)}


In [17]:
# print(f"Evaluate the pruned model")
# result = evaluate_model(module, config, test_dataloader)
# get_sparsity(module)

In [18]:
# for concern in range(num_labels):
#     valid = copy.deepcopy(valid_dataloader)
#     get_similarity(
#         model, module, valid, concern, num_samples, num_labels, device=device, seed=seed
#     )