Skip to content

Commit

Permalink
refactor MLP model class
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Sep 22, 2020
1 parent e570530 commit 880688a
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 25 deletions.
4 changes: 2 additions & 2 deletions torch_mnf/flows/affine_half_flow.py
Expand Up @@ -30,9 +30,9 @@ def __init__(self, dim, parity, net_class=MLP, nh=24, scale=True, shift=True):
self.parity = parity
self.s_net = self.t_net = lambda x: x.new_zeros(x.size(0), dim // 2)
if scale:
self.s_net = net_class(dim // 2, dim // 2, nh)
self.s_net = net_class(dim // 2, nh, nh, nh, dim // 2)
if shift:
self.t_net = net_class(dim // 2, dim // 2, nh)
self.t_net = net_class(dim // 2, nh, nh, nh, dim // 2)

def forward(self, z, inverse=False):
z0, z1 = z.chunk(2, dim=1)
Expand Down
10 changes: 5 additions & 5 deletions torch_mnf/flows/spline_flow.py
Expand Up @@ -182,15 +182,15 @@ def RQS(
class NSF_AR(nn.Module):
"""Neural spline flow, coupling layer, [Durkan et al. 2019]"""

def __init__(self, dim, K=5, B=3, hidden_dim=8, base_network=MLP):
def __init__(self, dim, K=5, B=3, n_h=8, net_class=MLP):
super().__init__()
self.dim = dim
self.K = K
self.B = B
self.layers = nn.ModuleList()
self.init_param = nn.Parameter(torch.Tensor(3 * K - 1))
for i in range(1, dim):
self.layers += [base_network(i, 3 * K - 1, hidden_dim)]
self.layers += [net_class(i, n_h, n_h, n_h, 3 * K - 1)]
self.reset_parameters()

def reset_parameters(self):
Expand Down Expand Up @@ -238,13 +238,13 @@ def inverse(self, x):
class NSF_CL(nn.Module):
"""Neural spline flow, coupling layer, [Durkan et al. 2019]"""

def __init__(self, dim, K=5, B=3, hidden_dim=8, base_network=MLP):
def __init__(self, dim, K=5, B=3, n_h=8, net_class=MLP):
super().__init__()
self.dim = dim
self.K = K
self.B = B
self.f1 = base_network(dim // 2, (3 * K - 1) * dim // 2, hidden_dim)
self.f2 = base_network(dim // 2, (3 * K - 1) * dim // 2, hidden_dim)
self.f1 = net_class(dim // 2, n_h, n_h, n_h, (3 * K - 1) * dim // 2)
self.f2 = net_class(dim // 2, n_h, n_h, n_h, (3 * K - 1) * dim // 2)

def forward(self, z):
log_det = torch.zeros(z.shape[0])
Expand Down
26 changes: 9 additions & 17 deletions torch_mnf/models/mlp.py
@@ -1,20 +1,12 @@
from torch import nn


class MLP(nn.Module):
"""Just a 4-layer perceptron. """

def __init__(self, n_in, n_out, n_h):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_in, n_h),
nn.LeakyReLU(0.2),
nn.Linear(n_h, n_h),
nn.LeakyReLU(0.2),
nn.Linear(n_h, n_h),
nn.LeakyReLU(0.2),
nn.Linear(n_h, n_out),
)

def forward(self, x):
return self.net(x)
class MLP(nn.Sequential):
"""Multilayer perceptron"""

def __init__(self, *layer_sizes, leaky_a=0.2):
layers = []
for s1, s2 in zip(layer_sizes, layer_sizes[1:]):
layers.append(nn.Linear(s1, s2))
layers.append(nn.LeakyReLU(leaky_a))
super().__init__(*layers[:-1]) # drop last ReLU
2 changes: 1 addition & 1 deletion torch_mnf/notebooks/2d.py
Expand Up @@ -56,7 +56,7 @@


# %% -- Neural Spline Flow --
flows = [nf.NSF_CL(dim=2, K=8, B=3, hidden_dim=16) for _ in range(3)]
flows = [nf.NSF_CL(dim=2, K=8, B=3, n_h=16) for _ in range(3)]
# prepend each NSF flow with ActNormFlow and Glow
for idx in reversed(range(len(flows))):
flows.insert(idx, nf.ActNormFlow(dim=2))
Expand Down

0 comments on commit 880688a

Please sign in to comment.