In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint
from src.models.evaluate import (
    evaluate_model,
    get_sparsity,
    get_similarity,
    get_perplexity,
)
from src.utils.sampling import SamplingDataset
from src.pruning.prune_head import head_importance_prunning
from src.pruning.prune import prune_concern_identification
from src.utils.helper import report_to_df, append_nth_row

In [3]:
name = "bert-6-128-yahoo"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ratio = 0.6
seed = 44
include_layers = ["intermediate", "output"]
exclude_layers = [
    "attention",
]

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-10-15 18:14:04


In [5]:
config = Config(name, device)
num_labels = config.config["num_labels"]
model = load_model(config)

Loading the model.




{

'architectures'

: 

'bert'

,
 

'dataset_name'

: 

'YahooAnswersTopics'

,
 

'model_name'

: 

'models/bert-6-128-yahoo'

,
 

'num_labels'

: 

10

,
 

'tokenizer_name'

: 

'fabriceyhc/bert-base-uncased-yahoo_answers_topics'

}




The model models/bert-6-128-yahoo is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.




train.pkl is loaded from cache.




valid.pkl is loaded from cache.




test.pkl is loaded from cache.




The dataset YahooAnswersTopics is loaded




{

'config_name'

: 

'yahoo_answers_topics'

,
 

'features'

: 

{'first_column': 'question_title', 'second_column': 'topic'}

,
 

'path'

: 

'yahoo_answers_topics'

}




In [7]:
# print("Evaluate the original model")
# result = evaluate_model(model, config, test_dataloader)

In [8]:
result_list = []

for concern in range(config.num_labels):
    config.init_seed()
    positive_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        True,
        4,
        resample=False,
    )
    negative_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        False,
        4,
        resample=False,
    )
    all_samples = SamplingDataset(
        train_dataloader,
        config,
        200,
        num_samples,
        False,
        4,
        resample=False,
    )

    module = copy.deepcopy(model)

    head_importance_prunning(module, config, all_samples, ratio)

    prune_concern_identification(
        module,
        config,
        positive_samples,
        negative_samples,
        include_layers=["intermediate", "output"],
        exclude_layers=["attention"],
        sparsity_ratio=ratio,
        keep_dim=True,
        method="unstructed",
    )

    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, config, test_dataloader)
    result_list.append(result)
    get_sparsity(module)

    get_similarity(model, module, valid_dataloader, concern, num_samples, config)
    print("original model's perplexity")
    get_perplexity(model, valid_dataloader, config)
    print("pruned model's perplexity")
    get_perplexity(module, valid_dataloader, config)

Total heads to prune: 6




{(4, 0), (0, 0), (1, 1), (2, 0), (5, 1), (3, 0)}




Evaluate the pruned model 0




Evaluating the model:   0%|                                                                                   …

Loss: 1.2917




Precision: 0.6390, Recall: 0.5991, F1-Score: 0.6047




              precision    recall  f1-score   support

           0     0.3978    0.6230    0.4855      2992
           1     0.7032    0.5130    0.5932      2992
           2     0.6520    0.6458    0.6489      3012
           3     0.3567    0.5664    0.4377      2998
           4     0.7799    0.6879    0.7310      2973
           5     0.8271    0.7646    0.7946      3054
           6     0.7206    0.3410    0.4629      3003
           7     0.6366    0.5747    0.6041      3012
           8     0.6825    0.5744    0.6238      2982
           9     0.6338    0.6999    0.6652      2982

    accuracy                         0.5992     30000
   macro avg     0.6390    0.5991    0.6047     30000
weighted avg     0.6393    0.5992    0.6049     30000





0.5528703163998864




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9811133814287224




CCA coefficients mean non-concern: 0.9816697144903449




Linear CKA concern: 0.8996991769636855




Linear CKA non-concern: 0.9019044189997113




Kernel CKA concern: 0.8304195610792994




Kernel CKA non-concern: 0.8008605914263034




original model's perplexity




3.187649726867676




pruned model's perplexity




3.458436965942383




Total heads to prune: 6




{(4, 0), (0, 0), (1, 1), (2, 0), (5, 1), (3, 0)}




Evaluate the pruned model 1




Evaluating the model:   0%|                                                                                   …

Loss: 1.2918




Precision: 0.6394, Recall: 0.5992, F1-Score: 0.6046




              precision    recall  f1-score   support

           0     0.4044    0.6207    0.4897      2992
           1     0.7091    0.5043    0.5895      2992
           2     0.6473    0.6484    0.6479      3012
           3     0.3546    0.5767    0.4392      2998
           4     0.7818    0.6895    0.7328      2973
           5     0.8209    0.7682    0.7936      3054
           6     0.7219    0.3407    0.4629      3003
           7     0.6383    0.5684    0.6013      3012
           8     0.6809    0.5738    0.6227      2982
           9     0.6346    0.7012    0.6662      2982

    accuracy                         0.5994     30000
   macro avg     0.6394    0.5992    0.6046     30000
weighted avg     0.6396    0.5994    0.6048     30000





0.5528703163998864




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9883449917312663




CCA coefficients mean non-concern: 0.9808406926728107




Linear CKA concern: 0.9039023743064187




Linear CKA non-concern: 0.9043330918547634




Kernel CKA concern: 0.7849607025908711




Kernel CKA non-concern: 0.8051461965513362




original model's perplexity




3.187649726867676




pruned model's perplexity




3.4605212211608887




Total heads to prune: 6




{(4, 0), (0, 0), (1, 1), (2, 0), (5, 1), (3, 0)}




Evaluate the pruned model 2




Evaluating the model:   0%|                                                                                   …

Loss: 1.2932




Precision: 0.6384, Recall: 0.5988, F1-Score: 0.6039




              precision    recall  f1-score   support

           0     0.4019    0.6240    0.4889      2992
           1     0.7043    0.5110    0.5923      2992
           2     0.6496    0.6474    0.6485      3012
           3     0.3577    0.5704    0.4397      2998
           4     0.7782    0.6868    0.7297      2973
           5     0.8226    0.7652    0.7929      3054
           6     0.7227    0.3367    0.4593      3003
           7     0.6352    0.5710    0.6014      3012
           8     0.6833    0.5708    0.6220      2982
           9     0.6291    0.7042    0.6646      2982

    accuracy                         0.5989     30000
   macro avg     0.6384    0.5988    0.6039     30000
weighted avg     0.6387    0.5989    0.6041     30000





0.5528703163998864




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9799780079332511




CCA coefficients mean non-concern: 0.9805476577736506




Linear CKA concern: 0.9059020087880427




Linear CKA non-concern: 0.8955779576964072




Kernel CKA concern: 0.8506833932025479




Kernel CKA non-concern: 0.7924994661384326




original model's perplexity




3.187649726867676




pruned model's perplexity




3.463540554046631




Total heads to prune: 6




{(4, 0), (0, 0), (1, 1), (2, 0), (5, 1), (3, 0)}




Evaluate the pruned model 3




Evaluating the model:   0%|                                                                                   …

Loss: 1.2921




Precision: 0.6390, Recall: 0.5992, F1-Score: 0.6046




              precision    recall  f1-score   support

           0     0.4019    0.6233    0.4887      2992
           1     0.7053    0.5070    0.5899      2992
           2     0.6493    0.6467    0.6480      3012
           3     0.3563    0.5724    0.4392      2998
           4     0.7802    0.6879    0.7311      2973
           5     0.8233    0.7672    0.7942      3054
           6     0.7206    0.3410    0.4629      3003
           7     0.6336    0.5747    0.6027      3012
           8     0.6834    0.5718    0.6226      2982
           9     0.6366    0.6995    0.6666      2982

    accuracy                         0.5993     30000
   macro avg     0.6390    0.5992    0.6046     30000
weighted avg     0.6393    0.5993    0.6048     30000





0.5528703163998864




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9816770284038474




CCA coefficients mean non-concern: 0.9821534439036291




Linear CKA concern: 0.8975800217813014




Linear CKA non-concern: 0.8923412107948516




Kernel CKA concern: 0.7975977042853778




Kernel CKA non-concern: 0.7959158219901502




original model's perplexity




3.187649726867676




pruned model's perplexity




3.4615042209625244




Total heads to prune: 6




{(4, 0), (0, 0), (1, 1), (2, 0), (5, 1), (3, 0)}




Evaluate the pruned model 4




Evaluating the model:   0%|                                                                                   …

Loss: 1.2926




Precision: 0.6385, Recall: 0.5994, F1-Score: 0.6045




              precision    recall  f1-score   support

           0     0.3973    0.6287    0.4869      2992
           1     0.7026    0.5124    0.5926      2992
           2     0.6566    0.6398    0.6481      3012
           3     0.3624    0.5637    0.4412      2998
           4     0.7727    0.6929    0.7306      2973
           5     0.8268    0.7642    0.7943      3054
           6     0.7207    0.3393    0.4614      3003
           7     0.6357    0.5747    0.6037      3012
           8     0.6860    0.5744    0.6253      2982
           9     0.6238    0.7036    0.6613      2982

    accuracy                         0.5995     30000
   macro avg     0.6385    0.5994    0.6045     30000
weighted avg     0.6387    0.5995    0.6048     30000





0.5528703163998864




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9796421937765347




CCA coefficients mean non-concern: 0.9827326890045448




Linear CKA concern: 0.9026463338643178




Linear CKA non-concern: 0.8925828505546984




Kernel CKA concern: 0.840894158485381




Kernel CKA non-concern: 0.7826840891432556




original model's perplexity




3.187649726867676




pruned model's perplexity




3.4627058506011963




Total heads to prune: 6




{(4, 0), (0, 0), (1, 1), (2, 0), (5, 1), (3, 0)}




Evaluate the pruned model 5




Evaluating the model:   0%|                                                                                   …

Loss: 1.2924




Precision: 0.6400, Recall: 0.5988, F1-Score: 0.6045




              precision    recall  f1-score   support

           0     0.3987    0.6273    0.4875      2992
           1     0.7040    0.5087    0.5906      2992
           2     0.6498    0.6481    0.6489      3012
           3     0.3549    0.5730    0.4383      2998
           4     0.7832    0.6805    0.7282      2973
           5     0.8223    0.7695    0.7950      3054
           6     0.7214    0.3397    0.4619      3003
           7     0.6396    0.5727    0.6043      3012
           8     0.6867    0.5711    0.6236      2982
           9     0.6391    0.6972    0.6669      2982

    accuracy                         0.5990     30000
   macro avg     0.6400    0.5988    0.6045     30000
weighted avg     0.6402    0.5990    0.6048     30000





0.5528703163998864




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9828280149715569




CCA coefficients mean non-concern: 0.9828149162791158




Linear CKA concern: 0.875016564938641




Linear CKA non-concern: 0.9018163981645021




Kernel CKA concern: 0.8343329552054427




Kernel CKA non-concern: 0.8063076536440746




original model's perplexity




3.187649726867676




pruned model's perplexity




3.462820529937744




Total heads to prune: 6




{(4, 0), (0, 0), (1, 1), (2, 0), (5, 1), (3, 0)}




Evaluate the pruned model 6




Evaluating the model:   0%|                                                                                   …

Loss: 1.2923




Precision: 0.6389, Recall: 0.5990, F1-Score: 0.6047




              precision    recall  f1-score   support

           0     0.4000    0.6257    0.4880      2992
           1     0.7066    0.5064    0.5900      2992
           2     0.6531    0.6438    0.6484      3012
           3     0.3572    0.5697    0.4391      2998
           4     0.7796    0.6855    0.7296      2973
           5     0.8279    0.7623    0.7937      3054
           6     0.7149    0.3440    0.4645      3003
           7     0.6352    0.5770    0.6047      3012
           8     0.6825    0.5751    0.6242      2982
           9     0.6322    0.7009    0.6648      2982

    accuracy                         0.5992     30000
   macro avg     0.6389    0.5990    0.6047     30000
weighted avg     0.6392    0.5992    0.6049     30000





0.5528703163998864




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9846642042716337




CCA coefficients mean non-concern: 0.9801313419207349




Linear CKA concern: 0.9078979299064894




Linear CKA non-concern: 0.895992605743931




Kernel CKA concern: 0.8060102471131925




Kernel CKA non-concern: 0.8007355452739168




original model's perplexity




3.187649726867676




pruned model's perplexity




3.461491584777832




Total heads to prune: 6




{(4, 0), (0, 0), (1, 1), (2, 0), (5, 1), (3, 0)}




Evaluate the pruned model 7




Evaluating the model:   0%|                                                                                   …

Loss: 1.2917




Precision: 0.6389, Recall: 0.5990, F1-Score: 0.6049




              precision    recall  f1-score   support

           0     0.3959    0.6243    0.4846      2992
           1     0.7020    0.5164    0.5950      2992
           2     0.6469    0.6471    0.6470      3012
           3     0.3589    0.5674    0.4397      2998
           4     0.7825    0.6778    0.7264      2973
           5     0.8295    0.7662    0.7966      3054
           6     0.7149    0.3440    0.4645      3003
           7     0.6368    0.5797    0.6069      3012
           8     0.6847    0.5718    0.6232      2982
           9     0.6370    0.6955    0.6650      2982

    accuracy                         0.5992     30000
   macro avg     0.6389    0.5990    0.6049     30000
weighted avg     0.6392    0.5992    0.6051     30000





0.5528703163998864




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9820589797080992




CCA coefficients mean non-concern: 0.982932086082009




Linear CKA concern: 0.898651117948678




Linear CKA non-concern: 0.8885900776167117




Kernel CKA concern: 0.8078084795353908




Kernel CKA non-concern: 0.7925829870660134




original model's perplexity




3.187649726867676




pruned model's perplexity




3.460731029510498




Total heads to prune: 6




{(4, 0), (0, 0), (1, 1), (2, 0), (5, 1), (3, 0)}




Evaluate the pruned model 8




Evaluating the model:   0%|                                                                                   …

Loss: 1.2932




Precision: 0.6382, Recall: 0.5988, F1-Score: 0.6041




              precision    recall  f1-score   support

           0     0.4012    0.6217    0.4877      2992
           1     0.7032    0.5124    0.5928      2992
           2     0.6530    0.6434    0.6482      3012
           3     0.3574    0.5684    0.4388      2998
           4     0.7766    0.6889    0.7301      2973
           5     0.8255    0.7623    0.7926      3054
           6     0.7196    0.3393    0.4612      3003
           7     0.6370    0.5727    0.6031      3012
           8     0.6815    0.5741    0.6232      2982
           9     0.6265    0.7049    0.6634      2982

    accuracy                         0.5990     30000
   macro avg     0.6382    0.5988    0.6041     30000
weighted avg     0.6384    0.5990    0.6043     30000





0.5528703163998864




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9864204039038451




CCA coefficients mean non-concern: 0.9820649915099328




Linear CKA concern: 0.9093640789351761




Linear CKA non-concern: 0.8917783143638252




Kernel CKA concern: 0.791411326247157




Kernel CKA non-concern: 0.7955439947647506




original model's perplexity




3.187649726867676




pruned model's perplexity




3.464167833328247




Total heads to prune: 6




{(4, 0), (0, 0), (1, 1), (2, 0), (5, 1), (3, 0)}




Evaluate the pruned model 9




Evaluating the model:   0%|                                                                                   …

Loss: 1.2925




Precision: 0.6392, Recall: 0.5986, F1-Score: 0.6044




              precision    recall  f1-score   support

           0     0.3979    0.6253    0.4864      2992
           1     0.7057    0.5080    0.5908      2992
           2     0.6515    0.6454    0.6484      3012
           3     0.3566    0.5720    0.4393      2998
           4     0.7831    0.6815    0.7288      2973
           5     0.8246    0.7649    0.7936      3054
           6     0.7175    0.3433    0.4644      3003
           7     0.6388    0.5724    0.6037      3012
           8     0.6814    0.5751    0.6237      2982
           9     0.6345    0.6975    0.6645      2982

    accuracy                         0.5987     30000
   macro avg     0.6392    0.5986    0.6044     30000
weighted avg     0.6394    0.5987    0.6046     30000





0.5528703163998864




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.983984954565108




CCA coefficients mean non-concern: 0.981828453559799




Linear CKA concern: 0.8490732101902595




Linear CKA non-concern: 0.8911517443052718




Kernel CKA concern: 0.7431682692337586




Kernel CKA non-concern: 0.7997720673396991




original model's perplexity




3.187649726867676




pruned model's perplexity




3.4613595008850098




In [9]:
df_list = [report_to_df(df) for df in result_list]
new_df = append_nth_row(df_list)
csv_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
new_df.to_csv(f"results/{csv_name}.csv", index=False)