In [1]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /content/drive


In [2]:
%cd '/content/drive/My Drive/CodingProjects/pix2pix'

/content/drive/My Drive/CodingProjects/pix2pix


In [3]:
import warnings
warnings.filterwarnings('ignore')

from keras.preprocessing.image import load_img
from matplotlib import pyplot
import numpy as np

from utils.models import get_discriminator_model, get_gan_model, get_generator_model

Using TensorFlow backend.


In [0]:
PATCH_HEIGHT = 16
PATCH_WIDTH = 16

IMG_HEIGHT = 256
IMG_WIDTH = 256
IMG_N_CHANNELS = 3
IMG_SHAPE = IMG_HEIGHT, IMG_WIDTH, IMG_N_CHANNELS

TRAINING_COMPRESSED_DATASET_FILEPATH = 'data/training/training_set.npz'
VALIDATION_COMPRESSED_DATASET_FILEPATH = 'data/validation/validation_set.npz'

In [0]:
def load_dataset(filename):
    data = np.load(filename)
    X_src, X_target_real = data['source'], data['target']
    # scale from [0,255] to [-1,1]
    X_src = (X_src - 127.5) / 127.5
    X_target_real = (X_target_real - 127.5) / 127.5
    
    return X_src, X_target_real

In [0]:
def summarize_performance(epoch_num, g_model, validation_dataset):
    X_src, X_target_real = validation_dataset
    X_target_fake = g_model.predict(X_src)
    n_examples = len(X_src)
    
    # scale all pixels from [-1,1] to [0,1]
    X_src = (X_src + 1) / 2.0
    X_target_real = (X_target_real + 1) / 2.0
    X_target_fake = (X_target_fake + 1) / 2.0
    
    # plot source images
    for i in range(n_examples):
        pyplot.subplot(3, n_examples, 1 + i)
        pyplot.axis('off')
        pyplot.imshow(X_src[i])
        
    # plot generated target image
    for i in range(n_examples):
        pyplot.subplot(3, n_examples, 1 + n_examples + i)
        pyplot.axis('off')
        pyplot.imshow(X_target_fake[i])
        
    # plot real target image
    for i in range(n_examples):
        pyplot.subplot(3, n_examples, 1 + n_examples*2 + i)
        pyplot.axis('off')
        pyplot.imshow(X_target_real[i])
        
    # save plot to file
    filename1 = 'output/plot_%06d.png' % (epoch_num+1)
    pyplot.savefig(filename1, dpi=300)
    pyplot.close()
    
    # save the generator model
    filename2 = 'output/model_%06d.h5' % (epoch_num+1)
    #g_model.save(filename2)
    print('>Saved: %s and %s' % (filename1, filename2))
    
def shuffle_arrays_unison(arr1, arr2):
    assert(len(arr1) == len(arr2), 'Arrays have different lengths.')
    
    random_idxs = np.arange(len(arr1))
    np.random.shuffle(random_idxs)
    
    return arr1[random_idxs], arr2[random_idxs]

In [0]:
def train(d_model, g_model, gan_model, training_dataset, validation_dataset=None, 
          n_epochs=100, ck_pt_freq=10, batch_size=1):
    src_imgs, target_imgs = training_dataset
    n_examples = len(src_imgs)
    
    y_real = np.ones((batch_size, PATCH_HEIGHT, PATCH_WIDTH, 1))
    y_fake = np.zeros((batch_size, PATCH_HEIGHT, PATCH_WIDTH, 1))
    
    for epoch_num in range(n_epochs):
        src_imgs, target_imgs = shuffle_arrays_unison(src_imgs, target_imgs)
        
        for idx in range(n_examples):
            idxs = [idx] # batch of size 1
            X_src = src_imgs[idxs]
            X_target_real = target_imgs[idxs]
            X_target_fake = g_model.predict(X_src)

            # update discriminator
            d_loss_real = d_model.train_on_batch([X_src, X_target_real], y_real)
            d_loss_fake = d_model.train_on_batch([X_src, X_target_fake], y_fake)

            # update generator
            g_loss, _, _ = gan_model.train_on_batch(X_src, [y_real, X_target_real])
        
        if validation_dataset is not None and (epoch_num+1) % ck_pt_freq == 0:
            # summarize model performance
            print('>%d, d_loss_real[%.3f] d_loss_fake[%.3f] g[%.3f]' % (epoch_num+1, d_loss_real, d_loss_fake, g_loss))
            summarize_performance(epoch_num, g_model, validation_dataset)

In [0]:
training_dataset = load_dataset(TRAINING_COMPRESSED_DATASET_FILEPATH)
validation_dataset = load_dataset(VALIDATION_COMPRESSED_DATASET_FILEPATH)

In [9]:
d_model = get_discriminator_model(IMG_SHAPE)
g_model = get_generator_model(IMG_SHAPE)
gan_model = get_gan_model(g_model, d_model, IMG_SHAPE)

W0827 03:17:39.736654 140230570268544 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.

W0827 03:17:39.761650 140230570268544 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

W0827 03:17:39.769209 140230570268544 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:4115: The name tf.random_normal is deprecated. Please use tf.random.normal instead.

W0827 03:17:39.810423 140230570268544 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:174: The name tf.get_default_session is deprecated. Please use tf.compat.v1.get_default_session instead.

W0827 03:17:39.811682 140230570268

In [11]:
train(d_model, g_model, gan_model, training_dataset, validation_dataset, n_epochs=100, ck_pt_freq=10)

>9, d_loss_real[0.145] d_loss_fake[0.168] g[4.369]
>Saved: output/plot_000010.png and output/model_000010.h5
>19, d_loss_real[0.296] d_loss_fake[0.117] g[4.804]
>Saved: output/plot_000020.png and output/model_000020.h5
>29, d_loss_real[0.260] d_loss_fake[0.223] g[5.811]
>Saved: output/plot_000030.png and output/model_000030.h5
>39, d_loss_real[0.265] d_loss_fake[0.331] g[4.699]
>Saved: output/plot_000040.png and output/model_000040.h5
>49, d_loss_real[0.048] d_loss_fake[0.246] g[7.198]
>Saved: output/plot_000050.png and output/model_000050.h5


In [0]:
#!kill -9 -1