In [None]:
import numpy as np
import matplotlib.pyplot as plt
import datetime
from sklearn.model_selection import train_test_split

import tensorflow as tf
from tensorflow import keras

from data import read_data
from utils import add_noise_est, normalize, add_noise, squeeze_patches

gpu_ok = tf.test.is_gpu_available()
print("tf version:", tf.__version__)
print("use GPU:", gpu_ok)

In [None]:
seed = 42
np.random.seed(seed)

# Data
## 1. DIV2K data

In [None]:
(train_X_p, train_Y_p), (test_X_p, test_Y_p) = read_data('div2k')
N_ims= len(train_X_p)

train_X_p, label_train_X_p = squeeze_patches(train_X_p)
train_Y_p, label_train_Y_p = squeeze_patches(train_Y_p)
test_X_p, label_test_X_p = squeeze_patches(test_X_p)
test_Y_p, label_test_Y_p = squeeze_patches(test_Y_p)

train_X_p = train_X_p[:,np.newaxis,...]
train_Y_p = train_Y_p[...,np.newaxis]
test_X_p = test_X_p[:,np.newaxis,...]
test_Y_p = test_Y_p[...,np.newaxis]

print('\nTrain data:')
print('train_X_p:',train_X_p.shape)
print('train_Y_p:',train_Y_p.shape)

print('\nTest data:')
print('test_X_p:',test_X_p.shape)
print('test_Y_p:',test_Y_p.shape)

## 2. Ultrasound data

In [None]:
'''Visualization'''
N_show = 20

plt.figure(figsize = (10,5*N_show))
for i in range(N_show):
    n = np.random.randint(test_X_p.shape[0], size = 1)

    plt.subplot(N_show, 2, 2*i+1)
    plt.imshow(train_X_p[n][...,0].squeeze(), cmap='gray')
    #plt.title('noise map std: {:.3f}'.format(train_X_p[n][...,1].std()))
    
    plt.subplot(N_show, 2, 2*i+2)
    plt.imshow(train_Y_p[n].squeeze(), cmap='gray')
    
#plt.savefig('exemple.png')    
plt.show()

In [None]:
print(test_X_p.mean(), test_X_p.std(), test_X_p.min(), test_X_p.max())
print(test_Y_p.mean(), test_Y_p.std(), test_Y_p.min(), test_Y_p.max())

In [None]:
use_noise_map = False   # if True, concatenate a noise map to the input
#use_noise_est = False   # if True, use a model to estimate noise map, if False, use known info

if not use_noise_map:
    train_X_p = train_X_p[...,0][..., np.newaxis]
    test_X_p = test_X_p[...,0][..., np.newaxis]

#train_X_p = train_X_p.squeeze(1)
#test_X_p = test_X_p.squeeze(1)
    
print('Train data:')
print('train_X_p:',train_X_p.shape)
print('train_Y_p:',train_Y_p.shape)

print('\nTest data:')
print('test_X_p:',test_X_p.shape)
print('test_Y_p:',test_Y_p.shape)

In [None]:
N_train = train_X_p.shape[0]
N_test = test_X_p.shape[0]

# training hyperparameters
batch_size = 16

In [None]:
# Use tf.data API to shuffle and batch data.
test_dataset = tf.data.Dataset.from_tensor_slices((test_X_p,test_Y_p))
test_dataset = test_dataset.batch(batch_size).prefetch(1)

# Method 
## Md1: kernel simulation by the features of patches

In [None]:
import pickle

with open('kernels/patches50.txt', 'rb') as f:
    patches_copy = pickle.loads(f.read())
    patches = []
    for k,v in patches_copy.items():
        patches.append(v)
    patches = np.stack(patches, axis=0)
        
with open('kernels/kernels50.txt', 'rb') as f:
    kernels_copy = pickle.loads(f.read())
    kernels = []
    for k,v in kernels_copy.items():
        kernels.append(v)
    kernels = np.stack(kernels, axis=0)
    
print(patches.shape)
print(kernels.shape)

## Test the data with the hybrid method

***Plusieurs critères sont possibles: NCC, SSIM, ou selon la direction***

In [None]:
def simulate_patches(batch_test_X_flatten, learned_patches):
    '''
    simulate the patches by the few learned patches and return the labels of their corresponding kernels
    batch_test_X_flatten: (nb_patches, 9)
    learned_patches: (nb_learned_patches, 9)
    '''
    batch_test_X_flatten_normalized = (batch_test_X_flatten-tf.reduce_mean(batch_test_X_flatten, axis=-1, keepdims=True)) / (tf.math.reduce_std(batch_test_X_flatten, axis=-1, keepdims=True))
    learned_patches_normalized = (learned_patches-tf.reduce_mean(learned_patches, axis=-1, keepdims=True)) / (tf.math.reduce_std(learned_patches, axis=-1, keepdims=True))
    
    score = tf.matmul(batch_test_X_flatten_normalized, tf.transpose(learned_patches_normalized))
    labels = tf.math.argmax(score, axis=-1)

    return labels

def apply_filtering(frames, core, kernel_size):
    img_stack = []
    pred_img = []
    kernel = kernel_size[::-1]
    for index, K in enumerate(kernel):
        if not len(img_stack):
            frame_pad = tf.pad(frames, paddings=[[0,0], [0,0], [K//2,K//2], [K//2,K//2], [0,0]], mode='constant')
            for i in range(K):
                for j in range(K):
                    img_stack.append(frame_pad[:, :, i:i+height, j:j+width,:])
            img_stack = tf.stack(img_stack, axis=-1)                 # (bs, N, h, w，color, K*K) 
        else:
            # k_diff = (kernel[index - 1]**2 - kernel[index]**2) // 2
            k_diff = (kernel[index-1] - kernel[index]) // 2
            k_chosen = []
            for i in range(k_diff, kernel[index-1]-k_diff):
                k_chosen += [i*kernel[index-1]+j for j in range(k_diff, kernel[index-1]-k_diff)]
            # img_stack = img_stack[..., k_diff:-k_diff]
            img_stack = tf.convert_to_tensor(img_stack.numpy()[..., k_chosen])
        pred_img.append(tf.reduce_sum(tf.math.multiply(core[K], img_stack), axis=-1, keepdims=False))
    pred_img = tf.stack(pred_img, axis=0)                           # (nb_kernels, bs, N, h, w, color)
    pred_img_i = tf.reduce_mean(pred_img, axis=0, keepdims=False)   # (bs, N, h, w, color)

    #pred_img_i += bias

    pred_img = tf.reduce_mean(pred_img_i, axis=1, keepdims=False)          # (bs, h, w, color)
    return pred_img, pred_img_i

In [None]:
K = 3
for step, (batch_test_X, batch_test_Y) in enumerate(test_dataset.take(1)):
    '''fetch the patches over every pixel'''
    batch_size, N, height, width, color = tf.expand_dims(batch_test_X[...,0], axis=-1).shape 
    batch_test_X_flatten = []
    frame_pad = tf.pad(batch_test_X, paddings=[[0,0], [0,0], [K//2,K//2], [K//2,K//2], [0,0]], mode='constant')
    for i in range(K):
        for j in range(K):
            batch_test_X_flatten.append(frame_pad[:, :, i:i+height, j:j+width,:])
    batch_test_X_flatten = tf.stack(batch_test_X_flatten, axis=-1)       
    batch_test_X_flatten = batch_test_X_flatten.numpy().reshape(-1, 9)
    print(batch_test_X_flatten.shape)
    
    '''simulate the patches by the few learned patches and use the corresponding kernels'''
    core = kernels[simulate_patches(batch_test_X_flatten, patches)]
    print(core.shape)
    
    core = core.reshape(batch_size, N, height, width, color, -1)
    core = dict({3: core}) # use dict
    print(core[3].shape)
    
    pred_test_Y, _ = apply_filtering(batch_test_X, core, kernel_size = [3])
    print(pred_test_Y.shape)

In [None]:
plt.figure(figsize = (30,80))
for i in range(16):
    plt.subplot(16, 6, 6*i+1)
    plt.imshow(batch_test_X[i, ...,0].numpy().squeeze(), cmap='gray')
    plt.title('noisy image')
    plt.axis('off')
    
    plt.subplot(16, 6, 6*i+2)
    plt.imshow(batch_test_Y[i].numpy().squeeze(), cmap='gray')
    plt.title('ground truth')
    plt.axis('off')
    
    plt.subplot(16, 6, 6*i+3)
    plt.imshow(pred_test_Y[i].numpy().squeeze(), cmap='gray')
    plt.title('recovered image')
    plt.axis('off')
    
#     plt.subplot(16, 6, 6*i+4)
#     plt.imshow(pred_test_Y3_clustered[i].numpy().squeeze(), cmap='gray')
#     plt.title('recovered image by clustered kernels')
#     plt.axis('off')

#     plt.subplot(16, 6, 6*i+5)
#     plt.imshow(tf.reduce_mean(core[5][i], axis=-1).numpy().squeeze(), cmap='gray')
#     plt.title('filter 5x5 {:.3f}'.format(tf.reduce_mean(core[5][i]).numpy().squeeze()))
#     plt.axis('off')
    
#     plt.subplot(16, 6, 6*i+6)
#     plt.imshow(tf.reduce_mean(core[7][i], axis=-1).numpy().squeeze(), cmap='gray')
#     plt.title('filter 7x7 {:.3f}'.format(tf.reduce_mean(core[7][i]).numpy().squeeze()))
#     plt.axis('off')
    
# plt.savefig('./eval/' + sub_dir + '/kpn3/recovered_images_by_30clustered_kernels.png')
plt.show()

## Use the neural network to choose the kernels

In [None]:
total_test_loss = []
for (batch_test_X, batch_test_Y) in test_dataset:
    #pred_test_Y = model(batch_test_X)
    pred_test_Y, _, core = model(batch_test_X, tf.expand_dims(batch_test_X[...,0], axis=-1))
    
#     batch_size, N, height, width, color = tf.expand_dims(batch_test_X[...,0], axis=-1).shape 
#     core, _ = model.kernel_pred._convert_dict(core, batch_size, N, height, width, color)
#     test_loss = loss_func(pred_test_Y, batch_test_Y, core)
    test_loss = loss_func(pred_test_Y, batch_test_Y)
                          
    total_test_loss.append(test_loss.numpy())
total_test_loss = np.mean(total_test_loss)

print("Test data loss: {:.3f}".format(total_test_loss))

In [None]:
# draw test figures
n_chosen = np.random.randint(test_X_p.shape[0], size = 16)
test_x = test_X_p[n_chosen] 
test_y = test_Y_p[n_chosen] 
#pred_y = model(test_x)
pred_y, _, _ = model(test_x, tf.expand_dims(test_x[...,0], axis=-1))
    
plt.figure(figsize = (15,5*batch_size))
i = 1
    
for n in range(batch_size):
    plt.subplot(batch_size,3,i)
    plt.imshow(test_x[n][...,0].squeeze(), cmap='gray')
    #plt.title('noise var {:.3f}'.format(test_x[n][...,1].mean()))
    plt.axis('off')
    i += 1

    plt.subplot(batch_size,3,i)
    plt.imshow(test_y[n].squeeze(), cmap='gray')
    plt.axis('off')
    i += 1
    
    plt.subplot(batch_size,3,i)
    plt.imshow(pred_y[n].numpy().squeeze(), cmap='gray')
    plt.axis('off')
    i += 1

plt.savefig('./results/images/' + sub_dir + '/' + filename + '_' + current_time + '.png')
plt.show()

In [None]:
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

def error(x1, x2, mode='mse'):
    if mode == 'mse':
        return np.mean(np.square(x1-x2))
    elif mode == 'mae':
        return np.mean(np.abs(x1-x2))
    return

In [None]:
test_X = []
test_Y = []
pred_Y = []
for inputs, target in test_dataset:
    test_X.append(inputs.numpy())
    test_Y.append(target.numpy())
    
    #outputs = model(inputs)
    outputs, _, _ = model(inputs, tf.expand_dims(inputs[...,0], axis=-1))
    pred_Y.append(outputs.numpy())

test_X = np.concatenate(test_X, axis=0)
test_Y = np.concatenate(test_Y, axis=0)
pred_Y = np.concatenate(pred_Y, axis=0)

print('Evaluation of ground truth and noised images:')
print('psnr:{:.3f}\tssmi:{:.3f}\tmse:{:.3f}'.format(psnr(test_X[..., 0].squeeze(), test_Y.squeeze(), data_range=1), 
                                        ssim(test_X[..., 0].squeeze(), test_Y.squeeze(), data_range=1),
                                        error(test_X[..., 0].squeeze(), test_Y.squeeze())))

print('\nEvaluation of recovered images and noised images:')
print('psnr:{:.3f}\tssmi:{:.3f}\tmse:{:.3f}'.format(psnr(pred_Y, test_Y, data_range=1), 
                                        ssim(pred_Y.squeeze(), test_Y.squeeze(), data_range=1),
                                        error(pred_Y, test_Y)))

print('\nGround Truth:')
print('max:{:.3f}\tmin:{:.3f}\tmean:{:.3f}'.format(test_Y.max(), test_Y.min(), test_Y.mean()))

print('\nNoised images:')
print('max:{:.3f}\tmin:{:.3f}\tmean:{:.3f}'.format(test_X[..., 0].max(), test_X[..., 0].min(), test_X.mean()))

print('\nRecoverd images:')
print('max:{:.3f}\tmin:{:.3f}\tmean:{:.3f}'.format(pred_Y.max(), pred_Y.min(), pred_Y.mean()))