## 1. Import Library

In [1]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from glob import glob
from PIL import Image
from scipy import misc
from sklearn.model_selection import train_test_split

## 2. Define Function

In [2]:
# define simple but powerful data augmentation technique
def data_augmentation(image, label):
    cut_size = int(crop_size / 4)
    
    x_size = image.shape[0]
    y_size = image.shape[1]
    start_x = np.random.randint(x_size-cut_size)
    start_y = np.random.randint(y_size-cut_size)

    cut_label = label[start_x:start_x+cut_size, start_y:start_y+cut_size].copy()

    image[start_x:start_x+cut_size, start_y:start_y+cut_size] = cut_label
    
    return image, label

In [3]:
# set parameters
crop_size = 256
factor = 2

# define training data loader
current_batch = 0
def train_batch_maker(batch_size):
    global current_batch
    global train_defocus_files
    global train_infocus_files

    if len(train_defocus_files) - current_batch >= batch_size:
        batch_train_defocus_files = train_defocus_files[current_batch:current_batch+batch_size]
        batch_train_infocus_files = train_infocus_files[current_batch:current_batch+batch_size]
        current_batch += batch_size
    else :
        idx_train = np.arange(len(train_defocus_files))
        np.random.shuffle(idx_train)
        batch_train_defocus_files = train_defocus_files[idx_train]
        batch_train_infocus_files = train_infocus_files[idx_train]

        current_batch = 0
        batch_train_defocus_files = train_defocus_files[current_batch:current_batch+batch_size]
        batch_train_infocus_files = train_infocus_files[current_batch:current_batch+batch_size]

    train_defocus_coarsest = []
    train_defocus_intermediate = []
    train_defocus_finer = []

    train_infocus_coarsest = []
    train_infocus_intermediate = []
    train_infocus_finer = []

    for image, label in zip(batch_train_defocus_files, batch_train_infocus_files):
        temp_image = Image.open(image)
        temp_image = np.array(temp_image)
        temp_label = Image.open(label)
        temp_label = np.array(temp_label)

        x_size = temp_image.shape[0]
        y_size = temp_image.shape[1]
        start_x = np.random.randint(x_size-crop_size)
        start_y = np.random.randint(y_size-crop_size)

        temp_image = temp_image[start_x:start_x+crop_size, start_y:start_y+crop_size]
        temp_label = temp_label[start_x:start_x+crop_size, start_y:start_y+crop_size]

        temp_image, temp_label = data_augmentation(temp_image, temp_label)
        
        temp_image_finer = temp_image.copy()[:,:,np.newaxis]
        temp_image_intermediate = misc.imresize(temp_image, 1.0/factor, interp = 'bicubic')[:,:,np.newaxis]
        temp_image_coarsest = misc.imresize(temp_image, 1.0/(factor**2), interp = 'bicubic')[:,:,np.newaxis]

        temp_label_finer = temp_label.copy()[:,:,np.newaxis]
        temp_label_intermediate = misc.imresize(temp_label, 1.0/factor, interp = 'bicubic')[:,:,np.newaxis]
        temp_label_coarsest = misc.imresize(temp_label, 1.0/(factor**2), interp = 'bicubic')[:,:,np.newaxis]

        train_defocus_coarsest.append(temp_image_coarsest / 255.0)
        train_defocus_intermediate.append(temp_image_intermediate / 255.0)
        train_defocus_finer.append(temp_image_finer / 255.0)

        train_infocus_coarsest.append(temp_label_coarsest / 255.0)
        train_infocus_intermediate.append(temp_label_intermediate / 255.0)
        train_infocus_finer.append(temp_label_finer / 255.0)

    train_defocus_coarsest = np.array(train_defocus_coarsest)
    train_defocus_intermediate = np.array(train_defocus_intermediate)
    train_defocus_finer = np.array(train_defocus_finer)

    train_infocus_coarsest = np.array(train_infocus_coarsest)
    train_infocus_intermediate = np.array(train_infocus_intermediate)
    train_infocus_finer = np.array(train_infocus_finer)
    
    return train_defocus_coarsest, train_defocus_intermediate, train_defocus_finer, train_infocus_coarsest, train_infocus_intermediate, train_infocus_finer 

# define validation data loader
def valid_batch_maker(valid_defocus_files, valid_infocus_files):
    valid_defocus_coarsest = []
    valid_defocus_intermediate = []
    valid_defocus_finer = []
    
    valid_infocus_coarsest = []
    valid_infocus_intermediate = []
    valid_infocus_finer = []
    
    for image, label in zip(valid_defocus_files, valid_infocus_files):
        temp_image = Image.open(image)
        temp_image = np.array(temp_image)
        temp_label = Image.open(label)
        temp_label = np.array(temp_label)

        start_x = 250
        start_y = 250

        temp_image = temp_image[start_x:start_x+crop_size, start_y:start_y+crop_size]
        temp_label = temp_label[start_x:start_x+crop_size, start_y:start_y+crop_size]

        temp_image, temp_label = data_augmentation(temp_image, temp_label)

        temp_image_finer = temp_image.copy()[:,:,np.newaxis]
        temp_image_intermediate = misc.imresize(temp_image, 1.0/factor, interp = 'bicubic')[:,:,np.newaxis]
        temp_image_coarsest = misc.imresize(temp_image, 1.0/(factor**2), interp = 'bicubic')[:,:,np.newaxis]

        temp_label_finer = temp_label.copy()[:,:,np.newaxis]
        temp_label_intermediate = misc.imresize(temp_label, 1.0/factor, interp = 'bicubic')[:,:,np.newaxis]
        temp_label_coarsest = misc.imresize(temp_label, 1.0/(factor**2), interp = 'bicubic')[:,:,np.newaxis]

        valid_defocus_coarsest.append(temp_image_coarsest / 255.0)
        valid_defocus_intermediate.append(temp_image_intermediate / 255.0)
        valid_defocus_finer.append(temp_image_finer / 255.0)

        valid_infocus_coarsest.append(temp_label_coarsest / 255.0)
        valid_infocus_intermediate.append(temp_label_intermediate / 255.0)
        valid_infocus_finer.append(temp_label_finer / 255.0)

    valid_defocus_coarsest = np.array(valid_defocus_coarsest)
    valid_defocus_intermediate = np.array(valid_defocus_intermediate)
    valid_defocus_finer = np.array(valid_defocus_finer)

    valid_infocus_coarsest = np.array(valid_infocus_coarsest)
    valid_infocus_intermediate = np.array(valid_infocus_intermediate)
    valid_infocus_finer = np.array(valid_infocus_finer)
    
    return valid_defocus_coarsest, valid_defocus_intermediate, valid_defocus_finer, valid_infocus_coarsest, valid_infocus_intermediate, valid_infocus_finer 

## 3. Load Data

In [4]:
## glob your dataset for your own MRN
# train_defocus_files = np.sort(np.array(glob())
# train_infocus_files = np.sort(np.array(glob())
# valid_defocus_files = np.sort(np.array(glob())
# valid_infocus_files = np.sort(np.array(glob())

test_defocus_files = glob('./data_files/*.png')

## 4. Implement Multi-scale Refocusing Network (MRN)

### 4.1. Build MRN Architecture

In [5]:
x_coarsest = tf.placeholder(tf.float32, [None, None, None, 1], name='x_coarsest')
x_intermediate = tf.placeholder(tf.float32, [None, None, None, 1], name='x_intermediate')
x_finer = tf.placeholder(tf.float32, [None, None, None, 1], name='x_finer')

y_coarsest = tf.placeholder(tf.float32, [None, None, None, 1], name='y_coarsest')
y_intermediate = tf.placeholder(tf.float32, [None, None, None, 1], name='y_intermediate')
y_finer = tf.placeholder(tf.float32, [None, None, None, 1], name='y_finer')

In [6]:
def ResidualBlock(x, kernel_size, filters, strides = 1):
    skip = x
    x = tf.layers.conv2d(x, 
                         kernel_size = kernel_size,
                         filters = filters,
                         strides = strides,
                         padding = 'same',
                         use_bias = False)
    x = tf.contrib.keras.layers.PReLU(shared_axes = [1,2])(x)
    x = tf.layers.conv2d(x,
                         kernel_size = kernel_size,
                         filters = filters,
                         strides = strides,
                         padding = 'same',
                         use_bias = False)
    x = x + skip
    return x

def Upsample2xBlock(x, kernel_size, filters, name, strides = 1):
    with tf.variable_scope(name, reuse=tf.AUTO_REUSE) as scope:
        x = tf.layers.conv2d(x,
                             kernel_size = kernel_size,
                             filters = filters,
                             strides = strides,
                             padding = 'same')
        x = tf.depth_to_space(x, 2)
        x = tf.nn.relu(x)
        return x

In [7]:
def resnet_coarsest(x, num_blocks):
    with tf.variable_scope('resnet_coarsest', reuse=tf.AUTO_REUSE) as scope:
        x = tf.layers.conv2d(x,
                             kernel_size = 5,
                             filters = 64,
                             strides = 1,
                             padding = 'same')
        x = tf.contrib.keras.layers.PReLU(shared_axes = [1,2])(x)
        skip = x

        for i in range(num_blocks):
            x = ResidualBlock(x, kernel_size = 5, filters = 64, strides = 1)
            
        x = tf.layers.conv2d(x,
                             kernel_size = 5,
                             filters = 64,
                             strides = 1,
                             padding = 'same',
                             use_bias = False)
        x = x + skip
        
        x = tf.layers.conv2d(x,
                             kernel_size = 5,
                             filters = 1,
                             strides = 1,
                             padding = 'same',
                             name = 'forward')
        return tf.nn.sigmoid(x)

def resnet_intermediate(x, num_blocks):
    with tf.variable_scope('resnet_intermediate', reuse=tf.AUTO_REUSE) as scope:
        x = tf.layers.conv2d(x,
                             kernel_size = 5,
                             filters = 64,
                             strides = 1,
                             padding = 'same')
        x = tf.contrib.keras.layers.PReLU(shared_axes = [1,2])(x)
        skip = x

        for i in range(num_blocks):
            x = ResidualBlock(x, kernel_size = 5, filters = 64, strides = 1)
            
        x = tf.layers.conv2d(x,
                             kernel_size = 5,
                             filters = 64,
                             strides = 1,
                             padding = 'same',
                             use_bias = False)
        x = x + skip
        
        x = tf.layers.conv2d(x,
                             kernel_size = 5,
                             filters = 1,
                             strides = 1,
                             padding = 'same',
                             name = 'forward')
        return tf.nn.sigmoid(x)
    
def resnet_finer(x, num_blocks):
    with tf.variable_scope('resnet_finer', reuse=tf.AUTO_REUSE) as scope:
        x = tf.layers.conv2d(x,
                             kernel_size = 5,
                             filters = 64,
                             strides = 1,
                             padding = 'same')
        x = tf.contrib.keras.layers.PReLU(shared_axes = [1,2])(x)
        skip = x

        for i in range(num_blocks):
            x = ResidualBlock(x, kernel_size = 5, filters = 64, strides = 1)
            
        x = tf.layers.conv2d(x,
                             kernel_size = 5,
                             filters = 64,
                             strides = 1,
                             padding = 'same',
                             use_bias = False)
        x = x + skip
        
        x = tf.layers.conv2d(x,
                             kernel_size = 5,
                             filters = 1,
                             strides = 1,
                             padding = 'same',
                             name = 'forward')
        return tf.nn.sigmoid(x)

### 4.2. Define Loss and Optimizer

In [8]:
# coarsest scale subnetwork
refocus_coarsest = resnet_coarsest(x_coarsest, 16)
refocus_coarsest_upconv = Upsample2xBlock(refocus_coarsest, kernel_size = 3, filters = 4, name = 'upconv_for_intermediate')
refocus_coarsest_upconv_concat = tf.concat((refocus_coarsest_upconv, x_intermediate), axis = 3)

# intermediate scale subnetwork
refocus_intermediate = resnet_intermediate(refocus_coarsest_upconv_concat, 16)
refocus_intermediate_upconv = Upsample2xBlock(refocus_intermediate, kernel_size = 3, filters = 4, name = 'upconv_for_finer')
refocus_intermediate_upconv_concat = tf.concat((refocus_intermediate_upconv, x_finer), axis = 3)

# finer scale subnetwork
refocus_finer = resnet_finer(refocus_intermediate_upconv_concat, 16)

# loss function
loss_coarsest = tf.reduce_mean(tf.abs(y_coarsest - refocus_coarsest))
loss_intermediate = tf.reduce_mean(tf.abs(y_intermediate - refocus_intermediate))
loss_finer = tf.reduce_mean(tf.abs(y_finer - refocus_finer))

# learning rate
LR = 0.00005
global_step = tf.contrib.framework.get_or_create_global_step()
learning_rate = tf.train.exponential_decay(LR, global_step, 50000, 0.1, staircase = False)
incr_global_step = tf.assign(global_step, global_step + 1)

# variable list
var_coarsest = [var for var in tf.get_collection('trainable_variables') if 'resnet_coarsest' in var.name]
var_intermediate = [var for var in tf.get_collection('trainable_variables') if 'resnet_intermediate' in var.name or 'upconv_for_intermediate' in var.name]
var_finer = [var for var in tf.get_collection('trainable_variables') if 'resnet_finer' in var.name or 'upconv_for_finer' in var.name]

# optimizer
optm_coarsest = tf.train.AdamOptimizer(learning_rate).minimize(loss_coarsest, var_list = var_coarsest)
optm_intermediate = tf.train.AdamOptimizer(learning_rate).minimize(loss_intermediate, var_list = var_intermediate)
optm_finer = tf.train.AdamOptimizer(learning_rate).minimize(loss_finer, var_list = var_finer)

### 4.3. Optimize

In [9]:
# # define training parameters
# n_iter = 50000
# n_prt = 100
# n_batch = 5
# save_criteria = 10

# # open tf session
# sess = tf.Session()
# sess.run(tf.global_variables_initializer())
# saver = tf.train.Saver()

# # load validation dataset
# valid_defocus_coarsest, valid_defocus_intermediate, valid_defocus_finer, valid_infocus_coarsest, valid_infocus_intermediate, valid_infocus_finer  = valid_batch_maker(valid_defocus_files, valid_infocus_files)

# # trainining phase
# for epoch in range(n_iter):
    
#     # load batch training dataset
#     train_defocus_coarsest, train_defocus_intermediate, train_defocus_finer, train_infocus_coarsest, train_infocus_intermediate, train_infocus_finer = train_batch_maker(n_batch)
    
#     # optimize MRN
#     sess.run([optm_coarsest, optm_intermediate, optm_finer], feed_dict = {x_coarsest: train_defocus_coarsest, 
#                                                                           x_intermediate: train_defocus_intermediate,
#                                                                           x_finer: train_defocus_finer, 
#                                                                           y_coarsest: train_infocus_coarsest, 
#                                                                           y_intermediate: train_infocus_intermediate, 
#                                                                           y_finer: train_infocus_finer})
#     sess.run(incr_global_step)
    
#     # save a best model
#     criteria_temp = sess.run(loss_finer, feed_dict = {x_coarsest: valid_defocus_coarsest, 
#                                                       x_intermediate: valid_defocus_intermediate,
#                                                       x_finer: valid_defocus_finer, 
#                                                       y_coarsest: valid_infocus_coarsest, 
#                                                       y_intermediate: valid_infocus_intermediate, 
#                                                       y_finer: valid_infocus_finer})
#     if save_criteria > criteria_temp:
#         save_criteria = criteria_temp
#         saver.save(sess, './model/MRN.ckpt')
    
#     # record loss graphs, and print refocused SEM image 
#     if epoch % n_prt == 0:        
        
#         print('Epoch:', '%04d' % epoch)
#         refocus_img = sess.run(refocus_finer, feed_dict = {x_coarsest: train_defocus_coarsest[:1], 
#                                                            x_intermediate: train_defocus_intermediate[:1],
#                                                            x_finer: train_defocus_finer[:1]})
        
#         plt.figure(figsize = (5,5))
#         plt.imshow(refocus_img[0,:,:,0], cmap = 'gray')
#         plt.axis('off')
#         plt.show()

## 5. Evaluation

In [10]:
# load the best model
save_file = './model/MRN.ckpt'

saver = tf.train.Saver()
sess = tf.Session()
saver.restore(sess, save_file)

In [11]:
# refocus SEM images
for defocus_file in test_defocus_files:
    print(defocus_file)
    
    defocus_image = Image.open(defocus_file)
    defocus_image = np.array(defocus_image)
    
    if (defocus_image.shape[0] / factor**2) % 1 != 0 or (defocus_image.shape[1] / factor**2) % 1 != 0:
        new_x_shape = int(defocus_image.shape[0] / factor**2) * factor**2
        new_y_shape = int(defocus_image.shape[1] / factor**2) * factor**2
        defocus_image = defocus_image[:new_x_shape,:new_y_shape]

    defocus_image_finer = defocus_image.copy()[np.newaxis,:,:,np.newaxis] / 255
    defocus_image_intermediate = misc.imresize(defocus_image, 1.0/factor, interp = 'bicubic')[np.newaxis,:,:,np.newaxis] / 255
    defocus_image_coarsest = misc.imresize(defocus_image, 1.0/(factor**2), interp = 'bicubic')[np.newaxis,:,:,np.newaxis] / 255

    refocus_img = sess.run(refocus_finer, feed_dict = {x_coarsest: defocus_image_coarsest, 
                                                     x_intermediate: defocus_image_intermediate,
                                                     x_finer: defocus_image_finer})
    
    plt.figure(figsize = (20,20))
    plt.imshow(refocus_img[0,:,:,0], cmap = 'gray')
    plt.axis('off')
    plt.show()

    plt.figure(figsize = (20,20))
    plt.imshow(defocus_image[:,:], cmap = 'gray')
    plt.axis('off')
    plt.show()