Skip to content

Commit

Permalink
mnf_conv big clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Sep 25, 2020
1 parent 41caa6b commit 7bc0975
Showing 1 changed file with 5 additions and 39 deletions.
44 changes: 5 additions & 39 deletions torch_mnf/layers/mnf_conv.py
@@ -1,4 +1,3 @@
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
Expand All @@ -17,18 +16,7 @@ class MNFConv2d(nn.Module):
"""

def __init__(
self,
in_channels, # = 1 for black & white images like MNIST
out_channels,
kernel_size, # int for kernel width and height
n_flows_q=2,
n_flows_r=2,
learn_p=False,
prior_var_w=1,
prior_var_b=1,
flow_h_sizes=[200],
std_init=1,
**kwargs,
self, n_in, n_out, kernel_size, n_flows_q=2, n_flows_r=2, h_sizes=[50]
):
"""
Args:
Expand Down Expand Up @@ -93,27 +81,7 @@ def sample_z(self):
return zs[-1], log_dets.squeeze()

def kl_div(self):
z_sample, log_det_q = self.sample_z(1)

std_w = torch.exp(self.log_std_W)
std_w = std_w.reshape(-1, self.out_channels)
mu_w = self.mean_W.reshape(-1, self.out_channels)
Mtilde = mu_w * z_sample
mean_b = self.mean_b * z_sample
Vtilde = std_w ** 2
# Stacking yields same result as outer product with ones. See eqs. 11, 12.
iUp = torch.stack([torch.exp(self.prior_var_r_p)] * self.out_channels, dim=1)

kl_div_w = 0.5 * torch.sum(
np.log(iUp) - std_w.log() + (Vtilde + Mtilde ** 2) / iUp - 1
)
kl_div_b = 0.5 * torch.sum(
self.prior_var_r_p_bias
- self.log_var_b
+ (torch.exp(self.log_var_b) + mean_b ** 2)
/ torch.exp(self.prior_var_r_p_bias)
- 1
)
z, log_det_q = self.sample_z()

W_var = self.W_log_var.exp()
b_var = self.b_log_var.exp()
Expand All @@ -133,7 +101,7 @@ def kl_div(self):
W_var = W_var.view(-1, len(self.r0_c)) @ self.r0_c # eq. (12)
epsilon_w = torch.randn_like(W_var)
# For convolutional layers, linear mappings empirically work better than
# tanh non-linearity. Hence the removal of a = tf.tanh(a). Christos Louizos
# tanh. Hence no need for act = tf.tanh(act). Christos Louizos
# confirmed this in https://github.com/AMLab-Amsterdam/MNF_VBNN/issues/4
# even though the paper states the use of tanh in conv layers.
act = W_mean + W_var.sqrt() * epsilon_w
Expand All @@ -151,10 +119,8 @@ def kl_div(self):

# Log likelihood of a zero-covariance normal dist: ln N(x | mu, sigma) =
# -1/2 sum_dims(ln(2 pi) + ln(sigma^2) + (x - mu)^2 / sigma^2)
log_r = log_det_r.squeeze() + 0.5 * torch.sum(
-torch.exp(log_var_r) * (z_sample[-1] - mean_r) ** 2
- np.log(2 * np.pi)
+ log_var_r
log_r = log_det_r + 0.5 * torch.sum(
-log_var_r.exp() * (zs[-1] - mean_r) ** 2 + log_var_r
)

return kl_div_W + kl_div_b + log_q - log_r

0 comments on commit 7bc0975

Please sign in to comment.