/
生成对抗模型GAN
121 lines (96 loc) · 4.86 KB
/
生成对抗模型GAN
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import moxing.tensorflow as mox
tf.flags.DEFINE_string('data_url', '/export1/zzy/mnist/zip/data', '')
flags = tf.flags.FLAGS
batch_size = 50
def input_fn(run_mode, **kwargs):
mnist = input_data.read_data_sets(flags.data_url, one_hot=True)
def gen():
while True:
yield mnist.train.next_batch(batch_size)
ds = tf.data.Dataset.from_generator(
gen, output_types=(tf.float32, tf.int64),
output_shapes=(tf.TensorShape([None, 784]), tf.TensorShape([None, 10])))
images, labels = ds.make_one_shot_iterator().get_next()
images = tf.reshape(images, shape=[-1, 28, 28, 1])
images = tf.image.resize_images(images, [64, 64])
images = (images - 0.5) / 0.5
return images, labels
def model_fn(inputs, run_mode, **kwargs):
images, labels = inputs
isTrain = (run_mode == mox.ModeKeys.TRAIN)
def lrelu(x, th=0.2):
return tf.maximum(th * x, x)
# G(z)
def generator(x, isTrain=True, reuse=False):
with tf.variable_scope('generator', reuse=reuse):
# 1st hidden layer
conv1 = tf.layers.conv2d_transpose(x, 1024, [4, 4], strides=(1, 1), padding='valid')
lrelu1 = lrelu(tf.layers.batch_normalization(conv1, training=isTrain), 0.2)
# 2nd hidden layer
conv2 = tf.layers.conv2d_transpose(lrelu1, 512, [4, 4], strides=(2, 2), padding='same')
lrelu2 = lrelu(tf.layers.batch_normalization(conv2, training=isTrain), 0.2)
# 3rd hidden layer
conv3 = tf.layers.conv2d_transpose(lrelu2, 256, [4, 4], strides=(2, 2), padding='same')
lrelu3 = lrelu(tf.layers.batch_normalization(conv3, training=isTrain), 0.2)
# 4th hidden layer
conv4 = tf.layers.conv2d_transpose(lrelu3, 128, [4, 4], strides=(2, 2), padding='same')
lrelu4 = lrelu(tf.layers.batch_normalization(conv4, training=isTrain), 0.2)
# output layer
conv5 = tf.layers.conv2d_transpose(lrelu4, 1, [4, 4], strides=(2, 2), padding='same')
o = tf.nn.tanh(conv5)
return o
# D(x)
def discriminator(x, isTrain=True, reuse=False):
with tf.variable_scope('discriminator', reuse=reuse):
# 1st hidden layer
conv1 = tf.layers.conv2d(x, 128, [4, 4], strides=(2, 2), padding='same')
lrelu1 = lrelu(conv1, 0.2)
# 2nd hidden layer
conv2 = tf.layers.conv2d(lrelu1, 256, [4, 4], strides=(2, 2), padding='same')
lrelu2 = lrelu(tf.layers.batch_normalization(conv2, training=isTrain), 0.2)
# 3rd hidden layer
conv3 = tf.layers.conv2d(lrelu2, 512, [4, 4], strides=(2, 2), padding='same')
lrelu3 = lrelu(tf.layers.batch_normalization(conv3, training=isTrain), 0.2)
# 4th hidden layer
conv4 = tf.layers.conv2d(lrelu3, 1024, [4, 4], strides=(2, 2), padding='same')
lrelu4 = lrelu(tf.layers.batch_normalization(conv4, training=isTrain), 0.2)
# output layer
conv5 = tf.layers.conv2d(lrelu4, 1, [4, 4], strides=(1, 1), padding='valid')
o = tf.nn.sigmoid(conv5)
return o, conv5
# networks : generator
z = tf.random_normal(shape=[batch_size, 1, 1, 100], mean=0.0, stddev=1.0)
G_z = generator(z, isTrain)
# networks : discriminator
D_real, D_real_logits = discriminator(images, isTrain)
D_fake, D_fake_logits = discriminator(G_z, isTrain, reuse=True)
# loss for each network
D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_real_logits,
labels=tf.ones(
[batch_size, 1, 1, 1])))
D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits,
labels=tf.zeros(
[batch_size, 1, 1, 1])))
D_loss = D_loss_real + D_loss_fake
G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits,
labels=tf.ones(
[batch_size, 1, 1, 1])))
tf.summary.image('G_z', G_z, max_outputs=10)
return mox.ModelSpec(loss=[D_loss, G_loss],
var_scopes=[['discriminator'], ['generator']],
log_info={'D_loss': D_loss, 'G_loss': G_loss})
if __name__ == '__main__':
mox.run(input_fn=input_fn,
model_fn=model_fn,
optimizer_fn=[mox.get_optimizer_fn(name='adam', learning_rate=0.0002, beta1=0.5),
mox.get_optimizer_fn(name='adam', learning_rate=0.0002, beta1=0.5)],
run_mode=mox.ModeKeys.TRAIN,
batch_size=batch_size,
auto_batch=False,
log_dir='/tmp/delete_me/dcgan_mnist',
max_number_of_steps=11000,
log_every_n_steps=9,
save_summary_steps=99,
save_model_secs=60)