Skip to content

Commit

Permalink
fix bugs in WGAN, WGAN-GP, DRAGAN
Browse files Browse the repository at this point in the history
output of discriminator is changed into linear output (logit value) from
sigmoid output
  • Loading branch information
hwalsuklee committed Nov 9, 2017
1 parent f24a27f commit 06cc307
Show file tree
Hide file tree
Showing 32 changed files with 10 additions and 10 deletions.
2 changes: 1 addition & 1 deletion DRAGAN.py
Expand Up @@ -119,7 +119,7 @@ def build_model(self):
alpha = tf.random_uniform(shape=self.inputs.get_shape(), minval=0.,maxval=1.)
differences = self.inputs_p - self.inputs # This is different from WGAN-GP
interpolates = self.inputs + (alpha * differences)
D_inter,_,_=self.discriminator(interpolates, is_training=True, reuse=True)
_,D_inter,_=self.discriminator(interpolates, is_training=True, reuse=True)
gradients = tf.gradients(D_inter, [interpolates])[0]
slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2)
Expand Down
6 changes: 3 additions & 3 deletions README.md
Expand Up @@ -12,7 +12,7 @@ Pytorch Version is now availabel at https://github.com/znxlwm/pytorch-generative
**GAN** | [Arxiv](https://arxiv.org/abs/1406.2661) | <img src = 'assets/equations/GAN.png' height = '70px'>
**LSGAN**| [Arxiv](https://arxiv.org/abs/1611.04076) | <img src = 'assets/equations/LSGAN.png' height = '70px'>
**WGAN**| [Arxiv](https://arxiv.org/abs/1701.07875) | <img src = 'assets/equations/WGAN.png' height = '105px'>
**WGAN-GP**| [Arxiv](https://arxiv.org/abs/1704.00028) | <img src = 'assets/equations/WGAN_GP.png' height = '70px'>
**WGAN_GP**| [Arxiv](https://arxiv.org/abs/1704.00028) | <img src = 'assets/equations/WGAN_GP.png' height = '70px'>
**DRAGAN**| [Arxiv](https://arxiv.org/abs/1705.07215) | <img src = 'assets/equations/DRAGAN.png' height = '70px'>
**CGAN**| [Arxiv](https://arxiv.org/abs/1411.1784) | <img src = 'assets/equations/CGAN.png' height = '70px'>
**infoGAN**| [Arxiv](https://arxiv.org/abs/1606.03657) | <img src = 'assets/equations/infoGAN.png' height = '70px'>
Expand Down Expand Up @@ -40,7 +40,7 @@ All results are randomly sampled.
GAN | <img src = 'assets/mnist_results/random_generation/GAN_epoch001_test_all_classes.png' height = '230px'> | <img src = 'assets/mnist_results/random_generation/GAN_epoch009_test_all_classes.png' height = '230px'> | <img src = 'assets/mnist_results/random_generation/GAN_epoch024_test_all_classes.png' height = '230px'>
LSGAN | <img src = 'assets/mnist_results/random_generation/LSGAN_epoch001_test_all_classes.png' height = '230px'> | <img src = 'assets/mnist_results/random_generation/LSGAN_epoch009_test_all_classes.png' height = '230px'> | <img src = 'assets/mnist_results/random_generation/LSGAN_epoch024_test_all_classes.png' height = '230px'>
WGAN | <img src = 'assets/mnist_results/random_generation/WGAN_epoch001_test_all_classes.png' height = '230px'> | <img src = 'assets/mnist_results/random_generation/WGAN_epoch009_test_all_classes.png' height = '230px'> | <img src = 'assets/mnist_results/random_generation/WGAN_epoch024_test_all_classes.png' height = '230px'>
WGAN-GP | <img src = 'assets/mnist_results/random_generation/WGAN-GP_epoch001_test_all_classes.png' height = '230px'> | <img src = 'assets/mnist_results/random_generation/WGAN-GP_epoch009_test_all_classes.png' height = '230px'> | <img src = 'assets/mnist_results/random_generation/WGAN-GP_epoch024_test_all_classes.png' height = '230px'>
WGAN_GP | <img src = 'assets/mnist_results/random_generation/WGAN_GP_epoch001_test_all_classes.png' height = '230px'> | <img src = 'assets/mnist_results/random_generation/WGAN_GP_epoch009_test_all_classes.png' height = '230px'> | <img src = 'assets/mnist_results/random_generation/WGAN_GP_epoch024_test_all_classes.png' height = '230px'>
DRAGAN | <img src = 'assets/mnist_results/random_generation/DRAGAN_epoch001_test_all_classes.png' height = '230px'> | <img src = 'assets/mnist_results/random_generation/DRAGAN_epoch009_test_all_classes.png' height = '230px'> | <img src = 'assets/mnist_results/random_generation/DRAGAN_epoch024_test_all_classes.png' height = '230px'>
EBGAN | <img src = 'assets/mnist_results/random_generation/EBGAN_epoch001_test_all_classes.png' height = '230px'> | <img src = 'assets/mnist_results/random_generation/EBGAN_epoch009_test_all_classes.png' height = '230px'> | <img src = 'assets/mnist_results/random_generation/EBGAN_epoch024_test_all_classes.png' height = '230px'>
BEGAN | <img src = 'assets/mnist_results/random_generation/BEGAN_epoch001_test_all_classes.png' height = '230px'> | <img src = 'assets/mnist_results/random_generation/BEGAN_epoch009_test_all_classes.png' height = '230px'> | <img src = 'assets/mnist_results/random_generation/BEGAN_epoch024_test_all_classes.png' height = '230px'>
Expand Down Expand Up @@ -79,7 +79,7 @@ All results are randomly sampled.
GAN | <img src = 'assets/fashion_mnist_results/random_generation/GAN_epoch000_test_all_classes.png' height = '230px'> | <img src = 'assets/fashion_mnist_results/random_generation/GAN_epoch019_test_all_classes.png' height = '230px'> | <img src = 'assets/fashion_mnist_results/random_generation/GAN_epoch039_test_all_classes.png' height = '230px'>
LSGAN | <img src = 'assets/fashion_mnist_results/random_generation/LSGAN_epoch000_test_all_classes.png' height = '230px'> | <img src = 'assets/fashion_mnist_results/random_generation/LSGAN_epoch019_test_all_classes.png' height = '230px'> | <img src = 'assets/fashion_mnist_results/random_generation/LSGAN_epoch039_test_all_classes.png' height = '230px'>
WGAN | <img src = 'assets/fashion_mnist_results/random_generation/WGAN_epoch000_test_all_classes.png' height = '230px'> | <img src = 'assets/fashion_mnist_results/random_generation/WGAN_epoch019_test_all_classes.png' height = '230px'> | <img src = 'assets/fashion_mnist_results/random_generation/WGAN_epoch039_test_all_classes.png' height = '230px'>
WGAN-GP | <img src = 'assets/fashion_mnist_results/random_generation/WGAN-GP_epoch000_test_all_classes.png' height = '230px'> | <img src = 'assets/fashion_mnist_results/random_generation/WGAN-GP_epoch019_test_all_classes.png' height = '230px'> | <img src = 'assets/fashion_mnist_results/random_generation/WGAN-GP_epoch039_test_all_classes.png' height = '230px'>
WGAN_GP | <img src = 'assets/fashion_mnist_results/random_generation/WGAN_GP_epoch000_test_all_classes.png' height = '230px'> | <img src = 'assets/fashion_mnist_results/random_generation/WGAN_GP_epoch019_test_all_classes.png' height = '230px'> | <img src = 'assets/fashion_mnist_results/random_generation/WGAN_GP_epoch039_test_all_classes.png' height = '230px'>
DRAGAN | <img src = 'assets/fashion_mnist_results/random_generation/DRAGAN_epoch000_test_all_classes.png' height = '230px'> | <img src = 'assets/fashion_mnist_results/random_generation/DRAGAN_epoch019_test_all_classes.png' height = '230px'> | <img src = 'assets/fashion_mnist_results/random_generation/DRAGAN_epoch039_test_all_classes.png' height = '230px'>
EBGAN | <img src = 'assets/fashion_mnist_results/random_generation/EBGAN_epoch000_test_all_classes.png' height = '230px'> | <img src = 'assets/fashion_mnist_results/random_generation/EBGAN_epoch019_test_all_classes.png' height = '230px'> | <img src = 'assets/fashion_mnist_results/random_generation/EBGAN_epoch039_test_all_classes.png' height = '230px'>
BEGAN | <img src = 'assets/fashion_mnist_results/random_generation/BEGAN_epoch000_test_all_classes.png' height = '230px'> | <img src = 'assets/fashion_mnist_results/random_generation/BEGAN_epoch019_test_all_classes.png' height = '230px'> | <img src = 'assets/fashion_mnist_results/random_generation/BEGAN_epoch039_test_all_classes.png' height = '230px'>
Expand Down
4 changes: 2 additions & 2 deletions WGAN.py
Expand Up @@ -99,8 +99,8 @@ def build_model(self):
D_fake, D_fake_logits, _ = self.discriminator(G, is_training=True, reuse=True)

# get loss for discriminator
d_loss_real = - tf.reduce_mean(D_real)
d_loss_fake = tf.reduce_mean(D_fake)
d_loss_real = - tf.reduce_mean(D_real_logits)
d_loss_fake = tf.reduce_mean(D_fake_logits)

self.d_loss = d_loss_real + d_loss_fake

Expand Down
8 changes: 4 additions & 4 deletions WGAN_GP.py
Expand Up @@ -9,7 +9,7 @@
from utils import *

class WGAN_GP(object):
model_name = "WGAN-GP" # name for checkpoint
model_name = "WGAN_GP" # name for checkpoint

def __init__(self, sess, epoch, batch_size, z_dim, dataset_name, checkpoint_dir, result_dir, log_dir):
self.sess = sess
Expand Down Expand Up @@ -100,8 +100,8 @@ def build_model(self):
D_fake, D_fake_logits, _ = self.discriminator(G, is_training=True, reuse=True)

# get loss for discriminator
d_loss_real = - tf.reduce_mean(D_real)
d_loss_fake = tf.reduce_mean(D_fake)
d_loss_real = - tf.reduce_mean(D_real_logits)
d_loss_fake = tf.reduce_mean(D_fake_logits)

self.d_loss = d_loss_real + d_loss_fake

Expand All @@ -113,7 +113,7 @@ def build_model(self):
alpha = tf.random_uniform(shape=self.inputs.get_shape(), minval=0.,maxval=1.)
differences = G - self.inputs # This is different from MAGAN
interpolates = self.inputs + (alpha * differences)
D_inter,_,_=self.discriminator(interpolates, is_training=True, reuse=True)
_,D_inter,_=self.discriminator(interpolates, is_training=True, reuse=True)
gradients = tf.gradients(D_inter, [interpolates])[0]
slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2)
Expand Down
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 06cc307

Please sign in to comment.