In [1]:
from prototypes import Generator, Discriminator
from data.loader import generate_labels, denorm_image, load_data
import tensorflow as tf
import numpy as np
import os
%matplotlib inline


def display_image(*images, col=None, width=20):
    from matplotlib import pyplot as plt

    if col is None:
        col = len(images)
    row = np.math.ceil(len(images) / col)
    plt.figure(figsize=(width, (width + 1) * row / col))
    for i, image in enumerate(images):
        plt.subplot(row, col, i + 1)
        plt.axis('off')
        plt.imshow(image, cmap='gray')
    plt.subplots_adjust(wspace=0, hspace=0)
    plt.show()
    
num_classes = 14
image_size = 128
min_neurons = 64
noise_size = 128
batch_size = 64

    
data_root = os.path.join(os.path.expanduser('~'), 'datasets', 'artworks')
train_gen, valid_gen = load_data(data_root, batch_size=batch_size, image_width=image_size, split=.05)
label2idx = valid_gen.class_indices
idx2label = {value: key for key, value in label2idx.items()}
gen = Generator(num_classes, image_size, bn=False)
disc = Discriminator(num_classes, image_size, min_neurons)

Found 34959 images belonging to 14 classes.
Found 1832 images belonging to 14 classes.


In [22]:
gen.load('GANGogh10000')
disc.load('GANGogh10000')

True

In [23]:
SAMPLES = 64


class Candidate:
    def __init__(self, arr_image, critic_score, label, label_confidence):
        self.image = arr_image
        self.critic_score = critic_score
        self.label = label
        self.label_confidence = label_confidence


def select_best_images(label, num_samples):
    LOOK_AT = 1
    BATCH_SIZE = 64
    input_label = generate_labels(BATCH_SIZE, num_classes, condition=label)
    list_candidates = []
    for j in range(LOOK_AT):
        noise = tf.random.uniform(shape=[BATCH_SIZE, noise_size], minval=-1., maxval=1.)
        samples = gen.model.predict([noise, input_label])
        pred_realness, pred_labels = disc.model.predict(samples)
        pred_realness = pred_realness.squeeze()
        guess = np.argmax(pred_labels, axis=1)
        confidence = np.amax(pred_labels, axis=1)
        indices = list(np.argwhere(guess == i))
        samples = denorm_image(samples)
        for k in indices:
            k = k.squeeze()
            candidate = Candidate(samples[k], pred_realness[k], label, confidence[k])
            list_candidates.append(candidate)
    list_candidates.sort(key=lambda x: x.label_confidence, reverse=True)
    list_candidates = list_candidates[:num_samples * 3]
    list_candidates.sort(key=lambda x: x.critic_score, reverse=True)
    list_candidates = list_candidates[:num_samples]
    return list_candidates

In [24]:
for i in range(14):
    print(idx2label[i])
    candidates = select_best_images(i, 16)
    print(*[(i.critic_score, i.label_confidence) for i in candidates])
#     display_image(*[i.image for i in candidates], col=4, width=15)

    

abstract

animal-painting

cityscape

figurative

flower-painting

genre-painting

landscape
(2704208.0, 1.0) (2543536.2, 1.0) (2505433.5, 1.0) (2504969.2, 1.0) (2482940.2, 1.0) (2468786.8, 1.0) (2464449.2, 1.0) (2424318.5, 1.0) (2388862.0, 1.0) (2354084.8, 1.0) (2242093.5, 1.0) (2201366.2, 1.0) (2163498.2, 1.0) (2130853.5, 1.0) (2128357.2, 1.0) (2048733.5, 1.0)
marina

mythological-painting

nude-painting-nu

portrait
(2558373.5, 1.0) (2545971.2, 1.0) (2393372.8, 1.0) (2114965.5, 1.0)
religious-painting

still-life

symbolic-painting
(-76095.164, 1.0) (-512046.8, 1.0)
