In [1]:
import torch
import ssl
import matplotlib.pyplot as plt
from src.logger.logger_initializer import LoggerInitializer
from src.utils.path_getter import PathGetter
from src.data_loader.dataset_getter import DatasetGetter

ssl._create_default_https_context = ssl._create_unverified_context
device = 'cuda'
torch.__version__

LoggerInitializer().init()
dataset = DatasetGetter(PathGetter.get_data_directory(), batch_size=12).get()

[2023-04-17 21:48:06] 13455 root {logger_initializer-47} INFO - Current process ID: 13455


In [2]:
from src.models.bnn.pyro_resnet18_bnn_classifier import PyroResnet18BnnClassifier

classifier = PyroResnet18BnnClassifier(dataset=dataset, device=device).init()

In [3]:
finetune_result_store = dict()

def finetune_layer(classifier, layer_name: str, num_epoch: int, result_store: dict):
    classifier.unfreeze_layer(layer_name)
    result_df, result_dict = classifier.fit(num_epoch=num_epoch)
    # display(result_df)
    stage_name = f'unfreeze_{layer_name}'
    result_store[stage_name] = result_dict
    result_df.to_csv(f'{classifier.name}_{classifier._uuid}_{stage_name}_{num_epoch}_epoch.csv')
    classifier.save_model(result_store)
    classifier.score(dataset.validate_dataloader)
    return result_df, result_dict

In [None]:
finetune_layer(classifier, 'fc', num_epoch=50, result_store=finetune_result_store)

[2023-04-17 21:48:49] 13455 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 0/50]: f1_score: 0.149 accuracy_score: 0.213 loss: 1114690261651989.000 
[2023-04-17 21:49:30] 13455 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 1/50]: f1_score: 0.174 accuracy_score: 0.230 loss: 843850609226034.125 
[2023-04-17 21:50:09] 13455 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 2/50]: f1_score: 0.158 accuracy_score: 0.227 loss: 620902579671494.625 


In [None]:
finetune_layer(classifier, 'layer4', num_epoch=30, result_store=finetune_result_store)

In [None]:
finetune_layer(classifier, 'layer3', num_epoch=30, result_store=finetune_result_store)

In [None]:
finetune_layer(classifier, 'layer2', num_epoch=30, result_store=finetune_result_store)

In [None]:
finetune_layer(classifier, 'layer1', num_epoch=30, result_store=finetune_result_store)

In [None]:
# Pull out weights

import pyro
pyro.get_param_store().get_all_param_names()
pyro.get_param_store().get_param('guide.locs._resnet.fc.0.weight')
pyro.get_param_store().get_param('guide.scales._resnet.fc.0.weight')

In [None]:
test_model_weights_path = '../../bayesian_refactoring/assets/model_weights/pyro_resnet18bnn_11922_3b79ad55-d56c-484d-9328-d101444ac6f1.pt'

In [None]:
test_results = torch.load(test_model_weights_path)

In [None]:
new_classifier = PyroResnet18BnnClassifier(
    dataset=dataset, device=device
).load_model(
    model_state_dict=test_results['best_f1_score_model'],
    pyro_state_dict=test_results['pyro_state_dict']
)

In [None]:
new_classifier.score(dataset.validate_dataloader)