In [1]:
import os
import sys

sys.path.append("../../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint
from src.models.evaluate import (
    evaluate_model,
    get_sparsity,
    get_similarity,
    get_perplexity,
)
from src.utils.sampling import SamplingDataset
from src.pruning.prune_head import head_importance_prunning
from src.utils.helper import report_to_df, append_nth_row

In [3]:
name = "bert-4-128-yahoo"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ratio = 0.4
seed = 44

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-10-22 18:22:35


In [5]:
config = Config(name, device)
num_labels = config.config["num_labels"]
model = load_model(config)

Loading the model.




{

'architectures'

: 

'bert'

,
 

'dataset_name'

: 

'YahooAnswersTopics'

,
 

'model_name'

: 

'models/bert-4-128-yahoo'

,
 

'num_labels'

: 

10

,
 

'tokenizer_name'

: 

'fabriceyhc/bert-base-uncased-yahoo_answers_topics'

}




The model models/bert-4-128-yahoo is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.




train.pkl is loaded from cache.




valid.pkl is loaded from cache.




test.pkl is loaded from cache.




The dataset YahooAnswersTopics is loaded




{

'config_name'

: 

'yahoo_answers_topics'

,
 

'features'

: 

{'first_column': 'question_title', 'second_column': 'topic'}

,
 

'path'

: 

'yahoo_answers_topics'

}




In [7]:
from src.utils.load import load_cache
from src.utils.data_class import CustomEmbeddingDataset
from torch.utils.data import DataLoader

generated = load_cache(
    "datasets/generated_dataset/embedding_based/4_128-yahoo",
    "4_128-yahoo_top1.pkl",
)

4_128-yahoo_top1.pkl is loaded from cache.




In [8]:
generated["embeddings"] = generated.pop("example_list")
generated["labels"] = generated.pop("example_label")
generated["attention_mask"] = generated.pop("attn_list")

In [9]:
generated_data = CustomEmbeddingDataset(generated)
generated_dataloder = DataLoader(
    generated_data,
    batch_size=4,
)

In [10]:
# print("Evaluate the original model")
# result = evaluate_model(model, config, test_dataloader)

In [11]:
config.init_seed()
all_samples = SamplingDataset(
    generated_dataloder,
    config,
    200,
    num_samples,
    False,
    4,
    resample=False,
)

In [12]:
result_list = []

module = copy.deepcopy(model)

head_importance_prunning(module, config, all_samples, ratio)

print(f"Evaluate the pruned model")
result = evaluate_model(module, config, test_dataloader, verbose=True)
result_list.append(result)

Total heads to prune: 6




tensor([[0.3952, 0.6048, 0.4106, 0.4775],
        [0.4851, 0.5362, 0.4935, 0.4638],
        [0.5171, 0.5280, 0.4731, 0.4720],
        [0.4691, 0.5217, 0.5309, 0.4839]])




{(0, 0), (3, 0), (2, 3), (0, 2), (2, 2), (1, 3)}




Evaluate the pruned model




Evaluating the model:   0%|                                                                                   …

In [13]:
for concern in range(num_labels):
    config.init_seed()
    get_similarity(model, module, valid_dataloader, concern, num_samples, config)

adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9936841425382332




CCA coefficients mean non-concern: 0.9933468934490761




Linear CKA concern: 0.9757933798980143




Linear CKA non-concern: 0.966444886029343




Kernel CKA concern: 0.9383350213879563




Kernel CKA non-concern: 0.9230318701684435




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9937571316249806




CCA coefficients mean non-concern: 0.9933385466591088




Linear CKA concern: 0.9638299639086196




Linear CKA non-concern: 0.9687711078525099




Kernel CKA concern: 0.911945327863503




Kernel CKA non-concern: 0.9274355203549988




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9929768099773091




CCA coefficients mean non-concern: 0.9933307864546425




Linear CKA concern: 0.9694072983543901




Linear CKA non-concern: 0.9678263904205174




Kernel CKA concern: 0.9298401337207406




Kernel CKA non-concern: 0.9246068971586648




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9910842531808708




CCA coefficients mean non-concern: 0.9937542196034126




Linear CKA concern: 0.960379482476435




Linear CKA non-concern: 0.9671949019553153




Kernel CKA concern: 0.9078463303117691




Kernel CKA non-concern: 0.92601487489479




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9880830409832357




CCA coefficients mean non-concern: 0.9936951407932181




Linear CKA concern: 0.9737821754378974




Linear CKA non-concern: 0.9687738496825533




Kernel CKA concern: 0.9475136184122607




Kernel CKA non-concern: 0.9272718722159281




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9895669449858067




CCA coefficients mean non-concern: 0.9936488167628246




Linear CKA concern: 0.9272983125906719




Linear CKA non-concern: 0.9717275566520257




Kernel CKA concern: 0.8830608694709883




Kernel CKA non-concern: 0.9311096137754428




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9934516248384186




CCA coefficients mean non-concern: 0.9931194626517942




Linear CKA concern: 0.9694061144176108




Linear CKA non-concern: 0.969074411443013




Kernel CKA concern: 0.9171603148265811




Kernel CKA non-concern: 0.9301098652480422




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9911930687860538




CCA coefficients mean non-concern: 0.9932372853435922




Linear CKA concern: 0.9823588449083589




Linear CKA non-concern: 0.9659428521426094




Kernel CKA concern: 0.9555482772451639




Kernel CKA non-concern: 0.9237551969576606




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9901710582010288




CCA coefficients mean non-concern: 0.9931071758001175




Linear CKA concern: 0.9588604210010787




Linear CKA non-concern: 0.9645593861103698




Kernel CKA concern: 0.88676682987365




Kernel CKA non-concern: 0.9220245764063781




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9940583639038497




CCA coefficients mean non-concern: 0.9932926619817075




Linear CKA concern: 0.9632421795457882




Linear CKA non-concern: 0.9662137706442092




Kernel CKA concern: 0.9191428522137369




Kernel CKA non-concern: 0.9247921080331482




In [14]:
get_sparsity(module)
print("original model's perplexity")
get_perplexity(model, valid_dataloader, config)
print("pruned model's perplexity")
get_perplexity(module, valid_dataloader, config)

0.07362613823774808




{'bert.encoder.layer.0.attention.self.query.weight': 0.5, 'bert.encoder.layer.0.attention.self.query.bias': 0.0, 'bert.encoder.layer.0.attention.self.key.weight': 0.5, 'bert.encoder.layer.0.attention.self.key.bias': 0.0, 'bert.encoder.layer.0.attention.self.value.weight': 0.5, 'bert.encoder.layer.0.attention.self.value.bias': 0.0, 'bert.encoder.layer.0.attention.output.dense.weight': 0.5, 'bert.encoder.layer.0.attention.output.dense.bias': 0.0, 'bert.encoder.layer.0.intermediate.dense.weight': 0.0, 'bert.encoder.layer.0.intermediate.dense.bias': 0.0, 'bert.encoder.layer.0.output.dense.weight': 0.0, 'bert.encoder.layer.0.output.dense.bias': 0.0, 'bert.encoder.layer.1.attention.self.query.weight': 0.25, 'bert.encoder.layer.1.attention.self.query.bias': 0.0, 'bert.encoder.layer.1.attention.self.key.weight': 0.25, 'bert.encoder.layer.1.attention.self.key.bias': 0.0, 'bert.encoder.layer.1.attention.self.value.weight': 0.25, 'bert.encoder.layer.1.attention.self.value.bias': 0.0, 'bert.encode




original model's perplexity




3.2110652923583984




pruned model's perplexity




3.498887777328491




3.498887777328491

In [15]:
df_list = [report_to_df(df) for df in result_list]
csv_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
new_df = df_list[0]
new_df.to_csv(f"results/{csv_name}.csv", index=False)
print(csv_name)
new_df

2024-10-22_18-24-56




Unnamed: 0,class,precision,recall,f1-score,support
0,0,0.5057,0.5191,0.5123,2992
1,1,0.6864,0.4104,0.5137,2992
2,2,0.6531,0.6189,0.6355,3012
3,3,0.3212,0.6531,0.4307,2998
4,4,0.7809,0.7027,0.7397,2973
5,5,0.8023,0.7705,0.786,3054
6,6,0.7102,0.3763,0.4919,3003
7,7,0.5042,0.6902,0.5828,3012
8,8,0.6914,0.5439,0.6089,2982
9,9,0.7503,0.6348,0.6877,2982
