Skip to content

Commit c78cb3f

Browse files
author
Jaan Altosaar
committed
add inverse autoregressive flow classes
1 parent 68c8535 commit c78cb3f

File tree

2 files changed

+205
-13
lines changed

2 files changed

+205
-13
lines changed

flow.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
"""Credit: mostly based on Ilya's excellent implementation here: https://github.com/ikostrikov/pytorch-flows"""
2+
import numpy as np
3+
import torch
4+
import torch.nn as nn
5+
from torch.nn import functional as F
6+
7+
8+
class InverseAutoregressiveFlow(nn.Module):
9+
"""Inverse Autoregressive Flows with LSTM-type update. One block.
10+
11+
Eq 11-14 of https://arxiv.org/abs/1606.04934
12+
"""
13+
def __init__(self, num_input, num_hidden, num_context):
14+
super().__init__()
15+
self.made = MADE(num_input=num_input, num_output=num_input * 2,
16+
num_hidden=num_hidden, num_context=num_context)
17+
# init such that sigmoid(s) is close to 1 for stability
18+
self.sigmoid_arg_bias = nn.Parameter(torch.ones(num_input) * 2)
19+
self.sigmoid = nn.Sigmoid()
20+
self.log_sigmoid = nn.LogSigmoid()
21+
22+
def forward(self, input, context=None):
23+
m, s = torch.chunk(self.made(input, context), chunks=2, dim=-1)
24+
s = s + self.sigmoid_arg_bias
25+
sigmoid = self.sigmoid(s)
26+
z = sigmoid * input + (1 - sigmoid) * m
27+
return z, -self.log_sigmoid(s)
28+
29+
30+
class FlowSequential(nn.Sequential):
31+
"""Forward pass."""
32+
33+
def forward(self, input, context=None):
34+
total_log_prob = torch.zeros_like(input, device=input.device)
35+
for block in self._modules.values():
36+
input, log_prob = block(input, context)
37+
total_log_prob += log_prob
38+
return input, total_log_prob
39+
40+
41+
class MaskedLinear(nn.Module):
42+
"""Linear layer with some input-output connections masked."""
43+
def __init__(self, in_features, out_features, mask, context_features=None, bias=True):
44+
super().__init__()
45+
self.linear = nn.Linear(in_features, out_features, bias)
46+
self.register_buffer("mask", mask)
47+
if context_features is not None:
48+
self.cond_linear = nn.Linear(context_features, out_features, bias=False)
49+
50+
def forward(self, input, context=None):
51+
output = F.linear(input, self.mask * self.linear.weight, self.linear.bias)
52+
if context is None:
53+
return output
54+
else:
55+
return output + self.cond_linear(context)
56+
57+
58+
class MADE(nn.Module):
59+
"""Implements MADE: Masked Autoencoder for Distribution Estimation.
60+
61+
Follows https://arxiv.org/abs/1502.03509
62+
63+
This is used to build MAF: Masked Autoregressive Flow (https://arxiv.org/abs/1705.07057).
64+
"""
65+
def __init__(self, num_input, num_output, num_hidden, num_context):
66+
super().__init__()
67+
# m corresponds to m(k), the maximum degree of a node in the MADE paper
68+
self._m = []
69+
self._masks = []
70+
self._build_masks(num_input, num_output, num_hidden, num_layers=3)
71+
self._check_masks()
72+
modules = []
73+
self.input_context_net = MaskedLinear(num_input, num_hidden, self._masks[0], num_context)
74+
modules.append(nn.ReLU())
75+
modules.append(MaskedLinear(num_hidden, num_hidden, self._masks[1], context_features=None))
76+
modules.append(nn.ReLU())
77+
modules.append(MaskedLinear(num_hidden, num_output, self._masks[2], context_features=None))
78+
self.net = nn.Sequential(*modules)
79+
80+
def _build_masks(self, num_input, num_output, num_hidden, num_layers):
81+
"""Build the masks according to Eq 12 and 13 in the MADE paper."""
82+
rng = np.random.RandomState(0)
83+
# assign input units a number between 1 and D
84+
self._m.append(np.arange(1, num_input + 1))
85+
for i in range(1, num_layers + 1):
86+
# randomly assign maximum number of input nodes to connect to
87+
if i == num_layers:
88+
# assign output layer units a number between 1 and D
89+
m = np.arange(1, num_input + 1)
90+
assert num_output % num_input == 0, "num_output must be multiple of num_input"
91+
self._m.append(np.hstack([m for _ in range(num_output // num_input)]))
92+
else:
93+
# assign hidden layer units a number between 1 and D-1
94+
self._m.append(rng.randint(1, num_input, size=num_hidden))
95+
#self._m.append(np.arange(1, num_hidden + 1) % (num_input - 1) + 1)
96+
if i == num_layers:
97+
mask = self._m[i][None, :] > self._m[i - 1][:, None]
98+
else:
99+
# input to hidden & hidden to hidden
100+
mask = self._m[i][None, :] >= self._m[i - 1][:, None]
101+
# need to transpose for torch linear layer, shape (num_output, num_input)
102+
self._masks.append(torch.from_numpy(mask.astype(np.float32).T))
103+
104+
def _check_masks(self):
105+
"""Check that the connectivity matrix between layers is lower triangular."""
106+
# (num_input, num_hidden)
107+
prev = self._masks[0].t()
108+
for i in range(1, len(self._masks)):
109+
# num_hidden is second axis
110+
prev = prev @ self._masks[i].t()
111+
final = prev.numpy()
112+
num_input = self._masks[0].shape[1]
113+
num_output = self._masks[-1].shape[0]
114+
assert final.shape == (num_input, num_output)
115+
if num_output == num_input:
116+
assert np.triu(final).all() == 0
117+
else:
118+
for submat in np.split(final,
119+
indices_or_sections=num_output // num_input,
120+
axis=1):
121+
assert np.triu(submat).all() == 0
122+
123+
def forward(self, input, context=None):
124+
# first hidden layer receives input and context
125+
hidden = self.input_context_net(input, context)
126+
# rest of the network is conditioned on both input and context
127+
return self.net(hidden)
128+
129+
130+
131+
class Reverse(nn.Module):
132+
""" An implementation of a reversing layer from
133+
Density estimation using Real NVP
134+
(https://arxiv.org/abs/1605.08803).
135+
136+
From https://github.com/ikostrikov/pytorch-flows/blob/master/main.py
137+
"""
138+
139+
def __init__(self, num_input):
140+
super(Reverse, self).__init__()
141+
self.perm = np.array(np.arange(0, num_input)[::-1])
142+
self.inv_perm = np.argsort(self.perm)
143+
144+
def forward(self, inputs, context=None, mode='forward'):
145+
if mode == "forward":
146+
return inputs[:, :, self.perm], torch.zeros_like(inputs, device=inputs.device)
147+
elif mode == "inverse":
148+
return inputs[:, :, self.inv_perm], torch.zeros_like(inputs, device=inputs.device)
149+
else:
150+
raise ValueError("Mode must be one of {forward, inverse}.")
151+
152+

train_variational_autoencoder_pytorch.py

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515
import pathlib
1616
import h5py
1717
import random
18+
import flow
1819

1920
config = """
2021
latent_size: 128
22+
variational: flow
23+
flow_depth: 2
2124
data_size: 784
2225
learning_rate: 0.001
2326
batch_size: 128
@@ -30,30 +33,29 @@
3033
seed: 582838
3134
"""
3235

33-
3436
class Model(nn.Module):
3537
"""Bernoulli model parameterized by a generative network with Gaussian latents for MNIST."""
36-
def __init__(self, latent_size, data_size, batch_size, device):
38+
def __init__(self, latent_size, data_size):
3739
super().__init__()
38-
self.p_z = torch.distributions.Normal(
39-
torch.zeros(latent_size, device=device),
40-
torch.ones(latent_size, device=device))
40+
self.register_buffer('p_z_loc', torch.zeros(latent_size))
41+
self.register_buffer('p_z_scale', torch.ones(latent_size))
42+
self.log_p_z = NormalLogProb()
4143
self.log_p_x = BernoulliLogProb()
4244
self.generative_network = NeuralNetwork(input_size=latent_size,
4345
output_size=data_size,
4446
hidden_size=latent_size * 2)
4547

4648
def forward(self, z, x):
4749
"""Return log probability of model."""
48-
log_p_z = self.p_z.log_prob(z).sum(-1)
50+
log_p_z = self.log_p_z(self.p_z_loc, self.p_z_scale, z).sum(-1, keepdim=True)
4951
logits = self.generative_network(z)
5052
# unsqueeze sample dimension
5153
logits, x = torch.broadcast_tensors(logits, x.unsqueeze(1))
52-
log_p_x = self.log_p_x(logits, x).sum(-1)
54+
log_p_x = self.log_p_x(logits, x).sum(-1, keepdim=True)
5355
return log_p_z + log_p_x
5456

5557

56-
class Variational(nn.Module):
58+
class VariationalMeanField(nn.Module):
5759
"""Approximate posterior parameterized by an inference network."""
5860
def __init__(self, latent_size, data_size):
5961
super().__init__()
@@ -73,6 +75,38 @@ def forward(self, x, n_samples=1):
7375
return z, log_q_z
7476

7577

78+
class VariationalFlow(nn.Module):
79+
"""Approximate posterior parameterized by a flow (https://arxiv.org/abs/1606.04934)."""
80+
def __init__(self, latent_size, data_size, flow_depth):
81+
super().__init__()
82+
hidden_size = latent_size * 2
83+
self.inference_network = NeuralNetwork(input_size=data_size,
84+
# loc, scale, and context
85+
output_size=latent_size * 3,
86+
hidden_size=hidden_size)
87+
modules = []
88+
for _ in range(flow_depth):
89+
modules.append(flow.InverseAutoregressiveFlow(num_input=latent_size,
90+
num_hidden=hidden_size,
91+
num_context=latent_size))
92+
modules.append(flow.Reverse(latent_size))
93+
self.q_z_flow = flow.FlowSequential(*modules)
94+
self.log_q_z_0 = NormalLogProb()
95+
self.softplus = nn.Softplus()
96+
97+
def forward(self, x, n_samples=1):
98+
"""Return sample of latent variable and log prob."""
99+
loc, scale_arg, h = torch.chunk(self.inference_network(x).unsqueeze(1), chunks=3, dim=-1)
100+
scale = self.softplus(scale_arg)
101+
eps = torch.randn((loc.shape[0], n_samples, loc.shape[-1]), device=loc.device)
102+
z_0 = loc + scale * eps # reparameterization
103+
log_q_z_0 = self.log_q_z_0(loc, scale, z_0)
104+
z_T, log_q_z_flow = self.q_z_flow(z_0, context=h)
105+
log_q_z = (log_q_z_0 + log_q_z_flow).sum(-1, keepdim=True)
106+
return z_T, log_q_z
107+
108+
109+
76110
class NeuralNetwork(nn.Module):
77111
def __init__(self, input_size, output_size, hidden_size):
78112
super().__init__()
@@ -155,11 +189,17 @@ def evaluate(n_samples, model, variational, eval_data):
155189
random.seed(cfg.seed)
156190

157191
model = Model(latent_size=cfg.latent_size,
158-
data_size=cfg.data_size,
159-
batch_size=cfg.batch_size,
160-
device=device)
161-
variational = Variational(latent_size=cfg.latent_size,
162-
data_size=cfg.data_size)
192+
data_size=cfg.data_size)
193+
if cfg.variational == 'flow':
194+
variational = VariationalFlow(latent_size=cfg.latent_size,
195+
data_size=cfg.data_size,
196+
flow_depth=cfg.flow_depth)
197+
elif cfg.variational == 'mean-field':
198+
variational = VariationalMeanField(latent_size=cfg.latent_size,
199+
data_size=cfg.data_size)
200+
else:
201+
raise ValueError('Variational distribution not implemented: %s' % cfg.variational)
202+
163203
model.to(device)
164204
variational.to(device)
165205

0 commit comments

Comments
 (0)