In [1]:
import torch
import torchvision
import ssl
import matplotlib.pyplot as plt
import pyro
from src.logger.logger_initializer import LoggerInitializer
from src.utils.path_getter import PathGetter
from src.data_loader.dataset_getter import DatasetGetter

ssl._create_default_https_context = ssl._create_unverified_context
device = 'cuda'
torch.__version__

pyro.set_rng_seed(777)
torch.manual_seed(777)

LoggerInitializer().init()
dataset = DatasetGetter(PathGetter.get_data_directory(), batch_size=8).get()

[2023-04-20 01:26:36] 6117 root {logger_initializer-47} INFO - Current process ID: 6117


In [2]:
from src.models.bnn.pyro_miniresnet_bnn_classifier import PyroMiniresnetBnnClassifier

classifier = PyroMiniresnetBnnClassifier(dataset=dataset, device=device).init()
classifier.unfreeze_all_layers()

In [3]:
finetune_result_store = dict()

def finetune_layer(classifier, layer_name: str, num_epoch: int, result_store: dict):
    classifier.unfreeze_layer(layer_name)
    result_df, result_dict = classifier.fit(num_epoch=num_epoch)
    # display(result_df)
    stage_name = f'unfreeze_{layer_name}'
    result_store[stage_name] = result_dict
    result_df.to_csv(f'{classifier.name}_{classifier._uuid}_{stage_name}_{num_epoch}_epoch.csv')
    classifier.save_model(result_store)
    print(classifier.score(dataset.validate_dataloader))
    return result_df, result_dict

def tune_model(classifier, num_epoch: int, result_store: dict):
    result_df, result_dict = classifier.fit(num_epoch=num_epoch)
    stage_name = f'tune_model'
    result_store[stage_name] = result_dict
    result_df.to_csv(f'{classifier.name}_{classifier._uuid}_{stage_name}_{num_epoch}_epoch.csv')
    classifier.save_model(result_store)
    print(classifier.score(dataset.validate_dataloader))
    return result_df, result_dict

In [4]:
tune_model(classifier, num_epoch=500, result_store=finetune_result_store)

[2023-04-20 01:27:04] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 0/500]: f1_score: 0.327 accuracy_score: 0.380 loss: 688448498700110.875 
[2023-04-20 01:27:24] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 1/500]: f1_score: 0.376 accuracy_score: 0.420 loss: 444645322111892.750 
[2023-04-20 01:27:43] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 2/500]: f1_score: 0.275 accuracy_score: 0.316 loss: 266609663516052.000 
[2023-04-20 01:28:03] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 3/500]: f1_score: 0.348 accuracy_score: 0.398 loss: 143981328168877.125 
[2023-04-20 01:28:21] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 4/500]: f1_score: 0.298 accuracy_score: 0.356 loss: 66391403016373.094 
[2023-04-20 01:28:42] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 5/500]: f1_score: 0.260 accuracy_score: 0.321 lo

[2023-04-20 01:43:25] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 50/500]: f1_score: 0.509 accuracy_score: 0.593 loss: 382168025.249 
[2023-04-20 01:43:45] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 51/500]: f1_score: 0.524 accuracy_score: 0.605 loss: 382134326.126 
[2023-04-20 01:44:05] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 52/500]: f1_score: 0.506 accuracy_score: 0.595 loss: 382046944.659 
[2023-04-20 01:44:25] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 53/500]: f1_score: 0.524 accuracy_score: 0.616 loss: 382072128.121 
[2023-04-20 01:44:45] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 54/500]: f1_score: 0.458 accuracy_score: 0.546 loss: 382107364.597 
[2023-04-20 01:45:05] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 55/500]: f1_score: 0.488 accuracy_score: 0.573 loss: 382127776.934 
[202

[2023-04-20 01:59:40] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 100/500]: f1_score: 0.534 accuracy_score: 0.613 loss: 382036177.512 
[2023-04-20 01:59:59] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 101/500]: f1_score: 0.553 accuracy_score: 0.617 loss: 382078327.831 
[2023-04-20 02:00:18] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 102/500]: f1_score: 0.580 accuracy_score: 0.649 loss: 382055440.591 
[2023-04-20 02:00:37] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 103/500]: f1_score: 0.536 accuracy_score: 0.618 loss: 382061830.844 
[2023-04-20 02:00:55] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 104/500]: f1_score: 0.558 accuracy_score: 0.642 loss: 382110728.538 
[2023-04-20 02:01:15] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 105/500]: f1_score: 0.520 accuracy_score: 0.597 loss: 382111600.549

[2023-04-20 02:15:57] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 150/500]: f1_score: 0.565 accuracy_score: 0.613 loss: 382061929.142 
[2023-04-20 02:16:17] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 151/500]: f1_score: 0.557 accuracy_score: 0.644 loss: 382109986.963 
[2023-04-20 02:16:37] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 152/500]: f1_score: 0.618 accuracy_score: 0.665 loss: 382059822.146 
[2023-04-20 02:16:57] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 153/500]: f1_score: 0.595 accuracy_score: 0.662 loss: 382077382.386 
[2023-04-20 02:17:16] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 154/500]: f1_score: 0.623 accuracy_score: 0.679 loss: 382151389.368 
[2023-04-20 02:17:36] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 155/500]: f1_score: 0.585 accuracy_score: 0.637 loss: 382191372.188

[2023-04-20 02:32:20] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 200/500]: f1_score: 0.613 accuracy_score: 0.663 loss: 382187354.719 
[2023-04-20 02:32:40] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 201/500]: f1_score: 0.621 accuracy_score: 0.677 loss: 382105015.412 
[2023-04-20 02:33:00] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 202/500]: f1_score: 0.541 accuracy_score: 0.598 loss: 382024625.564 
[2023-04-20 02:33:18] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 203/500]: f1_score: 0.609 accuracy_score: 0.658 loss: 382163975.254 
[2023-04-20 02:33:38] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 204/500]: f1_score: 0.580 accuracy_score: 0.630 loss: 382156929.925 
[2023-04-20 02:33:57] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 205/500]: f1_score: 0.634 accuracy_score: 0.681 loss: 382132470.746

[2023-04-20 02:48:36] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 250/500]: f1_score: 0.651 accuracy_score: 0.692 loss: 382176246.386 
[2023-04-20 02:48:55] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 251/500]: f1_score: 0.485 accuracy_score: 0.564 loss: 382104194.719 
[2023-04-20 02:49:15] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 252/500]: f1_score: 0.640 accuracy_score: 0.677 loss: 382106271.135 
[2023-04-20 02:49:34] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 253/500]: f1_score: 0.588 accuracy_score: 0.650 loss: 382107477.563 
[2023-04-20 02:49:53] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 254/500]: f1_score: 0.493 accuracy_score: 0.575 loss: 382201975.196 
[2023-04-20 02:50:11] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 255/500]: f1_score: 0.555 accuracy_score: 0.636 loss: 382134858.749

[2023-04-20 03:04:44] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 300/500]: f1_score: 0.618 accuracy_score: 0.666 loss: 381971979.755 
[2023-04-20 03:05:03] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 301/500]: f1_score: 0.621 accuracy_score: 0.675 loss: 382044272.958 
[2023-04-20 03:05:22] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 302/500]: f1_score: 0.632 accuracy_score: 0.676 loss: 382115125.998 
[2023-04-20 03:05:41] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 303/500]: f1_score: 0.633 accuracy_score: 0.671 loss: 382151805.255 
[2023-04-20 03:06:01] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 304/500]: f1_score: 0.591 accuracy_score: 0.642 loss: 382134467.986 
[2023-04-20 03:06:20] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 305/500]: f1_score: 0.634 accuracy_score: 0.672 loss: 382146025.362

[2023-04-20 03:20:45] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 350/500]: f1_score: 0.645 accuracy_score: 0.684 loss: 382031913.944 
[2023-04-20 03:21:04] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 351/500]: f1_score: 0.608 accuracy_score: 0.670 loss: 382105057.172 
[2023-04-20 03:21:24] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 352/500]: f1_score: 0.624 accuracy_score: 0.673 loss: 382156121.771 
[2023-04-20 03:21:43] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 353/500]: f1_score: 0.563 accuracy_score: 0.621 loss: 382101089.216 
[2023-04-20 03:22:03] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 354/500]: f1_score: 0.638 accuracy_score: 0.686 loss: 382154903.056 
[2023-04-20 03:22:23] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 355/500]: f1_score: 0.631 accuracy_score: 0.673 loss: 382186118.393

[2023-04-20 03:37:05] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 400/500]: f1_score: 0.598 accuracy_score: 0.651 loss: 382097758.895 
[2023-04-20 03:37:25] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 401/500]: f1_score: 0.630 accuracy_score: 0.671 loss: 382148196.982 
[2023-04-20 03:37:45] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 402/500]: f1_score: 0.626 accuracy_score: 0.666 loss: 382130871.037 
[2023-04-20 03:38:04] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 403/500]: f1_score: 0.584 accuracy_score: 0.645 loss: 382107347.893 
[2023-04-20 03:38:24] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 404/500]: f1_score: 0.522 accuracy_score: 0.607 loss: 382063007.524 
[2023-04-20 03:38:43] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 405/500]: f1_score: 0.640 accuracy_score: 0.686 loss: 382143403.557

[2023-04-20 03:53:25] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 450/500]: f1_score: 0.630 accuracy_score: 0.673 loss: 382147726.284 
[2023-04-20 03:53:45] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 451/500]: f1_score: 0.629 accuracy_score: 0.670 loss: 382068005.396 
[2023-04-20 03:54:05] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 452/500]: f1_score: 0.675 accuracy_score: 0.717 loss: 382077586.272 
[2023-04-20 03:54:24] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 453/500]: f1_score: 0.647 accuracy_score: 0.685 loss: 382043648.229 
[2023-04-20 03:54:45] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 454/500]: f1_score: 0.660 accuracy_score: 0.695 loss: 382023530.955 
[2023-04-20 03:55:04] 6117 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 455/500]: f1_score: 0.609 accuracy_score: 0.672 loss: 382097149.328

{'f1_score': 0.659407519161053, 'accuracy_score': 0.6861660079051384}


(     f1_score  accuracy_score          loss
 0    0.327276        0.380078  6.884485e+14
 1    0.376026        0.419922  4.446453e+14
 2    0.274544        0.316406  2.666097e+14
 3    0.348493        0.397656  1.439813e+14
 4    0.297855        0.355859  6.639140e+13
 ..        ...             ...           ...
 495  0.622521        0.668359  3.821463e+08
 496  0.645362        0.680469  3.820997e+08
 497  0.626818        0.677734  3.821333e+08
 498  0.642873        0.688672  3.821472e+08
 499  0.623714        0.663672  3.821649e+08
 
 [500 rows x 3 columns],
 {'best_loss': 368544080.5668788,
  'best_f1_score': 0.6751041469317914,
  'best_accuracy_score': 0.716796875,
  'best_loss_model': OrderedDict([('_resnet.bn1.weight',
                tensor([1.0324, 0.9653, 0.7148, 0.9058, 1.0824, 1.1014, 0.8309, 0.9801, 1.0149,
                        1.0311, 1.0015, 0.8567, 0.6549, 0.9553, 0.8604, 0.9005, 1.2873, 0.9238,
                        0.9865, 1.1878, 1.1924, 0.8117, 1.0019, 0.8752, 0

In [4]:
old_result = torch.load('../../bayesian_refactoring/assets/model_weights/densenet18cnn_28974_e649f283-a6d8-4c26-8c5d-4b24d610de1a.pt')

In [4]:
display(old_result['unfreeze_classifier'].keys())

dict_keys(['best_loss', 'best_f1_score', 'best_accuracy_score', 'best_loss_model', 'best_f1_score_model', 'best_accuracy_score_model'])

In [5]:
from src.models.cnn.densenet121_cnn_classifier import Densenet121CnnClassifier

classifier = Densenet121CnnClassifier(dataset=dataset, device=device).load_model(
    old_result['unfreeze_classifier']['best_f1_score_model'])
classifier.freeze_all_layers()
classifier.unfreeze_layer('classifier')

In [6]:
classifier.score(dataset.validate_dataloader)

{'f1_score': 0.830377275568568, 'accuracy_score': 0.8695652173913043}

In [None]:
finetune_layer(classifier, 'features.denseblock4', num_epoch=30, result_store=finetune_result_store)

In [None]:
classifier.score(dataset.validate_dataloader)

In [None]:
finetune_layer(classifier, 'layer3', num_epoch=30, result_store=finetune_result_store)

In [None]:
finetune_layer(classifier, 'layer2', num_epoch=30, result_store=finetune_result_store)

In [None]:
finetune_layer(classifier, 'layer1', num_epoch=30, result_store=finetune_result_store)

In [None]:
classifier.score(dataset.validate_dataloader)

In [None]:
# Pull out weights

import pyro
pyro.get_param_store().get_all_param_names()
pyro.get_param_store().get_param('guide.locs._resnet.fc.0.weight')
pyro.get_param_store().get_param('guide.scales._resnet.fc.0.weight')

In [None]:
test_model_weights_path = '../../bayesian_refactoring/assets/model_weights/pyro_resnet18bnn_13455_dac7c351-4890-4316-a5dc-fafd5a2075b7.pt'

In [None]:
test_results = torch.load(test_model_weights_path)

In [None]:
display(test_results.keys())

In [None]:
new_classifier = PyroResnet18BnnClassifier(
    dataset=dataset, device=device
).load_model(
    model_state_dict=test_results['unfreeze_layer4']['best_f1_score_model'],
    pyro_state_dict=test_results['unfreeze_layer4']['pyro_state_dict']
)

In [None]:
new_classifier.score(dataset.validate_dataloader)