In [8]:
import torch
import ssl
import matplotlib.pyplot as plt
import pyro
from src.logger.logger_initializer import LoggerInitializer
from src.utils.path_getter import PathGetter
from src.data_loader.dataset_getter import DatasetGetter
from src.models.bnn.pyro_resnet18_bnn_classifier import PyroResnet18BnnClassifier

ssl._create_default_https_context = ssl._create_unverified_context
device = 'cuda'
torch.__version__

pyro.set_rng_seed(777)
torch.manual_seed(777)


LoggerInitializer().init()
dataset = DatasetGetter(PathGetter.get_data_directory(), batch_size=12).get()

In [2]:
from src.models.bnn.pyro_resnet18_bnn_classifier import PyroResnet18BnnClassifier

classifier = PyroResnet18BnnClassifier(dataset=dataset, device=device).init()

In [3]:
finetune_result_store = dict()

def finetune_layer(classifier, layer_name: str, num_epoch: int, result_store: dict):
    classifier.unfreeze_layer(layer_name)
    result_df, result_dict = classifier.fit(num_epoch=num_epoch)
    # display(result_df)
    stage_name = f'unfreeze_{layer_name}'
    result_store[stage_name] = result_dict
    result_df.to_csv(f'{classifier.name}_{classifier._uuid}_{stage_name}_{num_epoch}_epoch.csv')
    classifier.save_model(result_store)
    print(classifier.score(dataset.validate_dataloader))
    return result_df, result_dict

In [4]:
finetune_layer(classifier, 'fc', num_epoch=50, result_store=finetune_result_store)

[2023-04-17 21:48:49] 13455 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 0/50]: f1_score: 0.149 accuracy_score: 0.213 loss: 1114690261651989.000 
[2023-04-17 21:49:30] 13455 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 1/50]: f1_score: 0.174 accuracy_score: 0.230 loss: 843850609226034.125 
[2023-04-17 21:50:09] 13455 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 2/50]: f1_score: 0.158 accuracy_score: 0.227 loss: 620902579671494.625 
[2023-04-17 21:50:49] 13455 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 3/50]: f1_score: 0.145 accuracy_score: 0.254 loss: 441186719857898.938 
[2023-04-17 21:51:28] 13455 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 4/50]: f1_score: 0.145 accuracy_score: 0.230 loss: 300042824339096.438 
[2023-04-17 21:52:08] 13455 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 5/50]: f1_score: 0.164 accuracy_score: 0.239 

(    f1_score  accuracy_score          loss
 0   0.148919        0.212891  1.114690e+15
 1   0.173644        0.230469  8.438506e+14
 2   0.157619        0.226562  6.209026e+14
 3   0.144521        0.253906  4.411867e+14
 4   0.144602        0.230078  3.000428e+14
 5   0.164348        0.239453  1.928035e+14
 6   0.215923        0.283984  1.148039e+14
 7   0.190406        0.270703  6.137868e+13
 8   0.295126        0.360156  2.786040e+13
 9   0.302044        0.389062  9.581951e+12
 10  0.236418        0.303906  1.876118e+12
 11  0.448486        0.516406  8.132246e+10
 12  0.546238        0.620703  3.099449e+09
 13  0.603084        0.677344  2.826191e+09
 14  0.662766        0.748828  2.707200e+09
 15  0.665022        0.751953  2.605686e+09
 16  0.672734        0.769531  2.506585e+09
 17  0.665167        0.774219  2.406108e+09
 18  0.675359        0.784375  2.302876e+09
 19  0.670383        0.779297  2.197031e+09
 20  0.673180        0.783203  2.088986e+09
 21  0.751266        0.817187  1

In [5]:
finetune_layer(classifier, 'layer4', num_epoch=30, result_store=finetune_result_store)

[2023-04-17 22:22:12] 13455 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 0/30]: f1_score: 0.877 accuracy_score: 0.889 loss: 614829287.923 
[2023-04-17 22:22:52] 13455 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 1/30]: f1_score: 0.849 accuracy_score: 0.879 loss: 578037767.645 
[2023-04-17 22:23:31] 13455 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 2/30]: f1_score: 0.889 accuracy_score: 0.900 loss: 578902644.119 
[2023-04-17 22:24:12] 13455 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 3/30]: f1_score: 0.855 accuracy_score: 0.878 loss: 579317691.090 
[2023-04-17 22:24:52] 13455 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 4/30]: f1_score: 0.866 accuracy_score: 0.883 loss: 579488255.216 
[2023-04-17 22:25:32] 13455 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 5/30]: f1_score: 0.875 accuracy_score: 0.884 loss: 579696155.994 
[2023-04-1

(    f1_score  accuracy_score          loss
 0   0.876911        0.889453  6.148293e+08
 1   0.849320        0.879297  5.780378e+08
 2   0.889499        0.900391  5.789026e+08
 3   0.855321        0.878125  5.793177e+08
 4   0.866225        0.883203  5.794883e+08
 5   0.875429        0.883984  5.796962e+08
 6   0.877078        0.887109  5.799998e+08
 7   0.870886        0.875781  5.799984e+08
 8   0.869388        0.881641  5.800599e+08
 9   0.859730        0.877734  5.802039e+08
 10  0.868836        0.884375  5.801395e+08
 11  0.843690        0.864062  5.802919e+08
 12  0.900996        0.903516  5.803482e+08
 13  0.867601        0.883203  5.802453e+08
 14  0.838916        0.864844  5.803457e+08
 15  0.867598        0.897266  5.802639e+08
 16  0.869205        0.894531  5.803747e+08
 17  0.883052        0.903125  5.802954e+08
 18  0.871033        0.891406  5.802388e+08
 19  0.853933        0.884375  5.803512e+08
 20  0.899223        0.915625  5.802803e+08
 21  0.873601        0.900000  5

In [6]:
finetune_layer(classifier, 'layer3', num_epoch=30, result_store=finetune_result_store)

[2023-04-17 22:42:17] 13455 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 0/30]: f1_score: 0.845 accuracy_score: 0.871 loss: 621172258.518 
[2023-04-17 22:42:57] 13455 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 1/30]: f1_score: 0.882 accuracy_score: 0.894 loss: 583156582.261 
[2023-04-17 22:43:37] 13455 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 2/30]: f1_score: 0.894 accuracy_score: 0.912 loss: 582756768.530 
[2023-04-17 22:44:17] 13455 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 3/30]: f1_score: 0.866 accuracy_score: 0.895 loss: 582539857.552 
[2023-04-17 22:44:57] 13455 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 4/30]: f1_score: 0.845 accuracy_score: 0.884 loss: 582284054.503 
[2023-04-17 22:45:36] 13455 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 5/30]: f1_score: 0.848 accuracy_score: 0.881 loss: 581903371.894 
[2023-04-1

(    f1_score  accuracy_score          loss
 0   0.844718        0.871094  6.211723e+08
 1   0.881537        0.894141  5.831566e+08
 2   0.894175        0.912109  5.827568e+08
 3   0.866339        0.895312  5.825399e+08
 4   0.845310        0.883984  5.822841e+08
 5   0.848061        0.880859  5.819034e+08
 6   0.832918        0.850781  5.817434e+08
 7   0.858305        0.884766  5.814548e+08
 8   0.870947        0.887891  5.813138e+08
 9   0.855594        0.882812  5.811371e+08
 10  0.863822        0.880078  5.809938e+08
 11  0.879453        0.892188  5.808886e+08
 12  0.853504        0.889453  5.807584e+08
 13  0.864810        0.892969  5.806520e+08
 14  0.845116        0.869922  5.807054e+08
 15  0.857522        0.886328  5.805510e+08
 16  0.856394        0.869922  5.805058e+08
 17  0.867581        0.883984  5.805726e+08
 18  0.849109        0.881250  5.804489e+08
 19  0.860104        0.877344  5.804420e+08
 20  0.865283        0.889453  5.804687e+08
 21  0.875990        0.905078  5

In [7]:
finetune_layer(classifier, 'layer2', num_epoch=30, result_store=finetune_result_store)

[2023-04-17 23:02:22] 13455 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 0/30]: f1_score: 0.863 accuracy_score: 0.890 loss: 621137482.096 
[2023-04-17 23:03:02] 13455 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 1/30]: f1_score: 0.864 accuracy_score: 0.883 loss: 583038399.949 
[2023-04-17 23:03:42] 13455 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 2/30]: f1_score: 0.854 accuracy_score: 0.873 loss: 582845143.609 
[2023-04-17 23:04:23] 13455 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 3/30]: f1_score: 0.843 accuracy_score: 0.886 loss: 582532438.934 
[2023-04-17 23:05:03] 13455 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 4/30]: f1_score: 0.869 accuracy_score: 0.887 loss: 582077739.118 
[2023-04-17 23:05:43] 13455 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 5/30]: f1_score: 0.884 accuracy_score: 0.908 loss: 581932363.888 
[2023-04-1

(    f1_score  accuracy_score          loss
 0   0.863423        0.889844  6.211375e+08
 1   0.864344        0.882812  5.830384e+08
 2   0.854024        0.873437  5.828451e+08
 3   0.842657        0.885547  5.825324e+08
 4   0.869473        0.886719  5.820777e+08
 5   0.883643        0.907813  5.819324e+08
 6   0.863968        0.885156  5.816486e+08
 7   0.832574        0.869922  5.814512e+08
 8   0.866071        0.884375  5.813816e+08
 9   0.896730        0.905859  5.810631e+08
 10  0.869820        0.893750  5.810948e+08
 11  0.876972        0.885547  5.809220e+08
 12  0.857536        0.874219  5.808498e+08
 13  0.870449        0.890625  5.806662e+08
 14  0.872083        0.890234  5.806242e+08
 15  0.856408        0.875000  5.806381e+08
 16  0.875238        0.887500  5.805948e+08
 17  0.851246        0.877734  5.806180e+08
 18  0.846980        0.882422  5.804683e+08
 19  0.898302        0.904687  5.804152e+08
 20  0.895149        0.913281  5.804973e+08
 21  0.872440        0.896484  5

In [8]:
finetune_layer(classifier, 'layer1', num_epoch=30, result_store=finetune_result_store)

[2023-04-17 23:22:34] 13455 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 0/30]: f1_score: 0.888 accuracy_score: 0.905 loss: 621309409.111 
[2023-04-17 23:23:14] 13455 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 1/30]: f1_score: 0.869 accuracy_score: 0.887 loss: 583211211.318 
[2023-04-17 23:23:55] 13455 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 2/30]: f1_score: 0.851 accuracy_score: 0.880 loss: 582879557.237 
[2023-04-17 23:24:35] 13455 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 3/30]: f1_score: 0.881 accuracy_score: 0.897 loss: 582464883.032 
[2023-04-17 23:25:15] 13455 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 4/30]: f1_score: 0.856 accuracy_score: 0.879 loss: 582230881.042 
[2023-04-17 23:25:55] 13455 src.scoring.epoch_score_printer {epoch_score_printer-21} INFO - [Epoch 5/30]: f1_score: 0.852 accuracy_score: 0.878 loss: 581955758.724 
[2023-04-1

(    f1_score  accuracy_score          loss
 0   0.888417        0.904687  6.213094e+08
 1   0.868572        0.886719  5.832112e+08
 2   0.851045        0.879687  5.828796e+08
 3   0.881116        0.896875  5.824649e+08
 4   0.855962        0.878906  5.822309e+08
 5   0.851699        0.877734  5.819558e+08
 6   0.866791        0.871094  5.815796e+08
 7   0.869906        0.891016  5.813557e+08
 8   0.854749        0.880469  5.813458e+08
 9   0.840056        0.871875  5.811573e+08
 10  0.813474        0.857812  5.809543e+08
 11  0.861074        0.888281  5.808452e+08
 12  0.854219        0.878125  5.807224e+08
 13  0.873092        0.898047  5.807094e+08
 14  0.845438        0.874609  5.806270e+08
 15  0.841374        0.867188  5.806048e+08
 16  0.891182        0.900391  5.805089e+08
 17  0.863620        0.888672  5.804693e+08
 18  0.858286        0.890625  5.805149e+08
 19  0.857174        0.885547  5.804966e+08
 20  0.884035        0.895703  5.805389e+08
 21  0.880749        0.890234  5

In [9]:
classifier.score(dataset.validate_dataloader)

{'f1_score': 0.8294524897117426, 'accuracy_score': 0.8478260869565217}

In [None]:
# Pull out weights

import pyro
pyro.get_param_store().get_all_param_names()
pyro.get_param_store().get_param('guide.locs._resnet.fc.0.weight')
pyro.get_param_store().get_param('guide.scales._resnet.fc.0.weight')

In [2]:
test_model_weights_path = '../../bayesian_refactoring/assets/model_weights/pyro_resnet18bnn_13455_dac7c351-4890-4316-a5dc-fafd5a2075b7.pt'

In [3]:
test_results = torch.load(test_model_weights_path)

In [11]:
display(test_results.keys())

dict_keys(['unfreeze_fc', 'pyro_state_dict', 'unfreeze_layer4', 'unfreeze_layer3', 'unfreeze_layer2', 'unfreeze_layer1'])

In [9]:
new_classifier = PyroResnet18BnnClassifier(
    dataset=dataset, device=device
).load_model(
    model_state_dict=test_results['unfreeze_layer4']['best_f1_score_model'],
    pyro_state_dict=test_results['unfreeze_layer4']['pyro_state_dict']
)

KeyError: 'pyro_state_dict'

In [None]:
new_classifier.score(dataset.validate_dataloader)