Skip to content

Commit 2dd9910

Browse files
author
Agustinus Kristiadi
committed
Add AVB models & AAE tf implementation
1 parent 89d3833 commit 2dd9910

File tree

4 files changed

+426
-0
lines changed

4 files changed

+426
-0
lines changed
File renamed without changes.
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import tensorflow as tf
2+
import numpy as np
3+
import matplotlib.pyplot as plt
4+
import matplotlib.gridspec as gridspec
5+
import os
6+
from torch.autograd import Variable
7+
from tensorflow.examples.tutorials.mnist import input_data
8+
9+
10+
mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)
11+
mb_size = 32
12+
z_dim = 10
13+
X_dim = mnist.train.images.shape[1]
14+
y_dim = mnist.train.labels.shape[1]
15+
h_dim = 128
16+
c = 0
17+
lr = 1e-3
18+
19+
20+
def plot(samples):
21+
fig = plt.figure(figsize=(4, 4))
22+
gs = gridspec.GridSpec(4, 4)
23+
gs.update(wspace=0.05, hspace=0.05)
24+
25+
for i, sample in enumerate(samples):
26+
ax = plt.subplot(gs[i])
27+
plt.axis('off')
28+
ax.set_xticklabels([])
29+
ax.set_yticklabels([])
30+
ax.set_aspect('equal')
31+
plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
32+
33+
return fig
34+
35+
36+
def xavier_init(size):
37+
in_dim = size[0]
38+
xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
39+
return tf.random_normal(shape=size, stddev=xavier_stddev)
40+
41+
42+
""" Q(z|X) """
43+
X = tf.placeholder(tf.float32, shape=[None, X_dim])
44+
z = tf.placeholder(tf.float32, shape=[None, z_dim])
45+
46+
Q_W1 = tf.Variable(xavier_init([X_dim, h_dim]))
47+
Q_b1 = tf.Variable(tf.zeros(shape=[h_dim]))
48+
49+
Q_W2 = tf.Variable(xavier_init([h_dim, z_dim]))
50+
Q_b2 = tf.Variable(tf.zeros(shape=[z_dim]))
51+
52+
theta_Q = [Q_W1, Q_W2, Q_b1, Q_b2]
53+
54+
55+
def Q(X):
56+
h = tf.nn.relu(tf.matmul(X, Q_W1) + Q_b1)
57+
z = tf.matmul(h, Q_W2) + Q_b2
58+
return z
59+
60+
61+
""" P(X|z) """
62+
P_W1 = tf.Variable(xavier_init([z_dim, h_dim]))
63+
P_b1 = tf.Variable(tf.zeros(shape=[h_dim]))
64+
65+
P_W2 = tf.Variable(xavier_init([h_dim, X_dim]))
66+
P_b2 = tf.Variable(tf.zeros(shape=[X_dim]))
67+
68+
theta_P = [P_W1, P_W2, P_b1, P_b2]
69+
70+
71+
def P(z):
72+
h = tf.nn.relu(tf.matmul(z, P_W1) + P_b1)
73+
logits = tf.matmul(h, P_W2) + P_b2
74+
prob = tf.nn.sigmoid(logits)
75+
return prob, logits
76+
77+
78+
""" D(z) """
79+
D_W1 = tf.Variable(xavier_init([z_dim, h_dim]))
80+
D_b1 = tf.Variable(tf.zeros(shape=[h_dim]))
81+
82+
D_W2 = tf.Variable(xavier_init([h_dim, 1]))
83+
D_b2 = tf.Variable(tf.zeros(shape=[1]))
84+
85+
theta_D = [D_W1, D_W2, D_b1, D_b2]
86+
87+
88+
def D(z):
89+
h = tf.nn.relu(tf.matmul(z, D_W1) + D_b1)
90+
logits = tf.matmul(h, D_W2) + D_b2
91+
prob = tf.nn.sigmoid(logits)
92+
return prob
93+
94+
95+
""" Training """
96+
z_sample = Q(X)
97+
_, logits = P(z_sample)
98+
99+
# Sample from random z
100+
X_samples, _ = P(z)
101+
102+
# E[log P(X|z)]
103+
recon_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits, X))
104+
105+
# Adversarial loss to approx. Q(z|X)
106+
D_real = D(z)
107+
D_fake = D(z_sample)
108+
109+
D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake))
110+
G_loss = -tf.reduce_mean(tf.log(D_fake))
111+
112+
AE_solver = tf.train.AdamOptimizer().minimize(recon_loss, var_list=theta_P + theta_Q)
113+
D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D)
114+
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_Q)
115+
116+
sess = tf.Session()
117+
sess.run(tf.initialize_all_variables())
118+
119+
if not os.path.exists('out/'):
120+
os.makedirs('out/')
121+
122+
i = 0
123+
124+
for it in range(1000000):
125+
X_mb, _ = mnist.train.next_batch(mb_size)
126+
z_mb = np.random.randn(mb_size, z_dim)
127+
128+
_, recon_loss_curr = sess.run([AE_solver, recon_loss], feed_dict={X: X_mb})
129+
_, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, z: z_mb})
130+
_, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={X: X_mb})
131+
132+
if it % 1000 == 0:
133+
print('Iter: {}; D_loss: {:.4}; G_loss: {:.4}; Recon_loss: {:.4}'
134+
.format(it, D_loss_curr, G_loss_curr, recon_loss_curr))
135+
136+
samples = sess.run(X_samples, feed_dict={z: np.random.randn(16, z_dim)})
137+
138+
fig = plot(samples)
139+
plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
140+
i += 1
141+
plt.close(fig)

VAE/adversarial_vb/avb_pytorch.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import torch
2+
import torch.nn
3+
import torch.nn.functional as nn
4+
import torch.autograd as autograd
5+
import torch.optim as optim
6+
import numpy as np
7+
import matplotlib.pyplot as plt
8+
import matplotlib.gridspec as gridspec
9+
import os
10+
from torch.autograd import Variable
11+
from tensorflow.examples.tutorials.mnist import input_data
12+
13+
14+
mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)
15+
mb_size = 32
16+
z_dim = 10
17+
eps_dim = 4
18+
X_dim = mnist.train.images.shape[1]
19+
y_dim = mnist.train.labels.shape[1]
20+
h_dim = 128
21+
cnt = 0
22+
lr = 1e-3
23+
24+
25+
# Encoder: q(z|x,eps)
26+
Q = torch.nn.Sequential(
27+
torch.nn.Linear(X_dim + eps_dim, h_dim),
28+
torch.nn.ReLU(),
29+
torch.nn.Linear(h_dim, z_dim)
30+
)
31+
32+
# Decoder: p(x|z)
33+
P = torch.nn.Sequential(
34+
torch.nn.Linear(z_dim, h_dim),
35+
torch.nn.ReLU(),
36+
torch.nn.Linear(h_dim, X_dim),
37+
torch.nn.Sigmoid()
38+
)
39+
40+
# Discriminator: d(z)
41+
D = torch.nn.Sequential(
42+
torch.nn.Linear(z_dim, h_dim),
43+
torch.nn.ReLU(),
44+
torch.nn.Linear(h_dim, 1),
45+
torch.nn.Sigmoid()
46+
)
47+
48+
49+
def reset_grad():
50+
Q.zero_grad()
51+
P.zero_grad()
52+
D.zero_grad()
53+
54+
55+
def sample_X(size, include_y=False):
56+
X, y = mnist.train.next_batch(size)
57+
X = Variable(torch.from_numpy(X))
58+
59+
if include_y:
60+
y = np.argmax(y, axis=1).astype(np.int)
61+
y = Variable(torch.from_numpy(y))
62+
return X, y
63+
64+
return X
65+
66+
67+
Q_solver = optim.Adam(Q.parameters(), lr=lr)
68+
P_solver = optim.Adam(P.parameters(), lr=lr)
69+
D_solver = optim.Adam(D.parameters(), lr=lr)
70+
71+
72+
for it in range(1000000):
73+
X = sample_X(mb_size)
74+
eps = Variable(torch.randn(mb_size, eps_dim))
75+
X_eps = torch.cat([X, eps], 1)
76+
z = Variable(torch.randn(mb_size, z_dim))
77+
78+
# Optimize VAE w.r.t. reconstruction loss
79+
z_sample = Q(X_eps)
80+
X_sample = P(z_sample)
81+
82+
recon_loss = nn.binary_cross_entropy(X_sample, X)
83+
84+
recon_loss.backward()
85+
P_solver.step()
86+
Q_solver.step()
87+
reset_grad()
88+
89+
# Discriminator D(z)
90+
z_fake = Q(X_eps)
91+
D_real = D(z)
92+
D_fake = D(z_fake)
93+
94+
D_loss = -torch.mean(torch.log(D_real) + torch.log(1 - D_fake))
95+
96+
D_loss.backward()
97+
D_solver.step()
98+
reset_grad()
99+
100+
# Q(z|X,eps)
101+
z_fake = Q(X_eps)
102+
D_fake = D(z_fake)
103+
104+
G_loss = -torch.mean(torch.log(D_fake))
105+
106+
G_loss.backward()
107+
Q_solver.step()
108+
reset_grad()
109+
110+
# Print and plot every now and then
111+
if it % 1000 == 0:
112+
print('Iter-{}; D_loss: {:.4}; G_loss: {:.4}; recon_loss: {:.4}'
113+
.format(it, D_loss.data[0], G_loss.data[0], recon_loss.data[0]))
114+
115+
samples = P(z).data.numpy()[:16]
116+
117+
fig = plt.figure(figsize=(4, 4))
118+
gs = gridspec.GridSpec(4, 4)
119+
gs.update(wspace=0.05, hspace=0.05)
120+
121+
for i, sample in enumerate(samples):
122+
ax = plt.subplot(gs[i])
123+
plt.axis('off')
124+
ax.set_xticklabels([])
125+
ax.set_yticklabels([])
126+
ax.set_aspect('equal')
127+
plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
128+
129+
if not os.path.exists('out/'):
130+
os.makedirs('out/')
131+
132+
plt.savefig('out/{}.png'
133+
.format(str(cnt).zfill(3)), bbox_inches='tight')
134+
cnt += 1
135+
plt.close(fig)

0 commit comments

Comments
 (0)