In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint
from src.models.evaluate import (
    evaluate_model,
    get_sparsity,
    get_similarity,
    get_perplexity,
)
from src.utils.sampling import SamplingDataset
from src.pruning.prune_head import head_importance_prunning
from src.pruning.prune import prune_concern_identification
from src.utils.helper import report_to_df, append_nth_row

In [3]:
name = "bert-4-128-yahoo"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ratio = 0.5
seed = 44
include_layers = ["intermediate", "output"]
exclude_layers = [
    "attention",
]

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-10-15 17:16:29


In [5]:
config = Config(name, device)
num_labels = config.config["num_labels"]
model = load_model(config)

Loading the model.




{

'architectures'

: 

'bert'

,
 

'dataset_name'

: 

'YahooAnswersTopics'

,
 

'model_name'

: 

'models/bert-4-128-yahoo'

,
 

'num_labels'

: 

10

,
 

'tokenizer_name'

: 

'fabriceyhc/bert-base-uncased-yahoo_answers_topics'

}




The model models/bert-4-128-yahoo is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.




train.pkl is loaded from cache.




valid.pkl is loaded from cache.




test.pkl is loaded from cache.




The dataset YahooAnswersTopics is loaded




{

'config_name'

: 

'yahoo_answers_topics'

,
 

'features'

: 

{'first_column': 'question_title', 'second_column': 'topic'}

,
 

'path'

: 

'yahoo_answers_topics'

}




In [7]:
# print("Evaluate the original model")
# result = evaluate_model(model, config, test_dataloader)

In [8]:
result_list = []

for concern in range(config.num_labels):
    config.init_seed()
    positive_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        True,
        4,
        resample=False,
    )
    negative_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        False,
        4,
        resample=False,
    )
    all_samples = SamplingDataset(
        train_dataloader,
        config,
        200,
        num_samples,
        False,
        4,
        resample=False,
    )

    module = copy.deepcopy(model)

    head_importance_prunning(module, config, all_samples, ratio)

    prune_concern_identification(
        module,
        config,
        positive_samples,
        negative_samples,
        include_layers=["intermediate", "output"],
        exclude_layers=["attention"],
        sparsity_ratio=ratio,
        keep_dim=True,
        method="unstructed",
    )

    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, config, test_dataloader)
    result_list.append(result)
    get_sparsity(module)

    get_similarity(model, module, valid_dataloader, concern, num_samples, config)
    print("original model's perplexity")
    get_perplexity(model, valid_dataloader, config)
    print("pruned model's perplexity")
    get_perplexity(module, valid_dataloader, config)

Total heads to prune: 8




{(1, 2), (0, 3), (2, 0), (3, 0), (0, 2), (2, 2), (1, 0), (1, 3)}




Evaluate the pruned model 0




Evaluating the model:   0%|                                                                                   …

Loss: 1.3507




Precision: 0.6565, Recall: 0.5715, F1-Score: 0.5893




              precision    recall  f1-score   support

           0     0.5063    0.5338    0.5197      2992
           1     0.6787    0.3840    0.4905      2992
           2     0.6794    0.5840    0.6281      3012
           3     0.2677    0.7141    0.3894      2998
           4     0.8137    0.6374    0.7148      2973
           5     0.8883    0.6798    0.7702      3054
           6     0.6721    0.3973    0.4994      3003
           7     0.5499    0.6730    0.6053      3012
           8     0.6916    0.5744    0.6276      2982
           9     0.8170    0.5376    0.6485      2982

    accuracy                         0.5718     30000
   macro avg     0.6565    0.5715    0.5893     30000
weighted avg     0.6567    0.5718    0.5896     30000





0.49084092158498716




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9876461791550796




CCA coefficients mean non-concern: 0.988605286592177




Linear CKA concern: 0.9569220653782988




Linear CKA non-concern: 0.9565279160828727




Kernel CKA concern: 0.8883825900031095




Kernel CKA non-concern: 0.9005037788573363




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.698537588119507




Total heads to prune: 8




{(1, 2), (0, 3), (2, 0), (3, 0), (0, 2), (2, 2), (1, 0), (1, 3)}




Evaluate the pruned model 1




Evaluating the model:   0%|                                                                                   …

Loss: 1.3517




Precision: 0.6561, Recall: 0.5705, F1-Score: 0.5883




              precision    recall  f1-score   support

           0     0.5094    0.5334    0.5211      2992
           1     0.6779    0.3750    0.4829      2992
           2     0.6769    0.5843    0.6272      3012
           3     0.2660    0.7161    0.3879      2998
           4     0.8118    0.6398    0.7156      2973
           5     0.8864    0.6798    0.7695      3054
           6     0.6706    0.3959    0.4979      3003
           7     0.5514    0.6697    0.6048      3012
           8     0.6942    0.5724    0.6275      2982
           9     0.8168    0.5382    0.6489      2982

    accuracy                         0.5707     30000
   macro avg     0.6561    0.5705    0.5883     30000
weighted avg     0.6563    0.5707    0.5885     30000





0.49084092158498716




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9900592202953847




CCA coefficients mean non-concern: 0.9886373542725475




Linear CKA concern: 0.9515608224206803




Linear CKA non-concern: 0.9613694934599215




Kernel CKA concern: 0.8888549056403889




Kernel CKA non-concern: 0.9089970132062115




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.7028415203094482




Total heads to prune: 8




{(1, 2), (0, 3), (2, 0), (3, 0), (0, 2), (2, 2), (1, 0), (1, 3)}




Evaluate the pruned model 2




Evaluating the model:   0%|                                                                                   …

Loss: 1.3523




Precision: 0.6562, Recall: 0.5708, F1-Score: 0.5885




              precision    recall  f1-score   support

           0     0.5059    0.5307    0.5180      2992
           1     0.6830    0.3767    0.4856      2992
           2     0.6771    0.5840    0.6271      3012
           3     0.2670    0.7161    0.3890      2998
           4     0.8097    0.6384    0.7139      2973
           5     0.8907    0.6778    0.7698      3054
           6     0.6706    0.3953    0.4974      3003
           7     0.5525    0.6707    0.6059      3012
           8     0.6882    0.5788    0.6288      2982
           9     0.8168    0.5396    0.6498      2982

    accuracy                         0.5710     30000
   macro avg     0.6562    0.5708    0.5885     30000
weighted avg     0.6563    0.5710    0.5888     30000





0.49084092158498716




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9882439317251952




CCA coefficients mean non-concern: 0.9895121415822298




Linear CKA concern: 0.9610967732857303




Linear CKA non-concern: 0.9557435857228915




Kernel CKA concern: 0.9033341315204416




Kernel CKA non-concern: 0.8974413578873343




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.704465389251709




Total heads to prune: 8




{(1, 2), (0, 3), (2, 0), (3, 0), (0, 2), (2, 2), (1, 0), (1, 3)}




Evaluate the pruned model 3




Evaluating the model:   0%|                                                                                   …

Loss: 1.3515




Precision: 0.6560, Recall: 0.5710, F1-Score: 0.5886




              precision    recall  f1-score   support

           0     0.5070    0.5321    0.5192      2992
           1     0.6808    0.3763    0.4847      2992
           2     0.6798    0.5837    0.6281      3012
           3     0.2675    0.7148    0.3894      2998
           4     0.8096    0.6408    0.7154      2973
           5     0.8893    0.6788    0.7699      3054
           6     0.6708    0.3963    0.4982      3003
           7     0.5474    0.6733    0.6038      3012
           8     0.6918    0.5744    0.6277      2982
           9     0.8163    0.5396    0.6497      2982

    accuracy                         0.5712     30000
   macro avg     0.6560    0.5710    0.5886     30000
weighted avg     0.6562    0.5712    0.5888     30000





0.49084092158498716




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9897672984715888




CCA coefficients mean non-concern: 0.9884413902645156




Linear CKA concern: 0.9577905179053818




Linear CKA non-concern: 0.9539231797666361




Kernel CKA concern: 0.8995788400176618




Kernel CKA non-concern: 0.8974804177880514




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.6999197006225586




Total heads to prune: 8




{(1, 2), (0, 3), (2, 0), (3, 0), (0, 2), (2, 2), (1, 0), (1, 3)}




Evaluate the pruned model 4




Evaluating the model:   0%|                                                                                   …

Loss: 1.3484




Precision: 0.6562, Recall: 0.5731, F1-Score: 0.5904




              precision    recall  f1-score   support

           0     0.5016    0.5404    0.5203      2992
           1     0.6821    0.3830    0.4906      2992
           2     0.6772    0.5850    0.6277      3012
           3     0.2702    0.7135    0.3919      2998
           4     0.8076    0.6468    0.7183      2973
           5     0.8881    0.6834    0.7724      3054
           6     0.6731    0.3969    0.4994      3003
           7     0.5544    0.6700    0.6067      3012
           8     0.6934    0.5718    0.6267      2982
           9     0.8148    0.5399    0.6495      2982

    accuracy                         0.5733     30000
   macro avg     0.6562    0.5731    0.5904     30000
weighted avg     0.6564    0.5733    0.5906     30000





0.49084092158498716




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9780254653457862




CCA coefficients mean non-concern: 0.9888670764351943




Linear CKA concern: 0.9237862531265385




Linear CKA non-concern: 0.9574953988452528




Kernel CKA concern: 0.8484464896059045




Kernel CKA non-concern: 0.9011055265252765




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.689180850982666




Total heads to prune: 8




{(1, 2), (0, 3), (2, 0), (3, 0), (0, 2), (2, 2), (1, 0), (1, 3)}




Evaluate the pruned model 5




Evaluating the model:   0%|                                                                                   …

Loss: 1.3515




Precision: 0.6569, Recall: 0.5719, F1-Score: 0.5895




              precision    recall  f1-score   support

           0     0.5049    0.5378    0.5208      2992
           1     0.6779    0.3847    0.4908      2992
           2     0.6767    0.5880    0.6292      3012
           3     0.2692    0.7151    0.3911      2998
           4     0.8134    0.6394    0.7160      2973
           5     0.8907    0.6807    0.7717      3054
           6     0.6740    0.3959    0.4988      3003
           7     0.5477    0.6743    0.6045      3012
           8     0.6924    0.5714    0.6261      2982
           9     0.8217    0.5315    0.6455      2982

    accuracy                         0.5721     30000
   macro avg     0.6569    0.5719    0.5895     30000
weighted avg     0.6571    0.5721    0.5897     30000





0.49084092158498716




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9775624214460155




CCA coefficients mean non-concern: 0.9886889361319545




Linear CKA concern: 0.8650065670318329




Linear CKA non-concern: 0.9613939001584538




Kernel CKA concern: 0.7750027232262217




Kernel CKA non-concern: 0.9058454249082181




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.7006912231445312




Total heads to prune: 8




{(1, 2), (0, 3), (2, 0), (3, 0), (0, 2), (2, 2), (1, 0), (1, 3)}




Evaluate the pruned model 6




Evaluating the model:   0%|                                                                                   …

Loss: 1.3502




Precision: 0.6562, Recall: 0.5721, F1-Score: 0.5898




              precision    recall  f1-score   support

           0     0.5052    0.5351    0.5197      2992
           1     0.6814    0.3824    0.4898      2992
           2     0.6806    0.5843    0.6288      3012
           3     0.2683    0.7125    0.3898      2998
           4     0.8102    0.6391    0.7146      2973
           5     0.8891    0.6798    0.7705      3054
           6     0.6706    0.3986    0.5000      3003
           7     0.5513    0.6703    0.6050      3012
           8     0.6889    0.5785    0.6289      2982
           9     0.8166    0.5406    0.6505      2982

    accuracy                         0.5723     30000
   macro avg     0.6562    0.5721    0.5898     30000
weighted avg     0.6564    0.5723    0.5900     30000





0.49084092158498716




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9897496872380304




CCA coefficients mean non-concern: 0.9882351594847932




Linear CKA concern: 0.9594896166365167




Linear CKA non-concern: 0.957386264639991




Kernel CKA concern: 0.8949089707972767




Kernel CKA non-concern: 0.9023218118251851




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.6968348026275635




Total heads to prune: 8




{(1, 2), (0, 3), (2, 0), (3, 0), (0, 2), (2, 2), (1, 0), (1, 3)}




Evaluate the pruned model 7




Evaluating the model:   0%|                                                                                   …

Loss: 1.3500




Precision: 0.6562, Recall: 0.5727, F1-Score: 0.5901




              precision    recall  f1-score   support

           0     0.5049    0.5354    0.5197      2992
           1     0.6750    0.3867    0.4917      2992
           2     0.6786    0.5860    0.6289      3012
           3     0.2699    0.7108    0.3913      2998
           4     0.8129    0.6414    0.7171      2973
           5     0.8896    0.6804    0.7711      3054
           6     0.6708    0.3983    0.4998      3003
           7     0.5469    0.6760    0.6046      3012
           8     0.6912    0.5771    0.6290      2982
           9     0.8221    0.5345    0.6478      2982

    accuracy                         0.5729     30000
   macro avg     0.6562    0.5727    0.5901     30000
weighted avg     0.6564    0.5729    0.5903     30000





0.49084092158498716




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9878728833394128




CCA coefficients mean non-concern: 0.988661535823688




Linear CKA concern: 0.9516994576768782




Linear CKA non-concern: 0.9531601982476552




Kernel CKA concern: 0.8806766259961886




Kernel CKA non-concern: 0.8949415698582099




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.694831132888794




Total heads to prune: 8




{(1, 2), (0, 3), (2, 0), (3, 0), (0, 2), (2, 2), (1, 0), (1, 3)}




Evaluate the pruned model 8




Evaluating the model:   0%|                                                                                   …

Loss: 1.3507




Precision: 0.6567, Recall: 0.5726, F1-Score: 0.5902




              precision    recall  f1-score   support

           0     0.5103    0.5297    0.5198      2992
           1     0.6847    0.3803    0.4890      2992
           2     0.6768    0.5847    0.6274      3012
           3     0.2677    0.7141    0.3894      2998
           4     0.8090    0.6455    0.7181      2973
           5     0.8886    0.6814    0.7713      3054
           6     0.6697    0.3969    0.4984      3003
           7     0.5538    0.6716    0.6071      3012
           8     0.6890    0.5795    0.6295      2982
           9     0.8175    0.5423    0.6520      2982

    accuracy                         0.5728     30000
   macro avg     0.6567    0.5726    0.5902     30000
weighted avg     0.6569    0.5728    0.5904     30000





0.49084092158498716




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9863280662538073




CCA coefficients mean non-concern: 0.9884943355724414




Linear CKA concern: 0.9582657512021315




Linear CKA non-concern: 0.9548443894226948




Kernel CKA concern: 0.8892132789439817




Kernel CKA non-concern: 0.8992485468376961




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.6995773315429688




Total heads to prune: 8




{(1, 2), (0, 3), (2, 0), (3, 0), (0, 2), (2, 2), (1, 0), (1, 3)}




Evaluate the pruned model 9




Evaluating the model:   0%|                                                                                   …

Loss: 1.3503




Precision: 0.6564, Recall: 0.5718, F1-Score: 0.5895




              precision    recall  f1-score   support

           0     0.5065    0.5314    0.5187      2992
           1     0.6798    0.3854    0.4919      2992
           2     0.6786    0.5853    0.6285      3012
           3     0.2682    0.7158    0.3902      2998
           4     0.8116    0.6374    0.7140      2973
           5     0.8890    0.6791    0.7700      3054
           6     0.6718    0.3959    0.4982      3003
           7     0.5521    0.6716    0.6061      3012
           8     0.6887    0.5771    0.6280      2982
           9     0.8173    0.5386    0.6493      2982

    accuracy                         0.5720     30000
   macro avg     0.6564    0.5718    0.5895     30000
weighted avg     0.6566    0.5720    0.5897     30000





0.49084092158498716




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9895343377639298




CCA coefficients mean non-concern: 0.9887587534147528




Linear CKA concern: 0.9321660130721936




Linear CKA non-concern: 0.9553131710684395




Kernel CKA concern: 0.8584580898756872




Kernel CKA non-concern: 0.8996148301277036




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.6965231895446777




In [9]:
df_list = [report_to_df(df) for df in result_list]
new_df = append_nth_row(df_list)
csv_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
new_df.to_csv(f"results/{csv_name}.csv", index=False)