In [1]:
import os
import sys

sys.path.append("../../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint
from src.models.evaluate import (
    evaluate_model,
    get_sparsity,
    get_similarity,
    get_perplexity,
)
from src.utils.sampling import SamplingDataset
from src.pruning.prune import prune_wanda
from src.utils.helper import report_to_df, append_nth_row

In [3]:
name = "bert-tiny-yahoo"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
wanda_ratio = 0.6
seed = 44
include_layers = ["attention", "intermediate", "output"]
exclude_layers = None

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-10-21 13:44:25


In [5]:
config = Config(name, device)
num_labels = config.config["num_labels"]
model = load_model(config)

Loading the model.




{

'architectures'

: 

'bert'

,
 

'dataset_name'

: 

'YahooAnswersTopics'

,
 

'model_name'

: 

'models/bert-tiny-yahoo'

,
 

'num_labels'

: 

10

,
 

'tokenizer_name'

: 

'fabriceyhc/bert-base-uncased-yahoo_answers_topics'

}




The model models/bert-tiny-yahoo is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.




train.pkl is loaded from cache.




valid.pkl is loaded from cache.




test.pkl is loaded from cache.




The dataset YahooAnswersTopics is loaded




{

'config_name'

: 

'yahoo_answers_topics'

,
 

'features'

: 

{'first_column': 'question_title', 'second_column': 'topic'}

,
 

'path'

: 

'yahoo_answers_topics'

}




In [7]:
# print("Evaluate the original model")
# result = evaluate_model(model, config, test_dataloader)

In [8]:
config.init_seed()
all_samples = SamplingDataset(
    train_dataloader,
    config,
    200,
    num_samples,
    False,
    4,
    resample=False,
)

In [9]:
result_list = []
module = copy.deepcopy(model)
prune_wanda(
    module,
    config,
    all_samples,
    sparsity_ratio=wanda_ratio,
    include_layers=include_layers,
    exclude_layers=exclude_layers,
)
print("Evaluate the pruned model")
result = evaluate_model(module, config, test_dataloader, verbose=True)
result_list.append(result)

Evaluate the pruned model




Evaluating the model:   0%|                                                                               | 0/…

In [10]:
for concern in range(num_labels):
    config.init_seed()
    get_similarity(model, module, valid_dataloader, concern, num_samples, config)

adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9938755845951943




CCA coefficients mean non-concern: 0.9934431952750673




Linear CKA concern: 0.9651669269496925




Linear CKA non-concern: 0.9581939539818984




Kernel CKA concern: 0.9443911468974121




Kernel CKA non-concern: 0.9356376292637246




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9938881004086785




CCA coefficients mean non-concern: 0.9935361494250623




Linear CKA concern: 0.8916879760851454




Linear CKA non-concern: 0.9668009029118299




Kernel CKA concern: 0.8816042110201598




Kernel CKA non-concern: 0.946746140179499




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9929104713715808




CCA coefficients mean non-concern: 0.9933366044823193




Linear CKA concern: 0.920709464162385




Linear CKA non-concern: 0.9588410488323327




Kernel CKA concern: 0.8978158381096605




Kernel CKA non-concern: 0.9381826648911513




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.995154531855621




CCA coefficients mean non-concern: 0.9932311017635703




Linear CKA concern: 0.9504928031994135




Linear CKA non-concern: 0.9573367661661932




Kernel CKA concern: 0.936313811879696




Kernel CKA non-concern: 0.9358670817526339




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.98700232530286




CCA coefficients mean non-concern: 0.9943303810271652




Linear CKA concern: 0.8614033121497995




Linear CKA non-concern: 0.9638609095353611




Kernel CKA concern: 0.8665753862412319




Kernel CKA non-concern: 0.942549472210201




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9882571127420758




CCA coefficients mean non-concern: 0.9934929801304108




Linear CKA concern: 0.733807448961755




Linear CKA non-concern: 0.9668542527935887




Kernel CKA concern: 0.7719194908555316




Kernel CKA non-concern: 0.9461200018455052




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9940197729492398




CCA coefficients mean non-concern: 0.9937554043126341




Linear CKA concern: 0.9655989422721201




Linear CKA non-concern: 0.9558190256372819




Kernel CKA concern: 0.9477623119376885




Kernel CKA non-concern: 0.9338620636849582




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9927538917733468




CCA coefficients mean non-concern: 0.993312794420359




Linear CKA concern: 0.9121946777768817




Linear CKA non-concern: 0.9506163595072386




Kernel CKA concern: 0.8902341823458387




Kernel CKA non-concern: 0.931072274747913




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9931373410984639




CCA coefficients mean non-concern: 0.99358472102668




Linear CKA concern: 0.9646086799938812




Linear CKA non-concern: 0.950042107114373




Kernel CKA concern: 0.9535344220601862




Kernel CKA non-concern: 0.9295528447034934




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: 0.9942517270907958




CCA coefficients mean non-concern: 0.9932205045069255




Linear CKA concern: 0.9108569904231671




Linear CKA non-concern: 0.9530677067270839




Kernel CKA concern: 0.9169976087716779




Kernel CKA non-concern: 0.9326426099294123




In [11]:
get_sparsity(module)
print("original model's perplexity")
get_perplexity(model, valid_dataloader, config)
print("pruned model's perplexity")
get_perplexity(module, valid_dataloader, config)

0.5667252166591664




{'bert.encoder.layer.0.attention.self.query.weight': 0.59375, 'bert.encoder.layer.0.attention.self.query.bias': 0.0, 'bert.encoder.layer.0.attention.self.key.weight': 0.59375, 'bert.encoder.layer.0.attention.self.key.bias': 0.0, 'bert.encoder.layer.0.attention.self.value.weight': 0.59375, 'bert.encoder.layer.0.attention.self.value.bias': 0.0, 'bert.encoder.layer.0.attention.output.dense.weight': 0.59375, 'bert.encoder.layer.0.attention.output.dense.bias': 0.0, 'bert.encoder.layer.0.intermediate.dense.weight': 0.59375, 'bert.encoder.layer.0.intermediate.dense.bias': 0.0, 'bert.encoder.layer.0.output.dense.weight': 0.599609375, 'bert.encoder.layer.0.output.dense.bias': 0.0, 'bert.encoder.layer.1.attention.self.query.weight': 0.59375, 'bert.encoder.layer.1.attention.self.query.bias': 0.0, 'bert.encoder.layer.1.attention.self.key.weight': 0.59375, 'bert.encoder.layer.1.attention.self.key.bias': 0.0, 'bert.encoder.layer.1.attention.self.value.weight': 0.59375, 'bert.encoder.layer.1.attentio




original model's perplexity




3.2782363891601562




pruned model's perplexity




3.764113664627075




3.764113664627075

In [12]:
df_list = [report_to_df(df) for df in result_list]
csv_name = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
new_df = df_list[0]
new_df.to_csv(f"results/{csv_name}.csv", index=False)
print(csv_name)
new_df

2024-10-21_13-45-45




Unnamed: 0,class,precision,recall,f1-score,support
0,0,0.4856,0.518,0.5013,2992
1,1,0.7593,0.369,0.4966,2992
2,2,0.6962,0.573,0.6287,3012
3,3,0.3203,0.6471,0.4285,2998
4,4,0.7317,0.7652,0.7481,2973
5,5,0.7762,0.7721,0.7741,3054
6,6,0.7186,0.3453,0.4665,3003
7,7,0.4861,0.658,0.5592,3012
8,8,0.6622,0.6107,0.6354,2982
9,9,0.7602,0.6123,0.6783,2982
