In [None]:
import numpy as np
import matplotlib.pyplot as plt
import datetime
from sklearn.model_selection import train_test_split

import tensorflow as tf
from tensorflow import keras

from data import read_data
from utils import add_noise_est, normalize, add_noise, squeeze_patches

#from model_global_dfn import GDFN
from model_baseline import Unet
from model_mwcnn import MWCNN
from model_mwkpn import MWKPN
from model_kpn import KPN, LossFunc, LossBasic

gpu_ok = tf.test.is_gpu_available()
print("tf version:", tf.__version__)
print("use GPU:", gpu_ok)

# transfer learning - from speckle noise to gaussian noise

In [None]:
ims = read_data('imagenet')

In [None]:
N_ims, h, w, _ = ims.shape
ims = ims[:N_ims].astype(np.float32)

noise_est = False        # if True, use model to estimate the noise; if Fasle, use known noise info

variances = np.arange(1,11) * 5e-4
#variances = [15*5e-4]

# for noise estimation map, there should be different noise levels
ims_noise = []
ims_noise_with_est = []
ims_split = np.array_split(ims, len(variances))
for i, var in enumerate(variances):
    ims_noise.append(normalize(add_noise(ims_split[i], mean=0, var=var, n_type='gaussian')))
    if not noise_est:
        ims_noise_with_est.append(add_noise_est(ims_noise[i], if_est=False, var = var))
    else:
        ims_noise_with_est.append(add_noise_est(ims_noise[i], if_est=True))
ims_noise = np.concatenate(ims_noise)
ims_noise_with_est = np.concatenate(ims_noise_with_est)

print(ims.shape)
print(ims_noise.shape)
print(ims_noise_with_est.shape)

In [None]:
# train test split
test_size = 0.1

train_X, train_Y = ims_noise, ims
#train_X, train_Y = ims_noise_with_est, ims
train_X, test_X, train_Y, test_Y = train_test_split(train_X, train_Y, test_size=test_size, random_state=42)

train_X = train_X[:,np.newaxis,...]
test_X = test_X[:,np.newaxis,...]

print('Training X: ', train_X.shape, train_X.dtype)
print('Training Y: ', train_Y.shape, train_Y.dtype)
print('Testing X: ', test_X.shape, test_X.dtype)
print('Testing Y: ', test_Y.shape, test_Y.dtype)

In [None]:
# Use tf.data API to shuffle and batch data.
train_dataset = tf.data.Dataset.from_tensor_slices((train_X,train_Y))
train_dataset = train_dataset.repeat().shuffle(5000).batch(batch_size).prefetch(1)

test_dataset = tf.data.Dataset.from_tensor_slices((test_X,test_Y))
test_dataset = test_dataset.batch(batch_size).prefetch(1)

In [None]:
model = KPN(color=False, burst_length=1, blind_est=True, sep_conv=False, kernel_size=[3,5,7],
            channel_att=False, spatial_att=True, core_bias=True, use_bias=True)
#model = MWKPN(color=False, burst_length=1, blind_est=True, sep_conv=False, kernel_size=[3,5,7],
#             channel_att=False, spatial_att=True, core_bias=True, use_bias=True)

filename = 'kpn_ks357_satt_bias_combinedloss_nvar'

load_model = True
if load_model:
    model.load_weights(filepath = "model_weights/transfer_to_speckle/" + filename + ".ckpt")

In [None]:
loss_func = LossBasic(gradient_L1 = True)

total_test_loss = []
for (batch_test_X, batch_test_Y) in test_dataset:
    #pred_test_Y, _ = model(batch_test_X, batch_test_X)
    pred_test_Y, _, _ = model(batch_test_X, tf.expand_dims(batch_test_X[...,0], axis=-1))
    test_loss = loss_func(pred_test_Y, batch_test_Y)
    total_test_loss.append(test_loss.numpy())
total_test_loss = np.mean(total_test_loss)

print("Test data loss: {:.3f}".format(total_test_loss))

In [None]:
# draw test figures
test_x = test_X[:batch_size] 
test_y = test_Y[:batch_size] 
#pred_y, _ = model(test_x, test_x)
pred_y, _, _  = model(test_x, tf.expand_dims(test_x[...,0], axis=-1))
    
plt.figure(figsize = (15,5*batch_size))
i = 1
    
for n in range(batch_size):
    plt.subplot(batch_size,3,i)
    plt.imshow(test_x[n][...,0].squeeze(), cmap='gray')
    #plt.title('noise var {:.3f}'.format(test_x[n][...,1].mean()))
    plt.axis('off')
    i += 1

    plt.subplot(batch_size,3,i)
    plt.imshow(test_y[n].squeeze(), cmap='gray')
    plt.axis('off')
    i += 1
    
    plt.subplot(batch_size,3,i)
    plt.imshow(pred_y[n].numpy().squeeze(), cmap='gray')
    plt.axis('off')
    i += 1

plt.show()