In [1]:
import os
import sys

sys.path.append("../../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint
from src.models.evaluate import (
    evaluate_model,
    get_sparsity,
    get_similarity,
    get_perplexity,
)
from src.utils.sampling import SamplingDataset
from src.pruning.prune_head import head_importance_prunning
from src.utils.helper import report_to_df, append_nth_row

In [3]:
name = "bert-4-128-yahoo"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ratio = 0.6
seed = 44

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-10-22 18:27:11


In [5]:
config = Config(name, device)
num_labels = config.config["num_labels"]
model = load_model(config)

Loading the model.




{

'architectures'

: 

'bert'

,
 

'dataset_name'

: 

'YahooAnswersTopics'

,
 

'model_name'

: 

'models/bert-4-128-yahoo'

,
 

'num_labels'

: 

10

,
 

'tokenizer_name'

: 

'fabriceyhc/bert-base-uncased-yahoo_answers_topics'

}




The model models/bert-4-128-yahoo is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.




train.pkl is loaded from cache.




valid.pkl is loaded from cache.




test.pkl is loaded from cache.




The dataset YahooAnswersTopics is loaded




{

'config_name'

: 

'yahoo_answers_topics'

,
 

'features'

: 

{'first_column': 'question_title', 'second_column': 'topic'}

,
 

'path'

: 

'yahoo_answers_topics'

}




In [7]:
from src.utils.load import load_cache
from src.utils.data_class import CustomEmbeddingDataset
from torch.utils.data import DataLoader

generated = load_cache(
    "datasets/generated_dataset/embedding_based/4_128-yahoo",
    "4_128-yahoo_top1.pkl",
)

4_128-yahoo_top1.pkl is loaded from cache.




In [8]:
generated["embeddings"] = generated.pop("example_list")
generated["labels"] = generated.pop("example_label")
generated["attention_mask"] = generated.pop("attn_list")

In [9]:
generated_data = CustomEmbeddingDataset(generated)
generated_dataloder = DataLoader(
    generated_data,
    batch_size=4,
)

In [10]:
# print("Evaluate the original model")
# result = evaluate_model(model, config, test_dataloader)

In [11]:
config.init_seed()
all_samples = SamplingDataset(
    generated_dataloder,
    config,
    200,
    num_samples,
    False,
    4,
    resample=False,
)

In [12]:
result_list = []

module = copy.deepcopy(model)

head_importance_prunning(module, config, all_samples, ratio)

print(f"Evaluate the pruned model")
result = evaluate_model(module, config, test_dataloader, verbose=True)
result_list.append(result)

Total heads to prune: 9




tensor([[0.3952, 0.6048, 0.4106, 0.4775],
        [0.4851, 0.5362, 0.4935, 0.4638],
        [0.5171, 0.5280, 0.4731, 0.4720],
        [0.4691, 0.5217, 0.5309, 0.4839]])




{(0, 0), (0, 3), (3, 0), (2, 3), (0, 2), (3, 3), (2, 2), (1, 0), (1, 3)}




Evaluate the pruned model




Evaluating the model:   0%|                                                                                   …

In [13]:
for concern in range(num_labels):
    config.init_seed()
    get_similarity(model, module, valid_dataloader, concern, num_samples, config)

adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9862076962013993




CCA coefficients mean non-concern: 0.98617136651885




Linear CKA concern: 0.9334554887829777




Linear CKA non-concern: 0.9114847097961117




Kernel CKA concern: 0.8515362495191547




Kernel CKA non-concern: 0.8341169472131778




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9877121179025464




CCA coefficients mean non-concern: 0.9861770075891382




Linear CKA concern: 0.9075024608263896




Linear CKA non-concern: 0.9176234417674409




Kernel CKA concern: 0.8103079810876685




Kernel CKA non-concern: 0.8402708086304089




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9855006614193921




CCA coefficients mean non-concern: 0.9866627302251467




Linear CKA concern: 0.9222224087261223




Linear CKA non-concern: 0.9151189026889306




Kernel CKA concern: 0.8445513985219167




Kernel CKA non-concern: 0.8352209449698134




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9856189712464131




CCA coefficients mean non-concern: 0.9867413425760173




Linear CKA concern: 0.9004924356766868




Linear CKA non-concern: 0.9143815409190104




Kernel CKA concern: 0.8070913331550318




Kernel CKA non-concern: 0.8391782223356486




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9728651595019593




CCA coefficients mean non-concern: 0.9871875523790853




Linear CKA concern: 0.9070782457272729




Linear CKA non-concern: 0.9172302937296967




Kernel CKA concern: 0.8237034182434612




Kernel CKA non-concern: 0.8394099701316319




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9787413856039417




CCA coefficients mean non-concern: 0.986903660830672




Linear CKA concern: 0.8146836235246396




Linear CKA non-concern: 0.9246846846425851




Kernel CKA concern: 0.7221803970079734




Kernel CKA non-concern: 0.8471580262750619




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9877138592685142




CCA coefficients mean non-concern: 0.985857419542005




Linear CKA concern: 0.9177739349436249




Linear CKA non-concern: 0.9180058199561987




Kernel CKA concern: 0.824842538175383




Kernel CKA non-concern: 0.844821611688882




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9840345572805177




CCA coefficients mean non-concern: 0.9862039167430389




Linear CKA concern: 0.9387558546189911




Linear CKA non-concern: 0.9114141884137891




Kernel CKA concern: 0.8654319599190622




Kernel CKA non-concern: 0.8350828659190458




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9842663052744814




CCA coefficients mean non-concern: 0.9861627806344367




Linear CKA concern: 0.9220967016490078




Linear CKA non-concern: 0.9072692131407027




Kernel CKA concern: 0.8077238454230705




Kernel CKA non-concern: 0.8313337885197591




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9870157073356465




CCA coefficients mean non-concern: 0.9864115324531461




Linear CKA concern: 0.8999471175057384




Linear CKA non-concern: 0.9114243476549491




Kernel CKA concern: 0.8150074031444621




Kernel CKA non-concern: 0.8359196226995756




In [14]:
get_sparsity(module)
print("original model's perplexity")
get_perplexity(model, valid_dataloader, config)
print("pruned model's perplexity")
get_perplexity(module, valid_dataloader, config)

0.11043920735662212




{'bert.encoder.layer.0.attention.self.query.weight': 0.75, 'bert.encoder.layer.0.attention.self.query.bias': 0.0, 'bert.encoder.layer.0.attention.self.key.weight': 0.75, 'bert.encoder.layer.0.attention.self.key.bias': 0.0, 'bert.encoder.layer.0.attention.self.value.weight': 0.75, 'bert.encoder.layer.0.attention.self.value.bias': 0.0, 'bert.encoder.layer.0.attention.output.dense.weight': 0.75, 'bert.encoder.layer.0.attention.output.dense.bias': 0.0, 'bert.encoder.layer.0.intermediate.dense.weight': 0.0, 'bert.encoder.layer.0.intermediate.dense.bias': 0.0, 'bert.encoder.layer.0.output.dense.weight': 0.0, 'bert.encoder.layer.0.output.dense.bias': 0.0, 'bert.encoder.layer.1.attention.self.query.weight': 0.5, 'bert.encoder.layer.1.attention.self.query.bias': 0.0, 'bert.encoder.layer.1.attention.self.key.weight': 0.5, 'bert.encoder.layer.1.attention.self.key.bias': 0.0, 'bert.encoder.layer.1.attention.self.value.weight': 0.5, 'bert.encoder.layer.1.attention.self.value.bias': 0.0, 'bert.encod




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.861149787902832




3.861149787902832

In [15]:
df_list = [report_to_df(df) for df in result_list]
csv_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
new_df = df_list[0]
new_df.to_csv(f"results/{csv_name}.csv", index=False)
print(csv_name)
new_df

2024-10-22_18-29-09




Unnamed: 0,class,precision,recall,f1-score,support
0,0,0.4428,0.5775,0.5013,2992
1,1,0.6815,0.3419,0.4554,2992
2,2,0.6577,0.5455,0.5964,3012
3,3,0.2874,0.6718,0.4026,2998
4,4,0.8215,0.5991,0.6929,2973
5,5,0.8194,0.7371,0.7761,3054
6,6,0.6659,0.3976,0.4979,3003
7,7,0.5204,0.6647,0.5838,3012
8,8,0.7057,0.4906,0.5788,2982
9,9,0.7553,0.6241,0.6834,2982
