In [5]:
from tensorflow.compat.v1 import get_default_graph
from tensorflow.keras.applications.inception_resnet_v2 import InceptionResNetV2, preprocess_input
from tensorflow.keras.layers import concatenate, Conv2D, Input, UpSampling2D, RepeatVector, Reshape
from tensorflow.keras.models import Model

from skimage.color import rgb2lab, lab2rgb

import tensorflow as tf

import numpy as np
import matplotlib.pyplot as plt
import os

In [6]:
# Helper functions
BATCH_SIZE = 32
IMG_HEIGHT = 256
IMG_WIDTH = 256
AUTOTUNE = tf.data.experimental.AUTOTUNE

def process_path(file_path):
    img = tf.io.read_file(file_path)
    grayscale_img = decode_grayscale_img(img)
    #inception_embeded = create_inception_embedding(grayscale_img)
    yuv_img = decode_yuv_img(img)
    return grayscale_img, yuv_img
    #return ([grayscale_img, inception_embeded], yuv_img)


def decode_grayscale_img(img):
    # convert the compressed string to a 3D uint8 tensor
    img = tf.image.decode_jpeg(img, channels=3)
    # convert to grayscale
    img = tf.image.rgb_to_grayscale(img)
    # Use `convert_image_dtype` to convert to floats in the [0,1] range.
    img = tf.image.convert_image_dtype(img, tf.float32)
    # resize the image to the desired size.
    return tf.image.resize(img, [IMG_WIDTH, IMG_HEIGHT])


def decode_yuv_img(img):
    # convert the compressed string to a 3D uint8 tensor
    img = tf.image.decode_jpeg(img, channels=3)
    # Use `convert_image_dtype` to convert to floats in the [0,1] range.
    img = tf.image.convert_image_dtype(img, tf.float32)
    # convert to yuv colorspace
    img = tf.image.rgb_to_yiq(img)
    # resize the image to the desired size.
    return tf.image.resize(img, [IMG_WIDTH, IMG_HEIGHT])


def create_inception_embedding(grayscaled_img):
    grayscaled_img = tf.image.resize(grayscaled_img, [299, 299])
    grayscaled_img = tf.image.grayscale_to_rgb(grayscaled_img)
    grayscaled_img = tf.expand_dims(grayscaled_img, 0)
    grayscaled_img = preprocess_input(grayscaled_img)
    
    print(grayscaled_img.shape)
    # grayscaled_rgb_resized = []
    # for i in grayscaled_rgb:
    #     i = resize(i, (299, 299, 3), mode='constant')
    #     grayscaled_rgb_resized.append(i)
    # grayscaled_rgb_resized = np.array(grayscaled_rgb_resized)
    # grayscaled_rgb_resized = preprocess_input(grayscaled_rgb_resized)
    
    with inception.graph.as_default():
        embed = inception.predict(grayscaled_img, steps=1)
    return embed


def prepare_for_training(ds, cache=True, shuffle_buffer_size=1000):
    # This is a small dataset, only load it once, and keep it in memory.
    # use `.cache(filename)` to cache preprocessing work for datasets that don't
    # fit in memory
    if cache:
        if isinstance(cache, str):
            ds = ds.cache(cache)
        else:
            ds = ds.cache()
    
    ds = ds.shuffle(buffer_size=shuffle_buffer_size)

    # Repeat forever
    ds = ds.repeat()

    ds = ds.batch(BATCH_SIZE)

    # `prefetch` lets the dataset fetch batches in the background while the model
    # is training.
    ds = ds.prefetch(buffer_size=AUTOTUNE)

    return ds


def show_batch(image_batch):
    plt.figure(figsize=(10,10))
    print(image_batch[0])
    for n in range(25):
        ax = plt.subplot(5,5,n+1)
        #plt.imshow(image_batch[n][:,:,0], cmap='binary')
        plt.imshow(image_batch[n])
        print(image_batch[n].shape)
        plt.axis('off')


In [None]:
# Download and/or load weights for InceptionResNetV2
inception = InceptionResNetV2(weights='imagenet', include_top=True)
inception.graph = get_default_graph()

In [7]:
# Load image paths
raw_ds = tf.data.Dataset.list_files('./Train/*')

# Process paths and decode images (RGB)
labeled_ds = raw_ds.map(process_path, num_parallel_calls=AUTOTUNE)

train_ds = prepare_for_training(labeled_ds)

#gray_batch, lab_batch = next(iter(train_ds))
#show_batch(lab_batch.numpy())

In [8]:
# Set up neural network

# Encoder part
encoder_input = Input(shape=(256, 256, 1,))
encoder_output = Conv2D(64, (3, 3), activation='relu', padding='same', strides=2)(encoder_input)
encoder_output = Conv2D(128, (3, 3), activation='relu', padding='same')(encoder_output)
encoder_output = Conv2D(128, (3, 3), activation='relu', padding='same', strides=2)(encoder_output)
encoder_output = Conv2D(256, (3, 3), activation='relu', padding='same')(encoder_output)
encoder_output = Conv2D(256, (3, 3), activation='relu', padding='same', strides=2)(encoder_output)
encoder_output = Conv2D(512, (3, 3), activation='relu', padding='same')(encoder_output)
encoder_output = Conv2D(512, (3, 3), activation='relu', padding='same')(encoder_output)
encoder_output = Conv2D(256, (3, 3), activation='relu', padding='same')(encoder_output)

# Fusion part
#embed_input = Input(shape=(1000,))
# fusion_output = RepeatVector(32 * 32)(embed_input)
# fusion_output = Reshape(([32, 32, 1000]))(fusion_output)
# fusion_output = concatenate([encoder_output, fusion_output], axis=3)
# fusion_output = Conv2D(256, (1, 1), activation='relu', padding='same')(fusion_output)

# Decoder part
decoder_output = Conv2D(128, (3,3), activation='relu', padding='same')(encoder_output)
decoder_output = UpSampling2D((2, 2))(decoder_output)
decoder_output = Conv2D(64, (3,3), activation='relu', padding='same')(decoder_output)
decoder_output = UpSampling2D((2, 2))(decoder_output)
decoder_output = Conv2D(32, (3,3), activation='relu', padding='same')(decoder_output)
decoder_output = Conv2D(16, (3,3), activation='relu', padding='same')(decoder_output)
decoder_output = Conv2D(2, (3, 3), activation='tanh', padding='same')(decoder_output)
decoder_output = UpSampling2D((2, 2))(decoder_output)


In [9]:
# Create model
#model = Model(inputs=[encoder_input, embed_input], outputs=decoder_output)
model = Model(inputs=encoder_input, outputs=decoder_output)
model.compile(optimizer='rmsprop', loss='mse')

In [10]:
model.fit(train_ds, steps_per_epoch=10)

ValueError: When passing an infinitely repeating dataset, you must specify the `steps_per_epoch` argument.