In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint
from src.models.evaluate import (
    evaluate_model,
    get_sparsity,
    get_similarity,
    get_perplexity,
)
from src.utils.sampling import SamplingDataset
from src.pruning.prune import (
    prune_concern_identification,
)
from src.utils.helper import report_to_df, append_nth_row

In [3]:
name = "bert-6-128-yahoo"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ratio = 0.3
seed = 44
include_layers = ["intermediate", "output"]
exclude_layers = ["attention"]

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-10-15 20:14:54


In [5]:
config = Config(name, device)
num_labels = config.config["num_labels"]
model = load_model(config)

Loading the model.




{

'architectures'

: 

'bert'

,
 

'dataset_name'

: 

'YahooAnswersTopics'

,
 

'model_name'

: 

'models/bert-6-128-yahoo'

,
 

'num_labels'

: 

10

,
 

'tokenizer_name'

: 

'fabriceyhc/bert-base-uncased-yahoo_answers_topics'

}




The model models/bert-6-128-yahoo is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.




train.pkl is loaded from cache.




valid.pkl is loaded from cache.




test.pkl is loaded from cache.




The dataset YahooAnswersTopics is loaded




{

'config_name'

: 

'yahoo_answers_topics'

,
 

'features'

: 

{'first_column': 'question_title', 'second_column': 'topic'}

,
 

'path'

: 

'yahoo_answers_topics'

}




In [7]:
# print("Evaluate the original model")
# result = evaluate_model(model, config, test_dataloader)

In [8]:
result_list = []

for concern in range(config.num_labels):
    config.init_seed()
    positive_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        True,
        4,
        resample=False,
    )
    negative_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        False,
        4,
        resample=False,
    )
    all_samples = SamplingDataset(
        train_dataloader,
        config,
        200,
        num_samples,
        False,
        4,
        resample=False,
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        config,
        positive_samples,
        negative_samples,
        include_layers=["intermediate", "output"],
        exclude_layers=["attention"],
        sparsity_ratio=ratio,
        keep_dim=True,
        method="unstructed",
    )
    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, config, test_dataloader)
    result_list.append(result)
    get_sparsity(module)

    get_similarity(model, module, valid_dataloader, concern, num_samples, config)
    print("original model's perplexity")
    get_perplexity(model, valid_dataloader, config)
    print("pruned model's perplexity")
    get_perplexity(module, valid_dataloader, config)

Evaluate the pruned model 0




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2202




Precision: 0.6460, Recall: 0.6145, F1-Score: 0.6195




              precision    recall  f1-score   support

           0     0.5562    0.4699    0.5094      2992
           1     0.6995    0.5197    0.5964      2992
           2     0.6811    0.6411    0.6605      3012
           3     0.3400    0.6438    0.4450      2998
           4     0.7181    0.7608    0.7389      2973
           5     0.8403    0.7682    0.8026      3054
           6     0.6767    0.4036    0.5056      3003
           7     0.6237    0.6069    0.6152      3012
           8     0.5962    0.6881    0.6389      2982
           9     0.7282    0.6425    0.6827      2982

    accuracy                         0.6146     30000
   macro avg     0.6460    0.6145    0.6195     30000
weighted avg     0.6463    0.6146    0.6197     30000





0.19449301788000617




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.999800027995468




CCA coefficients mean non-concern: 0.9997557848026797




Linear CKA concern: 0.999959341664185




Linear CKA non-concern: 0.9999342067469565




Kernel CKA concern: 0.9998351420594074




Kernel CKA non-concern: 0.9997665386574712




original model's perplexity




3.187649726867676




pruned model's perplexity




3.1860411167144775




Evaluate the pruned model 1




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2199




Precision: 0.6459, Recall: 0.6146, F1-Score: 0.6195




              precision    recall  f1-score   support

           0     0.5549    0.4693    0.5085      2992
           1     0.7006    0.5194    0.5965      2992
           2     0.6804    0.6411    0.6602      3012
           3     0.3407    0.6431    0.4454      2998
           4     0.7161    0.7635    0.7391      2973
           5     0.8404    0.7692    0.8032      3054
           6     0.6758    0.4039    0.5056      3003
           7     0.6248    0.6066    0.6156      3012
           8     0.5947    0.6888    0.6383      2982
           9     0.7302    0.6408    0.6826      2982

    accuracy                         0.6147     30000
   macro avg     0.6459    0.6146    0.6195     30000
weighted avg     0.6462    0.6147    0.6197     30000





0.19449301788000617




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9998405633280478




CCA coefficients mean non-concern: 0.9997608394650191




Linear CKA concern: 0.9999474240544095




Linear CKA non-concern: 0.9999446621522042




Kernel CKA concern: 0.999861227351618




Kernel CKA non-concern: 0.9997690601760959




original model's perplexity




3.187649726867676




pruned model's perplexity




3.184821367263794




Evaluate the pruned model 2




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2200




Precision: 0.6463, Recall: 0.6149, F1-Score: 0.6198




              precision    recall  f1-score   support

           0     0.5569    0.4713    0.5105      2992
           1     0.7000    0.5187    0.5959      2992
           2     0.6822    0.6408    0.6608      3012
           3     0.3410    0.6454    0.4463      2998
           4     0.7196    0.7615    0.7400      2973
           5     0.8392    0.7692    0.8027      3054
           6     0.6762    0.4033    0.5052      3003
           7     0.6250    0.6076    0.6162      3012
           8     0.5945    0.6888    0.6382      2982
           9     0.7284    0.6422    0.6826      2982

    accuracy                         0.6150     30000
   macro avg     0.6463    0.6149    0.6198     30000
weighted avg     0.6466    0.6150    0.6201     30000





0.19449301788000617




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.999793947984486




CCA coefficients mean non-concern: 0.999752509209763




Linear CKA concern: 0.9999200462910791




Linear CKA non-concern: 0.9999372690956759




Kernel CKA concern: 0.9998531793524356




Kernel CKA non-concern: 0.9997737907865105




original model's perplexity




3.187649726867676




pruned model's perplexity




3.185863733291626




Evaluate the pruned model 3




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2202




Precision: 0.6461, Recall: 0.6147, F1-Score: 0.6196




              precision    recall  f1-score   support

           0     0.5566    0.4703    0.5098      2992
           1     0.6997    0.5187    0.5958      2992
           2     0.6806    0.6411    0.6603      3012
           3     0.3403    0.6428    0.4450      2998
           4     0.7160    0.7632    0.7388      2973
           5     0.8406    0.7685    0.8029      3054
           6     0.6773    0.4039    0.5060      3003
           7     0.6243    0.6069    0.6155      3012
           8     0.5949    0.6895    0.6387      2982
           9     0.7311    0.6419    0.6836      2982

    accuracy                         0.6148     30000
   macro avg     0.6461    0.6147    0.6196     30000
weighted avg     0.6464    0.6148    0.6199     30000





0.19449301788000617




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9997957521856163




CCA coefficients mean non-concern: 0.9997551221300537




Linear CKA concern: 0.9999472082815284




Linear CKA non-concern: 0.9999382625795429




Kernel CKA concern: 0.999898177983931




Kernel CKA non-concern: 0.9997595796095924




original model's perplexity




3.187649726867676




pruned model's perplexity




3.185988664627075




Evaluate the pruned model 4




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2192




Precision: 0.6459, Recall: 0.6149, F1-Score: 0.6198




              precision    recall  f1-score   support

           0     0.5561    0.4706    0.5098      2992
           1     0.7023    0.5187    0.5967      2992
           2     0.6803    0.6414    0.6603      3012
           3     0.3415    0.6424    0.4459      2998
           4     0.7165    0.7642    0.7396      2973
           5     0.8390    0.7695    0.8027      3054
           6     0.6756    0.4036    0.5053      3003
           7     0.6242    0.6072    0.6156      3012
           8     0.5955    0.6888    0.6388      2982
           9     0.7281    0.6429    0.6828      2982

    accuracy                         0.6151     30000
   macro avg     0.6459    0.6149    0.6198     30000
weighted avg     0.6462    0.6151    0.6200     30000





0.19449301788000617




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.999697520520867




CCA coefficients mean non-concern: 0.9997210245041708




Linear CKA concern: 0.9998905880584283




Linear CKA non-concern: 0.9999112584798331




Kernel CKA concern: 0.999843355727957




Kernel CKA non-concern: 0.9996405610507534




original model's perplexity




3.187649726867676




pruned model's perplexity




3.1827616691589355




Evaluate the pruned model 5




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2193




Precision: 0.6458, Recall: 0.6148, F1-Score: 0.6198




              precision    recall  f1-score   support

           0     0.5553    0.4696    0.5089      2992
           1     0.6987    0.5217    0.5974      2992
           2     0.6796    0.6408    0.6596      3012
           3     0.3411    0.6428    0.4457      2998
           4     0.7192    0.7615    0.7397      2973
           5     0.8397    0.7685    0.8025      3054
           6     0.6748    0.4056    0.5067      3003
           7     0.6245    0.6062    0.6152      3012
           8     0.5951    0.6885    0.6384      2982
           9     0.7302    0.6425    0.6836      2982

    accuracy                         0.6149     30000
   macro avg     0.6458    0.6148    0.6198     30000
weighted avg     0.6461    0.6149    0.6200     30000





0.19449301788000617




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9997446497630545




CCA coefficients mean non-concern: 0.9997569879560277




Linear CKA concern: 0.9996222771478309




Linear CKA non-concern: 0.9999389219166668




Kernel CKA concern: 0.9997491175746077




Kernel CKA non-concern: 0.9997701677877203




original model's perplexity




3.187649726867676




pruned model's perplexity




3.1833832263946533




Evaluate the pruned model 6




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2201




Precision: 0.6461, Recall: 0.6148, F1-Score: 0.6198




              precision    recall  f1-score   support

           0     0.5553    0.4699    0.5091      2992
           1     0.6997    0.5211    0.5973      2992
           2     0.6806    0.6418    0.6606      3012
           3     0.3405    0.6424    0.4451      2998
           4     0.7177    0.7629    0.7396      2973
           5     0.8412    0.7685    0.8032      3054
           6     0.6763    0.4036    0.5055      3003
           7     0.6243    0.6062    0.6151      3012
           8     0.5948    0.6891    0.6385      2982
           9     0.7304    0.6422    0.6834      2982

    accuracy                         0.6149     30000
   macro avg     0.6461    0.6148    0.6198     30000
weighted avg     0.6464    0.6149    0.6200     30000





0.19449301788000617




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9998135412757466




CCA coefficients mean non-concern: 0.9997210631373982




Linear CKA concern: 0.9999621029801632




Linear CKA non-concern: 0.9999070166637012




Kernel CKA concern: 0.9998117317811426




Kernel CKA non-concern: 0.9996090642100298




original model's perplexity




3.187649726867676




pruned model's perplexity




3.1856937408447266




Evaluate the pruned model 7




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2198




Precision: 0.6463, Recall: 0.6150, F1-Score: 0.6200




              precision    recall  f1-score   support

           0     0.5561    0.4723    0.5108      2992
           1     0.6995    0.5221    0.5979      2992
           2     0.6810    0.6408    0.6603      3012
           3     0.3408    0.6421    0.4453      2998
           4     0.7151    0.7632    0.7384      2973
           5     0.8413    0.7688    0.8034      3054
           6     0.6771    0.4043    0.5063      3003
           7     0.6247    0.6072    0.6158      3012
           8     0.5956    0.6881    0.6386      2982
           9     0.7318    0.6415    0.6837      2982

    accuracy                         0.6152     30000
   macro avg     0.6463    0.6150    0.6200     30000
weighted avg     0.6466    0.6152    0.6203     30000





0.19449301788000617




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.999760902492138




CCA coefficients mean non-concern: 0.9997420419062482




Linear CKA concern: 0.9998665389936706




Linear CKA non-concern: 0.9999280131462084




Kernel CKA concern: 0.99984082840341




Kernel CKA non-concern: 0.9997605784518611




original model's perplexity




3.187649726867676




pruned model's perplexity




3.184856414794922




Evaluate the pruned model 8




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2202




Precision: 0.6463, Recall: 0.6146, F1-Score: 0.6196




              precision    recall  f1-score   support

           0     0.5553    0.4682    0.5081      2992
           1     0.6994    0.5187    0.5957      2992
           2     0.6820    0.6408    0.6607      3012
           3     0.3401    0.6461    0.4456      2998
           4     0.7185    0.7625    0.7399      2973
           5     0.8399    0.7678    0.8023      3054
           6     0.6771    0.4036    0.5057      3003
           7     0.6260    0.6062    0.6160      3012
           8     0.5957    0.6891    0.6390      2982
           9     0.7288    0.6425    0.6829      2982

    accuracy                         0.6147     30000
   macro avg     0.6463    0.6146    0.6196     30000
weighted avg     0.6466    0.6147    0.6198     30000





0.19449301788000617




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9997891642273887




CCA coefficients mean non-concern: 0.9996950248014522




Linear CKA concern: 0.9999718839710341




Linear CKA non-concern: 0.9999032348947725




Kernel CKA concern: 0.9999147874639077




Kernel CKA non-concern: 0.9996680798021327




original model's perplexity




3.187649726867676




pruned model's perplexity




3.1858785152435303




Evaluate the pruned model 9




Evaluating the model:   0%|                                                                                   ???

Loss: 1.2197




Precision: 0.6456, Recall: 0.6145, F1-Score: 0.6195




              precision    recall  f1-score   support

           0     0.5539    0.4706    0.5089      2992
           1     0.7001    0.5204    0.5970      2992
           2     0.6814    0.6411    0.6606      3012
           3     0.3409    0.6424    0.4455      2998
           4     0.7160    0.7622    0.7384      2973
           5     0.8409    0.7685    0.8031      3054
           6     0.6752    0.4049    0.5062      3003
           7     0.6243    0.6052    0.6146      3012
           8     0.5948    0.6881    0.6381      2982
           9     0.7290    0.6415    0.6825      2982

    accuracy                         0.6146     30000
   macro avg     0.6456    0.6145    0.6195     30000
weighted avg     0.6460    0.6146    0.6197     30000





0.19449301788000617




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9998292045293704




CCA coefficients mean non-concern: 0.9997552561305613




Linear CKA concern: 0.999950190268071




Linear CKA non-concern: 0.9999142666479581




Kernel CKA concern: 0.999835992004992




Kernel CKA non-concern: 0.9997505724926494




original model's perplexity




3.187649726867676




pruned model's perplexity




3.184286117553711




In [9]:
df_list = [report_to_df(df) for df in result_list]
new_df = append_nth_row(df_list)
csv_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
new_df.to_csv(f"results/{csv_name}.csv", index=False)