This repository contains a JAX implementation of Variational Neural Networks (VNNs) and Epistemic Neural Networks (ENN) experiments for Variational Neural Networks paper presented in IJCNN 2023 (citation for the published paper is presented below).
Bayesian Neural Networks (BNNs) provide a tool to estimate the uncertainty of a neural network by considering a distribution over weights and sampling different models for each input. In this paper, we propose a method for uncertainty estimation in neural networks called Variational Neural Network that, instead of considering a distribution over weights, generates parameters for the output distribution of a layer by transforming its inputs with learnable sub-layers. In uncertainty quality estimation experiments, we show that VNNs achieve better uncertainty quality than Monte Carlo Dropout or Bayes By Backpropagation methods.
Use run_example.sh
to train a single model and evaluate its uncertainty quality. Results are saved in results
folder and can be visualized by tools/create_*_plots.py
If you use this work for your research, you can cite it as:
@article{oleksiienko2022vnntorchjax,
title = {Variational Neural Networks implementation in Pytorch and JAX},
author = {Oleksiienko, Illia and Tran, Dat Thanh and Iosifidis, Alexandros},
journal = {Software Impacts},
volume = {14},
pages = {100431},
year = {2022},
}
@article{oleksiienko2023vnn,
title={Variational Neural Networks},
author = {Oleksiienko, Illia and Tran, Dat Thanh and Iosifidis, Alexandros},
journal = {Procedia Computer Science},
volume = {222C},
pages = {104-113},
year = {2023},
}
A library for uncertainty representation and training in neural networks.
Many applications in deep learning requires or benefit from going beyond a point estimte and representing uncertainty about the model. The coherent use of Bayes’ rule and probability theory are the gold standard for updating beliefs and estimating uncertainty. But exact computation quickly becomes infeasible for even simple problems. Modern machine learning has developed an effective toolkit for learning in high-dimensional using a simple and coherent convention. Epistemic neural network (ENN) is a library that provides a similarly simple and coherent convention for defining and training neural networks that represent uncertainty over a hypothesis class of models.
In a supervised setting, For input x_i ∈ X
and
outputs y_i ∈ Y
a point estimate f_θ(x)
is trained by fitting the
observed data D = {(xi, yi) for i = 1, ..., N}
by minimizing a loss
function l(θ, D) ∈ R
. In epistemic neural networks we
introduce the concept of an epistemic index z ∈ I ⊆ R^{n_z}
distributed
according to some reference distribution p_z(·)
. An augmented epistemic
function approximator then takes the form f_θ(x, z)
; where the function
class fθ(·, z)
is a neural network. The index z
allows unambiguous
identification of a corresponding function value and sampling z
corresponds
to sampling from the hypothesis class of functions.
On some level, ENNs are purely a notational convenience and most existing
approaches to dealing with uncertainty in deep learning can be rephrased in
this way. For example, an ensemble of point estimates {f_θ1, ..., f_θK }
can be viewed as an ENN with θ = (θ1, .., θK)
, z ∈ {1, .., K}
, and
f_θ(x, z) := f_θz(x)
. However, this simplicity hides a deeper insight: that
the process of epistemic update itself can be tackled through the tools of
machine learning typically reserved for point estimates, through the addition
of this epistemic index. Further, since these machine learning tools were
explicitly designed to scale to large and complex problems, they might
provide tractable approximations to large scale Bayesian inference even where
the exact computations are intractable.
For a more comprehensive overview, see the accompanying paper.
To reproduce the experiments from our paper please see experiments/neurips_2021
.
You can get started in our colab tutorial without installing anything on your machine.
We have tested ENN
on Python 3.7. To install the dependencies:
-
Optional: We recommend using a Python virtual environment to manage your dependencies, so as not to clobber your system installation:
python3 -m venv enn source enn/bin/activate pip install --upgrade pip setuptools
-
Install
ENN
directly from github:pip install git+https://github.com/deepmind/enn
-
Test that you can load
ENN
by training a simple ensemble ENN.from acme.utils.loggers.terminal import TerminalLogger from enn import losses from enn import networks from enn import supervised from enn.supervised import regression_data import optax # A small dummy dataset dataset = regression_data.make_dataset() # Logger logger = TerminalLogger('supervised_regression') # ENN enn = networks.MLPEnsembleMatchedPrior( output_sizes=[50, 50, 1], num_ensemble=10, ) # Loss loss_fn = losses.average_single_index_loss( single_loss=losses.L2LossWithBootstrap(), num_index_samples=10 ) # Optimizer optimizer = optax.adam(1e-3) # Train the experiment experiment = supervised.Experiment( enn, loss_fn, optimizer, dataset, seed=0, logger=logger) experiment.train(FLAGS.num_batch)
More examples can be found in the colab tutorial.
- Optional: run the tests by executing
./test.sh
from ENN root directory.
If you use ENN
in your work, please cite the accompanying paper:
@inproceedings{,
title={Epistemic Neural Networks},
author={Ian Osband, Zheng Wen, Mohammad Asghari, Morteza Ibrahimi, Xiyuan Lu, Benjamin Van Roy},
booktitle={arxiv},
year={2021},
url={https://arxiv.org/abs/2107.08924}
}