In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint
from src.models.evaluate import (
    evaluate_model,
    get_sparsity,
    get_similarity,
    get_perplexity,
)
from src.utils.sampling import SamplingDataset
from src.pruning.prune_head import head_importance_prunning
from src.pruning.prune import prune_concern_identification
from src.utils.helper import report_to_df, append_nth_row

In [3]:
name = "bert-6-128-yahoo"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ratio = 0.5
seed = 44
include_layers = ["intermediate", "output"]
exclude_layers = [
    "attention",
]

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-10-15 18:02:31


In [5]:
config = Config(name, device)
num_labels = config.config["num_labels"]
model = load_model(config)

Loading the model.




{

'architectures'

: 

'bert'

,
 

'dataset_name'

: 

'YahooAnswersTopics'

,
 

'model_name'

: 

'models/bert-6-128-yahoo'

,
 

'num_labels'

: 

10

,
 

'tokenizer_name'

: 

'fabriceyhc/bert-base-uncased-yahoo_answers_topics'

}




The model models/bert-6-128-yahoo is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.




train.pkl is loaded from cache.




valid.pkl is loaded from cache.




test.pkl is loaded from cache.




The dataset YahooAnswersTopics is loaded




{

'config_name'

: 

'yahoo_answers_topics'

,
 

'features'

: 

{'first_column': 'question_title', 'second_column': 'topic'}

,
 

'path'

: 

'yahoo_answers_topics'

}




In [7]:
# print("Evaluate the original model")
# result = evaluate_model(model, config, test_dataloader)

In [8]:
result_list = []

for concern in range(config.num_labels):
    config.init_seed()
    positive_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        True,
        4,
        resample=False,
    )
    negative_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        False,
        4,
        resample=False,
    )
    all_samples = SamplingDataset(
        train_dataloader,
        config,
        200,
        num_samples,
        False,
        4,
        resample=False,
    )

    module = copy.deepcopy(model)

    head_importance_prunning(module, config, all_samples, ratio)

    prune_concern_identification(
        module,
        config,
        positive_samples,
        negative_samples,
        include_layers=["intermediate", "output"],
        exclude_layers=["attention"],
        sparsity_ratio=ratio,
        keep_dim=True,
        method="unstructed",
    )

    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, config, test_dataloader)
    result_list.append(result)
    get_sparsity(module)

    get_similarity(model, module, valid_dataloader, concern, num_samples, config)
    print("original model's perplexity")
    get_perplexity(model, valid_dataloader, config)
    print("pruned model's perplexity")
    get_perplexity(module, valid_dataloader, config)

Total heads to prune: 4




{(1, 1), (4, 0), (5, 1), (0, 0)}




Evaluate the pruned model 0




Evaluating the model:   0%|                                                                                   …

Loss: 1.2603




Precision: 0.6402, Recall: 0.6055, F1-Score: 0.6112




              precision    recall  f1-score   support

           0     0.4479    0.5478    0.4929      2992
           1     0.7326    0.4652    0.5691      2992
           2     0.6623    0.6484    0.6553      3012
           3     0.3458    0.6117    0.4419      2998
           4     0.7909    0.6858    0.7346      2973
           5     0.8046    0.7793    0.7917      3054
           6     0.6832    0.3863    0.4935      3003
           7     0.6202    0.6278    0.6240      3012
           8     0.6446    0.6278    0.6361      2982
           9     0.6702    0.6747    0.6725      2982

    accuracy                         0.6057     30000
   macro avg     0.6402    0.6055    0.6112     30000
weighted avg     0.6404    0.6057    0.6114     30000





0.43532426297076793




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9883729990557054




CCA coefficients mean non-concern: 0.9885008058166549




Linear CKA concern: 0.9543155949404444




Linear CKA non-concern: 0.9544365007379255




Kernel CKA concern: 0.917219276319116




Kernel CKA non-concern: 0.9012599396787875




original model's perplexity




3.187649726867676




pruned model's perplexity




3.3325581550598145




Total heads to prune: 4




{(1, 1), (4, 0), (5, 1), (0, 0)}




Evaluate the pruned model 1




Evaluating the model:   0%|                                                                                   …

Loss: 1.2598




Precision: 0.6400, Recall: 0.6053, F1-Score: 0.6107




              precision    recall  f1-score   support

           0     0.4548    0.5455    0.4960      2992
           1     0.7334    0.4616    0.5666      2992
           2     0.6583    0.6517    0.6550      3012
           3     0.3442    0.6137    0.4411      2998
           4     0.7905    0.6879    0.7356      2973
           5     0.7990    0.7836    0.7912      3054
           6     0.6865    0.3849    0.4933      3003
           7     0.6203    0.6255    0.6229      3012
           8     0.6444    0.6234    0.6337      2982
           9     0.6682    0.6754    0.6718      2982

    accuracy                         0.6056     30000
   macro avg     0.6400    0.6053    0.6107     30000
weighted avg     0.6401    0.6056    0.6109     30000





0.43532426297076793




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9928303938011273




CCA coefficients mean non-concern: 0.9893632842365275




Linear CKA concern: 0.9563803436330668




Linear CKA non-concern: 0.955269535027096




Kernel CKA concern: 0.8919112739364576




Kernel CKA non-concern: 0.9025331098613393




original model's perplexity




3.187649726867676




pruned model's perplexity




3.3322556018829346




Total heads to prune: 4




{(1, 1), (4, 0), (5, 1), (0, 0)}




Evaluate the pruned model 2




Evaluating the model:   0%|                                                                                   …

Loss: 1.2606




Precision: 0.6404, Recall: 0.6058, F1-Score: 0.6112




              precision    recall  f1-score   support

           0     0.4519    0.5451    0.4942      2992
           1     0.7343    0.4656    0.5699      2992
           2     0.6598    0.6504    0.6551      3012
           3     0.3458    0.6121    0.4419      2998
           4     0.7932    0.6862    0.7358      2973
           5     0.8006    0.7809    0.7907      3054
           6     0.6898    0.3843    0.4936      3003
           7     0.6206    0.6268    0.6237      3012
           8     0.6447    0.6261    0.6353      2982
           9     0.6636    0.6801    0.6717      2982

    accuracy                         0.6060     30000
   macro avg     0.6404    0.6058    0.6112     30000
weighted avg     0.6406    0.6060    0.6114     30000





0.43532426297076793




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9884341011825125




CCA coefficients mean non-concern: 0.9890849556151184




Linear CKA concern: 0.9537614879364886




Linear CKA non-concern: 0.953252057695967




Kernel CKA concern: 0.9206208750316777




Kernel CKA non-concern: 0.9000537795585225




original model's perplexity




3.187649726867676




pruned model's perplexity




3.3338773250579834




Total heads to prune: 4




{(1, 1), (4, 0), (5, 1), (0, 0)}




Evaluate the pruned model 3




Evaluating the model:   0%|                                                                                   …

Loss: 1.2601




Precision: 0.6403, Recall: 0.6057, F1-Score: 0.6111




              precision    recall  f1-score   support

           0     0.4509    0.5458    0.4938      2992
           1     0.7335    0.4636    0.5681      2992
           2     0.6608    0.6494    0.6551      3012
           3     0.3459    0.6134    0.4424      2998
           4     0.7916    0.6872    0.7357      2973
           5     0.8006    0.7809    0.7907      3054
           6     0.6868    0.3856    0.4939      3003
           7     0.6187    0.6298    0.6242      3012
           8     0.6467    0.6247    0.6355      2982
           9     0.6680    0.6761    0.6720      2982

    accuracy                         0.6059     30000
   macro avg     0.6403    0.6057    0.6111     30000
weighted avg     0.6405    0.6059    0.6114     30000





0.43532426297076793




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9905624609553497




CCA coefficients mean non-concern: 0.9899902959644548




Linear CKA concern: 0.9542026276343404




Linear CKA non-concern: 0.9512290498863049




Kernel CKA concern: 0.901464564053652




Kernel CKA non-concern: 0.9008335723191289




original model's perplexity




3.187649726867676




pruned model's perplexity




3.3337700366973877




Total heads to prune: 4




{(1, 1), (4, 0), (5, 1), (0, 0)}




Evaluate the pruned model 4




Evaluating the model:   0%|                                                                                   …

Loss: 1.2592




Precision: 0.6402, Recall: 0.6065, F1-Score: 0.6117




              precision    recall  f1-score   support

           0     0.4496    0.5471    0.4936      2992
           1     0.7356    0.4686    0.5725      2992
           2     0.6668    0.6491    0.6578      3012
           3     0.3476    0.6061    0.4418      2998
           4     0.7885    0.6895    0.7357      2973
           5     0.7996    0.7839    0.7917      3054
           6     0.6881    0.3843    0.4932      3003
           7     0.6211    0.6265    0.6238      3012
           8     0.6440    0.6298    0.6368      2982
           9     0.6611    0.6804    0.6706      2982

    accuracy                         0.6068     30000
   macro avg     0.6402    0.6065    0.6117     30000
weighted avg     0.6404    0.6068    0.6120     30000





0.43532426297076793




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9895006831277655




CCA coefficients mean non-concern: 0.9896598005336579




Linear CKA concern: 0.9415632094248048




Linear CKA non-concern: 0.9540266295775218




Kernel CKA concern: 0.9144769943812165




Kernel CKA non-concern: 0.8984252933543665




original model's perplexity




3.187649726867676




pruned model's perplexity




3.3293867111206055




Total heads to prune: 4




{(1, 1), (4, 0), (5, 1), (0, 0)}




Evaluate the pruned model 5




Evaluating the model:   0%|                                                                                   …

Loss: 1.2616




Precision: 0.6412, Recall: 0.6052, F1-Score: 0.6110




              precision    recall  f1-score   support

           0     0.4479    0.5491    0.4934      2992
           1     0.7350    0.4626    0.5678      2992
           2     0.6656    0.6464    0.6559      3012
           3     0.3443    0.6161    0.4418      2998
           4     0.7905    0.6842    0.7335      2973
           5     0.8053    0.7790    0.7919      3054
           6     0.6874    0.3836    0.4924      3003
           7     0.6191    0.6282    0.6236      3012
           8     0.6450    0.6281    0.6364      2982
           9     0.6713    0.6747    0.6730      2982

    accuracy                         0.6054     30000
   macro avg     0.6412    0.6052    0.6110     30000
weighted avg     0.6413    0.6054    0.6112     30000





0.43532426297076793




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9910488973812739




CCA coefficients mean non-concern: 0.9906293268357687




Linear CKA concern: 0.9492794647891969




Linear CKA non-concern: 0.9562728940866897




Kernel CKA concern: 0.9355102324023601




Kernel CKA non-concern: 0.9000598438047628




original model's perplexity




3.187649726867676




pruned model's perplexity




3.33941650390625




Total heads to prune: 4




{(1, 1), (4, 0), (5, 1), (0, 0)}




Evaluate the pruned model 6




Evaluating the model:   0%|                                                                                   …

Loss: 1.2607




Precision: 0.6403, Recall: 0.6049, F1-Score: 0.6106




              precision    recall  f1-score   support

           0     0.4489    0.5461    0.4928      2992
           1     0.7343    0.4599    0.5656      2992
           2     0.6643    0.6458    0.6549      3012
           3     0.3440    0.6141    0.4410      2998
           4     0.7911    0.6855    0.7345      2973
           5     0.8051    0.7803    0.7925      3054
           6     0.6829    0.3866    0.4937      3003
           7     0.6179    0.6288    0.6233      3012
           8     0.6440    0.6268    0.6353      2982
           9     0.6701    0.6751    0.6726      2982

    accuracy                         0.6051     30000
   macro avg     0.6403    0.6049    0.6106     30000
weighted avg     0.6405    0.6051    0.6108     30000





0.43532426297076793




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9906260841188758




CCA coefficients mean non-concern: 0.9880637899401608




Linear CKA concern: 0.9582152400558451




Linear CKA non-concern: 0.9537012805883647




Kernel CKA concern: 0.9081622586156763




Kernel CKA non-concern: 0.9047977741066352




original model's perplexity




3.187649726867676




pruned model's perplexity




3.3350608348846436




Total heads to prune: 4




{(1, 1), (4, 0), (5, 1), (0, 0)}




Evaluate the pruned model 7




Evaluating the model:   0%|                                                                                   …

Loss: 1.2600




Precision: 0.6405, Recall: 0.6060, F1-Score: 0.6116




              precision    recall  f1-score   support

           0     0.4502    0.5475    0.4941      2992
           1     0.7328    0.4656    0.5694      2992
           2     0.6641    0.6484    0.6561      3012
           3     0.3472    0.6134    0.4434      2998
           4     0.7898    0.6862    0.7343      2973
           5     0.8062    0.7780    0.7919      3054
           6     0.6834    0.3859    0.4933      3003
           7     0.6157    0.6325    0.6240      3012
           8     0.6463    0.6274    0.6367      2982
           9     0.6695    0.6754    0.6725      2982

    accuracy                         0.6063     30000
   macro avg     0.6405    0.6060    0.6116     30000
weighted avg     0.6407    0.6063    0.6118     30000





0.43532426297076793




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9902064166339766




CCA coefficients mean non-concern: 0.9901667516086268




Linear CKA concern: 0.951204230492135




Linear CKA non-concern: 0.9502263656374578




Kernel CKA concern: 0.9217078060160352




Kernel CKA non-concern: 0.8995430844249939




original model's perplexity




3.187649726867676




pruned model's perplexity




3.3334856033325195




Total heads to prune: 4




{(1, 1), (4, 0), (5, 1), (0, 0)}




Evaluate the pruned model 8




Evaluating the model:   0%|                                                                                   …

Loss: 1.2607




Precision: 0.6398, Recall: 0.6054, F1-Score: 0.6110




              precision    recall  f1-score   support

           0     0.4486    0.5468    0.4928      2992
           1     0.7315    0.4689    0.5715      2992
           2     0.6627    0.6471    0.6548      3012
           3     0.3468    0.6114    0.4426      2998
           4     0.7916    0.6848    0.7344      2973
           5     0.8038    0.7793    0.7914      3054
           6     0.6858    0.3853    0.4934      3003
           7     0.6215    0.6265    0.6240      3012
           8     0.6427    0.6244    0.6334      2982
           9     0.6633    0.6797    0.6714      2982

    accuracy                         0.6057     30000
   macro avg     0.6398    0.6054    0.6110     30000
weighted avg     0.6400    0.6057    0.6112     30000





0.43532426297076793




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9920661434285547




CCA coefficients mean non-concern: 0.9887094194314243




Linear CKA concern: 0.9569131124436331




Linear CKA non-concern: 0.9501930256809382




Kernel CKA concern: 0.8775974293823315




Kernel CKA non-concern: 0.8995603133159218




original model's perplexity




3.187649726867676




pruned model's perplexity




3.3353195190429688




Total heads to prune: 4




{(1, 1), (4, 0), (5, 1), (0, 0)}




Evaluate the pruned model 9




Evaluating the model:   0%|                                                                                   …

Loss: 1.2613




Precision: 0.6399, Recall: 0.6048, F1-Score: 0.6106




              precision    recall  f1-score   support

           0     0.4465    0.5468    0.4916      2992
           1     0.7321    0.4659    0.5694      2992
           2     0.6643    0.6451    0.6545      3012
           3     0.3459    0.6144    0.4426      2998
           4     0.7941    0.6798    0.7325      2973
           5     0.8029    0.7803    0.7914      3054
           6     0.6803    0.3869    0.4933      3003
           7     0.6188    0.6278    0.6233      3012
           8     0.6444    0.6278    0.6360      2982
           9     0.6693    0.6734    0.6713      2982

    accuracy                         0.6051     30000
   macro avg     0.6399    0.6048    0.6106     30000
weighted avg     0.6401    0.6051    0.6108     30000





0.43532426297076793




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9915706558756495




CCA coefficients mean non-concern: 0.9886888395094732




Linear CKA concern: 0.9299403132716778




Linear CKA non-concern: 0.9511954959783085




Kernel CKA concern: 0.8640562099276899




Kernel CKA non-concern: 0.9040239956446265




original model's perplexity




3.187649726867676




pruned model's perplexity




3.3379263877868652




In [9]:
df_list = [report_to_df(df) for df in result_list]
new_df = append_nth_row(df_list)
csv_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
new_df.to_csv(f"results/{csv_name}.csv", index=False)