In [1]:
import numpy as np
import PIL
import os
import tensorflow.keras as keras
import tensorflow as tf

In [2]:
from model import get_model
from data_utils import batch_generator

In [3]:
model = get_model(use_DWT=True, num_blocks=3, input_shape=(64, 64, 3))

In [4]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 64, 64, 3)]  0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 64, 64, 16)   448         input_1[0][0]                    
__________________________________________________________________________________________________
leaky_re_lu (LeakyReLU)         (None, 64, 64, 16)   0           conv2d[0][0]                     
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 64, 64, 16)   64          input_1[0][0]                    
______________________________________________________________________________________________

In [5]:

train_path = 'train2017'

im_shape = (32,32)
im_shape_final = (64, 64)

In [6]:
HUBER_DELTA = 0.5

def smoothL1(y_true, y_pred):
    x   = keras.backend.abs(y_true - y_pred)
    x   = keras.backend.switch(x < HUBER_DELTA, 0.5 * x ** 2, HUBER_DELTA * (x - 0.5 * HUBER_DELTA))
    return  keras.backend.sum(x)

model.compile(optimizer=keras.optimizers.Adam(1e-4), loss = smoothL1)

In [7]:
train_names = os.listdir(train_path)
batch_num = 16

In [8]:
model.fit(batch_generator(ids=train_names,
                          batch_size=batch_num,
                          im_shape=im_shape,
                          im_shape_final=im_shape_final,
                          train_path=train_path),
                          epochs = 1,
                          steps_per_epoch=len(train_names)//batch_num)


  ...
    to  
  ['...']
Train for 7392 steps
 506/7392 [=>............................] - ETA: 1:21:11 - loss: 1529.8355

KeyboardInterrupt: 

In [9]:
from PIL import Image

In [28]:
test_path = "train2017/000000000025.jpg"
y = Image.open(test_path)
x = np.array(y.resize(im_shape).resize(im_shape_final))/255.0
y = np.array(y.resize(im_shape_final))/255.0
y_pred = full_model.predict(np.expand_dims(x, 0))
y_pred = y_pred.squeeze()

In [29]:
Image.fromarray((y_pred*255).astype('uint8'), 'RGB').show()

In [30]:
Image.fromarray((x*255).astype('uint8'), 'RGB').resize(im_shape_final, resample=Image.NEAREST).show()

In [31]:
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2

In [14]:
pretrained_model = MobileNetV2(include_top=False, input_shape=im_shape_final+(3,))
for layer in pretrained_model.layers:
    layer.trainable = False



In [15]:
pretrained_model.summary()

Model: "mobilenetv2_1.00_224"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            [(None, 64, 64, 3)]  0                                            
__________________________________________________________________________________________________
Conv1_pad (ZeroPadding2D)       (None, 65, 65, 3)    0           input_2[0][0]                    
__________________________________________________________________________________________________
Conv1 (Conv2D)                  (None, 32, 32, 32)   864         Conv1_pad[0][0]                  
__________________________________________________________________________________________________
bn_Conv1 (BatchNormalization)   (None, 32, 32, 32)   128         Conv1[0][0]                      
_______________________________________________________________________________

In [16]:
selectedLayers = [4,12,16,21,30,39]

In [17]:
outputs = [pretrained_model.layers[i].output for i in selectedLayers]
loss_model = keras.Model(pretrained_model.inputs, outputs)

In [18]:
def perceptual_loss(y_true, y_pred):
    y_true_outs = loss_model(y_true)
    y_pred_outs = loss_model(y_pred)
    loss = 0
    for i in range(len(y_true_outs)):
        loss += keras.backend.sqrt(keras.backend.mean(keras.backend.square(y_true_outs[i] - y_pred_outs[i]))) 
    return loss

def mse_perceptual_loss(y_true, y_pred):
    percep_loss = perceptual_loss(y_true, y_pred)
    mse_loss = keras.losses.mean_squared_error(y_true, y_pred)
    return 0.001*percep_loss + mse_loss
    

In [19]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 64, 64, 3)]  0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 64, 64, 16)   448         input_1[0][0]                    
__________________________________________________________________________________________________
leaky_re_lu (LeakyReLU)         (None, 64, 64, 16)   0           conv2d[0][0]                     
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 64, 64, 16)   64          input_1[0][0]                    
______________________________________________________________________________________________

In [20]:
model.compile(optimizer=keras.optimizers.Adam(1e-4), loss=mse_perceptual_loss)

In [21]:
model.fit(batch_generator(ids=train_names,
                          batch_size=batch_num,
                          im_shape=im_shape,
                          im_shape_final=im_shape_final,
                          train_path=train_path),
                          epochs = 1,
                          steps_per_epoch=len(train_names)//batch_num)



  ...
    to  
  ['...']
Train for 7392 steps
 297/7392 [>.............................] - ETA: 1:24:21 - loss: 0.0088

KeyboardInterrupt: 

In [22]:
from DIVA import DIVA2D


In [23]:
DIVA_model = DIVA2D(depth=3,filters=64,image_channels=3,use_bnorm=True)

In [24]:
full_input = DIVA_model.input
full_output = model(DIVA_model(full_input))
full_output = model(DIVA_model(full_output))
full_output = model(DIVA_model(full_output))
full_model = keras.Model(full_input, full_output)

In [25]:
full_model.summary()

Model: "model_3"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input0 (InputLayer)             [(None, None, None,  0                                            
__________________________________________________________________________________________________
model_2 (Model)                 multiple             215296      input0[0][0]                     
                                                                 model[1][0]                      
                                                                 model[2][0]                      
__________________________________________________________________________________________________
model (Model)                   (None, 64, 64, 3)    330259      model_2[1][0]                    
                                                                 model_2[2][0]              

In [26]:
full_model.compile(optimizer=keras.optimizers.Adam(1e-4), loss=mse_perceptual_loss)

In [32]:
full_model.fit(batch_generator(ids=train_names,
                          batch_size=batch_num,
                          im_shape=im_shape,
                          im_shape_final=im_shape_final,
                          train_path=train_path),
                          epochs = 1,
                          steps_per_epoch=len(train_names)//batch_num)

  ...
    to  
  ['...']
Train for 7392 steps
  31/7392 [..............................] - ETA: 9:22:35 - loss: 0.0165