# Subnetwork Inference

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import polars as pl
import numpy as np
import glob
import plotly.express as px

## Experiments using UCI gap datasets

### Reproduce results using the sub-network selection strategy proposed in Daxberger et al., 2020

In [5]:
def plot_nll(path, title):
    queries = []
    for file in glob.glob(path):
        q = pl.scan_ndjson(file)
        queries.append(q)
    df = pl.concat(queries)
    agg_df = df.groupby(["seed", "split",  "selection_strategy", "subset_size"]).agg(pl.mean("nll")).sort(["seed", "split",  "subset_size", "selection_strategy",]).collect()
    df = agg_df.select([pl.col("seed"), pl.col("split"), pl.col("selection_strategy").alias("Selection strategy"), pl.col("subset_size").alias("Subnetwork size"), pl.col("nll").alias("NLL")]).to_pandas()
    fig = px.box(df, x="Subnetwork size", y="NLL", color="Selection strategy", title=title, hover_data=["seed", "split"])
    fig.update_traces(quartilemethod="exclusive", showlegend=True) # or "inclusive", or "linear" by default
    fig.show()

#### Plot Mean and Standard Deviation of Negative Log Likelihoods of different selection strategies on wine-gap dataset



In [6]:
plot_nll(path="results/wine_GAP_*LA*.json", title="Mean NLL for wine-gap dataset")

#### Plot Mean and Standard Deviation of Negative Log Likelihoods of different selection strategies on wine dataset



In [7]:
plot_nll(path="results/wine_STANDARD_*LA*.json", title="Mean NLL for wine dataset")

#### Plot Mean and Standard Deviation of Negative Log Likelihoods of different selection strategies on kin8nm-gap dataset



In [6]:
plot_nll(path="results/kin8nm_GAP_*LA*.json", title="Mean NLL for kin8nm-gap dataset")

#### Plot Mean and Standard Deviation of Negative Log Likelihoods of different selection strategies on kin8nm dataset



In [7]:
plot_nll(path="results/kin8nm_STANDARD_*LA*.json", title="Mean NLL for kin8nm dataset")

## Conclusion

From the above experiments, we can conclude that the pruning methods can not be completely ruled out for subnetwork inference. But since the datasets we have used in the experiments are low dimensional, we can not say for sure that the pruning methods will work better than the approach proposed by the authors. It might be the case that the MAP models used in the experiments are too expressive for the size of the datasets and many of the weights in the network are redundant as is and pruning techniques can easily prune them out. KFAC based strategy seems to be a more reliable choice for subnetwork selection as compared to the method in the paper. It is computationally feasible for larger networks, and it also enables to do a more structured subnetwork inference by considering more complex covariances for example covariance in the same channel in a convolutional layer etc. We propose to extend the experiments to larger datasets and more complex networks and see if the results hold.