# Subnetwork Inference on Snelson 1D dataset

In [1]:
%load_ext autoreload
%autoreload 2

## Imports and configuration

In [2]:
import numpy as np
import torch
import copy
import plotly.graph_objects as go

from torch.utils.data import ConcatDataset, DataLoader

from data.snelson1d import Snelson1D
from main import set_seed, get_device
from util.plots import plot_data, plot_regression, plot_bayesian_regression
from models.nets import create_mlp
from trainer import ModelTrainer

from laplace import Laplace
from laplace.utils import LargestVarianceDiagLaplaceSubnetMask
from strategies.pruning import OBDSubnetMask
from strategies.kfe import KronckerFactoredEigenSubnetMask



IndentationError: unexpected indent (main.py, line 154)

### Read configuration for Snelson experiments 

In [None]:
from hydra import initialize, compose
from omegaconf import OmegaConf
try:
    initialize(version_base=None, config_path="configuration")
except Exception as e:
    print(e)
config = compose(config_name="snelson.yaml")
set_seed(config.seed)


## Import data and visualize

In [None]:
snelson1d = Snelson1D(config.data.path)
train_dataloader, val_dataloader, test_dataloader = snelson1d.get_dataloaders(batch_size=config.trainer.batch_size, val_size=config.data.val_size, random_state=config.data.seed)

In [None]:
train_X = np.concatenate([train_dataloader.dataset.X.numpy(), val_dataloader.dataset.X.numpy()], axis=0).squeeze()
train_y = np.concatenate([train_dataloader.dataset.y.numpy(), val_dataloader.dataset.y.numpy()], axis=0).squeeze()
test_X = test_dataloader.dataset.X.numpy().squeeze()
test_y = test_dataloader.dataset.y.numpy().squeeze()

plot_data(train_X, train_y, test_X, test_y, title="Snelson1D data")


## Subnetwork inference and comparison

### Train a MAP model

In [None]:
input_size = 1
output_size = 1
device = get_device()
print(f"Using device: {device}")
model = create_mlp(
            input_size=input_size,
            hidden_sizes=config.model.hidden_sizes,
            output_size=output_size,
        )
model = model.to(device=device, dtype=torch.float64)
print(f"Using model: {model}")
trainer = ModelTrainer(config.trainer, device=device)

Using device: cpu
Using model: Sequential(
  (0): Linear(in_features=1, out_features=50, bias=True)
  (1): ReLU()
  (2): Linear(in_features=50, out_features=50, bias=True)
  (3): ReLU()
  (4): Linear(in_features=50, out_features=1, bias=True)
)


In [None]:
map_model, sigma = trainer.train(
            model=model,
            train_dataloader=train_dataloader,
            val_dataloader=val_dataloader,
        )
print(f"Using sigma={sigma}")

Using sigma=0.09092176705598831


In [None]:
map_nll, map_err = trainer.evaluate(
            model=map_model, sigma=sigma, dataloader=test_dataloader
        )
print(f"MAP NLL: {map_nll}")

MAP NLL: 2.7441849465458312


In [None]:
prior_precisions = [0.1, 1.0, 2.0, 5.0]

In [None]:
def bayesian_regression(model, train_dataloader, test_dataloader, title):
    X_train = train_dataloader.dataset.X.numpy().squeeze()
    y_train = train_dataloader.dataset.y.numpy().squeeze()
    X_test = test_dataloader.dataset.X.numpy().squeeze()
    y_test = test_dataloader.dataset.y.numpy().squeeze()
    X_test = np.concatenate([X_train, X_test]).reshape(-1, 1)
    X = torch.from_numpy(X_test).to(device=device, dtype=torch.float64)
    f_mu, f_var = model(x=X)
    f_mu = f_mu.detach().squeeze().cpu().numpy()
    pred_std = torch.sqrt(f_var.squeeze() + model.sigma_noise**2).detach().cpu().numpy()
    return plot_bayesian_regression(X_train=X_train, y_train=y_train, X_test=X.squeeze().detach().cpu().numpy(), y_test=f_mu, y_std=pred_std, title=title)


### Last layer diagonal approximation

In [None]:
model_copy = copy.deepcopy(map_model)

last_layer_diag_la, prior_precision = trainer.train_la_posthoc(
            model=model_copy,
            dataloader=train_dataloader,
            subset_of_weights="last_layer",
            hessian_structure="diag",
            sigma_noise=sigma,
            prior_mean=config.trainer.la.prior_mean,
            prior_precisions=prior_precisions,
            val_dataloader=val_dataloader,
        )
last_layer_diag_la_nll = trainer.evaluate_la(last_layer_diag_la, test_dataloader)
print(f"Last layer diag covariance approximation NLL: {last_layer_diag_la_nll}")

Last layer diag covariance approximation NLL: 0.8422183812002569


In [None]:
fig = bayesian_regression(last_layer_diag_la, train_dataloader, test_dataloader, title="Last layer diag covariance approximation")
fig.show()

### Last layer full covariance approximation

In [None]:
model_copy = copy.deepcopy(map_model)

last_layer_full_la, prior_precision = trainer.train_la_posthoc(
            model=model_copy,
            dataloader=train_dataloader,
            subset_of_weights="last_layer",
            hessian_structure="full",
            sigma_noise=sigma,
            prior_mean=config.trainer.la.prior_mean,
            prior_precisions=prior_precisions,
            val_dataloader=val_dataloader,
        )
last_layer_full_la_nll = trainer.evaluate_la(last_layer_full_la, test_dataloader)
print(f"Last layer full covariance approximation NLL:: {last_layer_full_la_nll}")

Last layer full covariance approximation NLL:: 0.9788088966738945


In [None]:
fig = bayesian_regression(last_layer_full_la, train_dataloader, test_dataloader, title="Last layer diag covariance approximation")
fig.show()

In [None]:
subnet_sizes = [30, 90, 300, 600, 1200, 1800]


### Subnetwork Inference using LargestVarianceDiagLaplaceSubnetMask

In [None]:
largest_variance_subnet_results = []
for n_params_subnet in subnet_sizes:
    model_for_selection = copy.deepcopy(map_model)
    laplace_model_for_selection = Laplace(
        model=model_for_selection,
        likelihood="regression",
        subset_of_weights="all",
        hessian_structure="diag",
        sigma_noise=sigma,
        prior_mean=config.trainer.la.prior_mean,
    )

    subnetwork_mask = LargestVarianceDiagLaplaceSubnetMask(
        model_for_selection,
        n_params_subnet=n_params_subnet,
        diag_laplace_model=laplace_model_for_selection,
    )
    subnetwork_indices = subnetwork_mask.select(train_loader=train_dataloader)
    subnetwork_indices = torch.tensor(
        subnetwork_indices.cpu().numpy(), dtype=torch.long
    )

    model_copy = copy.deepcopy(map_model)
    la, prior_precision = trainer.train_la_posthoc(
        model=model_copy,
        dataloader=train_dataloader,
        subset_of_weights="subnetwork",
        hessian_structure="full",
        sigma_noise=sigma,
        prior_mean=config.trainer.la.prior_mean,
        prior_precisions=prior_precisions,
        subnetwork_indices=subnetwork_indices,
        val_dataloader=val_dataloader,
    )

    nll = trainer.evaluate_la(la, test_dataloader)
    largest_variance_subnet_results.append(nll)
    print(f"LA NLL: {nll}")

LA NLL: 2.2408168961536554
LA NLL: 1.7028994646171505
LA NLL: 1.6844595647224019
LA NLL: 0.5417533579179888
LA NLL: 0.5512365860923402
LA NLL: 0.40136672058716155


### Subnetwork Inference using OBDSubnetMask (Pruning using optimal brain damage)

In [None]:
obd_subnet_results = []
for n_params_subnet in subnet_sizes:
    model_for_selection = copy.deepcopy(map_model)
    laplace_model_for_selection = Laplace(
        model=model_for_selection,
        likelihood="regression",
        subset_of_weights="all",
        hessian_structure="diag",
        sigma_noise=sigma,
        prior_mean=config.trainer.la.prior_mean,
    )

    subnetwork_mask = OBDSubnetMask(
        model_for_selection,
        n_params_subnet=n_params_subnet,
        diag_laplace_model=laplace_model_for_selection,
    )
    subnetwork_indices = subnetwork_mask.select(train_loader=train_dataloader)
    subnetwork_indices = torch.tensor(
        subnetwork_indices.cpu().numpy(), dtype=torch.long
    )

    model_copy = copy.deepcopy(map_model)
    la, prior_precision = trainer.train_la_posthoc(
        model=model_copy,
        dataloader=train_dataloader,
        subset_of_weights="subnetwork",
        hessian_structure="full",
        sigma_noise=sigma,
        prior_mean=config.trainer.la.prior_mean,
        prior_precisions=prior_precisions,
        subnetwork_indices=subnetwork_indices,
        val_dataloader=val_dataloader,
    )

    nll = trainer.evaluate_la(la, test_dataloader)
    obd_subnet_results.append(nll)
    print(f"LA NLL: {nll}")

LA NLL: 0.5884815892787665
LA NLL: 0.08248394014475549
LA NLL: -0.09442881380490104
LA NLL: -0.10528269093827654
LA NLL: -0.13510393987278477
LA NLL: -0.106723727820922


### Subnetwork Inference using KroneckerFactoredEigenSubnetMask

In [None]:
kfe_subnet_results = []
for n_params_subnet in subnet_sizes:
    model_for_selection = copy.deepcopy(map_model)
    laplace_model_for_selection = Laplace(
        model=model_for_selection,
        likelihood="regression",
        subset_of_weights="all",
        hessian_structure="kron",
        sigma_noise=sigma,
        prior_mean=config.trainer.la.prior_mean,
    )

    subnetwork_mask = KronckerFactoredEigenSubnetMask(
        model_for_selection,
        n_params_subnet=n_params_subnet,
        kron_laplace_model=laplace_model_for_selection,
    )
    subnetwork_indices = subnetwork_mask.select(train_loader=train_dataloader)
    subnetwork_indices = torch.tensor(
        subnetwork_indices.cpu().numpy(), dtype=torch.long
    )

    model_copy = copy.deepcopy(map_model)
    la, prior_precision = trainer.train_la_posthoc(
        model=model_copy,
        dataloader=train_dataloader,
        subset_of_weights="subnetwork",
        hessian_structure="full",
        sigma_noise=sigma,
        prior_mean=config.trainer.la.prior_mean,
        subnetwork_indices=subnetwork_indices,
        prior_precisions=prior_precisions,
        val_dataloader=val_dataloader,
    )

    nll = trainer.evaluate_la(la, test_dataloader)
    kfe_subnet_results.append(nll)
    print(f"LA NLL: {nll}")

LA NLL: 1.214360673716097
LA NLL: 0.4094849087821557
LA NLL: 0.2900903987165472
LA NLL: 0.2519949163304585
LA NLL: 0.15969015615952484
LA NLL: -0.07581899164076436


### Compare the results

In [None]:

fig = go.Figure(data=[
    go.Bar(name='LVD', x=[str(s) for s in subnet_sizes], y=largest_variance_subnet_results),
    go.Bar(name='KFE', x=subnet_sizes, y=kfe_subnet_results),
    go.Bar(name='OBD', x=subnet_sizes, y=obd_subnet_results),
    
])
fig.update_layout(barmode='group')
fig.update_layout(yaxis_title="NLL", xaxis_title="subnetwork size", title="Comparison of selection strategies (Snelson 1D)", hovermode="x")

fig.update_layout(autosize=False, width=500, height=500)
fig.show()


## Conclusion