# Colorization Test Model
## This program will load an already-trained model, run it against a test set, and save individual test images
## Output will be saved in ./Test_Output

In [1]:
from __future__ import print_function, division

from keras.layers import Input, Dense, Flatten, Dropout, Reshape, Concatenate
from keras.layers import BatchNormalization, Activation, Conv2D, Conv2DTranspose, UpSampling2D
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Model
from keras.optimizers import Adam
from keras.engine.saving import load_model

from keras.datasets import cifar10
import keras.backend as K

import matplotlib.pyplot as plt
import os
import sys
import numpy as np

%pylab inline

from PIL import Image
from tqdm import tnrange, tqdm_notebook, tqdm
import cv2
import random

import tensorflow as tf

from tensorflow.keras.preprocessing.image import ImageDataGenerator

Using TensorFlow backend.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


Populating the interactive namespace from numpy and matplotlib


In [2]:
K.clear_session()






In [3]:
def list_image_files(directory):
    files = sorted(os.listdir(directory))
    return [os.path.join(directory, f) for f in files if is_an_image_file(f)]

def is_an_image_file(filename):
    IMAGE_EXTENSIONS = ['.png', '.jpg', '.jpeg']
    for ext in IMAGE_EXTENSIONS:
        if ext in filename:
            return True
    return False

In [4]:
def load_image(path):
    img = cv2.imread(path[0])
    
    # Make sure all images are 256 x 256 by cropping them
    r, c = img.shape[:2]
    r_diff = (r - 256) // 2
    c_diff = (c - 256) // 2
    cropped = img[r_diff:256 + r_diff, c_diff:256 + c_diff] 
    return cropped

def load_images(path, n_images=-1):
    all_image_paths = list_image_files(path)
    
    if n_images < 0:
        n_images = len(all_image_paths)
    images_l, images_ab = [], []
    
    # Initialize a progress bar with max of n_images
    pbar = tqdm_notebook(total = n_images, desc="Loading Images...")
    
    for path in zip(all_image_paths):
        img = load_image(path)
        lab_img = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
        lab_img = preprocess_image(lab_img)
        
        l = lab_img[:,:,0]
        l = l[:,:,np.newaxis]
        # Include all 3 channels, overwrite 1st channel with 0's
        ab = lab_img[:,:,1:]

        images_l.append(l)
        images_ab.append(ab)

        images_loaded = len(images_l)
        
        # Increase progress by one
        pbar.update(1)
        
        if images_loaded > n_images - 1: 
            break

    return {
        'l': np.array(images_l),
        'ab': np.array(images_ab)
    }

In [5]:
RESHAPE = (256,256)

def preprocess_image(cv_img):
    img = (cv_img - 127.5) / 127.5
    return img

def deprocess_image(img):
    img = (img * 127.5) + 127.5
    return img.astype('uint8')

In [6]:
def save_image(np_arr, path):
    img = np_arr * 127.5 + 127.5
    im = Image.fromarray(img)
    im.save(path)

In [7]:
def get_generator(H, W, k):
    # Inputs: height and width of the input image
    # Returns the model, which generates the AB channels

    # Pix2pix adapted from 
    # https://github.com/eriklindernoren/Keras-GAN/blob/master/pix2pix/pix2pix.py

    def conv2d(layer_input, filters, f_size=4, bn=True):
        """Layers used during downsampling"""
        d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
        d = LeakyReLU(alpha=0.2)(d)
        if bn:
            d = BatchNormalization(momentum=0.8)(d)
        return d

    def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
        """Layers used during upsampling"""
        u = UpSampling2D(size=2)(layer_input)
        u = Conv2D(filters, kernel_size=f_size, strides=1, padding='same', activation='relu')(u)
        if dropout_rate:
            u = Dropout(dropout_rate)(u)
        u = BatchNormalization(momentum=0.8)(u)
        u = Concatenate()([u, skip_input])
        return u

    gf = 64 # Number of filters in the first layer of G

    noise_in = Input(shape=(100,))
    condition_in = Input(shape=(H, W, 1))
    
    # pass noise through a FC layer to get it to the right size
    noise = Dense(H * H)(noise_in)

    # reshape to be the size of an image channel
    noise = Reshape((H, H, 1))(noise)
    
    # stick the (somewhat modified) noise as the second channel after
    # the gray input. Assuming new dimension of hid will be
    # B x 256 x 256 x 2, where B is the batch size.
#     d0 = Concatenate(axis=-1)([condition_in, noise])
    d0 = condition_in # Don't need noise since it's being ignored anyway

    # U-NET
    # Downsampling
    d1 = conv2d(d0, gf, bn=False)
    d2 = conv2d(d1, gf*2)
    d3 = conv2d(d2, gf*4)
    d4 = conv2d(d3, gf*8)
    d5 = conv2d(d4, gf*8)
    d6 = conv2d(d5, gf*8)
    d7 = conv2d(d6, gf*8)

    # Upsampling
    u1 = deconv2d(d7, d6, gf*8)
    u2 = deconv2d(u1, d5, gf*8)
    u3 = deconv2d(u2, d4, gf*8)
    u4 = deconv2d(u3, d3, gf*4)
    u5 = deconv2d(u4, d2, gf*2)
    u6 = deconv2d(u5, d1, gf)

    u7 = UpSampling2D(size=2)(u6)
    
    # Final 2-channel AB image with values between -1 and 1
    img_out = Conv2D(2*k, kernel_size=4, strides=1, padding='same', activation='tanh', name='pred_ab')(u7)

    # Make Model
    model = Model(inputs=[noise_in, condition_in], outputs=img_out)
    
    # Show summary of layers
    print("Generator Model:")
    model.summary()

    return model


In [8]:
def get_discriminator(H, W, k):
    # Inputs: height and width of the input image
    # Returns the model, which predicts real/fake
    # over a set of spatial regions (i.e., predicts a matrix instead of a scalar).

    # Pix2pix adapted from 
    # https://github.com/eriklindernoren/Keras-GAN/blob/master/pix2pix/pix2pix.py

    def d_layer(layer_input, filters, f_size=4, bn=True):
        """Discriminator layer"""
        d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
        d = LeakyReLU(alpha=0.2)(d)
        if bn:
            d = BatchNormalization(momentum=0.8)(d)
        return d

    # Number of filters in the first layer of D
    df = 64

    img_in = Input(shape=(H, W, 2*k)) # AB channels
    condition_in = Input(shape=(H, W, 1)) # L channel
    
    # Concat the L and AB channels
    concat_imgs = Concatenate()([condition_in, img_in])

    d1 = d_layer(concat_imgs, df, bn=False)
    d2 = d_layer(d1, df*2)
    d3 = d_layer(d2, df*4)
    d4 = d_layer(d3, df*8)

    # validity map is a one-channel matrix 1/16 the size of the input (halved 4 times).
    # Each number predicts whether a region of the input is real/fake.
    validity = Conv2D(1*k, kernel_size=4, strides=1, padding='same', name='pred_valid')(d4)

    # Build Model
    model = Model(inputs=[img_in, condition_in], outputs=validity)

    # Show summary of layers
    print("Disciminator Model:")
    model.summary()

    return model

In [9]:
def min_k_diff(y_true, y_pred):
    # Shape: (Batch, H, W, k, 2)
    y_true = K.reshape(y_true, (-1, H, W, k, 2))
    y_pred = K.reshape(y_pred, (-1, H, W, k, 2))

    print("true:", y_true.shape)
    print("pred:", y_pred.shape)

    diff = y_true - y_pred
    diff = K.abs(diff)
    diff = K.mean(diff, axis=(1, 2, 4)) # mean of (H, W, 2) leaves (B, k)
    
    loss_metric = diff

    min_for_each_batch = K.min(loss_metric, axis=1)
    return K.sum(min_for_each_batch) #* .01

In [10]:
from keras.preprocessing import image

def generate_noise(n_samples, noise_dim):
    X = np.random.normal(0, 1, size=(n_samples, noise_dim))
    return X

## Parameters

In [11]:
# Model parameters
trained_model_name = "lsun_colorization_full_model"
k = 5

# Testing parameters
num_test_imgs = 100 
random_select_gt_or_colorzed = True

# dataset = 'new_circles/'
# dataset = '../Colorization_GAN/circle_pairs/'
dataset = 'lsun/'
# dataset = 'places2/'

In [12]:
# Find where model is located
saved_model_location = "Output/" + trained_model_name + "/GAN_Weights_Epoch_100.h5"

# Create folder to store output
generic_output_folder = "Test_Output/"
new_output_folder = trained_model_name + "/"
save_path = generic_output_folder + new_output_folder
if random_select_gt_or_colorzed:
    save_path += "random_colorized_or_ground_truth/"
else:
    save_path += "all_predictions/"

# Ensure output can save in desired location
if not os.path.exists(save_path):
    os.makedirs(save_path)

In [13]:
# ===================================
# COULD NOT HANDLE LARGE TRAINING SET
# ===================================

# Get training images
# Load dataset, convert to LAB, normalize to range [-1, 1]
data = load_images(dataset + 'test', num_test_imgs)

# Only want l channel
l_channel_imgs, ab_channel_imgs = data['l'], data['ab']

HBox(children=(IntProgress(value=0, description='Loading Images...', style=ProgressStyle(description_width='in…




In [14]:
# GAN creation
H = W = 256

# Discriminator loss - MSE seems to produce better results
#discrim_loss = 'binary_crossentropy'
discrim_loss = 'mse'

# 1. Discriminator
# Calculate output shape of D (PatchGAN)
patch = H // 2**4 # Input size gets cut in half 4 times
discriminator = get_discriminator(H, W, k)
discriminator.name = 'discrim_model' # Need a name for the loss dictionary below
discriminator.compile(optimizer=Adam(2e-4, 0.5), loss=discrim_loss, metrics=['accuracy'])
discriminator.trainable = False # For the combined model we will only train the generator
print("\n")

# 2. Generator
generator = get_generator(H, W, k)
generator.name = 'gen_model' # Need a name for the loss dictionary below

# 3. GAN
gan_noise_in = Input(shape=(100,))
gan_condition_in = Input(shape=(H, W, 1))

# By conditioning on L generate a fake version of AB
fake_AB = generator([gan_noise_in, gan_condition_in])

# Discriminator determines validity of AB images / L pairs
print("fake_ab:", fake_AB.shape)

print("gan_condition_in:", gan_condition_in.shape)

valid = discriminator([fake_AB, gan_condition_in])

losses = {'gen_model': min_k_diff, # used to be 'gen_loss'
          'discrim_model': discrim_loss}
loss_weights = {'gen_model': 100.0, 'discrim_model': 1.0}

gan = Model(inputs=[gan_noise_in, gan_condition_in], outputs=[fake_AB, valid])
gan.compile(optimizer=Adam(2e-4, 0.5), loss=losses, loss_weights=loss_weights)
gan.summary()




Disciminator Model:
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            (None, 256, 256, 1)  0                                            
__________________________________________________________________________________________________
input_1 (InputLayer)            (None, 256, 256, 10) 0                                            
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 256, 256, 11) 0           input_2[0][0]                    
                                                                 input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 128, 128, 64) 11328       concatenate_1[0][0]  

fake_ab: (?, 256, 256, 10)
gan_condition_in: (?, 256, 256, 1)
true: (?, 256, 256, 5, 2)
pred: (?, 256, 256, 5, 2)
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_5 (InputLayer)            (None, 100)          0                                            
__________________________________________________________________________________________________
input_6 (InputLayer)            (None, 256, 256, 1)  0                                            
__________________________________________________________________________________________________
gen_model (Model)               (None, 256, 256, 10) 41855626    input_5[0][0]                    
                                                                 input_6[0][0]                    
______________________________________________________________________________________________

In [15]:
print(saved_model_location)

Output/lsun_colorization_full_model/GAN_Weights_Epoch_100.h5


In [16]:
# Load the weights
discriminator.load_weights('Output/lsun_colorization_full_model/Discriminator_Weights_Epoch_100.h5')
gan.load_weights('Output/lsun_colorization_full_model/GAN_Weights_Epoch_100.h5')

print("Model loaded")

Model loaded


In [17]:
noise = generate_noise(num_test_imgs, 100)

# colorized_predictions is [num_test_imgs, k]
colorized_predictions = generator.predict([noise, l_channel_imgs])

print("Predictions for", len(colorized_predictions), "images complete!")

Predictions for 100 images complete!


In [18]:
def save_rgb_img(l, ab, i, filename):
    # Make sure ab is the right type, generated imgs change to float32
    ab = ab.astype(np.float64)
    
    # Merge
    merged = cv2.merge((l, ab))
    
    # Get between 0, 255
    deprocessed = deprocess_image(merged)
    
    # Change to BGR (Curse you CV2!!!)
    rgb = cv2.cvtColor(deprocessed, cv2.COLOR_LAB2BGR)
    
    # Save
    cv2.imwrite(save_path + filename, rgb)

In [19]:
print("Merging, deprocessing, converting to RGB, and saving images")

num_ground_truth_selected = 0

# Loop through images that were colorized
for i, img in enumerate(colorized_predictions):
    filename = str(i+1).zfill(len(str(num_test_imgs)))

    # For each img, use either ground truth or random colorized prediction
    if random_select_gt_or_colorzed:
        # Determine whether to use ground truth or prediction
        use_ground_truth = random.choice([True, False])
        
        if use_ground_truth:
            num_ground_truth_selected += 1
            filename += "_Ground_Truth.png"
            save_rgb_img(l_channel_imgs[i], ab_channel_imgs[i], i, filename)
        else:
            prediction_i = random.randint(0,k-1)
            filename += "_Colorized_.png"
            prediction = img[:,:,2*prediction_i:2*prediction_i+2]
            save_rgb_img(l_channel_imgs[i], prediction, i, filename)
    else:
        # Loop through predictions
        for j in range(k):
            prediction = img[:,:,2*j:2*j+2]
            save_rgb_img(l_channel_imgs[i], prediction, i, filename + "-" + str(j+1) + ".png")
            
    if (i + 1) % 25 == 0:
        print("--", i+1, "completed")
        
print("DONE!")

if random_select_gt_or_colorzed:
    print("\nBreakdown of Ground Truth vs. Colorized Selected:")
    print("Ground Truth:", num_ground_truth_selected, "--", str(100 * num_ground_truth_selected / num_test_imgs) + "%")
    print("Colorized:", num_test_imgs - num_ground_truth_selected, "--", str(100 * (num_test_imgs - num_ground_truth_selected) / num_test_imgs) + "%")

Merging, deprocessing, converting to RGB, and saving images
-- 25 completed
-- 50 completed
-- 75 completed
-- 100 completed
DONE!

Breakdown of Ground Truth vs. Colorized Selected:
Ground Truth: 49 -- 49.0%
Colorized: 51 -- 51.0%


In [23]:
value = discriminator.predict([colorized_predictions, l_channel_imgs])

for i, img in enumerate(value):
    print('=====', i, '=====')
    for j in range(k):
        print(j, sum(img[:,:,j]))
    print("MIN:",np.argmin(sum(img, axis=(0,1))))

===== 0 =====
0 -611.6212
1 -537.7493
2 -454.92267
3 -721.1657
4 -49.701584
MIN: 3
===== 1 =====
0 -180.33813
1 -125.81885
2 -88.437256
3 -325.2655
4 321.8932
MIN: 3
===== 2 =====
0 -1833.7739
1 -1770.6782
2 -1753.2402
3 -1959.8857
4 -1475.3164
MIN: 3
===== 3 =====
0 796.3767
1 860.2129
2 891.1913
3 710.1904
4 1177.862
MIN: 3
===== 4 =====
0 -1400.9674
1 -1340.4893
2 -1281.1377
3 -1522.4576
4 -914.71814
MIN: 3
===== 5 =====
0 4574.956
1 4645.1436
2 4688.954
3 4457.4
4 5044.89
MIN: 3
===== 6 =====
0 -902.0542
1 -839.8125
2 -784.7599
3 -979.7953
4 -472.80008
MIN: 3
===== 7 =====
0 683.6513
1 742.59766
2 847.90283
3 590.50195
4 1196.9446
MIN: 3
===== 8 =====
0 256.5343
1 312.2461
2 315.2163
3 185.22107
4 539.9988
MIN: 3
===== 9 =====
0 1511.5364
1 1570.8456
2 1611.3647
3 1427.2673
4 1924.6301
MIN: 3
===== 10 =====
0 -447.3651
1 -386.97827
2 -340.51184
3 -506.068
4 -83.01221
MIN: 3
===== 11 =====
0 -1074.8641
1 -1013.9446
2 -963.33936
3 -1161.7313
4 -711.56226
MIN: 3
===== 12 =====
0 724.8