In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint
from src.models.evaluate import (
    evaluate_model,
    get_sparsity,
    get_similarity,
    get_perplexity,
)
from src.utils.sampling import SamplingDataset
from src.pruning.prune_head import head_importance_prunning
from src.pruning.prune import prune_concern_identification
from src.utils.helper import report_to_df, append_nth_row

In [3]:
name = "bert-4-128-yahoo"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ratio = 0.6
seed = 44
include_layers = ["intermediate", "output"]
exclude_layers = [
    "attention",
]

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-10-15 17:27:56


In [5]:
config = Config(name, device)
num_labels = config.config["num_labels"]
model = load_model(config)

Loading the model.




{

'architectures'

: 

'bert'

,
 

'dataset_name'

: 

'YahooAnswersTopics'

,
 

'model_name'

: 

'models/bert-4-128-yahoo'

,
 

'num_labels'

: 

10

,
 

'tokenizer_name'

: 

'fabriceyhc/bert-base-uncased-yahoo_answers_topics'

}




The model models/bert-4-128-yahoo is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.




train.pkl is loaded from cache.




valid.pkl is loaded from cache.




test.pkl is loaded from cache.




The dataset YahooAnswersTopics is loaded




{

'config_name'

: 

'yahoo_answers_topics'

,
 

'features'

: 

{'first_column': 'question_title', 'second_column': 'topic'}

,
 

'path'

: 

'yahoo_answers_topics'

}




In [7]:
# print("Evaluate the original model")
# result = evaluate_model(model, config, test_dataloader)

In [8]:
result_list = []

for concern in range(config.num_labels):
    config.init_seed()
    positive_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        True,
        4,
        resample=False,
    )
    negative_samples = SamplingDataset(
        train_dataloader,
        config,
        concern,
        num_samples,
        False,
        4,
        resample=False,
    )
    all_samples = SamplingDataset(
        train_dataloader,
        config,
        200,
        num_samples,
        False,
        4,
        resample=False,
    )

    module = copy.deepcopy(model)

    head_importance_prunning(module, config, all_samples, ratio)

    prune_concern_identification(
        module,
        config,
        positive_samples,
        negative_samples,
        include_layers=["intermediate", "output"],
        exclude_layers=["attention"],
        sparsity_ratio=ratio,
        keep_dim=True,
        method="unstructed",
    )

    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, config, test_dataloader)
    result_list.append(result)
    get_sparsity(module)

    get_similarity(model, module, valid_dataloader, concern, num_samples, config)
    print("original model's perplexity")
    get_perplexity(model, valid_dataloader, config)
    print("pruned model's perplexity")
    get_perplexity(module, valid_dataloader, config)

Total heads to prune: 6




{(0, 0), (0, 3), (3, 0), (0, 2), (2, 2), (1, 0)}




Evaluate the pruned model 0




Evaluating the model:   0%|                                                                                   …

Loss: 1.3049




Precision: 0.6508, Recall: 0.5876, F1-Score: 0.6000




              precision    recall  f1-score   support

           0     0.5151    0.5174    0.5163      2992
           1     0.6791    0.4258    0.5234      2992
           2     0.6894    0.5813    0.6308      3012
           3     0.2946    0.6871    0.4124      2998
           4     0.8066    0.6535    0.7220      2973
           5     0.8569    0.7394    0.7938      3054
           6     0.7041    0.3843    0.4972      3003
           7     0.5387    0.6886    0.6045      3012
           8     0.6512    0.5922    0.6203      2982
           9     0.7720    0.6063    0.6792      2982

    accuracy                         0.5879     30000
   macro avg     0.6508    0.5876    0.6000     30000
weighted avg     0.6510    0.5879    0.6002     30000





0.5422258305634156




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9911543572044084




CCA coefficients mean non-concern: 0.9915058831713319




Linear CKA concern: 0.9802923662528804




Linear CKA non-concern: 0.9787497380832129




Kernel CKA concern: 0.9429136173146264




Kernel CKA non-concern: 0.9453899563905265




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.5301194190979004




Total heads to prune: 6




{(0, 0), (0, 3), (3, 0), (0, 2), (2, 2), (1, 0)}




Evaluate the pruned model 1




Evaluating the model:   0%|                                                                                   …

Loss: 1.3061




Precision: 0.6518, Recall: 0.5862, F1-Score: 0.5992




              precision    recall  f1-score   support

           0     0.5218    0.5157    0.5187      2992
           1     0.6786    0.4191    0.5182      2992
           2     0.6895    0.5810    0.6306      3012
           3     0.2899    0.6918    0.4085      2998
           4     0.8064    0.6529    0.7216      2973
           5     0.8614    0.7367    0.7942      3054
           6     0.7054    0.3820    0.4956      3003
           7     0.5432    0.6846    0.6058      3012
           8     0.6500    0.5922    0.6198      2982
           9     0.7713    0.6063    0.6789      2982

    accuracy                         0.5865     30000
   macro avg     0.6518    0.5862    0.5992     30000
weighted avg     0.6519    0.5865    0.5994     30000





0.5422258305634156




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9909392467526921




CCA coefficients mean non-concern: 0.991688075781389




Linear CKA concern: 0.9754069513594423




Linear CKA non-concern: 0.9808362186210896




Kernel CKA concern: 0.937689910167579




Kernel CKA non-concern: 0.9489889799630862




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.537315607070923




Total heads to prune: 6




{(0, 0), (0, 3), (3, 0), (0, 2), (2, 2), (1, 0)}




Evaluate the pruned model 2




Evaluating the model:   0%|                                                                                   …

Loss: 1.3051




Precision: 0.6504, Recall: 0.5873, F1-Score: 0.5999




              precision    recall  f1-score   support

           0     0.5108    0.5194    0.5151      2992
           1     0.6835    0.4258    0.5247      2992
           2     0.6870    0.5830    0.6307      3012
           3     0.2933    0.6855    0.4108      2998
           4     0.8046    0.6535    0.7212      2973
           5     0.8586    0.7374    0.7934      3054
           6     0.7018    0.3839    0.4963      3003
           7     0.5464    0.6826    0.6069      3012
           8     0.6482    0.5939    0.6199      2982
           9     0.7699    0.6083    0.6797      2982

    accuracy                         0.5876     30000
   macro avg     0.6504    0.5873    0.5999     30000
weighted avg     0.6506    0.5876    0.6001     30000





0.5422258305634156




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9908683854337911




CCA coefficients mean non-concern: 0.9917204848065599




Linear CKA concern: 0.9759418521654104




Linear CKA non-concern: 0.9798175618886713




Kernel CKA concern: 0.9357039127568267




Kernel CKA non-concern: 0.9463497152280791




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.53151798248291




Total heads to prune: 6




{(0, 0), (0, 3), (3, 0), (0, 2), (2, 2), (1, 0)}




Evaluate the pruned model 3




Evaluating the model:   0%|                                                                                   …

Loss: 1.3032




Precision: 0.6504, Recall: 0.5884, F1-Score: 0.6006




              precision    recall  f1-score   support

           0     0.5108    0.5204    0.5156      2992
           1     0.6772    0.4285    0.5249      2992
           2     0.6913    0.5813    0.6316      3012
           3     0.2957    0.6828    0.4127      2998
           4     0.8068    0.6589    0.7254      2973
           5     0.8587    0.7384    0.7940      3054
           6     0.7008    0.3853    0.4972      3003
           7     0.5406    0.6899    0.6062      3012
           8     0.6512    0.5899    0.6190      2982
           9     0.7706    0.6083    0.6799      2982

    accuracy                         0.5886     30000
   macro avg     0.6504    0.5884    0.6006     30000
weighted avg     0.6506    0.5886    0.6009     30000





0.5422258305634156




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9918658928679506




CCA coefficients mean non-concern: 0.991713188113922




Linear CKA concern: 0.9692673915073483




Linear CKA non-concern: 0.9793446158616146




Kernel CKA concern: 0.9192270623806477




Kernel CKA non-concern: 0.9475303675197717




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.5237159729003906




Total heads to prune: 6




{(0, 0), (0, 3), (3, 0), (0, 2), (2, 2), (1, 0)}




Evaluate the pruned model 4




Evaluating the model:   0%|                                                                                   …

Loss: 1.3025




Precision: 0.6507, Recall: 0.5882, F1-Score: 0.6006




              precision    recall  f1-score   support

           0     0.5036    0.5321    0.5175      2992
           1     0.6828    0.4281    0.5263      2992
           2     0.6911    0.5793    0.6303      3012
           3     0.2949    0.6828    0.4119      2998
           4     0.8059    0.6566    0.7236      2973
           5     0.8565    0.7407    0.7944      3054
           6     0.7032    0.3826    0.4956      3003
           7     0.5491    0.6799    0.6075      3012
           8     0.6497    0.5895    0.6181      2982
           9     0.7698    0.6100    0.6806      2982

    accuracy                         0.5884     30000
   macro avg     0.6507    0.5882    0.6006     30000
weighted avg     0.6509    0.5884    0.6008     30000





0.5422258305634156




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9810618488190375




CCA coefficients mean non-concern: 0.9924015718137261




Linear CKA concern: 0.945079363386222




Linear CKA non-concern: 0.9797734983204749




Kernel CKA concern: 0.8895168606082856




Kernel CKA non-concern: 0.9474810016759263




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.5218703746795654




Total heads to prune: 6




{(0, 0), (0, 3), (3, 0), (0, 2), (2, 2), (1, 0)}




Evaluate the pruned model 5




Evaluating the model:   0%|                                                                                   …

Loss: 1.3031




Precision: 0.6517, Recall: 0.5889, F1-Score: 0.6013




              precision    recall  f1-score   support

           0     0.5162    0.5227    0.5194      2992
           1     0.6770    0.4315    0.5270      2992
           2     0.6898    0.5847    0.6329      3012
           3     0.2957    0.6858    0.4132      2998
           4     0.8090    0.6566    0.7248      2973
           5     0.8574    0.7384    0.7935      3054
           6     0.7051    0.3846    0.4977      3003
           7     0.5397    0.6899    0.6057      3012
           8     0.6506    0.5912    0.6195      2982
           9     0.7771    0.6033    0.6793      2982

    accuracy                         0.5891     30000
   macro avg     0.6517    0.5889    0.6013     30000
weighted avg     0.6519    0.5891    0.6015     30000





0.5422258305634156




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9855847358740982




CCA coefficients mean non-concern: 0.991419521087833




Linear CKA concern: 0.9154024572750052




Linear CKA non-concern: 0.9829014616289268




Kernel CKA concern: 0.8615465007218519




Kernel CKA non-concern: 0.952795121794629




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.522143602371216




Total heads to prune: 6




{(0, 0), (0, 3), (3, 0), (0, 2), (2, 2), (1, 0)}




Evaluate the pruned model 6




Evaluating the model:   0%|                                                                                   …

Loss: 1.3032




Precision: 0.6501, Recall: 0.5882, F1-Score: 0.6005




              precision    recall  f1-score   support

           0     0.5104    0.5227    0.5165      2992
           1     0.6782    0.4275    0.5244      2992
           2     0.6900    0.5800    0.6302      3012
           3     0.2958    0.6855    0.4133      2998
           4     0.8069    0.6549    0.7230      2973
           5     0.8571    0.7387    0.7935      3054
           6     0.6999    0.3859    0.4975      3003
           7     0.5447    0.6859    0.6072      3012
           8     0.6503    0.5899    0.6186      2982
           9     0.7681    0.6110    0.6806      2982

    accuracy                         0.5885     30000
   macro avg     0.6501    0.5882    0.6005     30000
weighted avg     0.6503    0.5885    0.6007     30000





0.5422258305634156




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9914330756648949




CCA coefficients mean non-concern: 0.9910620273260425




Linear CKA concern: 0.9771180742593665




Linear CKA non-concern: 0.9813122133537491




Kernel CKA concern: 0.9362932978267068




Kernel CKA non-concern: 0.9499231626755472




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.5251379013061523




Total heads to prune: 6




{(0, 0), (0, 3), (3, 0), (0, 2), (2, 2), (1, 0)}




Evaluate the pruned model 7




Evaluating the model:   0%|                                                                                   …

Loss: 1.3028




Precision: 0.6500, Recall: 0.5891, F1-Score: 0.6012




              precision    recall  f1-score   support

           0     0.5125    0.5267    0.5195      2992
           1     0.6752    0.4328    0.5275      2992
           2     0.6914    0.5823    0.6322      3012
           3     0.2979    0.6795    0.4142      2998
           4     0.8082    0.6562    0.7243      2973
           5     0.8567    0.7400    0.7941      3054
           6     0.6990    0.3866    0.4979      3003
           7     0.5371    0.6906    0.6042      3012
           8     0.6509    0.5902    0.6191      2982
           9     0.7716    0.6060    0.6788      2982

    accuracy                         0.5894     30000
   macro avg     0.6500    0.5891    0.6012     30000
weighted avg     0.6502    0.5894    0.6014     30000





0.5422258305634156




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9895326495880511




CCA coefficients mean non-concern: 0.9911236863112813




Linear CKA concern: 0.9729818053545187




Linear CKA non-concern: 0.97868851378274




Kernel CKA concern: 0.9251733429857415




Kernel CKA non-concern: 0.9447312722710955




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.522033214569092




Total heads to prune: 6




{(0, 0), (0, 3), (3, 0), (0, 2), (2, 2), (1, 0)}




Evaluate the pruned model 8




Evaluating the model:   0%|                                                                                   …

Loss: 1.3059




Precision: 0.6507, Recall: 0.5872, F1-Score: 0.5999




              precision    recall  f1-score   support

           0     0.5131    0.5180    0.5155      2992
           1     0.6784    0.4251    0.5227      2992
           2     0.6887    0.5833    0.6317      3012
           3     0.2919    0.6861    0.4095      2998
           4     0.8068    0.6502    0.7201      2973
           5     0.8579    0.7394    0.7942      3054
           6     0.7029    0.3830    0.4958      3003
           7     0.5494    0.6809    0.6082      3012
           8     0.6474    0.5973    0.6213      2982
           9     0.7700    0.6087    0.6799      2982

    accuracy                         0.5875     30000
   macro avg     0.6507    0.5872    0.5999     30000
weighted avg     0.6509    0.5875    0.6001     30000





0.5422258305634156




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9888087601313321




CCA coefficients mean non-concern: 0.9910647527769108




Linear CKA concern: 0.96644032356303




Linear CKA non-concern: 0.9783955733683007




Kernel CKA concern: 0.905875703874428




Kernel CKA non-concern: 0.9456430875023084




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.5362589359283447




Total heads to prune: 6




{(0, 0), (0, 3), (3, 0), (0, 2), (2, 2), (1, 0)}




Evaluate the pruned model 9




Evaluating the model:   0%|                                                                                   …

Loss: 1.3037




Precision: 0.6506, Recall: 0.5878, F1-Score: 0.6003




              precision    recall  f1-score   support

           0     0.5144    0.5177    0.5161      2992
           1     0.6760    0.4295    0.5252      2992
           2     0.6893    0.5827    0.6315      3012
           3     0.2943    0.6875    0.4121      2998
           4     0.8075    0.6492    0.7197      2973
           5     0.8567    0.7397    0.7939      3054
           6     0.7017    0.3846    0.4969      3003
           7     0.5450    0.6859    0.6074      3012
           8     0.6505    0.5922    0.6200      2982
           9     0.7702    0.6093    0.6804      2982

    accuracy                         0.5881     30000
   macro avg     0.6506    0.5878    0.6003     30000
weighted avg     0.6507    0.5881    0.6006     30000





0.5422258305634156




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9912079749505229




CCA coefficients mean non-concern: 0.9912423624294632




Linear CKA concern: 0.9637031063511854




Linear CKA non-concern: 0.9793986912835758




Kernel CKA concern: 0.9125069108447312




Kernel CKA non-concern: 0.9475906663623236




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.527446985244751




In [9]:
df_list = [report_to_df(df) for df in result_list]
new_df = append_nth_row(df_list)
csv_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
new_df.to_csv(f"results/{csv_name}.csv", index=False)