Skip to content

danielkelshaw/ConcreteDropout

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

21 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ConcreteDropout

Build Status

PyTorch implementation of Concrete Dropout

This repository provides an implementation of the theory described in the Concrete Dropout paper. The code provides a simple PyTorch interface which ensures that the module can be integrated into existing code with ease.

  • Python 3.6+
  • MIT License

Overview:

Obtaining reliable uncertainty estimates is a challenge which requires a grid-search over various dropout probabilities - for larger models this can be computationally prohibitive. The Concrete Dropout paper suggests a novel dropout variant which improves performance and yields better uncertainty estimates.

Concrete Dropout uses the approach of optimising the dropout probability through gradient descent in order to minimise an objective wrt. that parameter. Dropout can be viewed as as an approximating distribution to the posterior, q(w). Using this interpretation it is possible to add a regularisation term to the loss function which is dependant on the KL Divergence, KL[q(w)||p(w)]; this ensures that the posterior does not deviate too far from the prior. As is often the case, the KL Divergence is computationally intractable and as such an approximation is developed - details of this can be seen in equations [2-4] in the paper.

In typical dropout the probability is modelled as a Bernoulli random variable - unfortunately this does not play well with the re-parameterisation trick which is required to calculate the derivative of the objective. To allow the derivative to be calculated, a continous relaxation of the discrete Bernoulli distribution is used - specifically the Concrete distribution relaxation. This has a simple parameterisation which reduces to a simple sigmoid distribution as seen in equation [5].

Through use of the Concrete relaxation it is now possible to compute the derivatives of the objective with help from the re-parameterisation trick and optimise the dropout probability through gradient descent.

Example:

An example of ConcreteDropout has been implemented in mnist_example.py - this example can be run with:

python3 mnist_example.py

MNIST Results

References:

@misc{gal2017concrete,
    title={Concrete Dropout},
    author={Yarin Gal and Jiri Hron and Alex Kendall},
    year={2017},
    eprint={1705.07832},
    archivePrefix={arXiv},
    primaryClass={stat.ML}
}
Code by Yarin Gal, author of the paper.
PyTorch implementation of Concrete Dropout
Made by Daniel Kelshaw