In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.layers import Conv2D,Input,concatenate,AveragePooling2D,Conv2DTranspose,UpSampling2D,BatchNormalization,Activation,Add
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import img_to_array,ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.optimizers import Adam
from keras.preprocessing.image import load_img
from PIL import Image,UnidentifiedImageError
from skimage.metrics import peak_signal_noise_ratio as psnr,mean_squared_error as mse
from sklearn.metrics import mean_absolute_error as mae
from sklearn.model_selection import train_test_split

def load_images(image_dir,image_size=(400,600)):
    images=[]
    for img_name in sorted(os.listdir(image_dir)):
        img_path=os.path.join(image_dir,img_name)
        img=load_img(img_path,target_size=image_size)
        img=img_to_array(img)/255
        images.append(img)
    return np.array(images)

def build_denoising_model(input_shape=(400,600,3)):
    input=Input(shape=(400,600,3))

    #estimation subnetwork
    x=Conv2D(32,(3,3),activation="relu",kernel_initializer='he_normal',padding="same")(input)
    x=Conv2D(32,(3,3),activation="relu",kernel_initializer='he_normal',padding="same")(x)
    x=Conv2D(32,(3,3),activation="relu",kernel_initializer='he_normal',padding="same")(x)
    x=Conv2D(32,(3,3),activation="relu",kernel_initializer='he_normal',padding="same")(x)
    x=Conv2D(3,(3,3),activation="relu",kernel_initializer='he_normal',padding="same")(x)

    # denoising subnetwork
    x=concatenate([x,input])
    conv1=Conv2D(64,(3,3),activation="relu",kernel_initializer='he_normal',padding="same")(x)
    conv2=Conv2D(64,(3,3),activation="relu",kernel_initializer='he_normal',padding="same")(conv1)

    pool1=AveragePooling2D(pool_size=(2,2),padding='same')(conv2)
    conv3=Conv2D(128,(3,3),activation="relu",kernel_initializer='he_normal',padding="same")(pool1)
    conv4=Conv2D(128,(3,3),activation="relu",kernel_initializer='he_normal',padding="same")(conv3)
    conv5=Conv2D(128,(3,3),activation="relu",kernel_initializer='he_normal',padding="same")(conv4)

    pool2=AveragePooling2D(pool_size=(2,2),padding='same')(conv5)
    conv6=Conv2D(256,(3,3),activation="relu",kernel_initializer='he_normal',padding="same")(pool2)
    conv7=Conv2D(256,(3,3),activation="relu",kernel_initializer='he_normal',padding="same")(conv6)
    conv8=Conv2D(256,(3,3),activation="relu",kernel_initializer='he_normal',padding="same")(conv7)
    conv9=Conv2D(256,(3,3),activation="relu",kernel_initializer='he_normal',padding="same")(conv8)
    conv10=Conv2D(256,(3,3),activation="relu",kernel_initializer='he_normal',padding="same")(conv9)
    conv11=Conv2D(256,(3,3),activation="relu",kernel_initializer='he_normal',padding="same")(conv10)

    upsample1=Conv2DTranspose(128,(3,3),strides=2,activation="relu",kernel_initializer='he_normal',padding="same")(conv11)
    add1=Add()([upsample1,conv5])
    conv12=Conv2D(128,(3,3),activation="relu",kernel_initializer='he_normal',padding="same")(add1)
    conv13=Conv2D(128,(3,3),activation="relu",kernel_initializer='he_normal',padding="same")(conv12)
    conv14=Conv2D(128,(3,3),activation="relu",kernel_initializer='he_normal',padding="same")(conv13)

    upsample2=Conv2DTranspose(64,(3,3),strides=2,activation="relu",kernel_initializer='he_normal',padding="same")(conv14)
    add1=Add()([upsample2,conv2])
    conv15=Conv2D(64,(3,3),activation="relu",kernel_initializer='he_normal',padding="same")(add1)
    conv16=Conv2D(64,(3,3),activation="relu",kernel_initializer='he_normal',padding="same")(conv15)

    out=Conv2D(3,(1,1),kernel_initializer='he_normal',padding="same")(conv16)
    out=Add()([out,input])

    CBDNet=Model(input,out)
    CBDNet.compile(optimizer=tf.keras.optimizers.Adam(1e-03),loss=tf.keras.losses.MeanSquaredError())
    return CBDNet

def train_model():
    clean_images=load_images('./Train/high')
    noisy_images=load_images('./Train/low')

    model=build_denoising_model(input_shape=(400,600,3))

    early_stopping=EarlyStopping(monitor='val_loss',patience=10,restore_best_weights=True)
    model.fit(noisy_images,clean_images,epochs=20,batch_size=32,validation_split=0.1,callbacks=[early_stopping])
    return model

def evaluate_model(model,clean_images,noisy_images):
    predictions=model.predict(noisy_images)
    mse_scores=[mse(clean,pred) for clean,pred in zip(clean_images,predictions)]
    psnr_scores=[psnr(clean,pred) for clean,pred in zip(clean_images,predictions)]
    mae_scores=[mae(clean.flatten(),pred.flatten()) for clean,pred in zip(clean_images,predictions)]

    print(f"Mean MSE: {np.mean(mse_scores)}")
    print(f"Mean PSNR: {np.mean(psnr_scores)}")
    print(f"Mean MAE: {np.mean(mae_scores)}")

if __name__=='__main__':
    model=train_model()

    clean_images=load_images('./Train/high')
    noisy_images=load_images('./Train/low')

    evaluate_model(model,clean_images,noisy_images)


Epoch 1/20
[1m 1/14[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m1:03:42[0m 294s/step - loss: 0.6538

AbortedError: Graph execution error:

Detected at node StatefulPartitionedCall/functional_1_1/conv2d_transpose_1_2/conv_transpose defined at (most recent call last):
<stack traces unavailable>
Operation received an exception:Status: 1, message: could not create a primitive, in file tensorflow/core/kernels/mkl/mkl_conv_grad_input_ops.cc:546
	 [[{{node StatefulPartitionedCall/functional_1_1/conv2d_transpose_1_2/conv_transpose}}]] [Op:__inference_one_step_on_iterator_8082]