# Subnetwork Analysis

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import torch
import copy
import plotly.graph_objects as go

from torch.utils.data import ConcatDataset, DataLoader

from data.uci_datasets import UCIData
from main import set_seed, get_device
from util.plots import plot_data, plot_regression, plot_bayesian_regression
from models.nets import create_mlp
from trainer import ModelTrainer

from laplace import Laplace
from laplace.utils import LargestVarianceDiagLaplaceSubnetMask
from strategies.pruning import OBDSubnetMask, MNSubnetMask, SPRSubnetMask
from strategies.kfe import KronckerFactoredEigenSubnetMask



In [3]:
from hydra import initialize, compose
from omegaconf import OmegaConf
try:
    initialize(version_base=None, config_path="configuration")
except Exception as e:
    print(e)
config = compose(config_name="uci.yaml")
set_seed(config.seed)


## Train a MAP model

In [4]:
data = UCIData(config.data.path)
meta_data = data.get_metadata()
device = get_device()
train_dataloader, val_dataloader, test_dataloader = data.get_dataloaders(
        dataset=config.data.name,
        batch_size=config.trainer.batch_size,
        seed=config.data.seed,
        val_size=config.data.val_size,
        split_index=config.data.split_index,
        gap=(config.data.split == "GAP"),
    )
trainer = ModelTrainer(config.trainer, device=device)
  

model = create_mlp(
        input_size=meta_data[config.data.name]["input_dim"],
        hidden_sizes=config.model.hidden_sizes,
        output_size=meta_data[config.data.name]["output_dim"],
    )
model = model.to(device=device, dtype=torch.float64)
map_model, sigma = trainer.train(
        model=model,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
    )

## Overlap of selected parameters between different strategies

In [5]:
subnet_sizes = [30, 90, 300, 600, 1200, 1800]


In [7]:
# lvd_obd_overlap = []
# lvd_kfe_overlap = []
# obd_kfe_overlap = []
mn_obd_overlap = []
spr_obd_overlap = []
mn_spr_overlap = []

for n_params_subnet in subnet_sizes:
    model_for_selection = copy.deepcopy(map_model)
    lvd_laplace_for_selection = Laplace(
        model=model_for_selection,
        likelihood="regression",
        subset_of_weights="all",
        hessian_structure="diag",
        sigma_noise=sigma,
        prior_mean=config.trainer.la.prior_mean,
    )

    lvd_mask = LargestVarianceDiagLaplaceSubnetMask(
        model_for_selection,
        n_params_subnet=n_params_subnet,
        diag_laplace_model=lvd_laplace_for_selection,
    )
    lvd_indices = lvd_mask.select(train_loader=train_dataloader)
    lvd_indices = lvd_indices.cpu().numpy()

    model_for_selection = copy.deepcopy(map_model)
    mn_laplace_for_selection = Laplace(
        model=model_for_selection,
        likelihood="regression",
        subset_of_weights="all",
        hessian_structure="diag",
        sigma_noise=sigma,
        prior_mean=config.trainer.la.prior_mean,
    )

    mn_mask = MNSubnetMask(
        model_for_selection,
        n_params_subnet=n_params_subnet,
        diag_laplace_model=mn_laplace_for_selection,
    )
    mn_indices = mn_mask.select(train_loader=train_dataloader)
    mn_indices = mn_indices.cpu().numpy()

    model_for_selection = copy.deepcopy(map_model)
    spr_laplace_for_selection = Laplace(
        model=model_for_selection,
        likelihood="regression",
        subset_of_weights="all",
        hessian_structure="diag",
        sigma_noise=sigma,
        prior_mean=config.trainer.la.prior_mean,
    )

    spr_mask = SPRSubnetMask(
        model_for_selection,
        n_params_subnet=n_params_subnet,
        diag_laplace_model=spr_laplace_for_selection,
    )
    spr_indices = spr_mask.select(train_loader=train_dataloader)
    spr_indices = spr_indices.cpu().numpy()

    model_for_selection = copy.deepcopy(map_model)
    obd_laplace_for_selection = Laplace(
        model=model_for_selection,
        likelihood="regression",
        subset_of_weights="all",
        hessian_structure="diag",
        sigma_noise=sigma,
        prior_mean=config.trainer.la.prior_mean,
    )

    obd_mask = OBDSubnetMask(
        model_for_selection,
        n_params_subnet=n_params_subnet,
        diag_laplace_model=lvd_laplace_for_selection,
    )
    obd_indices = obd_mask.select(train_loader=train_dataloader)
    obd_indices = obd_indices.cpu().numpy()
   
    model_for_selection = copy.deepcopy(map_model)       
    kfe_laplace_for_selection = Laplace(
                    model=model_for_selection,
                    likelihood="regression",
                    subset_of_weights="all",
                    hessian_structure="kron",
                    sigma_noise=sigma,
                    prior_precision=10,
                    prior_mean=config.trainer.la.prior_mean,
                )

    kfe_mask = KronckerFactoredEigenSubnetMask(
                    model_for_selection,
                    n_params_subnet=n_params_subnet,
                    kron_laplace_model=kfe_laplace_for_selection,
            )
    kfe_indices = kfe_mask.select(train_loader=train_dataloader)
    kfe_indices = kfe_indices.cpu().numpy()

   
    mn_obd_overlap.append(len(np.intersect1d(mn_indices, obd_indices))/n_params_subnet)
    spr_obd_overlap.append(len(np.intersect1d(spr_indices, obd_indices))/n_params_subnet)
    mn_spr_overlap.append(len(np.intersect1d(mn_indices, spr_indices))/n_params_subnet)

In [8]:
fig = go.Figure()
fig.add_trace(go.Scatter(x=subnet_sizes, y=mn_obd_overlap,
                    mode='lines',
                    name='MN_OBD', line=dict(width=3.5)))
fig.add_trace(go.Scatter(x=subnet_sizes, y=spr_obd_overlap,
                    mode='lines',
                    name='SPR_OBD', line=dict(width=3.5)))
fig.add_trace(go.Scatter(x=subnet_sizes, y=mn_spr_overlap,
                    mode='lines',
                    name='MN_SPR', line=dict(width=3.5)))
fig.update_layout(title="Overlap of selected parameters between different strategies")

fig.show()

## Selected parameters by layer

In [8]:
def plot_indices_by_layer(layers, obd_indices, lvd_indices, kfe_indices, title):
    fig = go.Figure(data=[
    go.Bar(name='SPR', x=layers, y=spr_indices),
    go.Bar(name='MN', x=layers, y=mn_indices),
    go.Bar(name='OBD', x=layers, y=obd_indices),
    ])
    fig.update_layout(barmode='group', title=title)
    fig.show()

In [9]:
indices_by_layer = {"layer_0" : list(range(0, 600)), "layer_1" : list(range(600, 3100)), "layer_2" : list(range(3100, 3201)) }

for n_params_subnet in subnet_sizes:
    model_for_selection = copy.deepcopy(map_model)
    lvd_laplace_for_selection = Laplace(
        model=model_for_selection,
        likelihood="regression",
        subset_of_weights="all",
        hessian_structure="diag",
        sigma_noise=sigma,
        prior_mean=config.trainer.la.prior_mean,
    )

    lvd_mask = LargestVarianceDiagLaplaceSubnetMask(
        model_for_selection,
        n_params_subnet=n_params_subnet,
        diag_laplace_model=lvd_laplace_for_selection,
    )
    lvd_indices = lvd_mask.select(train_loader=train_dataloader)
    lvd_indices = lvd_indices.cpu().numpy()

    model_for_selection = copy.deepcopy(map_model)
    obd_laplace_for_selection = Laplace(
        model=model_for_selection,
        likelihood="regression",
        subset_of_weights="all",
        hessian_structure="diag",
        sigma_noise=sigma,
        prior_mean=config.trainer.la.prior_mean,
    )

    obd_mask = OBDSubnetMask(
        model_for_selection,
        n_params_subnet=n_params_subnet,
        diag_laplace_model=lvd_laplace_for_selection,
    )
    obd_indices = obd_mask.select(train_loader=train_dataloader)
    obd_indices = obd_indices.cpu().numpy()
   
    model_for_selection = copy.deepcopy(map_model)       
    kfe_laplace_for_selection = Laplace(
                    model=model_for_selection,
                    likelihood="regression",
                    subset_of_weights="all",
                    hessian_structure="kron",
                    sigma_noise=sigma,
                    prior_precision=10,
                    prior_mean=config.trainer.la.prior_mean,
                )

    kfe_mask = KronckerFactoredEigenSubnetMask(
                    model_for_selection,
                    n_params_subnet=n_params_subnet,
                    kron_laplace_model=kfe_laplace_for_selection,
            )
    kfe_indices = kfe_mask.select(train_loader=train_dataloader)
    kfe_indices = kfe_indices.cpu().numpy()

    obd_indices_by_layer = [len(np.intersect1d(indices, obd_indices)) for l, indices in indices_by_layer.items()]
    lvd_indices_by_layer = [len(np.intersect1d(indices, lvd_indices)) for l, indices in indices_by_layer.items()]
    kfe_indices_by_layer = [len(np.intersect1d(indices, kfe_indices)) for l, indices in indices_by_layer.items()]
    plot_indices_by_layer(list(indices_by_layer.keys()), obd_indices_by_layer, lvd_indices_by_layer, kfe_indices_by_layer, f"Selected parameters by layer between different strategies for {n_params_subnet} parameters")
    
        

