In [1]:
import os
import sys

sys.path.append("../../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint
from src.models.evaluate import (
    evaluate_model,
    get_sparsity,
    get_similarity,
    get_perplexity,
)
from src.utils.sampling import SamplingDataset
from src.pruning.prune_head import head_importance_prunning
from src.utils.helper import report_to_df, append_nth_row

In [3]:
name = "bert-4-128-yahoo"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ratio = 0.5
seed = 44

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-10-22 18:24:59


In [5]:
config = Config(name, device)
num_labels = config.config["num_labels"]
model = load_model(config)

Loading the model.




{

'architectures'

: 

'bert'

,
 

'dataset_name'

: 

'YahooAnswersTopics'

,
 

'model_name'

: 

'models/bert-4-128-yahoo'

,
 

'num_labels'

: 

10

,
 

'tokenizer_name'

: 

'fabriceyhc/bert-base-uncased-yahoo_answers_topics'

}




The model models/bert-4-128-yahoo is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.




train.pkl is loaded from cache.




valid.pkl is loaded from cache.




test.pkl is loaded from cache.




The dataset YahooAnswersTopics is loaded




{

'config_name'

: 

'yahoo_answers_topics'

,
 

'features'

: 

{'first_column': 'question_title', 'second_column': 'topic'}

,
 

'path'

: 

'yahoo_answers_topics'

}




In [7]:
from src.utils.load import load_cache
from src.utils.data_class import CustomEmbeddingDataset
from torch.utils.data import DataLoader

generated = load_cache(
    "datasets/generated_dataset/embedding_based/4_128-yahoo",
    "4_128-yahoo_top1.pkl",
)

4_128-yahoo_top1.pkl is loaded from cache.




In [8]:
generated["embeddings"] = generated.pop("example_list")
generated["labels"] = generated.pop("example_label")
generated["attention_mask"] = generated.pop("attn_list")

In [9]:
generated_data = CustomEmbeddingDataset(generated)
generated_dataloder = DataLoader(
    generated_data,
    batch_size=4,
)

In [10]:
# print("Evaluate the original model")
# result = evaluate_model(model, config, test_dataloader)

In [11]:
config.init_seed()
all_samples = SamplingDataset(
    generated_dataloder,
    config,
    200,
    num_samples,
    False,
    4,
    resample=False,
)

In [12]:
result_list = []

module = copy.deepcopy(model)

head_importance_prunning(module, config, all_samples, ratio)

print(f"Evaluate the pruned model")
result = evaluate_model(module, config, test_dataloader, verbose=True)
result_list.append(result)

Total heads to prune: 8




tensor([[0.3952, 0.6048, 0.4106, 0.4775],
        [0.4851, 0.5362, 0.4935, 0.4638],
        [0.5171, 0.5280, 0.4731, 0.4720],
        [0.4691, 0.5217, 0.5309, 0.4839]])




{(0, 0), (0, 3), (3, 0), (2, 3), (0, 2), (3, 3), (2, 2), (1, 3)}




Evaluate the pruned model




Evaluating the model:   0%|                                                                                   …

In [13]:
for concern in range(num_labels):
    config.init_seed()
    get_similarity(model, module, valid_dataloader, concern, num_samples, config)

adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9887567807311284




CCA coefficients mean non-concern: 0.9884964257161947




Linear CKA concern: 0.9390688146409307




Linear CKA non-concern: 0.9189985411284896




Kernel CKA concern: 0.8661325697048274




Kernel CKA non-concern: 0.8488119850828026




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9896527722768762




CCA coefficients mean non-concern: 0.9885099845348712




Linear CKA concern: 0.9212903100407269




Linear CKA non-concern: 0.924198415624473




Kernel CKA concern: 0.8406480592119452




Kernel CKA non-concern: 0.8542550895524943




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9879094047209239




CCA coefficients mean non-concern: 0.9888503356860221




Linear CKA concern: 0.9255513914676455




Linear CKA non-concern: 0.9227500930601016




Kernel CKA concern: 0.8523734788871034




Kernel CKA non-concern: 0.8511074048412212




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9880189564959944




CCA coefficients mean non-concern: 0.9890926962327664




Linear CKA concern: 0.910878475654154




Linear CKA non-concern: 0.9221345705060212




Kernel CKA concern: 0.8318423365193048




Kernel CKA non-concern: 0.8545246718386005




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9768954218456946




CCA coefficients mean non-concern: 0.9894055908663465




Linear CKA concern: 0.9181368349510114




Linear CKA non-concern: 0.9239240788806539




Kernel CKA concern: 0.843083198570638




Kernel CKA non-concern: 0.8540694729791902




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9837375714323495




CCA coefficients mean non-concern: 0.9890033627273423




Linear CKA concern: 0.8829088841549837




Linear CKA non-concern: 0.9310683889037439




Kernel CKA concern: 0.8295732177061432




Kernel CKA non-concern: 0.8617068958037336




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9895409374025449




CCA coefficients mean non-concern: 0.9881253831956665




Linear CKA concern: 0.9254078611527278




Linear CKA non-concern: 0.9256800542497663




Kernel CKA concern: 0.8401938546336292




Kernel CKA non-concern: 0.8607103122946325




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9866455195880126




CCA coefficients mean non-concern: 0.988505136641098




Linear CKA concern: 0.9488061376799213




Linear CKA non-concern: 0.9192231120475051




Kernel CKA concern: 0.9018993337205834




Kernel CKA non-concern: 0.8506991337752966




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9863615134077472




CCA coefficients mean non-concern: 0.9885444399297288




Linear CKA concern: 0.9260223727154764




Linear CKA non-concern: 0.9157444503950892




Kernel CKA concern: 0.8185601456568675




Kernel CKA non-concern: 0.847695727036734




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9892400929208482




CCA coefficients mean non-concern: 0.9886331458566466




Linear CKA concern: 0.9168541520684045




Linear CKA non-concern: 0.9196585513612614




Kernel CKA concern: 0.848180194433529




Kernel CKA non-concern: 0.8523217329397217




In [14]:
get_sparsity(module)
print("original model's perplexity")
get_perplexity(model, valid_dataloader, config)
print("pruned model's perplexity")
get_perplexity(module, valid_dataloader, config)

0.09816818431699743




{'bert.encoder.layer.0.attention.self.query.weight': 0.75, 'bert.encoder.layer.0.attention.self.query.bias': 0.0, 'bert.encoder.layer.0.attention.self.key.weight': 0.75, 'bert.encoder.layer.0.attention.self.key.bias': 0.0, 'bert.encoder.layer.0.attention.self.value.weight': 0.75, 'bert.encoder.layer.0.attention.self.value.bias': 0.0, 'bert.encoder.layer.0.attention.output.dense.weight': 0.75, 'bert.encoder.layer.0.attention.output.dense.bias': 0.0, 'bert.encoder.layer.0.intermediate.dense.weight': 0.0, 'bert.encoder.layer.0.intermediate.dense.bias': 0.0, 'bert.encoder.layer.0.output.dense.weight': 0.0, 'bert.encoder.layer.0.output.dense.bias': 0.0, 'bert.encoder.layer.1.attention.self.query.weight': 0.25, 'bert.encoder.layer.1.attention.self.query.bias': 0.0, 'bert.encoder.layer.1.attention.self.key.weight': 0.25, 'bert.encoder.layer.1.attention.self.key.bias': 0.0, 'bert.encoder.layer.1.attention.self.value.weight': 0.25, 'bert.encoder.layer.1.attention.self.value.bias': 0.0, 'bert.en




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.773810386657715




3.773810386657715

In [15]:
df_list = [report_to_df(df) for df in result_list]
csv_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
new_df = df_list[0]
new_df.to_csv(f"results/{csv_name}.csv", index=False)
print(csv_name)
new_df

2024-10-22_18-27-08




Unnamed: 0,class,precision,recall,f1-score,support
0,0,0.4816,0.5338,0.5063,2992
1,1,0.6741,0.3352,0.4478,2992
2,2,0.6663,0.5422,0.5978,3012
3,3,0.2875,0.6778,0.4037,2998
4,4,0.8261,0.588,0.687,2973
5,5,0.7896,0.7731,0.7813,3054
6,6,0.6917,0.3863,0.4957,3003
7,7,0.5116,0.6823,0.5847,3012
8,8,0.7064,0.5044,0.5885,2982
9,9,0.7253,0.6683,0.6956,2982
