In [1]:
from numpy import load
from numpy import zeros
from numpy import ones
from numpy.random import randint
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Conv2DTranspose
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Concatenate
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import Flatten
from matplotlib import pyplot
from PIL import Image
import os, sys
from sklearn.model_selection import train_test_split
import numpy

In [2]:
# Define the encoder block based on the original paper
def define_encoder_block(layer, filtersNo, batchnorm=True):
    
    # init weights from a Gaussian distribution with mean 0 and standard deviation 0.02
    init = RandomNormal(stddev=0.02)
    
    # in the original paper, all convolution kernels are (4,4), with stride 2. Stride for decoder means downsampling.
    x = Conv2D(filtersNo, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(layer)
    
    # Conditional batch normalization (important for the first layer)
    if batchnorm:
        x = BatchNormalization()(x, training=True)
        
    # All ReLUs in the encoder are leaky!
    x = LeakyReLU(alpha=0.2)(x)
    
    return x

In [3]:
# Define the decoder block based on the original paper
def decoder_block(layer, skip, filtersNo, dropout=True, batch=True):
    
    # init weights from a Gaussian distribution with mean 0 and standard deviation 0.02
    init = RandomNormal(stddev=0.02)
    
    # in the original paper, all convolution kernels are (4,4), with stride 2. Stride for decoder means upsampling.
    x = Conv2DTranspose(filtersNo, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(layer)

    # All layers in the original paper have batch normalization, although we set an if statement just to play with the model
    if batch:
        x = BatchNormalization()(x, training=True)
        
    # Some decoder layers don't have dropout
    if dropout:
        x = Dropout(0.5)(x, training=True)
        
    # Merge with skip connection
    x = Concatenate()([x, skip])
    
    # All ReLUs in the decoder are not leaky!
    x = Activation('relu')(x)
    
    return x

In [4]:
# Define the generator based on encoder/decoder
def define_generator():
    
    # init weights from a Gaussian distribution with mean 0 and standard deviation 0.02
    init = RandomNormal(stddev=0.02)
    
    # image input
    inputImage = Input(shape=(128,128,1))
    
    ###### Encoder
    
    # C64, input (128,128,1), output (64,64,64)
    encoderLayer1 = define_encoder_block(inputImage, 64, batchnorm=False)
    
    #C128, input (64,64,64), output (32,32,128)
    encoderLayer2 = define_encoder_block(encoderLayer1, 128)
    
    #C256, input (32,32,128), output (16,16,256)
    encoderLayer3 = define_encoder_block(encoderLayer2, 256)
    
    #C512, input (16,16,256), output (8,8,512)
    encoderLayer4 = define_encoder_block(encoderLayer3, 512)
    
    #C512, input (8,8,512), output (4,4,512)
    encoderLayer5 = define_encoder_block(encoderLayer4, 512)
    
    #C512, input (4,4,512), output (2,2,512)
    encoderLayer6 = define_encoder_block(encoderLayer5, 512)
    
    ###### Bottleneck layer, will have an input of (2,2,512) and an output of (1,1,512)
    bottleneck = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(encoderLayer6)
    bottleneck = Activation('relu')(bottleneck)
    
    ###### Decoder, with skip connection
    
    #CD512
    decoderLayer1 = decoder_block(bottleneck, encoderLayer6, 512)
    
    #CD512
    decoderLayer2 = decoder_block(decoderLayer1, encoderLayer5, 512)
    
    #C512
    decoderLayer3 = decoder_block(decoderLayer2, encoderLayer4, 512, dropout=False)
    
    #C256
    decoderLayer4 = decoder_block(decoderLayer3, encoderLayer3, 256, dropout=False)
    
    #C128
    decoderLayer5 = decoder_block(decoderLayer4, encoderLayer2, 128, dropout=False)
    
    #C64
    decoderLayer6 = decoder_block(decoderLayer5, encoderLayer1, 64, dropout=False)
    
    # Output with tanh function, as mentioned in the original paper. Output will be (128x128x3)
    g = Conv2DTranspose(3, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(decoderLayer6)
    outputImage = Activation('tanh')(g)
    
    # Define model
    model = Model(inputImage, outputImage)
    
    return model

In [5]:
# Define the 70x70 discriminator as in the original paper
def define_discriminator():
    
    # init weights from a Gaussian distribution with mean 0 and standard deviation 0.02
    init = RandomNormal(stddev=0.02)
    
    # source image input
    source = Input(shape=(128,128,1))
    
    # target image input
    target = Input(shape=(128,128,3))
    
    # concatenate images channel-wise
    merged = Concatenate()([source, target])
    
    # C64
    d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(merged)
    d = LeakyReLU(alpha=0.2)(d)
    
    # C128
    d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
    d = BatchNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    
    # C256
    d = Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
    d = BatchNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    
    # C512
    d = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
    d = BatchNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    
    # patch output
    d = Conv2D(1, (8,8), strides=(8,8), padding='same', kernel_initializer=init)(d)
    patch_out = Activation('sigmoid')(d)
    patch_out = Flatten()(patch_out)
    
    # define model
    model = Model([source, target], patch_out)
    
    # compile model
    opt = Adam(lr=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt, loss_weights=[0.5])
    return model

In [6]:
# Define Pix2Pix GAN
def pix2pix(generator, discriminator):
    
    # make weights in the discriminator not trainable
    for layer in discriminator.layers:
        if not isinstance(layer, BatchNormalization):
            layer.trainable = False
            
    # define the source image
    source = Input(shape=(128,128,1))
    
    # connect the source image to the generator input
    genOut = generator(source)
    
    # connect the source input and generator output to the discriminator input
    disOut = discriminator([source, genOut])
    
    # src image as input, generated image and classification output
    model = Model(source, [disOut, genOut])
    
    # compile model
    opt = Adam(lr=0.0002, beta_1=0.5)
    model.compile(loss=['binary_crossentropy', 'mae'], optimizer=opt, loss_weights=[1,100])
    
    return model

In [7]:
def generateFromDataset(trainX, trainY, samples):
    
    # Choose random images from both input and output
    no = randint(0, trainX.shape[0], samples)
    gx, gy = trainX[no], trainY[no]
    
    # Set y-labels to 1, as these images are from dataset
    y = ones((samples, 1))
    return [gx, gy], y

In [8]:
def generateFromGenerator(generator, samples):
    # Generate fake instance
    x = generator.predict(samples)
    
    # Labels will be zero because they come from the generator
    y = zeros((len(x), 1))
    
    return x, y

In [18]:
def summarize_performance(step, g_model, trainX, trainY, n_samples=5):
    # select a sample of input images
    [X_realA, X_realB], _ = generateFromDataset(trainX, trainY, n_samples)
    
    # Generate fake samples
    X_fakeB, _ = generateFromGenerator(g_model, X_realA)
    
    # Plot black and white images
    for i in range(n_samples):
        pyplot.subplot(3, n_samples, 1 + i)
        pyplot.axis('off')
        pyplot.imshow(X_realA[i])
        
    # Plot generated images
    for i in range(n_samples):
        pyplot.subplot(3, n_samples, 1 + n_samples + i)
        pyplot.axis('off')
        pyplot.imshow(X_fakeB[i])
        
    # Plot expected output images
    for i in range(n_samples):
        pyplot.subplot(3, n_samples, 1 + n_samples*2 + i)
        pyplot.axis('off')
        pyplot.imshow(X_realB[i])
        
    # Save plot to file
    filename1 = 'images/plot_%06d.png' % (step+1)
    pyplot.savefig(filename1)
    pyplot.close()
    
    # Save generator model
    filename2 = 'models/model_%06d.h5' % (step+1)
    g_model.save(filename2)
    print('>Saved: %s and %s' % (filename1, filename2))

In [19]:
def train(discriminator, generator, gan, epochs=100000, samplesPerEpoch=250):
    
    # Load the data
    bwData = load("Flickr8kblackandwhite1dim.npy")
    colorData = load("flickr8k_shuffled.npy")
    y = numpy.ones((8091,1))
    
    BWSplit = numpy.array_split(bwData, 2)
    colorSplit = numpy.array_split(colorData, 2)
    glossMin = 999
    # manually enumerate epochs
    for i in range(epochs):
        
        # Generate real samples
        [realX, realY], realLabel = generateFromDataset(BWSplit[0], colorSplit[0], samplesPerEpoch)
        
        # Generate fake samples
        fakeY, fakeLabel = generateFromGenerator(generator, realX)
        
        # Update discriminator on real samples
        realLoss = discriminator.train_on_batch([realX, realY], realLabel)
        
        # Update discriminator on fake samples
        fakeLoss = discriminator.train_on_batch([realX, fakeY], fakeLabel)
        
        # Update generator
        generatorLoss, _, _ = gan.train_on_batch(realX, [realLabel, realY])
        
        # summarize performance
        print('>%d, d1[%.3f] d2[%.3f] g[%.3f]' % (i+1, realLoss, fakeLoss, generatorLoss))
        if glossMin > generatorLoss:
            summarize_performance(i, generator, BWSplit[0], colorSplit[0])
            glossMin = generatorLoss

In [20]:
d = define_discriminator()
g = define_generator()
p2p = pix2pix(g,d)

In [None]:
train(d,g,p2p)

>1, d1[0.793] d2[3.266] g[53.214]


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


>Saved: images/plot_000001.png and models/model_000001.h5


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


>2, d1[0.003] d2[3.653] g[51.900]
>Saved: images/plot_000002.png and models/model_000002.h5


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


>3, d1[0.249] d2[1.736] g[51.697]
>Saved: images/plot_000003.png and models/model_000003.h5
>4, d1[0.765] d2[0.967] g[52.894]


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


>5, d1[0.100] d2[0.437] g[51.178]
>Saved: images/plot_000005.png and models/model_000005.h5


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


>6, d1[0.131] d2[0.465] g[50.237]
>Saved: images/plot_000006.png and models/model_000006.h5
>7, d1[0.335] d2[0.739] g[50.699]


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


>8, d1[0.084] d2[0.045] g[48.057]
>Saved: images/plot_000008.png and models/model_000008.h5


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


>9, d1[0.053] d2[0.707] g[47.351]
>Saved: images/plot_000009.png and models/model_000009.h5


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


>10, d1[0.734] d2[1.631] g[47.016]
>Saved: images/plot_000010.png and models/model_000010.h5


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


>11, d1[0.006] d2[0.000] g[46.390]
>Saved: images/plot_000011.png and models/model_000011.h5


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


>12, d1[0.032] d2[0.000] g[43.456]
>Saved: images/plot_000012.png and models/model_000012.h5


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


>13, d1[0.019] d2[0.015] g[38.995]
>Saved: images/plot_000013.png and models/model_000013.h5
>14, d1[0.004] d2[1.408] g[39.190]


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


>15, d1[2.328] d2[4.322] g[34.695]
>Saved: images/plot_000015.png and models/model_000015.h5
>16, d1[0.077] d2[1.118] g[38.503]


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


>17, d1[2.033] d2[0.251] g[30.942]
>Saved: images/plot_000017.png and models/model_000017.h5
>18, d1[0.001] d2[3.253] g[35.036]


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


>19, d1[0.971] d2[0.004] g[28.902]
>Saved: images/plot_000019.png and models/model_000019.h5
>20, d1[0.117] d2[3.419] g[34.553]


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


>21, d1[3.516] d2[0.002] g[24.396]
>Saved: images/plot_000021.png and models/model_000021.h5
>22, d1[0.000] d2[2.461] g[27.470]


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


>23, d1[0.515] d2[0.061] g[20.713]
>Saved: images/plot_000023.png and models/model_000023.h5
>24, d1[0.157] d2[2.434] g[30.656]


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


>25, d1[4.930] d2[0.009] g[19.504]
>Saved: images/plot_000025.png and models/model_000025.h5


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


>26, d1[0.025] d2[3.357] g[19.400]
>Saved: images/plot_000026.png and models/model_000026.h5


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


>27, d1[1.561] d2[0.663] g[18.281]
>Saved: images/plot_000027.png and models/model_000027.h5


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


>28, d1[1.362] d2[2.838] g[17.124]
>Saved: images/plot_000028.png and models/model_000028.h5
>29, d1[0.993] d2[2.255] g[21.918]


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


>30, d1[2.884] d2[0.925] g[15.213]
>Saved: images/plot_000030.png and models/model_000030.h5


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


>31, d1[0.830] d2[0.999] g[13.584]
>Saved: images/plot_000031.png and models/model_000031.h5
>32, d1[0.403] d2[1.009] g[15.991]
>33, d1[1.735] d2[0.324] g[14.721]
