In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint
from src.models.evaluate import (
    evaluate_model,
    get_sparsity,
    get_similarity,
    get_perplexity,
)
from src.utils.sampling import SamplingDataset
from src.pruning.prune import (
    prune_concern_identification,
)
from src.utils.helper import report_to_df, append_nth_row

In [3]:
name = "bert-4-128-yahoo"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ratio = 0.4
seed = 44
include_layers = ["intermediate", "output"]
exclude_layers = ["attention"]

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-10-15 23:10:28


In [5]:
config = Config(name, device)
num_labels = config.config["num_labels"]
model = load_model(config)

Loading the model.




{

'architectures'

: 

'bert'

,
 

'dataset_name'

: 

'YahooAnswersTopics'

,
 

'model_name'

: 

'models/bert-4-128-yahoo'

,
 

'num_labels'

: 

10

,
 

'tokenizer_name'

: 

'fabriceyhc/bert-base-uncased-yahoo_answers_topics'

}




The model models/bert-4-128-yahoo is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.




train.pkl is loaded from cache.




valid.pkl is loaded from cache.




test.pkl is loaded from cache.




The dataset YahooAnswersTopics is loaded




{

'config_name'

: 

'yahoo_answers_topics'

,
 

'features'

: 

{'first_column': 'question_title', 'second_column': 'topic'}

,
 

'path'

: 

'yahoo_answers_topics'

}




In [7]:
from src.utils.load import load_cache
from src.utils.data_class import CustomEmbeddingDataset
from torch.utils.data import DataLoader

generated = load_cache(
    "datasets/generated_dataset/embedding_based/4_128-yahoo",
    "4_128-yahoo_top1.pkl",
)

4_128-yahoo_top1.pkl is loaded from cache.




In [8]:
generated["embeddings"] = generated.pop("example_list")
generated["labels"] = generated.pop("example_label")
generated["attention_mask"] = generated.pop("attn_list")

In [9]:
generated_data = CustomEmbeddingDataset(generated)
generated_dataloder = DataLoader(
    generated_data,
    batch_size=4,
)

In [10]:
# print("Evaluate the original model")
# result = evaluate_model(model, config, test_dataloader)

In [11]:
result_list = []

for concern in range(config.num_labels):
    config.init_seed()
    positive_samples = SamplingDataset(
        generated_dataloder,
        config,
        concern,
        num_samples,
        True,
        4,
        resample=False,
    )
    negative_samples = SamplingDataset(
        generated_dataloder,
        config,
        concern,
        num_samples,
        False,
        4,
        resample=False,
    )
    all_samples = SamplingDataset(
        generated_dataloder,
        config,
        200,
        num_samples,
        False,
        4,
        resample=False,
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        config,
        positive_samples,
        negative_samples,
        include_layers=include_layers,
        exclude_layers=exclude_layers,
        sparsity_ratio=ratio,
        keep_dim=True,
        method="unstructed",
    )
    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, config, test_dataloader)
    result_list.append(result)
    get_sparsity(module)

    get_similarity(model, module, valid_dataloader, concern, num_samples, config)
    print("original model's perplexity")
    get_perplexity(model, valid_dataloader, config)
    print("pruned model's perplexity")
    get_perplexity(module, valid_dataloader, config)

Evaluate the pruned model 0




Evaluating the model:   0%|                                                              | 0/1875 [00:00<?, ?i…

Loss: 1.2221




Precision: 0.6492, Recall: 0.6147, F1-Score: 0.6197




              precision    recall  f1-score   support

           0     0.5451    0.4783    0.5095      2992
           1     0.7006    0.4779    0.5682      2992
           2     0.6981    0.6125    0.6525      3012
           3     0.3385    0.6428    0.4434      2998
           4     0.7268    0.7783    0.7517      2973
           5     0.8416    0.7623    0.8000      3054
           6     0.6769    0.4053    0.5070      3003
           7     0.6220    0.6338    0.6279      3012
           8     0.5822    0.7173    0.6427      2982
           9     0.7599    0.6388    0.6941      2982

    accuracy                         0.6148     30000
   macro avg     0.6492    0.6147    0.6197     30000
weighted avg     0.6495    0.6148    0.6199     30000





0.3132945569804176




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9991838282142733




CCA coefficients mean non-concern: 0.9990230188786895




Linear CKA concern: 0.9997931582221422




Linear CKA non-concern: 0.9996089878172577




Kernel CKA concern: 0.9991969955127955




Kernel CKA non-concern: 0.9986086521897154




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2069835662841797




Evaluate the pruned model 1




Evaluating the model:   0%|                                                              | 0/1875 [00:00<?, ?i…

Loss: 1.2222




Precision: 0.6491, Recall: 0.6148, F1-Score: 0.6196




              precision    recall  f1-score   support

           0     0.5411    0.4813    0.5095      2992
           1     0.6961    0.4823    0.5698      2992
           2     0.7000    0.6112    0.6526      3012
           3     0.3394    0.6408    0.4438      2998
           4     0.7232    0.7804    0.7507      2973
           5     0.8413    0.7620    0.7997      3054
           6     0.6833    0.4003    0.5048      3003
           7     0.6210    0.6338    0.6273      3012
           8     0.5819    0.7183    0.6430      2982
           9     0.7641    0.6375    0.6951      2982

    accuracy                         0.6149     30000
   macro avg     0.6491    0.6148    0.6196     30000
weighted avg     0.6494    0.6149    0.6198     30000





0.3132945569804176




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9993112079352184




CCA coefficients mean non-concern: 0.999138037762555




Linear CKA concern: 0.9996680797012191




Linear CKA non-concern: 0.9997739938882251




Kernel CKA concern: 0.9990254358409865




Kernel CKA non-concern: 0.9991121165286747




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2095229625701904




Evaluate the pruned model 2




Evaluating the model:   0%|                                                              | 0/1875 [00:00<?, ?i…

Loss: 1.2214




Precision: 0.6478, Recall: 0.6143, F1-Score: 0.6192




              precision    recall  f1-score   support

           0     0.5357    0.4820    0.5074      2992
           1     0.7028    0.4773    0.5685      2992
           2     0.6982    0.6099    0.6511      3012
           3     0.3405    0.6414    0.4449      2998
           4     0.7228    0.7797    0.7502      2973
           5     0.8438    0.7587    0.7990      3054
           6     0.6727    0.4073    0.5074      3003
           7     0.6233    0.6285    0.6259      3012
           8     0.5860    0.7150    0.6441      2982
           9     0.7524    0.6439    0.6939      2982

    accuracy                         0.6144     30000
   macro avg     0.6478    0.6143    0.6192     30000
weighted avg     0.6481    0.6144    0.6194     30000





0.3132945569804176




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9992511512095248




CCA coefficients mean non-concern: 0.9990982194206146




Linear CKA concern: 0.9997852780968508




Linear CKA non-concern: 0.9996594859204738




Kernel CKA concern: 0.9993532007741476




Kernel CKA non-concern: 0.9986701738756806




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2056217193603516




Evaluate the pruned model 3




Evaluating the model:   0%|                                                              | 0/1875 [00:00<?, ?i…

Loss: 1.2215




Precision: 0.6491, Recall: 0.6158, F1-Score: 0.6206




              precision    recall  f1-score   support

           0     0.5409    0.4816    0.5095      2992
           1     0.6997    0.4866    0.5740      2992
           2     0.6995    0.6135    0.6537      3012
           3     0.3417    0.6394    0.4453      2998
           4     0.7205    0.7804    0.7492      2973
           5     0.8410    0.7620    0.7995      3054
           6     0.6776    0.4053    0.5072      3003
           7     0.6253    0.6328    0.6290      3012
           8     0.5827    0.7180    0.6433      2982
           9     0.7625    0.6385    0.6950      2982

    accuracy                         0.6159     30000
   macro avg     0.6491    0.6158    0.6206     30000
weighted avg     0.6494    0.6159    0.6208     30000





0.3132945569804176




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9993909199313915




CCA coefficients mean non-concern: 0.9993931545087008




Linear CKA concern: 0.9997933573122244




Linear CKA non-concern: 0.999888655491356




Kernel CKA concern: 0.9994387593258729




Kernel CKA non-concern: 0.9995569311600507




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2038347721099854




Evaluate the pruned model 4




Evaluating the model:   0%|                                                                                   …

Loss: 1.2220




Precision: 0.6490, Recall: 0.6150, F1-Score: 0.6198




              precision    recall  f1-score   support

           0     0.5302    0.4893    0.5090      2992
           1     0.7059    0.4756    0.5683      2992
           2     0.7063    0.6069    0.6529      3012
           3     0.3417    0.6381    0.4451      2998
           4     0.7192    0.7797    0.7482      2973
           5     0.8414    0.7606    0.7990      3054
           6     0.6808    0.4063    0.5089      3003
           7     0.6224    0.6341    0.6282      3012
           8     0.5833    0.7183    0.6438      2982
           9     0.7583    0.6408    0.6947      2982

    accuracy                         0.6151     30000
   macro avg     0.6490    0.6150    0.6198     30000
weighted avg     0.6493    0.6151    0.6200     30000





0.3132945569804176




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9988585847868727




CCA coefficients mean non-concern: 0.9992721482031113




Linear CKA concern: 0.9992365177631811




Linear CKA non-concern: 0.9997622422773229




Kernel CKA concern: 0.9983256331364885




Kernel CKA non-concern: 0.999113602141




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.207387685775757




Evaluate the pruned model 5




Evaluating the model:   0%|                                  | 0/1875 [00:00<?, ?it/s]

Loss: 1.2215




Precision: 0.6493, Recall: 0.6159, F1-Score: 0.6204




              precision    recall  f1-score   support

           0     0.5417    0.4816    0.5099      2992
           1     0.7003    0.4820    0.5710      2992
           2     0.6994    0.6172    0.6557      3012
           3     0.3429    0.6404    0.4466      2998
           4     0.7224    0.7773    0.7489      2973
           5     0.8381    0.7646    0.7997      3054
           6     0.6818    0.4009    0.5049      3003
           7     0.6204    0.6381    0.6291      3012
           8     0.5808    0.7186    0.6424      2982
           9     0.7657    0.6378    0.6959      2982

    accuracy                         0.6160     30000
   macro avg     0.6493    0.6159    0.6204     30000
weighted avg     0.6496    0.6160    0.6206     30000





0.3132945569804176




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9988960199052307




CCA coefficients mean non-concern: 0.9993318408981576




Linear CKA concern: 0.9974462503822811




Linear CKA non-concern: 0.999796698498268




Kernel CKA concern: 0.9968871770380049




Kernel CKA non-concern: 0.9992254106029491




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2058043479919434




Evaluate the pruned model 6




Evaluating the model:   0%|                                  | 0/1875 [00:00<?, ?it/s]

Loss: 1.2194




Precision: 0.6473, Recall: 0.6160, F1-Score: 0.6204




              precision    recall  f1-score   support

           0     0.5295    0.4886    0.5083      2992
           1     0.7017    0.4826    0.5719      2992
           2     0.6983    0.6132    0.6530      3012
           3     0.3468    0.6334    0.4482      2998
           4     0.7229    0.7773    0.7491      2973
           5     0.8406    0.7613    0.7990      3054
           6     0.6718    0.4083    0.5079      3003
           7     0.6173    0.6394    0.6282      3012
           8     0.5875    0.7140    0.6446      2982
           9     0.7564    0.6415    0.6942      2982

    accuracy                         0.6161     30000
   macro avg     0.6473    0.6160    0.6204     30000
weighted avg     0.6476    0.6161    0.6206     30000





0.3132945569804176




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9994460061348205




CCA coefficients mean non-concern: 0.9992565406492456




Linear CKA concern: 0.9998718059235012




Linear CKA non-concern: 0.9996859322826633




Kernel CKA concern: 0.9995302004690769




Kernel CKA non-concern: 0.9988854270704233




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.1988377571105957




Evaluate the pruned model 7




Evaluating the model:   0%|                                  | 0/1875 [00:00<?, ?it/s]

Loss: 1.2217




Precision: 0.6488, Recall: 0.6148, F1-Score: 0.6197




              precision    recall  f1-score   support

           0     0.5405    0.4813    0.5092      2992
           1     0.6994    0.4766    0.5669      2992
           2     0.7045    0.6079    0.6526      3012
           3     0.3395    0.6391    0.4435      2998
           4     0.7215    0.7773    0.7484      2973
           5     0.8429    0.7606    0.7997      3054
           6     0.6751    0.4096    0.5098      3003
           7     0.6196    0.6388    0.6291      3012
           8     0.5843    0.7183    0.6444      2982
           9     0.7603    0.6382    0.6939      2982

    accuracy                         0.6149     30000
   macro avg     0.6488    0.6148    0.6197     30000
weighted avg     0.6491    0.6149    0.6200     30000





0.3132945569804176




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9993654872289072




CCA coefficients mean non-concern: 0.9992382815842029




Linear CKA concern: 0.9997688044973708




Linear CKA non-concern: 0.9996651286881153




Kernel CKA concern: 0.9994331083670256




Kernel CKA non-concern: 0.9987921616978278




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2057974338531494




Evaluate the pruned model 8




Evaluating the model:   0%|                                  | 0/1875 [00:00<?, ?it/s]

Loss: 1.2223




Precision: 0.6487, Recall: 0.6146, F1-Score: 0.6195




              precision    recall  f1-score   support

           0     0.5342    0.4833    0.5075      2992
           1     0.7049    0.4749    0.5675      2992
           2     0.7008    0.6072    0.6507      3012
           3     0.3404    0.6418    0.4449      2998
           4     0.7213    0.7810    0.7500      2973
           5     0.8432    0.7587    0.7987      3054
           6     0.6785    0.4076    0.5093      3003
           7     0.6248    0.6315    0.6281      3012
           8     0.5840    0.7170    0.6437      2982
           9     0.7550    0.6429    0.6944      2982

    accuracy                         0.6147     30000
   macro avg     0.6487    0.6146    0.6195     30000
weighted avg     0.6490    0.6147    0.6197     30000





0.3132945569804176




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9992858397008709




CCA coefficients mean non-concern: 0.9991366535262257




Linear CKA concern: 0.9998514389302847




Linear CKA non-concern: 0.9996434763897116




Kernel CKA concern: 0.9994947965784066




Kernel CKA non-concern: 0.9987133376607568




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2067880630493164




Evaluate the pruned model 9




Evaluating the model:   0%|                                  | 0/1875 [00:00<?, ?it/s]

Loss: 1.2224




Precision: 0.6501, Recall: 0.6150, F1-Score: 0.6198




              precision    recall  f1-score   support

           0     0.5479    0.4759    0.5094      2992
           1     0.7056    0.4783    0.5701      2992
           2     0.6980    0.6155    0.6542      3012
           3     0.3381    0.6428    0.4431      2998
           4     0.7187    0.7854    0.7506      2973
           5     0.8441    0.7606    0.8002      3054
           6     0.6831    0.4013    0.5056      3003
           7     0.6246    0.6325    0.6285      3012
           8     0.5802    0.7203    0.6427      2982
           9     0.7605    0.6378    0.6938      2982

    accuracy                         0.6151     30000
   macro avg     0.6501    0.6150    0.6198     30000
weighted avg     0.6504    0.6151    0.6200     30000





0.3132945569804176




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9993911381479462




CCA coefficients mean non-concern: 0.9991455268014446




Linear CKA concern: 0.9998442182181438




Linear CKA non-concern: 0.9997247138430085




Kernel CKA concern: 0.9995779488528225




Kernel CKA non-concern: 0.9990545271066533




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.207749366760254




In [12]:
df_list = [report_to_df(df) for df in result_list]
new_df = append_nth_row(df_list)
csv_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
new_df.to_csv(f"results/{csv_name}.csv", index=False)