In [None]:
%pylab inline
import sys
sys.path.append('/home/peter/code/projects/')
from aidevutil import *
from tqdm import tqdm_notebook as tqdm
import warnings
warnings.filterwarnings('ignore')

In [None]:
# Create the dataset

In [None]:
x_train.shape

In [None]:
gen_optimizer = Adam(lr=0.0002)#, beta_1=0.5)
disc_optimizer = Adam(lr=0.00016)#, beta_1=0.5)#Adam(0.0002, 0.5)#SGD(0.0002)

In [None]:
img_shape=(imdim,imdim,3)
noise_shape = (64,)

In [None]:
def build_discriminator():
    
    model = Sequential()
    
    model.add(
            Conv2D(16, (5, 5),
            padding='same',
            input_shape=img_shape)
            )
    model.add(LeakyReLU(alpha=0.2))
    
    model.add(AveragePooling2D(pool_size=(2, 2)))
    
    model.add(BatchNormalization())
    model.add(Conv2D(8, (3, 3), padding='same',))
    model.add(LeakyReLU(alpha=0.2))
    
    model.add(AveragePooling2D(pool_size=(2, 2)))
    model.add(Flatten())

    model.add(BatchNormalization())
    model.add(Dense(128))
    model.add(LeakyReLU(alpha=0.2))
    
    model.add(Dense(1))
    model.add(Activation('sigmoid'))
    
    model.summary()

    img = Input(shape=img_shape)
    validity = model(img)

    return Model(img, validity)

In [None]:
def build_generator():

    model = Sequential()

    #model.add(BatchNormalization())
    model.add(Dense(512, input_shape=noise_shape))
    model.add(LeakyReLU(alpha=0.2))
    #model.add(RepeatVector(4))
    model.add(Reshape((1, 1, 512)))
    model.add(Dropout(0.5))

    #
    #model.add(Conv2D(64, (3, 3), padding='same'))
    model.add(BatchNormalization())
    model.add(Conv2DTranspose(256, (5,5), strides=2, padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    
    model.add(BatchNormalization())
    model.add(Conv2DTranspose(128, (5,5), strides=2, padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    
    #model.add(UpSampling2D(size=(2, 2)))
    #model.add(Conv2D(32, (3, 3), padding='same'))
    model.add(BatchNormalization())
    model.add(Conv2DTranspose(64, (3,3), strides=2, padding='same'))
    model.add(LeakyReLU(alpha=0.2))#LeakyReLU(alpha=0.2))
    
    model.add(BatchNormalization())
    model.add(Conv2DTranspose(32, (3,3), strides=2, padding='same'))
    model.add(LeakyReLU(alpha=0.2))#LeakyReLU(alpha=0.2))
    
    #model.add(BatchNormalization())
    model.add(Conv2DTranspose(3, (3,3), strides=2, padding='same'))
    model.add(Activation('tanh'))#LeakyReLU(alpha=0.2))

    #model.add(BatchNormalization())
    #model.add(Conv2DTranspose(3, (3,3), strides=1, padding='same'))
    #model.add(Activation('relu'))
    
    #model.add(UpSampling2D(size=(2, 2)))
    #model.add(BatchNormalization())
    #model.add(Conv2D(3, (3, 3), strides=2, padding='same'))
    #model.add(Conv2DTranspose(3, (3,3), padding='same'))
    #model.add(Activation('tanh'))

    noise = Input(shape=noise_shape)
    img = model(noise)
    
    mdl = Model(noise, img)
    model.summary()

    return mdl

In [None]:
def save_imgs(epoch):
    r, c = 5, 5
    noise = np.random.normal(0, 1, (r * c,)+noise_shape)
    gen_imgs = generator.predict(noise)

    # Rescale images 0 - 1
    gen_imgs = 0.5 * gen_imgs + 0.5
    gen_imgs = np.clip(gen_imgs, 0, 1)

    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_imgs[cnt, :,:,:]#, cmap='gray'
                           )
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig("out_images/faces_%d.png" % epoch)
    plt.close()

In [None]:
def show_imgs(epoch, log_dloss, log_gloss, log_dacc, log_gacc):
    r, c = 1, 3+2
    noise = np.random.normal(0, 1, (r * c,)+noise_shape)
    gen_imgs = generator.predict(noise)

    # Rescale images 0 - 1
    gen_imgs = 0.5 * gen_imgs + 0.5
    gen_imgs = np.clip(gen_imgs, 0, 1)

    fig, axs = plt.subplots(r, c, figsize=(16,3))
    cnt = 0
    for j in range(c-2):
        axs[j].imshow(gen_imgs[cnt, :,:,:]#, cmap='gray'
                       )
        axs[j].axis('off')
        cnt += 1
    
    axs[cnt].plot(log_dloss, label='D loss', alpha=0.6) 
    axs[cnt].plot(log_gloss, label='G loss', alpha=0.6)
    axs[cnt].legend()
    cnt += 1
    axs[cnt].plot(log_dacc, label='D acc', alpha=0.6) 
    axs[cnt].plot(log_gacc, label='G acc', alpha=0.6)
    axs[cnt].legend()
    cnt += 1
    
    fig.tight_layout()
    plt.show()

In [None]:
generator = build_generator()
generator.compile(loss='binary_crossentropy', optimizer=gen_optimizer)

In [None]:
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer=disc_optimizer, metrics=['accuracy'])

In [None]:
# The generator takes noise as input and generates imgs
z = Input(shape=noise_shape)
img = generator(z)

# The valid takes generated images as input and determines validity
valid = discriminator(img)

# For the combined model we will only train the generator
# Trainable is set after it is compiled, so discriminator will still train when called, 
# but it will not be trained after combined is compiled, see https://github.com/eriklindernoren/Keras-GAN/issues/73
discriminator.trainable = False

# The combined model  (stacked generator and discriminator) takes
# noise as input => generates images => determines validity
combined = Model(z, valid)
combined.compile(loss='binary_crossentropy', optimizer=gen_optimizer, metrics=['accuracy'])
combined.summary()

In [None]:
# Rescale -1 to 1
epochs, batch_size, save_interval = 150000, 64, 100

half_batch = int(batch_size / 2)

log_dloss = []
log_gloss = []
log_dacc = []
log_gacc = []

def moving_average(a, n=20):
    ret = np.cumsum(a, dtype=float)
    ret[n:] = ret[n:] - ret[:-n]
    return ret[n - 1:] / n

try:
    for epoch in tqdm(range(epochs)):

        # ---------------------
        #  Train Discriminator
        # ---------------------

        # Select a random half batch of images
        idx = np.random.randint(0, x_train.shape[0], half_batch)
        imgs = x_train[idx]

        noise = np.random.normal(0, 1, (half_batch,)+noise_shape)

        # Generate a half batch of new images
        gen_imgs = generator.predict(noise)

        # Train the discriminator
        d_loss_real = discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
        d_loss_fake = discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)


        # ---------------------
        #  Train Generator
        # ---------------------
        
        noise = np.random.normal(0, 1, (batch_size,)+noise_shape)

        # The generator wants the discriminator to label the generated samples
        # as valid (ones)
        valid_y = np.array([1] * batch_size)

        # Train the generator
        g_loss = combined.train_on_batch(noise, valid_y)

        # Plot the progress
        print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f, acc.: %.2f%%]              " % (epoch, d_loss[0], 100*d_loss[1], g_loss[0], 100*g_loss[1]), end='\r')
        
        log_dloss.append(d_loss[0])
        log_gloss.append(g_loss[0])
        log_dacc.append(d_loss[1])
        log_gacc.append(g_loss[1])

        # If at save interval => save generated image samples
        if epoch % save_interval == 0:
            clear_output(wait=True)
            show_imgs(epoch, 
                      moving_average(log_dloss), 
                      moving_average(log_gloss), 
                      moving_average(log_dacc), 
                      moving_average(log_gacc))
            #generator.save_weights('models/face_generator_weights.h5')
except KeyboardInterrupt:
    pass

In [None]:
plot(log_dloss, label='D');
plot(log_gloss, label='G');
legend();

In [None]:
generator.save_weights('models/face_generator_32x32_weights.h5')

In [None]:
im = generator.predict(array([[0.0]*noise_shape[0]]))
imshow(im[0].reshape(*img_shape));

In [None]:
r, c = 7, 7
noise = zeros((49,noise_shape[0])) 
noise[:,0] = linspace(-8, 8, num=49)
gen_imgs = generator.predict(noise)

# Rescale images 0 - 1
gen_imgs = 0.5 * gen_imgs + 0.5

fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
    for j in range(c):
        axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
        axs[i,j].axis('off')
        cnt += 1