In [1]:
import os
import sys

sys.path.append("../../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint
from src.models.evaluate import (
    evaluate_model,
    get_sparsity,
    get_similarity,
    get_perplexity,
)
from src.utils.sampling import SamplingDataset
from src.pruning.prune_head import head_importance_prunning
from src.utils.helper import report_to_df, append_nth_row

In [3]:
name = "bert-6-128-yahoo"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ratio = 0.5
seed = 44

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-10-22 22:49:21


In [5]:
config = Config(name, device)
num_labels = config.config["num_labels"]
model = load_model(config)

Loading the model.




{

'architectures'

: 

'bert'

,
 

'dataset_name'

: 

'YahooAnswersTopics'

,
 

'model_name'

: 

'models/bert-6-128-yahoo'

,
 

'num_labels'

: 

10

,
 

'tokenizer_name'

: 

'fabriceyhc/bert-base-uncased-yahoo_answers_topics'

}




The model models/bert-6-128-yahoo is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.




train.pkl is loaded from cache.




valid.pkl is loaded from cache.




test.pkl is loaded from cache.




The dataset YahooAnswersTopics is loaded




{

'config_name'

: 

'yahoo_answers_topics'

,
 

'features'

: 

{'first_column': 'question_title', 'second_column': 'topic'}

,
 

'path'

: 

'yahoo_answers_topics'

}




In [7]:
from src.utils.load import load_cache
from src.utils.data_class import CustomEmbeddingDataset
from torch.utils.data import DataLoader

generated = load_cache(
    "datasets/generated_dataset/embedding_based/4_128-yahoo",
    "4_128-yahoo_top1.pkl",
)

4_128-yahoo_top1.pkl is loaded from cache.




In [8]:
generated["embeddings"] = generated.pop("example_list")
generated["labels"] = generated.pop("example_label")
generated["attention_mask"] = generated.pop("attn_list")

In [9]:
generated_data = CustomEmbeddingDataset(generated)
generated_dataloder = DataLoader(
    generated_data,
    batch_size=4,
)

In [10]:
# print("Evaluate the original model")
# result = evaluate_model(model, config, test_dataloader)

In [11]:
config.init_seed()
all_samples = SamplingDataset(
    generated_dataloder,
    config,
    200,
    num_samples,
    False,
    4,
    resample=False,
)

In [12]:
result_list = []

module = copy.deepcopy(model)

head_importance_prunning(module, config, all_samples, ratio)

print(f"Evaluate the pruned model")
result = evaluate_model(module, config, test_dataloader, verbose=True)
result_list.append(result)

Total heads to prune: 6




tensor([[0.5141, 0.4859],
        [0.5482, 0.4518],
        [0.5592, 0.4408],
        [0.4687, 0.5313],
        [0.4890, 0.5110],
        [0.4927, 0.5073]])




{(0, 1), (4, 0), (2, 1), (1, 1), (3, 0), (5, 0)}




Evaluate the pruned model




Evaluating the model:   0%|                                                                                   …

In [13]:
for concern in range(num_labels):
    config.init_seed()
    get_similarity(model, module, valid_dataloader, concern, num_samples, config)

adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9782933065994711




CCA coefficients mean non-concern: 0.9769895932078556




Linear CKA concern: 0.8756037175896856




Linear CKA non-concern: 0.8817423604990791




Kernel CKA concern: 0.8341925965208935




Kernel CKA non-concern: 0.7945960620331488




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9867353633204428




CCA coefficients mean non-concern: 0.9754316571148505




Linear CKA concern: 0.8295938655608004




Linear CKA non-concern: 0.8837682821595425




Kernel CKA concern: 0.7354281929765251




Kernel CKA non-concern: 0.7904153301817359




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9765100700189576




CCA coefficients mean non-concern: 0.9747702238867104




Linear CKA concern: 0.8717506958364974




Linear CKA non-concern: 0.8783034350616038




Kernel CKA concern: 0.8809358552875118




Kernel CKA non-concern: 0.7862521356899299




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9811160511875017




CCA coefficients mean non-concern: 0.9776943813489586




Linear CKA concern: 0.8993409527399433




Linear CKA non-concern: 0.8694483121371204




Kernel CKA concern: 0.8422295712656584




Kernel CKA non-concern: 0.7895840268490585




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.980501266047556




CCA coefficients mean non-concern: 0.9773654460085368




Linear CKA concern: 0.8927600397126704




Linear CKA non-concern: 0.8778805743575776




Kernel CKA concern: 0.869461095309524




Kernel CKA non-concern: 0.7817194417498663




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9735178903920166




CCA coefficients mean non-concern: 0.9815827654134192




Linear CKA concern: 0.6212106420568105




Linear CKA non-concern: 0.87492147297455




Kernel CKA concern: 0.5345943448269335




Kernel CKA non-concern: 0.8187458192128652




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9839191272625606




CCA coefficients mean non-concern: 0.9751935548499839




Linear CKA concern: 0.888814376158154




Linear CKA non-concern: 0.870784161066947




Kernel CKA concern: 0.8267666629178069




Kernel CKA non-concern: 0.7934686740280165




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9813432149400895




CCA coefficients mean non-concern: 0.9786567650356887




Linear CKA concern: 0.8845552013446829




Linear CKA non-concern: 0.8604277757409493




Kernel CKA concern: 0.7535396641541515




Kernel CKA non-concern: 0.7876656354486654




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9857409293008769




CCA coefficients mean non-concern: 0.9793690877132822




Linear CKA concern: 0.8851081546847636




Linear CKA non-concern: 0.8690689852392894




Kernel CKA concern: 0.8792321372753099




Kernel CKA non-concern: 0.7924160710353165




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.985420541440683




CCA coefficients mean non-concern: 0.9778639605856895




Linear CKA concern: 0.8712837369973188




Linear CKA non-concern: 0.8684539337568247




Kernel CKA concern: 0.7811670701989467




Kernel CKA non-concern: 0.7922013682539825




In [14]:
get_sparsity(module)
print("original model's perplexity")
get_perplexity(model, valid_dataloader, config)
print("pruned model's perplexity")
get_perplexity(module, valid_dataloader, config)

0.16324659861403798




{'bert.encoder.layer.0.attention.self.query.weight': 0.5, 'bert.encoder.layer.0.attention.self.query.bias': 0.0, 'bert.encoder.layer.0.attention.self.key.weight': 0.5, 'bert.encoder.layer.0.attention.self.key.bias': 0.0, 'bert.encoder.layer.0.attention.self.value.weight': 0.5, 'bert.encoder.layer.0.attention.self.value.bias': 0.0, 'bert.encoder.layer.0.attention.output.dense.weight': 0.5, 'bert.encoder.layer.0.attention.output.dense.bias': 0.0, 'bert.encoder.layer.0.intermediate.dense.weight': 0.0, 'bert.encoder.layer.0.intermediate.dense.bias': 0.0, 'bert.encoder.layer.0.output.dense.weight': 0.0, 'bert.encoder.layer.0.output.dense.bias': 0.0, 'bert.encoder.layer.1.attention.self.query.weight': 0.5, 'bert.encoder.layer.1.attention.self.query.bias': 0.0, 'bert.encoder.layer.1.attention.self.key.weight': 0.5, 'bert.encoder.layer.1.attention.self.key.bias': 0.0, 'bert.encoder.layer.1.attention.self.value.weight': 0.5, 'bert.encoder.layer.1.attention.self.value.bias': 0.0, 'bert.encoder.l




original model's perplexity




3.187649726867676




pruned model's perplexity




3.7875261306762695




3.7875261306762695

In [15]:
df_list = [report_to_df(df) for df in result_list]
csv_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
new_df = df_list[0]
new_df.to_csv(f"results/{csv_name}.csv", index=False)
print(csv_name)
new_df

2024-10-22_22-51-09




Unnamed: 0,class,precision,recall,f1-score,support
0,0,0.4241,0.5842,0.4914,2992
1,1,0.694,0.4669,0.5582,2992
2,2,0.681,0.6039,0.6402,3012
3,3,0.3207,0.6117,0.4208,2998
4,4,0.7784,0.6179,0.6889,2973
5,5,0.818,0.5491,0.6571,3054
6,6,0.6239,0.4166,0.4996,3003
7,7,0.6167,0.4797,0.5397,3012
8,8,0.6213,0.6486,0.6346,2982
9,9,0.5795,0.7005,0.6343,2982
