In [2]:
import torch
import ssl
import matplotlib.pyplot as plt
from src.logger.logger_initializer import LoggerInitializer
from src.utils.path_getter import PathGetter
from src.data_loader.dataset_getter import DatasetGetter

ssl._create_default_https_context = ssl._create_unverified_context
device = 'cuda'
torch.__version__

LoggerInitializer().init()
dataset = DatasetGetter(PathGetter.get_data_directory(), batch_size=12).get()

[2023-04-17 20:19:08] 8231 root {logger_initializer-47} INFO - Current process ID: 8231


In [3]:
from src.models.bnn.pyro_resnet18_bnn_classifier import PyroResnet18BnnClassifier

classifier = PyroResnet18BnnClassifier(dataset=dataset, device=device).init()

In [4]:
finetune_result_store = dict()

def finetune_layer(classifier, layer_name: str, num_epoch: int, result_store: dict):
    classifier.unfreeze_layer(layer_name)
    result_df, result_dict = classifier.fit(num_epoch=num_epoch)
    # display(result_df)
    stage_name = f'unfreeze_{layer_name}'
    result_store[stage_name] = result_dict
    result_df.to_csv(f'{classifier.name}_{stage_name}_{num_epoch}_epoch.pt')
    classifier.save_model(result_store)
    classifier.score(dataset.validate_dataloader)
    return result_df, result_dict

In [None]:
finetune_layer(classifier, 'fc', num_epoch=50, result_store=finetune_result_store)

[2023-04-17 20:20:54] 8231 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 0/50]: f1_score: 0.153 accuracy_score: 0.246 loss: 1114740077709074.500 
[2023-04-17 20:21:16] 8231 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 1/50]: f1_score: 0.178 accuracy_score: 0.266 loss: 843905656445659.375 


In [None]:
finetune_layer(classifier, 'layer4', num_epoch=30, result_store=finetune_result_store)

In [None]:
finetune_layer(classifier, 'layer3', num_epoch=30, result_store=finetune_result_store)

In [None]:
finetune_layer(classifier, 'layer2', num_epoch=30, result_store=finetune_result_store)

In [None]:
finetune_layer(classifier, 'layer1', num_epoch=30, result_store=finetune_result_store)

In [7]:
test_model_weights_path = '../../bayesian_refactoring/assets/model_weights/pyro_resnet18bnn_11922_3b79ad55-d56c-484d-9328-d101444ac6f1.pt'

In [8]:
test_results = torch.load(test_model_weights_path)

In [8]:
new_classifier = PyroResnet18BnnClassifier(
    dataset=dataset, device=device
).load_model(
    model_state_dict=test_results['best_f1_score_model'],
    pyro_state_dict=test_results['pyro_state_dict']
)

In [9]:
new_classifier.score(dataset.validate_dataloader)

{'f1_score': 0.11227969435503775, 'accuracy_score': 0.22924901185770752}