In [1]:
import os
import sys

sys.path.append("../../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint
from src.models.evaluate import (
    evaluate_model,
    get_sparsity,
    get_similarity,
    get_perplexity,
)
from src.utils.sampling import SamplingDataset
from src.pruning.prune import (
    prune_concern_identification,
)
from src.utils.helper import report_to_df, append_nth_row

In [3]:
name = "bert-6-128-yahoo"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ratio = 0.5
seed = 44
include_layers = ["intermediate", "output"]
exclude_layers = ["attention"]

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-10-15 20:32:56


In [5]:
config = Config(name, device)
num_labels = config.config["num_labels"]
model = load_model(config)

Loading the model.




{

'architectures'

: 

'bert'

,
 

'dataset_name'

: 

'YahooAnswersTopics'

,
 

'model_name'

: 

'models/bert-6-128-yahoo'

,
 

'num_labels'

: 

10

,
 

'tokenizer_name'

: 

'fabriceyhc/bert-base-uncased-yahoo_answers_topics'

}




The model models/bert-6-128-yahoo is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.




train.pkl is loaded from cache.




valid.pkl is loaded from cache.




test.pkl is loaded from cache.




The dataset YahooAnswersTopics is loaded




{

'config_name'

: 

'yahoo_answers_topics'

,
 

'features'

: 

{'first_column': 'question_title', 'second_column': 'topic'}

,
 

'path'

: 

'yahoo_answers_topics'

}




In [7]:
# print("Evaluate the original model")
# result = evaluate_model(model, config, test_dataloader)

In [8]:
result_list = []

for concern in range(config.num_labels):
    config.init_seed()
    positive_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        True,
        4,
        resample=False,
    )
    negative_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        False,
        4,
        resample=False,
    )
    all_samples = SamplingDataset(
        train_dataloader,
        config,
        200,
        num_samples,
        False,
        4,
        resample=False,
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        config,
        positive_samples,
        negative_samples,
        include_layers=["intermediate", "output"],
        exclude_layers=["attention"],
        sparsity_ratio=ratio,
        keep_dim=True,
        method="unstructed",
    )
    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, config, test_dataloader, verbose=True)
    result_list.append(result)
    get_sparsity(module)

    get_similarity(model, module, valid_dataloader, concern, num_samples, config)
    print("original model's perplexity")
    get_perplexity(model, valid_dataloader, config)
    print("pruned model's perplexity")
    get_perplexity(module, valid_dataloader, config)

Evaluate the pruned model 0




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2153




Precision: 0.6465, Recall: 0.6156, F1-Score: 0.6204




              precision    recall  f1-score   support

           0     0.5586    0.4703    0.5106      2992
           1     0.6991    0.5257    0.6002      2992
           2     0.6780    0.6418    0.6594      3012
           3     0.3415    0.6404    0.4454      2998
           4     0.7246    0.7595    0.7417      2973
           5     0.8388    0.7701    0.8030      3054
           6     0.6801    0.3993    0.5031      3003
           7     0.6254    0.6069    0.6160      3012
           8     0.5932    0.6935    0.6395      2982
           9     0.7261    0.6489    0.6853      2982

    accuracy                         0.6158     30000
   macro avg     0.6465    0.6156    0.6204     30000
weighted avg     0.6468    0.6158    0.6206     30000





0.32649319722807596




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9989692393249631




CCA coefficients mean non-concern: 0.9985807745176424




Linear CKA concern: 0.9995012200176271




Linear CKA non-concern: 0.9992384127500107




Kernel CKA concern: 0.9982317908639831




Kernel CKA non-concern: 0.9968030527036583




original model's perplexity




3.187649726867676




pruned model's perplexity




3.1765120029449463




Evaluate the pruned model 1




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2160




Precision: 0.6459, Recall: 0.6146, F1-Score: 0.6193




              precision    recall  f1-score   support

           0     0.5592    0.4686    0.5099      2992
           1     0.6989    0.5197    0.5961      2992
           2     0.6761    0.6418    0.6585      3012
           3     0.3402    0.6414    0.4446      2998
           4     0.7217    0.7622    0.7414      2973
           5     0.8324    0.7711    0.8006      3054
           6     0.6842    0.3983    0.5035      3003
           7     0.6268    0.6049    0.6156      3012
           8     0.5934    0.6911    0.6386      2982
           9     0.7258    0.6472    0.6843      2982

    accuracy                         0.6148     30000
   macro avg     0.6459    0.6146    0.6193     30000
weighted avg     0.6462    0.6148    0.6195     30000





0.32649319722807596




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9991320673564436




CCA coefficients mean non-concern: 0.9986808212439668




Linear CKA concern: 0.9994068118631825




Linear CKA non-concern: 0.9994813872503575




Kernel CKA concern: 0.9984058677947778




Kernel CKA non-concern: 0.9975850338962545




original model's perplexity




3.187649726867676




pruned model's perplexity




3.1780123710632324




Evaluate the pruned model 2




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2147




Precision: 0.6465, Recall: 0.6162, F1-Score: 0.6208




              precision    recall  f1-score   support

           0     0.5588    0.4719    0.5117      2992
           1     0.6976    0.5274    0.6007      2992
           2     0.6793    0.6428    0.6605      3012
           3     0.3433    0.6391    0.4467      2998
           4     0.7238    0.7598    0.7414      2973
           5     0.8344    0.7721    0.8020      3054
           6     0.6861    0.3996    0.5051      3003
           7     0.6263    0.6049    0.6154      3012
           8     0.5929    0.6922    0.6387      2982
           9     0.7220    0.6522    0.6853      2982

    accuracy                         0.6163     30000
   macro avg     0.6465    0.6162    0.6208     30000
weighted avg     0.6467    0.6163    0.6210     30000





0.32649319722807596




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.998955243142456




CCA coefficients mean non-concern: 0.9986733104455784




Linear CKA concern: 0.9993035766382111




Linear CKA non-concern: 0.9994008089318343




Kernel CKA concern: 0.9986747199255936




Kernel CKA non-concern: 0.9977185026563358




original model's perplexity




3.187649726867676




pruned model's perplexity




3.173621416091919




Evaluate the pruned model 3




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2153




Precision: 0.6464, Recall: 0.6155, F1-Score: 0.6201




              precision    recall  f1-score   support

           0     0.5567    0.4679    0.5084      2992
           1     0.7006    0.5217    0.5981      2992
           2     0.6772    0.6414    0.6588      3012
           3     0.3422    0.6421    0.4464      2998
           4     0.7222    0.7642    0.7426      2973
           5     0.8356    0.7708    0.8019      3054
           6     0.6834    0.3996    0.5043      3003
           7     0.6249    0.6079    0.6163      3012
           8     0.5935    0.6932    0.6394      2982
           9     0.7279    0.6459    0.6844      2982

    accuracy                         0.6156     30000
   macro avg     0.6464    0.6155    0.6201     30000
weighted avg     0.6467    0.6156    0.6203     30000





0.32649319722807596




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9989144527589705




CCA coefficients mean non-concern: 0.9985848781010317




Linear CKA concern: 0.9991676459616222




Linear CKA non-concern: 0.99939900391456




Kernel CKA concern: 0.998747155208261




Kernel CKA non-concern: 0.9975427663703641




original model's perplexity




3.187649726867676




pruned model's perplexity




3.17681884765625




Evaluate the pruned model 4




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2139




Precision: 0.6455, Recall: 0.6161, F1-Score: 0.6202




              precision    recall  f1-score   support

           0     0.5566    0.4719    0.5108      2992
           1     0.7019    0.5217    0.5985      2992
           2     0.6787    0.6418    0.6597      3012
           3     0.3452    0.6368    0.4477      2998
           4     0.7179    0.7686    0.7424      2973
           5     0.8322    0.7731    0.8016      3054
           6     0.6815    0.3989    0.5033      3003
           7     0.6263    0.6049    0.6154      3012
           8     0.5915    0.6925    0.6380      2982
           9     0.7231    0.6506    0.6849      2982

    accuracy                         0.6162     30000
   macro avg     0.6455    0.6161    0.6202     30000
weighted avg     0.6458    0.6162    0.6204     30000





0.32649319722807596




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9983424896618026




CCA coefficients mean non-concern: 0.9984919208020866




Linear CKA concern: 0.9982429097723022




Linear CKA non-concern: 0.9992110304218677




Kernel CKA concern: 0.9973946113204931




Kernel CKA non-concern: 0.9963179336107844




original model's perplexity




3.187649726867676




pruned model's perplexity




3.1697733402252197




Evaluate the pruned model 5




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2156




Precision: 0.6467, Recall: 0.6157, F1-Score: 0.6204




              precision    recall  f1-score   support

           0     0.5525    0.4733    0.5098      2992
           1     0.7004    0.5244    0.5998      2992
           2     0.6788    0.6421    0.6600      3012
           3     0.3425    0.6421    0.4467      2998
           4     0.7257    0.7598    0.7424      2973
           5     0.8360    0.7711    0.8022      3054
           6     0.6860    0.3979    0.5037      3003
           7     0.6249    0.6062    0.6154      3012
           8     0.5941    0.6911    0.6390      2982
           9     0.7261    0.6489    0.6853      2982

    accuracy                         0.6158     30000
   macro avg     0.6467    0.6157    0.6204     30000
weighted avg     0.6470    0.6158    0.6206     30000





0.32649319722807596




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9986502988305735




CCA coefficients mean non-concern: 0.9987408070186862




Linear CKA concern: 0.9961021710715847




Linear CKA non-concern: 0.9994956428260704




Kernel CKA concern: 0.9978547237065932




Kernel CKA non-concern: 0.997791759394249




original model's perplexity




3.187649726867676




pruned model's perplexity




3.178312301635742




Evaluate the pruned model 6




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2154




Precision: 0.6462, Recall: 0.6156, F1-Score: 0.6202




              precision    recall  f1-score   support

           0     0.5568    0.4686    0.5089      2992
           1     0.7016    0.5211    0.5980      2992
           2     0.6789    0.6408    0.6593      3012
           3     0.3428    0.6418    0.4469      2998
           4     0.7230    0.7622    0.7421      2973
           5     0.8355    0.7701    0.8015      3054
           6     0.6812    0.4006    0.5045      3003
           7     0.6237    0.6092    0.6164      3012
           8     0.5937    0.6918    0.6390      2982
           9     0.7250    0.6499    0.6854      2982

    accuracy                         0.6157     30000
   macro avg     0.6462    0.6156    0.6202     30000
weighted avg     0.6465    0.6157    0.6204     30000





0.32649319722807596




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9989622407849577




CCA coefficients mean non-concern: 0.9984881384770631




Linear CKA concern: 0.9996389567103671




Linear CKA non-concern: 0.9992323732276935




Kernel CKA concern: 0.997949904434969




Kernel CKA non-concern: 0.9967068426746596




original model's perplexity




3.187649726867676




pruned model's perplexity




3.1756837368011475




Evaluate the pruned model 7




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2151




Precision: 0.6461, Recall: 0.6157, F1-Score: 0.6203




              precision    recall  f1-score   support

           0     0.5548    0.4706    0.5092      2992
           1     0.6998    0.5267    0.6011      2992
           2     0.6762    0.6428    0.6591      3012
           3     0.3436    0.6388    0.4469      2998
           4     0.7205    0.7639    0.7416      2973
           5     0.8368    0.7692    0.8016      3054
           6     0.6848    0.3979    0.5034      3003
           7     0.6208    0.6099    0.6153      3012
           8     0.5948    0.6901    0.6389      2982
           9     0.7284    0.6476    0.6856      2982

    accuracy                         0.6159     30000
   macro avg     0.6461    0.6157    0.6203     30000
weighted avg     0.6464    0.6159    0.6205     30000





0.32649319722807596




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.998580490983921




CCA coefficients mean non-concern: 0.9986071926411558




Linear CKA concern: 0.9974423998865201




Linear CKA non-concern: 0.9991672833540838




Kernel CKA concern: 0.9968364186682545




Kernel CKA non-concern: 0.9970553404865238




original model's perplexity




3.187649726867676




pruned model's perplexity




3.174818277359009




Evaluate the pruned model 8




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2142




Precision: 0.6459, Recall: 0.6157, F1-Score: 0.6203




              precision    recall  f1-score   support

           0     0.5564    0.4716    0.5105      2992
           1     0.6972    0.5287    0.6014      2992
           2     0.6801    0.6408    0.6598      3012
           3     0.3433    0.6388    0.4466      2998
           4     0.7219    0.7622    0.7415      2973
           5     0.8374    0.7688    0.8016      3054
           6     0.6830    0.3989    0.5037      3003
           7     0.6262    0.6046    0.6152      3012
           8     0.5929    0.6915    0.6384      2982
           9     0.7203    0.6512    0.6840      2982

    accuracy                         0.6158     30000
   macro avg     0.6459    0.6157    0.6203     30000
weighted avg     0.6462    0.6158    0.6205     30000





0.32649319722807596




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9989382100469775




CCA coefficients mean non-concern: 0.9983074907013637




Linear CKA concern: 0.9997437025652385




Linear CKA non-concern: 0.9988125181375566




Kernel CKA concern: 0.9991931729287611




Kernel CKA non-concern: 0.9958163923899559




original model's perplexity




3.187649726867676




pruned model's perplexity




3.1733696460723877




Evaluate the pruned model 9




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2146




Precision: 0.6462, Recall: 0.6161, F1-Score: 0.6206




              precision    recall  f1-score   support

           0     0.5544    0.4719    0.5098      2992
           1     0.7028    0.5241    0.6004      2992
           2     0.6786    0.6414    0.6595      3012
           3     0.3444    0.6384    0.4474      2998
           4     0.7228    0.7649    0.7433      2973
           5     0.8326    0.7718    0.8010      3054
           6     0.6823    0.4013    0.5053      3003
           7     0.6258    0.6052    0.6154      3012
           8     0.5913    0.6942    0.6386      2982
           9     0.7267    0.6482    0.6852      2982

    accuracy                         0.6163     30000
   macro avg     0.6462    0.6161    0.6206     30000
weighted avg     0.6465    0.6163    0.6208     30000





0.32649319722807596




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9990638100679095




CCA coefficients mean non-concern: 0.9986464009710293




Linear CKA concern: 0.9989022319008312




Linear CKA non-concern: 0.9992340027687494




Kernel CKA concern: 0.9971302893502217




Kernel CKA non-concern: 0.9972468585883404




original model's perplexity




3.187649726867676




pruned model's perplexity




3.1731655597686768




In [9]:
df_list = [report_to_df(df) for df in result_list]
new_df = append_nth_row(df_list)
csv_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
new_df.to_csv(f"results/{csv_name}.csv", index=False)
print(csv_name)
new_df