In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint
from src.models.evaluate import (
    evaluate_model,
    get_sparsity,
    get_similarity,
    get_perplexity,
)
from src.utils.sampling import SamplingDataset
from src.pruning.prune import (
    prune_concern_identification,
)
from src.utils.helper import report_to_df, append_nth_row

In [3]:
name = "bert-4-128-yahoo"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ratio = 0.3
seed = 44
include_layers = ["intermediate", "output"]
exclude_layers = ["attention"]

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-10-15 23:02:02


In [5]:
config = Config(name, device)
num_labels = config.config["num_labels"]
model = load_model(config)

Loading the model.




{

'architectures'

: 

'bert'

,
 

'dataset_name'

: 

'YahooAnswersTopics'

,
 

'model_name'

: 

'models/bert-4-128-yahoo'

,
 

'num_labels'

: 

10

,
 

'tokenizer_name'

: 

'fabriceyhc/bert-base-uncased-yahoo_answers_topics'

}




The model models/bert-4-128-yahoo is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.




train.pkl is loaded from cache.




valid.pkl is loaded from cache.




test.pkl is loaded from cache.




The dataset YahooAnswersTopics is loaded




{

'config_name'

: 

'yahoo_answers_topics'

,
 

'features'

: 

{'first_column': 'question_title', 'second_column': 'topic'}

,
 

'path'

: 

'yahoo_answers_topics'

}




In [7]:
from src.utils.load import load_cache
from src.utils.data_class import CustomEmbeddingDataset
from torch.utils.data import DataLoader

generated = load_cache(
    "datasets/generated_dataset/embedding_based/4_128-yahoo",
    "4_128-yahoo_top1.pkl",
)

4_128-yahoo_top1.pkl is loaded from cache.




In [8]:
generated["embeddings"] = generated.pop("example_list")
generated["labels"] = generated.pop("example_label")
generated["attention_mask"] = generated.pop("attn_list")

In [9]:
generated_data = CustomEmbeddingDataset(generated)
generated_dataloder = DataLoader(
    generated_data,
    batch_size=4,
)

In [10]:
# print("Evaluate the original model")
# result = evaluate_model(model, config, test_dataloader)

In [11]:
result_list = []

for concern in range(config.num_labels):
    config.init_seed()
    positive_samples = SamplingDataset(
        generated_dataloder,
        config,
        concern,
        num_samples,
        True,
        4,
        resample=False,
    )
    negative_samples = SamplingDataset(
        generated_dataloder,
        config,
        concern,
        num_samples,
        False,
        4,
        resample=False,
    )
    all_samples = SamplingDataset(
        generated_dataloder,
        config,
        200,
        num_samples,
        False,
        4,
        resample=False,
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        config,
        positive_samples,
        negative_samples,
        include_layers=include_layers,
        exclude_layers=exclude_layers,
        sparsity_ratio=ratio,
        keep_dim=True,
        method="unstructed",
    )
    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, config, test_dataloader)
    result_list.append(result)
    get_sparsity(module)

    get_similarity(model, module, valid_dataloader, concern, num_samples, config)
    print("original model's perplexity")
    get_perplexity(model, valid_dataloader, config)
    print("pruned model's perplexity")
    get_perplexity(module, valid_dataloader, config)

Evaluate the pruned model 0




Evaluating the model:   0%|                                                                                   …

Loss: 1.2233




Precision: 0.6485, Recall: 0.6149, F1-Score: 0.6198




              precision    recall  f1-score   support

           0     0.5371    0.4820    0.5080      2992
           1     0.6987    0.4759    0.5662      2992
           2     0.6954    0.6102    0.6500      3012
           3     0.3411    0.6421    0.4455      2998
           4     0.7277    0.7783    0.7522      2973
           5     0.8416    0.7620    0.7998      3054
           6     0.6720    0.4073    0.5072      3003
           7     0.6206    0.6375    0.6289      3012
           8     0.5853    0.7153    0.6438      2982
           9     0.7652    0.6382    0.6959      2982

    accuracy                         0.6150     30000
   macro avg     0.6485    0.6149    0.6198     30000
weighted avg     0.6488    0.6150    0.6200     30000





0.23429984616283372




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9996551087104261




CCA coefficients mean non-concern: 0.9996085155251102




Linear CKA concern: 0.9999156275836261




Linear CKA non-concern: 0.9998635946992503




Kernel CKA concern: 0.9997203300827909




Kernel CKA non-concern: 0.9995478567311875




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.208841323852539




Evaluate the pruned model 1




Evaluating the model:   0%|                                                              | 0/1875 [00:00<?, ?i…

Loss: 1.2223




Precision: 0.6484, Recall: 0.6155, F1-Score: 0.6201




              precision    recall  f1-score   support

           0     0.5391    0.4820    0.5089      2992
           1     0.6980    0.4783    0.5676      2992
           2     0.6958    0.6129    0.6517      3012
           3     0.3432    0.6408    0.4470      2998
           4     0.7240    0.7807    0.7513      2973
           5     0.8403    0.7616    0.7990      3054
           6     0.6753    0.4066    0.5076      3003
           7     0.6209    0.6378    0.6292      3012
           8     0.5842    0.7156    0.6433      2982
           9     0.7635    0.6388    0.6956      2982

    accuracy                         0.6156     30000
   macro avg     0.6484    0.6155    0.6201     30000
weighted avg     0.6487    0.6156    0.6203     30000





0.23429984616283372




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9997269752386584




CCA coefficients mean non-concern: 0.9996703803669899




Linear CKA concern: 0.9998942607751646




Linear CKA non-concern: 0.9999188941159259




Kernel CKA concern: 0.9996962774500842




Kernel CKA non-concern: 0.999708139459496




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2063677310943604




Evaluate the pruned model 2




Evaluating the model:   0%|                                                              | 0/1875 [00:00<?, ?i…

Loss: 1.2229




Precision: 0.6480, Recall: 0.6144, F1-Score: 0.6194




              precision    recall  f1-score   support

           0     0.5357    0.4820    0.5074      2992
           1     0.7003    0.4756    0.5665      2992
           2     0.6965    0.6102    0.6505      3012
           3     0.3404    0.6408    0.4446      2998
           4     0.7244    0.7790    0.7507      2973
           5     0.8432    0.7587    0.7987      3054
           6     0.6714    0.4089    0.5083      3003
           7     0.6229    0.6351    0.6290      3012
           8     0.5859    0.7146    0.6439      2982
           9     0.7595    0.6395    0.6943      2982

    accuracy                         0.6145     30000
   macro avg     0.6480    0.6144    0.6194     30000
weighted avg     0.6483    0.6145    0.6196     30000





0.23429984616283372




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9996931119735749




CCA coefficients mean non-concern: 0.9996272203813248




Linear CKA concern: 0.9999321893518497




Linear CKA non-concern: 0.9998928352324674




Kernel CKA concern: 0.9998007473112217




Kernel CKA non-concern: 0.9995861090565806




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2079734802246094




Evaluate the pruned model 3




Evaluating the model:   0%|                                                              | 0/1875 [00:00<?, ?i…

Loss: 1.2224




Precision: 0.6484, Recall: 0.6153, F1-Score: 0.6201




              precision    recall  f1-score   support

           0     0.5340    0.4833    0.5074      2992
           1     0.7006    0.4779    0.5682      2992
           2     0.6960    0.6142    0.6526      3012
           3     0.3428    0.6401    0.4465      2998
           4     0.7235    0.7790    0.7502      2973
           5     0.8429    0.7606    0.7997      3054
           6     0.6720    0.4086    0.5082      3003
           7     0.6218    0.6371    0.6294      3012
           8     0.5861    0.7146    0.6440      2982
           9     0.7639    0.6378    0.6952      2982

    accuracy                         0.6154     30000
   macro avg     0.6484    0.6153    0.6201     30000
weighted avg     0.6487    0.6154    0.6203     30000





0.23429984616283372




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9997506003604024




CCA coefficients mean non-concern: 0.9997462026947719




Linear CKA concern: 0.9999379248795243




Linear CKA non-concern: 0.9999505985408415




Kernel CKA concern: 0.9998500099389677




Kernel CKA non-concern: 0.9998316811957938




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2062385082244873




Evaluate the pruned model 4




Evaluating the model:   0%|                                                              | 0/1875 [00:00<?, ?i…

Loss: 1.2222




Precision: 0.6484, Recall: 0.6155, F1-Score: 0.6202




              precision    recall  f1-score   support

           0     0.5331    0.4843    0.5075      2992
           1     0.7007    0.4766    0.5673      2992
           2     0.6987    0.6112    0.6520      3012
           3     0.3431    0.6394    0.4465      2998
           4     0.7236    0.7804    0.7509      2973
           5     0.8401    0.7620    0.7991      3054
           6     0.6757    0.4086    0.5092      3003
           7     0.6206    0.6388    0.6296      3012
           8     0.5877    0.7136    0.6446      2982
           9     0.7607    0.6405    0.6954      2982

    accuracy                         0.6156     30000
   macro avg     0.6484    0.6155    0.6202     30000
weighted avg     0.6487    0.6156    0.6204     30000





0.23429984616283372




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9995886377772498




CCA coefficients mean non-concern: 0.9997259254792548




Linear CKA concern: 0.9998027466885419




Linear CKA non-concern: 0.9999449771731501




Kernel CKA concern: 0.999583703094724




Kernel CKA non-concern: 0.9997948213962443




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.205566883087158




Evaluate the pruned model 5




Evaluating the model:   0%|                                                              | 0/1875 [00:00<?, ?i…

Loss: 1.2225




Precision: 0.6489, Recall: 0.6154, F1-Score: 0.6202




              precision    recall  f1-score   support

           0     0.5348    0.4823    0.5072      2992
           1     0.6998    0.4776    0.5677      2992
           2     0.6962    0.6155    0.6534      3012
           3     0.3420    0.6401    0.4458      2998
           4     0.7240    0.7783    0.7502      2973
           5     0.8417    0.7626    0.8002      3054
           6     0.6755    0.4056    0.5069      3003
           7     0.6216    0.6375    0.6294      3012
           8     0.5855    0.7176    0.6449      2982
           9     0.7674    0.6372    0.6962      2982

    accuracy                         0.6155     30000
   macro avg     0.6489    0.6154    0.6202     30000
weighted avg     0.6492    0.6155    0.6204     30000





0.23429984616283372




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9995681127568334




CCA coefficients mean non-concern: 0.9997238008178295




Linear CKA concern: 0.9993628225877437




Linear CKA non-concern: 0.99994293220772




Kernel CKA concern: 0.9991332433497441




Kernel CKA non-concern: 0.9997787908240484




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.206864833831787




Evaluate the pruned model 6




Evaluating the model:   0%|                                                              | 0/1875 [00:00<?, ?i…

Loss: 1.2215




Precision: 0.6477, Recall: 0.6155, F1-Score: 0.6201




              precision    recall  f1-score   support

           0     0.5315    0.4856    0.5075      2992
           1     0.7009    0.4746    0.5660      2992
           2     0.6963    0.6129    0.6520      3012
           3     0.3452    0.6384    0.4481      2998
           4     0.7241    0.7793    0.7507      2973
           5     0.8398    0.7623    0.7992      3054
           6     0.6696    0.4103    0.5088      3003
           7     0.6215    0.6378    0.6295      3012
           8     0.5856    0.7146    0.6437      2982
           9     0.7622    0.6395    0.6955      2982

    accuracy                         0.6156     30000
   macro avg     0.6477    0.6155    0.6201     30000
weighted avg     0.6480    0.6156    0.6203     30000





0.23429984616283372




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9997800453374122




CCA coefficients mean non-concern: 0.9997070854563624




Linear CKA concern: 0.9999631031134371




Linear CKA non-concern: 0.9999327911522597




Kernel CKA concern: 0.9998619508823108




Kernel CKA non-concern: 0.9997582881144967




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2035937309265137




Evaluate the pruned model 7




Evaluating the model:   0%|                                                              | 0/1875 [00:00<?, ?i…

Loss: 1.2232




Precision: 0.6485, Recall: 0.6145, F1-Score: 0.6195




              precision    recall  f1-score   support

           0     0.5363    0.4816    0.5075      2992
           1     0.7010    0.4733    0.5650      2992
           2     0.6981    0.6096    0.6508      3012
           3     0.3402    0.6404    0.4443      2998
           4     0.7243    0.7783    0.7503      2973
           5     0.8436    0.7597    0.7994      3054
           6     0.6708    0.4099    0.5089      3003
           7     0.6216    0.6391    0.6302      3012
           8     0.5834    0.7166    0.6432      2982
           9     0.7656    0.6365    0.6951      2982

    accuracy                         0.6146     30000
   macro avg     0.6485    0.6145    0.6195     30000
weighted avg     0.6488    0.6146    0.6197     30000





0.23429984616283372




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9997392433227229




CCA coefficients mean non-concern: 0.9996844196915924




Linear CKA concern: 0.9999332276697528




Linear CKA non-concern: 0.9999015224298577




Kernel CKA concern: 0.9998595861954287




Kernel CKA non-concern: 0.9996401541556429




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2092552185058594




Evaluate the pruned model 8




Evaluating the model:   0%|                                                              | 0/1875 [00:00<?, ?i…

Loss: 1.2217




Precision: 0.6482, Recall: 0.6151, F1-Score: 0.6199




              precision    recall  f1-score   support

           0     0.5365    0.4820    0.5077      2992
           1     0.6995    0.4753    0.5660      2992
           2     0.6973    0.6119    0.6518      3012
           3     0.3417    0.6388    0.4452      2998
           4     0.7222    0.7810    0.7505      2973
           5     0.8436    0.7597    0.7994      3054
           6     0.6712    0.4106    0.5095      3003
           7     0.6233    0.6371    0.6301      3012
           8     0.5850    0.7153    0.6436      2982
           9     0.7617    0.6398    0.6955      2982

    accuracy                         0.6152     30000
   macro avg     0.6482    0.6151    0.6199     30000
weighted avg     0.6485    0.6152    0.6201     30000





0.23429984616283372




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9997076448591748




CCA coefficients mean non-concern: 0.9996534914847505




Linear CKA concern: 0.9999466694015099




Linear CKA non-concern: 0.9998926181056845




Kernel CKA concern: 0.9998414409906422




Kernel CKA non-concern: 0.9996295711845393




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.20339298248291




Evaluate the pruned model 9




Evaluating the model:   0%|                                                              | 0/1875 [00:00<?, ?i…

Loss: 1.2232




Precision: 0.6487, Recall: 0.6150, F1-Score: 0.6198




              precision    recall  f1-score   support

           0     0.5372    0.4799    0.5070      2992
           1     0.7016    0.4763    0.5674      2992
           2     0.6956    0.6129    0.6516      3012
           3     0.3416    0.6411    0.4457      2998
           4     0.7225    0.7810    0.7506      2973
           5     0.8435    0.7606    0.7999      3054
           6     0.6759    0.4069    0.5080      3003
           7     0.6201    0.6361    0.6280      3012
           8     0.5839    0.7176    0.6439      2982
           9     0.7654    0.6378    0.6958      2982

    accuracy                         0.6151     30000
   macro avg     0.6487    0.6150    0.6198     30000
weighted avg     0.6490    0.6151    0.6200     30000





0.23429984616283372




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9997484593742814




CCA coefficients mean non-concern: 0.9996403378371458




Linear CKA concern: 0.999942149084453




Linear CKA non-concern: 0.9999096617050434




Kernel CKA concern: 0.9998478028924189




Kernel CKA non-concern: 0.9997045966241747




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.208742141723633




In [12]:
df_list = [report_to_df(df) for df in result_list]
new_df = append_nth_row(df_list)
csv_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
new_df.to_csv(f"results/{csv_name}.csv", index=False)