Probabilistic Network Ensemble
PyTorch implementations that trains an ensemble of probabilistic neural networks to fit data of toy problems, effectively replicating the results from
- Lakshminarayanan et al., Simple and Scalable Predictive Uncertainty Estimation using Deep Ensembles (2017)
- Chua et al., Deep Reinforcement Learning in a Handful of Trials using Probabilistic Dynamic Models
Currently implements three data sets:
- A simple sine wave
y = sin(x)
- A simple curve
y = x**3
- Simple two dimensional system
z = sin(x)cos(y)
Fig 2: Two dimensional regression with clustered training data. Left: Contour lines display one standard deviation from the mean, indicating low valued plateau's around the training data. Right: Standard deviation in log-scale. Fig 3: Ground truth, ensemble mean and ensemble standard deviation with the same training data. Notice the mean is accurate where there is training data available, but inaccurate outside. However this is reflected by the increase in standard deviation.
In addition to the standard Python 3 libraries, to run the code you will need:
Executing the code
Executing the code is done through the
run.sh bash file which requires the script to execute and has an additional plotting flag. For general use please execute
bash run.sh toy.py plot
which trains the network and plots the figures. After training, the output models are stored in the
/data/ directory and can be plotten by simply calling
python plot.py. Additionally, the plots can be saved with an additional
save argument, e.g.
python plot.py save saves the figures in the
- Note that the implementation is rather naive and might not work for different data sets, other architectures, different hyperparameters, etc.
- To ignore the parallel computation one can simply run the code with
python toy.pyand increasing the
ensemble_sizeto any desired size. This will execute the program on a single core.
- Number of ensembles can be increased by increasing the number of cores within
run.sh. I have used 4 cores, since my laptop has 4 cores.
For any further questions, please do not hesitate to contact me.