In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint
from src.models.evaluate import (
    evaluate_model,
    get_sparsity,
    get_similarity,
    get_perplexity,
)
from src.utils.sampling import SamplingDataset
from src.pruning.prune_head import head_importance_prunning
from src.pruning.prune import prune_concern_identification
from src.utils.helper import report_to_df, append_nth_row

In [3]:
name = "bert-4-128-yahoo"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ratio = 0.3
seed = 44
include_layers = ["intermediate", "output"]
exclude_layers = [
    "attention",
]

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-10-15 16:53:45


In [5]:
config = Config(name, device)
num_labels = config.config["num_labels"]
model = load_model(config)

Loading the model.




{

'architectures'

: 

'bert'

,
 

'dataset_name'

: 

'YahooAnswersTopics'

,
 

'model_name'

: 

'models/bert-4-128-yahoo'

,
 

'num_labels'

: 

10

,
 

'tokenizer_name'

: 

'fabriceyhc/bert-base-uncased-yahoo_answers_topics'

}




The model models/bert-4-128-yahoo is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.




train.pkl is loaded from cache.




valid.pkl is loaded from cache.




test.pkl is loaded from cache.




The dataset YahooAnswersTopics is loaded




{

'config_name'

: 

'yahoo_answers_topics'

,
 

'features'

: 

{'first_column': 'question_title', 'second_column': 'topic'}

,
 

'path'

: 

'yahoo_answers_topics'

}




In [7]:
# print("Evaluate the original model")
# result = evaluate_model(model, config, test_dataloader)

In [8]:
result_list = []

for concern in range(config.num_labels):
    config.init_seed()
    positive_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        True,
        4,
        resample=False,
    )
    negative_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        False,
        4,
        resample=False,
    )
    all_samples = SamplingDataset(
        train_dataloader,
        config,
        200,
        num_samples,
        False,
        4,
        resample=False,
    )

    module = copy.deepcopy(model)

    head_importance_prunning(module, config, all_samples, ratio)

    prune_concern_identification(
        module,
        config,
        positive_samples,
        negative_samples,
        include_layers=include_layers,
        exclude_layers=exclude_layers,
        sparsity_ratio=ratio,
        keep_dim=True,
        method="unstructed",
    )

    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, config, test_dataloader)
    result_list.append(result)
    get_sparsity(module)

    get_similarity(model, module, valid_dataloader, concern, num_samples, config)
    print("original model's perplexity")
    get_perplexity(model, valid_dataloader, config)
    print("pruned model's perplexity")
    get_perplexity(module, valid_dataloader, config)

Total heads to prune: 4




{(0, 2), (0, 3), (3, 0), (2, 2)}




Evaluate the pruned model 0




Evaluating the model:   0%|                                                                                   …

Loss: 1.2761




Precision: 0.6501, Recall: 0.5958, F1-Score: 0.6063




              precision    recall  f1-score   support

           0     0.5711    0.4579    0.5083      2992
           1     0.6528    0.4505    0.5331      2992
           2     0.6997    0.5810    0.6349      3012
           3     0.2994    0.6791    0.4156      2998
           4     0.8153    0.6428    0.7188      2973
           5     0.8391    0.7600    0.7976      3054
           6     0.6896    0.3943    0.5017      3003
           7     0.5536    0.6776    0.6093      3012
           8     0.6347    0.6606    0.6474      2982
           9     0.7454    0.6539    0.6967      2982

    accuracy                         0.5960     30000
   macro avg     0.6501    0.5958    0.6063     30000
weighted avg     0.6502    0.5960    0.6066     30000





0.28338393832133246




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9957659285496266




CCA coefficients mean non-concern: 0.995393425641079




Linear CKA concern: 0.9878827115949036




Linear CKA non-concern: 0.9895866006717882




Kernel CKA concern: 0.9651110899146832




Kernel CKA non-concern: 0.9722597011329673




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.4068710803985596




Total heads to prune: 4




{(0, 2), (0, 3), (3, 0), (2, 2)}




Evaluate the pruned model 1




Evaluating the model:   0%|                                                                                   …

Loss: 1.2766




Precision: 0.6506, Recall: 0.5958, F1-Score: 0.6065




              precision    recall  f1-score   support

           0     0.5741    0.4582    0.5097      2992
           1     0.6514    0.4509    0.5329      2992
           2     0.7028    0.5793    0.6351      3012
           3     0.2988    0.6798    0.4152      2998
           4     0.8146    0.6445    0.7196      2973
           5     0.8399    0.7593    0.7976      3054
           6     0.6900    0.3943    0.5018      3003
           7     0.5522    0.6793    0.6092      3012
           8     0.6354    0.6593    0.6471      2982
           9     0.7464    0.6533    0.6967      2982

    accuracy                         0.5961     30000
   macro avg     0.6506    0.5958    0.6065     30000
weighted avg     0.6507    0.5961    0.6067     30000





0.28338393832133246




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9964560316929786




CCA coefficients mean non-concern: 0.9954236776328511




Linear CKA concern: 0.9895965975610964




Linear CKA non-concern: 0.9899990949514906




Kernel CKA concern: 0.9748885062016048




Kernel CKA non-concern: 0.9715876427777835




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.4089980125427246




Total heads to prune: 4




{(0, 2), (0, 3), (3, 0), (2, 2)}




Evaluate the pruned model 2




Evaluating the model:   0%|                                                                                   …

Loss: 1.2757




Precision: 0.6500, Recall: 0.5958, F1-Score: 0.6064




              precision    recall  f1-score   support

           0     0.5707    0.4589    0.5087      2992
           1     0.6529    0.4482    0.5315      2992
           2     0.7001    0.5813    0.6352      3012
           3     0.2990    0.6788    0.4151      2998
           4     0.8144    0.6465    0.7208      2973
           5     0.8406    0.7600    0.7983      3054
           6     0.6876    0.3943    0.5012      3003
           7     0.5545    0.6760    0.6092      3012
           8     0.6351    0.6613    0.6479      2982
           9     0.7455    0.6533    0.6963      2982

    accuracy                         0.5961     30000
   macro avg     0.6500    0.5958    0.6064     30000
weighted avg     0.6502    0.5961    0.6067     30000





0.28338393832133246




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9955593507413191




CCA coefficients mean non-concern: 0.9958159456917799




Linear CKA concern: 0.987952698293125




Linear CKA non-concern: 0.9889989743844748




Kernel CKA concern: 0.9662060369861182




Kernel CKA non-concern: 0.9700274993906729




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.4048755168914795




Total heads to prune: 4




{(0, 2), (0, 3), (3, 0), (2, 2)}




Evaluate the pruned model 3




Evaluating the model:   0%|                                                                                   …

Loss: 1.2761




Precision: 0.6504, Recall: 0.5958, F1-Score: 0.6065




              precision    recall  f1-score   support

           0     0.5716    0.4576    0.5083      2992
           1     0.6526    0.4495    0.5324      2992
           2     0.7020    0.5810    0.6358      3012
           3     0.2986    0.6795    0.4149      2998
           4     0.8153    0.6458    0.7207      2973
           5     0.8391    0.7597    0.7974      3054
           6     0.6894    0.3939    0.5014      3003
           7     0.5547    0.6770    0.6097      3012
           8     0.6353    0.6613    0.6480      2982
           9     0.7455    0.6533    0.6963      2982

    accuracy                         0.5961     30000
   macro avg     0.6504    0.5958    0.6065     30000
weighted avg     0.6506    0.5961    0.6067     30000





0.28338393832133246




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9963503727900719




CCA coefficients mean non-concern: 0.9954192211055408




Linear CKA concern: 0.98732930179026




Linear CKA non-concern: 0.9891617944314459




Kernel CKA concern: 0.9682316004611726




Kernel CKA non-concern: 0.9717380085655761




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.407057046890259




Total heads to prune: 4




{(0, 2), (0, 3), (3, 0), (2, 2)}




Evaluate the pruned model 4




Evaluating the model:   0%|                                                                                   …

Loss: 1.2754




Precision: 0.6503, Recall: 0.5961, F1-Score: 0.6066




              precision    recall  f1-score   support

           0     0.5736    0.4596    0.5103      2992
           1     0.6537    0.4492    0.5325      2992
           2     0.7012    0.5813    0.6357      3012
           3     0.2993    0.6788    0.4154      2998
           4     0.8132    0.6458    0.7199      2973
           5     0.8400    0.7597    0.7978      3054
           6     0.6896    0.3936    0.5012      3003
           7     0.5536    0.6773    0.6092      3012
           8     0.6348    0.6610    0.6476      2982
           9     0.7442    0.6546    0.6965      2982

    accuracy                         0.5963     30000
   macro avg     0.6503    0.5961    0.6066     30000
weighted avg     0.6505    0.5963    0.6068     30000





0.28338393832133246




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9891490580823623




CCA coefficients mean non-concern: 0.9960976784577631




Linear CKA concern: 0.9601262175049861




Linear CKA non-concern: 0.9888371711875804




Kernel CKA concern: 0.9193771061696046




Kernel CKA non-concern: 0.971175467781624




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.4043149948120117




Total heads to prune: 4




{(0, 2), (0, 3), (3, 0), (2, 2)}




Evaluate the pruned model 5




Evaluating the model:   0%|                                                                                   …

Loss: 1.2763




Precision: 0.6502, Recall: 0.5955, F1-Score: 0.6060




              precision    recall  f1-score   support

           0     0.5730    0.4579    0.5090      2992
           1     0.6536    0.4472    0.5311      2992
           2     0.6988    0.5817    0.6349      3012
           3     0.2989    0.6801    0.4153      2998
           4     0.8158    0.6404    0.7175      2973
           5     0.8395    0.7606    0.7981      3054
           6     0.6894    0.3939    0.5014      3003
           7     0.5540    0.6780    0.6097      3012
           8     0.6328    0.6623    0.6472      2982
           9     0.7459    0.6526    0.6961      2982

    accuracy                         0.5957     30000
   macro avg     0.6502    0.5955    0.6060     30000
weighted avg     0.6503    0.5957    0.6063     30000





0.28338393832133246




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9929159569807253




CCA coefficients mean non-concern: 0.9954941999411822




Linear CKA concern: 0.9783755951047551




Linear CKA non-concern: 0.9908800980431536




Kernel CKA concern: 0.9646175672136222




Kernel CKA non-concern: 0.9732305940752625




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.4076740741729736




Total heads to prune: 4




{(0, 2), (0, 3), (3, 0), (2, 2)}




Evaluate the pruned model 6




Evaluating the model:   0%|                                                                                   …

Loss: 1.2760




Precision: 0.6503, Recall: 0.5957, F1-Score: 0.6063




              precision    recall  f1-score   support

           0     0.5714    0.4589    0.5090      2992
           1     0.6532    0.4475    0.5311      2992
           2     0.7014    0.5810    0.6356      3012
           3     0.2989    0.6788    0.4151      2998
           4     0.8151    0.6435    0.7192      2973
           5     0.8411    0.7590    0.7979      3054
           6     0.6892    0.3943    0.5016      3003
           7     0.5535    0.6790    0.6098      3012
           8     0.6343    0.6620    0.6479      2982
           9     0.7450    0.6536    0.6963      2982

    accuracy                         0.5960     30000
   macro avg     0.6503    0.5957    0.6063     30000
weighted avg     0.6505    0.5960    0.6066     30000





0.28338393832133246




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.996268965998561




CCA coefficients mean non-concern: 0.9953070354375851




Linear CKA concern: 0.9869155834978894




Linear CKA non-concern: 0.9899953788083639




Kernel CKA concern: 0.9610592876359108




Kernel CKA non-concern: 0.9727121830200585




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.4065604209899902




Total heads to prune: 4




{(0, 2), (0, 3), (3, 0), (2, 2)}




Evaluate the pruned model 7




Evaluating the model:   0%|                                                                                   …

Loss: 1.2759




Precision: 0.6501, Recall: 0.5958, F1-Score: 0.6064




              precision    recall  f1-score   support

           0     0.5715    0.4582    0.5086      2992
           1     0.6530    0.4509    0.5334      2992
           2     0.6995    0.5820    0.6354      3012
           3     0.2995    0.6791    0.4156      2998
           4     0.8149    0.6428    0.7187      2973
           5     0.8390    0.7593    0.7972      3054
           6     0.6894    0.3946    0.5019      3003
           7     0.5545    0.6770    0.6097      3012
           8     0.6335    0.6620    0.6474      2982
           9     0.7459    0.6526    0.6961      2982

    accuracy                         0.5961     30000
   macro avg     0.6501    0.5958    0.6064     30000
weighted avg     0.6502    0.5961    0.6066     30000





0.28338393832133246




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9957213138151177




CCA coefficients mean non-concern: 0.9953607185991555




Linear CKA concern: 0.9866849176144178




Linear CKA non-concern: 0.9892623001929095




Kernel CKA concern: 0.9697732093139919




Kernel CKA non-concern: 0.9714138482118979




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.4061169624328613




Total heads to prune: 4




{(0, 2), (0, 3), (3, 0), (2, 2)}




Evaluate the pruned model 8




Evaluating the model:   0%|                                                                                   …

Loss: 1.2762




Precision: 0.6503, Recall: 0.5960, F1-Score: 0.6065




              precision    recall  f1-score   support

           0     0.5726    0.4576    0.5086      2992
           1     0.6531    0.4499    0.5328      2992
           2     0.7011    0.5810    0.6354      3012
           3     0.2996    0.6795    0.4158      2998
           4     0.8144    0.6465    0.7208      2973
           5     0.8416    0.7583    0.7978      3054
           6     0.6892    0.3936    0.5011      3003
           7     0.5523    0.6786    0.6090      3012
           8     0.6340    0.6616    0.6475      2982
           9     0.7455    0.6533    0.6963      2982

    accuracy                         0.5962     30000
   macro avg     0.6503    0.5960    0.6065     30000
weighted avg     0.6505    0.5962    0.6067     30000





0.28338393832133246




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9955270297553519




CCA coefficients mean non-concern: 0.9953943162912836




Linear CKA concern: 0.9876928271995334




Linear CKA non-concern: 0.9892776521108769




Kernel CKA concern: 0.9688617688720154




Kernel CKA non-concern: 0.9721172576345616




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.407153844833374




Total heads to prune: 4




{(0, 2), (0, 3), (3, 0), (2, 2)}




Evaluate the pruned model 9




Evaluating the model:   0%|                                                                                   …

Loss: 1.2760




Precision: 0.6502, Recall: 0.5958, F1-Score: 0.6064




              precision    recall  f1-score   support

           0     0.5720    0.4569    0.5080      2992
           1     0.6524    0.4499    0.5325      2992
           2     0.6990    0.5813    0.6348      3012
           3     0.2994    0.6791    0.4156      2998
           4     0.8152    0.6424    0.7186      2973
           5     0.8401    0.7603    0.7982      3054
           6     0.6896    0.3936    0.5012      3003
           7     0.5533    0.6790    0.6097      3012
           8     0.6347    0.6620    0.6481      2982
           9     0.7460    0.6539    0.6969      2982

    accuracy                         0.5961     30000
   macro avg     0.6502    0.5958    0.6064     30000
weighted avg     0.6503    0.5961    0.6066     30000





0.28338393832133246




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9958928523221626




CCA coefficients mean non-concern: 0.9953454744320714




Linear CKA concern: 0.9821805967056623




Linear CKA non-concern: 0.9893736772731835




Kernel CKA concern: 0.9609864083520072




Kernel CKA non-concern: 0.9719004779580833




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.4065914154052734




In [9]:
df_list = [report_to_df(df) for df in result_list]
new_df = append_nth_row(df_list)
csv_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
new_df.to_csv(f"results/{csv_name}.csv", index=False)