In [1]:
import os
import sys

sys.path.append("../../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint
from src.models.evaluate import (
    evaluate_model,
    get_sparsity,
    get_similarity,
    get_perplexity,
)
from src.utils.sampling import SamplingDataset
from src.pruning.prune_head import head_importance_prunning
from src.utils.helper import report_to_df, append_nth_row

In [3]:
name = "bert-6-128-yahoo"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ratio = 0.6
seed = 44

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-10-22 22:51:12


In [5]:
config = Config(name, device)
num_labels = config.config["num_labels"]
model = load_model(config)

Loading the model.




{

'architectures'

: 

'bert'

,
 

'dataset_name'

: 

'YahooAnswersTopics'

,
 

'model_name'

: 

'models/bert-6-128-yahoo'

,
 

'num_labels'

: 

10

,
 

'tokenizer_name'

: 

'fabriceyhc/bert-base-uncased-yahoo_answers_topics'

}




The model models/bert-6-128-yahoo is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.




train.pkl is loaded from cache.




valid.pkl is loaded from cache.




test.pkl is loaded from cache.




The dataset YahooAnswersTopics is loaded




{

'config_name'

: 

'yahoo_answers_topics'

,
 

'features'

: 

{'first_column': 'question_title', 'second_column': 'topic'}

,
 

'path'

: 

'yahoo_answers_topics'

}




In [7]:
from src.utils.load import load_cache
from src.utils.data_class import CustomEmbeddingDataset
from torch.utils.data import DataLoader

generated = load_cache(
    "datasets/generated_dataset/embedding_based/4_128-yahoo",
    "4_128-yahoo_top1.pkl",
)

4_128-yahoo_top1.pkl is loaded from cache.




In [8]:
generated["embeddings"] = generated.pop("example_list")
generated["labels"] = generated.pop("example_label")
generated["attention_mask"] = generated.pop("attn_list")

In [9]:
generated_data = CustomEmbeddingDataset(generated)
generated_dataloder = DataLoader(
    generated_data,
    batch_size=4,
)

In [10]:
# print("Evaluate the original model")
# result = evaluate_model(model, config, test_dataloader)

In [11]:
config.init_seed()
all_samples = SamplingDataset(
    generated_dataloder,
    config,
    200,
    num_samples,
    False,
    4,
    resample=False,
)

In [12]:
result_list = []

module = copy.deepcopy(model)

head_importance_prunning(module, config, all_samples, ratio)

print(f"Evaluate the pruned model")
result = evaluate_model(module, config, test_dataloader, verbose=True)
result_list.append(result)

Total heads to prune: 7




tensor([[0.5141, 0.4859],
        [0.5482, 0.4518],
        [0.5592, 0.4408],
        [0.4687, 0.5313],
        [0.4890, 0.5110],
        [0.4927, 0.5073]])




{(0, 1), (4, 0), (2, 1), (1, 1), (5, 1), (3, 0), (5, 0)}




Evaluate the pruned model




Evaluating the model:   0%|                                                                                   …

In [13]:
for concern in range(num_labels):
    config.init_seed()
    get_similarity(model, module, valid_dataloader, concern, num_samples, config)

adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9663349040559495




CCA coefficients mean non-concern: 0.9655897508704713




Linear CKA concern: 0.6800403401203248




Linear CKA non-concern: 0.6945792658221206




Kernel CKA concern: 0.7192965640331522




Kernel CKA non-concern: 0.6787128941625845




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9773717559769853




CCA coefficients mean non-concern: 0.9646143484367995




Linear CKA concern: 0.6094203372360666




Linear CKA non-concern: 0.6948998959255149




Kernel CKA concern: 0.6175760478406822




Kernel CKA non-concern: 0.6704993678322833




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9620679727908867




CCA coefficients mean non-concern: 0.9611658983564919




Linear CKA concern: 0.67706979488708




Linear CKA non-concern: 0.6887200048839596




Kernel CKA concern: 0.8026399301059781




Kernel CKA non-concern: 0.6712790245471627




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9671212718791344




CCA coefficients mean non-concern: 0.9679851820187327




Linear CKA concern: 0.7265358939445472




Linear CKA non-concern: 0.6736503630034966




Kernel CKA concern: 0.7019534811066122




Kernel CKA non-concern: 0.6748787183640129




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9677922531042592




CCA coefficients mean non-concern: 0.965577050499151




Linear CKA concern: 0.7518809272354209




Linear CKA non-concern: 0.690141730809361




Kernel CKA concern: 0.8006100613565771




Kernel CKA non-concern: 0.6599865168947305




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9671294106292957




CCA coefficients mean non-concern: 0.9697867830922428




Linear CKA concern: 0.5716636880382339




Linear CKA non-concern: 0.6768209790100139




Kernel CKA concern: 0.47572472304991686




Kernel CKA non-concern: 0.6934109769207912




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9735031780358563




CCA coefficients mean non-concern: 0.9643783961479651




Linear CKA concern: 0.6994870189746675




Linear CKA non-concern: 0.6767600039065




Kernel CKA concern: 0.7097653199070246




Kernel CKA non-concern: 0.6758055206918715




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9710560810270623




CCA coefficients mean non-concern: 0.9705544882734841




Linear CKA concern: 0.7702768289842299




Linear CKA non-concern: 0.6547925150937132




Kernel CKA concern: 0.6585576665074232




Kernel CKA non-concern: 0.6707316797163448




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9738592457274112




CCA coefficients mean non-concern: 0.9715577433624482




Linear CKA concern: 0.7434937444444835




Linear CKA non-concern: 0.6714538200683751




Kernel CKA concern: 0.7891550694524047




Kernel CKA non-concern: 0.6753793336440551




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9790668372017922




CCA coefficients mean non-concern: 0.9690072566462817




Linear CKA concern: 0.6602754015320456




Linear CKA non-concern: 0.6695680075385132




Kernel CKA concern: 0.6028211572449116




Kernel CKA non-concern: 0.6737755008468624




In [14]:
get_sparsity(module)
print("original model's perplexity")
get_perplexity(model, valid_dataloader, config)
print("pruned model's perplexity")
get_perplexity(module, valid_dataloader, config)

0.19045436504971097




{'bert.encoder.layer.0.attention.self.query.weight': 0.5, 'bert.encoder.layer.0.attention.self.query.bias': 0.0, 'bert.encoder.layer.0.attention.self.key.weight': 0.5, 'bert.encoder.layer.0.attention.self.key.bias': 0.0, 'bert.encoder.layer.0.attention.self.value.weight': 0.5, 'bert.encoder.layer.0.attention.self.value.bias': 0.0, 'bert.encoder.layer.0.attention.output.dense.weight': 0.5, 'bert.encoder.layer.0.attention.output.dense.bias': 0.0, 'bert.encoder.layer.0.intermediate.dense.weight': 0.0, 'bert.encoder.layer.0.intermediate.dense.bias': 0.0, 'bert.encoder.layer.0.output.dense.weight': 0.0, 'bert.encoder.layer.0.output.dense.bias': 0.0, 'bert.encoder.layer.1.attention.self.query.weight': 0.5, 'bert.encoder.layer.1.attention.self.query.bias': 0.0, 'bert.encoder.layer.1.attention.self.key.weight': 0.5, 'bert.encoder.layer.1.attention.self.key.bias': 0.0, 'bert.encoder.layer.1.attention.self.value.weight': 0.5, 'bert.encoder.layer.1.attention.self.value.bias': 0.0, 'bert.encoder.l




original model's perplexity




3.187649726867676




pruned model's perplexity




3.850386381149292




3.850386381149292

In [15]:
df_list = [report_to_df(df) for df in result_list]
csv_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
new_df = df_list[0]
new_df.to_csv(f"results/{csv_name}.csv", index=False)
print(csv_name)
new_df

2024-10-22_22-53-08




Unnamed: 0,class,precision,recall,f1-score,support
0,0,0.4059,0.5826,0.4785,2992
1,1,0.6574,0.5164,0.5784,2992
2,2,0.6511,0.6238,0.6372,3012
3,3,0.3218,0.5794,0.4138,2998
4,4,0.7655,0.6038,0.6751,2973
5,5,0.7971,0.5815,0.6725,3054
6,6,0.6638,0.38,0.4833,3003
7,7,0.6803,0.4641,0.5518,3012
8,8,0.6148,0.6512,0.6325,2982
9,9,0.5799,0.6935,0.6316,2982
