## RM-Net with frequency domain loss function

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import cv2

from sklearn.model_selection import train_test_split
from skimage.util import random_noise

from tensorflow.keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, \
    BatchNormalization, concatenate, Dropout, Activation, ReLU, Add

from tensorflow.keras.models import Model
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import TensorBoard, EarlyStopping, ModelCheckpoint
from tensorflow.keras.preprocessing.image import array_to_img, img_to_array, load_img
from tensorflow.keras.optimizers import SGD, Adam

import tensorflow as tf

from glob import glob
import random
import datetime

In [None]:
# Parameter
##############################################################
modelName = 'FPD_M_net_weights'
lrate = 0.1
decay_Rate = 1e-6

img_width = 272
img_height = 360
img_shape = (img_width, img_height)
conv_window = {
    '3': (3, 3),
    '5': (5, 5),
    '7': (7, 7),
    '9': (9, 9),
    '11': (11, 11)
    }
res_blocks = 9
model_label = 'm70kFFTPlusDCT'

#range 5 to 9 = 32 to 256
filters = [int(2**x) for x in range(5,9)]
##############################################################
filters

[32, 64, 128, 256]

In [None]:
batch_size = 8

#Generator to load data in batches for training
class Custom_Generator(tf.keras.utils.Sequence) :

  def __init__(self, image_filenames, labels, batch_size) :
    self.image_filenames = image_filenames
    self.labels = labels
    self.batch_size = batch_size


  def __len__(self) :
    return (np.ceil(len(self.image_filenames) / float(self.batch_size))).astype(np.int)


  def __getitem__(self, idx) :
    batch_x = self.image_filenames[idx * self.batch_size : (idx+1) * self.batch_size]
    batch_y = self.labels[idx * self.batch_size : (idx+1) * self.batch_size]

    x_train = []
    for ff in batch_x:
        f_image = cv2.resize(img_to_array(cv2.imread(ff)), img_shape)[...,::-1]
        x_train.append(f_image)
    x_train = np.array(x_train)

    # Carregando Y
    y_train = []
    for fe in batch_y:
        f_image = cv2.resize(img_to_array(cv2.imread(fe)), img_shape)[...,::-1]
        y_train.append(f_image)
    y_train = np.array(y_train)

    # Normalizando X
    x_train_g = []
    for i in range (len(x_train)):
        x_train_g.append((np.mean(x_train[i], axis=2)))
    x_train_g = np.array(x_train_g)

    x_train_g = x_train_g[:,:,:,np.newaxis]
    x_train_g = x_train_g.astype('float32') / 255.


    # Normalizando Y
    y_train_g = []
    for i in range (len(y_train)):
        y_train_g.append((np.mean(y_train[i], axis=2)))
    y_train_g = np.array(y_train_g)

    y_train_g = y_train_g[:,:,:,np.newaxis]
    y_train_g = y_train_g.astype('float32') / 255.

    return x_train_g, y_train_g


training_batch_generator = Custom_Generator(X_train, Y_train, batch_size)
val_batch_generator = Custom_Generator(X_val, Y_val, batch_size)

In [None]:
## Métricas de reconstrução da imagem
def SSIM(y_true, y_pred):
    max_pixel = 1.0
    return tf.image.ssim(y_true, y_pred, max_val=max_pixel)

def SSIM_multi(y_true, y_pred):
    max_pixel = 1.0
    return tf.image.ssim_multiscale(y_true, y_pred, max_val=max_pixel)

In [None]:
# Loss functions
def l1_loss(y_true, y_pred):
  l1_dist = K.mean(K.abs(y_pred - y_true))
  return l1_dist

def f_loss(y_true, y_pred):

  im1 = y_true
  im2 = y_pred

  img1_s1 = tf.dtypes.cast(im1, tf.complex64)
  img2_s1 = tf.dtypes.cast(im2, tf.complex64)


  FFT1_s1 = tf.signal.fft3d(img1_s1)
  FFT2_s1 = tf.signal.fft3d(img2_s1)
  first_stage = tf.math.reduce_mean(tf.abs(FFT1_s1-FFT2_s1))


  im1_s2 = tf.image.resize(im1, (tf.constant(256, tf.int32), tf.constant(256, tf.int32)), method='bicubic')
  im2_s2 = tf.image.resize(im2, (tf.constant(256, tf.int32), tf.constant(256, tf.int32)), method='bicubic')

  img1_s2 = tf.dtypes.cast(im1_s2, tf.complex64)
  img2_s2 = tf.dtypes.cast(im2_s2, tf.complex64)

  FFT1_s2 = tf.signal.fft3d(img1_s2)
  FFT2_s2 = tf.signal.fft3d(img2_s2)
  second_stage = tf.math.reduce_mean(tf.abs(FFT1_s2-FFT2_s2))


  im1_s3 = tf.image.resize(im1, (tf.constant(128, tf.int32), tf.constant(128, tf.int32)), method='bicubic')
  im2_s3 = tf.image.resize(im2, (tf.constant(128, tf.int32), tf.constant(128, tf.int32)), method='bicubic')

  img1_s3 = tf.dtypes.cast(im1_s3, tf.complex64)
  img2_s3 = tf.dtypes.cast(im2_s3, tf.complex64)

  FFT1_s3 = tf.signal.fft3d(img1_s3)
  FFT2_s3 = tf.signal.fft3d(img2_s3)
  third_stage = tf.math.reduce_mean(tf.abs(FFT1_s3-FFT2_s3))

  #is actually FFT-loss (multistage)
  dct_loss = (first_stage + (second_stage) + (third_stage))
  return dct_loss


def dct_loss(y_true, y_pred):
  # print('DCT_loss')

  im1 = y_true
  im2 = y_pred

  img1_s1 = tf.dtypes.cast(im1, tf.float32)
  img2_s1 = tf.dtypes.cast(im2, tf.float32)


  FFT1_s1 = tf.signal.dct(img1_s1)
  FFT2_s1 = tf.signal.dct(img2_s1)
  first_stage = tf.math.reduce_mean(tf.abs(FFT1_s1-FFT2_s1))


  im1_s2 = tf.image.resize(im1, (tf.constant(256, tf.int32), tf.constant(256, tf.int32)), method='bicubic')
  im2_s2 = tf.image.resize(im2, (tf.constant(256, tf.int32), tf.constant(256, tf.int32)), method='bicubic')

  img1_s2 = tf.dtypes.cast(im1_s2, tf.float32)
  img2_s2 = tf.dtypes.cast(im2_s2, tf.float32)

  FFT1_s2 = tf.signal.dct(img1_s2)
  FFT2_s2 = tf.signal.dct(img2_s2)
  second_stage = tf.math.reduce_mean(tf.abs(FFT1_s2-FFT2_s2)) / 1200


  im1_s3 = tf.image.resize(im1, (tf.constant(128, tf.int32), tf.constant(128, tf.int32)), method='bicubic')
  im2_s3 = tf.image.resize(im2, (tf.constant(128, tf.int32), tf.constant(128, tf.int32)), method='bicubic')

  img1_s3 = tf.dtypes.cast(im1_s3, tf.float32)
  img2_s3 = tf.dtypes.cast(im2_s3, tf.float32)

  FFT1_s3 = tf.signal.dct(img1_s3)
  FFT2_s3 = tf.signal.dct(img2_s3)
  third_stage = tf.math.reduce_mean(tf.abs(FFT1_s3-FFT2_s3))

  #is actually FFT-loss (multistage)
  dct_loss = (first_stage + (second_stage) + (third_stage)) / 1200
  return dct_loss


def fftdctl1_loss(y_true, y_pred):
  delta = 0.10
  return ((1 - delta) * l1_loss(y_true, y_pred)) + (delta * (dct_loss(y_true, y_pred) + f_loss(y_true, y_pred)))

In [None]:
def residual_block(x, num_filters):
    y = Conv2D(num_filters, conv_window['3'], padding='same')(x)
    y = BatchNormalization()(y)
    y = ReLU()(y)
    y = Dropout(0.2)(y)
    y = Conv2D(num_filters, conv_window['3'], padding='same')(y)
    y = BatchNormalization()(y)

    out = Add()([x, y])
    out = ReLU()(out)
    return out

In [None]:
# Define the neural network
def getFPDMNet():
    # Input
    input1 = Input(shape=(img_height, img_width, 1))

    # Encoder
    conv1 = Conv2D(32, conv_window['9'], padding='same')(input1)
    conv1 = BatchNormalization()(conv1)
    conv1 = Activation('relu')(conv1)
    conv1 = Dropout(0.2)(conv1)

    conv1 = concatenate([input1, conv1], axis=-1)
    conv1 = Conv2D(32, conv_window['9'], padding='same')(conv1)
    conv1 = BatchNormalization()(conv1)
    conv1 = Activation('relu')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    #
    input2 = MaxPooling2D(pool_size=(2, 2))(input1)
    conv21 = concatenate([input2, pool1], axis=-1)

    conv2 = Conv2D(64, conv_window['7'], padding='same')(conv21)
    conv2 = BatchNormalization()(conv2)
    conv2 = Activation('relu')(conv2)
    conv2 = Dropout(0.2)(conv2)

    conv2 = concatenate([conv21, conv2], axis=-1)
    conv2 = Conv2D(64, conv_window['7'], padding='same')(conv2)
    conv2 = BatchNormalization()(conv2)
    conv2 = Activation('relu')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    #
    input3 = MaxPooling2D(pool_size=(2, 2))(input2)
    conv31 = concatenate([input3, pool2], axis=-1)

    conv3 = Conv2D(128, conv_window['5'], padding='same')(conv31)
    conv3 = BatchNormalization()(conv3)
    conv3 = Activation('relu')(conv3)
    conv3 = Dropout(0.2)(conv3)

    conv3 = concatenate([conv31, conv3], axis=-1)
    conv3 = Conv2D(128, conv_window['5'], padding='same')(conv3)
    conv3 = BatchNormalization()(conv3)
    conv3 = Activation('relu')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    #
    input4 = MaxPooling2D(pool_size=(2, 2))(input3)
    conv41 = concatenate([input4, pool3], axis=-1)

    conv4 = Conv2D(256, conv_window['3'], padding='same')(conv41)
    conv4 = BatchNormalization()(conv4)
    conv4 = Activation('relu')(conv4)
    conv4 = Dropout(0.2)(conv4)

    conv4 = concatenate([conv41, conv4], axis=-1)
    conv4 = Conv2D(256, conv_window['3'], padding='same')(conv4)
    conv4 = BatchNormalization()(conv4)
    conv4 = Activation('relu')(conv4)
    conv4 = Dropout(0.2)(conv4)

    conv4 = Conv2D(256, conv_window['3'], padding='same')(conv4)
    conv4 = BatchNormalization()(conv4)
    conv4 = Activation('relu')(conv4)

    # Resnet blocks (5)
    res = conv4
    for i in range(res_blocks):
        res = residual_block(res, 256)

    # Decoder
    conv5 = UpSampling2D(size=(2, 2))(res)
    # conv5 = UpSampling2D(size=(2, 2))(conv4)
    conv51 = concatenate([conv3, conv5], axis=-1)

    conv5 = Conv2D(128, conv_window['5'], padding='same')(conv51)
    conv5 = BatchNormalization()(conv5)
    conv5 = Activation('relu')(conv5)
    conv5 = Dropout(0.2)(conv5)

    conv5 = concatenate([conv51, conv5], axis=-1)
    conv5 = Conv2D(128, conv_window['5'], padding='same')(conv5)
    conv5 = BatchNormalization()(conv5)
    conv5 = Activation('relu')(conv5)

    #
    conv6 = UpSampling2D(size=(2, 2))(conv5)
    conv61 = concatenate([conv2, conv6], axis=-1)

    conv6 = Conv2D(64, conv_window['7'], padding='same')(conv61)
    conv6 = BatchNormalization()(conv6)
    conv6 = Activation('relu')(conv6)
    conv6 = Dropout(0.2)(conv6)

    conv6 = concatenate([conv61, conv6], axis=-1)
    conv6 = Conv2D(64, conv_window['7'], padding='same')(conv6)
    conv6 = BatchNormalization()(conv6)
    conv6 = Activation('relu')(conv6)

    #
    conv7 = UpSampling2D(size=(2, 2))(conv6)
    conv71 = concatenate([conv1, conv7], axis=-1)

    conv7 = Conv2D(32, conv_window['9'], padding='same')(conv71)
    conv7 = BatchNormalization()(conv7)
    conv7 = Activation('relu')(conv7)
    conv7 = Dropout(0.2)(conv7)

    conv7 = concatenate([conv71, conv7], axis=-1)
    conv7 = Conv2D(32, conv_window['9'], padding='same')(conv7)
    conv7 = BatchNormalization()(conv7)
    conv7 = Activation('relu')(conv7)

    # Final
    conv81 = UpSampling2D(size=(8, 8))(conv4)
    conv82 = UpSampling2D(size=(4, 4))(conv5)
    conv83 = UpSampling2D(size=(2, 2))(conv6)
    conv8 = concatenate([conv81, conv82, conv83, conv7], axis=-1)
    conv8 = Conv2D(1, (1, 1), activation='sigmoid')(conv8)

    ############
    model = Model(inputs=input1, outputs=conv8)
    adam = Adam(learning_rate=0.0002, beta_1=0.5, beta_2=0.999, epsilon=1e-07)
    model.compile(optimizer=adam, loss=fftdctl1_loss, metrics=[SSIM_multi, tf.keras.metrics.Accuracy()])
    model.summary()

    return model

In [None]:
ae_01 = getFPDMNet()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 360, 272, 1  0           []                               
                                )]                                                                
                                                                                                  
 conv2d (Conv2D)                (None, 360, 272, 32  2624        ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 batch_normalization (BatchNorm  (None, 360, 272, 32  128        ['conv2d[0][0]']                 
 alization)                     )                                                             

In [None]:
#callbacks
logdir = '/content/drive/MyDrive/logs/%s/' %model_label
earlystop = EarlyStopping(monitor = 'val_loss',
                          min_delta = 0,
                          patience = 25,
                          verbose = 1,
                          restore_best_weights = True)

tb = TensorBoard(log_dir = logdir, write_graph=True)

model_checkpoint = ModelCheckpoint(
    filepath='/content/drive/MyDrive/saves/%s.{epoch:02d}-{val_loss:.3f}.h5' %model_label,
                              save_weights_only=True,
                              monitor='val_loss',
                              mode='min',
                              save_best_only=True,
                             )

In [None]:
# ae_01.load_weights('m50k.85-0.022.h5')

In [None]:
begin = datetime.datetime.now()

##treinando o autoencoder
ae_01.fit(  training_batch_generator,
            epochs=1500,
            #batch_size=16,
            initial_epoch=20,
            shuffle=True,
            validation_data=val_batch_generator,
            callbacks=[tb, earlystop, model_checkpoint]
        )
ae_01.save_weights('%s.h5' %model_label)
print('TEMPO TOTAL: ')
print(datetime.datetime.now() - begin)

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  del sys.path[0]


Epoch 21/1500
Epoch 22/1500
Epoch 23/1500
Epoch 24/1500
Epoch 25/1500
Epoch 26/1500
Epoch 27/1500
Epoch 28/1500
Epoch 29/1500