# Fine tune the backbone of a BLL model

In this notebook, we continue training the backbone of a BLL model based on the Bayesian layer's predictions. We allow a reconstruction loss to be included according to $\alpha$ 

## Setup


Import libraries

In [8]:
import importlib
import models.regene_models as regene_models
importlib.reload(regene_models)
import models.BLL_VI
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import os

Set the device

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

Using device: mps


Load the Datasets

In [3]:
# Load the MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.MNIST(root='../data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='../data', train=False, download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)

Set the latent dimension

In [4]:
latent_dim = 256

Create a models directory if it doesn't exist

In [5]:
# Create models directory in parent directory if it doesn't exist
os.makedirs(os.path.join('..', 'model_saves'), exist_ok=True)
model_saves_path = os.path.join('..', 'model_saves')

## Training

### Loading models

Define and load the backbone and decoder

In [10]:
importlib.reload(regene_models)

decoder = regene_models.Decoder(latent_dim=256, device=device)
backbone = regene_models.Classifier(latent_dim=latent_dim, num_classes=10, device=device)

decoder.load_state_dict(torch.load(os.path.join(model_saves_path, 'joint_decoder.pth'), map_location=device))
backbone.load_state_dict(torch.load(os.path.join(model_saves_path, 'joint_classifier.pth'), map_location=device))

  decoder.load_state_dict(torch.load(os.path.join(model_saves_path, 'joint_decoder.pth'), map_location=device))
  backbone.load_state_dict(torch.load(os.path.join(model_saves_path, 'joint_classifier.pth'), map_location=device))


<All keys matched successfully>

Define and load the BLL model

In [12]:
from models.BLL_VI import BayesianLastLayerVI
importlib.reload(models.BLL_VI)

bll_vi = BayesianLastLayerVI(
    backbone=backbone,
    input_dim=256,
    output_dim=10,
    device=device  
)

bll_vi.load_checkpoint(model_saves_path + '/mnist_bll_vi_models/BLL_VI_Joint_Decoder.pt')

 [load_checkpoint] Loaded checkpoint from ../model_saves/mnist_bll_vi_models/BLL_VI_Joint_Decoder.pt


  checkpoint = torch.load(path, map_location=self.device)


### Finetuning

In [17]:
import train
importlib.reload(models.BLL_VI)
importlib.reload(train)

train_loss, val_loss = train.fine_tune_backbone(bll_vi, decoder, 'BLL_VI_Joint_Decoder_Finetuned', trainloader, testloader, num_epochs=5, lr=0.001, lambda_recon=0.8, model_saves_dir=model_saves_path, patience=10)



Epoch [1/5], Total Loss: 0.0171, Class Loss: 0.0616, Recon Loss: 0.0060, [31mtime: 23.47 seconds
[0m
[32m    Val Total Loss: 0.0230, Class Loss: 0.0883, Recon Loss: 0.0067
[0m
[34mbest validation loss[0m
Epoch [2/5], Total Loss: 0.0172, Class Loss: 0.0624, Recon Loss: 0.0059, [31mtime: 23.20 seconds
[0m
[32m    Val Total Loss: 0.0235, Class Loss: 0.0896, Recon Loss: 0.0070
[0m
Epoch [3/5], Total Loss: 0.0171, Class Loss: 0.0619, Recon Loss: 0.0059, [31mtime: 23.56 seconds
[0m
[32m    Val Total Loss: 0.0220, Class Loss: 0.0836, Recon Loss: 0.0066
[0m
[34mbest validation loss[0m
Epoch [4/5], Total Loss: 0.0171, Class Loss: 0.0618, Recon Loss: 0.0059, [31mtime: 24.80 seconds
[0m
[32m    Val Total Loss: 0.0240, Class Loss: 0.0933, Recon Loss: 0.0066
[0m
Epoch [5/5], Total Loss: 0.0172, Class Loss: 0.0624, Recon Loss: 0.0059, [31mtime: 23.39 seconds
[0m
[32m    Val Total Loss: 0.0264, Class Loss: 0.1052, Recon Loss: 0.0067
[0m
[31m   average time: 28.86 seconds
[0m