#### Next two cells are only needed for a Google Colab environment.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd '/content/drive/My Drive/CodingProjects/pix2pix'

In [None]:
import os
import warnings
warnings.filterwarnings('ignore')

from keras.models import load_model
from keras.preprocessing.image import load_img
import numpy as np

from utils import config, training_utils
from utils.data_generator import DataGenerator
from utils.models import get_discriminator_model, get_gan_model, get_generator_model

#### Download Dataset

In [None]:
from keras.utils import get_file

get_file('data.zip', config.PREPROCESSED_DATASET_URL, extract=True, 
         cache_dir='.', cache_subdir='')

## Data Generators

In [None]:
training_generator = DataGenerator(config.TRAINING_SOURCE_DIR, config.TRAINING_TARGET_DIR, 
                                   config.TRAINING_BATCH_SIZE, is_training=True)

validation_generator = DataGenerator(config.VALIDATION_SOURCE_DIR, config.VALIDATION_TARGET_DIR, 
                                     config.VALIDATION_BATCH_SIZE, is_training=False)

# Train GAN Model

In [None]:
def train(gen_model, d_model, gan_model, training_generator, validation_generator=None, 
          epochs=100, initial_epoch=0, ck_pt_freq=5, output_dir='output', save_models=True):
    for epoch_num in range(initial_epoch, epochs):
        for imgs_source, imgs_target_real, d_labels_real, d_labels_fake in training_generator:
            imgs_target_fake = gen_model.predict(imgs_source)
            
            # update discriminator
            d_loss_real = d_model.train_on_batch([imgs_source, imgs_target_real], d_labels_real)
            d_loss_fake = d_model.train_on_batch([imgs_source, imgs_target_fake], d_labels_fake)

            # update generator
            g_loss, _, _ = gan_model.train_on_batch(imgs_source, [d_labels_real, imgs_target_real])
        
        if validation_generator is not None and (epoch_num+1) % ck_pt_freq == 0:
            print(f'epoch {epoch_num+1}, g_loss: {g_loss:.2f}')
            training_utils.save_results(gen_model, d_model, validation_generator, 
                                        epoch_num+1, output_dir, save_models)
        
        training_generator.on_epoch_end()

In [None]:
gen_model = get_generator_model()
d_model = get_discriminator_model()
gan_model = get_gan_model(gen_model, d_model)

In [None]:
train(gen_model, d_model, gan_model, training_generator, validation_generator, 
      epochs=50, ck_pt_freq=1, output_dir='output_pts350', save_models=True)