In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint
from src.models.evaluate import (
    evaluate_model,
    get_sparsity,
    get_similarity,
    get_perplexity,
)
from src.utils.sampling import SamplingDataset
from src.pruning.prune_head import head_importance_prunning
from src.pruning.prune import prune_concern_identification
from src.utils.helper import report_to_df, append_nth_row

In [3]:
name = "bert-4-128-yahoo"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ratio = 0.4
seed = 44
include_layers = ["intermediate", "output"]
exclude_layers = [
    "attention",
]

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-10-15 17:05:14


In [5]:
config = Config(name, device)
num_labels = config.config["num_labels"]
model = load_model(config)

Loading the model.




{

'architectures'

: 

'bert'

,
 

'dataset_name'

: 

'YahooAnswersTopics'

,
 

'model_name'

: 

'models/bert-4-128-yahoo'

,
 

'num_labels'

: 

10

,
 

'tokenizer_name'

: 

'fabriceyhc/bert-base-uncased-yahoo_answers_topics'

}




The model models/bert-4-128-yahoo is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.




train.pkl is loaded from cache.




valid.pkl is loaded from cache.




test.pkl is loaded from cache.




The dataset YahooAnswersTopics is loaded




{

'config_name'

: 

'yahoo_answers_topics'

,
 

'features'

: 

{'first_column': 'question_title', 'second_column': 'topic'}

,
 

'path'

: 

'yahoo_answers_topics'

}




In [7]:
# print("Evaluate the original model")
# result = evaluate_model(model, config, test_dataloader)

In [8]:
result_list = []

for concern in range(config.num_labels):
    config.init_seed()
    positive_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        True,
        4,
        resample=False,
    )
    negative_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        False,
        4,
        resample=False,
    )
    all_samples = SamplingDataset(
        train_dataloader,
        config,
        200,
        num_samples,
        False,
        4,
        resample=False,
    )

    module = copy.deepcopy(model)

    head_importance_prunning(module, config, all_samples, ratio)

    prune_concern_identification(
        module,
        config,
        positive_samples,
        negative_samples,
        include_layers=["intermediate", "output"],
        exclude_layers=["attention"],
        sparsity_ratio=ratio,
        keep_dim=True,
        method="unstructed",
    )

    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, config, test_dataloader)
    result_list.append(result)
    get_sparsity(module)

    get_similarity(model, module, valid_dataloader, concern, num_samples, config)
    print("original model's perplexity")
    get_perplexity(model, valid_dataloader, config)
    print("pruned model's perplexity")
    get_perplexity(module, valid_dataloader, config)

Total heads to prune: 4




{(0, 2), (0, 3), (3, 0), (2, 2)}




Evaluate the pruned model 0




Evaluating the model:   0%|                                                                                   …

Loss: 1.2749




Precision: 0.6509, Recall: 0.5967, F1-Score: 0.6072




              precision    recall  f1-score   support

           0     0.5766    0.4589    0.5111      2992
           1     0.6511    0.4559    0.5363      2992
           2     0.7031    0.5810    0.6362      3012
           3     0.2999    0.6771    0.4157      2998
           4     0.8158    0.6404    0.7175      2973
           5     0.8387    0.7610    0.7979      3054
           6     0.6920    0.3936    0.5018      3003
           7     0.5526    0.6823    0.6106      3012
           8     0.6339    0.6626    0.6480      2982
           9     0.7451    0.6539    0.6966      2982

    accuracy                         0.5969     30000
   macro avg     0.6509    0.5967    0.6072     30000
weighted avg     0.6511    0.5969    0.6074     30000





0.3623786491389163




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9955892146612206




CCA coefficients mean non-concern: 0.9951727464202176




Linear CKA concern: 0.9878265049897255




Linear CKA non-concern: 0.9895917616292194




Kernel CKA concern: 0.9649437153868801




Kernel CKA non-concern: 0.9720742668207027




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.4037067890167236




Total heads to prune: 4




{(0, 2), (0, 3), (3, 0), (2, 2)}




Evaluate the pruned model 1




Evaluating the model:   0%|                                                                                   …

Loss: 1.2735




Precision: 0.6511, Recall: 0.5967, F1-Score: 0.6073




              precision    recall  f1-score   support

           0     0.5774    0.4589    0.5114      2992
           1     0.6509    0.4542    0.5350      2992
           2     0.7023    0.5820    0.6365      3012
           3     0.2992    0.6775    0.4150      2998
           4     0.8158    0.6421    0.7186      2973
           5     0.8395    0.7620    0.7988      3054
           6     0.6910    0.3933    0.5013      3003
           7     0.5538    0.6806    0.6107      3012
           8     0.6359    0.6623    0.6488      2982
           9     0.7450    0.6546    0.6969      2982

    accuracy                         0.5970     30000
   macro avg     0.6511    0.5967    0.6073     30000
weighted avg     0.6512    0.5970    0.6075     30000





0.3623786491389163




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9962927573480796




CCA coefficients mean non-concern: 0.9951695101809369




Linear CKA concern: 0.9896284615784078




Linear CKA non-concern: 0.9900425866522251




Kernel CKA concern: 0.9747486833104223




Kernel CKA non-concern: 0.9714548243119905




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.4000022411346436




Total heads to prune: 4




{(0, 2), (0, 3), (3, 0), (2, 2)}




Evaluate the pruned model 2




Evaluating the model:   0%|                                                                                   …

Loss: 1.2744




Precision: 0.6511, Recall: 0.5964, F1-Score: 0.6071




              precision    recall  f1-score   support

           0     0.5768    0.4569    0.5099      2992
           1     0.6502    0.4542    0.5348      2992
           2     0.7032    0.5813    0.6365      3012
           3     0.2987    0.6781    0.4147      2998
           4     0.8155    0.6408    0.7176      2973
           5     0.8405    0.7610    0.7988      3054
           6     0.6918    0.3946    0.5025      3003
           7     0.5534    0.6799    0.6102      3012
           8     0.6357    0.6630    0.6490      2982
           9     0.7452    0.6543    0.6968      2982

    accuracy                         0.5967     30000
   macro avg     0.6511    0.5964    0.6071     30000
weighted avg     0.6513    0.5967    0.6073     30000





0.3623786491389163




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9953093580605726




CCA coefficients mean non-concern: 0.9955805998055423




Linear CKA concern: 0.9878518196699968




Linear CKA non-concern: 0.989049199513044




Kernel CKA concern: 0.9654798403489776




Kernel CKA non-concern: 0.9698737714174203




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.4028820991516113




Total heads to prune: 4




{(0, 2), (0, 3), (3, 0), (2, 2)}




Evaluate the pruned model 3




Evaluating the model:   0%|                                                                                   …

Loss: 1.2741




Precision: 0.6512, Recall: 0.5965, F1-Score: 0.6072




              precision    recall  f1-score   support

           0     0.5776    0.4589    0.5115      2992
           1     0.6530    0.4529    0.5348      2992
           2     0.7045    0.5810    0.6368      3012
           3     0.2989    0.6781    0.4149      2998
           4     0.8155    0.6438    0.7195      2973
           5     0.8404    0.7603    0.7983      3054
           6     0.6898    0.3946    0.5020      3003
           7     0.5521    0.6803    0.6095      3012
           8     0.6342    0.6610    0.6473      2982
           9     0.7458    0.6543    0.6970      2982

    accuracy                         0.5968     30000
   macro avg     0.6512    0.5965    0.6072     30000
weighted avg     0.6514    0.5968    0.6074     30000





0.3623786491389163




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9960750816849766




CCA coefficients mean non-concern: 0.9952310384945579




Linear CKA concern: 0.9867788297878546




Linear CKA non-concern: 0.9892967463248217




Kernel CKA concern: 0.966327459299585




Kernel CKA non-concern: 0.9718646501556252




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.4012269973754883




Total heads to prune: 4




{(0, 2), (0, 3), (3, 0), (2, 2)}




Evaluate the pruned model 4




Evaluating the model:   0%|                                                                                   …

Loss: 1.2736




Precision: 0.6515, Recall: 0.5969, F1-Score: 0.6075




              precision    recall  f1-score   support

           0     0.5754    0.4589    0.5106      2992
           1     0.6533    0.4535    0.5354      2992
           2     0.7030    0.5823    0.6370      3012
           3     0.2991    0.6785    0.4152      2998
           4     0.8156    0.6458    0.7209      2973
           5     0.8400    0.7600    0.7980      3054
           6     0.6933    0.3936    0.5021      3003
           7     0.5553    0.6790    0.6109      3012
           8     0.6340    0.6633    0.6483      2982
           9     0.7455    0.6543    0.6969      2982

    accuracy                         0.5972     30000
   macro avg     0.6515    0.5969    0.6075     30000
weighted avg     0.6516    0.5972    0.6078     30000





0.3623786491389163




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9886157645134144




CCA coefficients mean non-concern: 0.9959123953762857




Linear CKA concern: 0.9597914583960332




Linear CKA non-concern: 0.9889347244318444




Kernel CKA concern: 0.9192205668329392




Kernel CKA non-concern: 0.971292963919534




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.398770332336426




Total heads to prune: 4




{(0, 2), (0, 3), (3, 0), (2, 2)}




Evaluate the pruned model 5




Evaluating the model:   0%|                                                                                   …

Loss: 1.2728




Precision: 0.6504, Recall: 0.5972, F1-Score: 0.6076




              precision    recall  f1-score   support

           0     0.5730    0.4632    0.5123      2992
           1     0.6482    0.4582    0.5369      2992
           2     0.7006    0.5850    0.6376      3012
           3     0.3009    0.6744    0.4161      2998
           4     0.8159    0.6394    0.7170      2973
           5     0.8404    0.7606    0.7986      3054
           6     0.6900    0.3943    0.5018      3003
           7     0.5553    0.6806    0.6116      3012
           8     0.6337    0.6630    0.6480      2982
           9     0.7463    0.6529    0.6965      2982

    accuracy                         0.5974     30000
   macro avg     0.6504    0.5972    0.6076     30000
weighted avg     0.6506    0.5974    0.6079     30000





0.3623786491389163




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9921976771558572




CCA coefficients mean non-concern: 0.9953061466137832




Linear CKA concern: 0.9780016513735531




Linear CKA non-concern: 0.9909063919175897




Kernel CKA concern: 0.9631693471555204




Kernel CKA non-concern: 0.9731251339592




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.397416591644287




Total heads to prune: 4




{(0, 2), (0, 3), (3, 0), (2, 2)}




Evaluate the pruned model 6




Evaluating the model:   0%|                                                                                   …

Loss: 1.2737




Precision: 0.6512, Recall: 0.5967, F1-Score: 0.6073




              precision    recall  f1-score   support

           0     0.5772    0.4576    0.5104      2992
           1     0.6507    0.4545    0.5352      2992
           2     0.7035    0.5813    0.6366      3012
           3     0.2992    0.6788    0.4153      2998
           4     0.8148    0.6438    0.7193      2973
           5     0.8404    0.7606    0.7986      3054
           6     0.6920    0.3943    0.5023      3003
           7     0.5538    0.6799    0.6104      3012
           8     0.6348    0.6616    0.6479      2982
           9     0.7452    0.6543    0.6968      2982

    accuracy                         0.5969     30000
   macro avg     0.6512    0.5967    0.6073     30000
weighted avg     0.6513    0.5969    0.6075     30000





0.3623786491389163




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9961440635456665




CCA coefficients mean non-concern: 0.9951032988987183




Linear CKA concern: 0.9869439239083544




Linear CKA non-concern: 0.9901455324032693




Kernel CKA concern: 0.9610711186509054




Kernel CKA non-concern: 0.9728018157505592




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.4005250930786133




Total heads to prune: 4




{(0, 2), (0, 3), (3, 0), (2, 2)}




Evaluate the pruned model 7




Evaluating the model:   0%|                                                                                   …

Loss: 1.2734




Precision: 0.6512, Recall: 0.5968, F1-Score: 0.6075




              precision    recall  f1-score   support

           0     0.5765    0.4596    0.5114      2992
           1     0.6500    0.4576    0.5371      2992
           2     0.7045    0.5817    0.6372      3012
           3     0.2990    0.6775    0.4149      2998
           4     0.8164    0.6414    0.7184      2973
           5     0.8397    0.7613    0.7986      3054
           6     0.6895    0.3949    0.5022      3003
           7     0.5560    0.6809    0.6121      3012
           8     0.6343    0.6603    0.6471      2982
           9     0.7458    0.6533    0.6965      2982

    accuracy                         0.5971     30000
   macro avg     0.6512    0.5968    0.6075     30000
weighted avg     0.6513    0.5971    0.6078     30000





0.3623786491389163




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9955815567422147




CCA coefficients mean non-concern: 0.9951531298889581




Linear CKA concern: 0.9868845730122829




Linear CKA non-concern: 0.9893619703418218




Kernel CKA concern: 0.9696850912465571




Kernel CKA non-concern: 0.9715488681264534




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.3992464542388916




Total heads to prune: 4




{(0, 2), (0, 3), (3, 0), (2, 2)}




Evaluate the pruned model 8




Evaluating the model:   0%|                                                                                   …

Loss: 1.2745




Precision: 0.6512, Recall: 0.5962, F1-Score: 0.6070




              precision    recall  f1-score   support

           0     0.5777    0.4559    0.5096      2992
           1     0.6520    0.4522    0.5340      2992
           2     0.7060    0.5803    0.6370      3012
           3     0.2981    0.6788    0.4142      2998
           4     0.8124    0.6451    0.7192      2973
           5     0.8419    0.7603    0.7990      3054
           6     0.6911    0.3949    0.5026      3003
           7     0.5532    0.6790    0.6096      3012
           8     0.6346    0.6616    0.6478      2982
           9     0.7447    0.6543    0.6965      2982

    accuracy                         0.5965     30000
   macro avg     0.6512    0.5962    0.6070     30000
weighted avg     0.6513    0.5965    0.6072     30000





0.3623786491389163




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9951926469089483




CCA coefficients mean non-concern: 0.995094182782132




Linear CKA concern: 0.9874070725086674




Linear CKA non-concern: 0.989216678142034




Kernel CKA concern: 0.9678947818334719




Kernel CKA non-concern: 0.9718283659851983




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.4033639430999756




Total heads to prune: 4




{(0, 2), (0, 3), (3, 0), (2, 2)}




Evaluate the pruned model 9




Evaluating the model:   0%|                                                                                   …

Loss: 1.2750




Precision: 0.6510, Recall: 0.5965, F1-Score: 0.6071




              precision    recall  f1-score   support

           0     0.5754    0.4589    0.5106      2992
           1     0.6496    0.4572    0.5367      2992
           2     0.7033    0.5817    0.6367      3012
           3     0.3000    0.6781    0.4160      2998
           4     0.8176    0.6377    0.7166      2973
           5     0.8407    0.7603    0.7985      3054
           6     0.6914    0.3939    0.5019      3003
           7     0.5517    0.6823    0.6101      3012
           8     0.6341    0.6620    0.6477      2982
           9     0.7461    0.6533    0.6966      2982

    accuracy                         0.5968     30000
   macro avg     0.6510    0.5965    0.6071     30000
weighted avg     0.6512    0.5968    0.6074     30000





0.3623786491389163




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.995747252704555




CCA coefficients mean non-concern: 0.9950949207189627




Linear CKA concern: 0.9820191279209517




Linear CKA non-concern: 0.9893338801277923




Kernel CKA concern: 0.9603015531213526




Kernel CKA non-concern: 0.9716998234024933




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.4048655033111572




In [9]:
df_list = [report_to_df(df) for df in result_list]
new_df = append_nth_row(df_list)
csv_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
new_df.to_csv(f"results/{csv_name}.csv", index=False)