Skip to content

Commit

Permalink
fix cifar10 resnet
Browse files Browse the repository at this point in the history
  • Loading branch information
igul222 committed Jun 22, 2017
1 parent 9a4b20b commit fa66c57
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions gan_cifar_resnet.py
Expand Up @@ -172,7 +172,7 @@ def Discriminator(inputs, labels):
output = tf.reduce_mean(output, axis=[2,3])
output_wgan = lib.ops.linear.Linear('Discriminator.Output', DIM_D, 1, output)
output_wgan = tf.reshape(output_wgan, [-1])
if ACGAN:
if CONDITIONAL and ACGAN:
output_acgan = lib.ops.linear.Linear('Discriminator.ACGANOutput', DIM_D, 10, output)
return output_wgan, output_acgan
else:
Expand Down Expand Up @@ -220,7 +220,7 @@ def Discriminator(inputs, labels):
disc_real = disc_all[:BATCH_SIZE/len(DEVICES_A)]
disc_fake = disc_all[BATCH_SIZE/len(DEVICES_A):]
disc_costs.append(tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real))
if ACGAN:
if CONDITIONAL and ACGAN:
disc_acgan_costs.append(tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(logits=disc_all_acgan[:BATCH_SIZE/len(DEVICES_A)], labels=real_and_fake_labels[:BATCH_SIZE/len(DEVICES_A)])
))
Expand Down Expand Up @@ -265,7 +265,7 @@ def Discriminator(inputs, labels):
disc_costs.append(gradient_penalty)

disc_wgan = tf.add_n(disc_costs) / len(DEVICES_A)
if ACGAN:
if CONDITIONAL and ACGAN:
disc_acgan = tf.add_n(disc_acgan_costs) / len(DEVICES_A)
disc_acgan_acc = tf.add_n(disc_acgan_accs) / len(DEVICES_A)
disc_acgan_fake_acc = tf.add_n(disc_acgan_fake_accs) / len(DEVICES_A)
Expand All @@ -289,7 +289,7 @@ def Discriminator(inputs, labels):
with tf.device(device):
n_samples = GEN_BS_MULTIPLE * BATCH_SIZE / len(DEVICES)
fake_labels = tf.cast(tf.random_uniform([n_samples])*10, tf.int32)
if ACGAN:
if CONDITIONAL and ACGAN:
disc_fake, disc_fake_acgan = Discriminator(Generator(n_samples,fake_labels), fake_labels)
gen_costs.append(-tf.reduce_mean(disc_fake))
gen_acgan_costs.append(tf.reduce_mean(
Expand All @@ -298,7 +298,7 @@ def Discriminator(inputs, labels):
else:
gen_costs.append(-tf.reduce_mean(Discriminator(Generator(n_samples, fake_labels), fake_labels)[0]))
gen_cost = (tf.add_n(gen_costs) / len(DEVICES))
if ACGAN:
if CONDITIONAL and ACGAN:
gen_cost += (ACGAN_SCALE_G*(tf.add_n(gen_acgan_costs) / len(DEVICES)))


Expand Down Expand Up @@ -370,13 +370,13 @@ def inf_train_gen():

for i in xrange(N_CRITIC):
_data,_labels = gen.next()
if ACGAN:
if CONDITIONAL and ACGAN:
_disc_cost, _disc_wgan, _disc_acgan, _disc_acgan_acc, _disc_acgan_fake_acc, _ = session.run([disc_cost, disc_wgan, disc_acgan, disc_acgan_acc, disc_acgan_fake_acc, disc_train_op], feed_dict={all_real_data_int: _data, all_real_labels:_labels, _iteration:iteration})
else:
_disc_cost, _ = session.run([disc_cost, disc_train_op], feed_dict={all_real_data_int: _data, all_real_labels:_labels, _iteration:iteration})

lib.plot.plot('cost', _disc_cost)
if ACGAN:
if CONDITIONAL and ACGAN:
lib.plot.plot('wgan', _disc_wgan)
lib.plot.plot('acgan', _disc_acgan)
lib.plot.plot('acc_real', _disc_acgan_acc)
Expand Down

0 comments on commit fa66c57

Please sign in to comment.