In [1]:
import pyro
import torch
import ssl
import matplotlib.pyplot as plt


from src.logger.logger_initializer import LoggerInitializer
from src.models.bnn.pyro_miniresnet_bnn_classifier import PyroMiniresnetBnnClassifier
from src.utils.path_getter import PathGetter
from src.data_loader.dataset_getter import DatasetGetter


ssl._create_default_https_context = ssl._create_unverified_context
device = 'cuda'
torch.__version__

pyro.set_rng_seed(777)
torch.manual_seed(777)

LoggerInitializer().init()
dataset = DatasetGetter(PathGetter.get_data_directory(), batch_size=12).get()

[2023-04-18 11:05:34] 4596 root {logger_initializer-47} INFO - Current process ID: 4596


In [4]:
classifier = PyroMiniresnetBnnClassifier(dataset=dataset, device=device).init()

In [5]:
classifier.unfreeze_all_layers()
num_epoch = 300
result_df, result_dict = classifier.fit(num_epoch=num_epoch)
classifier.save_model(result_dict)
result_df.to_csv(f'{classifier.name}_{classifier._uuid}_{num_epoch}_epoch.csv')
# display(result_df)
# print(result_dict)

[2023-04-18 00:36:44] 451 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 0/300]: f1_score: 0.326 accuracy_score: 0.395 loss: 489601134885036.625 
[2023-04-18 00:37:09] 451 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 1/300]: f1_score: 0.390 accuracy_score: 0.457 loss: 370690796339438.188 
[2023-04-18 00:37:35] 451 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 2/300]: f1_score: 0.376 accuracy_score: 0.430 loss: 272801893492163.125 
[2023-04-18 00:38:00] 451 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 3/300]: f1_score: 0.375 accuracy_score: 0.432 loss: 193888135889356.188 
[2023-04-18 00:38:26] 451 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 4/300]: f1_score: 0.339 accuracy_score: 0.383 loss: 131903019515274.250 
[2023-04-18 00:38:52] 451 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 5/300]: f1_score: 0.368 accuracy_score: 0.411 loss: 8

In [None]:
classifier.score(dataset.validate_dataloader)

In [None]:
result_df.to_csv(f'{classifier.name}_300_epoch_freeze_all_layers_except_fc_scores.csv')

In [4]:
test_model_weights_path = '../../bayesian_refactoring/assets/model_weights/pyro_miniresnet_15014_2dc1ce9b-0c57-4387-acab-451c9b83999f.pt'

In [5]:
test_results = torch.load(test_model_weights_path)

In [6]:
display(test_results.keys())

dict_keys(['best_loss', 'best_f1_score', 'best_accuracy_score', 'best_loss_model', 'best_f1_score_model', 'best_accuracy_score_model', 'best_loss_model_pyro_params', 'best_f1_score_model_pyro_params', 'best_accuracy_score_model_pyro_params'])

In [7]:
new_classifier = PyroMiniresnetBnnClassifier(
    dataset=dataset, device=device
).load_model(
    model_state_dict=test_results['best_f1_score_model'],
    pyro_state_dict=test_results['best_f1_score_model_pyro_params']
)

In [8]:
new_classifier.score(dataset.validate_dataloader)

{'f1_score': 0.35777670714739807, 'accuracy_score': 0.391699604743083}