In [None]:
import numpy as np
import matplotlib.pyplot as plt
import datetime
from sklearn.model_selection import train_test_split

import tensorflow as tf
from tensorflow import keras

from data import read_data
from utils import add_noise_est, normalize, add_noise, squeeze_patches

#from model_global_dfn import GDFN
from model_baseline import Unet
from model_mwcnn import MWCNN
from model_mwkpn import MWKPN
from model_kpn import KPN, LossFunc, LossBasic

gpu_ok = tf.test.is_gpu_available()
print("tf version:", tf.__version__)
print("use GPU:", gpu_ok)

# Analyse de l'influence de l'ondelette - Unet - Speckle

In [None]:
'''préparation des données'''
seed = 42
np.random.seed(seed)

(train_X_p, train_Y_p), (test_X_p, test_Y_p) = read_data('speckle')
N_ims= len(train_X_p)

train_X_p, label_train_X_p = squeeze_patches(train_X_p)
train_Y_p, label_train_Y_p = squeeze_patches(train_Y_p)
test_X_p, label_test_X_p = squeeze_patches(test_X_p)
test_Y_p, label_test_Y_p = squeeze_patches(test_Y_p)

train_X_p = train_X_p[:,np.newaxis,...]
train_Y_p = train_Y_p[...,np.newaxis]
test_X_p = test_X_p[:,np.newaxis,...]
test_Y_p = test_Y_p[...,np.newaxis]

print('\nTrain data:')
print('train_X_p:',train_X_p.shape)
print('train_Y_p:',train_Y_p.shape)

print('\nTest data:')
print('test_X_p:',test_X_p.shape)
print('test_Y_p:',test_Y_p.shape)

In [None]:
use_noise_map = False   # if True, concatenate a noise map to the input
#use_noise_est = False   # if True, use a model to estimate noise map, if False, use known info

if not use_noise_map:
    train_X_p = train_X_p[...,0][..., np.newaxis]
    test_X_p = test_X_p[...,0][..., np.newaxis]

train_X_p = train_X_p.squeeze(1)
test_X_p = test_X_p.squeeze(1)
    
print('Train data:')
print('train_X_p:',train_X_p.shape)
print('train_Y_p:',train_Y_p.shape)

print('\nTest data:')
print('test_X_p:',test_X_p.shape)
print('test_Y_p:',test_Y_p.shape)

In [None]:
# Use tf.data API to shuffle and batch data.
batch_size = 16

train_dataset = tf.data.Dataset.from_tensor_slices((train_X_p,train_Y_p))
train_dataset = train_dataset.repeat().shuffle(5000).batch(batch_size).prefetch(1)

test_dataset = tf.data.Dataset.from_tensor_slices((test_X_p,test_Y_p))
test_dataset = test_dataset.batch(batch_size).prefetch(1)

In [None]:
#model = Unet(color=False, kernel_size=5, channel_att=False, spatial_att=True, if_wavelet=True)
model = MWCNN(color = False, kernel_size=3, channel_att=False, spatial_att=False)

#filename = 'unet_satt_bias_combinedloss'
filename = 'mwcnn_satt_combinedloss_nvar'

load_model = True
if load_model:
    model.load_weights(filepath = "model_weights/transfer_to_speckle/" + filename + ".ckpt")

In [None]:
for step, (batch_test_X, batch_test_Y) in enumerate(test_dataset.take(1)):
    pred_test_Y, pred_test_Y_wavelet = model(batch_test_X)
    
    pred_test_Y = pred_test_Y.numpy()
    pred_test_Y_wavelet = pred_test_Y_wavelet.numpy()
    batch_test_Y = batch_test_Y.numpy()
    batch_test_X = batch_test_X.numpy()
    
    print(pred_test_Y.shape)
    print(pred_test_Y_wavelet.shape)
    
#pred_test_Y_wavelet = np.split(pred_test_Y_wavelet, 4, -1)  # a list of length 4 with every element: (16,64,64,1)

In [None]:
import pywt

plt.figure(figsize = (25,10*batch_size))
for i in range(batch_size):
    '''GT'''
    original = batch_test_Y[i].squeeze()
    LL, (LH, HL, HH) = pywt.dwt2(original, 'haar')
    
    plt.subplot(batch_size*2, 5, 10*i+1)
    plt.imshow(original, cmap='gray')
    plt.title('original image')

    plt.subplot(batch_size*2, 5, 10*i+2)
    plt.imshow(LL, cmap='gray')
    plt.title('approximation')
    
    plt.subplot(batch_size*2, 5, 10*i+3)
    plt.imshow(LH, cmap='gray')
    plt.title('horizeontal detail')
    
    plt.subplot(batch_size*2, 5, 10*i+4)
    plt.imshow(HL, cmap='gray')
    plt.title('vertical detail')
    
    plt.subplot(batch_size*2, 5, 10*i+5)
    plt.imshow(HH, cmap='gray')
    plt.title('diagonal detail')
    #plt.axis('off')
    
    '''predictions'''
    plt.subplot(batch_size*2, 5, 10*i+6)
    plt.imshow(pred_test_Y[i].squeeze(), cmap='gray')
    plt.title('recovered image')

    plt.subplot(batch_size*2, 5, 10*i+7)
    plt.imshow(pred_test_Y_wavelet[i][:,:,0].squeeze(), cmap='gray')
    plt.title('1st detail')
    
    plt.subplot(batch_size*2, 5, 10*i+8)
    plt.imshow(pred_test_Y_wavelet[i][:,:,4].squeeze(), cmap='gray')
    plt.title('5th detail')
    
    plt.subplot(batch_size*2, 5, 10*i+9)
    plt.imshow(pred_test_Y_wavelet[i][:,:,8].squeeze(), cmap='gray')
    plt.title('9th detail')
    
    plt.subplot(batch_size*2, 5, 10*i+10)
    plt.imshow(pred_test_Y_wavelet[i][:,:,12].squeeze(), cmap='gray')
    plt.title('13th detail')
    #plt.axis('off')
    
plt.savefig('./eval/mwcnn.png')
plt.show()