In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint
from src.models.evaluate import (
    evaluate_model,
    get_sparsity,
    get_similarity,
    get_perplexity,
)
from src.utils.sampling import SamplingDataset
from src.pruning.prune import (
    prune_concern_identification,
)
from src.utils.helper import report_to_df, append_nth_row

In [3]:
name = "bert-4-128-yahoo"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ratio = 0.5
seed = 44
include_layers = ["intermediate", "output"]
exclude_layers = ["attention"]

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-10-15 23:18:54


In [5]:
config = Config(name, device)
num_labels = config.config["num_labels"]
model = load_model(config)

Loading the model.




{

'architectures'

: 

'bert'

,
 

'dataset_name'

: 

'YahooAnswersTopics'

,
 

'model_name'

: 

'models/bert-4-128-yahoo'

,
 

'num_labels'

: 

10

,
 

'tokenizer_name'

: 

'fabriceyhc/bert-base-uncased-yahoo_answers_topics'

}




The model models/bert-4-128-yahoo is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.




train.pkl is loaded from cache.




valid.pkl is loaded from cache.




test.pkl is loaded from cache.




The dataset YahooAnswersTopics is loaded




{

'config_name'

: 

'yahoo_answers_topics'

,
 

'features'

: 

{'first_column': 'question_title', 'second_column': 'topic'}

,
 

'path'

: 

'yahoo_answers_topics'

}




In [7]:
from src.utils.load import load_cache
from src.utils.data_class import CustomEmbeddingDataset
from torch.utils.data import DataLoader

generated = load_cache(
    "datasets/generated_dataset/embedding_based/4_128-yahoo",
    "4_128-yahoo_top1.pkl",
)

4_128-yahoo_top1.pkl is loaded from cache.




In [8]:
generated["embeddings"] = generated.pop("example_list")
generated["labels"] = generated.pop("example_label")
generated["attention_mask"] = generated.pop("attn_list")

In [9]:
generated_data = CustomEmbeddingDataset(generated)
generated_dataloder = DataLoader(
    generated_data,
    batch_size=4,
)

In [10]:
# print("Evaluate the original model")
# result = evaluate_model(model, config, test_dataloader)

In [11]:
result_list = []

for concern in range(config.num_labels):
    config.init_seed()
    positive_samples = SamplingDataset(
        generated_dataloder,
        config,
        concern,
        num_samples,
        True,
        4,
        resample=False,
    )
    negative_samples = SamplingDataset(
        generated_dataloder,
        config,
        concern,
        num_samples,
        False,
        4,
        resample=False,
    )
    all_samples = SamplingDataset(
        generated_dataloder,
        config,
        200,
        num_samples,
        False,
        4,
        resample=False,
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        config,
        positive_samples,
        negative_samples,
        include_layers=include_layers,
        exclude_layers=exclude_layers,
        sparsity_ratio=ratio,
        keep_dim=True,
        method="unstructed",
    )
    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, config, test_dataloader)
    result_list.append(result)
    get_sparsity(module)

    get_similarity(model, module, valid_dataloader, concern, num_samples, config)
    print("original model's perplexity")
    get_perplexity(model, valid_dataloader, config)
    print("pruned model's perplexity")
    get_perplexity(module, valid_dataloader, config)

Evaluate the pruned model 0




Evaluating the model:   0%|                                  | 0/1875 [00:00<?, ?it/s]

Loss: 1.2220




Precision: 0.6484, Recall: 0.6147, F1-Score: 0.6196




              precision    recall  f1-score   support

           0     0.5434    0.4793    0.5093      2992
           1     0.6906    0.4910    0.5739      2992
           2     0.6978    0.6125    0.6524      3012
           3     0.3403    0.6398    0.4442      2998
           4     0.7281    0.7736    0.7502      2973
           5     0.8425    0.7603    0.7993      3054
           6     0.6793    0.3986    0.5024      3003
           7     0.6193    0.6325    0.6258      3012
           8     0.5809    0.7203    0.6431      2982
           9     0.7618    0.6392    0.6951      2982

    accuracy                         0.6148     30000
   macro avg     0.6484    0.6147    0.6196     30000
weighted avg     0.6487    0.6148    0.6198     30000





0.39267273726798974




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9984044508500699




CCA coefficients mean non-concern: 0.9979983112864761




Linear CKA concern: 0.9994990495122578




Linear CKA non-concern: 0.9991556397901576




Kernel CKA concern: 0.9980774082798723




Kernel CKA non-concern: 0.9968188308026228




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.209425687789917




Evaluate the pruned model 1




Evaluating the model:   0%|                                  | 0/1875 [00:00<?, ?it/s]

Loss: 1.2245




Precision: 0.6498, Recall: 0.6136, F1-Score: 0.6186




              precision    recall  f1-score   support

           0     0.5442    0.4776    0.5087      2992
           1     0.6938    0.4870    0.5723      2992
           2     0.7049    0.6066    0.6520      3012
           3     0.3358    0.6438    0.4414      2998
           4     0.7220    0.7773    0.7486      2973
           5     0.8408    0.7626    0.7998      3054
           6     0.6924    0.3913    0.5000      3003
           7     0.6225    0.6328    0.6276      3012
           8     0.5784    0.7233    0.6428      2982
           9     0.7633    0.6338    0.6926      2982

    accuracy                         0.6137     30000
   macro avg     0.6498    0.6136    0.6186     30000
weighted avg     0.6501    0.6137    0.6188     30000





0.39267273726798974




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9985575810989995




CCA coefficients mean non-concern: 0.9980856316280481




Linear CKA concern: 0.9990830762267084




Linear CKA non-concern: 0.9994799548929073




Kernel CKA concern: 0.9971657791988676




Kernel CKA non-concern: 0.9978014530472241




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2209248542785645




Evaluate the pruned model 2




Evaluating the model:   0%|                                  | 0/1875 [00:00<?, ?it/s]

Loss: 1.2215




Precision: 0.6479, Recall: 0.6151, F1-Score: 0.6199




              precision    recall  f1-score   support

           0     0.5312    0.4860    0.5076      2992
           1     0.6966    0.4896    0.5751      2992
           2     0.6980    0.6109    0.6516      3012
           3     0.3422    0.6374    0.4453      2998
           4     0.7269    0.7736    0.7496      2973
           5     0.8450    0.7587    0.7995      3054
           6     0.6814    0.4016    0.5053      3003
           7     0.6240    0.6282    0.6261      3012
           8     0.5855    0.7156    0.6440      2982
           9     0.7486    0.6492    0.6954      2982

    accuracy                         0.6152     30000
   macro avg     0.6479    0.6151    0.6199     30000
weighted avg     0.6483    0.6152    0.6201     30000





0.39267273726798974




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9984166798673041




CCA coefficients mean non-concern: 0.9981086739247881




Linear CKA concern: 0.9992866193637345




Linear CKA non-concern: 0.9989782534643572




Kernel CKA concern: 0.9978172929650009




Kernel CKA non-concern: 0.9960529027055276




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2100226879119873




Evaluate the pruned model 3




Evaluating the model:   0%|                                  | 0/1875 [00:00<?, ?it/s]

Loss: 1.2194




Precision: 0.6486, Recall: 0.6159, F1-Score: 0.6206




              precision    recall  f1-score   support

           0     0.5321    0.4873    0.5087      2992
           1     0.6974    0.4983    0.5813      2992
           2     0.6981    0.6149    0.6538      3012
           3     0.3430    0.6364    0.4457      2998
           4     0.7175    0.7844    0.7495      2973
           5     0.8408    0.7610    0.7989      3054
           6     0.6833    0.4009    0.5054      3003
           7     0.6269    0.6248    0.6259      3012
           8     0.5866    0.7143    0.6442      2982
           9     0.7601    0.6365    0.6928      2982

    accuracy                         0.6160     30000
   macro avg     0.6486    0.6159    0.6206     30000
weighted avg     0.6489    0.6160    0.6208     30000





0.39267273726798974




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9986684756037758




CCA coefficients mean non-concern: 0.9987013958978228




Linear CKA concern: 0.9992607786331787




Linear CKA non-concern: 0.9996348594930945




Kernel CKA concern: 0.998090378786656




Kernel CKA non-concern: 0.9986547351818218




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.1981911659240723




Evaluate the pruned model 4




Evaluating the model:   0%|                                  | 0/1875 [00:00<?, ?it/s]

Loss: 1.2225




Precision: 0.6492, Recall: 0.6149, F1-Score: 0.6194




              precision    recall  f1-score   support

           0     0.5257    0.4963    0.5106      2992
           1     0.7017    0.4820    0.5714      2992
           2     0.7076    0.6026    0.6509      3012
           3     0.3414    0.6361    0.4443      2998
           4     0.7109    0.7884    0.7477      2973
           5     0.8422    0.7603    0.7992      3054
           6     0.6918    0.3939    0.5020      3003
           7     0.6284    0.6295    0.6290      3012
           8     0.5848    0.7180    0.6446      2982
           9     0.7574    0.6419    0.6949      2982

    accuracy                         0.6150     30000
   macro avg     0.6492    0.6149    0.6194     30000
weighted avg     0.6495    0.6150    0.6197     30000





0.39267273726798974




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9973230522993799




CCA coefficients mean non-concern: 0.9984217412968447




Linear CKA concern: 0.9981035522482459




Linear CKA non-concern: 0.9994077348331644




Kernel CKA concern: 0.9955184607674862




Kernel CKA non-concern: 0.9978974165809711




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2105278968811035




Evaluate the pruned model 5




Evaluating the model:   0%|                                  | 0/1875 [00:00<?, ?it/s]

Loss: 1.2217




Precision: 0.6491, Recall: 0.6152, F1-Score: 0.6197




              precision    recall  f1-score   support

           0     0.5434    0.4833    0.5116      2992
           1     0.6927    0.4926    0.5758      2992
           2     0.6960    0.6189    0.6552      3012
           3     0.3404    0.6371    0.4437      2998
           4     0.7215    0.7783    0.7489      2973
           5     0.8359    0.7656    0.7992      3054
           6     0.6918    0.3909    0.4996      3003
           7     0.6223    0.6308    0.6265      3012
           8     0.5798    0.7203    0.6424      2982
           9     0.7674    0.6338    0.6942      2982

    accuracy                         0.6153     30000
   macro avg     0.6491    0.6152    0.6197     30000
weighted avg     0.6494    0.6153    0.6199     30000





0.39267273726798974




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9973414099818648




CCA coefficients mean non-concern: 0.998583110956392




Linear CKA concern: 0.9902455792452134




Linear CKA non-concern: 0.999375381574628




Kernel CKA concern: 0.988547678126135




Kernel CKA non-concern: 0.9976982884921334




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2085623741149902




Evaluate the pruned model 6




Evaluating the model:   0%|                                                                                   …

Loss: 1.2170




Precision: 0.6470, Recall: 0.6180, F1-Score: 0.6222




              precision    recall  f1-score   support

           0     0.5254    0.5010    0.5129      2992
           1     0.6996    0.5013    0.5841      2992
           2     0.6970    0.6155    0.6537      3012
           3     0.3540    0.6261    0.4522      2998
           4     0.7211    0.7756    0.7474      2973
           5     0.8387    0.7613    0.7981      3054
           6     0.6725    0.4063    0.5065      3003
           7     0.6213    0.6351    0.6281      3012
           8     0.5867    0.7146    0.6444      2982
           9     0.7540    0.6435    0.6944      2982

    accuracy                         0.6181     30000
   macro avg     0.6470    0.6180    0.6222     30000
weighted avg     0.6473    0.6181    0.6224     30000





0.39267273726798974




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9988276868714596




CCA coefficients mean non-concern: 0.9984178790791476




Linear CKA concern: 0.9996070761086561




Linear CKA non-concern: 0.9991396280585569




Kernel CKA concern: 0.998578422933008




Kernel CKA non-concern: 0.9970345662128062




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.1918129920959473




Evaluate the pruned model 7




Evaluating the model:   0%|                                                                                   …

Loss: 1.2194




Precision: 0.6482, Recall: 0.6164, F1-Score: 0.6210




              precision    recall  f1-score   support

           0     0.5375    0.4836    0.5091      2992
           1     0.6936    0.4926    0.5761      2992
           2     0.7069    0.6119    0.6560      3012
           3     0.3447    0.6331    0.4464      2998
           4     0.7221    0.7780    0.7490      2973
           5     0.8400    0.7600    0.7980      3054
           6     0.6740    0.4069    0.5075      3003
           7     0.6179    0.6401    0.6288      3012
           8     0.5845    0.7176    0.6443      2982
           9     0.7602    0.6398    0.6948      2982

    accuracy                         0.6165     30000
   macro avg     0.6482    0.6164    0.6210     30000
weighted avg     0.6485    0.6165    0.6212     30000





0.39267273726798974




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9986524665428133




CCA coefficients mean non-concern: 0.9984436139821459




Linear CKA concern: 0.9992159281877021




Linear CKA non-concern: 0.9990266475924657




Kernel CKA concern: 0.9980019413578924




Kernel CKA non-concern: 0.9965932191528416




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.200343132019043




Evaluate the pruned model 8




Evaluating the model:   0%|                                                                                   …

Loss: 1.2207




Precision: 0.6486, Recall: 0.6155, F1-Score: 0.6204




              precision    recall  f1-score   support

           0     0.5237    0.4987    0.5109      2992
           1     0.7023    0.4896    0.5770      2992
           2     0.7008    0.6082    0.6513      3012
           3     0.3433    0.6334    0.4453      2998
           4     0.7211    0.7777    0.7483      2973
           5     0.8457    0.7590    0.8000      3054
           6     0.6836    0.4023    0.5065      3003
           7     0.6286    0.6242    0.6264      3012
           8     0.5832    0.7146    0.6423      2982
           9     0.7535    0.6469    0.6961      2982

    accuracy                         0.6155     30000
   macro avg     0.6486    0.6155    0.6204     30000
weighted avg     0.6489    0.6155    0.6206     30000





0.39267273726798974




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9984422547674274




CCA coefficients mean non-concern: 0.9981724450366386




Linear CKA concern: 0.9993594079149926




Linear CKA non-concern: 0.9988972960186712




Kernel CKA concern: 0.9979361017925343




Kernel CKA non-concern: 0.9960831962270636




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2061123847961426




Evaluate the pruned model 9




Evaluating the model:   0%|                                                                                   …

Loss: 1.2227




Precision: 0.6497, Recall: 0.6146, F1-Score: 0.6193




              precision    recall  f1-score   support

           0     0.5465    0.4769    0.5094      2992
           1     0.7055    0.4893    0.5779      2992
           2     0.6972    0.6175    0.6549      3012
           3     0.3374    0.6391    0.4417      2998
           4     0.7166    0.7851    0.7493      2973
           5     0.8430    0.7613    0.8001      3054
           6     0.6898    0.3909    0.4990      3003
           7     0.6231    0.6301    0.6266      3012
           8     0.5794    0.7193    0.6418      2982
           9     0.7583    0.6365    0.6921      2982

    accuracy                         0.6147     30000
   macro avg     0.6497    0.6146    0.6193     30000
weighted avg     0.6500    0.6147    0.6195     30000





0.39267273726798974




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9987232566643103




CCA coefficients mean non-concern: 0.9981957921031339




Linear CKA concern: 0.9993829315345087




Linear CKA non-concern: 0.999322672261147




Kernel CKA concern: 0.9983200474145365




Kernel CKA non-concern: 0.997568187699403




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.2121024131774902




In [12]:
df_list = [report_to_df(df) for df in result_list]
new_df = append_nth_row(df_list)
csv_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
new_df.to_csv(f"results/{csv_name}.csv", index=False)