Skip to content

Commit

Permalink
Merge pull request #2 from ekrim/master
Browse files Browse the repository at this point in the history
MADE generative pass, tanh layer for scaling params, moons example
  • Loading branch information
ikostrikov2 committed Sep 12, 2018
2 parents fc788f8 + 8193d4d commit a162a5f
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 13 deletions.
3 changes: 3 additions & 0 deletions .gitignore
@@ -1,3 +1,6 @@
# Data
datasets/data/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
1 change: 1 addition & 0 deletions datasets/__init__.py
Expand Up @@ -5,3 +5,4 @@
from .hepmass import HEPMASS
from .miniboone import MINIBOONE
from .bsds300 import BSDS300
from .moons import MOONS
40 changes: 40 additions & 0 deletions datasets/moons.py
@@ -0,0 +1,40 @@
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets as ds

import datasets
import datasets.util


class MOONS:

class Data:

def __init__(self, data):

self.x = data.astype(np.float32)
self.N = self.x.shape[0]

def __init__(self):

trn, val, tst = load_data()

self.trn = self.Data(trn)
self.val = self.Data(val)
self.tst = self.Data(tst)

self.n_dims = self.trn.x.shape[1]

def show_histograms(self, split):

data_split = getattr(self, split, None)
if data_split is None:
raise ValueError('Invalid data split')

datasets.util.plot_hist_marginals(data_split.x)
plt.show()


def load_data():
x = ds.make_moons(n_samples=30000, shuffle=True, noise=0.05)[0]
return x[:24000], x[24000:27000], x[27000:]
25 changes: 15 additions & 10 deletions flows.py
Expand Up @@ -43,9 +43,11 @@ class MADE(nn.Module):
(https://arxiv.org/abs/1502.03509s).
"""

def __init__(self, num_inputs, num_hidden):
def __init__(self, num_inputs, num_hidden, use_tanh=True):
super(MADE, self).__init__()

self.use_tanh = use_tanh

input_mask = get_mask(
num_inputs, num_hidden, num_inputs, mask_type='input')
hidden_mask = get_mask(num_hidden, num_hidden, num_inputs)
Expand All @@ -59,17 +61,20 @@ def __init__(self, num_inputs, num_hidden):

def forward(self, inputs, mode='direct'):
if mode == 'direct':
x = self.main(inputs)

m, a = x.chunk(2, 1)
m, a = self.main(inputs).chunk(2, 1)
if self.use_tanh:
a = torch.tanh(a)
u = (inputs - m) * torch.exp(-a)
return u, -a.sum(-1, keepdim=True)

u = (inputs - m) * torch.exp(a)
return u, a.sum(-1, keepdim=True)
else:
# TODO:
# Sampling with MADE is tricky.
# We need to perform N forward passes.
raise NotImplementedError
x = torch.zeros_like(inputs)
for i_col in range(inputs.shape[1]):
m, a = self.main(x).chunk(2, 1)
if self.use_tanh:
a = torch.tanh(a)
x[:, i_col] = inputs[:, i_col] * torch.exp(a[:, i_col]) + m[:, i_col]
return x, -a.sum(-1, keepdim=True)


class BatchNormFlow(nn.Module):
Expand Down
28 changes: 25 additions & 3 deletions main.py
Expand Up @@ -38,7 +38,7 @@
parser.add_argument(
'--dataset',
default='POWER',
help='POWER | GAS | HEPMASS | MINIBONE | BSDS300')
help='POWER | GAS | HEPMASS | MINIBONE | BSDS300 | MOONS')
parser.add_argument('--flow', default='maf', help='flow to use: maf | glow')
parser.add_argument(
'--no-cuda',
Expand Down Expand Up @@ -68,7 +68,7 @@

kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}

assert args.dataset in ['POWER', 'GAS', 'HEPMASS', 'MINIBONE', 'BSDS300']
assert args.dataset in ['POWER', 'GAS', 'HEPMASS', 'MINIBONE', 'BSDS300', 'MOONS']
dataset = getattr(datasets, args.dataset)()

train_tensor = torch.from_numpy(dataset.trn.x)
Expand Down Expand Up @@ -103,7 +103,8 @@
'GAS': 100,
'HEPMASS': 512,
'MINIBOONE': 512,
'BSDS300': 512
'BSDS300': 512,
'MOONS': 64
}[args.dataset]

modules = []
Expand Down Expand Up @@ -212,3 +213,24 @@ def validate(epoch, model, loader, prefix='Validation'):
best_validation_epoch, best_validation_loss))

validate(best_validation_epoch, best_model, test_loader, prefix='Test')

if args.dataset == 'MOONS':
# generate some examples
best_model.eval()
u = np.random.randn(500, 2).astype(np.float32)
u_tens = torch.from_numpy(u).to(device)
x_synth = best_model.forward(u_tens, mode='inverse')[0].detach().cpu().numpy()

import matplotlib.pyplot as plt

fig = plt.figure()

ax = fig.add_subplot(121)
ax.plot(dataset.val.x[:,0], dataset.val.x[:,1], '.')
ax.set_title('Real data')

ax = fig.add_subplot(122)
ax.plot(x_synth[:,0], x_synth[:,1], '.')
ax.set_title('Synth data')

plt.show()

0 comments on commit a162a5f

Please sign in to comment.