In [1]:
import os
import sys

sys.path.append("../../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint
from src.models.evaluate import (
    evaluate_model,
    get_sparsity,
    get_similarity,
    get_perplexity,
)
from src.utils.sampling import SamplingDataset
from src.pruning.prune import prune_wanda
from src.utils.helper import report_to_df, append_nth_row

In [3]:
name = "bert-tiny-yahoo"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
wanda_ratio = 0.4
seed = 44
include_layers = ["attention", "intermediate", "output"]
exclude_layers = None

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-10-21 13:41:57


In [5]:
config = Config(name, device)
num_labels = config.config["num_labels"]
model = load_model(config)

Loading the model.




{

'architectures'

: 

'bert'

,
 

'dataset_name'

: 

'YahooAnswersTopics'

,
 

'model_name'

: 

'models/bert-tiny-yahoo'

,
 

'num_labels'

: 

10

,
 

'tokenizer_name'

: 

'fabriceyhc/bert-base-uncased-yahoo_answers_topics'

}




The model models/bert-tiny-yahoo is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.




train.pkl is loaded from cache.




valid.pkl is loaded from cache.




test.pkl is loaded from cache.




The dataset YahooAnswersTopics is loaded




{

'config_name'

: 

'yahoo_answers_topics'

,
 

'features'

: 

{'first_column': 'question_title', 'second_column': 'topic'}

,
 

'path'

: 

'yahoo_answers_topics'

}




In [7]:
# print("Evaluate the original model")
# result = evaluate_model(model, config, test_dataloader)

In [8]:
config.init_seed()
all_samples = SamplingDataset(
    train_dataloader,
    config,
    200,
    num_samples,
    False,
    4,
    resample=False,
)

In [9]:
result_list = []
module = copy.deepcopy(model)
prune_wanda(
    module,
    config,
    all_samples,
    sparsity_ratio=wanda_ratio,
    include_layers=include_layers,
    exclude_layers=exclude_layers,
)
print("Evaluate the pruned model")
result = evaluate_model(module, config, test_dataloader, verbose=True)
result_list.append(result)

Evaluate the pruned model




Evaluating the model:   0%|                                                                               | 0/…

In [10]:
for concern in range(num_labels):
    config.init_seed()
    get_similarity(model, module, valid_dataloader, concern, num_samples, config)

adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.998653104299616




CCA coefficients mean non-concern: 0.9983457124618468




Linear CKA concern: 0.9928747568900207




Linear CKA non-concern: 0.9892045431759973




Kernel CKA concern: 0.9899529367510793




Kernel CKA non-concern: 0.9850221044732527




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.99819927748926




CCA coefficients mean non-concern: 0.9985236587050125




Linear CKA concern: 0.9709405679627435




Linear CKA non-concern: 0.9918740483696277




Kernel CKA concern: 0.9708983165444128




Kernel CKA non-concern: 0.9882741731517374




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9981121761809534




CCA coefficients mean non-concern: 0.9983196848832545




Linear CKA concern: 0.985849220307225




Linear CKA non-concern: 0.98947584665801




Kernel CKA concern: 0.9830441723965982




Kernel CKA non-concern: 0.9856464696475955




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9989774705125843




CCA coefficients mean non-concern: 0.9983091284418872




Linear CKA concern: 0.9933750321411803




Linear CKA non-concern: 0.9885385449555694




Kernel CKA concern: 0.9927014998901961




Kernel CKA non-concern: 0.9845923935111689




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.997794647586893




CCA coefficients mean non-concern: 0.9984823804009009




Linear CKA concern: 0.9777388209056648




Linear CKA non-concern: 0.9903916407180934




Kernel CKA concern: 0.9790167816271069




Kernel CKA non-concern: 0.9857618621259333




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9968195610682841




CCA coefficients mean non-concern: 0.9985597124351139




Linear CKA concern: 0.8503601340901498




Linear CKA non-concern: 0.9937246600885716




Kernel CKA concern: 0.8595406857839304




Kernel CKA non-concern: 0.9908890103081428




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9988283743483978




CCA coefficients mean non-concern: 0.9983426086982112




Linear CKA concern: 0.9952848368217594




Linear CKA non-concern: 0.9884591016062758




Kernel CKA concern: 0.9938407553521851




Kernel CKA non-concern: 0.9845292637815299




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9987281924199164




CCA coefficients mean non-concern: 0.998358107530523




Linear CKA concern: 0.9847647973753613




Linear CKA non-concern: 0.9877287464338307




Kernel CKA concern: 0.9837465353616068




Kernel CKA non-concern: 0.9846549929393218




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9983963678661527




CCA coefficients mean non-concern: 0.9983681173598475




Linear CKA concern: 0.9922682673963491




Linear CKA non-concern: 0.9872707282776824




Kernel CKA concern: 0.991099311421534




Kernel CKA non-concern: 0.9840007986743688




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9987509709016865




CCA coefficients mean non-concern: 0.9983068299897999




Linear CKA concern: 0.9818624247323617




Linear CKA non-concern: 0.9877894242148766




Kernel CKA concern: 0.9852883664402404




Kernel CKA non-concern: 0.9843832602961885




In [11]:
get_sparsity(module)
print("original model's perplexity")
get_perplexity(model, valid_dataloader, config)
print("pruned model's perplexity")
get_perplexity(module, valid_dataloader, config)

0.3790555547490818




{'bert.encoder.layer.0.attention.self.query.weight': 0.3984375, 'bert.encoder.layer.0.attention.self.query.bias': 0.0, 'bert.encoder.layer.0.attention.self.key.weight': 0.3984375, 'bert.encoder.layer.0.attention.self.key.bias': 0.0, 'bert.encoder.layer.0.attention.self.value.weight': 0.3984375, 'bert.encoder.layer.0.attention.self.value.bias': 0.0, 'bert.encoder.layer.0.attention.output.dense.weight': 0.3984375, 'bert.encoder.layer.0.attention.output.dense.bias': 0.0, 'bert.encoder.layer.0.intermediate.dense.weight': 0.3984375, 'bert.encoder.layer.0.intermediate.dense.bias': 0.0, 'bert.encoder.layer.0.output.dense.weight': 0.3984375, 'bert.encoder.layer.0.output.dense.bias': 0.0, 'bert.encoder.layer.1.attention.self.query.weight': 0.3984375, 'bert.encoder.layer.1.attention.self.query.bias': 0.0, 'bert.encoder.layer.1.attention.self.key.weight': 0.3984375, 'bert.encoder.layer.1.attention.self.key.bias': 0.0, 'bert.encoder.layer.1.attention.self.value.weight': 0.3984375, 'bert.encoder.la




original model's perplexity




3.2782363891601562




pruned model's perplexity




3.3599812984466553




3.3599812984466553

In [12]:
df_list = [report_to_df(df) for df in result_list]
csv_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
new_df = df_list[0]
new_df.to_csv(f"results/{csv_name}.csv", index=False)
print(csv_name)
new_df

2024-10-21_13-43-05




Unnamed: 0,class,precision,recall,f1-score,support
0,0,0.546,0.488,0.5154,2992
1,1,0.7147,0.4873,0.5795,2992
2,2,0.6607,0.6341,0.6471,3012
3,3,0.3429,0.6087,0.4387,2998
4,4,0.6808,0.7999,0.7355,2973
5,5,0.7545,0.7881,0.771,3054
6,6,0.676,0.3786,0.4854,3003
7,7,0.565,0.6292,0.5954,3012
8,8,0.6567,0.6261,0.641,2982
9,9,0.7423,0.6251,0.6787,2982
