In [1]:
import tensorflow as tf
import tensorflow_io as tfio
from efficientnet.tfkeras import EfficientNetB3

In [2]:
import numpy as np
from matplotlib import pyplot as plt
import glob
from sklearn.model_selection import train_test_split
import os

In [3]:
# Hyper params
learning_rate = 1e-4
epochs = 100
batch_size = 32

# Build Model

In [4]:
effnet_model = EfficientNetB3(weights='imagenet', include_top=True)

In [5]:
embed_input = tf.keras.layers.Input(shape=(1000,))

# Encoder input
encoder_input = tf.keras.layers.Input(shape=(256, 256, 1))
encoder = tf.keras.layers.Conv2D(64, (3, 3), padding='same', strides=(2, 2))(encoder_input)
encoder = tf.keras.layers.Activation('relu')(encoder)
encoder = tf.keras.layers.Conv2D(128, (3, 3), padding='same')(encoder)
encoder = tf.keras.layers.Activation('relu')(encoder)
encoder = tf.keras.layers.Conv2D(128, (3, 3), padding='same', strides=(2, 2))(encoder)
encoder = tf.keras.layers.Activation('relu')(encoder)
encoder = tf.keras.layers.Conv2D(256, (3, 3), padding='same')(encoder)
encoder = tf.keras.layers.Activation('relu')(encoder)
encoder = tf.keras.layers.Conv2D(256, (3, 3), padding='same', strides=(2, 2))(encoder)
encoder = tf.keras.layers.Activation('relu')(encoder)
encoder = tf.keras.layers.Conv2D(512, (3, 3), padding='same')(encoder)
encoder = tf.keras.layers.Activation('relu')(encoder)
encoder = tf.keras.layers.Conv2D(512, (3, 3), padding='same')(encoder)
encoder = tf.keras.layers.Activation('relu')(encoder)
encoder = tf.keras.layers.Conv2D(256, (3, 3), padding='same')(encoder)
encoder = tf.keras.layers.Activation('relu')(encoder)

# Fusion
fusion_output = tf.keras.layers.RepeatVector(32 * 32)(embed_input)
fusion_output = tf.keras.layers.Reshape(([32, 32, 1000]))(fusion_output)
fusion_output = tf.keras.layers.concatenate([encoder, fusion_output], axis=3)
fusion_output = tf.keras.layers.Conv2D(256, (1, 1), padding='same')(fusion_output)
fusion_output = tf.keras.layers.Activation('relu')(fusion_output)

# Decoder
decoder_output = tf.keras.layers.Conv2D(128, (3, 3), padding='same')(fusion_output)
decoder_output = tf.keras.layers.Activation('relu')(decoder_output)
decoder_output = tf.keras.layers.UpSampling2D((2, 2))(decoder_output)
decoder_output = tf.keras.layers.Conv2D(64, (3, 3), padding='same')(decoder_output)
decoder_output = tf.keras.layers.Activation('relu')(decoder_output)
decoder_output = tf.keras.layers.UpSampling2D((2, 2))(decoder_output)
decoder_output = tf.keras.layers.Conv2D(32, (3, 3), padding='same')(decoder_output)
decoder_output = tf.keras.layers.Activation('relu')(decoder_output)
decoder_output = tf.keras.layers.Conv2D(16, (3, 3), padding='same')(decoder_output)
decoder_output = tf.keras.layers.Activation('relu')(decoder_output)
decoder_output = tf.keras.layers.Conv2D(2, (3, 3), padding='same')(decoder_output)
decoder_output = tf.keras.layers.Activation('tanh')(decoder_output)
decoder_output = tf.keras.layers.UpSampling2D((2, 2))(decoder_output)

model = tf.keras.models.Model(inputs=[encoder_input, embed_input], outputs=decoder_output)

In [6]:
model.summary()

Model: "functional_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            [(None, 256, 256, 1) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 128, 128, 64) 640         input_3[0][0]                    
__________________________________________________________________________________________________
activation (Activation)         (None, 128, 128, 64) 0           conv2d[0][0]                     
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 128, 128, 128 73856       activation[0][0]                 
_______________________________________________________________________________________

In [7]:
model.compile(loss='mse', optimizer=tf.keras.optimizers.Adam(learning_rate))

# Load Data

In [8]:
def get_data(image_path):
    byte_img = tf.io.read_file(image_path)
    img = tf.io.decode_jpeg(byte_img)
    img_norm = tf.cast(img, tf.float32)
    img_norm = img_norm / 255.
    
    img = tf.image.resize(img, (300, 300))
    img = img / 255.
    img_expanded = tf.expand_dims(img, axis=0)
    embeddings = tf.squeeze(effnet_model(img_expanded), axis=0)
    
    lab_img = tfio.experimental.color.rgb_to_lab(img_norm)
    
    l_img = tf.expand_dims(lab_img[:, :, 0], axis=2)
    ab_img = lab_img[:, :, 1:] / 128.
    
    return {'input_1': l_img, 'input_2': embeddings}, ab_img

In [9]:
images = glob.glob('data/images/Train/*.jpg')

In [10]:
train_images, validation_images = train_test_split(images, test_size=0.4, random_state=2021)
validation_images, test_images = train_test_split(validation_images, test_size=0.5, random_state=2021)

In [13]:
train_data = tf.data.Dataset.from_tensor_slices(train_images)
train_data = train_data.map(get_data)
train_data = train_data.batch(batch_size)
train_data = train_data.shuffle(len(train_images))
train_data = train_data.cache()
train_data = train_data.prefetch(tf.data.experimental.AUTOTUNE)

validation_data = tf.data.Dataset.from_tensor_slices(validation_images)
validation_data = validation_data.map(get_data)
validation_data = validation_data.batch(batch_size)
validation_data = validation_data.cache()
validation_data = validation_data.prefetch(tf.data.experimental.AUTOTUNE)

test_data = tf.data.Dataset.from_tensor_slices(test_images)

In [14]:
model.fit(train_data, epochs=epochs)

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100

KeyboardInterrupt: 

# Train