-
Notifications
You must be signed in to change notification settings - Fork 34
/
cifar10_aug_flow.py
115 lines (95 loc) · 3.79 KB
/
cifar10_aug_flow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import math
import torch
# Data
from survae.data.loaders.image import CIFAR10
# Model
import torch.nn as nn
from survae.flows import Flow
from survae.distributions import StandardNormal, StandardUniform
from survae.transforms import AffineCouplingBijection, ActNormBijection2d, Conv1x1
from survae.transforms import UniformDequantization, Augment, Squeeze2d, Slice
from survae.nn.layers import ElementwiseParams2d
from survae.nn.nets import DenseNet
# Optim
from torch.optim import Adam
# Plot
import torchvision.utils as vutils
############
## Device ##
############
device = 'cuda' if torch.cuda.is_available() else 'cpu'
##########
## Data ##
##########
data = CIFAR10()
train_loader, test_loader = data.get_data_loaders(32)
###########
## Model ##
###########
def net(channels):
return nn.Sequential(DenseNet(in_channels=channels//2,
out_channels=channels,
num_blocks=1,
mid_channels=64,
depth=8,
growth=16,
dropout=0.0,
gated_conv=True,
zero_init=True),
ElementwiseParams2d(2))
model = Flow(base_dist=StandardNormal((24,8,8)),
transforms=[
UniformDequantization(num_bits=8),
Augment(StandardUniform((3,32,32)), x_size=3),
AffineCouplingBijection(net(6)), ActNormBijection2d(6), Conv1x1(6),
AffineCouplingBijection(net(6)), ActNormBijection2d(6), Conv1x1(6),
AffineCouplingBijection(net(6)), ActNormBijection2d(6), Conv1x1(6),
AffineCouplingBijection(net(6)), ActNormBijection2d(6), Conv1x1(6),
Squeeze2d(), Slice(StandardNormal((12,16,16)), num_keep=12),
AffineCouplingBijection(net(12)), ActNormBijection2d(12), Conv1x1(12),
AffineCouplingBijection(net(12)), ActNormBijection2d(12), Conv1x1(12),
AffineCouplingBijection(net(12)), ActNormBijection2d(12), Conv1x1(12),
AffineCouplingBijection(net(12)), ActNormBijection2d(12), Conv1x1(12),
Squeeze2d(), Slice(StandardNormal((24,8,8)), num_keep=24),
AffineCouplingBijection(net(24)), ActNormBijection2d(24), Conv1x1(24),
AffineCouplingBijection(net(24)), ActNormBijection2d(24), Conv1x1(24),
AffineCouplingBijection(net(24)), ActNormBijection2d(24), Conv1x1(24),
AffineCouplingBijection(net(24)), ActNormBijection2d(24), Conv1x1(24),
]).to(device)
###########
## Optim ##
###########
optimizer = Adam(model.parameters(), lr=1e-3)
###########
## Train ##
###########
print('Training...')
for epoch in range(10):
l = 0.0
for i, x in enumerate(train_loader):
optimizer.zero_grad()
loss = -model.log_prob(x.to(device)).sum() / (math.log(2) * x.numel())
loss.backward()
optimizer.step()
l += loss.detach().cpu().item()
print('Epoch: {}/{}, Iter: {}/{}, Bits/dim: {:.3f}'.format(epoch+1, 10, i+1, len(train_loader), l/(i+1)), end='\r')
print('')
##########
## Test ##
##########
print('Testing...')
with torch.no_grad():
l = 0.0
for i, x in enumerate(test_loader):
loss = -model.log_prob(x.to(device)).sum() / (math.log(2) * x.numel())
l += loss.detach().cpu().item()
print('Iter: {}/{}, Bits/dim: {:.3f}'.format(i+1, len(test_loader), l/(i+1)), end='\r')
print('')
############
## Sample ##
############
print('Sampling...')
img = torch.from_numpy(data.test.data[:64]).permute([0,3,1,2])
samples = model.sample(64)
vutils.save_image(img.cpu().float()/255, fp='cifar10_data.png', nrow=8)
vutils.save_image(samples.cpu().float()/255, fp='cifar10_aug_flow.png', nrow=8)