Skip to content

Commit

Permalink
Added VOGN results
Browse files Browse the repository at this point in the history
  • Loading branch information
Didrik Nielsen committed Aug 3, 2018
1 parent 323d25b commit e89d4b5
Show file tree
Hide file tree
Showing 17 changed files with 1,207 additions and 36 deletions.
4 changes: 1 addition & 3 deletions README.md
@@ -1,4 +1,2 @@
# vadam
Code for ICML 2018 paper on "Fast and Scalable Bayesian Deep Learning by Weight-Perturbation in Adam" by Khan, Nielsen, Tangkaratt, Lin, Gal, and Srivastava. (https://arxiv.org/abs/1806.04854)

All code to reproduce results in the paper will be available soon.
Code for ICML 2018 paper on "[Fast and Scalable Bayesian Deep Learning by Weight-Perturbation in Adam](https://arxiv.org/abs/1806.04854)" by Khan, Nielsen, Tangkaratt, Lin, Gal, and Srivastava.
8 changes: 7 additions & 1 deletion pytorch/README.md
@@ -1,6 +1,6 @@
# Install

In the folder containing `setup.py`, run
In the folder containing `setup.py`, run
```
pip install --user -e .
```
Expand All @@ -22,3 +22,9 @@ The code for the UCI experiments together with obtained results can be found in
<img src="uci_code/plots/uci_rmse_boston-page-001.jpg" width="200"><img src="uci_code/plots/uci_rmse_concrete-page-001.jpg" width="200"><img src="uci_code/plots/uci_rmse_energy-page-001.jpg" width="200"><img src="uci_code/plots/uci_rmse_kin8nm-page-001.jpg" width="200">

<img src="uci_code/plots/uci_rmse_naval-page-001.jpg" width="200"><img src="uci_code/plots/uci_rmse_powerplant-page-001.jpg" width="200"><img src="uci_code/plots/uci_rmse_wine-page-001.jpg" width="200"><img src="uci_code/plots/uci_rmse_yacht-page-001.jpg" width="200">

# Reproducing VOGN Experiments

The code for the VOGN experiments together with obtained results can be found in `vogn_code`. Using these results you should be able to reprodoce these figures:

<img src="vogn_code/plots/plot_bs1_mc1-page-001.jpg" width="200"><img src="vogn_code/plots/plot_bs1_mc16-page-001.jpg" width="200"><img src="vogn_code/plots/plot_bs128_mc16_legend-page-001.jpg" width="200">
34 changes: 30 additions & 4 deletions pytorch/vadam/datasets.py
Expand Up @@ -2,6 +2,8 @@
import numpy as np
import torch.utils.data as data
from torch.utils.data.dataloader import DataLoader
import torchvision.datasets as dset
import torchvision.transforms as transforms

import sklearn.model_selection as modsel

Expand Down Expand Up @@ -37,7 +39,21 @@ class Dataset():
def __init__(self, data_set, data_folder=DEFAULT_DATA_FOLDER):
super(type(self), self).__init__()

if data_set == "australian_presplit":
if data_set == "mnist":
self.train_set = dset.MNIST(root = data_folder,
train = True,
transform = transforms.ToTensor(),
download = True)

self.test_set = dset.MNIST(root = data_folder,
train = False,
transform = transforms.ToTensor())

self.task = "classification"
self.num_features = 28 * 28
self.num_classes = 10

elif data_set == "australian_presplit":
self.train_set = AustralianPresplit(root = data_folder,
train = True)
self.test_set = AustralianPresplit(root = data_folder,
Expand All @@ -53,7 +69,7 @@ def __init__(self, data_set, data_folder=DEFAULT_DATA_FOLDER):
self.test_set = BreastCancerPresplit(root = data_folder,
train = False)

self.task = "classification_presplit"
self.task = "classification"
self.num_features = 10
self.num_classes = 2

Expand Down Expand Up @@ -215,7 +231,17 @@ def __init__(self, data_set, n_splits=3, seed=None, data_folder=DEFAULT_DATA_FOL
self.seed = seed
self.current_split = 0

if data_set == "australian_presplit":
if data_set == "mnist":
self.data = dset.MNIST(root = data_folder,
train = True,
transform = transforms.ToTensor(),
download = True)

self.task = "classification"
self.num_features = 28 * 28
self.num_classes = 10

elif data_set == "australian_presplit":
self.data = AustralianPresplit(root = data_folder,
train = True)

Expand All @@ -227,7 +253,7 @@ def __init__(self, data_set, n_splits=3, seed=None, data_folder=DEFAULT_DATA_FOL
self.data = BreastCancerPresplit(root = data_folder,
train = True)

self.task = "classification_presplit"
self.task = "classification"
self.num_features = 10
self.num_classes = 2

Expand Down
126 changes: 99 additions & 27 deletions pytorch/vadam/models.py
Expand Up @@ -31,7 +31,7 @@ def __init__(self, input_size, hidden_sizes, output_size, act_func="relu"):

# Define layers
if len(hidden_sizes) == 0:
# Logitstic regression
# Linear model
self.hidden_layers = []
self.output_layer = nn.Linear(self.input_size, self.output_size)
else:
Expand All @@ -40,12 +40,13 @@ def __init__(self, input_size, hidden_sizes, output_size, act_func="relu"):
self.output_layer = nn.Linear(hidden_sizes[-1], self.output_size)

def forward(self, x):
x = x.view(-1,self.input_size)
out = x
for layer in self.hidden_layers:
out = self.act(layer(out))
z = self.output_layer(out)
if self.squeeze_output:
z = torch.squeeze(z)
z = torch.squeeze(z).view([-1])
return z


Expand All @@ -54,7 +55,7 @@ def forward(self, x):
#############################

class BNN(nn.Module):
def __init__(self, input_size, hidden_sizes, output_size, act_func, prior_prec=1.0, prec_init=1.0):
def __init__(self, input_size, hidden_sizes, output_size, act_func="relu", prior_prec=1.0, prec_init=1.0):
super(type(self), self).__init__()
self.input_size = input_size
sigma_prior = 1.0/math.sqrt(prior_prec)
Expand All @@ -65,11 +66,22 @@ def __init__(self, input_size, hidden_sizes, output_size, act_func, prior_prec=1
else :
self.output_size = 1
self.squeeze_output = True
self.act = F.tanh if act_func == "tanh" else F.relu

# Set activation function
if act_func == "relu":
self.act = F.relu
elif act_func == "tanh":
self.act = F.tanh
elif act_func == "sigmoid":
self.act = F.sigmoid

# Define layers
if len(hidden_sizes) == 0:
# Linear model
self.hidden_layers = []
self.output_layer = StochasticLinear(self.input_size, self.output_size, sigma_prior = sigma_prior, sigma_init = sigma_init)
else:
# Neural network
self.hidden_layers = nn.ModuleList([StochasticLinear(in_size, out_size, sigma_prior = sigma_prior, sigma_init = sigma_init) for in_size, out_size in zip([self.input_size] + hidden_sizes[:-1], hidden_sizes)])
self.output_layer = StochasticLinear(hidden_sizes[-1], self.output_size, sigma_prior = sigma_prior, sigma_init = sigma_init)

Expand All @@ -78,10 +90,10 @@ def forward(self, x):
out = x
for layer in self.hidden_layers:
out = self.act(layer(out))
logits = self.output_layer(out)
z = self.output_layer(out)
if self.squeeze_output:
logits = torch.squeeze(logits)
return logits
z = torch.squeeze(z).view([-1])
return z

def kl_divergence(self):
kl = 0
Expand All @@ -96,26 +108,8 @@ def kl_divergence(self):
###############################################

class StochasticLinear(nn.Module):
"""Applies a stochastic linear transformation to the incoming data: :math:`y = Ax + b`
Args:
in_features: size of each input sample
out_features: size of each output sample
bias: If set to False, the layer will not learn an additive bias.
Default: ``True``
Shape:
- Input: :math:`(N, *, in\_features)` where :math:`*` means any number of
additional dimensions
- Output: :math:`(N, *, out\_features)` where all but the last dimension
are the same shape as the input.
Attributes:
weight: the learnable weights of the module of shape
`(out_features x in_features)`
bias: the learnable bias of the module of shape `(out_features)`
Examples::
>>> m = nn.Linear(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
"""Applies a stochastic linear transformation to the incoming data: :math:`y = Ax + b`.
This is a stochastic variant of the in-built torch.nn.Linear().
"""

def __init__(self, in_features, out_features, sigma_prior=1.0, sigma_init=1.0, bias=True):
Expand Down Expand Up @@ -157,6 +151,7 @@ def _kl_gaussian(self, p_mu, p_sigma, q_mu, q_sigma):
return 0.5 * torch.sum((var_ratio + t1 - 1 - var_ratio.log()))

def kl_divergence(self):
# Compute KL divergence between current distribution and the prior.
mu = self.weight_mu
sigma = F.softplus(self.weight_spsigma)
mu0 = torch.zeros_like(mu)
Expand All @@ -174,4 +169,81 @@ def extra_repr(self):
return 'in_features={}, out_features={}, sigma_prior={}, sigma_init={}, bias={}'.format(
self.in_features, self.out_features, self.sigma_prior, self.sigma_init, self.bias is not None
)


#################################################################
## MultiLayer Perceptron with support for individual gradients ##
#################################################################

class IndividualGradientMLP(nn.Module):
def __init__(self, input_size, hidden_sizes, output_size, act_func="relu"):
super(type(self), self).__init__()
self.input_size = input_size
self.hidden_sizes = hidden_sizes
if output_size is not None:
self.output_size = output_size
self.squeeze_output = False
else :
self.output_size = 1
self.squeeze_output = True

# Set activation function
if act_func == "relu":
self.act = F.relu
elif act_func == "tanh":
self.act = F.tanh
elif act_func == "sigmoid":
self.act = F.sigmoid

# Define layers
if len(hidden_sizes) == 0:
# Linear model
self.hidden_layers = []
self.output_layer = nn.Linear(self.input_size, self.output_size)
else:
# Neural network
self.hidden_layers = nn.ModuleList([nn.Linear(in_size, out_size) for in_size, out_size in zip([self.input_size] + hidden_sizes[:-1], hidden_sizes)])
self.output_layer = nn.Linear(hidden_sizes[-1], self.output_size)

def forward(self, x, individual_grads=False):
'''
x: The input patterns/features.
individual_grads: Whether or not the activations tensors and linear
combination tensors from each layer are returned. These tensors
are necessary for computing the GGN using goodfellow_backprop_ggn
'''

x = x.view(-1, self.input_size)
out = x
# Save the model inputs, which are considered the activations of the
# 0'th layer.
if individual_grads:
H_list = [out]
Z_list = []

for layer in self.hidden_layers:
Z = layer(out)
out = self.act(Z)

# Save the activations and linear combinations from this layer.
if individual_grads:
H_list.append(out)
Z.retain_grad()
Z.requires_grad_(True)
Z_list.append(Z)

z = self.output_layer(out)
if self.squeeze_output:
z = torch.squeeze(z).view([-1])

# Save the final model ouputs, which are the linear combinations
# from the final layer.
if individual_grads:
z.retain_grad()
z.requires_grad_(True)
Z_list.append(z)

if individual_grads:
return (z, H_list, Z_list)

return z

0 comments on commit e89d4b5

Please sign in to comment.