In [1]:
import os
import sys

sys.path.append("../../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint
from src.models.evaluate import (
    evaluate_model,
    get_sparsity,
    get_similarity,
    get_perplexity,
)
from src.utils.sampling import SamplingDataset
from src.pruning.prune_head import head_importance_prunning
from src.utils.helper import report_to_df, append_nth_row

In [3]:
name = "YahooAnswersTopics"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ratio = 0.6
seed = 44

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-10-23 01:48:44


In [5]:
config = Config(name, device)
num_labels = config.config["num_labels"]
model = load_model(config)

Loading the model.




{

'architectures'

: 

'bert'

,
 

'dataset_name'

: 

'YahooAnswersTopics'

,
 

'model_name'

: 

'fabriceyhc/bert-base-uncased-yahoo_answers_topics'

,
 

'num_labels'

: 

10

,
 

'tokenizer_name'

: 

'fabriceyhc/bert-base-uncased-yahoo_answers_topics'

}




The model fabriceyhc/bert-base-uncased-yahoo_answers_topics is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.




train.pkl is loaded from cache.




valid.pkl is loaded from cache.




test.pkl is loaded from cache.




The dataset YahooAnswersTopics is loaded




{

'config_name'

: 

'yahoo_answers_topics'

,
 

'features'

: 

{'first_column': 'question_title', 'second_column': 'topic'}

,
 

'path'

: 

'yahoo_answers_topics'

}




In [7]:
# print("Evaluate the original model")
# result = evaluate_model(model, config, test_dataloader)

In [8]:
config.init_seed()
all_samples = SamplingDataset(
    train_dataloader,
    config,
    200,
    num_samples,
    False,
    4,
    resample=False,
)

In [9]:
result_list = []

module = copy.deepcopy(model)

head_importance_prunning(module, config, all_samples, ratio)

print(f"Evaluate the pruned model")
result = evaluate_model(module, config, test_dataloader, verbose=True)
result_list.append(result)

Total heads to prune: 86




tensor([[0.6076, 0.4772, 0.3568, 0.3846, 0.4560, 0.3766, 0.5436, 0.4334, 0.6432,
         0.5125, 0.4895, 0.4496],
        [0.5521, 0.6671, 0.3830, 0.3551, 0.5890, 0.4487, 0.3329, 0.3896, 0.4603,
         0.4626, 0.4191, 0.5129],
        [0.7132, 0.4412, 0.3058, 0.3228, 0.2868, 0.3049, 0.4221, 0.3026, 0.3443,
         0.6358, 0.3968, 0.3584],
        [0.3185, 0.2584, 0.2544, 0.3329, 0.4143, 0.7929, 0.2071, 0.3755, 0.3678,
         0.5597, 0.2876, 0.3968],
        [0.5160, 0.3828, 0.3549, 0.6017, 0.6927, 0.4596, 0.3273, 0.3173, 0.3073,
         0.3367, 0.4392, 0.6130],
        [0.4848, 0.3186, 0.5481, 0.2955, 0.5394, 0.2806, 0.2502, 0.4210, 0.2485,
         0.7515, 0.5052, 0.2697],
        [0.6288, 0.4500, 0.3530, 0.3226, 0.4115, 0.5442, 0.6367, 0.3214, 0.4035,
         0.4223, 0.3378, 0.6786],
        [0.5168, 0.3592, 0.4791, 0.3492, 0.6916, 0.3457, 0.3678, 0.4930, 0.3084,
         0.3526, 0.4262, 0.6160],
        [0.2873, 0.2877, 0.7601, 0.3070, 0.4546, 0.4570, 0.4629, 0.2402, 0.2399,




{(4, 9), (5, 1), (8, 0), (8, 9), (9, 8), (11, 5), (2, 2), (0, 5), (2, 11), (6, 2), (7, 1), (4, 2), (3, 6), (5, 3), (8, 11), (9, 10), (11, 7), (2, 4), (6, 4), (7, 3), (3, 8), (5, 5), (9, 3), (11, 9), (1, 10), (7, 5), (3, 1), (3, 10), (5, 7), (11, 2), (0, 2), (1, 3), (3, 3), (9, 7), (10, 8), (7, 9), (9, 0), (5, 11), (9, 9), (10, 1), (10, 10), (1, 7), (2, 6), (3, 7), (4, 6), (10, 3), (2, 8), (6, 8), (3, 0), (5, 6), (4, 8), (8, 8), (10, 5), (1, 2), (2, 10), (6, 10), (3, 2), (4, 1), (3, 11), (8, 1), (8, 10), (10, 7), (11, 6), (2, 3), (6, 3), (3, 4), (8, 3), (10, 0), (10, 9), (9, 11), (11, 8), (2, 5), (1, 6), (10, 2), (9, 4), (11, 1), (11, 10), (2, 7), (6, 7), (7, 6), (4, 7), (5, 8), (8, 7), (11, 3), (0, 3), (7, 8)}




Evaluate the pruned model




Evaluating the model:   0%|                                                                                   …

In [10]:
for concern in range(num_labels):
    config.init_seed()
    get_similarity(model, module, valid_dataloader, concern, num_samples, config)

adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.7302194758469096




CCA coefficients mean non-concern: 0.7395325360548229




Linear CKA concern: 0.8425841945763825




Linear CKA non-concern: 0.8020152350505368




Kernel CKA concern: 0.7466249508156757




Kernel CKA non-concern: 0.720978060620086




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.727360427323423




CCA coefficients mean non-concern: 0.7408438810728396




Linear CKA concern: 0.8203348742067802




Linear CKA non-concern: 0.8009486077898198




Kernel CKA concern: 0.7396381369828943




Kernel CKA non-concern: 0.71529511325207




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.7230748078080033




CCA coefficients mean non-concern: 0.7400380371487371




Linear CKA concern: 0.822226322774814




Linear CKA non-concern: 0.8103312778329029




Kernel CKA concern: 0.7815471935823405




Kernel CKA non-concern: 0.7012176442685276




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.7329866574753702




CCA coefficients mean non-concern: 0.7409825815267413




Linear CKA concern: 0.7767719422332531




Linear CKA non-concern: 0.796705648510808




Kernel CKA concern: 0.689160620597657




Kernel CKA non-concern: 0.7247660483362663




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.7278364871563751




CCA coefficients mean non-concern: 0.7391290740403247




Linear CKA concern: 0.77112928314459




Linear CKA non-concern: 0.8050385274363655




Kernel CKA concern: 0.6971254480084678




Kernel CKA non-concern: 0.7205289618312514




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.717682398383796




CCA coefficients mean non-concern: 0.7378729191738227




Linear CKA concern: 0.8183213033713443




Linear CKA non-concern: 0.8175713450148253




Kernel CKA concern: 0.7704531081222264




Kernel CKA non-concern: 0.7261590935501472




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.7374376732025133




CCA coefficients mean non-concern: 0.7399377984907329




Linear CKA concern: 0.853058807121265




Linear CKA non-concern: 0.8002192327217951




Kernel CKA concern: 0.7231598618782634




Kernel CKA non-concern: 0.7238653051315874




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.7266811720669368




CCA coefficients mean non-concern: 0.7402582482218842




Linear CKA concern: 0.8108879965148599




Linear CKA non-concern: 0.7969821873240286




Kernel CKA concern: 0.7385144407864315




Kernel CKA non-concern: 0.7258952860173641




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.7302790636628369




CCA coefficients mean non-concern: 0.7389463827389514




Linear CKA concern: 0.8308979147208202




Linear CKA non-concern: 0.7935597022093018




Kernel CKA concern: 0.7705774548532986




Kernel CKA non-concern: 0.7175969471143573




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.7119873789346937




CCA coefficients mean non-concern: 0.7413208617014051




Linear CKA concern: 0.7269541231733093




Linear CKA non-concern: 0.7940429363973663




Kernel CKA concern: 0.6529498297920183




Kernel CKA non-concern: 0.7202686130155923




In [11]:
get_sparsity(module)
print("original model's perplexity")
get_perplexity(model, valid_dataloader, config)
print("pruned model's perplexity")
get_perplexity(module, valid_dataloader, config)

0.1974900871779841




{'bert.encoder.layer.0.attention.self.query.weight': 0.25, 'bert.encoder.layer.0.attention.self.query.bias': 0.0, 'bert.encoder.layer.0.attention.self.key.weight': 0.25, 'bert.encoder.layer.0.attention.self.key.bias': 0.0, 'bert.encoder.layer.0.attention.self.value.weight': 0.25, 'bert.encoder.layer.0.attention.self.value.bias': 0.0, 'bert.encoder.layer.0.attention.output.dense.weight': 0.25, 'bert.encoder.layer.0.attention.output.dense.bias': 0.0, 'bert.encoder.layer.0.intermediate.dense.weight': 0.0, 'bert.encoder.layer.0.intermediate.dense.bias': 0.0, 'bert.encoder.layer.0.output.dense.weight': 0.0, 'bert.encoder.layer.0.output.dense.bias': 0.0, 'bert.encoder.layer.1.attention.self.query.weight': 0.4166666666666667, 'bert.encoder.layer.1.attention.self.query.bias': 0.0, 'bert.encoder.layer.1.attention.self.key.weight': 0.4166666666666667, 'bert.encoder.layer.1.attention.self.key.bias': 0.0, 'bert.encoder.layer.1.attention.self.value.weight': 0.4166666666666667, 'bert.encoder.layer.1




original model's perplexity




2.6398401260375977




pruned model's perplexity




2.9050800800323486




2.9050800800323486

In [12]:
df_list = [report_to_df(df) for df in result_list]
csv_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
new_df = df_list[0]
new_df.to_csv(f"results/{csv_name}.csv", index=False)
print(csv_name)
new_df

2024-10-23_02-02-45




Unnamed: 0,class,precision,recall,f1-score,support
0,0,0.5927,0.5194,0.5536,2992
1,1,0.7356,0.5979,0.6597,2992
2,2,0.7122,0.7294,0.7207,3012
3,3,0.5126,0.4827,0.4972,2998
4,4,0.8263,0.7296,0.7749,2973
5,5,0.8909,0.7728,0.8276,3054
6,6,0.4821,0.4805,0.4813,3003
7,7,0.4998,0.7752,0.6078,3012
8,8,0.6534,0.7136,0.6822,2982
9,9,0.7424,0.7076,0.7246,2982
