[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/guilbera/colorizing/blob/main/notebooks/keras_implementation/autoencoder_keras.ipynb)

In [12]:
import tensorflow as tf
import numpy as np
from tensorflow.python.keras.layers import Conv2D, UpSampling2D, Input, Reshape, concatenate
from tensorflow.python.keras.models import Model, Sequential
from tensorflow.python.keras.applications.inception_resnet_v2 import InceptionResNetV2, preprocess_input
from tensorflow.python.keras.layers.core import RepeatVector
from skimage.transform import resize

In [13]:
def encoder(encoder_input):
    encoder_output = Conv2D(64, (3,3), activation='relu', padding='same', strides=2)(encoder_input)
    encoder_output = Conv2D(128, (3,3), activation='relu', padding='same')(encoder_output)
    encoder_output = Conv2D(128, (3,3), activation='relu', padding='same', strides=2)(encoder_output)
    encoder_output = Conv2D(256, (3,3), activation='relu', padding='same')(encoder_output)
    encoder_output = Conv2D(256, (3,3), activation='relu', padding='same', strides=2)(encoder_output)
    encoder_output = Conv2D(512, (3,3), activation='relu', padding='same')(encoder_output)
    encoder_output = Conv2D(512, (3,3), activation='relu', padding='same')(encoder_output)
    encoder_output = Conv2D(256, (3,3), activation='relu', padding='same')(encoder_output)
    return encoder_output

In [14]:
def decoder(decoder_input):
    decoder_output = Conv2D(128, (3,3), activation='relu', padding='same')(decoder_input)
    decoder_output = UpSampling2D((2, 2))(decoder_output)
    decoder_output = Conv2D(64, (3,3), activation='relu', padding='same')(decoder_output)
    decoder_output = UpSampling2D((2, 2))(decoder_output)
    decoder_output = Conv2D(32, (3,3), activation='relu', padding='same')(decoder_output)
    decoder_output = Conv2D(16, (3,3), activation='relu', padding='same')(decoder_output)
    decoder_output = Conv2D(2, (3, 3), activation='tanh', padding='same')(decoder_output)
    decoder_output = UpSampling2D((2, 2))(decoder_output)
    return decoder_output

In [15]:
def fusion(embed_input, encoder_output):
    """fusion layer for the gamma modell"""
    fusion_output = RepeatVector(32 * 32)(embed_input) 
    fusion_output = Reshape(([32, 32, 1000]))(fusion_output)
    fusion_output = concatenate([encoder_output, fusion_output], axis=3) 
    fusion_output = Conv2D(256, (1, 1), activation='relu', padding='same')(fusion_output)
    return fusion_output

In [16]:
def load_model():
  """load the classifier"""
  graph = tf.Graph()
  inception = InceptionResNetV2(weights='imagenet', include_top=True)
  return graph, inception

In [17]:
def create_inception_embedding(grayscaled_rgb, graph, inception):
    """preprocess images for the classifier and returns the prediction"""
    grayscaled_rgb_resized = []
    for i in grayscaled_rgb:
        i = resize(i, (299, 299, 3), mode='constant')
        grayscaled_rgb_resized.append(i)
    grayscaled_rgb_resized = np.array(grayscaled_rgb_resized)
    grayscaled_rgb_resized = preprocess_input(grayscaled_rgb_resized)
    embed = inception.predict(grayscaled_rgb_resized)
    return embed

In [18]:
def model_beta(encoder_input, optimizer='adam'):
    decoder_output = decoder(encoder(encoder_input))
    model = Model(inputs=encoder_input, outputs=decoder_output)
    model.compile(optimizer=optimizer, loss='mse',  metrics=['mse', 'mae', 'mape'])
    return model

In [19]:
def model_gamma(encoder_input, embed_input, optimizer='adam'):
    decoder_output = decoder(fusion(embed_input, encoder(encoder_input)))
    model = Model(inputs=[encoder_input,embed_input], outputs=decoder_output)
    model.compile(optimizer=optimizer, loss='mse',  metrics=['mse', 'mae', 'mape'])
    return model   

In [23]:
show_model_beta = False
show_model_gamma = False

In [21]:
if show_model_beta:
  encoder_input = Input(shape=(256, 256, 1,))
  model = model_beta(encoder_input)
  model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 256, 256, 1)]     0         
_________________________________________________________________
conv2d (Conv2D)              (None, 128, 128, 64)      640       
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 128, 128, 128)     73856     
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 64, 64, 128)       147584    
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 64, 64, 256)       295168    
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 32, 32, 256)       590080    
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 32, 32, 512)       118016

In [22]:
if show_model_gamma:
  encoder_input = Input(shape=(256, 256, 1,))
  embed_input = Input(shape=(1000,))
  model = model_gamma(encoder_input, embed_input)
  model.summary()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            [(None, 256, 256, 1) 0                                            
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 128, 128, 64) 640         input_2[0][0]                    
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 128, 128, 128 73856       conv2d_13[0][0]                  
__________________________________________________________________________________________________
conv2d_15 (Conv2D)              (None, 64, 64, 128)  147584      conv2d_14[0][0]                  
____________________________________________________________________________________________