In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint
from src.models.evaluate import (
    evaluate_model,
    get_sparsity,
    get_similarity,
    get_perplexity,
)
from src.utils.sampling import SamplingDataset
from src.pruning.prune import (
    prune_concern_identification,
)
from src.utils.helper import report_to_df, append_nth_row

In [3]:
name = "bert-6-128-yahoo"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ratio = 0.3
seed = 44
include_layers = ["intermediate", "output"]
exclude_layers = ["attention"]

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-10-15 23:35:46


In [5]:
config = Config(name, device)
num_labels = config.config["num_labels"]
model = load_model(config)

Loading the model.




{

'architectures'

: 

'bert'

,
 

'dataset_name'

: 

'YahooAnswersTopics'

,
 

'model_name'

: 

'models/bert-6-128-yahoo'

,
 

'num_labels'

: 

10

,
 

'tokenizer_name'

: 

'fabriceyhc/bert-base-uncased-yahoo_answers_topics'

}




The model models/bert-6-128-yahoo is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.




train.pkl is loaded from cache.




valid.pkl is loaded from cache.




test.pkl is loaded from cache.




The dataset YahooAnswersTopics is loaded




{

'config_name'

: 

'yahoo_answers_topics'

,
 

'features'

: 

{'first_column': 'question_title', 'second_column': 'topic'}

,
 

'path'

: 

'yahoo_answers_topics'

}




In [7]:
from src.utils.load import load_cache
from src.utils.data_class import CustomEmbeddingDataset
from torch.utils.data import DataLoader

generated = load_cache(
    "datasets/generated_dataset/embedding_based/4_128-yahoo",
    "4_128-yahoo_top1.pkl",
)

4_128-yahoo_top1.pkl is loaded from cache.




In [8]:
generated["embeddings"] = generated.pop("example_list")
generated["labels"] = generated.pop("example_label")
generated["attention_mask"] = generated.pop("attn_list")

In [9]:
generated_data = CustomEmbeddingDataset(generated)
generated_dataloder = DataLoader(
    generated_data,
    batch_size=4,
)

In [10]:
# print("Evaluate the original model")
# result = evaluate_model(model, config, test_dataloader)

In [11]:
result_list = []

for concern in range(config.num_labels):
    config.init_seed()
    positive_samples = SamplingDataset(
        generated_dataloder,
        config,
        concern,
        num_samples,
        True,
        4,
        resample=False,
    )
    negative_samples = SamplingDataset(
        generated_dataloder,
        config,
        concern,
        num_samples,
        False,
        4,
        resample=False,
    )
    all_samples = SamplingDataset(
        generated_dataloder,
        config,
        200,
        num_samples,
        False,
        4,
        resample=False,
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        config,
        positive_samples,
        negative_samples,
        include_layers=include_layers,
        exclude_layers=exclude_layers,
        sparsity_ratio=ratio,
        keep_dim=True,
        method="unstructed",
    )
    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, config, test_dataloader)
    result_list.append(result)
    get_sparsity(module)

    get_similarity(model, module, valid_dataloader, concern, num_samples, config)
    print("original model's perplexity")
    get_perplexity(model, valid_dataloader, config)
    print("pruned model's perplexity")
    get_perplexity(module, valid_dataloader, config)

Evaluate the pruned model 0




Evaluating the model:   0%|                                                                                   …

Loss: 1.2183




Precision: 0.6475, Recall: 0.6150, F1-Score: 0.6202




              precision    recall  f1-score   support

           0     0.5588    0.4686    0.5097      2992
           1     0.7034    0.5167    0.5958      2992
           2     0.6792    0.6431    0.6606      3012
           3     0.3403    0.6444    0.4454      2998
           4     0.7139    0.7662    0.7391      2973
           5     0.8533    0.7659    0.8072      3054
           6     0.6742    0.4059    0.5068      3003
           7     0.6292    0.6029    0.6158      3012
           8     0.5866    0.6982    0.6376      2982
           9     0.7357    0.6385    0.6837      2982

    accuracy                         0.6152     30000
   macro avg     0.6475    0.6150    0.6202     30000
weighted avg     0.6478    0.6152    0.6204     30000





0.19449301788000617




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9995006879253531




CCA coefficients mean non-concern: 0.9994068230500678




Linear CKA concern: 0.9995170427336303




Linear CKA non-concern: 0.9993246213521086




Kernel CKA concern: 0.9983239873239539




Kernel CKA non-concern: 0.9979682494274768




original model's perplexity




3.187649726867676




pruned model's perplexity




3.1821084022521973




Evaluate the pruned model 1




Evaluating the model:   0%|                                                                                   …

Loss: 1.2190




Precision: 0.6464, Recall: 0.6155, F1-Score: 0.6202




              precision    recall  f1-score   support

           0     0.5600    0.4682    0.5100      2992
           1     0.7003    0.5217    0.5980      2992
           2     0.6798    0.6428    0.6608      3012
           3     0.3425    0.6418    0.4467      2998
           4     0.7135    0.7666    0.7391      2973
           5     0.8424    0.7682    0.8036      3054
           6     0.6780    0.4039    0.5063      3003
           7     0.6247    0.6079    0.6162      3012
           8     0.5924    0.6915    0.6381      2982
           9     0.7302    0.6425    0.6836      2982

    accuracy                         0.6156     30000
   macro avg     0.6464    0.6155    0.6202     30000
weighted avg     0.6467    0.6156    0.6204     30000





0.19449301788000617




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9996119997449254




CCA coefficients mean non-concern: 0.9994931571298752




Linear CKA concern: 0.9995115055943481




Linear CKA non-concern: 0.999682711700787




Kernel CKA concern: 0.9986916232132613




Kernel CKA non-concern: 0.9988944036882221




original model's perplexity




3.187649726867676




pruned model's perplexity




3.1833300590515137




Evaluate the pruned model 2




Evaluating the model:   0%|                                                                                   …

Loss: 1.2181




Precision: 0.6467, Recall: 0.6152, F1-Score: 0.6201




              precision    recall  f1-score   support

           0     0.5577    0.4686    0.5093      2992
           1     0.7026    0.5157    0.5948      2992
           2     0.6841    0.6421    0.6624      3012
           3     0.3416    0.6424    0.4460      2998
           4     0.7129    0.7659    0.7384      2973
           5     0.8492    0.7672    0.8061      3054
           6     0.6735    0.4053    0.5060      3003
           7     0.6255    0.6066    0.6159      3012
           8     0.5896    0.6972    0.6389      2982
           9     0.7307    0.6415    0.6832      2982

    accuracy                         0.6154     30000
   macro avg     0.6467    0.6152    0.6201     30000
weighted avg     0.6471    0.6154    0.6203     30000





0.19449301788000617




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9991475192599201




CCA coefficients mean non-concern: 0.9994873753544579




Linear CKA concern: 0.9954959483715484




Linear CKA non-concern: 0.9995795097538521




Kernel CKA concern: 0.9907202299268781




Kernel CKA non-concern: 0.9986839161623574




original model's perplexity




3.187649726867676




pruned model's perplexity




3.1810293197631836




Evaluate the pruned model 3




Evaluating the model:   0%|                                                                                   …

Loss: 1.2170




Precision: 0.6470, Recall: 0.6156, F1-Score: 0.6207




              precision    recall  f1-score   support

           0     0.5519    0.4706    0.5080      2992
           1     0.7013    0.5194    0.5968      2992
           2     0.6855    0.6418    0.6629      3012
           3     0.3422    0.6418    0.4464      2998
           4     0.7104    0.7682    0.7382      2973
           5     0.8540    0.7662    0.8077      3054
           6     0.6712    0.4086    0.5080      3003
           7     0.6308    0.6029    0.6165      3012
           8     0.5896    0.6965    0.6386      2982
           9     0.7332    0.6405    0.6837      2982

    accuracy                         0.6158     30000
   macro avg     0.6470    0.6156    0.6207     30000
weighted avg     0.6473    0.6158    0.6209     30000





0.19449301788000617




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.999371414171741




CCA coefficients mean non-concern: 0.9992261494627815




Linear CKA concern: 0.9988893998627716




Linear CKA non-concern: 0.9988295170361648




Kernel CKA concern: 0.9978283976333677




Kernel CKA non-concern: 0.9965099801166768




original model's perplexity




3.187649726867676




pruned model's perplexity




3.1792306900024414




Evaluate the pruned model 4




Evaluating the model:   0%|                                                                                   …

Loss: 1.2190




Precision: 0.6464, Recall: 0.6154, F1-Score: 0.6202




              precision    recall  f1-score   support

           0     0.5586    0.4666    0.5085      2992
           1     0.7032    0.5194    0.5975      2992
           2     0.6802    0.6454    0.6624      3012
           3     0.3418    0.6424    0.4462      2998
           4     0.7155    0.7639    0.7389      2973
           5     0.8413    0.7688    0.8034      3054
           6     0.6765    0.4059    0.5074      3003
           7     0.6240    0.6082    0.6160      3012
           8     0.5935    0.6898    0.6380      2982
           9     0.7296    0.6432    0.6837      2982

    accuracy                         0.6155     30000
   macro avg     0.6464    0.6154    0.6202     30000
weighted avg     0.6467    0.6155    0.6204     30000





0.19449301788000617




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9992517498320392




CCA coefficients mean non-concern: 0.999584468234934




Linear CKA concern: 0.9983964130907845




Linear CKA non-concern: 0.9997907344085656




Kernel CKA concern: 0.9980693144563904




Kernel CKA non-concern: 0.9992605633931946




original model's perplexity




3.187649726867676




pruned model's perplexity




3.1831183433532715




Evaluate the pruned model 5




Evaluating the model:   0%|                                                                                   …

Loss: 1.2177




Precision: 0.6465, Recall: 0.6155, F1-Score: 0.6204




              precision    recall  f1-score   support

           0     0.5591    0.4662    0.5085      2992
           1     0.7022    0.5177    0.5960      2992
           2     0.6784    0.6444    0.6610      3012
           3     0.3425    0.6414    0.4465      2998
           4     0.7203    0.7649    0.7419      2973
           5     0.8462    0.7675    0.8049      3054
           6     0.6703    0.4083    0.5075      3003
           7     0.6286    0.6046    0.6163      3012
           8     0.5864    0.6975    0.6372      2982
           9     0.7308    0.6429    0.6840      2982

    accuracy                         0.6157     30000
   macro avg     0.6465    0.6155    0.6204     30000
weighted avg     0.6468    0.6157    0.6206     30000





0.19449301788000617




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9992185274119567




CCA coefficients mean non-concern: 0.9994619068762868




Linear CKA concern: 0.9951963843184513




Linear CKA non-concern: 0.9994496652471688




Kernel CKA concern: 0.9965683506924207




Kernel CKA non-concern: 0.9976989964102357




original model's perplexity




3.187649726867676




pruned model's perplexity




3.180872917175293




Evaluate the pruned model 6




Evaluating the model:   0%|                                  | 0/1875 [00:00<?, ?it/s]

Loss: 1.2183




Precision: 0.6462, Recall: 0.6151, F1-Score: 0.6199




              precision    recall  f1-score   support

           0     0.5574    0.4693    0.5095      2992
           1     0.7020    0.5174    0.5957      2992
           2     0.6822    0.6408    0.6608      3012
           3     0.3427    0.6418    0.4468      2998
           4     0.7131    0.7666    0.7389      2973
           5     0.8457    0.7678    0.8049      3054
           6     0.6750    0.4053    0.5065      3003
           7     0.6243    0.6069    0.6155      3012
           8     0.5895    0.6938    0.6374      2982
           9     0.7300    0.6419    0.6831      2982

    accuracy                         0.6153     30000
   macro avg     0.6462    0.6151    0.6199     30000
weighted avg     0.6465    0.6153    0.6201     30000





0.19449301788000617




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9995669976038045




CCA coefficients mean non-concern: 0.9995253555320858




Linear CKA concern: 0.9996701676701572




Linear CKA non-concern: 0.9995602054547182




Kernel CKA concern: 0.9986732346687048




Kernel CKA non-concern: 0.9987931682356375




original model's perplexity




3.187649726867676




pruned model's perplexity




3.1808671951293945




Evaluate the pruned model 7




Evaluating the model:   0%|                                                                                   …

Loss: 1.2174




Precision: 0.6469, Recall: 0.6154, F1-Score: 0.6203




              precision    recall  f1-score   support

           0     0.5577    0.4682    0.5091      2992
           1     0.7007    0.5180    0.5957      2992
           2     0.6845    0.6404    0.6617      3012
           3     0.3417    0.6414    0.4459      2998
           4     0.7105    0.7686    0.7384      2973
           5     0.8531    0.7662    0.8073      3054
           6     0.6731    0.4066    0.5070      3003
           7     0.6281    0.6049    0.6163      3012
           8     0.5867    0.6989    0.6379      2982
           9     0.7332    0.6405    0.6837      2982

    accuracy                         0.6155     30000
   macro avg     0.6469    0.6154    0.6203     30000
weighted avg     0.6473    0.6155    0.6205     30000





0.19449301788000617




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9993445875605222




CCA coefficients mean non-concern: 0.9992874341260167




Linear CKA concern: 0.9985335083950615




Linear CKA non-concern: 0.9989422666572801




Kernel CKA concern: 0.9981091767145606




Kernel CKA non-concern: 0.9970918290171985




original model's perplexity




3.187649726867676




pruned model's perplexity




3.179891586303711




Evaluate the pruned model 8




Evaluating the model:   0%|                                                                                   …

Loss: 1.2180




Precision: 0.6468, Recall: 0.6153, F1-Score: 0.6202




              precision    recall  f1-score   support

           0     0.5576    0.4676    0.5086      2992
           1     0.7018    0.5191    0.5967      2992
           2     0.6825    0.6431    0.6622      3012
           3     0.3414    0.6418    0.4457      2998
           4     0.7105    0.7676    0.7379      2973
           5     0.8507    0.7669    0.8066      3054
           6     0.6737    0.4056    0.5063      3003
           7     0.6285    0.6049    0.6165      3012
           8     0.5894    0.6962    0.6384      2982
           9     0.7322    0.6408    0.6835      2982

    accuracy                         0.6155     30000
   macro avg     0.6468    0.6153    0.6202     30000
weighted avg     0.6472    0.6155    0.6205     30000





0.19449301788000617




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9994352783137592




CCA coefficients mean non-concern: 0.9993933521602115




Linear CKA concern: 0.9993656317744758




Linear CKA non-concern: 0.9992087464924129




Kernel CKA concern: 0.998153486927033




Kernel CKA non-concern: 0.997964736358618




original model's perplexity




3.187649726867676




pruned model's perplexity




3.1809399127960205




Evaluate the pruned model 9




Evaluating the model:   0%|                                  | 0/1875 [00:00<?, ?it/s]

Loss: 1.2188




Precision: 0.6465, Recall: 0.6151, F1-Score: 0.6200




              precision    recall  f1-score   support

           0     0.5576    0.4676    0.5086      2992
           1     0.7013    0.5187    0.5963      2992
           2     0.6816    0.6411    0.6607      3012
           3     0.3417    0.6434    0.4464      2998
           4     0.7145    0.7652    0.7390      2973
           5     0.8466    0.7682    0.8055      3054
           6     0.6765    0.4053    0.5069      3003
           7     0.6254    0.6059    0.6155      3012
           8     0.5899    0.6928    0.6373      2982
           9     0.7300    0.6429    0.6837      2982

    accuracy                         0.6152     30000
   macro avg     0.6465    0.6151    0.6200     30000
weighted avg     0.6468    0.6152    0.6202     30000





0.19449301788000617




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9996479642895966




CCA coefficients mean non-concern: 0.9995026760834851




Linear CKA concern: 0.9998367911023586




Linear CKA non-concern: 0.999577473226377




Kernel CKA concern: 0.9994466039041503




Kernel CKA non-concern: 0.9988000199887845




original model's perplexity




3.187649726867676




pruned model's perplexity




3.182737112045288




In [12]:
df_list = [report_to_df(df) for df in result_list]
new_df = append_nth_row(df_list)
csv_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
new_df.to_csv(f"results/{csv_name}.csv", index=False)