# Subnetwork Inference on Snelson 1D dataset

In [1]:
%load_ext autoreload
%autoreload 2

## Imports and configuration

In [2]:
import numpy as np
import torch
import copy
import plotly.graph_objects as go

from plotly.subplots import make_subplots

from torch.utils.data import ConcatDataset, DataLoader

from data.snelson1d import Snelson1D
from main import set_seed, get_device
from util.plots import plot_data, plot_regression, plot_bayesian_regression
from models.nets import create_mlp
from trainer import ModelTrainer

from laplace import Laplace
from laplace.utils import LargestVarianceDiagLaplaceSubnetMask
from strategies.pruning import OBDSubnetMask
from strategies.kfe import KronckerFactoredEigenSubnetMask



### Read configuration for Snelson experiments 

In [3]:
from hydra import initialize, compose
from omegaconf import OmegaConf
try:
    initialize(version_base=None, config_path="configuration")
except Exception as e:
    print(e)
config = compose(config_name="snelson.yaml")
set_seed(config.seed)


## Import data and visualize

In [4]:
snelson1d = Snelson1D(config.data.path)
train_dataloader, val_dataloader, test_dataloader = snelson1d.get_dataloaders(batch_size=config.trainer.batch_size, val_size=config.data.val_size, random_state=config.data.seed)

In [5]:
train_X = np.concatenate([train_dataloader.dataset.X.numpy(), val_dataloader.dataset.X.numpy()], axis=0).squeeze()
train_y = np.concatenate([train_dataloader.dataset.y.numpy(), val_dataloader.dataset.y.numpy()], axis=0).squeeze()
test_X = test_dataloader.dataset.X.numpy().squeeze()
test_y = test_dataloader.dataset.y.numpy().squeeze()

plot_data(train_X, train_y, test_X, test_y, title="Snelson1D data")


## Subnetwork inference and comparison

### Train a MAP model

In [6]:
input_size = 1
output_size = 1
device = get_device()
print(f"Using device: {device}")
model = create_mlp(
            input_size=input_size,
            hidden_sizes=config.model.hidden_sizes,
            output_size=output_size,
        )
model = model.to(device=device, dtype=torch.float64)
print(f"Using model: {model}")
trainer = ModelTrainer(config.trainer, device=device)

Using device: cpu
Using model: Sequential(
  (0): Linear(in_features=1, out_features=50, bias=True)
  (1): ReLU()
  (2): Linear(in_features=50, out_features=50, bias=True)
  (3): ReLU()
  (4): Linear(in_features=50, out_features=1, bias=True)
)


In [7]:
map_model, sigma = trainer.train(
            model=model,
            train_dataloader=train_dataloader,
            val_dataloader=val_dataloader,
        )
print(f"Using sigma={sigma}")

Using sigma=0.08778076618909836


In [8]:
map_nll, map_err = trainer.evaluate(
            model=map_model, sigma=sigma, dataloader=test_dataloader
        )
print(f"MAP NLL: {map_nll}")

MAP NLL: 2.1785243352013874


In [9]:
prior_precisions = [0.1, 1.0, 2.0, 5.0]

In [10]:
def bayesian_regression(model, train_dataloader, test_dataloader, title):
    X_train = train_dataloader.dataset.X.numpy().squeeze()
    y_train = train_dataloader.dataset.y.numpy().squeeze()
    X_test = test_dataloader.dataset.X.numpy().squeeze()
    y_test = test_dataloader.dataset.y.numpy().squeeze()
    X_test = np.concatenate([X_train, X_test]).reshape(-1, 1)
    X = torch.from_numpy(X_test).to(device=device, dtype=torch.float64)
    f_mu, f_var = model(x=X)
    f_mu = f_mu.detach().squeeze().cpu().numpy()
    pred_std = torch.sqrt(f_var.squeeze() + model.sigma_noise**2).detach().cpu().numpy()
    return plot_bayesian_regression(X_train=X_train, y_train=y_train, X_test=X.squeeze().detach().cpu().numpy(), y_test=f_mu, y_std=pred_std, title=title)


### Last layer diagonal approximation

In [11]:
model_copy = copy.deepcopy(map_model)

last_layer_diag_la, prior_precision = trainer.train_la_posthoc(
            model=model_copy,
            dataloader=train_dataloader,
            subset_of_weights="last_layer",
            hessian_structure="diag",
            sigma_noise=sigma,
            prior_mean=config.trainer.la.prior_mean,
            prior_precisions=prior_precisions,
            val_dataloader=val_dataloader,
        )
last_layer_diag_la_nll = trainer.evaluate_la(last_layer_diag_la, test_dataloader)
print(f"Last layer diag covariance approximation NLL: {last_layer_diag_la_nll}")

Last layer diag covariance approximation NLL: 0.8896953326786995


In [12]:
fig = bayesian_regression(last_layer_diag_la, train_dataloader, test_dataloader, title="Last layer diag covariance approximation")
fig.show()

### Last layer full covariance approximation

In [13]:
model_copy = copy.deepcopy(map_model)

last_layer_full_la, prior_precision = trainer.train_la_posthoc(
            model=model_copy,
            dataloader=train_dataloader,
            subset_of_weights="last_layer",
            hessian_structure="full",
            sigma_noise=sigma,
            prior_mean=config.trainer.la.prior_mean,
            prior_precisions=prior_precisions,
            val_dataloader=val_dataloader,
        )
last_layer_full_la_nll = trainer.evaluate_la(last_layer_full_la, test_dataloader)
print(f"Last layer full covariance approximation NLL:: {last_layer_full_la_nll}")

Last layer full covariance approximation NLL:: 0.5192822066399496


In [14]:
fig = bayesian_regression(last_layer_full_la, train_dataloader, test_dataloader, title="Last layer diag covariance approximation")
fig.show()

### All layers full approximation

In [15]:
model_copy = copy.deepcopy(map_model)

all_full_la, prior_precision = trainer.train_la_posthoc(
            model=model_copy,
            dataloader=train_dataloader,
            subset_of_weights="all",
            hessian_structure="full",
            sigma_noise=sigma,
            prior_mean=config.trainer.la.prior_mean,
            prior_precisions=prior_precisions,
            val_dataloader=val_dataloader,
        )
all_full_la_nll = trainer.evaluate_la(all_full_la, test_dataloader)
print(f"All layers full covariance approximation NLL:: {all_full_la_nll}")

All layers full covariance approximation NLL:: 0.038498227599171445


### All layers diag approximation.

In [16]:
model_copy = copy.deepcopy(map_model)

all_diag_la, prior_precision = trainer.train_la_posthoc(
            model=model_copy,
            dataloader=train_dataloader,
            subset_of_weights="all",
            hessian_structure="diag",
            sigma_noise=sigma,
            prior_mean=config.trainer.la.prior_mean,
            prior_precisions=prior_precisions,
            val_dataloader=val_dataloader,
        )
all_diag_la_nll = trainer.evaluate_la(all_diag_la, test_dataloader)
print(f"All layers diag covariance approximation NLL:: {all_diag_la_nll}")

All layers diag covariance approximation NLL:: 0.17705288379140538


In [17]:
subnet_sizes = [30, 90, 300, 600, 1200, 1800]


### Subnetwork Inference using LargestVarianceDiagLaplaceSubnetMask

In [18]:
largest_variance_subnet_results = []
for n_params_subnet in subnet_sizes:
    model_for_selection = copy.deepcopy(map_model)
    laplace_model_for_selection = Laplace(
        model=model_for_selection,
        likelihood="regression",
        subset_of_weights="all",
        hessian_structure="diag",
        sigma_noise=sigma,
        prior_mean=config.trainer.la.prior_mean,
    )

    subnetwork_mask = LargestVarianceDiagLaplaceSubnetMask(
        model_for_selection,
        n_params_subnet=n_params_subnet,
        diag_laplace_model=laplace_model_for_selection,
    )
    subnetwork_indices = subnetwork_mask.select(train_loader=train_dataloader)
    subnetwork_indices = torch.tensor(
        subnetwork_indices.cpu().numpy(), dtype=torch.long
    )

    model_copy = copy.deepcopy(map_model)
    la, prior_precision = trainer.train_la_posthoc(
        model=model_copy,
        dataloader=train_dataloader,
        subset_of_weights="subnetwork",
        hessian_structure="full",
        sigma_noise=sigma,
        prior_mean=config.trainer.la.prior_mean,
        prior_precisions=prior_precisions,
        subnetwork_indices=subnetwork_indices,
        val_dataloader=val_dataloader,
    )

    nll = trainer.evaluate_la(la, test_dataloader)
    largest_variance_subnet_results.append(nll)
    print(f"LA NLL: {nll}")

    fig = bayesian_regression(la, train_dataloader, test_dataloader, title=f"LVD with subnetwork size {n_params_subnet}")
    fig.write_image(f"figures/LVD_{n_params_subnet}.png")

LA NLL: 1.3187190531038138
LA NLL: 0.8784570077040961
LA NLL: 0.5605159126039748
LA NLL: 0.5333250890330301
LA NLL: 0.5410132977950712
LA NLL: 0.47584607316075167


### Subnetwork Inference using OBDSubnetMask (Pruning using optimal brain damage)

In [19]:
obd_subnet_results = []
for n_params_subnet in subnet_sizes:
    model_for_selection = copy.deepcopy(map_model)
    laplace_model_for_selection = Laplace(
        model=model_for_selection,
        likelihood="regression",
        subset_of_weights="all",
        hessian_structure="diag",
        sigma_noise=sigma,
        prior_mean=config.trainer.la.prior_mean,
    )

    subnetwork_mask = OBDSubnetMask(
        model_for_selection,
        n_params_subnet=n_params_subnet,
        diag_laplace_model=laplace_model_for_selection,
    )
    subnetwork_indices = subnetwork_mask.select(train_loader=train_dataloader)
    subnetwork_indices = torch.tensor(
        subnetwork_indices.cpu().numpy(), dtype=torch.long
    )

    model_copy = copy.deepcopy(map_model)
    la, prior_precision = trainer.train_la_posthoc(
        model=model_copy,
        dataloader=train_dataloader,
        subset_of_weights="subnetwork",
        hessian_structure="full",
        sigma_noise=sigma,
        prior_mean=config.trainer.la.prior_mean,
        prior_precisions=prior_precisions,
        subnetwork_indices=subnetwork_indices,
        val_dataloader=val_dataloader,
    )

    nll = trainer.evaluate_la(la, test_dataloader)
    obd_subnet_results.append(nll)
    print(f"LA NLL: {nll}")
    fig = bayesian_regression(la, train_dataloader, test_dataloader, title=f"OBD with subnetwork size {n_params_subnet}")
    fig.write_image(f"figures/OBD_{n_params_subnet}.png")

LA NLL: 0.887423153731364
LA NLL: -0.10934171526389458
LA NLL: -0.1311254595824427
LA NLL: -0.11773724600288567
LA NLL: -0.10200218936993706
LA NLL: -0.024923823348913864


### Subnetwork Inference using KroneckerFactoredEigenSubnetMask

In [20]:
kfe_subnet_results = []
for n_params_subnet in subnet_sizes:
    model_for_selection = copy.deepcopy(map_model)
    laplace_model_for_selection = Laplace(
        model=model_for_selection,
        likelihood="regression",
        subset_of_weights="all",
        hessian_structure="kron",
        sigma_noise=sigma,
        prior_mean=config.trainer.la.prior_mean,
    )

    subnetwork_mask = KronckerFactoredEigenSubnetMask(
        model_for_selection,
        n_params_subnet=n_params_subnet,
        kron_laplace_model=laplace_model_for_selection,
    )
    subnetwork_indices = subnetwork_mask.select(train_loader=train_dataloader)
    subnetwork_indices = torch.tensor(
        subnetwork_indices.cpu().numpy(), dtype=torch.long
    )

    model_copy = copy.deepcopy(map_model)
    la, prior_precision = trainer.train_la_posthoc(
        model=model_copy,
        dataloader=train_dataloader,
        subset_of_weights="subnetwork",
        hessian_structure="full",
        sigma_noise=sigma,
        prior_mean=config.trainer.la.prior_mean,
        subnetwork_indices=subnetwork_indices,
        prior_precisions=prior_precisions,
        val_dataloader=val_dataloader,
    )

    nll = trainer.evaluate_la(la, test_dataloader)
    kfe_subnet_results.append(nll)
    print(f"LA NLL: {nll}")
    fig = bayesian_regression(la, train_dataloader, test_dataloader, title=f"KFE with subnetwork size {n_params_subnet}")
    fig.write_image(f"figures/KFE_{n_params_subnet}.png")

LA NLL: 0.6282977502766957
LA NLL: 0.151787379523569
LA NLL: 0.050763550564994674
LA NLL: 0.09392451283720997
LA NLL: -0.07065211042267279
LA NLL: -0.022591145445839563


### Compare the results

In [21]:

fig = go.Figure(data=[
    go.Bar(name='KFE', x=[str(s) for s in subnet_sizes], y=kfe_subnet_results),
    go.Bar(name='LVD', x=[str(s) for s in subnet_sizes], y=largest_variance_subnet_results), 
    go.Bar(name='OBD', x=[str(s) for s in subnet_sizes], y=obd_subnet_results),
    go.Bar(name='Full Covariance', x=['All', 'LL'], y=[all_full_la_nll, last_layer_full_la_nll]),
    go.Bar(name='Diag', x=['All', 'LL'], y=[all_diag_la_nll, last_layer_diag_la_nll]),
    
    
])
fig.update_layout(barmode='group')
fig.update_layout(yaxis_title="NLL", xaxis_title="subnetwork size", title="Comparison of selection strategies (Snelson 1D)", hovermode="x")

fig.update_layout(autosize=True, width=1000, height=500)
fig.show()


## Conclusion

- OBDSubnetMask is the best method for this dataset. It consistently outperforms the other methods.
    - At the start of the project, we didn't expect the pruning methods to work better than the appoach proposed by the authors as the pruning methods are not designed to retain the weights with the largest variance. One of the core ideas of the paper is that by doing a subnetwork inference over the weigts with maximum variance, uncertainty is captured better. However, on the Snelson 1D dataset, the pruning methods work better than the approach proposed by the authors.
    - One reason for this could be that the dataset is low dimensional and the network used to model the data is too expressive and there are too many parameters with low saliency and Optimal Brain Damage is able to easily prune them out.
    - This led to us doing some experiments on finding the overlap between the subnetworks selected by different strategies. 
- KroneckerFactoredEigenSubnetMask didn't outperform OBDSubnetMask but it consistently outperformed LargestVarianceDiagLaplaceSubnetMask
    - The results are inline with the hypothesis that KroneckerFactoredEigenSubnetMask is a better approximation than LargestVarianceDiagLaplaceSubnetMask.
    - Taking into consideration the covariances between weigts within each layer seems to be important for subnetwork inference.
    - It is also quite efficient and computationally feasible for larger networks.
    - It also enables to do a more structured subnetwork inference by considering more complex covariances for example covariance in the same channel in a convolutional layer etc.

- Overall the LargestVarianceDiagLaplaceSubnetMask is the worst performing method on Snelson 1D and the same results repeated in other experiments as well.