We will here go through an example of using a reserch-grade de-noising model. The model is TomoGAN (Liu, Z., Bicer, T., Kettimuthu, R., Gursoy, D., De Carlo, F. and Foster, I., 2020. TomoGAN: low-dose synchrotron x-ray tomography with generative adversarial networks: discussion. JOSA A, vol. 37(3), pp.422-434, and arXiv:1902.07582). Because training this model takes several hours on GPU clusters, we will not here train it but instead use a pre-trained model on realistic data sets.

We start - as usual - by importing various python libraries.

In [None]:
import tensorflow as tf 
import numpy as np 
from matplotlib import pyplot as plt
import sys, time, imageio, h5py, skimage, glob, os, shutil

In [None]:
# overwrite anyway 
if os.path.isdir('dataset'): 
    shutil.rmtree('dataset')
os.mkdir('dataset') # to save temp output

!wget -O dataset/demo-dataset-real4test.h5 https://raw.githubusercontent.com/AIScienceTutorial/Denoising/main/dataset/demo-dataset-real4test.h5

We will then load noisy images.  

In [None]:
with h5py.File('dataset/demo-dataset-real4test.h5', 'r') as h5fd:
    ns_img_test_real = h5fd["test_ns"][:]
    gt_img_test_real = h5fd["test_gt"][:]

Next, we load the pre-trained model and print out a summary of the model characteristics.

In [None]:
# overwrite anyway 
if os.path.isdir('model'): 
    shutil.rmtree('model')
os.mkdir('model') # to save temp output

!wget -O model/TomoGAN.h5 https://raw.githubusercontent.com/AIScienceTutorial/Denoising/main/model/TomoGAN.h5

In [None]:
TomoGAN_mdl = tf.keras.models.load_model('model/TomoGAN.h5', )
TomoGAN_mdl.summary()

We then apply TomoGAN to the noisy images and output the results.

In [None]:
for _idx in range(ns_img_test_real.shape[0])[:2]:
    # just one line of code to denoise, need to remember that the input is four dimension [n, h, w, c] 
    dn_img = TomoGAN_mdl.predict(ns_img_test_real[_idx:_idx+1,:,:,np.newaxis]).squeeze()
    
    plt.figure(figsize=(15, 5))
    plt.subplot(131)
    plt.imshow(ns_img_test_real[_idx, 200:-100, 200:-100], cmap='gray')
    plt.title('Noisy/Input', fontsize=18)
    plt.subplot(132)
    plt.imshow(gt_img_test_real[_idx, 200:-100, 200:-100], cmap='gray')
    plt.title('Clean/Label', fontsize=18)
    plt.subplot(133)
    plt.imshow(dn_img[200:-100, 200:-100], cmap='gray')
    plt.title('Denoised/output', fontsize=18)
    plt.tight_layout(); plt.show(); plt.close()

The de-noising can be done in batches to accelerate the process.

In [None]:
batch_sz = 4
_idx = 0
tick = time.time()
dn_img = TomoGAN_mdl.predict(ns_img_test_real[_idx:_idx+batch_sz,:,:,np.newaxis]).squeeze()
print('It takes %.1f seconds to denoise %d, %dx%d images,' % (time.time() - tick, batch_sz, \
                                                            ns_img_test_real.shape[1],\
                                                            ns_img_test_real.shape[2]))
dn_img.shape

In [None]:
for _img_dn in dn_img:
    plt.figure(figsize=(7, 7))
    plt.imshow(_img_dn[200:-100, 200:-100], cmap='gray')
    plt.title('Denoised/output', fontsize=18)
    plt.show(); plt.close()