In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint
from src.models.evaluate import (
    evaluate_model,
    get_sparsity,
    get_similarity,
    get_perplexity,
)
from src.utils.sampling import SamplingDataset
from src.pruning.prune import (
    prune_concern_identification,
)
from src.utils.helper import report_to_df, append_nth_row

In [3]:
name = "bert-6-128-yahoo"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ratio = 0.4
seed = 44
include_layers = ["intermediate", "output"]
exclude_layers = ["attention"]

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-10-15 20:23:56


In [5]:
config = Config(name, device)
num_labels = config.config["num_labels"]
model = load_model(config)

Loading the model.




{

'architectures'

: 

'bert'

,
 

'dataset_name'

: 

'YahooAnswersTopics'

,
 

'model_name'

: 

'models/bert-6-128-yahoo'

,
 

'num_labels'

: 

10

,
 

'tokenizer_name'

: 

'fabriceyhc/bert-base-uncased-yahoo_answers_topics'

}




The model models/bert-6-128-yahoo is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.




train.pkl is loaded from cache.




valid.pkl is loaded from cache.




test.pkl is loaded from cache.




The dataset YahooAnswersTopics is loaded




{

'config_name'

: 

'yahoo_answers_topics'

,
 

'features'

: 

{'first_column': 'question_title', 'second_column': 'topic'}

,
 

'path'

: 

'yahoo_answers_topics'

}




In [7]:
# print("Evaluate the original model")
# result = evaluate_model(model, config, test_dataloader)

In [8]:
result_list = []

for concern in range(config.num_labels):
    config.init_seed()
    positive_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        True,
        4,
        resample=False,
    )
    negative_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        False,
        4,
        resample=False,
    )
    all_samples = SamplingDataset(
        train_dataloader,
        config,
        200,
        num_samples,
        False,
        4,
        resample=False,
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        config,
        positive_samples,
        negative_samples,
        include_layers=["intermediate", "output"],
        exclude_layers=["attention"],
        sparsity_ratio=ratio,
        keep_dim=True,
        method="unstructed",
    )
    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, config, test_dataloader)
    result_list.append(result)
    get_sparsity(module)

    get_similarity(model, module, valid_dataloader, concern, num_samples, config)
    print("original model's perplexity")
    get_perplexity(model, valid_dataloader, config)
    print("pruned model's perplexity")
    get_perplexity(module, valid_dataloader, config)

Evaluate the pruned model 0




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2182




Precision: 0.6464, Recall: 0.6150, F1-Score: 0.6199




              precision    recall  f1-score   support

           0     0.5564    0.4679    0.5084      2992
           1     0.7019    0.5187    0.5966      2992
           2     0.6802    0.6411    0.6601      3012
           3     0.3411    0.6451    0.4462      2998
           4     0.7174    0.7632    0.7396      2973
           5     0.8399    0.7698    0.8033      3054
           6     0.6792    0.4033    0.5061      3003
           7     0.6253    0.6079    0.6165      3012
           8     0.5953    0.6891    0.6388      2982
           9     0.7275    0.6439    0.6832      2982

    accuracy                         0.6151     30000
   macro avg     0.6464    0.6150    0.6199     30000
weighted avg     0.6467    0.6151    0.6201     30000





0.260174266541123




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9994984508075546




CCA coefficients mean non-concern: 0.9993915887803498




Linear CKA concern: 0.9998169171611085




Linear CKA non-concern: 0.9998080711893171




Kernel CKA concern: 0.9993359401887009




Kernel CKA non-concern: 0.9992178673233639




original model's perplexity




3.187649726867676




pruned model's perplexity




3.181722640991211




Evaluate the pruned model 1




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2180




Precision: 0.6461, Recall: 0.6148, F1-Score: 0.6197




              precision    recall  f1-score   support

           0     0.5580    0.4696    0.5100      2992
           1     0.7007    0.5187    0.5961      2992
           2     0.6803    0.6428    0.6610      3012
           3     0.3401    0.6414    0.4445      2998
           4     0.7174    0.7625    0.7393      2973
           5     0.8388    0.7701    0.8030      3054
           6     0.6780    0.4039    0.5063      3003
           7     0.6248    0.6059    0.6152      3012
           8     0.5945    0.6901    0.6387      2982
           9     0.7285    0.6425    0.6828      2982

    accuracy                         0.6149     30000
   macro avg     0.6461    0.6148    0.6197     30000
weighted avg     0.6464    0.6149    0.6199     30000





0.260174266541123




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.999602034040425




CCA coefficients mean non-concern: 0.9993960494588634




Linear CKA concern: 0.9998085964920257




Linear CKA non-concern: 0.9998383774011631




Kernel CKA concern: 0.999493954999626




Kernel CKA non-concern: 0.9993045931827502




original model's perplexity




3.187649726867676




pruned model's perplexity




3.180698871612549




Evaluate the pruned model 2




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2178




Precision: 0.6459, Recall: 0.6149, F1-Score: 0.6197




              precision    recall  f1-score   support

           0     0.5549    0.4699    0.5089      2992
           1     0.7005    0.5191    0.5963      2992
           2     0.6811    0.6418    0.6609      3012
           3     0.3418    0.6438    0.4466      2998
           4     0.7191    0.7629    0.7403      2973
           5     0.8371    0.7708    0.8026      3054
           6     0.6781    0.4033    0.5057      3003
           7     0.6255    0.6056    0.6154      3012
           8     0.5947    0.6878    0.6378      2982
           9     0.7258    0.6445    0.6828      2982

    accuracy                         0.6151     30000
   macro avg     0.6459    0.6149    0.6197     30000
weighted avg     0.6462    0.6151    0.6199     30000





0.260174266541123




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9995027425048238




CCA coefficients mean non-concern: 0.9993792238722841




Linear CKA concern: 0.9997336955381909




Linear CKA non-concern: 0.9998177065352373




Kernel CKA concern: 0.9994872463924926




Kernel CKA non-concern: 0.9992904141889004




original model's perplexity




3.187649726867676




pruned model's perplexity




3.180622100830078




Evaluate the pruned model 3




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2182




Precision: 0.6462, Recall: 0.6148, F1-Score: 0.6197




              precision    recall  f1-score   support

           0     0.5590    0.4669    0.5088      2992
           1     0.6997    0.5194    0.5962      2992
           2     0.6788    0.6391    0.6583      3012
           3     0.3403    0.6424    0.4450      2998
           4     0.7190    0.7615    0.7396      2973
           5     0.8422    0.7692    0.8040      3054
           6     0.6780    0.4053    0.5073      3003
           7     0.6247    0.6079    0.6162      3012
           8     0.5931    0.6922    0.6388      2982
           9     0.7270    0.6439    0.6829      2982

    accuracy                         0.6149     30000
   macro avg     0.6462    0.6148    0.6197     30000
weighted avg     0.6465    0.6149    0.6199     30000





0.260174266541123




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9995083307065223




CCA coefficients mean non-concern: 0.9994007690970567




Linear CKA concern: 0.9997860640735019




Linear CKA non-concern: 0.9997555523977799




Kernel CKA concern: 0.999682543775353




Kernel CKA non-concern: 0.9991234070428782




original model's perplexity




3.187649726867676




pruned model's perplexity




3.182474136352539




Evaluate the pruned model 4




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2172




Precision: 0.6459, Recall: 0.6152, F1-Score: 0.6198




              precision    recall  f1-score   support

           0     0.5573    0.4699    0.5099      2992
           1     0.7021    0.5184    0.5964      2992
           2     0.6833    0.6398    0.6608      3012
           3     0.3425    0.6424    0.4468      2998
           4     0.7155    0.7639    0.7389      2973
           5     0.8374    0.7705    0.8025      3054
           6     0.6787    0.4016    0.5046      3003
           7     0.6242    0.6082    0.6161      3012
           8     0.5941    0.6911    0.6390      2982
           9     0.7243    0.6459    0.6829      2982

    accuracy                         0.6153     30000
   macro avg     0.6459    0.6152    0.6198     30000
weighted avg     0.6462    0.6153    0.6200     30000





0.260174266541123




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9992540815700247




CCA coefficients mean non-concern: 0.9993212796464455




Linear CKA concern: 0.9994606836526837




Linear CKA non-concern: 0.999734794420089




Kernel CKA concern: 0.9992991003871754




Kernel CKA non-concern: 0.9988001756343727




original model's perplexity




3.187649726867676




pruned model's perplexity




3.178295850753784




Evaluate the pruned model 5




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2185




Precision: 0.6472, Recall: 0.6153, F1-Score: 0.6203




              precision    recall  f1-score   support

           0     0.5580    0.4696    0.5100      2992
           1     0.7010    0.5187    0.5962      2992
           2     0.6821    0.6424    0.6617      3012
           3     0.3404    0.6458    0.4458      2998
           4     0.7215    0.7608    0.7407      2973
           5     0.8410    0.7688    0.8033      3054
           6     0.6788    0.4026    0.5054      3003
           7     0.6257    0.6082    0.6168      3012
           8     0.5944    0.6915    0.6393      2982
           9     0.7290    0.6449    0.6843      2982

    accuracy                         0.6155     30000
   macro avg     0.6472    0.6153    0.6203     30000
weighted avg     0.6475    0.6155    0.6206     30000





0.260174266541123




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9993791377751405




CCA coefficients mean non-concern: 0.9993908072842262




Linear CKA concern: 0.9983686944558833




Linear CKA non-concern: 0.9998453136183053




Kernel CKA concern: 0.9991683593408137




Kernel CKA non-concern: 0.9993304440858748




original model's perplexity




3.187649726867676




pruned model's perplexity




3.183448076248169




Evaluate the pruned model 6




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2180




Precision: 0.6460, Recall: 0.6150, F1-Score: 0.6199




              precision    recall  f1-score   support

           0     0.5564    0.4713    0.5103      2992
           1     0.7000    0.5211    0.5974      2992
           2     0.6824    0.6384    0.6597      3012
           3     0.3412    0.6418    0.4455      2998
           4     0.7189    0.7622    0.7399      2973
           5     0.8381    0.7695    0.8023      3054
           6     0.6779    0.4036    0.5059      3003
           7     0.6250    0.6086    0.6167      3012
           8     0.5938    0.6888    0.6378      2982
           9     0.7268    0.6449    0.6834      2982

    accuracy                         0.6151     30000
   macro avg     0.6460    0.6150    0.6199     30000
weighted avg     0.6463    0.6151    0.6201     30000





0.260174266541123




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9995318292190489




CCA coefficients mean non-concern: 0.999329722766654




Linear CKA concern: 0.9998785038249199




Linear CKA non-concern: 0.9997009420479475




Kernel CKA concern: 0.9994176425747179




Kernel CKA non-concern: 0.9988393032096755




original model's perplexity




3.187649726867676




pruned model's perplexity




3.1814303398132324




Evaluate the pruned model 7




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2176




Precision: 0.6456, Recall: 0.6149, F1-Score: 0.6198




              precision    recall  f1-score   support

           0     0.5543    0.4689    0.5081      2992
           1     0.6992    0.5237    0.5989      2992
           2     0.6797    0.6411    0.6598      3012
           3     0.3420    0.6401    0.4458      2998
           4     0.7183    0.7625    0.7398      2973
           5     0.8403    0.7685    0.8028      3054
           6     0.6782    0.4043    0.5066      3003
           7     0.6224    0.6079    0.6150      3012
           8     0.5942    0.6888    0.6380      2982
           9     0.7274    0.6435    0.6829      2982

    accuracy                         0.6151     30000
   macro avg     0.6456    0.6149    0.6198     30000
weighted avg     0.6459    0.6151    0.6200     30000





0.260174266541123




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9994006090256906




CCA coefficients mean non-concern: 0.9993666528086212




Linear CKA concern: 0.9994736543476751




Linear CKA non-concern: 0.9997058530627714




Kernel CKA concern: 0.9993470007049756




Kernel CKA non-concern: 0.999051347153257




original model's perplexity




3.187649726867676




pruned model's perplexity




3.1798040866851807




Evaluate the pruned model 8




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2178




Precision: 0.6463, Recall: 0.6153, F1-Score: 0.6201




              precision    recall  f1-score   support

           0     0.5571    0.4699    0.5098      2992
           1     0.6997    0.5211    0.5973      2992
           2     0.6820    0.6394    0.6600      3012
           3     0.3419    0.6428    0.4464      2998
           4     0.7175    0.7629    0.7395      2973
           5     0.8405    0.7695    0.8034      3054
           6     0.6794    0.4023    0.5053      3003
           7     0.6257    0.6089    0.6172      3012
           8     0.5942    0.6888    0.6380      2982
           9     0.7246    0.6476    0.6839      2982

    accuracy                         0.6154     30000
   macro avg     0.6463    0.6153    0.6201     30000
weighted avg     0.6466    0.6154    0.6203     30000





0.260174266541123




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9995022737000998




CCA coefficients mean non-concern: 0.9992319399583924




Linear CKA concern: 0.9999071440744486




Linear CKA non-concern: 0.9996039257704721




Kernel CKA concern: 0.9997272025638279




Kernel CKA non-concern: 0.9986332857494981




original model's perplexity




3.187649726867676




pruned model's perplexity




3.180799722671509




Evaluate the pruned model 9




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2176




Precision: 0.6458, Recall: 0.6150, F1-Score: 0.6198




              precision    recall  f1-score   support

           0     0.5554    0.4709    0.5097      2992
           1     0.6997    0.5194    0.5962      2992
           2     0.6787    0.6424    0.6601      3012
           3     0.3418    0.6414    0.4460      2998
           4     0.7182    0.7629    0.7398      2973
           5     0.8396    0.7695    0.8030      3054
           6     0.6779    0.4036    0.5059      3003
           7     0.6264    0.6062    0.6162      3012
           8     0.5933    0.6898    0.6379      2982
           9     0.7274    0.6435    0.6829      2982

    accuracy                         0.6151     30000
   macro avg     0.6458    0.6150    0.6198     30000
weighted avg     0.6461    0.6151    0.6200     30000





0.260174266541123




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9995640160474402




CCA coefficients mean non-concern: 0.9994070107344353




Linear CKA concern: 0.9997267463193062




Linear CKA non-concern: 0.999743073037157




Kernel CKA concern: 0.9992373345439324




Kernel CKA non-concern: 0.9991796296027643




original model's perplexity




3.187649726867676




pruned model's perplexity




3.1804962158203125




In [9]:
df_list = [report_to_df(df) for df in result_list]
new_df = append_nth_row(df_list)
csv_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
new_df.to_csv(f"results/{csv_name}.csv", index=False)