In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint
from src.models.evaluate import (
    evaluate_model,
    get_sparsity,
    get_similarity,
    get_perplexity,
)
from src.utils.sampling import SamplingDataset
from src.pruning.prune import (
    prune_concern_identification,
)
from src.utils.helper import report_to_df, append_nth_row

In [3]:
name = "YahooAnswersTopics"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ratio = 0.4
seed = 44
include_layers = ["intermediate", "output"]
exclude_layers = ["attention"]

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-10-15 13:40:11


In [5]:
config = Config(name, device)
num_labels = config.config["num_labels"]
model = load_model(config)

Loading the model.




{

'architectures'

: 

'bert'

,
 

'dataset_name'

: 

'YahooAnswersTopics'

,
 

'model_name'

: 

'models/bert-4-128-yahoo'

,
 

'num_labels'

: 

10

,
 

'tokenizer_name'

: 

'fabriceyhc/bert-base-uncased-yahoo_answers_topics'

}




The model models/bert-4-128-yahoo is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.




train.pkl is loaded from cache.




valid.pkl is loaded from cache.




test.pkl is loaded from cache.




The dataset YahooAnswersTopics is loaded




{

'config_name'

: 

'yahoo_answers_topics'

,
 

'features'

: 

{'first_column': 'question_title', 'second_column': 'topic'}

,
 

'path'

: 

'yahoo_answers_topics'

}




In [7]:
# print("Evaluate the original model")
# result = evaluate_model(model, config, test_dataloader)

In [8]:
result_list = []

for concern in range(config.num_labels):
    config.init_seed()
    positive_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        True,
        4,
        resample=False,
    )
    negative_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        False,
        4,
        resample=False,
    )
    all_samples = SamplingDataset(
        train_dataloader,
        config,
        200,
        num_samples,
        False,
        4,
        resample=False,
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        config,
        positive_samples,
        negative_samples,
        include_layers=["intermediate", "output"],
        exclude_layers=["attention"],
        sparsity_ratio=ratio,
        keep_dim=True,
        method="unstructed",
    )
    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, config, test_dataloader)
    result_list.append(result)
    get_sparsity(module)

    get_similarity(model, module, valid_dataloader, concern, num_samples, config)
    print("original model's perplexity")
    get_perplexity(model, valid_dataloader, config)
    print("pruned model's perplexity")
    get_perplexity(module, valid_dataloader, config)

Evaluate the pruned model 0




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2223




Precision: 0.6482, Recall: 0.6152, F1-Score: 0.6200




              precision    recall  f1-score   support

           0     0.5332    0.4853    0.5081      2992
           1     0.7011    0.4736    0.5653      2992
           2     0.6966    0.6122    0.6517      3012
           3     0.3435    0.6411    0.4473      2998
           4     0.7235    0.7790    0.7502      2973
           5     0.8421    0.7613    0.7997      3054
           6     0.6707    0.4096    0.5086      3003
           7     0.6204    0.6381    0.6291      3012
           8     0.5868    0.7140    0.6442      2982
           9     0.7640    0.6382    0.6954      2982

    accuracy                         0.6153     30000
   macro avg     0.6482    0.6152    0.6200     30000
weighted avg     0.6485    0.6153    0.6202     30000





0.23429984616283372




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9998137869762803




CCA coefficients mean non-concern: 0.9998025462296554




Linear CKA concern: 0.9999742832177394




Linear CKA non-concern: 0.9999685525106372




Kernel CKA concern: 0.9999222723494795




Kernel CKA non-concern: 0.9998873481444548




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2059054374694824




Evaluate the pruned model 1




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2226




Precision: 0.6481, Recall: 0.6150, F1-Score: 0.6197




              precision    recall  f1-score   support

           0     0.5367    0.4816    0.5077      2992
           1     0.7010    0.4733    0.5650      2992
           2     0.6968    0.6119    0.6516      3012
           3     0.3433    0.6421    0.4474      2998
           4     0.7232    0.7787    0.7499      2973
           5     0.8412    0.7613    0.7992      3054
           6     0.6710    0.4089    0.5082      3003
           7     0.6193    0.6394    0.6292      3012
           8     0.5846    0.7146    0.6431      2982
           9     0.7637    0.6385    0.6955      2982

    accuracy                         0.6151     30000
   macro avg     0.6481    0.6150    0.6197     30000
weighted avg     0.6484    0.6151    0.6199     30000





0.23429984616283372




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9998315218823854




CCA coefficients mean non-concern: 0.9998147154532449




Linear CKA concern: 0.9999721379513291




Linear CKA non-concern: 0.9999733083141116




Kernel CKA concern: 0.9999302381822566




Kernel CKA non-concern: 0.9998926826942484




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.207026720046997




Evaluate the pruned model 2




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2222




Precision: 0.6481, Recall: 0.6149, F1-Score: 0.6197




              precision    recall  f1-score   support

           0     0.5362    0.4833    0.5083      2992
           1     0.7011    0.4719    0.5641      2992
           2     0.6954    0.6125    0.6514      3012
           3     0.3424    0.6414    0.4464      2998
           4     0.7232    0.7787    0.7499      2973
           5     0.8423    0.7610    0.7996      3054
           6     0.6697    0.4099    0.5086      3003
           7     0.6224    0.6371    0.6297      3012
           8     0.5850    0.7146    0.6433      2982
           9     0.7631    0.6385    0.6953      2982

    accuracy                         0.6150     30000
   macro avg     0.6481    0.6149    0.6197     30000
weighted avg     0.6484    0.6150    0.6199     30000





0.23429984616283372




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9998072029115264




CCA coefficients mean non-concern: 0.9997899243293964




Linear CKA concern: 0.9999760186193344




Linear CKA non-concern: 0.999968046576815




Kernel CKA concern: 0.9999351344622254




Kernel CKA non-concern: 0.9998812014721399




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.205592155456543




Evaluate the pruned model 3




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2223




Precision: 0.6483, Recall: 0.6157, F1-Score: 0.6203




              precision    recall  f1-score   support

           0     0.5349    0.4840    0.5082      2992
           1     0.6994    0.4759    0.5664      2992
           2     0.6971    0.6129    0.6523      3012
           3     0.3440    0.6401    0.4475      2998
           4     0.7239    0.7797    0.7508      2973
           5     0.8425    0.7620    0.8002      3054
           6     0.6708    0.4099    0.5089      3003
           7     0.6192    0.6388    0.6289      3012
           8     0.5867    0.7150    0.6445      2982
           9     0.7644    0.6385    0.6958      2982

    accuracy                         0.6158     30000
   macro avg     0.6483    0.6157    0.6203     30000
weighted avg     0.6486    0.6158    0.6205     30000





0.23429984616283372




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9998084300401567




CCA coefficients mean non-concern: 0.999799880408359




Linear CKA concern: 0.9999536190567219




Linear CKA non-concern: 0.999969405468412




Kernel CKA concern: 0.9998868924794098




Kernel CKA non-concern: 0.9998839995928296




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2062530517578125




Evaluate the pruned model 4




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2225




Precision: 0.6482, Recall: 0.6152, F1-Score: 0.6199




              precision    recall  f1-score   support

           0     0.5362    0.4823    0.5078      2992
           1     0.7000    0.4749    0.5659      2992
           2     0.6967    0.6132    0.6523      3012
           3     0.3426    0.6398    0.4463      2998
           4     0.7232    0.7793    0.7502      2973
           5     0.8428    0.7603    0.7994      3054
           6     0.6721    0.4089    0.5085      3003
           7     0.6225    0.6384    0.6304      3012
           8     0.5837    0.7160    0.6431      2982
           9     0.7625    0.6385    0.6950      2982

    accuracy                         0.6153     30000
   macro avg     0.6482    0.6152    0.6199     30000
weighted avg     0.6485    0.6153    0.6201     30000





0.23429984616283372




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9997519452783521




CCA coefficients mean non-concern: 0.9997983126823534




Linear CKA concern: 0.9999609003711746




Linear CKA non-concern: 0.9999603253782172




Kernel CKA concern: 0.9999064358749452




Kernel CKA non-concern: 0.9998653718452847




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2064731121063232




Evaluate the pruned model 5




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2222




Precision: 0.6485, Recall: 0.6149, F1-Score: 0.6197




              precision    recall  f1-score   support

           0     0.5366    0.4826    0.5082      2992
           1     0.7012    0.4706    0.5632      2992
           2     0.6969    0.6145    0.6531      3012
           3     0.3416    0.6418    0.4458      2998
           4     0.7241    0.7767    0.7494      2973
           5     0.8418    0.7616    0.7997      3054
           6     0.6710    0.4103    0.5092      3003
           7     0.6228    0.6381    0.6304      3012
           8     0.5840    0.7156    0.6432      2982
           9     0.7649    0.6372    0.6952      2982

    accuracy                         0.6150     30000
   macro avg     0.6485    0.6149    0.6197     30000
weighted avg     0.6488    0.6150    0.6200     30000





0.23429984616283372




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9997346912704375




CCA coefficients mean non-concern: 0.9997832467137365




Linear CKA concern: 0.999822439271243




Linear CKA non-concern: 0.9999628740367791




Kernel CKA concern: 0.9997873225814013




Kernel CKA non-concern: 0.9998831168501817




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.20586895942688




Evaluate the pruned model 6




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2224




Precision: 0.6483, Recall: 0.6153, F1-Score: 0.6201




              precision    recall  f1-score   support

           0     0.5331    0.4850    0.5079      2992
           1     0.7009    0.4739    0.5655      2992
           2     0.6977    0.6122    0.6522      3012
           3     0.3430    0.6408    0.4468      2998
           4     0.7243    0.7787    0.7505      2973
           5     0.8426    0.7606    0.7995      3054
           6     0.6696    0.4109    0.5093      3003
           7     0.6216    0.6381    0.6298      3012
           8     0.5864    0.7146    0.6442      2982
           9     0.7639    0.6378    0.6952      2982

    accuracy                         0.6154     30000
   macro avg     0.6483    0.6153    0.6201     30000
weighted avg     0.6486    0.6154    0.6203     30000





0.23429984616283372




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9998349467237888




CCA coefficients mean non-concern: 0.9997869542676253




Linear CKA concern: 0.9999753392699955




Linear CKA non-concern: 0.9999586859991613




Kernel CKA concern: 0.999921420792039




Kernel CKA non-concern: 0.9998475494842061




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2064051628112793




Evaluate the pruned model 7




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2223




Precision: 0.6481, Recall: 0.6150, F1-Score: 0.6198




              precision    recall  f1-score   support

           0     0.5345    0.4836    0.5078      2992
           1     0.7002    0.4723    0.5641      2992
           2     0.6968    0.6119    0.6516      3012
           3     0.3423    0.6401    0.4461      2998
           4     0.7243    0.7783    0.7503      2973
           5     0.8422    0.7620    0.8001      3054
           6     0.6697    0.4113    0.5096      3003
           7     0.6202    0.6391    0.6295      3012
           8     0.5858    0.7143    0.6437      2982
           9     0.7651    0.6368    0.6951      2982

    accuracy                         0.6151     30000
   macro avg     0.6481    0.6150    0.6198     30000
weighted avg     0.6484    0.6151    0.6200     30000





0.23429984616283372




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9997839689083851




CCA coefficients mean non-concern: 0.9997794749130167




Linear CKA concern: 0.9999422814656356




Linear CKA non-concern: 0.9999597471961174




Kernel CKA concern: 0.999876972719151




Kernel CKA non-concern: 0.9998615644146218




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2063405513763428




Evaluate the pruned model 8




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2225




Precision: 0.6480, Recall: 0.6152, F1-Score: 0.6199




              precision    recall  f1-score   support

           0     0.5374    0.4830    0.5087      2992
           1     0.7002    0.4746    0.5657      2992
           2     0.6960    0.6112    0.6509      3012
           3     0.3432    0.6408    0.4470      2998
           4     0.7233    0.7790    0.7501      2973
           5     0.8427    0.7613    0.7999      3054
           6     0.6701    0.4099    0.5087      3003
           7     0.6197    0.6384    0.6289      3012
           8     0.5856    0.7150    0.6438      2982
           9     0.7619    0.6385    0.6948      2982

    accuracy                         0.6153     30000
   macro avg     0.6480    0.6152    0.6199     30000
weighted avg     0.6483    0.6153    0.6201     30000





0.23429984616283372




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9997707122951452




CCA coefficients mean non-concern: 0.9997475984993133




Linear CKA concern: 0.9999686493127394




Linear CKA non-concern: 0.9999499985166191




Kernel CKA concern: 0.9999045425864466




Kernel CKA non-concern: 0.9998044081194628




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2063791751861572




Evaluate the pruned model 9




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2219




Precision: 0.6482, Recall: 0.6154, F1-Score: 0.6201




              precision    recall  f1-score   support

           0     0.5341    0.4846    0.5081      2992
           1     0.7010    0.4749    0.5662      2992
           2     0.6968    0.6119    0.6516      3012
           3     0.3437    0.6408    0.4474      2998
           4     0.7239    0.7787    0.7503      2973
           5     0.8419    0.7620    0.7999      3054
           6     0.6701    0.4099    0.5087      3003
           7     0.6206    0.6381    0.6292      3012
           8     0.5860    0.7140    0.6437      2982
           9     0.7639    0.6392    0.6960      2982

    accuracy                         0.6155     30000
   macro avg     0.6482    0.6154    0.6201     30000
weighted avg     0.6485    0.6155    0.6203     30000





0.23429984616283372




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9998238725486269




CCA coefficients mean non-concern: 0.999804917124792




Linear CKA concern: 0.9999651643942417




Linear CKA non-concern: 0.9999621689864914




Kernel CKA concern: 0.9999194270216412




Kernel CKA non-concern: 0.9998768633063776




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2047946453094482




In [9]:
df_list = [report_to_df(df) for df in result_list]
new_df = append_nth_row(df_list)
csv_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
new_df.to_csv(f"results/{csv_name}.csv", index=False)