In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint
from src.models.evaluate import (
    evaluate_model,
    get_sparsity,
    get_similarity,
    get_perplexity,
)
from src.utils.sampling import SamplingDataset
from src.pruning.prune import (
    prune_concern_identification,
)
from src.utils.helper import report_to_df, append_nth_row

In [3]:
name = "bert-6-128-yahoo"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ratio = 0.6
seed = 44
include_layers = ["intermediate", "output"]
exclude_layers = ["attention"]

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-10-15 20:41:55


In [5]:
config = Config(name, device)
num_labels = config.config["num_labels"]
model = load_model(config)

Loading the model.




{

'architectures'

: 

'bert'

,
 

'dataset_name'

: 

'YahooAnswersTopics'

,
 

'model_name'

: 

'models/bert-6-128-yahoo'

,
 

'num_labels'

: 

10

,
 

'tokenizer_name'

: 

'fabriceyhc/bert-base-uncased-yahoo_answers_topics'

}




The model models/bert-6-128-yahoo is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.




train.pkl is loaded from cache.




valid.pkl is loaded from cache.




test.pkl is loaded from cache.




The dataset YahooAnswersTopics is loaded




{

'config_name'

: 

'yahoo_answers_topics'

,
 

'features'

: 

{'first_column': 'question_title', 'second_column': 'topic'}

,
 

'path'

: 

'yahoo_answers_topics'

}




In [7]:
# print("Evaluate the original model")
# result = evaluate_model(model, config, test_dataloader)

In [8]:
result_list = []

for concern in range(config.num_labels):
    config.init_seed()
    positive_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        True,
        4,
        resample=False,
    )
    negative_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        False,
        4,
        resample=False,
    )
    all_samples = SamplingDataset(
        train_dataloader,
        config,
        200,
        num_samples,
        False,
        4,
        resample=False,
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        config,
        positive_samples,
        negative_samples,
        include_layers=["intermediate", "output"],
        exclude_layers=["attention"],
        sparsity_ratio=ratio,
        keep_dim=True,
        method="unstructed",
    )
    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, config, test_dataloader)
    result_list.append(result)
    get_sparsity(module)

    get_similarity(model, module, valid_dataloader, concern, num_samples, config)
    print("original model's perplexity")
    get_perplexity(model, valid_dataloader, config)
    print("pruned model's perplexity")
    get_perplexity(module, valid_dataloader, config)

Evaluate the pruned model 0




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2151




Precision: 0.6465, Recall: 0.6151, F1-Score: 0.6200




              precision    recall  f1-score   support

           0     0.5534    0.4729    0.5100      2992
           1     0.7006    0.5287    0.6027      2992
           2     0.6777    0.6388    0.6577      3012
           3     0.3411    0.6371    0.4443      2998
           4     0.7329    0.7484    0.7406      2973
           5     0.8378    0.7714    0.8033      3054
           6     0.6896    0.3966    0.5036      3003
           7     0.6235    0.6076    0.6154      3012
           8     0.5877    0.6968    0.6376      2982
           9     0.7208    0.6529    0.6852      2982

    accuracy                         0.6153     30000
   macro avg     0.6465    0.6151    0.6200     30000
weighted avg     0.6468    0.6153    0.6203     30000





0.3896237177858484




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9980182816109396




CCA coefficients mean non-concern: 0.9973322569328651




Linear CKA concern: 0.998920680827511




Linear CKA non-concern: 0.9982787294858838




Kernel CKA concern: 0.995765245031148




Kernel CKA non-concern: 0.9924221189010626




original model's perplexity




3.187649726867676




pruned model's perplexity




3.182018518447876




Evaluate the pruned model 1




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2141




Precision: 0.6469, Recall: 0.6154, F1-Score: 0.6201




              precision    recall  f1-score   support

           0     0.5589    0.4723    0.5120      2992
           1     0.6986    0.5244    0.5991      2992
           2     0.6765    0.6401    0.6578      3012
           3     0.3403    0.6431    0.4451      2998
           4     0.7234    0.7652    0.7437      2973
           5     0.8343    0.7714    0.8016      3054
           6     0.6895    0.3956    0.5028      3003
           7     0.6279    0.6006    0.6139      3012
           8     0.5952    0.6908    0.6395      2982
           9     0.7243    0.6502    0.6853      2982

    accuracy                         0.6155     30000
   macro avg     0.6469    0.6154    0.6201     30000
weighted avg     0.6472    0.6155    0.6203     30000





0.3896237177858484




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9983586646562664




CCA coefficients mean non-concern: 0.9975078302243144




Linear CKA concern: 0.9985051633026042




Linear CKA non-concern: 0.9987567109774037




Kernel CKA concern: 0.9959628775524315




Kernel CKA non-concern: 0.993971476047603




original model's perplexity




3.187649726867676




pruned model's perplexity




3.1766064167022705




Evaluate the pruned model 2




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2135




Precision: 0.6459, Recall: 0.6162, F1-Score: 0.6206




              precision    recall  f1-score   support

           0     0.5560    0.4743    0.5119      2992
           1     0.6981    0.5364    0.6067      2992
           2     0.6771    0.6411    0.6586      3012
           3     0.3448    0.6374    0.4475      2998
           4     0.7274    0.7531    0.7400      2973
           5     0.8306    0.7754    0.8020      3054
           6     0.6888    0.3936    0.5010      3003
           7     0.6250    0.6033    0.6140      3012
           8     0.5932    0.6918    0.6387      2982
           9     0.7178    0.6559    0.6855      2982

    accuracy                         0.6164     30000
   macro avg     0.6459    0.6162    0.6206     30000
weighted avg     0.6462    0.6164    0.6208     30000





0.3896237177858484




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9978909538754782




CCA coefficients mean non-concern: 0.9973449583068938




Linear CKA concern: 0.9973586149903818




Linear CKA non-concern: 0.9984721189605548




Kernel CKA concern: 0.9952038127066433




Kernel CKA non-concern: 0.9938566053594102




original model's perplexity




3.187649726867676




pruned model's perplexity




3.1748292446136475




Evaluate the pruned model 3




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2143




Precision: 0.6466, Recall: 0.6159, F1-Score: 0.6206




              precision    recall  f1-score   support

           0     0.5561    0.4706    0.5098      2992
           1     0.6986    0.5331    0.6047      2992
           2     0.6790    0.6384    0.6581      3012
           3     0.3427    0.6398    0.4463      2998
           4     0.7273    0.7528    0.7398      2973
           5     0.8377    0.7708    0.8029      3054
           6     0.6844    0.3986    0.5038      3003
           7     0.6268    0.6062    0.6164      3012
           8     0.5913    0.6972    0.6399      2982
           9     0.7217    0.6512    0.6846      2982

    accuracy                         0.6160     30000
   macro avg     0.6466    0.6159    0.6206     30000
weighted avg     0.6469    0.6160    0.6209     30000





0.3896237177858484




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9979600086694372




CCA coefficients mean non-concern: 0.9973039636447298




Linear CKA concern: 0.9976146445319738




Linear CKA non-concern: 0.9983195445605524




Kernel CKA concern: 0.9962986014620018




Kernel CKA non-concern: 0.9930712134059925




original model's perplexity




3.187649726867676




pruned model's perplexity




3.1809496879577637




Evaluate the pruned model 4




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2125




Precision: 0.6452, Recall: 0.6162, F1-Score: 0.6204




              precision    recall  f1-score   support

           0     0.5489    0.4786    0.5113      2992
           1     0.6963    0.5334    0.6041      2992
           2     0.6855    0.6338    0.6586      3012
           3     0.3458    0.6331    0.4473      2998
           4     0.7207    0.7656    0.7425      2973
           5     0.8307    0.7744    0.8016      3054
           6     0.6890    0.3946    0.5018      3003
           7     0.6255    0.6033    0.6142      3012
           8     0.5943    0.6908    0.6390      2982
           9     0.7150    0.6546    0.6835      2982

    accuracy                         0.6163     30000
   macro avg     0.6452    0.6162    0.6204     30000
weighted avg     0.6455    0.6163    0.6206     30000





0.3896237177858484




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9964791384151782




CCA coefficients mean non-concern: 0.9972905342221847




Linear CKA concern: 0.9955693222609859




Linear CKA non-concern: 0.9980854598876441




Kernel CKA concern: 0.9939566795829434




Kernel CKA non-concern: 0.9911412889456859




original model's perplexity




3.187649726867676




pruned model's perplexity




3.168870687484741




Evaluate the pruned model 5




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2152




Precision: 0.6475, Recall: 0.6157, F1-Score: 0.6208




              precision    recall  f1-score   support

           0     0.5478    0.4786    0.5109      2992
           1     0.6937    0.5374    0.6056      2992
           2     0.6803    0.6401    0.6596      3012
           3     0.3402    0.6418    0.4447      2998
           4     0.7289    0.7514    0.7400      2973
           5     0.8353    0.7737    0.8033      3054
           6     0.6924    0.3919    0.5005      3003
           7     0.6326    0.6019    0.6169      3012
           8     0.5950    0.6901    0.6390      2982
           9     0.7286    0.6499    0.6870      2982

    accuracy                         0.6158     30000
   macro avg     0.6475    0.6157    0.6208     30000
weighted avg     0.6478    0.6158    0.6210     30000





0.3896237177858484




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9974701949203755




CCA coefficients mean non-concern: 0.9976987374564095




Linear CKA concern: 0.9932420746415882




Linear CKA non-concern: 0.9989834603312602




Kernel CKA concern: 0.9948686252980915




Kernel CKA non-concern: 0.9950672755342934




original model's perplexity




3.187649726867676




pruned model's perplexity




3.18182635307312




Evaluate the pruned model 6




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2144




Precision: 0.6465, Recall: 0.6159, F1-Score: 0.6206




              precision    recall  f1-score   support

           0     0.5554    0.4723    0.5105      2992
           1     0.7012    0.5311    0.6044      2992
           2     0.6805    0.6384    0.6588      3012
           3     0.3423    0.6384    0.4457      2998
           4     0.7274    0.7575    0.7421      2973
           5     0.8361    0.7701    0.8018      3054
           6     0.6860    0.3979    0.5037      3003
           7     0.6230    0.6072    0.6150      3012
           8     0.5930    0.6928    0.6390      2982
           9     0.7206    0.6529    0.6851      2982

    accuracy                         0.6160     30000
   macro avg     0.6465    0.6159    0.6206     30000
weighted avg     0.6468    0.6160    0.6208     30000





0.3896237177858484




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9979158691890685




CCA coefficients mean non-concern: 0.9971951046391639




Linear CKA concern: 0.9990034448403419




Linear CKA non-concern: 0.9981103696096653




Kernel CKA concern: 0.9940379660416203




Kernel CKA non-concern: 0.9923958301156737




original model's perplexity




3.187649726867676




pruned model's perplexity




3.177887439727783




Evaluate the pruned model 7




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2125




Precision: 0.6456, Recall: 0.6168, F1-Score: 0.6210




              precision    recall  f1-score   support

           0     0.5496    0.4763    0.5103      2992
           1     0.6940    0.5398    0.6073      2992
           2     0.6774    0.6401    0.6582      3012
           3     0.3463    0.6301    0.4469      2998
           4     0.7244    0.7656    0.7444      2973
           5     0.8342    0.7711    0.8014      3054
           6     0.6915    0.3933    0.5014      3003
           7     0.6210    0.6082    0.6146      3012
           8     0.5945    0.6911    0.6392      2982
           9     0.7233    0.6522    0.6859      2982

    accuracy                         0.6169     30000
   macro avg     0.6456    0.6168    0.6210     30000
weighted avg     0.6459    0.6169    0.6212     30000





0.3896237177858484




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.997121978960503




CCA coefficients mean non-concern: 0.9972617908326613




Linear CKA concern: 0.9934844658105738




Linear CKA non-concern: 0.9977033733711418




Kernel CKA concern: 0.9914229284668198




Kernel CKA non-concern: 0.9918805548844637




original model's perplexity




3.187649726867676




pruned model's perplexity




3.1704235076904297




Evaluate the pruned model 8




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2134




Precision: 0.6461, Recall: 0.6168, F1-Score: 0.6209




              precision    recall  f1-score   support

           0     0.5538    0.4746    0.5112      2992
           1     0.6979    0.5381    0.6077      2992
           2     0.6833    0.6368    0.6592      3012
           3     0.3469    0.6351    0.4487      2998
           4     0.7263    0.7639    0.7446      2973
           5     0.8358    0.7701    0.8016      3054
           6     0.6897    0.3923    0.5001      3003
           7     0.6242    0.6049    0.6144      3012
           8     0.5878    0.6948    0.6369      2982
           9     0.7151    0.6573    0.6850      2982

    accuracy                         0.6169     30000
   macro avg     0.6461    0.6168    0.6209     30000
weighted avg     0.6464    0.6169    0.6211     30000





0.3896237177858484




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.997996139367477




CCA coefficients mean non-concern: 0.9967359793334555




Linear CKA concern: 0.9992143841830944




Linear CKA non-concern: 0.9966756347308128




Kernel CKA concern: 0.9974057568990983




Kernel CKA non-concern: 0.9877191234241982




original model's perplexity




3.187649726867676




pruned model's perplexity




3.174039840698242




Evaluate the pruned model 9




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2130




Precision: 0.6459, Recall: 0.6160, F1-Score: 0.6205




              precision    recall  f1-score   support

           0     0.5505    0.4756    0.5103      2992
           1     0.7001    0.5291    0.6027      2992
           2     0.6785    0.6391    0.6582      3012
           3     0.3442    0.6368    0.4469      2998
           4     0.7276    0.7582    0.7425      2973
           5     0.8357    0.7728    0.8030      3054
           6     0.6843    0.3963    0.5019      3003
           7     0.6243    0.6062    0.6151      3012
           8     0.5931    0.6942    0.6397      2982
           9     0.7210    0.6516    0.6845      2982

    accuracy                         0.6161     30000
   macro avg     0.6459    0.6160    0.6205     30000
weighted avg     0.6462    0.6161    0.6207     30000





0.3896237177858484




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9981792903999036




CCA coefficients mean non-concern: 0.9975282853211034




Linear CKA concern: 0.9975784791004971




Linear CKA non-concern: 0.9981100776797727




Kernel CKA concern: 0.9937328557180516




Kernel CKA non-concern: 0.9930481459239481




original model's perplexity




3.187649726867676




pruned model's perplexity




3.1740171909332275




In [9]:
df_list = [report_to_df(df) for df in result_list]
new_df = append_nth_row(df_list)
csv_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
new_df.to_csv(f"results/{csv_name}.csv", index=False)