In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint
from src.models.evaluate import (
    evaluate_model,
    get_sparsity,
    get_similarity,
    get_perplexity,
)
from src.utils.sampling import SamplingDataset
from src.pruning.prune import (
    prune_concern_identification,
)
from src.utils.helper import report_to_df, append_nth_row

In [3]:
name = "bert-4-128-yahoo"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ratio = 0.4
seed = 44
include_layers = ["intermediate", "output"]
exclude_layers = ["attention"]

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-10-15 19:47:22


In [5]:
config = Config(name, device)
num_labels = config.config["num_labels"]
model = load_model(config)

Loading the model.




{

'architectures'

: 

'bert'

,
 

'dataset_name'

: 

'YahooAnswersTopics'

,
 

'model_name'

: 

'models/bert-4-128-yahoo'

,
 

'num_labels'

: 

10

,
 

'tokenizer_name'

: 

'fabriceyhc/bert-base-uncased-yahoo_answers_topics'

}




The model models/bert-4-128-yahoo is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.




train.pkl is loaded from cache.




valid.pkl is loaded from cache.




test.pkl is loaded from cache.




The dataset YahooAnswersTopics is loaded




{

'config_name'

: 

'yahoo_answers_topics'

,
 

'features'

: 

{'first_column': 'question_title', 'second_column': 'topic'}

,
 

'path'

: 

'yahoo_answers_topics'

}




In [7]:
# print("Evaluate the original model")
# result = evaluate_model(model, config, test_dataloader)

In [8]:
result_list = []

for concern in range(config.num_labels):
    config.init_seed()
    positive_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        True,
        4,
        resample=False,
    )
    negative_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        False,
        4,
        resample=False,
    )
    all_samples = SamplingDataset(
        train_dataloader,
        config,
        200,
        num_samples,
        False,
        4,
        resample=False,
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        config,
        positive_samples,
        negative_samples,
        include_layers=["intermediate", "output"],
        exclude_layers=["attention"],
        sparsity_ratio=ratio,
        keep_dim=True,
        method="unstructed",
    )
    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, config, test_dataloader)
    result_list.append(result)
    get_sparsity(module)

    get_similarity(model, module, valid_dataloader, concern, num_samples, config)
    print("original model's perplexity")
    get_perplexity(model, valid_dataloader, config)
    print("pruned model's perplexity")
    get_perplexity(module, valid_dataloader, config)

Evaluate the pruned model 0




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2207




Precision: 0.6481, Recall: 0.6155, F1-Score: 0.6202




              precision    recall  f1-score   support

           0     0.5368    0.4826    0.5083      2992
           1     0.6994    0.4783    0.5681      2992
           2     0.6980    0.6145    0.6536      3012
           3     0.3433    0.6398    0.4468      2998
           4     0.7225    0.7804    0.7503      2973
           5     0.8414    0.7606    0.7990      3054
           6     0.6709    0.4079    0.5074      3003
           7     0.6213    0.6378    0.6294      3012
           8     0.5865    0.7143    0.6441      2982
           9     0.7615    0.6392    0.6950      2982

    accuracy                         0.6156     30000
   macro avg     0.6481    0.6155    0.6202     30000
weighted avg     0.6484    0.6156    0.6204     30000





0.3132945569804176




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9995441187934736




CCA coefficients mean non-concern: 0.9995014088979999




Linear CKA concern: 0.9999203028154574




Linear CKA non-concern: 0.9998930538350006




Kernel CKA concern: 0.9997110510800413




Kernel CKA non-concern: 0.9995852304109686




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2021968364715576




Evaluate the pruned model 1




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2210




Precision: 0.6487, Recall: 0.6155, F1-Score: 0.6202




              precision    recall  f1-score   support

           0     0.5403    0.4813    0.5091      2992
           1     0.6988    0.4776    0.5674      2992
           2     0.7000    0.6142    0.6543      3012
           3     0.3423    0.6401    0.4461      2998
           4     0.7217    0.7807    0.7500      2973
           5     0.8415    0.7613    0.7994      3054
           6     0.6753    0.4079    0.5086      3003
           7     0.6213    0.6384    0.6298      3012
           8     0.5837    0.7160    0.6431      2982
           9     0.7623    0.6378    0.6945      2982

    accuracy                         0.6156     30000
   macro avg     0.6487    0.6155    0.6202     30000
weighted avg     0.6490    0.6156    0.6204     30000





0.3132945569804176




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9995812367317525




CCA coefficients mean non-concern: 0.9995442624380847




Linear CKA concern: 0.999898025076727




Linear CKA non-concern: 0.9999001189922719




Kernel CKA concern: 0.9997223367338528




Kernel CKA non-concern: 0.9996049167665643




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2035419940948486




Evaluate the pruned model 2




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2213




Precision: 0.6483, Recall: 0.6147, F1-Score: 0.6195




              precision    recall  f1-score   support

           0     0.5379    0.4816    0.5082      2992
           1     0.7021    0.4719    0.5645      2992
           2     0.6982    0.6129    0.6528      3012
           3     0.3415    0.6418    0.4458      2998
           4     0.7220    0.7800    0.7499      2973
           5     0.8413    0.7587    0.7979      3054
           6     0.6734    0.4086    0.5086      3003
           7     0.6220    0.6371    0.6295      3012
           8     0.5839    0.7150    0.6428      2982
           9     0.7604    0.6395    0.6947      2982

    accuracy                         0.6148     30000
   macro avg     0.6483    0.6147    0.6195     30000
weighted avg     0.6486    0.6148    0.6197     30000





0.3132945569804176




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9995204789569822




CCA coefficients mean non-concern: 0.9994741308715099




Linear CKA concern: 0.9999312258088592




Linear CKA non-concern: 0.9998800089907287




Kernel CKA concern: 0.9998126017577437




Kernel CKA non-concern: 0.9995383576929318




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2046332359313965




Evaluate the pruned model 3




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2209




Precision: 0.6482, Recall: 0.6157, F1-Score: 0.6204




              precision    recall  f1-score   support

           0     0.5367    0.4843    0.5091      2992
           1     0.7001    0.4799    0.5695      2992
           2     0.6988    0.6145    0.6539      3012
           3     0.3440    0.6394    0.4474      2998
           4     0.7220    0.7777    0.7488      2973
           5     0.8414    0.7606    0.7990      3054
           6     0.6718    0.4096    0.5089      3003
           7     0.6190    0.6391    0.6289      3012
           8     0.5866    0.7133    0.6438      2982
           9     0.7618    0.6382    0.6945      2982

    accuracy                         0.6158     30000
   macro avg     0.6482    0.6157    0.6204     30000
weighted avg     0.6485    0.6158    0.6206     30000





0.3132945569804176




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9994981015595755




CCA coefficients mean non-concern: 0.9995073474241608




Linear CKA concern: 0.9997539310767969




Linear CKA non-concern: 0.9998892705438515




Kernel CKA concern: 0.9994546282491419




Kernel CKA non-concern: 0.9995674383795842




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2037553787231445




Evaluate the pruned model 4




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2216




Precision: 0.6489, Recall: 0.6150, F1-Score: 0.6198




              precision    recall  f1-score   support

           0     0.5362    0.4833    0.5083      2992
           1     0.7008    0.4753    0.5664      2992
           2     0.6995    0.6145    0.6543      3012
           3     0.3413    0.6408    0.4453      2998
           4     0.7211    0.7783    0.7486      2973
           5     0.8435    0.7590    0.7990      3054
           6     0.6781    0.4069    0.5086      3003
           7     0.6217    0.6351    0.6283      3012
           8     0.5836    0.7163    0.6432      2982
           9     0.7627    0.6402    0.6961      2982

    accuracy                         0.6151     30000
   macro avg     0.6489    0.6150    0.6198     30000
weighted avg     0.6492    0.6151    0.6200     30000





0.3132945569804176




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9993768703323213




CCA coefficients mean non-concern: 0.999497433162665




Linear CKA concern: 0.9998391924521726




Linear CKA non-concern: 0.9998828799258643




Kernel CKA concern: 0.9996229745700104




Kernel CKA non-concern: 0.9995441026018276




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2052536010742188




Evaluate the pruned model 5




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2198




Precision: 0.6487, Recall: 0.6154, F1-Score: 0.6202




              precision    recall  f1-score   support

           0     0.5345    0.4813    0.5065      2992
           1     0.6981    0.4769    0.5667      2992
           2     0.6997    0.6165    0.6555      3012
           3     0.3418    0.6391    0.4454      2998
           4     0.7220    0.7773    0.7486      2973
           5     0.8432    0.7626    0.8009      3054
           6     0.6751    0.4083    0.5088      3003
           7     0.6258    0.6358    0.6308      3012
           8     0.5836    0.7190    0.6442      2982
           9     0.7631    0.6372    0.6944      2982

    accuracy                         0.6155     30000
   macro avg     0.6487    0.6154    0.6202     30000
weighted avg     0.6490    0.6155    0.6204     30000





0.3132945569804176




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9993489508710774




CCA coefficients mean non-concern: 0.9994726415912469




Linear CKA concern: 0.9992037354186611




Linear CKA non-concern: 0.9998990723047062




Kernel CKA concern: 0.9990194501867121




Kernel CKA non-concern: 0.9996640685033673




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.199854612350464




Evaluate the pruned model 6




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2216




Precision: 0.6480, Recall: 0.6146, F1-Score: 0.6194




              precision    recall  f1-score   support

           0     0.5370    0.4799    0.5069      2992
           1     0.7003    0.4756    0.5665      2992
           2     0.6995    0.6135    0.6537      3012
           3     0.3419    0.6414    0.4460      2998
           4     0.7230    0.7770    0.7490      2973
           5     0.8423    0.7590    0.7985      3054
           6     0.6720    0.4079    0.5077      3003
           7     0.6180    0.6381    0.6279      3012
           8     0.5851    0.7150    0.6435      2982
           9     0.7607    0.6385    0.6943      2982

    accuracy                         0.6147     30000
   macro avg     0.6480    0.6146    0.6194     30000
weighted avg     0.6483    0.6147    0.6196     30000





0.3132945569804176




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9995887022933584




CCA coefficients mean non-concern: 0.9994600643492239




Linear CKA concern: 0.9999289420523426




Linear CKA non-concern: 0.9998492608589206




Kernel CKA concern: 0.9997320350637143




Kernel CKA non-concern: 0.9994065306675654




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2063381671905518




Evaluate the pruned model 7




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2206




Precision: 0.6483, Recall: 0.6159, F1-Score: 0.6205




              precision    recall  f1-score   support

           0     0.5372    0.4826    0.5085      2992
           1     0.7004    0.4813    0.5705      2992
           2     0.7002    0.6132    0.6538      3012
           3     0.3442    0.6381    0.4472      2998
           4     0.7226    0.7773    0.7490      2973
           5     0.8422    0.7620    0.8001      3054
           6     0.6709    0.4093    0.5084      3003
           7     0.6195    0.6394    0.6293      3012
           8     0.5844    0.7166    0.6438      2982
           9     0.7617    0.6388    0.6949      2982

    accuracy                         0.6160     30000
   macro avg     0.6483    0.6159    0.6205     30000
weighted avg     0.6486    0.6160    0.6207     30000





0.3132945569804176




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9994568305055256




CCA coefficients mean non-concern: 0.9994498013468242




Linear CKA concern: 0.9997157235135041




Linear CKA non-concern: 0.9998551630524842




Kernel CKA concern: 0.9993969318624666




Kernel CKA non-concern: 0.9994896006819652




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2027156352996826




Evaluate the pruned model 8




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2207




Precision: 0.6486, Recall: 0.6153, F1-Score: 0.6200




              precision    recall  f1-score   support

           0     0.5377    0.4809    0.5078      2992
           1     0.7014    0.4766    0.5676      2992
           2     0.6992    0.6135    0.6536      3012
           3     0.3422    0.6414    0.4463      2998
           4     0.7212    0.7814    0.7501      2973
           5     0.8446    0.7583    0.7992      3054
           6     0.6729    0.4069    0.5072      3003
           7     0.6230    0.6381    0.6305      3012
           8     0.5851    0.7156    0.6438      2982
           9     0.7581    0.6402    0.6942      2982

    accuracy                         0.6154     30000
   macro avg     0.6486    0.6153    0.6200     30000
weighted avg     0.6489    0.6154    0.6202     30000





0.3132945569804176




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9994371171933906




CCA coefficients mean non-concern: 0.9993694414440907




Linear CKA concern: 0.9998984774873169




Linear CKA non-concern: 0.9997943287600533




Kernel CKA concern: 0.9996808128283503




Kernel CKA non-concern: 0.9991924825984405




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2027523517608643




Evaluate the pruned model 9




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2211




Precision: 0.6484, Recall: 0.6155, F1-Score: 0.6203




              precision    recall  f1-score   support

           0     0.5392    0.4830    0.5095      2992
           1     0.6997    0.4796    0.5691      2992
           2     0.6996    0.6139    0.6539      3012
           3     0.3433    0.6404    0.4470      2998
           4     0.7222    0.7773    0.7487      2973
           5     0.8429    0.7606    0.7997      3054
           6     0.6718    0.4096    0.5089      3003
           7     0.6199    0.6384    0.6290      3012
           8     0.5848    0.7143    0.6431      2982
           9     0.7606    0.6382    0.6940      2982

    accuracy                         0.6156     30000
   macro avg     0.6484    0.6155    0.6203     30000
weighted avg     0.6487    0.6156    0.6205     30000





0.3132945569804176




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9995533899366119




CCA coefficients mean non-concern: 0.9995099592452662




Linear CKA concern: 0.99984350897043




Linear CKA non-concern: 0.9998614418790616




Kernel CKA concern: 0.999616244159705




Kernel CKA non-concern: 0.999521358213088




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.204106330871582




In [9]:
df_list = [report_to_df(df) for df in result_list]
new_df = append_nth_row(df_list)
csv_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
new_df.to_csv(f"results/{csv_name}.csv", index=False)