In [None]:
import keras
from keras.models import load_model, model_from_json
import sys
import matplotlib.pyplot as plt
%matplotlib inline
sys.path.append("./Func")

In [None]:
from loading_data import DataLoader_base
from building_network import NetworkBuild_base
from run_network import NetworkRun_base
from editing_network import NetworkEdit_base

In [None]:
import os

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
from keras.layers import  Input,Conv2D,BatchNormalization,Activation,Subtract,Dense,Reshape
from keras.models import Model, load_model

In [None]:
def MyDnCNN(depth,filters=64,image_shape=(256,256,1), # Height, Width, Channel
            use_bnorm=True):
    layer_count = 0
    inpt = Input(shape=image_shape,name = 'input'+str(layer_count))
    # 1st layer, Conv+relu
    layer_count += 1
    x = Conv2D(filters=filters, kernel_size=(3,3), strides=(1,1),kernel_initializer='Orthogonal', padding='same',name = 'conv'+str(layer_count))(inpt)
    layer_count += 1
    x = Activation('relu',name = 'relu'+str(layer_count))(x)
    # depth-2 layers, Conv+BN+relu
    for i in range(depth-2):
        layer_count += 1
        x = Conv2D(filters=filters, kernel_size=(3,3), strides=(1,1),kernel_initializer='Orthogonal', padding='same',use_bias = False,name = 'conv'+str(layer_count))(x)
        if use_bnorm:
            layer_count += 1
            #x = BatchNormalization(axis=3, momentum=0.1,epsilon=0.0001, name = 'bn'+str(layer_count))(x) 
        x = BatchNormalization(axis=3, momentum=0.0,epsilon=0.0001, name = 'bn'+str(layer_count))(x)
        layer_count += 1
        x = Activation('relu',name = 'relu'+str(layer_count))(x)  
    # last layer, Conv
    layer_count += 1
    x = Conv2D(filters=image_shape[-1], kernel_size=(3,3), strides=(1,1), kernel_initializer='Orthogonal',padding='same',use_bias = False,name = 'conv'+str(layer_count))(x)
    layer_count += 1
    x = Subtract(name = 'subtract' + str(layer_count))([inpt, x])   # input - noise
    model = Model(inputs=inpt, outputs=x)
#     print(image_shape)
    
    return model

In [None]:
test_net = NetworkEdit_base(model=MyDnCNN(17))

In [None]:
load_path = "./models/DnCNN_sigma25/model.h5"
test_net.LoadModel("WEIGHTS",load_path)
print(test_net.model.summary())

In [None]:
from skimage.measure import compare_psnr, compare_ssim
from skimage.io import imread, imsave
import numpy as np

In [None]:
test_image_path = "./data/Test/Set12/"
image_name = "01.png"

In [None]:
x = np.asarray(imread(os.path.join(test_image_path,image_name)),dtype=np.float32)/ 255.0
plt.imshow(x)
print(x.shape)

np.random.seed(seed=0) # for reproducibility
y = x + np.random.normal(0, 70/255.0, x.shape) # Add Gaussian noise without clipping
y = y.astype(np.float32)
plt.imshow(y)
print(y.shape)

In [None]:
y = np.reshape(x_predict,(1,256,256,1))
x_predict = test_net.model.predict(y)
print(x_predict.shape)

In [None]:
x_predict = np.reshape(x_predict,(256,256))
plt.imshow(x_predict)

In [None]:
plt.imshow(x-x_predict)

In [None]:
psnr_x_ = compare_psnr(x, x_predict)
ssim_x_ = compare_ssim(x, x_predict)
print(psnr_x_,ssim_x_)

In [None]:
def MyDnCNN_2(depth,filters=64,image_shape=(256,256,1), # Height, Width, Channel
            use_bnorm=True):
    layer_count = 0
    inpt_1 = Input(shape=(1,),name = 'input'+str(layer_count))
    inpt = Dense(image_shape[0]*image_shape[1]*image_shape[2])(inpt_1)
    inpt = Reshape(image_shape)(inpt)
    # 1st layer, Conv+relu
    layer_count += 1
    x = Conv2D(filters=filters, kernel_size=(3,3), strides=(1,1),kernel_initializer='Orthogonal', padding='same',name = 'conv'+str(layer_count))(inpt)
    layer_count += 1
    x = Activation('relu',name = 'relu'+str(layer_count))(x)
    # depth-2 layers, Conv+BN+relu
    for i in range(depth-2):
        layer_count += 1
        x = Conv2D(filters=filters, kernel_size=(3,3), strides=(1,1),kernel_initializer='Orthogonal', padding='same',use_bias = False,name = 'conv'+str(layer_count))(x)
        if use_bnorm:
            layer_count += 1
            #x = BatchNormalization(axis=3, momentum=0.1,epsilon=0.0001, name = 'bn'+str(layer_count))(x) 
        x = BatchNormalization(axis=3, momentum=0.0,epsilon=0.0001, name = 'bn'+str(layer_count))(x)
        layer_count += 1
        x = Activation('relu',name = 'relu'+str(layer_count))(x)  
    # last layer, Conv
    layer_count += 1
    x = Conv2D(filters=image_shape[-1], kernel_size=(3,3), strides=(1,1), kernel_initializer='Orthogonal',padding='same',use_bias = False,name = 'conv'+str(layer_count))(x)
#     layer_count += 1
#     x = Subtract(name = 'subtract' + str(layer_count))([inpt, x])   # input - noise
    model = Model(inputs=inpt_1, outputs=x)
#     print(image_shape)
    
    return model

In [None]:
new_model = NetworkEdit_base(model=MyDnCNN_2(17))

In [None]:
load_path = "./models/DnCNN_sigma25/model.h5"
new_model.LoadModel("WEIGHTS",load_path)
print(new_model.model.summary())

In [None]:
my_weight =[]
my_weight.append(np.reshape(y,(1,-1)))
my_weight.append(np.reshape(y,(-1)))
new_model.model.layers[1].set_weights(my_weight)

# my_weight =[]
# my_weight.append(np.reshape(x_predict,(1,-1)))
# my_weight.append(np.reshape(x_predict,(-1)))
# new_model.model.layers[1].set_weights(my_weight)

In [None]:
net_input = np.asarray([0],dtype=np.float32)
x_predict_2 = y-new_model.model.predict(net_input)
print(x_predict_2.shape)

In [None]:
x_predict_2 = np.reshape(x_predict_2,(256,256))
plt.imshow(x_predict_2)
# print(np.mean(abs(x_predict_2)))

In [None]:
psnr_x_ = compare_psnr(x, x_predict_2)
ssim_x_ = compare_ssim(x, x_predict_2)
print(psnr_x_,ssim_x_)

In [None]:
# gradient backpropagation

In [None]:
print(len(new_model.model.layers))

In [None]:
print(new_model.model.layers[1].name)
print(new_model.model.layers[1].trainable_weights)

In [None]:
for i in range(2,len(new_model.model.layers)):
    new_model.model.layers[i].trainable= False
for i in range(len(new_model.model.layers)):
    print(new_model.model.layers[i].trainable)

In [None]:
my_weight =[]
my_weight.append(np.reshape(x_predict_2,(1,-1)))
my_weight.append(np.reshape(x_predict_2,(-1)))
new_model.model.layers[1].set_weights(my_weight)

In [None]:
target = np.zeros(y.shape)
print(target.shape)

In [None]:
opt = keras.optimizers.Adam(lr=0.1)
loss = keras.losses.mean_absolute_error
new_model.model.compile(opt,loss)

In [None]:
new_weight = new_model.model.layers[1].get_weights()
new_image = np.reshape(new_weight[1],(256,256))
plt.imshow(new_image)
psnr_x_ = compare_psnr(x, new_image)
ssim_x_ = compare_ssim(x, new_image)
print(psnr_x_,ssim_x_)

In [None]:
new_model.model.fit(net_input,target,epochs=100,batch_size=1,verbose=2)

In [None]:
new_weight = new_model.model.layers[1].get_weights()
new_image = np.reshape(new_weight[1],(256,256))
plt.imshow(new_image)

In [None]:
psnr_x_ = compare_psnr(x, new_image)
ssim_x_ = compare_ssim(x, new_image)
print(psnr_x_,ssim_x_)

In [None]:
net_input = np.asarray([0],dtype=np.float32)
x_predict_2 = y-new_model.model.predict(net_input)
print(x_predict_2.shape)
x_predict_2 = np.reshape(x_predict_2,(256,256))
plt.imshow(x_predict_2)

In [None]:
psnr_x_ = compare_psnr(x, x_predict_2)
ssim_x_ = compare_ssim(x, x_predict_2)
print(psnr_x_,ssim_x_)

Does not work well. The residual trained by the author of the paper does not be a converged network, which means the network is not able to detect/recognize whether the result we get is a "successful" result, which means no more updates.

next step is to use the dataset to retrain the network, see whether a converged intellegent network exists, if so, does the method work?

To be specific, the up-coming task is to rewrite some head files.