In [None]:
import numpy as np
import sys
from multiprocessing import Queue, Process
import matplotlib
matplotlib.use('Agg') #for using this notebook as script
from IPython import display as jupyter
#import matplotlib.pyplot as plt

# load global constants
from definitions_plot import *

import data_utils_plot
data_utils_plot.init("definitions_plot")

import model_utils
model_utils.init("definitions_plot")

In [None]:
from skimage import color, transform
import matplotlib.pyplot as plt
import imageio

In [None]:
# load generator after adverserial training
pre_generator, adverserial_model, generator_trainer, critic, critic_trainer = model_utils.generate_and_compile_models()
adverserial_model.load_weights("d:/downloads/models/wgan_c/wildV7/2/generator.model") 

In [None]:
# load generator after direct training with cross-entropy loss
direct_model, generator, generator_trainer, critic, critic_trainer = model_utils.generate_and_compile_models()
direct_model.load_weights("d:/downloads/models/directV3/fields/pre_generator.model")

In [None]:
%matplotlib inline

In [None]:
# load data
full_images_queue = Queue(1)
tasks = data_utils_plot.populate_trainer_queues(full_images_queue)

In [None]:
"""
the pre_generator uses cross entropy and yields predictions for a/b in a one-hot-encoded format
this function extracts the most likely image
"""
def get_most_likely_image(L_batch, direct_batch):
    a, b = direct_batch
    rgb_batch = []
    for ind in range(L_batch.shape[0]):
        a_gen = data_utils_plot.decode_bin(np.argmax(a[ind, :, :, :], axis=-1)).reshape(IMAGE_SHAPE_1)
        b_gen = data_utils_plot.decode_bin(np.argmax(b[ind, :, :, :], axis=-1)).reshape(IMAGE_SHAPE_1)
        Lab_gen = np.concatenate([L_batch[ind,:,:,:], a_gen, b_gen], -1)
        rgb_gen = color.lab2rgb((Lab_gen * 100).astype(np.float64))
        rgb_batch.append(rgb_gen)
    return np.stack(rgb_batch)   

In [None]:
counter = 8

In [None]:
# inference
img_batch = full_images_queue.get()
L_batch = img_batch[:,:,:,0].reshape(BATCH_SHAPE_1)
adverserial_images = adverserial_model.predict(L_batch)
direct_batch = direct_model.predict(L_batch)
direct_images = get_most_likely_image(L_batch, direct_batch)

In [None]:
"""
generates 4 images: raw, grey, generator after pre-training, generator after WGAN
"""
for k in range(6):    
    # grey image
    img = np.repeat(L_batch[k,:,:],3,-1).reshape(IMAGE_SHAPE_3)
    img = (img*255).astype(np.uint8)
    name = "./demo/grey_%d_%d.png" % (k, counter)
    imageio.imwrite(name, img)
    
    # plotting and saving adverserial images
    img = color.lab2rgb(adverserial_images[k,:,:,:]*100)
    img = (img*255).astype(np.uint8)
    name = "./demo/adverserial_%d_%d.png" % (k, counter)
    imageio.imwrite(name, img)
    
    # plotting and saving pre_generator images
    img = direct_images[k,:,:,:]
    img = (img*255).astype(np.uint8)
    name = "./demo/direct_%d_%d.png" % (k, counter)
    imageio.imwrite(name, img)
    
    # plotting and saving raw images
    img = color.lab2rgb(img_batch[k,:,:,:]*100)
    img = (img*255).astype(np.uint8)
    name = "./demo/raw_%d_%d.png" % (k, counter)
    imageio.imwrite(name, img)