In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint
from src.models.evaluate import (
    evaluate_model,
    get_sparsity,
    get_similarity,
    get_perplexity,
)
from src.utils.sampling import SamplingDataset
from src.pruning.prune import (
    prune_concern_identification,
)
from src.utils.helper import report_to_df, append_nth_row

In [3]:
name = "bert-4-128-yahoo"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ratio = 0.6
seed = 44
include_layers = ["intermediate", "output"]
exclude_layers = ["attention"]

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-10-15 23:27:20


In [5]:
config = Config(name, device)
num_labels = config.config["num_labels"]
model = load_model(config)

Loading the model.




{

'architectures'

: 

'bert'

,
 

'dataset_name'

: 

'YahooAnswersTopics'

,
 

'model_name'

: 

'models/bert-4-128-yahoo'

,
 

'num_labels'

: 

10

,
 

'tokenizer_name'

: 

'fabriceyhc/bert-base-uncased-yahoo_answers_topics'

}




The model models/bert-4-128-yahoo is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.




train.pkl is loaded from cache.




valid.pkl is loaded from cache.




test.pkl is loaded from cache.




The dataset YahooAnswersTopics is loaded




{

'config_name'

: 

'yahoo_answers_topics'

,
 

'features'

: 

{'first_column': 'question_title', 'second_column': 'topic'}

,
 

'path'

: 

'yahoo_answers_topics'

}




In [7]:
from src.utils.load import load_cache
from src.utils.data_class import CustomEmbeddingDataset
from torch.utils.data import DataLoader

generated = load_cache(
    "datasets/generated_dataset/embedding_based/4_128-yahoo",
    "4_128-yahoo_top1.pkl",
)

4_128-yahoo_top1.pkl is loaded from cache.




In [8]:
generated["embeddings"] = generated.pop("example_list")
generated["labels"] = generated.pop("example_label")
generated["attention_mask"] = generated.pop("attn_list")

In [9]:
generated_data = CustomEmbeddingDataset(generated)
generated_dataloder = DataLoader(
    generated_data,
    batch_size=4,
)

In [10]:
# print("Evaluate the original model")
# result = evaluate_model(model, config, test_dataloader)

In [11]:
result_list = []

for concern in range(config.num_labels):
    config.init_seed()
    positive_samples = SamplingDataset(
        generated_dataloder,
        config,
        concern,
        num_samples,
        True,
        4,
        resample=False,
    )
    negative_samples = SamplingDataset(
        generated_dataloder,
        config,
        concern,
        num_samples,
        False,
        4,
        resample=False,
    )
    all_samples = SamplingDataset(
        generated_dataloder,
        config,
        200,
        num_samples,
        False,
        4,
        resample=False,
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        config,
        positive_samples,
        negative_samples,
        include_layers=include_layers,
        exclude_layers=exclude_layers,
        sparsity_ratio=ratio,
        keep_dim=True,
        method="unstructed",
    )
    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, config, test_dataloader)
    result_list.append(result)
    get_sparsity(module)

    get_similarity(model, module, valid_dataloader, concern, num_samples, config)
    print("original model's perplexity")
    get_perplexity(model, valid_dataloader, config)
    print("pruned model's perplexity")
    get_perplexity(module, valid_dataloader, config)

Evaluate the pruned model 0




Evaluating the model:   0%|                                                                                   …

Loss: 1.2223




Precision: 0.6488, Recall: 0.6136, F1-Score: 0.6187




              precision    recall  f1-score   support

           0     0.5555    0.4733    0.5111      2992
           1     0.6868    0.4933    0.5742      2992
           2     0.6939    0.6112    0.6500      3012
           3     0.3368    0.6441    0.4423      2998
           4     0.7303    0.7649    0.7472      2973
           5     0.8420    0.7590    0.7983      3054
           6     0.6891    0.3949    0.5021      3003
           7     0.6134    0.6375    0.6252      3012
           8     0.5810    0.7183    0.6424      2982
           9     0.7591    0.6392    0.6940      2982

    accuracy                         0.6137     30000
   macro avg     0.6488    0.6136    0.6187     30000
weighted avg     0.6491    0.6137    0.6189     30000





0.46859969232566745




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9972993750836708




CCA coefficients mean non-concern: 0.9966586198897556




Linear CKA concern: 0.9987846825880102




Linear CKA non-concern: 0.9979161077456626




Kernel CKA concern: 0.9958206640895865




Kernel CKA non-concern: 0.9927480777962914




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2159111499786377




Evaluate the pruned model 1




Evaluating the model:   0%|                                                                                   …

Loss: 1.2273




Precision: 0.6511, Recall: 0.6129, F1-Score: 0.6179




              precision    recall  f1-score   support

           0     0.5551    0.4696    0.5088      2992
           1     0.6884    0.4933    0.5748      2992
           2     0.6971    0.6106    0.6510      3012
           3     0.3333    0.6514    0.4410      2998
           4     0.7226    0.7730    0.7470      2973
           5     0.8366    0.7626    0.7979      3054
           6     0.7087    0.3839    0.4981      3003
           7     0.6256    0.6252    0.6254      3012
           8     0.5739    0.7264    0.6412      2982
           9     0.7693    0.6328    0.6944      2982

    accuracy                         0.6130     30000
   macro avg     0.6511    0.6129    0.6179     30000
weighted avg     0.6514    0.6130    0.6181     30000





0.46859969232566745




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9973111268760767




CCA coefficients mean non-concern: 0.9963658073274687




Linear CKA concern: 0.997782562152526




Linear CKA non-concern: 0.9987201540720392




Kernel CKA concern: 0.9933073641926498




Kernel CKA non-concern: 0.9947077085204123




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2353711128234863




Evaluate the pruned model 2




Evaluating the model:   0%|                                                                                   …

Loss: 1.2241




Precision: 0.6484, Recall: 0.6152, F1-Score: 0.6201




              precision    recall  f1-score   support

           0     0.5355    0.4913    0.5125      2992
           1     0.6953    0.4973    0.5799      2992
           2     0.6959    0.6086    0.6493      3012
           3     0.3414    0.6364    0.4444      2998
           4     0.7277    0.7723    0.7493      2973
           5     0.8459    0.7567    0.7988      3054
           6     0.6874    0.3946    0.5014      3003
           7     0.6225    0.6275    0.6250      3012
           8     0.5860    0.7150    0.6441      2982
           9     0.7458    0.6522    0.6959      2982

    accuracy                         0.6153     30000
   macro avg     0.6484    0.6152    0.6201     30000
weighted avg     0.6487    0.6153    0.6203     30000





0.46859969232566745




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9970914689945634




CCA coefficients mean non-concern: 0.9966966950885668




Linear CKA concern: 0.9985014342537395




Linear CKA non-concern: 0.9980515824142057




Kernel CKA concern: 0.9952262109675336




Kernel CKA non-concern: 0.992528806127453




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.222069263458252




Evaluate the pruned model 3




Evaluating the model:   0%|                                                                                   …

Loss: 1.2201




Precision: 0.6485, Recall: 0.6160, F1-Score: 0.6205




              precision    recall  f1-score   support

           0     0.5337    0.4943    0.5133      2992
           1     0.6911    0.5070    0.5849      2992
           2     0.6973    0.6112    0.6515      3012
           3     0.3439    0.6358    0.4464      2998
           4     0.7118    0.7867    0.7474      2973
           5     0.8446    0.7616    0.8010      3054
           6     0.6900    0.3906    0.4988      3003
           7     0.6269    0.6258    0.6263      3012
           8     0.5887    0.7089    0.6432      2982
           9     0.7566    0.6378    0.6921      2982

    accuracy                         0.6161     30000
   macro avg     0.6485    0.6160    0.6205     30000
weighted avg     0.6488    0.6161    0.6207     30000





0.46859969232566745




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9974779164029137




CCA coefficients mean non-concern: 0.9976138031143067




Linear CKA concern: 0.9975449527844799




Linear CKA non-concern: 0.9990391585812153




Kernel CKA concern: 0.9932763757682721




Kernel CKA non-concern: 0.99659025989804




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2023377418518066




Evaluate the pruned model 4




Evaluating the model:   0%|                                                                                   …

Loss: 1.2255




Precision: 0.6518, Recall: 0.6151, F1-Score: 0.6195




              precision    recall  f1-score   support

           0     0.5233    0.4993    0.5110      2992
           1     0.7034    0.4860    0.5748      2992
           2     0.7157    0.5959    0.6504      3012
           3     0.3412    0.6431    0.4459      2998
           4     0.7108    0.7871    0.7470      2973
           5     0.8396    0.7626    0.7992      3054
           6     0.7184    0.3823    0.4990      3003
           7     0.6288    0.6298    0.6293      3012
           8     0.5823    0.7223    0.6448      2982
           9     0.7543    0.6425    0.6940      2982

    accuracy                         0.6152     30000
   macro avg     0.6518    0.6151    0.6195     30000
weighted avg     0.6521    0.6152    0.6197     30000





0.46859969232566745




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9945389268703316




CCA coefficients mean non-concern: 0.9970881706918651




Linear CKA concern: 0.9951074318084094




Linear CKA non-concern: 0.9986391726677676




Kernel CKA concern: 0.9881423214013584




Kernel CKA non-concern: 0.99510544710741




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.223405122756958




Evaluate the pruned model 5




Evaluating the model:   0%|                                                                                   …

Loss: 1.2244




Precision: 0.6510, Recall: 0.6153, F1-Score: 0.6196




              precision    recall  f1-score   support

           0     0.5428    0.4833    0.5113      2992
           1     0.6958    0.4970    0.5798      2992
           2     0.6939    0.6202    0.6550      3012
           3     0.3404    0.6391    0.4442      2998
           4     0.7193    0.7773    0.7472      2973
           5     0.8314    0.7672    0.7980      3054
           6     0.7183    0.3796    0.4967      3003
           7     0.6235    0.6295    0.6265      3012
           8     0.5755    0.7257    0.6419      2982
           9     0.7691    0.6345    0.6953      2982

    accuracy                         0.6154     30000
   macro avg     0.6510    0.6153    0.6196     30000
weighted avg     0.6513    0.6154    0.6198     30000





0.46859969232566745




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9947153636391649




CCA coefficients mean non-concern: 0.9973937054024664




Linear CKA concern: 0.9755904985613275




Linear CKA non-concern: 0.9987027201471851




Kernel CKA concern: 0.9677024616725272




Kernel CKA non-concern: 0.9951368680098512




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2203612327575684




Evaluate the pruned model 6




Evaluating the model:   0%|                                                                                   …

Loss: 1.2139




Precision: 0.6453, Recall: 0.6186, F1-Score: 0.6222




              precision    recall  f1-score   support

           0     0.5188    0.5064    0.5125      2992
           1     0.6957    0.5074    0.5868      2992
           2     0.6974    0.6129    0.6524      3012
           3     0.3604    0.6187    0.4555      2998
           4     0.7198    0.7783    0.7479      2973
           5     0.8370    0.7633    0.7984      3054
           6     0.6703    0.4056    0.5054      3003
           7     0.6153    0.6368    0.6259      3012
           8     0.5891    0.7119    0.6447      2982
           9     0.7490    0.6445    0.6929      2982

    accuracy                         0.6187     30000
   macro avg     0.6453    0.6186    0.6222     30000
weighted avg     0.6456    0.6187    0.6224     30000





0.46859969232566745




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9978004978850122




CCA coefficients mean non-concern: 0.9969859013058504




Linear CKA concern: 0.9991123573460666




Linear CKA non-concern: 0.9979953317203425




Kernel CKA concern: 0.9967781337671463




Kernel CKA non-concern: 0.9931413572870417




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.1850666999816895




Evaluate the pruned model 7




Evaluating the model:   0%|                                                                                   …

Loss: 1.2177




Precision: 0.6476, Recall: 0.6170, F1-Score: 0.6213




              precision    recall  f1-score   support

           0     0.5363    0.4860    0.5099      2992
           1     0.6851    0.5127    0.5865      2992
           2     0.7092    0.6033    0.6520      3012
           3     0.3482    0.6284    0.4481      2998
           4     0.7211    0.7783    0.7486      2973
           5     0.8366    0.7629    0.7981      3054
           6     0.6857    0.3989    0.5044      3003
           7     0.6136    0.6401    0.6266      3012
           8     0.5853    0.7193    0.6454      2982
           9     0.7549    0.6405    0.6930      2982

    accuracy                         0.6171     30000
   macro avg     0.6476    0.6170    0.6213     30000
weighted avg     0.6479    0.6171    0.6215     30000





0.46859969232566745




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9974667030938066




CCA coefficients mean non-concern: 0.997153233455915




Linear CKA concern: 0.9981339435186088




Linear CKA non-concern: 0.9978907409868522




Kernel CKA concern: 0.9951016868768764




Kernel CKA non-concern: 0.9928081423347606




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.1959803104400635




Evaluate the pruned model 8




Evaluating the model:   0%|                                                                                   …

Loss: 1.2232




Precision: 0.6482, Recall: 0.6149, F1-Score: 0.6196




              precision    recall  f1-score   support

           0     0.5199    0.4983    0.5089      2992
           1     0.7013    0.4960    0.5810      2992
           2     0.7037    0.6039    0.6500      3012
           3     0.3432    0.6321    0.4448      2998
           4     0.7208    0.7773    0.7480      2973
           5     0.8442    0.7577    0.7986      3054
           6     0.6885    0.3923    0.4998      3003
           7     0.6261    0.6248    0.6255      3012
           8     0.5849    0.7193    0.6452      2982
           9     0.7490    0.6476    0.6946      2982

    accuracy                         0.6150     30000
   macro avg     0.6482    0.6149    0.6196     30000
weighted avg     0.6485    0.6150    0.6198     30000





0.46859969232566745




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.99705353216064




CCA coefficients mean non-concern: 0.9967698175043291




Linear CKA concern: 0.9988667016888721




Linear CKA non-concern: 0.9979260144189008




Kernel CKA concern: 0.9962643766454913




Kernel CKA non-concern: 0.9926138582750977




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2173144817352295




Evaluate the pruned model 9




Evaluating the model:   0%|                                                                                   …

Loss: 1.2254




Precision: 0.6526, Recall: 0.6135, F1-Score: 0.6186




              precision    recall  f1-score   support

           0     0.5516    0.4716    0.5085      2992
           1     0.7062    0.4860    0.5757      2992
           2     0.7021    0.6135    0.6549      3012
           3     0.3309    0.6484    0.4382      2998
           4     0.7171    0.7827    0.7485      2973
           5     0.8471    0.7580    0.8001      3054
           6     0.7133    0.3820    0.4975      3003
           7     0.6222    0.6298    0.6260      3012
           8     0.5798    0.7210    0.6428      2982
           9     0.7553    0.6419    0.6940      2982

    accuracy                         0.6136     30000
   macro avg     0.6526    0.6135    0.6186     30000
weighted avg     0.6529    0.6136    0.6188     30000





0.46859969232566745




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.997656317351566




CCA coefficients mean non-concern: 0.9968062604704488




Linear CKA concern: 0.9982311531960796




Linear CKA non-concern: 0.9985765336246364




Kernel CKA concern: 0.9955379805290723




Kernel CKA non-concern: 0.9948027664501596




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.225411891937256




In [12]:
df_list = [report_to_df(df) for df in result_list]
new_df = append_nth_row(df_list)
csv_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
new_df.to_csv(f"results/{csv_name}.csv", index=False)