In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint
from src.models.evaluate import (
    evaluate_model,
    get_sparsity,
    get_similarity,
    get_perplexity,
)
from src.utils.sampling import SamplingDataset
from src.pruning.prune import (
    prune_concern_identification,
)
from src.utils.helper import report_to_df, append_nth_row

In [3]:
name = "bert-4-128-yahoo"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ratio = 0.6
seed = 44
include_layers = ["intermediate", "output"]
exclude_layers = ["attention"]

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-10-15 20:05:42


In [5]:
config = Config(name, device)
num_labels = config.config["num_labels"]
model = load_model(config)

Loading the model.




{

'architectures'

: 

'bert'

,
 

'dataset_name'

: 

'YahooAnswersTopics'

,
 

'model_name'

: 

'models/bert-4-128-yahoo'

,
 

'num_labels'

: 

10

,
 

'tokenizer_name'

: 

'fabriceyhc/bert-base-uncased-yahoo_answers_topics'

}




The model models/bert-4-128-yahoo is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.




train.pkl is loaded from cache.




valid.pkl is loaded from cache.




test.pkl is loaded from cache.




The dataset YahooAnswersTopics is loaded




{

'config_name'

: 

'yahoo_answers_topics'

,
 

'features'

: 

{'first_column': 'question_title', 'second_column': 'topic'}

,
 

'path'

: 

'yahoo_answers_topics'

}




In [7]:
# print("Evaluate the original model")
# result = evaluate_model(model, config, test_dataloader)

In [8]:
result_list = []

for concern in range(config.num_labels):
    config.init_seed()
    positive_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        True,
        4,
        resample=False,
    )
    negative_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        False,
        4,
        resample=False,
    )
    all_samples = SamplingDataset(
        train_dataloader,
        config,
        200,
        num_samples,
        False,
        4,
        resample=False,
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        config,
        positive_samples,
        negative_samples,
        include_layers=["intermediate", "output"],
        exclude_layers=["attention"],
        sparsity_ratio=ratio,
        keep_dim=True,
        method="unstructed",
    )
    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, config, test_dataloader)
    result_list.append(result)
    get_sparsity(module)

    get_similarity(model, module, valid_dataloader, concern, num_samples, config)
    print("original model's perplexity")
    get_perplexity(model, valid_dataloader, config)
    print("pruned model's perplexity")
    get_perplexity(module, valid_dataloader, config)

Evaluate the pruned model 0




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2183




Precision: 0.6488, Recall: 0.6166, F1-Score: 0.6212




              precision    recall  f1-score   support

           0     0.5381    0.4856    0.5105      2992
           1     0.6938    0.4983    0.5800      2992
           2     0.7002    0.6119    0.6531      3012
           3     0.3436    0.6351    0.4459      2998
           4     0.7265    0.7763    0.7506      2973
           5     0.8413    0.7656    0.8016      3054
           6     0.6793    0.3986    0.5024      3003
           7     0.6231    0.6318    0.6274      3012
           8     0.5837    0.7213    0.6453      2982
           9     0.7587    0.6412    0.6950      2982

    accuracy                         0.6167     30000
   macro avg     0.6488    0.6166    0.6212     30000
weighted avg     0.6491    0.6167    0.6214     30000





0.46859969232566745




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9980806026240419




CCA coefficients mean non-concern: 0.9977564558021006




Linear CKA concern: 0.9994842237516117




Linear CKA non-concern: 0.9992970929950756




Kernel CKA concern: 0.9980809059085995




Kernel CKA non-concern: 0.9970033002962281




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.20340895652771




Evaluate the pruned model 1




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2178




Precision: 0.6480, Recall: 0.6152, F1-Score: 0.6197




              precision    recall  f1-score   support

           0     0.5407    0.4816    0.5095      2992
           1     0.6883    0.4893    0.5720      2992
           2     0.7017    0.6116    0.6535      3012
           3     0.3425    0.6381    0.4458      2998
           4     0.7205    0.7787    0.7485      2973
           5     0.8398    0.7639    0.8001      3054
           6     0.6822    0.3996    0.5040      3003
           7     0.6214    0.6331    0.6272      3012
           8     0.5831    0.7190    0.6439      2982
           9     0.7594    0.6372    0.6929      2982

    accuracy                         0.6153     30000
   macro avg     0.6480    0.6152    0.6197     30000
weighted avg     0.6483    0.6153    0.6199     30000





0.46859969232566745




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9983406488971671




CCA coefficients mean non-concern: 0.9980518279295117




Linear CKA concern: 0.9992883218897163




Linear CKA non-concern: 0.9992968619260993




Kernel CKA concern: 0.998199998564699




Kernel CKA non-concern: 0.997186103509956




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2011353969573975




Evaluate the pruned model 2




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2164




Precision: 0.6480, Recall: 0.6169, F1-Score: 0.6213




              precision    recall  f1-score   support

           0     0.5298    0.4987    0.5138      2992
           1     0.6956    0.4950    0.5784      2992
           2     0.6977    0.6116    0.6518      3012
           3     0.3479    0.6328    0.4490      2998
           4     0.7219    0.7790    0.7494      2973
           5     0.8412    0.7633    0.8003      3054
           6     0.6793    0.3999    0.5035      3003
           7     0.6282    0.6272    0.6277      3012
           8     0.5860    0.7186    0.6456      2982
           9     0.7520    0.6435    0.6935      2982

    accuracy                         0.6170     30000
   macro avg     0.6480    0.6169    0.6213     30000
weighted avg     0.6483    0.6170    0.6215     30000





0.46859969232566745




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9978987866047849




CCA coefficients mean non-concern: 0.9977590443604667




Linear CKA concern: 0.9993156887929199




Linear CKA non-concern: 0.9990053336604323




Kernel CKA concern: 0.9978438738086525




Kernel CKA non-concern: 0.9961964103555262




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.195537567138672




Evaluate the pruned model 3




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2176




Precision: 0.6484, Recall: 0.6166, F1-Score: 0.6211




              precision    recall  f1-score   support

           0     0.5417    0.4840    0.5112      2992
           1     0.6923    0.4993    0.5802      2992
           2     0.6992    0.6135    0.6536      3012
           3     0.3450    0.6358    0.4473      2998
           4     0.7203    0.7787    0.7483      2973
           5     0.8399    0.7646    0.8005      3054
           6     0.6785    0.3999    0.5032      3003
           7     0.6237    0.6318    0.6277      3012
           8     0.5834    0.7190    0.6441      2982
           9     0.7599    0.6398    0.6947      2982

    accuracy                         0.6167     30000
   macro avg     0.6484    0.6166    0.6211     30000
weighted avg     0.6487    0.6167    0.6213     30000





0.46859969232566745




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9977886318569122




CCA coefficients mean non-concern: 0.9977851040343891




Linear CKA concern: 0.9973504570960653




Linear CKA non-concern: 0.9992311533279363




Kernel CKA concern: 0.9941435363261605




Kernel CKA non-concern: 0.9967089575841446




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.202406644821167




Evaluate the pruned model 4




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2172




Precision: 0.6485, Recall: 0.6167, F1-Score: 0.6211




              precision    recall  f1-score   support

           0     0.5271    0.5043    0.5155      2992
           1     0.6926    0.4977    0.5792      2992
           2     0.7022    0.6129    0.6545      3012
           3     0.3470    0.6321    0.4480      2998
           4     0.7153    0.7824    0.7473      2973
           5     0.8438    0.7590    0.7992      3054
           6     0.6882    0.3946    0.5016      3003
           7     0.6261    0.6315    0.6288      3012
           8     0.5884    0.7106    0.6438      2982
           9     0.7544    0.6419    0.6936      2982

    accuracy                         0.6168     30000
   macro avg     0.6485    0.6167    0.6211     30000
weighted avg     0.6488    0.6168    0.6213     30000





0.46859969232566745




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9971691079342435




CCA coefficients mean non-concern: 0.9979450138810007




Linear CKA concern: 0.9985501179079366




Linear CKA non-concern: 0.9992388647210604




Kernel CKA concern: 0.996547025472031




Kernel CKA non-concern: 0.9968537738454195




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.1962599754333496




Evaluate the pruned model 5




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2131




Precision: 0.6484, Recall: 0.6176, F1-Score: 0.6218




              precision    recall  f1-score   support

           0     0.5382    0.4923    0.5142      2992
           1     0.6930    0.4987    0.5800      2992
           2     0.6960    0.6195    0.6555      3012
           3     0.3487    0.6331    0.4497      2998
           4     0.7191    0.7777    0.7473      2973
           5     0.8346    0.7682    0.8000      3054
           6     0.6822    0.3989    0.5035      3003
           7     0.6201    0.6358    0.6279      3012
           8     0.5870    0.7153    0.6448      2982
           9     0.7651    0.6368    0.6951      2982

    accuracy                         0.6177     30000
   macro avg     0.6484    0.6176    0.6218     30000
weighted avg     0.6487    0.6177    0.6220     30000





0.46859969232566745




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9968699115330867




CCA coefficients mean non-concern: 0.9978652196935675




Linear CKA concern: 0.9902385725580445




Linear CKA non-concern: 0.999209829553827




Kernel CKA concern: 0.9878405361470985




Kernel CKA non-concern: 0.9973146217712303




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.1827738285064697




Evaluate the pruned model 6




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2169




Precision: 0.6484, Recall: 0.6168, F1-Score: 0.6213




              precision    recall  f1-score   support

           0     0.5377    0.4906    0.5131      2992
           1     0.6943    0.4950    0.5780      2992
           2     0.6990    0.6145    0.6541      3012
           3     0.3468    0.6354    0.4487      2998
           4     0.7224    0.7756    0.7481      2973
           5     0.8436    0.7577    0.7983      3054
           6     0.6769    0.4026    0.5049      3003
           7     0.6230    0.6341    0.6285      3012
           8     0.5844    0.7213    0.6457      2982
           9     0.7563    0.6412    0.6940      2982

    accuracy                         0.6169     30000
   macro avg     0.6484    0.6168    0.6213     30000
weighted avg     0.6488    0.6169    0.6215     30000





0.46859969232566745




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9982528397892427




CCA coefficients mean non-concern: 0.9976583524454468




Linear CKA concern: 0.9996036203378775




Linear CKA non-concern: 0.9990301411934175




Kernel CKA concern: 0.9982947040952165




Kernel CKA non-concern: 0.9961270858949616




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.1990296840667725




Evaluate the pruned model 7




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2177




Precision: 0.6482, Recall: 0.6170, F1-Score: 0.6214




              precision    recall  f1-score   support

           0     0.5345    0.4923    0.5125      2992
           1     0.6910    0.5000    0.5802      2992
           2     0.7030    0.6082    0.6522      3012
           3     0.3474    0.6324    0.4485      2998
           4     0.7215    0.7763    0.7479      2973
           5     0.8392    0.7639    0.7998      3054
           6     0.6805    0.3993    0.5033      3003
           7     0.6201    0.6378    0.6288      3012
           8     0.5869    0.7203    0.6468      2982
           9     0.7580    0.6398    0.6939      2982

    accuracy                         0.6171     30000
   macro avg     0.6482    0.6170    0.6214     30000
weighted avg     0.6485    0.6171    0.6216     30000





0.46859969232566745




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.997666077959142




CCA coefficients mean non-concern: 0.9976305890945235




Linear CKA concern: 0.9966186149497752




Linear CKA non-concern: 0.9989654338800552




Kernel CKA concern: 0.9925256237311781




Kernel CKA non-concern: 0.9960855062451558




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.199307918548584




Evaluate the pruned model 8




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2187




Precision: 0.6485, Recall: 0.6168, F1-Score: 0.6212




              precision    recall  f1-score   support

           0     0.5299    0.4947    0.5117      2992
           1     0.6957    0.4997    0.5816      2992
           2     0.7021    0.6112    0.6535      3012
           3     0.3469    0.6331    0.4482      2998
           4     0.7216    0.7793    0.7494      2973
           5     0.8420    0.7606    0.7992      3054
           6     0.6817    0.3966    0.5015      3003
           7     0.6286    0.6282    0.6284      3012
           8     0.5826    0.7217    0.6447      2982
           9     0.7542    0.6432    0.6943      2982

    accuracy                         0.6169     30000
   macro avg     0.6485    0.6168    0.6212     30000
weighted avg     0.6488    0.6169    0.6214     30000





0.46859969232566745




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9975188459272643




CCA coefficients mean non-concern: 0.9973703104882082




Linear CKA concern: 0.9988225065148738




Linear CKA non-concern: 0.9988913476663341




Kernel CKA concern: 0.9963404205570426




Kernel CKA non-concern: 0.9955162847199245




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2031538486480713




Evaluate the pruned model 9




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2162




Precision: 0.6482, Recall: 0.6174, F1-Score: 0.6218




              precision    recall  f1-score   support

           0     0.5365    0.4910    0.5127      2992
           1     0.6875    0.5043    0.5818      2992
           2     0.7014    0.6129    0.6541      3012
           3     0.3480    0.6318    0.4488      2998
           4     0.7249    0.7746    0.7489      2973
           5     0.8399    0.7642    0.8003      3054
           6     0.6818    0.4003    0.5044      3003
           7     0.6191    0.6368    0.6278      3012
           8     0.5859    0.7183    0.6454      2982
           9     0.7571    0.6398    0.6936      2982

    accuracy                         0.6175     30000
   macro avg     0.6482    0.6174    0.6218     30000
weighted avg     0.6485    0.6175    0.6220     30000





0.46859969232566745




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9981027037583977




CCA coefficients mean non-concern: 0.9978571501133603




Linear CKA concern: 0.9982183321056768




Linear CKA non-concern: 0.9990607973955872




Kernel CKA concern: 0.9957921410126994




Kernel CKA non-concern: 0.9966457675774045




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.195718765258789




In [9]:
df_list = [report_to_df(df) for df in result_list]
new_df = append_nth_row(df_list)
csv_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
new_df.to_csv(f"results/{csv_name}.csv", index=False)