In [4]:
import torch
import ssl
import matplotlib.pyplot as plt
from src.logger.logger_initializer import LoggerInitializer
from src.utils.path_getter import PathGetter
from src.data_loader.dataset_getter import DatasetGetter

ssl._create_default_https_context = ssl._create_unverified_context
device = 'cuda'
torch.__version__

LoggerInitializer().init()
dataset = DatasetGetter(PathGetter.get_data_directory(), batch_size=12).get()

In [2]:
from src.models.bnn.pyro_resnet18_bnn_classifier import PyroResnet18BnnClassifier

classifier = PyroResnet18BnnClassifier(dataset=dataset, device=device).init()

In [4]:
finetune_result_store = dict()

def finetune_layer(classifier, layer_name: str, num_epoch: int, result_store: dict):
    classifier.unfreeze_layer(layer_name)
    result_df, result_dict = classifier.fit(num_epoch=num_epoch)
    # display(result_df)
    stage_name = f'unfreeze_{layer_name}'
    result_store[stage_name] = result_dict
    result_df.to_csv(f'{classifier.name}_{stage_name}_{num_epoch}_epoch.csv')
    classifier.save_model(result_store)
    classifier.score(dataset.validate_dataloader)
    return result_df, result_dict

In [5]:
finetune_layer(classifier, 'fc', num_epoch=50, result_store=finetune_result_store)

[2023-04-17 20:20:54] 8231 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 0/50]: f1_score: 0.153 accuracy_score: 0.246 loss: 1114740077709074.500 
[2023-04-17 20:21:16] 8231 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 1/50]: f1_score: 0.178 accuracy_score: 0.266 loss: 843905656445659.375 
[2023-04-17 20:21:38] 8231 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 2/50]: f1_score: 0.199 accuracy_score: 0.290 loss: 620968505174562.875 
[2023-04-17 20:21:59] 8231 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 3/50]: f1_score: 0.163 accuracy_score: 0.266 loss: 441261394268203.312 
[2023-04-17 20:22:21] 8231 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 4/50]: f1_score: 0.150 accuracy_score: 0.246 loss: 300117898108432.000 
[2023-04-17 20:22:42] 8231 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 5/50]: f1_score: 0.110 accuracy_score: 0.215 loss: 

(    f1_score  accuracy_score          loss
 0   0.152516        0.246094  1.114740e+15
 1   0.177729        0.265625  8.439057e+14
 2   0.198666        0.290234  6.209685e+14
 3   0.163166        0.266406  4.412614e+14
 4   0.150125        0.246094  3.001179e+14
 5   0.110381        0.214844  1.928716e+14
 6   0.177721        0.265234  1.148598e+14
 7   0.165247        0.238281  6.141864e+13
 8   0.180262        0.277344  2.788561e+13
 9   0.179394        0.267578  9.593968e+12
 10  0.226486        0.283984  1.879143e+12
 11  0.424963        0.489844  8.161991e+10
 12  0.510100        0.594141  3.161496e+09
 13  0.581280        0.677344  2.826405e+09
 14  0.629916        0.733984  2.699809e+09
 15  0.634525        0.739062  2.603838e+09
 16  0.648234        0.755469  2.506581e+09
 17  0.654202        0.766406  2.406150e+09
 18  0.672119        0.777734  2.302871e+09
 19  0.703625        0.771875  2.197039e+09
 20  0.732794        0.794141  2.089027e+09
 21  0.686194        0.773828  1

In [6]:
finetune_layer(classifier, 'layer4', num_epoch=30, result_store=finetune_result_store)

[2023-04-17 20:39:08] 8231 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 0/30]: f1_score: 0.851 accuracy_score: 0.890 loss: 614957509.181 
[2023-04-17 20:39:31] 8231 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 1/30]: f1_score: 0.825 accuracy_score: 0.859 loss: 578099486.399 
[2023-04-17 20:39:53] 8231 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 2/30]: f1_score: 0.819 accuracy_score: 0.872 loss: 578848328.551 
[2023-04-17 20:40:15] 8231 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 3/30]: f1_score: 0.818 accuracy_score: 0.863 loss: 579160234.989 
[2023-04-17 20:40:36] 8231 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 4/30]: f1_score: 0.791 accuracy_score: 0.851 loss: 579563490.192 
[2023-04-17 20:40:58] 8231 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 5/30]: f1_score: 0.817 accuracy_score: 0.875 loss: 579797189.141 
[2023-04-17 20:4

KeyboardInterrupt: 

In [None]:
finetune_layer(classifier, 'layer3', num_epoch=30, result_store=finetune_result_store)

In [None]:
finetune_layer(classifier, 'layer2', num_epoch=30, result_store=finetune_result_store)

In [None]:
finetune_layer(classifier, 'layer1', num_epoch=30, result_store=finetune_result_store)

In [7]:
# Pull out weights

import pyro
pyro.get_param_store().get_all_param_names()
pyro.get_param_store().get_param('guide.locs._resnet.fc.0.weight')
pyro.get_param_store().get_param('guide.scales._resnet.fc.0.weight')

<pyro.params.param_store.ParamStoreDict at 0x7f26811fb310>

In [7]:
test_model_weights_path = '../../bayesian_refactoring/assets/model_weights/pyro_resnet18bnn_11922_3b79ad55-d56c-484d-9328-d101444ac6f1.pt'

In [8]:
test_results = torch.load(test_model_weights_path)

In [8]:
new_classifier = PyroResnet18BnnClassifier(
    dataset=dataset, device=device
).load_model(
    model_state_dict=test_results['best_f1_score_model'],
    pyro_state_dict=test_results['pyro_state_dict']
)

In [9]:
new_classifier.score(dataset.validate_dataloader)

{'f1_score': 0.11227969435503775, 'accuracy_score': 0.22924901185770752}