In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint
from src.models.evaluate import (
    evaluate_model,
    get_sparsity,
    get_similarity,
    get_perplexity,
)
from src.utils.sampling import SamplingDataset
from src.pruning.prune import (
    prune_concern_identification,
)
from src.utils.helper import report_to_df, append_nth_row

In [3]:
name = "bert-4-128-yahoo"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ratio = 0.5
seed = 44
include_layers = ["intermediate", "output"]
exclude_layers = ["attention"]

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-10-15 19:56:31


In [5]:
config = Config(name, device)
num_labels = config.config["num_labels"]
model = load_model(config)

Loading the model.




{

'architectures'

: 

'bert'

,
 

'dataset_name'

: 

'YahooAnswersTopics'

,
 

'model_name'

: 

'models/bert-4-128-yahoo'

,
 

'num_labels'

: 

10

,
 

'tokenizer_name'

: 

'fabriceyhc/bert-base-uncased-yahoo_answers_topics'

}




The model models/bert-4-128-yahoo is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.




train.pkl is loaded from cache.




valid.pkl is loaded from cache.




test.pkl is loaded from cache.




The dataset YahooAnswersTopics is loaded




{

'config_name'

: 

'yahoo_answers_topics'

,
 

'features'

: 

{'first_column': 'question_title', 'second_column': 'topic'}

,
 

'path'

: 

'yahoo_answers_topics'

}




In [7]:
# print("Evaluate the original model")
# result = evaluate_model(model, config, test_dataloader)

In [8]:
result_list = []

for concern in range(config.num_labels):
    config.init_seed()
    positive_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        True,
        4,
        resample=False,
    )
    negative_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        False,
        4,
        resample=False,
    )
    all_samples = SamplingDataset(
        train_dataloader,
        config,
        200,
        num_samples,
        False,
        4,
        resample=False,
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        config,
        positive_samples,
        negative_samples,
        include_layers=["intermediate", "output"],
        exclude_layers=["attention"],
        sparsity_ratio=ratio,
        keep_dim=True,
        method="unstructed",
    )
    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, config, test_dataloader)
    result_list.append(result)
    get_sparsity(module)

    get_similarity(model, module, valid_dataloader, concern, num_samples, config)
    print("original model's perplexity")
    get_perplexity(model, valid_dataloader, config)
    print("pruned model's perplexity")
    get_perplexity(module, valid_dataloader, config)

Evaluate the pruned model 0




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2191




Precision: 0.6488, Recall: 0.6162, F1-Score: 0.6208




              precision    recall  f1-score   support

           0     0.5356    0.4833    0.5081      2992
           1     0.7007    0.4860    0.5739      2992
           2     0.7020    0.6149    0.6556      3012
           3     0.3446    0.6408    0.4482      2998
           4     0.7243    0.7767    0.7496      2973
           5     0.8401    0.7623    0.7993      3054
           6     0.6759    0.4043    0.5059      3003
           7     0.6228    0.6358    0.6292      3012
           8     0.5847    0.7176    0.6444      2982
           9     0.7574    0.6408    0.6943      2982

    accuracy                         0.6163     30000
   macro avg     0.6488    0.6162    0.6208     30000
weighted avg     0.6491    0.6163    0.6211     30000





0.39267273726798974




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9990036613802763




CCA coefficients mean non-concern: 0.9988637246816681




Linear CKA concern: 0.9997045308947307




Linear CKA non-concern: 0.9996605573664932




Kernel CKA concern: 0.9989691119725961




Kernel CKA non-concern: 0.998594321531




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2010529041290283




Evaluate the pruned model 1




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2196




Precision: 0.6494, Recall: 0.6156, F1-Score: 0.6203




              precision    recall  f1-score   support

           0     0.5424    0.4809    0.5098      2992
           1     0.6979    0.4779    0.5673      2992
           2     0.7028    0.6155    0.6563      3012
           3     0.3422    0.6448    0.4471      2998
           4     0.7239    0.7787    0.7503      2973
           5     0.8405    0.7610    0.7988      3054
           6     0.6799    0.4046    0.5073      3003
           7     0.6224    0.6375    0.6298      3012
           8     0.5819    0.7173    0.6425      2982
           9     0.7604    0.6375    0.6935      2982

    accuracy                         0.6157     30000
   macro avg     0.6494    0.6156    0.6203     30000
weighted avg     0.6497    0.6157    0.6205     30000





0.39267273726798974




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9991143164510867




CCA coefficients mean non-concern: 0.998991277240882




Linear CKA concern: 0.9997773696302663




Linear CKA non-concern: 0.9997014606079154




Kernel CKA concern: 0.9994369445469891




Kernel CKA non-concern: 0.9987581847758688




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2033021450042725




Evaluate the pruned model 2




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2190




Precision: 0.6495, Recall: 0.6159, F1-Score: 0.6207




              precision    recall  f1-score   support

           0     0.5378    0.4846    0.5098      2992
           1     0.7025    0.4806    0.5707      2992
           2     0.6993    0.6162    0.6551      3012
           3     0.3422    0.6428    0.4466      2998
           4     0.7235    0.7797    0.7505      2973
           5     0.8433    0.7613    0.8002      3054
           6     0.6776    0.4046    0.5067      3003
           7     0.6267    0.6315    0.6291      3012
           8     0.5828    0.7166    0.6428      2982
           9     0.7595    0.6408    0.6952      2982

    accuracy                         0.6160     30000
   macro avg     0.6495    0.6159    0.6207     30000
weighted avg     0.6498    0.6160    0.6209     30000





0.39267273726798974




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9989137684243644




CCA coefficients mean non-concern: 0.9988428818236763




Linear CKA concern: 0.9997284983433627




Linear CKA non-concern: 0.9996511912150075




Kernel CKA concern: 0.9991722325897414




Kernel CKA non-concern: 0.9986228191126841




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.199392557144165




Evaluate the pruned model 3




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2197




Precision: 0.6487, Recall: 0.6159, F1-Score: 0.6207




              precision    recall  f1-score   support

           0     0.5383    0.4833    0.5093      2992
           1     0.6957    0.4866    0.5727      2992
           2     0.7033    0.6145    0.6559      3012
           3     0.3430    0.6414    0.4470      2998
           4     0.7231    0.7773    0.7492      2973
           5     0.8411    0.7606    0.7988      3054
           6     0.6763    0.4049    0.5066      3003
           7     0.6226    0.6365    0.6295      3012
           8     0.5871    0.7140    0.6444      2982
           9     0.7565    0.6398    0.6933      2982

    accuracy                         0.6160     30000
   macro avg     0.6487    0.6159    0.6207     30000
weighted avg     0.6490    0.6160    0.6209     30000





0.39267273726798974




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9988773807536343




CCA coefficients mean non-concern: 0.9988791348260424




Linear CKA concern: 0.9990895225201966




Linear CKA non-concern: 0.9996406896489796




Kernel CKA concern: 0.9980162419054477




Kernel CKA non-concern: 0.9984662339061613




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.204606533050537




Evaluate the pruned model 4




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2177




Precision: 0.6486, Recall: 0.6162, F1-Score: 0.6209




              precision    recall  f1-score   support

           0     0.5325    0.4926    0.5118      2992
           1     0.6955    0.4870    0.5728      2992
           2     0.6999    0.6179    0.6563      3012
           3     0.3450    0.6361    0.4474      2998
           4     0.7204    0.7790    0.7485      2973
           5     0.8427    0.7597    0.7990      3054
           6     0.6799    0.4046    0.5073      3003
           7     0.6257    0.6305    0.6281      3012
           8     0.5846    0.7160    0.6437      2982
           9     0.7597    0.6392    0.6942      2982

    accuracy                         0.6163     30000
   macro avg     0.6486    0.6162    0.6209     30000
weighted avg     0.6489    0.6163    0.6211     30000





0.39267273726798974




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.998591812020558




CCA coefficients mean non-concern: 0.9989175743159063




Linear CKA concern: 0.9995040122338846




Linear CKA non-concern: 0.9996459082750799




Kernel CKA concern: 0.9988857411752584




Kernel CKA non-concern: 0.9986411349722772




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.194809913635254




Evaluate the pruned model 5




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2173




Precision: 0.6492, Recall: 0.6166, F1-Score: 0.6213




              precision    recall  f1-score   support

           0     0.5366    0.4903    0.5124      2992
           1     0.6957    0.4860    0.5722      2992
           2     0.7011    0.6175    0.6567      3012
           3     0.3449    0.6391    0.4480      2998
           4     0.7205    0.7777    0.7480      2973
           5     0.8405    0.7626    0.7997      3054
           6     0.6784    0.4053    0.5074      3003
           7     0.6241    0.6361    0.6301      3012
           8     0.5860    0.7163    0.6446      2982
           9     0.7640    0.6351    0.6936      2982

    accuracy                         0.6167     30000
   macro avg     0.6492    0.6166    0.6213     30000
weighted avg     0.6495    0.6167    0.6215     30000





0.39267273726798974




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9984527178665146




CCA coefficients mean non-concern: 0.9988739146419824




Linear CKA concern: 0.9966381364642334




Linear CKA non-concern: 0.999787005160984




Kernel CKA concern: 0.9958731091005908




Kernel CKA non-concern: 0.9992351789680812




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.1937785148620605




Evaluate the pruned model 6




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2184




Precision: 0.6486, Recall: 0.6166, F1-Score: 0.6212




              precision    recall  f1-score   support

           0     0.5329    0.4873    0.5091      2992
           1     0.6991    0.4870    0.5741      2992
           2     0.7013    0.6165    0.6562      3012
           3     0.3465    0.6378    0.4490      2998
           4     0.7250    0.7767    0.7499      2973
           5     0.8451    0.7577    0.7990      3054
           6     0.6742    0.4059    0.5068      3003
           7     0.6233    0.6361    0.6296      3012
           8     0.5834    0.7180    0.6437      2982
           9     0.7550    0.6429    0.6944      2982

    accuracy                         0.6167     30000
   macro avg     0.6486    0.6166    0.6212     30000
weighted avg     0.6489    0.6167    0.6214     30000





0.39267273726798974




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9990937190799735




CCA coefficients mean non-concern: 0.9987826220547066




Linear CKA concern: 0.9998214291372598




Linear CKA non-concern: 0.9995513775349342




Kernel CKA concern: 0.9993140567785072




Kernel CKA non-concern: 0.9981522789377725




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2004776000976562




Evaluate the pruned model 7




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2194




Precision: 0.6483, Recall: 0.6162, F1-Score: 0.6209




              precision    recall  f1-score   support

           0     0.5351    0.4840    0.5082      2992
           1     0.6928    0.4936    0.5765      2992
           2     0.7037    0.6125    0.6550      3012
           3     0.3448    0.6374    0.4475      2998
           4     0.7229    0.7750    0.7481      2973
           5     0.8397    0.7616    0.7988      3054
           6     0.6748    0.4056    0.5067      3003
           7     0.6217    0.6361    0.6288      3012
           8     0.5843    0.7170    0.6439      2982
           9     0.7635    0.6388    0.6956      2982

    accuracy                         0.6163     30000
   macro avg     0.6483    0.6162    0.6209     30000
weighted avg     0.6486    0.6163    0.6211     30000





0.39267273726798974




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9987943179769515




CCA coefficients mean non-concern: 0.998784836881527




Linear CKA concern: 0.9988804473438817




Linear CKA non-concern: 0.9995889221621121




Kernel CKA concern: 0.9977137042645204




Kernel CKA non-concern: 0.9984582215741695




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2018871307373047




Evaluate the pruned model 8




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2198




Precision: 0.6490, Recall: 0.6156, F1-Score: 0.6204




              precision    recall  f1-score   support

           0     0.5361    0.4820    0.5076      2992
           1     0.7016    0.4850    0.5735      2992
           2     0.7016    0.6142    0.6550      3012
           3     0.3426    0.6404    0.4464      2998
           4     0.7210    0.7797    0.7492      2973
           5     0.8437    0.7583    0.7988      3054
           6     0.6783    0.4029    0.5055      3003
           7     0.6267    0.6321    0.6294      3012
           8     0.5820    0.7197    0.6436      2982
           9     0.7566    0.6422    0.6947      2982

    accuracy                         0.6157     30000
   macro avg     0.6490    0.6156    0.6204     30000
weighted avg     0.6493    0.6157    0.6206     30000





0.39267273726798974




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9987633835972018




CCA coefficients mean non-concern: 0.9986229415321434




Linear CKA concern: 0.9996570246546301




Linear CKA non-concern: 0.9995108039211884




Kernel CKA concern: 0.9989370629981923




Kernel CKA non-concern: 0.9979256222630161




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.202697992324829




Evaluate the pruned model 9




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2185




Precision: 0.6487, Recall: 0.6161, F1-Score: 0.6208




              precision    recall  f1-score   support

           0     0.5376    0.4823    0.5085      2992
           1     0.6947    0.4890    0.5740      2992
           2     0.7017    0.6145    0.6552      3012
           3     0.3448    0.6401    0.4482      2998
           4     0.7246    0.7763    0.7496      2973
           5     0.8408    0.7606    0.7987      3054
           6     0.6772    0.4053    0.5071      3003
           7     0.6221    0.6361    0.6290      3012
           8     0.5827    0.7186    0.6435      2982
           9     0.7607    0.6385    0.6943      2982

    accuracy                         0.6162     30000
   macro avg     0.6487    0.6161    0.6208     30000
weighted avg     0.6490    0.6162    0.6210     30000





0.39267273726798974




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9990293508789665




CCA coefficients mean non-concern: 0.9989289158966695




Linear CKA concern: 0.9993127257146294




Linear CKA non-concern: 0.9995828495986342




Kernel CKA concern: 0.998433828896102




Kernel CKA non-concern: 0.9984934079374677




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.1995232105255127




In [9]:
df_list = [report_to_df(df) for df in result_list]
new_df = append_nth_row(df_list)
csv_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
new_df.to_csv(f"results/{csv_name}.csv", index=False)