<a href="https://colab.research.google.com/github/mirfan35/Ridnet/blob/main/ridnet_training_(ssim_vs_assim).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
pip install tensorflow-io

Collecting tensorflow-io
  Downloading tensorflow_io-0.33.0-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (28.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m28.6/28.6 MB[0m [31m55.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tensorflow-io
Successfully installed tensorflow-io-0.33.0


In [3]:
import tensorflow as tf
import tensorflow_io as tfio
from tensorflow import keras
from tensorflow.keras import layers, models
import cv2
import numpy as np
import pickle

In [5]:
print(tf.__version__)
print(tf.config.list_physical_devices("GPU"))
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
	print(gpu)
	tf.config.experimental.set_memory_growth(gpu, True)

from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)

tf.keras.backend.clear_session()

2.12.0
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')


In [6]:
physical_devices = tf.config.list_physical_devices('GPU')

print("GPU:", tf.config.list_physical_devices('GPU'))
print("Num GPUs:", len(physical_devices))

GPU: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
Num GPUs: 1


In [7]:
###############################################################################################
# load image
###############################################################################################
f = open('/content/drive/MyDrive/python project/dataset/noise and referance (64x64).pckl', 'rb')
data_in, data_out = pickle.load(f)
f.close()

In [8]:
###############################################################################################
# normalize
###############################################################################################
x_train = np.float32(data_in)/255
y_train = np.float32(data_out)/255

## random shuffle traning data (optional) ##
seed = np.random.randint(123)
np.random.seed(seed)
np.random.shuffle(x_train)
np.random.seed(seed)
np.random.shuffle(y_train)

In [9]:
###############################################################################################
# Enhancement Attention Modules (EAM)
###############################################################################################
def EAM(input):
	conv1 = layers.Conv2D(64, (3,3), dilation_rate=1,padding='same',activation='relu')(input)
	conv1 = layers.Conv2D(64, (3,3), dilation_rate=2,padding='same',activation='relu')(conv1)

	conv2 = layers.Conv2D(64, (3,3), dilation_rate=3,padding='same',activation='relu')(input)
	conv2 = layers.Conv2D(64, (3,3), dilation_rate=4,padding='same',activation='relu')(conv2)

	concat = layers.concatenate([conv1,conv2])
	conv3 = layers.Conv2D(64, (3,3),padding='same',activation='relu')(concat)
	add1 = layers.Add()([input,conv3])

	conv4 = layers.Conv2D(64, (3,3),padding='same',activation='relu')(add1)
	conv4 = layers.Conv2D(64, (3,3),padding='same')(conv4)
	add2 = layers.Add()([conv4,add1])
	add2 = layers.Activation('relu')(add2)

	conv5 = layers.Conv2D(64, (3,3),padding='same',activation='relu')(add2)
	conv5 = layers.Conv2D(64, (3,3),padding='same',activation='relu')(conv5)
	conv5 = layers.Conv2D(64, (1,1),padding='same')(conv5)
	add3 = layers.Add()([add2,conv5])
	add3 = layers.Activation('relu')(add3)

	gap = layers.GlobalAveragePooling2D()(add3)
	gap = layers.Reshape((1,1,64))(gap)
	conv6 = layers.Conv2D(64, (3,3),padding='same',activation='relu')(gap)
	conv6 = layers.Conv2D(64, (3,3),padding='same',activation='sigmoid')(conv6)

	mul = layers.Multiply()([conv6, add3])
	out = layers.Add()([input,mul]) # This is not included in the reference code
	return out

In [10]:
###############################################################################################
# RIDnet autoencoder (https://medium.com/analytics-vidhya/image-denoising-using-deep-learning-dc2b19a3fd54)
###############################################################################################
#### RIDnet layers ####
tf.keras.backend.clear_session()
input = keras.Input(shape=(64, 64, 3))
conv1 = layers.Conv2D(64 , (3,3),padding='same')(input)
eam1 = EAM(conv1)
eam2 = EAM(eam1)
eam3 = EAM(eam2)
eam4 = EAM(eam3)
conv2 = layers.Conv2D(3, (3,3),padding='same')(eam4)
output = layers.Add()([input,conv2])
#### RIDnet layers ####

RIDNet = keras.Model(input,output)

In [11]:
##################################################################################
# mssim from cv2 (https://docs.opencv.org/4.x/d5/dc4/tutorial_video_input_psnr_ssim.html)
##################################################################################
def MSSIM(I1, I2, win=11):
    C1 = 6.5025
    C2 = 58.5225

    # PRELIMINARY COMPUTING
    mu1 = tfio.experimental.filter.gaussian(I1, ksize=[win, win], sigma=1.5, mode='REFLECT') #mean
    mu2 = tfio.experimental.filter.gaussian(I2, ksize=[win, win], sigma=1.5, mode='REFLECT')
    mu1_2 = mu1**2
    mu2_2 = mu2**2
    mu1_mu2 = mu1*mu2
    sigma1_2 = tfio.experimental.filter.gaussian(I1**2, ksize=[win, win], sigma=1.5, mode='REFLECT') - mu1_2 # variance (shortcut formula)
    sigma2_2 = tfio.experimental.filter.gaussian(I2**2, ksize=[win, win], sigma=1.5, mode='REFLECT') - mu2_2
    sigma12 = tfio.experimental.filter.gaussian(I1*I2, ksize=[win, win], sigma=1.5, mode='REFLECT') - mu1_mu2 # covariance (shortcut formula)

    t1 = 2 * mu1_mu2 + C1
    t2 = 2 * sigma12 + C2
    t3 = t1 * t2                    # t3 = ((2*mu1_mu2 + C1).*(2*sigma12 + C2))
    t1 = mu1_2 + mu2_2 + C1
    t2 = sigma1_2 + sigma2_2 + C2
    t1 = t1 * t2                    # t1 =((mu1_2 + mu2_2 + C1).*(sigma1_2 + sigma2_2 + C2))

    ssim_map = t3/t1    # ssim_map =  t3./t1;

    mssim = tf.reduce_mean(ssim_map)       # mssim = average of ssim map

    return mssim

In [12]:
##################################################################################
# absolute mssim
##################################################################################
def absoluteMSSIM(I1, I2, win=11):
    k = 450 # constant k, control contrast sensitivity

    # PRELIMINARY COMPUTING
    mu1 = tfio.experimental.filter.gaussian(I1, ksize=[win, win], sigma=1.5, mode='REFLECT') #mean
    mu2 = tfio.experimental.filter.gaussian(I2, ksize=[win, win], sigma=1.5, mode='REFLECT')
    sigma1_2 = tfio.experimental.filter.gaussian(I1**2, ksize=[win, win], sigma=1.5, mode='REFLECT') - mu1**2 # variance (shortcut formula)
    sigma2_2 = tfio.experimental.filter.gaussian(I2**2, ksize=[win, win], sigma=1.5, mode='REFLECT') - mu2**2
    sigma12 = tfio.experimental.filter.gaussian(I1*I2, ksize=[win, win], sigma=1.5, mode='REFLECT') - mu1*mu2 # covariance (shortcut formula)

    lm = 1 - tf.math.abs(mu1 - mu2)/255 # luminance
    cn = MinMaxRatio(sigma1_2, sigma2_2, thr=k) # contrast

    ssim_map = lm*cn
    mssim = tf.reduce_mean(ssim_map)

    return mssim

##################################################################################
# ratio of between small to large (a/b if a<b)
##################################################################################
def MinMaxRatio(a,b,thr=0):
	tot = a+b
	dif = tf.math.abs(a-b)
	return (tot-dif+2*thr)/(tot+dif+2*thr)

In [13]:
###############################################################################################
# loss function (ssim and absolute ssim)
###############################################################################################
def ssim_loss(y_true, y_pred):
    return 1 - MSSIM(y_true*255, y_pred*255)

def assim_loss(y_true, y_pred):
    return 1 - absoluteMSSIM(y_true*255, y_pred*255)

# RIDNet.compile(optimizer=tf.keras.optimizers.Adam(1e-03), loss=tf.keras.losses.MeanSquaredError())
RIDNet.compile(optimizer=tf.keras.optimizers.Adam(1e-03), loss=ssim_loss, run_eagerly=True)

In [None]:
###############################################################################################
# Training
###############################################################################################
print(RIDNet.summary())

model_name = 'RIDNet(ssim).h5'
check_point = tf.keras.callbacks.ModelCheckpoint(model_name, monitor='val_loss')
early_stopping = tf.keras.callbacks.EarlyStopping(patience=5)
RIDNet.fit(x_train, y_train, epochs=30, validation_split=0.1, batch_size=4, callbacks=[check_point, early_stopping])

RIDNet = models.load_model(model_name, custom_objects={"EAM": EAM, "assim_loss": ssim_loss})
RIDNet.compile(optimizer=tf.keras.optimizers.Adam(1e-03), loss=tf.keras.losses.MeanSquaredError()) # replace loss function after training
RIDNet.save(model_name)

print('done')

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 64, 64, 3)]  0           []                               
                                                                                                  
 conv2d (Conv2D)                (None, 64, 64, 64)   1792        ['input_1[0][0]']                
                                                                                                  
 conv2d_1 (Conv2D)              (None, 64, 64, 64)   36928       ['conv2d[0][0]']                 
                                                                                                  
 conv2d_3 (Conv2D)              (None, 64, 64, 64)   36928       ['conv2d[0][0]']                 
                                                                                              



