## Example of training/testing using BNN with bayesian-torch

In [2]:
%load_ext lab_black

In [6]:
import itertools
from pathlib import Path
from types import SimpleNamespace
from functools import partial
import matplotlib.pyplot as plt
import numpy as np

import torch
from torch.utils.data import DataLoader

from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn, get_kl_loss

from bnnrul.cmapss.dataset import CMAPSSDataModule
from bnnrul.cmapss.models import CMAPSSModel

In [4]:
args = SimpleNamespace(
    data_path="../data/cmapss",
    out_path="../results/cmapss",
    scn="bnn_bt",
    arch="linear",
)
checkpoint_dir = Path(f"{args.out_path}/{args.scn}/checkpoints/{args.arch}")

### Train

In [5]:
data = CMAPSSDataModule(args.data_path, batch_size=1000)
train_dl = data.train_dataloader()

In [7]:
model = CMAPSSModel(data.win_length, data.n_features, args.arch)
const_bnn_prior_parameters = {
    "prior_mu": 0.0,
    "prior_sigma": 1.0,
    "posterior_mu_init": 0.0,
    "posterior_rho_init": -3.0,
    "type": "Reparameterization",  # Flipout or Reparameterization
    "moped_enable": False,  # True to initialize mu/sigma from the pretrained dnn weights
    "moped_delta": 0.5,
}
dnn_to_bnn(model, const_bnn_prior_parameters)

In [8]:
model

CMAPSSModel(
  (net): Linear(
    (layers): Sequential(
      (0): Flatten(start_dim=1, end_dim=-1)
      (1): LinearReparameterization()
      (2): Sigmoid()
      (3): LinearReparameterization()
      (4): Sigmoid()
      (5): LinearReparameterization()
      (6): Sigmoid()
      (7): LinearReparameterization()
      (8): Softplus(beta=1, threshold=20)
    )
  )
)

### Test

In [9]:
data = CMAPSSDataModule(args.data_path, batch_size=1)
test_dl = data.test_dataloader()

In [10]:
# Issue with loading state
# bnn.load_state_dict(
#     torch.load(f"{checkpoint_dir}/bnn_state_dict_lr_{num_epochs}_epochs.pt")
# )