In [None]:
import numpy as np
import sys
from multiprocessing import Queue, Process
import matplotlib
matplotlib.use('Agg') #for using this notebook as script
from IPython import display as jupyter
#import matplotlib.pyplot as plt

# load global constants
from definitions_main import *

import data_utils
data_utils.init("definitions_main")

import model_utils
model_utils.init("definitions_main")

In [None]:
# generate various models
# pre_generator : compiled generator model, using traditional training with cross-entropy
# generator_trainer : compiled, generator with non-trainable critic
# critic_trainer : compiled, critic with non-trainble generator
pre_generator, generator, generator_trainer, critic, critic_trainer = model_utils.generate_and_compile_models()

In [None]:
# load previous weights if they exist
generator.load_weights("generator.model") 
critic.load_weights("d:/downloads/models/wgan_c/cultivated/3/critic.model")

In [None]:
# populate data queues (launches about 6 additional processes)
full_images_queue, only_L_ch_queue = Queue(20), Queue(5)
tasks = data_utils.populate_trainer_queues(full_images_queue, only_L_ch_queue) 

In [None]:
# calculates probability for a sample to be positive, based on Cantelli's inequality and a list of i.i.d. samples
def probability_for_positive(threshold, samples):
    cut =  int(np.ceil(len(samples)/3))
    mu = np.mean(samples[cut:])
    var = np.var(samples[cut:])/(len(samples)-cut)
    p = 1 - var/(mu**2 +var)
    return p > threshold and mu > 0

In [None]:
# keras always needs #sample target values, even when model loss does not need it
dummy_y = np.zeros((BATCH_SIZE, 1), dtype=np.float32) 
positive_y = np.ones((BATCH_SIZE, 1), dtype=np.float32)
negative_y = -np.ones((BATCH_SIZE, 1), dtype=np.float32)

critic_loss = []
generator_loss = []

# The generator is always trained for #GENERATOR_ITERATIONS per loop, 
# but the ciritic is trained for #training_ratio iterations, which is dynamically determined based on its performnace:
# when it performs well the ratio is decreased, when it performs poorly the ratio is increased
training_ratio = CRITIC_ITERATIONS_INIT

critic_iterations, generator_iterations = 0,0
for outer_loop in range(100000):
    jupyter.clear_output(wait=True)
    print("Queue status: full images %d, L images %d" % 
          (full_images_queue.qsize(), only_L_ch_queue.qsize()) )
    print("Total iterations. Critic: %d, Generator: %d, Outer loop: %d" %(critic_iterations, generator_iterations, outer_loop))
    print("Training ratio Critic-Generator %d" % training_ratio)    

    # go through training_ratio minibatches for one iteration of discriminator training
    w_distances = []
    for j in range(training_ratio):        
        critic_iterations+=1
        # train critic
        real_image_batch = full_images_queue.get()
        L_batch = real_image_batch[:,:,:,0].reshape(BATCH_SHAPE_1)
        # optimizes for low real score and high fake score --> low score means less real
        loss = critic_trainer.train_on_batch([real_image_batch, L_batch], [negative_y, positive_y, positive_y])
        critic_loss.append(loss)

        critic_score_real_images = -loss[1]
        critic_score_fake_images = loss[2]
        gp_loss = loss[3]
        score_gap = critic_score_real_images - critic_score_fake_images
        w_distances.append(score_gap)
        print("Wasserstein distance (aestimate) = %f, Score real image= %f, Gradient penalty = %f" % (score_gap, critic_score_real_images, gp_loss))


    # train generator
    # only train generator when we can expect the wasserstein distance to be positive
    if probability_for_positive(THRESHOLD_A, w_distances):
        for m in range(GENERATOR_ITERATIONS):
            generator_iterations+=1
            L_batch = only_L_ch_queue.get()
            # optimizes for high score --> tries to make it more real
            loss = generator_trainer.train_on_batch(L_batch, negative_y)
            fake_images_score = -loss
            print("Adverserial training loss %f" % fake_images_score)
            generator_loss.append(loss)
    else:
        print("Skip generator training since crtic does not perform well.")

    # training schedule management
    # if score is below target, spend more time on critic, otherwise spend less time on critic
    if probability_for_positive(THRESHOLD_B, w_distances):
        # critic performs well >> reduce the training ratio
        training_ratio = max(training_ratio - 1, CRITIC_ITERATIONS_MIN)
    else:
        # critic performs poolry >> increase the training ratio
        training_ratio = min(training_ratio + 1, CRITIC_ITERATIONS_MAX)

    if outer_loop % 25 == 0:
        # save some sample image results, the models and the accumulated statistics
        data_utils.save_continuous_images(full_images_queue, generator) 
        generator.save("generator.model") 
        critic.save("critic.model")
        np.savetxt("generator_loss.txt", generator_loss)
        np.savetxt("critic_loss.txt", critic_loss)