-
Notifications
You must be signed in to change notification settings - Fork 34
/
toy_flow.py
82 lines (67 loc) · 2.26 KB
/
toy_flow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# Data
from survae.data.datasets.toy import CheckerboardDataset
from torch.utils.data import DataLoader
# Model
import torch.nn as nn
from survae.flows import Flow
from survae.distributions import StandardNormal
from survae.transforms import AffineCouplingBijection, ActNormBijection, Reverse
from survae.nn.layers import ElementwiseParams
# Optim
from torch.optim import Adam
# Plot
import matplotlib.pyplot as plt
##########
## Data ##
##########
train = CheckerboardDataset(num_points=128*1000)
test = CheckerboardDataset(num_points=128*1000)
train_loader = DataLoader(train, batch_size=128, shuffle=False)
test_loader = DataLoader(test, batch_size=128, shuffle=True)
###########
## Model ##
###########
def net():
return nn.Sequential(nn.Linear(1, 200), nn.ReLU(),
nn.Linear(200, 100), nn.ReLU(),
nn.Linear(100, 2), ElementwiseParams(2))
model = Flow(base_dist=StandardNormal((2,)),
transforms=[
AffineCouplingBijection(net()), ActNormBijection(2), Reverse(2),
AffineCouplingBijection(net()), ActNormBijection(2), Reverse(2),
AffineCouplingBijection(net()), ActNormBijection(2), Reverse(2),
AffineCouplingBijection(net()), ActNormBijection(2),
])
###########
## Optim ##
###########
optimizer = Adam(model.parameters(), lr=1e-3)
###########
## Train ##
###########
print('Training...')
for epoch in range(10):
l = 0.0
for i, x in enumerate(train_loader):
optimizer.zero_grad()
loss = -model.log_prob(x).mean()
loss.backward()
optimizer.step()
l += loss.detach().cpu().item()
print('Epoch: {}/{}, Loglik: {:.3f}'.format(epoch+1, 10, l/(i+1)), end='\r')
print('')
############
## Sample ##
############
print('Sampling...')
data = test.data.numpy()
samples = model.sample(100000).numpy()
fig, ax = plt.subplots(1, 2, figsize=(12,6))
ax[0].set_title('Data')
ax[0].hist2d(data[...,0], data[...,1], bins=256, range=[[-4, 4], [-4, 4]])
ax[0].set_xlim([-4, 4]); ax[0].set_ylim([-4, 4]); ax[0].axis('off')
ax[1].set_title('Samples')
ax[1].hist2d(samples[...,0], samples[...,1], bins=256, range=[[-4, 4], [-4, 4]])
ax[1].set_xlim([-4, 4]); ax[1].set_ylim([-4, 4]); ax[1].axis('off')
plt.tight_layout()
plt.show()