In [1]:
import torch
import torchvision
import ssl
import matplotlib.pyplot as plt
import pyro
from src.logger.logger_initializer import LoggerInitializer
from src.utils.path_getter import PathGetter
from src.data_loader.dataset_getter import DatasetGetter

ssl._create_default_https_context = ssl._create_unverified_context
device = 'cuda'
torch.__version__

pyro.set_rng_seed(777)
torch.manual_seed(777)

LoggerInitializer().init()
dataset = DatasetGetter(PathGetter.get_data_directory(), batch_size=8).get()

[2023-04-19 16:16:12] 20464 root {logger_initializer-47} INFO - Current process ID: 20464


In [2]:
from src.models.bnn.pyro_resnet18_bnn_classifier import PyroResnet18BnnClassifier

classifier = PyroResnet18BnnClassifier(dataset=dataset, device=device).init()
classifier.freeze_all_layers()

In [4]:
finetune_result_store = dict()

def finetune_layer(classifier, layer_name: str, num_epoch: int, result_store: dict):
    classifier.unfreeze_layer(layer_name)
    result_df, result_dict = classifier.fit(num_epoch=num_epoch)
    # display(result_df)
    stage_name = f'unfreeze_{layer_name}'
    result_store[stage_name] = result_dict
    result_df.to_csv(f'{classifier.name}_{classifier._uuid}_{stage_name}_{num_epoch}_epoch.csv')
    classifier.save_model(result_store)
    print(classifier.score(dataset.validate_dataloader))
    return result_df, result_dict

In [5]:
finetune_layer(classifier, 'fc', num_epoch=50, result_store=finetune_result_store)

[2023-04-19 10:30:02] 28319 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 0/50]: f1_score: 0.184 accuracy_score: 0.213 loss: 1567517185528544.000 
[2023-04-19 10:30:53] 28319 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 1/50]: f1_score: 0.149 accuracy_score: 0.226 loss: 1012170423464647.500 
[2023-04-19 10:31:43] 28319 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 2/50]: f1_score: 0.206 accuracy_score: 0.281 loss: 606673001018050.125 
[2023-04-19 10:32:33] 28319 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 3/50]: f1_score: 0.217 accuracy_score: 0.293 loss: 327428213020129.812 
[2023-04-19 10:33:23] 28319 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 4/50]: f1_score: 0.162 accuracy_score: 0.243 loss: 150829665617597.969 
[2023-04-19 10:34:12] 28319 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 5/50]: f1_score: 0.164 accuracy_score: 0.232

{'f1_score': 0.866291841850007, 'accuracy_score': 0.8885375494071146}


(    f1_score  accuracy_score          loss
 0   0.184260        0.213281  1.567517e+15
 1   0.148596        0.226172  1.012170e+15
 2   0.206088        0.281250  6.066730e+14
 3   0.216895        0.293359  3.274282e+14
 4   0.162376        0.242578  1.508297e+14
 5   0.163971        0.231641  5.323315e+13
 6   0.166192        0.256641  1.100331e+13
 7   0.189957        0.255859  5.344172e+11
 8   0.399728        0.462109  4.616095e+09
 9   0.510169        0.581641  4.105600e+09
 10  0.614645        0.672656  3.870306e+09
 11  0.689954        0.744922  3.647380e+09
 12  0.732335        0.780859  3.414907e+09
 13  0.693013        0.760547  3.174358e+09
 14  0.728242        0.789844  2.927862e+09
 15  0.733684        0.786328  2.678313e+09
 16  0.757364        0.797656  2.429357e+09
 17  0.773558        0.813672  2.184895e+09
 18  0.773969        0.814453  1.949518e+09
 19  0.834981        0.855859  1.727602e+09
 20  0.826305        0.855859  1.523973e+09
 21  0.836309        0.858594  1

In [6]:
finetune_layer(classifier, 'layer4', num_epoch=30, result_store=finetune_result_store)

[2023-04-19 11:28:26] 28319 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 0/30]: f1_score: 0.881 accuracy_score: 0.900 loss: 912685906.404 
[2023-04-19 11:29:40] 28319 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 1/30]: f1_score: 0.898 accuracy_score: 0.917 loss: 874349766.476 
[2023-04-19 11:30:53] 28319 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 2/30]: f1_score: 0.893 accuracy_score: 0.907 loss: 873501547.921 
[2023-04-19 11:32:06] 28319 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 3/30]: f1_score: 0.883 accuracy_score: 0.902 loss: 872894413.190 
[2023-04-19 11:33:19] 28319 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 4/30]: f1_score: 0.870 accuracy_score: 0.893 loss: 872290069.438 
[2023-04-19 11:34:33] 28319 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 5/30]: f1_score: 0.896 accuracy_score: 0.905 loss: 871751192.696 
[2023-04-1

{'f1_score': 0.8660444879481496, 'accuracy_score': 0.8932806324110671}


(    f1_score  accuracy_score          loss
 0   0.880880        0.900000  9.126859e+08
 1   0.897976        0.917188  8.743498e+08
 2   0.892521        0.906641  8.735015e+08
 3   0.883428        0.902344  8.728944e+08
 4   0.870023        0.892578  8.722901e+08
 5   0.895522        0.905078  8.717512e+08
 6   0.895617        0.908984  8.717040e+08
 7   0.894899        0.909375  8.713005e+08
 8   0.872139        0.900781  8.711912e+08
 9   0.886098        0.909766  8.709263e+08
 10  0.894221        0.910547  8.708255e+08
 11  0.896754        0.913281  8.707784e+08
 12  0.895382        0.914844  8.706982e+08
 13  0.879591        0.898828  8.706312e+08
 14  0.888775        0.905859  8.705872e+08
 15  0.870784        0.900391  8.706733e+08
 16  0.886605        0.903516  8.704878e+08
 17  0.873704        0.894141  8.704877e+08
 18  0.887259        0.896484  8.704265e+08
 19  0.894861        0.907813  8.705412e+08
 20  0.888068        0.900781  8.702582e+08
 21  0.900628        0.918750  8

In [4]:
old_result = torch.load('../../bayesian_refactoring/assets/model_weights/densenet18cnn_28974_e649f283-a6d8-4c26-8c5d-4b24d610de1a.pt')

In [4]:
display(old_result['unfreeze_classifier'].keys())

dict_keys(['best_loss', 'best_f1_score', 'best_accuracy_score', 'best_loss_model', 'best_f1_score_model', 'best_accuracy_score_model'])

In [5]:
from src.models.cnn.densenet121_cnn_classifier import Densenet121CnnClassifier

classifier = Densenet121CnnClassifier(dataset=dataset, device=device).load_model(
    old_result['unfreeze_classifier']['best_f1_score_model'])
classifier.freeze_all_layers()
classifier.unfreeze_layer('classifier')

In [6]:
classifier.score(dataset.validate_dataloader)

{'f1_score': 0.830377275568568, 'accuracy_score': 0.8695652173913043}

In [None]:
finetune_layer(classifier, 'features.denseblock4', num_epoch=30, result_store=finetune_result_store)

In [None]:
classifier.score(dataset.validate_dataloader)

In [None]:
finetune_layer(classifier, 'layer3', num_epoch=30, result_store=finetune_result_store)

In [None]:
finetune_layer(classifier, 'layer2', num_epoch=30, result_store=finetune_result_store)

In [None]:
finetune_layer(classifier, 'layer1', num_epoch=30, result_store=finetune_result_store)

In [None]:
classifier.score(dataset.validate_dataloader)

In [None]:
# Pull out weights

import pyro
pyro.get_param_store().get_all_param_names()
pyro.get_param_store().get_param('guide.locs._resnet.fc.0.weight')
pyro.get_param_store().get_param('guide.scales._resnet.fc.0.weight')

In [None]:
test_model_weights_path = '../../bayesian_refactoring/assets/model_weights/pyro_resnet18bnn_13455_dac7c351-4890-4316-a5dc-fafd5a2075b7.pt'

In [None]:
test_results = torch.load(test_model_weights_path)

In [None]:
display(test_results.keys())

In [None]:
new_classifier = PyroResnet18BnnClassifier(
    dataset=dataset, device=device
).load_model(
    model_state_dict=test_results['unfreeze_layer4']['best_f1_score_model'],
    pyro_state_dict=test_results['unfreeze_layer4']['pyro_state_dict']
)

In [None]:
new_classifier.score(dataset.validate_dataloader)