#### Next two cells are only needed for a Google Colab environment.

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /content/drive


In [2]:
%cd '/content/drive/My Drive/CodingProjects/pix2pix'

/content/drive/My Drive/CodingProjects/pix2pix


In [3]:
import os
import warnings
warnings.filterwarnings('ignore')

from keras.models import load_model
from keras.preprocessing.image import load_img
from matplotlib import pyplot
import numpy as np

from utils import config
from utils.data_generator import DataGenerator
from utils.models import get_discriminator_model, get_gan_model, get_generator_model

Using TensorFlow backend.


## Data Generators

In [0]:
training_generator = DataGenerator(config.TRAINING_SOURCE_DIR, config.TRAINING_TARGET_DIR, 
                                   config.TRAINING_BATCH_SIZE, is_training=True)

validation_generator = DataGenerator(config.VALIDATION_SOURCE_DIR, config.VALIDATION_TARGET_DIR, 
                                     config.VALIDATION_BATCH_SIZE, is_training=False)

In [0]:
def save_validation_results(gen_model, d_model, validation_generator, epoch_num, output_dir, save_models):
    output_imgs_dir = os.path.join(output_dir, 'images')
    output_models_dir = os.path.join(output_dir, 'models')
    os.makedirs(output_imgs_dir, exist_ok=True)
    os.makedirs(output_models_dir, exist_ok=True)

    for idx, (imgs_source, imgs_target_real, _, _) in enumerate(validation_generator):
        imgs_target_fake = gen_model.predict(imgs_source)
        n_examples = len(imgs_source)

        # scale all pixels from [-1,1] to [0,1]
        imgs_source = (imgs_source + 1) / 2.0
        imgs_target_real = (imgs_target_real + 1) / 2.0
        imgs_target_fake = (imgs_target_fake + 1) / 2.0

        # plot source images
        for i in range(n_examples):
            pyplot.subplot(3, n_examples, 1 + i)
            pyplot.axis('off')
            pyplot.imshow(imgs_source[i])

        # plot generated target image
        for i in range(n_examples):
            pyplot.subplot(3, n_examples, 1 + n_examples + i)
            pyplot.axis('off')
            pyplot.imshow(imgs_target_fake[i])

        # plot real target image
        for i in range(n_examples):
            pyplot.subplot(3, n_examples, 1 + n_examples*2 + i)
            pyplot.axis('off')
            pyplot.imshow(imgs_target_real[i])

        # save plot to file
        img_output_filename = f'plot_{epoch_num:05d}_{idx}.png'
        filepath = os.path.join(output_imgs_dir, img_output_filename)
        pyplot.savefig(filepath, dpi=400)
        pyplot.close()
    
    if save_models:
        gen_model_output_filename = os.path.join(output_models_dir, f'{epoch_num:05d}_gen_model.h5')
        d_model_output_filename = os.path.join(output_models_dir, f'{epoch_num:05d}_d_model.h5')
        gen_model.save(gen_model_output_filename)
        d_model.save(d_model_output_filename)

## Train GAN Model

In [0]:
def train(gen_model, d_model, gan_model, training_generator, validation_generator=None, 
          epochs=100, initial_epoch=0, ck_pt_freq=10, output_dir='output', save_models=True):
    for epoch_num in range(initial_epoch, epochs):
        for imgs_source, imgs_target_real, d_labels_real, d_labels_fake in training_generator:
            imgs_target_fake = gen_model.predict(imgs_source)
            
            # update discriminator
            d_loss_real = d_model.train_on_batch([imgs_source, imgs_target_real], d_labels_real)
            d_loss_fake = d_model.train_on_batch([imgs_source, imgs_target_fake], d_labels_fake)

            # update generator
            g_loss, _, _ = gan_model.train_on_batch(imgs_source, [d_labels_real, imgs_target_real])
        
        if validation_generator is not None and (epoch_num+1) % ck_pt_freq == 0:
            print(f'epoch {epoch_num+1}, g_loss: {g_loss:.2f}')
            save_validation_results(gen_model, d_model, validation_generator, 
                                    epoch_num+1, output_dir, save_models)
        
        training_generator.on_epoch_end()

In [7]:
LOAD_FROM_CK_PT = True
if LOAD_FROM_CK_PT:
    num = '00210'
    gen_model = load_model(f'output_b4_pts250/models/{num}_gen_model.h5')
    d_model = load_model(f'output_b4_pts250/models/{num}_d_model.h5')
else:
    gen_model = get_generator_model()
    d_model = get_discriminator_model()

gan_model = get_gan_model(gen_model, d_model, L1_loss_lambda=100)







Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


In [8]:
train(gen_model, d_model, gan_model, training_generator, validation_generator, 
      epochs=1000, initial_epoch=210, ck_pt_freq=10, output_dir='output_b4_pts250', save_models=True)

epoch 220, g_loss: 12.85
epoch 230, g_loss: 11.91
epoch 240, g_loss: 17.53
epoch 250, g_loss: 14.21
epoch 260, g_loss: 11.38
epoch 270, g_loss: 10.91
epoch 280, g_loss: 13.50
epoch 290, g_loss: 10.82
epoch 300, g_loss: 14.38
epoch 310, g_loss: 11.90
epoch 320, g_loss: 15.90


KeyboardInterrupt: ignored

In [0]:
!kill -9 -1

In [0]:
# save_validation_results(gen_model, validation_generator, 99, output_dir='output_lambda_100_dot')