In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint
from src.models.evaluate import (
    evaluate_model,
    get_sparsity,
    get_similarity,
    get_perplexity,
)
from src.utils.sampling import SamplingDataset
from src.pruning.prune_head import head_importance_prunning
from src.pruning.prune import prune_concern_identification
from src.utils.helper import report_to_df, append_nth_row

In [3]:
name = "bert-6-128-yahoo"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ratio = 0.3
seed = 44
include_layers = ["intermediate", "output"]
exclude_layers = [
    "attention",
]

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-10-15 17:39:22


In [5]:
config = Config(name, device)
num_labels = config.config["num_labels"]
model = load_model(config)

Loading the model.




{

'architectures'

: 

'bert'

,
 

'dataset_name'

: 

'YahooAnswersTopics'

,
 

'model_name'

: 

'models/bert-6-128-yahoo'

,
 

'num_labels'

: 

10

,
 

'tokenizer_name'

: 

'fabriceyhc/bert-base-uncased-yahoo_answers_topics'

}




The model models/bert-6-128-yahoo is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.




train.pkl is loaded from cache.




valid.pkl is loaded from cache.




test.pkl is loaded from cache.




The dataset YahooAnswersTopics is loaded




{

'config_name'

: 

'yahoo_answers_topics'

,
 

'features'

: 

{'first_column': 'question_title', 'second_column': 'topic'}

,
 

'path'

: 

'yahoo_answers_topics'

}




In [7]:
# print("Evaluate the original model")
# result = evaluate_model(model, config, test_dataloader)

In [8]:
result_list = []

for concern in range(config.num_labels):
    config.init_seed()
    positive_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        True,
        4,
        resample=False,
    )
    negative_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        False,
        4,
        resample=False,
    )
    all_samples = SamplingDataset(
        train_dataloader,
        config,
        200,
        num_samples,
        False,
        4,
        resample=False,
    )

    module = copy.deepcopy(model)

    head_importance_prunning(module, config, all_samples, ratio)

    prune_concern_identification(
        module,
        config,
        positive_samples,
        negative_samples,
        include_layers=["intermediate", "output"],
        exclude_layers=["attention"],
        sparsity_ratio=ratio,
        keep_dim=True,
        method="unstructed",
    )

    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, config, test_dataloader)
    result_list.append(result)
    get_sparsity(module)

    get_similarity(model, module, valid_dataloader, concern, num_samples, config)
    print("original model's perplexity")
    get_perplexity(model, valid_dataloader, config)
    print("pruned model's perplexity")
    get_perplexity(module, valid_dataloader, config)

Total heads to prune: 3




{(4, 0), (5, 1), (0, 0)}




Evaluate the pruned model 0




Evaluating the model:   0%|                                                                                   …

Loss: 1.2549




Precision: 0.6453, Recall: 0.6034, F1-Score: 0.6113




              precision    recall  f1-score   support

           0     0.4783    0.5037    0.4906      2992
           1     0.7216    0.4729    0.5714      2992
           2     0.6757    0.6268    0.6504      3012
           3     0.3225    0.6564    0.4325      2998
           4     0.7826    0.6963    0.7369      2973
           5     0.8284    0.7698    0.7980      3054
           6     0.6848    0.3849    0.4929      3003
           7     0.6273    0.6235    0.6254      3012
           8     0.6338    0.6308    0.6323      2982
           9     0.6978    0.6690    0.6831      2982

    accuracy                         0.6036     30000
   macro avg     0.6453    0.6034    0.6113     30000
weighted avg     0.6455    0.6036    0.6116     30000





0.2761163171870252




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.992530340937872




CCA coefficients mean non-concern: 0.9925267784221903




Linear CKA concern: 0.9661432508893693




Linear CKA non-concern: 0.9617625849741985




Kernel CKA concern: 0.9380639477437892




Kernel CKA non-concern: 0.9193762068548302




original model's perplexity




3.187649726867676




pruned model's perplexity




3.302976369857788




Total heads to prune: 3




{(4, 0), (5, 1), (0, 0)}




Evaluate the pruned model 1




Evaluating the model:   0%|                                                                                   …

Loss: 1.2551




Precision: 0.6450, Recall: 0.6032, F1-Score: 0.6110




              precision    recall  f1-score   support

           0     0.4784    0.5037    0.4907      2992
           1     0.7223    0.4719    0.5709      2992
           2     0.6762    0.6262    0.6502      3012
           3     0.3225    0.6561    0.4325      2998
           4     0.7815    0.6953    0.7358      2973
           5     0.8273    0.7701    0.7977      3054
           6     0.6841    0.3843    0.4921      3003
           7     0.6254    0.6218    0.6236      3012
           8     0.6325    0.6338    0.6332      2982
           9     0.6995    0.6683    0.6836      2982

    accuracy                         0.6034     30000
   macro avg     0.6450    0.6032    0.6110     30000
weighted avg     0.6452    0.6034    0.6113     30000





0.2761163171870252




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9944160716858368




CCA coefficients mean non-concern: 0.9927686336596517




Linear CKA concern: 0.9609356234058695




Linear CKA non-concern: 0.9636041019606236




Kernel CKA concern: 0.9169113458231192




Kernel CKA non-concern: 0.9250766435084705




original model's perplexity




3.187649726867676




pruned model's perplexity




3.3039956092834473




Total heads to prune: 3




{(4, 0), (5, 1), (0, 0)}




Evaluate the pruned model 2




Evaluating the model:   0%|                                                                                   …

Loss: 1.2555




Precision: 0.6453, Recall: 0.6032, F1-Score: 0.6111




              precision    recall  f1-score   support

           0     0.4785    0.5053    0.4915      2992
           1     0.7213    0.4706    0.5696      2992
           2     0.6774    0.6262    0.6508      3012
           3     0.3226    0.6578    0.4329      2998
           4     0.7829    0.6949    0.7363      2973
           5     0.8275    0.7695    0.7974      3054
           6     0.6859    0.3833    0.4918      3003
           7     0.6254    0.6218    0.6236      3012
           8     0.6332    0.6328    0.6330      2982
           9     0.6985    0.6697    0.6838      2982

    accuracy                         0.6034     30000
   macro avg     0.6453    0.6032    0.6111     30000
weighted avg     0.6456    0.6034    0.6113     30000





0.2761163171870252




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9914435238465962




CCA coefficients mean non-concern: 0.9925776228330625




Linear CKA concern: 0.9582366740342915




Linear CKA non-concern: 0.9625484506693741




Kernel CKA concern: 0.930633496908423




Kernel CKA non-concern: 0.92208272384427




original model's perplexity




3.187649726867676




pruned model's perplexity




3.3055524826049805




Total heads to prune: 3




{(4, 0), (5, 1), (0, 0)}




Evaluate the pruned model 3




Evaluating the model:   0%|                                                                                   …

Loss: 1.2552




Precision: 0.6455, Recall: 0.6035, F1-Score: 0.6114




              precision    recall  f1-score   support

           0     0.4781    0.5037    0.4906      2992
           1     0.7213    0.4723    0.5708      2992
           2     0.6768    0.6265    0.6507      3012
           3     0.3227    0.6574    0.4329      2998
           4     0.7819    0.6959    0.7364      2973
           5     0.8272    0.7698    0.7975      3054
           6     0.6870    0.3830    0.4918      3003
           7     0.6269    0.6232    0.6250      3012
           8     0.6341    0.6351    0.6346      2982
           9     0.6991    0.6683    0.6834      2982

    accuracy                         0.6037     30000
   macro avg     0.6455    0.6035    0.6114     30000
weighted avg     0.6458    0.6037    0.6116     30000





0.2761163171870252




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9927510077226794




CCA coefficients mean non-concern: 0.9929663565776454




Linear CKA concern: 0.9592173566695616




Linear CKA non-concern: 0.9617488365060584




Kernel CKA concern: 0.9312686846454613




Kernel CKA non-concern: 0.922657289368938




original model's perplexity




3.187649726867676




pruned model's perplexity




3.3042519092559814




Total heads to prune: 3




{(4, 0), (5, 1), (0, 0)}




Evaluate the pruned model 4




Evaluating the model:   0%|                                                                                   …

Loss: 1.2548




Precision: 0.6455, Recall: 0.6036, F1-Score: 0.6114




              precision    recall  f1-score   support

           0     0.4801    0.5030    0.4913      2992
           1     0.7218    0.4709    0.5700      2992
           2     0.6764    0.6258    0.6501      3012
           3     0.3227    0.6574    0.4329      2998
           4     0.7837    0.6946    0.7364      2973
           5     0.8259    0.7705    0.7972      3054
           6     0.6849    0.3836    0.4918      3003
           7     0.6262    0.6258    0.6260      3012
           8     0.6345    0.6345    0.6345      2982
           9     0.6986    0.6700    0.6840      2982

    accuracy                         0.6038     30000
   macro avg     0.6455    0.6036    0.6114     30000
weighted avg     0.6457    0.6038    0.6117     30000





0.2761163171870252




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9927238178009647




CCA coefficients mean non-concern: 0.9930807756589788




Linear CKA concern: 0.955528904608107




Linear CKA non-concern: 0.9620196624401584




Kernel CKA concern: 0.9407108929431928




Kernel CKA non-concern: 0.9215235483485202




original model's perplexity




3.187649726867676




pruned model's perplexity




3.302764415740967




Total heads to prune: 3




{(4, 0), (5, 1), (0, 0)}




Evaluate the pruned model 5




Evaluating the model:   0%|                                                                                   …

Loss: 1.2549




Precision: 0.6451, Recall: 0.6032, F1-Score: 0.6110




              precision    recall  f1-score   support

           0     0.4776    0.5027    0.4898      2992
           1     0.7209    0.4706    0.5695      2992
           2     0.6769    0.6268    0.6509      3012
           3     0.3223    0.6568    0.4324      2998
           4     0.7821    0.6956    0.7363      2973
           5     0.8274    0.7708    0.7981      3054
           6     0.6843    0.3839    0.4919      3003
           7     0.6263    0.6232    0.6247      3012
           8     0.6339    0.6325    0.6332      2982
           9     0.6989    0.6687    0.6835      2982

    accuracy                         0.6034     30000
   macro avg     0.6451    0.6032    0.6110     30000
weighted avg     0.6453    0.6034    0.6113     30000





0.2761163171870252




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9937213896049278




CCA coefficients mean non-concern: 0.9935046368644637




Linear CKA concern: 0.9506650034963726




Linear CKA non-concern: 0.9639469179591413




Kernel CKA concern: 0.9481958751451249




Kernel CKA non-concern: 0.9240327961575568




original model's perplexity




3.187649726867676




pruned model's perplexity




3.3028857707977295




Total heads to prune: 3




{(4, 0), (5, 1), (0, 0)}




Evaluate the pruned model 6




Evaluating the model:   0%|                                                                                   …

Loss: 1.2554




Precision: 0.6451, Recall: 0.6031, F1-Score: 0.6110




              precision    recall  f1-score   support

           0     0.4774    0.5037    0.4902      2992
           1     0.7209    0.4723    0.5707      2992
           2     0.6760    0.6255    0.6498      3012
           3     0.3226    0.6574    0.4328      2998
           4     0.7829    0.6939    0.7357      2973
           5     0.8268    0.7692    0.7969      3054
           6     0.6849    0.3843    0.4923      3003
           7     0.6267    0.6232    0.6249      3012
           8     0.6341    0.6328    0.6334      2982
           9     0.6987    0.6687    0.6833      2982

    accuracy                         0.6033     30000
   macro avg     0.6451    0.6031    0.6110     30000
weighted avg     0.6453    0.6033    0.6112     30000





0.2761163171870252




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9930385717280534




CCA coefficients mean non-concern: 0.9924645914460035




Linear CKA concern: 0.966198662472869




Linear CKA non-concern: 0.9609043340399199




Kernel CKA concern: 0.9336881760246553




Kernel CKA non-concern: 0.9248561620213568




original model's perplexity




3.187649726867676




pruned model's perplexity




3.304591178894043




Total heads to prune: 3




{(4, 0), (5, 1), (0, 0)}




Evaluate the pruned model 7




Evaluating the model:   0%|                                                                                   …

Loss: 1.2546




Precision: 0.6450, Recall: 0.6035, F1-Score: 0.6114




              precision    recall  f1-score   support

           0     0.4779    0.5027    0.4900      2992
           1     0.7200    0.4726    0.5706      2992
           2     0.6767    0.6268    0.6508      3012
           3     0.3231    0.6558    0.4329      2998
           4     0.7825    0.6946    0.7359      2973
           5     0.8274    0.7708    0.7981      3054
           6     0.6836    0.3856    0.4931      3003
           7     0.6270    0.6238    0.6254      3012
           8     0.6338    0.6331    0.6335      2982
           9     0.6977    0.6693    0.6832      2982

    accuracy                         0.6037     30000
   macro avg     0.6450    0.6035    0.6114     30000
weighted avg     0.6452    0.6037    0.6116     30000





0.2761163171870252




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9931435540345733




CCA coefficients mean non-concern: 0.9928965233470955




Linear CKA concern: 0.9612132142377636




Linear CKA non-concern: 0.9585784028735845




Kernel CKA concern: 0.9393697936672613




Kernel CKA non-concern: 0.920396144938689




original model's perplexity




3.187649726867676




pruned model's perplexity




3.302596092224121




Total heads to prune: 3




{(4, 0), (5, 1), (0, 0)}




Evaluate the pruned model 8




Evaluating the model:   0%|                                                                                   …

Loss: 1.2558




Precision: 0.6452, Recall: 0.6030, F1-Score: 0.6109




              precision    recall  f1-score   support

           0     0.4792    0.5037    0.4911      2992
           1     0.7193    0.4703    0.5687      2992
           2     0.6765    0.6255    0.6500      3012
           3     0.3216    0.6574    0.4319      2998
           4     0.7828    0.6946    0.7361      2973
           5     0.8275    0.7695    0.7974      3054
           6     0.6861    0.3836    0.4921      3003
           7     0.6268    0.6222    0.6245      3012
           8     0.6344    0.6338    0.6341      2982
           9     0.6981    0.6693    0.6834      2982

    accuracy                         0.6032     30000
   macro avg     0.6452    0.6030    0.6109     30000
weighted avg     0.6455    0.6032    0.6112     30000





0.2761163171870252




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9940306639892433




CCA coefficients mean non-concern: 0.9926781327679007




Linear CKA concern: 0.9661026842178142




Linear CKA non-concern: 0.9564885547651615




Kernel CKA concern: 0.8959097538664932




Kernel CKA non-concern: 0.9203129797850418




original model's perplexity




3.187649726867676




pruned model's perplexity




3.3065736293792725




Total heads to prune: 3




{(4, 0), (5, 1), (0, 0)}




Evaluate the pruned model 9




Evaluating the model:   0%|                                                                                   …

Loss: 1.2548




Precision: 0.6450, Recall: 0.6034, F1-Score: 0.6112




              precision    recall  f1-score   support

           0     0.4777    0.5037    0.4903      2992
           1     0.7211    0.4726    0.5710      2992
           2     0.6762    0.6275    0.6509      3012
           3     0.3227    0.6554    0.4325      2998
           4     0.7819    0.6946    0.7357      2973
           5     0.8272    0.7698    0.7975      3054
           6     0.6834    0.3846    0.4922      3003
           7     0.6271    0.6248    0.6260      3012
           8     0.6345    0.6311    0.6328      2982
           9     0.6979    0.6693    0.6833      2982

    accuracy                         0.6036     30000
   macro avg     0.6450    0.6034    0.6112     30000
weighted avg     0.6452    0.6036    0.6115     30000





0.2761163171870252




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9937981115614117




CCA coefficients mean non-concern: 0.9927947277027416




Linear CKA concern: 0.9512725489050856




Linear CKA non-concern: 0.9598995373053709




Kernel CKA concern: 0.9140595227180366




Kernel CKA non-concern: 0.9246423695354197




original model's perplexity




3.187649726867676




pruned model's perplexity




3.30281925201416




In [9]:
df_list = [report_to_df(df) for df in result_list]
new_df = append_nth_row(df_list)
csv_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
new_df.to_csv(f"results/{csv_name}.csv", index=False)