#### Next two cells are only needed for a Google Colab environment.

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /content/drive


In [2]:
%cd '/content/drive/My Drive/CodingProjects/pix2pix'

/content/drive/My Drive/CodingProjects/pix2pix


In [3]:
import warnings
warnings.filterwarnings('ignore')

from keras.preprocessing.image import load_img
from matplotlib import pyplot
import numpy as np

from utils import config
from utils.data_generator import DataGenerator
from utils.models import get_discriminator_model, get_gan_model, get_generator_model

Using TensorFlow backend.


## Data Generators

In [0]:
training_generator = DataGenerator(config.TRAINING_SOURCE_DIR, config.TRAINING_TARGET_DIR, 
                                   config.TRAINING_BATCH_SIZE)

validation_generator = DataGenerator(config.VALIDATION_SOURCE_DIR, config.VALIDATION_TARGET_DIR, 
                                     config.VALIDATION_BATCH_SIZE, shuffle=False)

In [0]:
def save_validation_results(g_model, validation_generator, epoch_num):
    for idx, (imgs_source, imgs_target_real) in enumerate(validation_generator):
        imgs_target_fake = g_model.predict(imgs_source)
        n_examples = len(imgs_source)

        # scale all pixels from [-1,1] to [0,1]
        imgs_source = (imgs_source + 1) / 2.0
        imgs_target_real = (imgs_target_real + 1) / 2.0
        imgs_target_fake = (imgs_target_fake + 1) / 2.0

        # plot source images
        for i in range(n_examples):
            pyplot.subplot(3, n_examples, 1 + i)
            pyplot.axis('off')
            pyplot.imshow(imgs_source[i])

        # plot generated target image
        for i in range(n_examples):
            pyplot.subplot(3, n_examples, 1 + n_examples + i)
            pyplot.axis('off')
            pyplot.imshow(imgs_target_fake[i])

        # plot real target image
        for i in range(n_examples):
            pyplot.subplot(3, n_examples, 1 + n_examples*2 + i)
            pyplot.axis('off')
            pyplot.imshow(imgs_target_real[i])

        # save plot to file
        img_output_filename = f'output/plot_{epoch_num:05d}_{idx}.png'
        pyplot.savefig(img_output_filename, dpi=300)
        pyplot.close()
    
    # save the generator model
    model_output_filename = f'output/model_{epoch_num:05d}.h5'
    #g_model.save(model_output_filename)

## Train GAN Model

In [0]:
def train(d_model, g_model, gan_model, training_generator, validation_generator=None, n_epochs=100, ck_pt_freq=10):
    y_real = training_generator.get_labels_real()
    y_fake = training_generator.get_labels_fake()
    
    for epoch_num in range(n_epochs):
        for imgs_source, imgs_target_real in training_generator:
            imgs_target_fake = g_model.predict(imgs_source)
            
            # update discriminator
            d_loss_real = d_model.train_on_batch([imgs_source, imgs_target_real], y_real)
            d_loss_fake = d_model.train_on_batch([imgs_source, imgs_target_fake], y_fake)

            # update generator
            g_loss, _, _ = gan_model.train_on_batch(imgs_source, [y_real, imgs_target_real])
        
        if validation_generator is not None and (epoch_num+1) % ck_pt_freq == 0:
            loss_output_str = f'epoch: {epoch_num+1}, d_loss_real: {d_loss_real:.2f}'
            loss_output_str += f', d_loss_fake: {d_loss_fake:.2f}, g: {g_loss:.2f}'
            print(loss_output_str)
            save_validation_results(g_model, validation_generator, epoch_num+1)
            
        training_generator.on_epoch_end()

In [0]:
d_model = get_discriminator_model(config.IMG_SHAPE)
g_model = get_generator_model(config.IMG_SHAPE)
gan_model = get_gan_model(g_model, d_model, config.IMG_SHAPE)

In [0]:
train(d_model, g_model, gan_model, training_generator, validation_generator, n_epochs=100, ck_pt_freq=10)

In [0]:
#!kill -9 -1