# Coupled generative adversarial networks

Ref.: LIU, Ming-Yu; TUZEL, Oncel. Coupled generative adversarial networks. In: Advances in neural information processing systems. 2016. p. 469-477.

In [3]:
import scipy
import numpy as np
import matplotlib.pyplot as plt

In [4]:
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten
from keras.layers import BatchNormalization
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential, Model
from keras.optimizers import Adam

# generator

In [5]:
def build_generators(img_shape, latent_dim):
    """ structure is hard-coded
    """
    
    # Shared weights between generators
    model = Sequential()
    model.add(Dense(256, input_dim=latent_dim))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    noise = Input(shape=(latent_dim,))
    feature_repr = model(noise)

    # Generator 1
    g1 = Dense(1024)(feature_repr)
    g1 = LeakyReLU(alpha=0.2)(g1)
    g1 = BatchNormalization(momentum=0.8)(g1)
    g1 = Dense(np.prod(img_shape), activation='tanh')(g1)
    img1 = Reshape(img_shape)(g1)

    # Generator 2
    g2 = Dense(1024)(feature_repr)
    g2 = LeakyReLU(alpha=0.2)(g2)
    g2 = BatchNormalization(momentum=0.8)(g2)
    g2 = Dense(np.prod(img_shape), activation='tanh')(g2)
    img2 = Reshape(img_shape)(g2)

    model.summary()

    return Model(noise, img1), Model(noise, img2)

# discriminator

In [6]:
def build_discriminators(img_shape):

    img1 = Input(shape=img_shape)
    img2 = Input(shape=img_shape)

    # Shared discriminator layers
    model = Sequential()
    model.add(Flatten(input_shape=img_shape))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(256))
    model.add(LeakyReLU(alpha=0.2))

    img1_embedding = model(img1)
    img2_embedding = model(img2)

    # Discriminator 1
    validity1 = Dense(1, activation='sigmoid')(img1_embedding)
    # Discriminator 2
    validity2 = Dense(1, activation='sigmoid')(img2_embedding)

    return Model(img1, validity1), Model(img2, validity2)

In [7]:
def sample_images(epoch, g1, g2):
    r, c = 4, 4
    noise = np.random.normal(0, 1, (r * int(c/2), 100))
    gen_imgs1 = g1.predict(noise)
    gen_imgs2 = g2.predict(noise)

    gen_imgs = np.concatenate([gen_imgs1, gen_imgs2])

    # Rescale images 0 - 1
    gen_imgs = 0.5 * gen_imgs + 0.5

    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig("images/mnist_%d.png" % epoch)
    plt.close()

# train the model

In [8]:
def train(g1, g2, d1, d2, combined,
          epochs, batch_size=128, sample_interval=50):

    # Load the dataset
    (X_train, _), (_, _) = mnist.load_data()

    # Rescale -1 to 1
    X_train = (X_train.astype(np.float32) - 127.5) / 127.5
    X_train = np.expand_dims(X_train, axis=3)

    # Images in domain A and B (rotated)
    X1 = X_train[:int(X_train.shape[0]/2)]
    X2 = X_train[int(X_train.shape[0]/2):]
    X2 = scipy.ndimage.interpolation.rotate(X2, 90, axes=(1, 2))

    # Adversarial ground truths
    valid = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))

    for epoch in range(epochs):

        # ----------------------
        #  Train Discriminators
        # ----------------------

        # Select a random batch of images
        idx = np.random.randint(0, X1.shape[0], batch_size)
        imgs1 = X1[idx]
        imgs2 = X2[idx]

        # Sample noise as generator input
        noise = np.random.normal(0, 1, (batch_size, 100))

        # Generate a batch of new images
        gen_imgs1 = g1.predict(noise)
        gen_imgs2 = g2.predict(noise)

        # Train the discriminators
        d1_loss_real = d1.train_on_batch(imgs1, valid)
        d2_loss_real = d2.train_on_batch(imgs2, valid)
        d1_loss_fake = d1.train_on_batch(gen_imgs1, fake)
        d2_loss_fake = d2.train_on_batch(gen_imgs2, fake)
        d1_loss = 0.5 * np.add(d1_loss_real, d1_loss_fake)
        d2_loss = 0.5 * np.add(d2_loss_real, d2_loss_fake)


        # ------------------
        #  Train Generators
        # ------------------

        g_loss = combined.train_on_batch(noise, [valid, valid])

        # Plot the progress
        print ("%d [D1 loss: %f, acc.: %.2f%%] [D2 loss: %f, acc.: %.2f%%] [G loss: %f]" \
            % (epoch, d1_loss[0], 100*d1_loss[1], d2_loss[0], 100*d2_loss[1], g_loss[0]))

        # If at save interval => save generated image samples
        if epoch % sample_interval == 0:
            sample_images(epoch, g1, g2)

# main()

In [9]:
img_rows = 28
img_cols = 28
channels = 1
img_shape = (img_rows, img_cols, channels)
latent_dim = 100

In [10]:
# create optimizer
optimizer = Adam(0.0002, 0.5)

Instructions for updating:
If using Keras pass *_constraint arguments to layers.


In [11]:
# Build and compile both discriminator
d1, d2 = build_discriminators(img_shape)

d1.compile(loss='binary_crossentropy',
    optimizer=optimizer,
    metrics=['accuracy'])

d2.compile(loss='binary_crossentropy',
    optimizer=optimizer,
    metrics=['accuracy'])

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


In [12]:
# Build both generator
g1, g2 = build_generators(img_shape, latent_dim)

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_5 (Dense)              (None, 256)               25856     
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 256)               0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 256)               1024      
_________________________________________________________________
dense_6 (Dense)              (None, 512)               131584    
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
batch_normalization_2 (Batch (None, 512)               2048      
Total params: 160,512
Trainable params: 158,976
Non-trainable params: 1,536
____________________________________________

In [13]:
# The generator takes noise as input and generated imgs
z = Input(shape=(latent_dim,))
img1 = g1(z)
img2 = g2(z)

In [14]:
# For the combined model we will only train the generators
d1.trainable = False
d2.trainable = False



In [15]:
# The valid takes generated images as input and determines validity
valid1 = d1(img1)
valid2 = d2(img2)

In [16]:
# The combined model  (stacked generators and discriminators)
# Trains generators to fool discriminators
combined = Model(z, [valid1, valid2])
combined.compile(loss=['binary_crossentropy', 'binary_crossentropy'],
                            optimizer=optimizer)

# Train the model

In [None]:
epochs=3000
# epochs=30000
train(g1, g2, d1, d2, combined,
      epochs=epochs, batch_size=32, sample_interval=200)




  'Discrepancy between trainable weights and collected trainable'


0 [D1 loss: 0.918094, acc.: 46.88%] [D2 loss: 0.604709, acc.: 62.50%] [G loss: 1.501780]


  'Discrepancy between trainable weights and collected trainable'


1 [D1 loss: 0.397216, acc.: 96.88%] [D2 loss: 0.343117, acc.: 89.06%] [G loss: 1.536288]
2 [D1 loss: 0.325359, acc.: 92.19%] [D2 loss: 0.328741, acc.: 82.81%] [G loss: 1.660351]
3 [D1 loss: 0.303079, acc.: 92.19%] [D2 loss: 0.316144, acc.: 85.94%] [G loss: 1.851377]
4 [D1 loss: 0.281995, acc.: 95.31%] [D2 loss: 0.295434, acc.: 93.75%] [G loss: 2.076351]
5 [D1 loss: 0.265806, acc.: 98.44%] [D2 loss: 0.301856, acc.: 87.50%] [G loss: 2.408001]
6 [D1 loss: 0.222210, acc.: 98.44%] [D2 loss: 0.226912, acc.: 98.44%] [G loss: 2.985777]
7 [D1 loss: 0.170685, acc.: 98.44%] [D2 loss: 0.209495, acc.: 98.44%] [G loss: 3.375984]
8 [D1 loss: 0.156022, acc.: 100.00%] [D2 loss: 0.151447, acc.: 100.00%] [G loss: 3.787433]
9 [D1 loss: 0.114053, acc.: 100.00%] [D2 loss: 0.140640, acc.: 100.00%] [G loss: 4.260783]
10 [D1 loss: 0.087218, acc.: 100.00%] [D2 loss: 0.098014, acc.: 100.00%] [G loss: 4.671067]
11 [D1 loss: 0.081186, acc.: 100.00%] [D2 loss: 0.080641, acc.: 100.00%] [G loss: 4.970495]
12 [D1 loss

92 [D1 loss: 0.114201, acc.: 95.31%] [D2 loss: 0.035227, acc.: 98.44%] [G loss: 11.568077]
93 [D1 loss: 0.170189, acc.: 90.62%] [D2 loss: 0.030626, acc.: 100.00%] [G loss: 13.491845]
94 [D1 loss: 0.792813, acc.: 75.00%] [D2 loss: 1.587731, acc.: 32.81%] [G loss: 7.175697]
95 [D1 loss: 0.266301, acc.: 87.50%] [D2 loss: 0.334540, acc.: 85.94%] [G loss: 7.555627]
96 [D1 loss: 0.155245, acc.: 89.06%] [D2 loss: 0.284816, acc.: 90.62%] [G loss: 10.394873]
97 [D1 loss: 0.035605, acc.: 98.44%] [D2 loss: 0.079459, acc.: 95.31%] [G loss: 10.828533]
98 [D1 loss: 0.103069, acc.: 96.88%] [D2 loss: 0.113216, acc.: 95.31%] [G loss: 11.014616]
99 [D1 loss: 0.017494, acc.: 100.00%] [D2 loss: 0.117127, acc.: 96.88%] [G loss: 10.881680]
100 [D1 loss: 0.045461, acc.: 100.00%] [D2 loss: 0.061805, acc.: 96.88%] [G loss: 10.304309]
101 [D1 loss: 0.044672, acc.: 100.00%] [D2 loss: 0.063246, acc.: 96.88%] [G loss: 10.617430]
102 [D1 loss: 0.016717, acc.: 100.00%] [D2 loss: 0.036694, acc.: 100.00%] [G loss: 10.

184 [D1 loss: 0.843212, acc.: 45.31%] [D2 loss: 0.382824, acc.: 79.69%] [G loss: 7.665075]
185 [D1 loss: 0.735584, acc.: 54.69%] [D2 loss: 0.321174, acc.: 85.94%] [G loss: 6.476449]
186 [D1 loss: 0.528376, acc.: 68.75%] [D2 loss: 0.383958, acc.: 82.81%] [G loss: 5.355880]
187 [D1 loss: 0.503927, acc.: 73.44%] [D2 loss: 0.537280, acc.: 73.44%] [G loss: 5.317835]
188 [D1 loss: 0.448755, acc.: 81.25%] [D2 loss: 0.325535, acc.: 81.25%] [G loss: 6.483703]
189 [D1 loss: 0.585494, acc.: 57.81%] [D2 loss: 0.498598, acc.: 73.44%] [G loss: 5.516922]
190 [D1 loss: 0.572999, acc.: 60.94%] [D2 loss: 0.433203, acc.: 71.88%] [G loss: 7.529742]
191 [D1 loss: 0.709991, acc.: 59.38%] [D2 loss: 0.320691, acc.: 81.25%] [G loss: 6.344965]
192 [D1 loss: 0.669646, acc.: 50.00%] [D2 loss: 0.554466, acc.: 70.31%] [G loss: 4.964752]
193 [D1 loss: 0.539823, acc.: 62.50%] [D2 loss: 0.342125, acc.: 85.94%] [G loss: 5.697367]
194 [D1 loss: 0.684245, acc.: 48.44%] [D2 loss: 0.391266, acc.: 79.69%] [G loss: 6.502151]

276 [D1 loss: 0.653919, acc.: 54.69%] [D2 loss: 0.646939, acc.: 54.69%] [G loss: 1.610748]
277 [D1 loss: 0.654350, acc.: 50.00%] [D2 loss: 0.635938, acc.: 57.81%] [G loss: 1.657691]
278 [D1 loss: 0.675288, acc.: 48.44%] [D2 loss: 0.617238, acc.: 64.06%] [G loss: 1.766468]
279 [D1 loss: 0.702686, acc.: 48.44%] [D2 loss: 0.693721, acc.: 48.44%] [G loss: 1.683612]
280 [D1 loss: 0.633698, acc.: 54.69%] [D2 loss: 0.669446, acc.: 46.88%] [G loss: 1.635872]
281 [D1 loss: 0.633999, acc.: 59.38%] [D2 loss: 0.650157, acc.: 51.56%] [G loss: 1.711440]
282 [D1 loss: 0.666389, acc.: 51.56%] [D2 loss: 0.651758, acc.: 48.44%] [G loss: 1.724130]
283 [D1 loss: 0.638381, acc.: 59.38%] [D2 loss: 0.610584, acc.: 65.62%] [G loss: 1.952689]
284 [D1 loss: 0.640811, acc.: 57.81%] [D2 loss: 0.624832, acc.: 60.94%] [G loss: 1.921546]
285 [D1 loss: 0.654418, acc.: 56.25%] [D2 loss: 0.640638, acc.: 54.69%] [G loss: 1.744110]
286 [D1 loss: 0.634410, acc.: 64.06%] [D2 loss: 0.683082, acc.: 46.88%] [G loss: 1.608051]

369 [D1 loss: 0.649106, acc.: 56.25%] [D2 loss: 0.602324, acc.: 76.56%] [G loss: 1.727276]
370 [D1 loss: 0.626374, acc.: 59.38%] [D2 loss: 0.581048, acc.: 70.31%] [G loss: 1.684571]
371 [D1 loss: 0.646881, acc.: 54.69%] [D2 loss: 0.627415, acc.: 57.81%] [G loss: 1.725569]
372 [D1 loss: 0.616269, acc.: 62.50%] [D2 loss: 0.628430, acc.: 46.88%] [G loss: 1.753407]
373 [D1 loss: 0.656470, acc.: 57.81%] [D2 loss: 0.603479, acc.: 62.50%] [G loss: 1.670519]
374 [D1 loss: 0.627793, acc.: 62.50%] [D2 loss: 0.590040, acc.: 62.50%] [G loss: 1.698863]
375 [D1 loss: 0.650976, acc.: 48.44%] [D2 loss: 0.565456, acc.: 76.56%] [G loss: 1.769246]
376 [D1 loss: 0.631937, acc.: 53.12%] [D2 loss: 0.592884, acc.: 67.19%] [G loss: 1.732720]
377 [D1 loss: 0.675485, acc.: 46.88%] [D2 loss: 0.632708, acc.: 57.81%] [G loss: 1.686882]
378 [D1 loss: 0.662807, acc.: 51.56%] [D2 loss: 0.593269, acc.: 76.56%] [G loss: 1.761949]
379 [D1 loss: 0.647198, acc.: 60.94%] [D2 loss: 0.591179, acc.: 65.62%] [G loss: 1.911217]

461 [D1 loss: 0.572991, acc.: 71.88%] [D2 loss: 0.585293, acc.: 71.88%] [G loss: 1.908547]
462 [D1 loss: 0.607741, acc.: 65.62%] [D2 loss: 0.583373, acc.: 73.44%] [G loss: 1.926935]
463 [D1 loss: 0.608061, acc.: 64.06%] [D2 loss: 0.521038, acc.: 87.50%] [G loss: 1.925344]
464 [D1 loss: 0.614720, acc.: 57.81%] [D2 loss: 0.549417, acc.: 75.00%] [G loss: 2.015236]
465 [D1 loss: 0.687753, acc.: 43.75%] [D2 loss: 0.564469, acc.: 73.44%] [G loss: 2.209119]
466 [D1 loss: 0.629111, acc.: 67.19%] [D2 loss: 0.563425, acc.: 71.88%] [G loss: 2.007701]
467 [D1 loss: 0.622495, acc.: 67.19%] [D2 loss: 0.607698, acc.: 71.88%] [G loss: 1.983852]
468 [D1 loss: 0.690281, acc.: 57.81%] [D2 loss: 0.554197, acc.: 76.56%] [G loss: 2.049994]
469 [D1 loss: 0.656624, acc.: 56.25%] [D2 loss: 0.598251, acc.: 67.19%] [G loss: 1.815611]
470 [D1 loss: 0.650643, acc.: 54.69%] [D2 loss: 0.573157, acc.: 76.56%] [G loss: 1.960771]
471 [D1 loss: 0.702201, acc.: 48.44%] [D2 loss: 0.610442, acc.: 67.19%] [G loss: 1.942354]

554 [D1 loss: 0.656677, acc.: 53.12%] [D2 loss: 0.551433, acc.: 79.69%] [G loss: 1.933757]
555 [D1 loss: 0.657222, acc.: 56.25%] [D2 loss: 0.592855, acc.: 62.50%] [G loss: 1.855131]
556 [D1 loss: 0.671259, acc.: 54.69%] [D2 loss: 0.555087, acc.: 73.44%] [G loss: 1.788247]
557 [D1 loss: 0.668951, acc.: 59.38%] [D2 loss: 0.558062, acc.: 71.88%] [G loss: 1.885299]
558 [D1 loss: 0.689522, acc.: 56.25%] [D2 loss: 0.614891, acc.: 65.62%] [G loss: 1.859081]
559 [D1 loss: 0.635548, acc.: 62.50%] [D2 loss: 0.539409, acc.: 82.81%] [G loss: 1.994881]
560 [D1 loss: 0.661093, acc.: 53.12%] [D2 loss: 0.610273, acc.: 65.62%] [G loss: 1.815862]
561 [D1 loss: 0.629016, acc.: 60.94%] [D2 loss: 0.568264, acc.: 78.12%] [G loss: 1.848925]
562 [D1 loss: 0.618825, acc.: 65.62%] [D2 loss: 0.590310, acc.: 71.88%] [G loss: 1.891282]
563 [D1 loss: 0.655436, acc.: 59.38%] [D2 loss: 0.616844, acc.: 68.75%] [G loss: 1.993249]
564 [D1 loss: 0.629020, acc.: 65.62%] [D2 loss: 0.552940, acc.: 75.00%] [G loss: 1.990035]

646 [D1 loss: 0.632455, acc.: 60.94%] [D2 loss: 0.544936, acc.: 79.69%] [G loss: 1.965943]
647 [D1 loss: 0.642681, acc.: 57.81%] [D2 loss: 0.540561, acc.: 82.81%] [G loss: 1.984786]
648 [D1 loss: 0.598486, acc.: 73.44%] [D2 loss: 0.543306, acc.: 81.25%] [G loss: 2.137443]
649 [D1 loss: 0.678221, acc.: 59.38%] [D2 loss: 0.603401, acc.: 68.75%] [G loss: 1.920970]
650 [D1 loss: 0.619958, acc.: 56.25%] [D2 loss: 0.593940, acc.: 67.19%] [G loss: 1.887103]
651 [D1 loss: 0.653784, acc.: 64.06%] [D2 loss: 0.610702, acc.: 68.75%] [G loss: 1.821408]
652 [D1 loss: 0.609857, acc.: 71.88%] [D2 loss: 0.600728, acc.: 62.50%] [G loss: 1.837106]
653 [D1 loss: 0.626838, acc.: 64.06%] [D2 loss: 0.540287, acc.: 76.56%] [G loss: 1.853718]
654 [D1 loss: 0.621068, acc.: 64.06%] [D2 loss: 0.589529, acc.: 70.31%] [G loss: 1.900631]
655 [D1 loss: 0.640601, acc.: 59.38%] [D2 loss: 0.608633, acc.: 67.19%] [G loss: 2.034640]
656 [D1 loss: 0.601379, acc.: 73.44%] [D2 loss: 0.583598, acc.: 73.44%] [G loss: 2.008693]

739 [D1 loss: 0.568002, acc.: 76.56%] [D2 loss: 0.609512, acc.: 68.75%] [G loss: 2.022719]
740 [D1 loss: 0.676395, acc.: 60.94%] [D2 loss: 0.569555, acc.: 73.44%] [G loss: 1.986013]
741 [D1 loss: 0.600103, acc.: 78.12%] [D2 loss: 0.589144, acc.: 73.44%] [G loss: 2.057251]
742 [D1 loss: 0.617674, acc.: 60.94%] [D2 loss: 0.597452, acc.: 71.88%] [G loss: 2.007678]
743 [D1 loss: 0.612027, acc.: 71.88%] [D2 loss: 0.636504, acc.: 67.19%] [G loss: 2.058329]
744 [D1 loss: 0.591829, acc.: 65.62%] [D2 loss: 0.532425, acc.: 79.69%] [G loss: 2.085129]
745 [D1 loss: 0.616142, acc.: 59.38%] [D2 loss: 0.646717, acc.: 57.81%] [G loss: 1.862557]
746 [D1 loss: 0.652308, acc.: 59.38%] [D2 loss: 0.609483, acc.: 68.75%] [G loss: 1.946368]
747 [D1 loss: 0.620820, acc.: 62.50%] [D2 loss: 0.645504, acc.: 62.50%] [G loss: 1.970520]
748 [D1 loss: 0.653098, acc.: 60.94%] [D2 loss: 0.598462, acc.: 70.31%] [G loss: 2.110260]
749 [D1 loss: 0.660742, acc.: 50.00%] [D2 loss: 0.647441, acc.: 54.69%] [G loss: 1.988805]

831 [D1 loss: 0.541850, acc.: 78.12%] [D2 loss: 0.527481, acc.: 78.12%] [G loss: 2.170136]
832 [D1 loss: 0.549221, acc.: 85.94%] [D2 loss: 0.514486, acc.: 84.38%] [G loss: 2.127663]
833 [D1 loss: 0.596434, acc.: 68.75%] [D2 loss: 0.549695, acc.: 84.38%] [G loss: 1.967735]
834 [D1 loss: 0.555623, acc.: 73.44%] [D2 loss: 0.533949, acc.: 81.25%] [G loss: 2.122163]
835 [D1 loss: 0.552028, acc.: 81.25%] [D2 loss: 0.557519, acc.: 73.44%] [G loss: 2.063494]
836 [D1 loss: 0.548270, acc.: 78.12%] [D2 loss: 0.547294, acc.: 75.00%] [G loss: 2.098363]
837 [D1 loss: 0.633860, acc.: 64.06%] [D2 loss: 0.544058, acc.: 75.00%] [G loss: 2.076979]
838 [D1 loss: 0.589020, acc.: 71.88%] [D2 loss: 0.549479, acc.: 78.12%] [G loss: 2.194468]
839 [D1 loss: 0.598377, acc.: 60.94%] [D2 loss: 0.594772, acc.: 65.62%] [G loss: 2.171523]
