In [None]:
import numpy as np
import matplotlib.pyplot as plt
import datetime
from sklearn.model_selection import train_test_split

import tensorflow as tf
from tensorflow import keras

from data import read_data
from utils import add_noise_est, normalize, add_noise
from model_noise_est import FCN

gpu_ok = tf.test.is_gpu_available()
print("tf version:", tf.__version__)
print("use GPU:", gpu_ok)

In [None]:
ims = read_data('imagenet')

N_ims, h, w, _ = ims.shape
ims = ims[:N_ims].astype(np.float32)
ims_noise = ims_noise[:N_ims].astype(np.float32)

In [None]:
# for noise estimation map, there should be different noise levels
ims_noise = []
ims_noise_with_est = []
variances = np.arange(1,11) * 5e-4
# variances = [80e-4]
ims_split = np.array_split(ims, len(variances))
for i, var in enumerate(variances):
    ims_noise.append(normalize(add_noise(ims_split[i], mean=0, var=var, n_type='gaussian')))
    ims_noise_with_est.append(add_noise_est(ims_noise[i], var = var))
ims_noise = np.concatenate(ims_noise)
ims_noise_with_est = np.concatenate(ims_noise_with_est)

print(ims.shape)
print(ims_noise.shape)
print(ims_noise_with_est.shape)

In [None]:
# training hyperparameters
batch_size = 16
lr = 3e-4
epochs = 80
test_size = 0.1
training_steps = int(epochs*N_ims*(1-test_size)/batch_size)
display_step = int(training_steps/epochs*0.2)

print(training_steps)
print(display_step)

In [None]:
# train test split
train_X, train_Y = ims_noise_with_est[...,0][...,np.newaxis], ims_noise_with_est[...,1][...,np.newaxis]
train_X, test_X, train_Y, test_Y = train_test_split(train_X, train_Y, test_size=test_size, random_state=42)

print('Training X: ', train_X.shape, train_X.dtype)
print('Training Y: ', train_Y.shape, train_Y.dtype)
print('Testing X: ', test_X.shape, test_X.dtype)
print('Testing Y: ', test_Y.shape, test_Y.dtype)

In [None]:
# Use tf.data API to shuffle and batch data.
train_dataset = tf.data.Dataset.from_tensor_slices((train_X,train_Y))
train_dataset = train_dataset.repeat().shuffle(5000).batch(batch_size).prefetch(1)

test_dataset = tf.data.Dataset.from_tensor_slices((test_X,test_Y))
test_dataset = test_dataset.batch(batch_size).prefetch(1)

In [None]:
# model
model = FCN(color = False, channels = [16, 16, 32, 32, 64, 64, 32, 32, 16, 16],
            channel_att=False, spatial_att=False, use_bias = True)

load_model = True
if load_model:
    model.load_weights(filepath = "model_weights/model_noise_est.ckpt")
    
# print(np.sum([np.prod(v.get_shape().as_list()) for v in model.trainable_variables]))

In [None]:
# optimizer
optimizer = tf.keras.optimizers.Adam(lr)
#optimizer = tf.keras.optimizers.SGD(learning_rate=lr, momentum=0.9, nesterov=True, decay=1e-6)

# loss func
#loss_func = tf.keras.losses.MeanAbsoluteError()
loss_func = tf.keras.losses.MeanSquaredError()

In [None]:
# optimization process
def lr_fn(step, cur_lr):
    '''exponetial'''
    next_epoch = step * batch_size // int(N_ims*(1-test_size)) - (step-1) * batch_size // int(N_ims*(1-test_size))
    return cur_lr * (0.95**next_epoch)

def run_optimization(step, train_X, train_Y):
    with tf.GradientTape() as g:
        pred_Y = model(train_X) 
        loss = loss_func(pred_Y, train_Y)
    
    gradients = g.gradient(loss, model.trainable_variables)
    optimizer.learning_rate = lr_fn(step, optimizer.learning_rate.numpy())
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    
    return loss

In [None]:
# 发现用tensorboard summary会让速度变得很慢很慢
train_losses = []
test_losses = []
test_steps = []
lrs = []

In [None]:
mean_train_loss = total_train = 0
for step, (batch_X, batch_Y) in enumerate(train_dataset.take(training_steps), start = 1):
    train_loss = run_optimization(step, batch_X, batch_Y)
    
    mean_train_loss +=  train_loss.numpy()
    total_train += 1
    train_losses.append(train_loss.numpy())
    lrs.append(optimizer.lr.numpy())
    
    if step % display_step == 0:
        mean_test_loss = total_test = 0
        for (batch_test_X, batch_test_Y) in test_dataset:
            pred_test_Y = model(batch_test_X)
            test_loss = loss_func(pred_test_Y, batch_test_Y)
            
            mean_test_loss += test_loss.numpy()
            total_test += 1
        
        mean_test_loss /= total_test
        mean_train_loss /= total_train
        test_losses.append(mean_test_loss)
        test_steps.append(step)

        print("step: {:3d}/{:3d} || train loss: {:.5f} || test loss: {:.5f}"
              .format(step, training_steps, mean_train_loss, mean_test_loss))
        
        mean_train_loss = total_train = 0

In [None]:
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = './logs/'
filename = 'model_noise_est'

In [None]:
plt.figure(figsize=(15,5))
plt.subplot(131)
plt.plot(range(training_steps)[100:], train_losses[100:])
plt.xlabel('steps')
plt.ylabel('value')
plt.title('training loss')

plt.subplot(132)
plt.plot(test_steps[1:], test_losses[1:])
plt.xlabel('steps')
plt.ylabel('value')
plt.title('test loss')

plt.subplot(133)
plt.plot(range(training_steps), lrs)
plt.xlabel('steps')
plt.ylabel('value')
plt.title('learning rate')

plt.savefig(log_dir+filename+'_'+current_time+'.png')
plt.show()

In [None]:
total_test_loss = []
for (batch_test_X, batch_test_Y) in test_dataset:
    pred_test_Y = model(batch_test_X)
    test_loss = loss_func(pred_test_Y, batch_test_Y)
    total_test_loss.append(test_loss.numpy())
total_test_loss = np.mean(total_test_loss)

print("Test data loss: {:.5f}".format(total_test_loss))

In [None]:
# draw test figures
test_x = test_X[:batch_size] 
test_y = test_Y[:batch_size] 
pred_y = model(test_x)
    
plt.figure(figsize = (15,5*batch_size))
i = 1
    
for n in range(batch_size):
    plt.subplot(batch_size,3,i)
    plt.imshow(test_x[n].squeeze(), cmap='gray')
    plt.title('noise var {:.3f}'.format(test_y[n].mean()))
    plt.axis('off')
    i += 1

    plt.subplot(batch_size,3,i)
    plt.imshow(test_y[n].squeeze(), cmap='gray')
    plt.axis('off')
    i += 1
    
    plt.subplot(batch_size,3,i)
    plt.imshow(pred_y[n].numpy().squeeze(), cmap='gray')
    plt.title('est var {:.3f}'.format(pred_y[n].numpy().mean()))
    plt.axis('off')
    i += 1

plt.savefig('./results/images/'+filename+'_'+current_time+'.png')
plt.show()

In [None]:
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

def error(x1, x2, mode='mse'):
    if mode == 'mse':
        return np.mean(np.square(x1-x2))
    elif mode == 'mae':
        return np.mean(np.abs(x1-x2))
    return

In [None]:
test_X = []
test_Y = []
pred_Y = []
for inputs, target in test_dataset:
    test_X.append(inputs.numpy())
    test_Y.append(target.numpy())
    
    outputs = model(inputs)
    pred_Y.append(outputs.numpy())

test_X = np.concatenate(test_X, axis=0)
test_Y = np.concatenate(test_Y, axis=0)
pred_Y = np.concatenate(pred_Y, axis=0)

print('Evaluation of ground truth and noised images:')
print('psnr:{:.3f}\tssmi:{:.3f}\tmse:{:.3f}'.format(psnr(test_X[..., 0].squeeze(), test_Y.squeeze(), data_range=1), 
                                        ssim(test_X[..., 0].squeeze(), test_Y.squeeze(), data_range=1),
                                        error(test_X, test_Y)))

print('\nEvaluation of recovered images and noised images:')
print('psnr:{:.3f}\tssmi:{:.3f}\tmse:{:.3f}'.format(psnr(pred_Y, test_Y, data_range=1), 
                                        ssim(pred_Y.squeeze(), test_Y.squeeze(), data_range=1),
                                        error(pred_Y, test_Y)))

print('\nGround Truth:')
print('max:{:.3f}\tmin:{:.3f}\tmean:{:.3f}'.format(test_Y.max(), test_Y.min(), test_Y.mean()))

print('\nNoised images:')
print('max:{:.3f}\tmin:{:.3f}\tmean:{:.3f}'.format(test_X[..., 0].max(), test_X[..., 0].min(), test_X.mean()))

print('\nRecoverd images:')
print('max:{:.3f}\tmin:{:.3f}\tmean:{:.3f}'.format(pred_Y.max(), pred_Y.min(), pred_Y.mean()))

In [None]:
# draw loss
# 1.在命令行输入：
# python -m tensorboard.main --logdir logs
# 2.在浏览器输入
# http://localhost:6006

In [None]:
# Save TF model.
model.save_weights(filepath="model_weights/"+filename+".ckpt")

In [None]:
for i, v in enumerate(model.trainable_variables):
    print(i, v)