In [1]:
import tensorflow as tf
import numpy as np
import random
from PIL import Image
from skimage import io
from scipy.interpolate import splprep, splev
import cv2
from scipy import misc
from scipy import ndimage
from scipy import signal
from scipy.ndimage import gaussian_filter
from scipy.interpolate import interp1d
from scipy import stats
from skimage.transform import resize
import matplotlib.pyplot as plt
import math
import gc
import os
import pandas as pd
import time

psfsize = 100

def posimage(image):
    
    addr = np.where(image < 0)
    for i in range(0, addr[0].size):
        image[addr[0][i], addr[1][i]] = 0
    
    return image

def filtimage(image):
    
    filtimage = np.ones(image.shape)
    
    addr = np.where(image < 0)
    for i in range(0, addr[0].size):
        filtimage[addr[0][i], addr[1][i]] = 0
    
    return filtimage*100

def lp_filt(image, psf, imsize):
    
    fftpsf = np.fft.fft2(psf, s=imsize)
    fftimage = np.fft.fft2(image, s=imsize)
    
    OTF_filt = np.zeros(fftimage.shape)
    
    pos_addr = np.where(np.absolute(fftpsf) > np.max(np.absolute(fftpsf))*0.01)
    OTF_filt[pos_addr] = 1
    
    newfftimage = OTF_filt * fftimage
    
    newimage = np.fft.ifft2(newfftimage, s=imsize)
    
    return newimage

def lp_filt_2(image, psf, imsize):
    
    fftpsf = np.fft.fft2(psf, s=imsize)
    fftimage = np.fft.fft2(image, s=imsize)
    
    OTF_filt = np.zeros(fftimage.shape)
    
    pos_addr = np.where(np.absolute(fftpsf) > np.max(np.absolute(fftpsf))*0.1)
    OTF_filt[pos_addr] = 1
    
    newfftimage = OTF_filt * fftimage
    
    newimage = np.fft.ifft2(newfftimage, s=imsize)
    
    return newimage


def det_offset(image, psf, imsize):
    
    fftpsf = np.fft.fft2(psf, s=imsize)
    fftimage = np.fft.fft2(image, s=imsize)
    
    OTF_filt = np.ones(fftimage.shape)
    
    pos_addr = np.where(np.absolute(fftpsf) < np.max(np.absolute(fftpsf))*0.80)
    OTF_filt[pos_addr] = 0
    
    newfftimage = OTF_filt * fftimage
    
    newimage = np.fft.ifft2(newfftimage, s=imsize)
    
    return newimage


def SNR_cal(a, psf, imsize): 
    
    img_noise = hp_filt(a, psf, imsize)
    img_signal = lp_filt(a, psf, imsize)
    SNR = np.mean(np.abs(img_signal))/np.std(img_noise)
    
    return SNR

def xcov(x, y):
    
    x_64 = tf.placeholder(dtype=tf.float64)
    y_64 = tf.placeholder(dtype=tf.float64)
    x_64 = tf.cast(x, dtype=tf.float64)
    y_64 = tf.cast(y, dtype=tf.float64)
    
    cov_xy = tf.placeholder(dtype=tf.float64)
    cov_xy = 1 / (tf.shape(x_64)[0] - 1) * tf.reduce_sum((x_64 - tf.reduce_mean(x_64)) * (y_64 - tf.reduce_mean(y_64)))
    
    cov_xy_32 = tf.placeholder(dtype=tf.float32)
    cov_xy_32 = tf.cast(cov_xy, dtype=tf.float32)
    
    return cov_xy_32

def SSIM_cal(orgimg, image):
    
    sig_o = tf.placeholder(dtype=tf.float32)
    sig_i = tf.placeholder(dtype=tf.float32)
    mu_o = tf.placeholder(dtype=tf.float32)
    mu_i = tf.placeholder(dtype=tf.float32)
    sig_oi = tf.placeholder(dtype=tf.float32)
    
    sig_o = tf.math.reduce_std(orgimg)
    sig_i = tf.math.reduce_std(image)
    mu_o = tf.math.reduce_mean(orgimg)
    mu_i = tf.math.reduce_mean(image)
    sig_oi = xcov(orgimg, image)
    
    SSIM_val = tf.placeholder(dtype=tf.float32)
    #SSIM_val = (2*mu_o*mu_i)*(2*sig_oi)/(mu_o^2+mu_i^2)/(sig_o^2+sig_i^2)
    SSIM_val = 4*tf.math.divide(tf.math.multiply(tf.math.multiply(mu_o, mu_i), sig_oi), tf.math.multiply(tf.math.add(tf.math.square(mu_o), tf.math.square(mu_i)), tf.math.add(tf.math.square(sig_o), tf.math.square(sig_o))))
    
                    
    
    return SSIM_val


def hp_filt(image, psf, imsize):
    
    fftpsf = np.fft.fft2(psf, s=imsize)
    fftimage = np.fft.fft2(image, s=imsize)
    
    OTF_filt = np.ones(fftimage.shape)
    
    pos_addr = np.where(np.absolute(fftpsf) > np.max(np.absolute(fftpsf))*0.001)
    OTF_filt[pos_addr] = 0
    
    newfftimage = OTF_filt * fftimage
    
    newimage = np.fft.ifft2(newfftimage, s=imsize)
    
    return newimage


def psf_crop(psf):
    
    max_addr = np.where(psf == np.max(psf))
    
    max_val_psf = psf[max_addr[0][0], max_addr[1][0]]
    
    max_val_psf_x = psf[max_addr[0][0], 0:psf.shape[1]]
    
    
    for j in range(1, psfsize):
        
        val_psf_x = psf[max_addr[0][0], max_addr[1][0]+j]
        
        if val_psf_x < 0.5*max_val_psf:
            
            psf_size = ((j)*5)
            break
    
    new_psf = psf[max_addr[0][0]-psf_size:max_addr[0][0]+psf_size+1, max_addr[1][0]-psf_size:max_addr[1][0]+psf_size+1]
    
    return new_psf, psf_size

        
def virtualimage(imsize, addr_x, addr_y, sizeaddr):
    
    tempimg = np.zeros(imsize[0]*imsize[1])
    
    unit_inten = 1
    
    
    intarray = np.ones(sizeaddr)*unit_inten
    
    Less_addr_x = tf.where(tf.less(addr_x, 1))
    Less_addr_x = tf.cast(tf.slice(Less_addr_x, [0, 0], [tf.shape(Less_addr_x)[0], 1]), tf.int32)
    Greater_addr_x = tf.where(tf.greater(addr_x, (imsize[1]-1)))
    Greater_addr_x = tf.cast(tf.slice(Greater_addr_x, [0, 0], [tf.shape(Less_addr_x)[0], 1]), tf.int32)
    Less_Tensor_x = tf.ones([0, tf.size(Less_addr_x)])*(imsize[1]-1)
    Less_Tensor_x = tf.cast(Less_Tensor_x, tf.float32)
    Greater_Tensor_x = tf.ones([0, tf.size(Greater_addr_x)])
    Greater_Tensor_x = tf.cast(Greater_Tensor_x, tf.float32)
    
    new_addr_x = tf.placeholder(tf.float32, name='new_address_x')
    
    new_addr_x = tf.cond(tf.less(tf.constant(0), tf.size(Less_addr_x)), 
                          lambda: tf.tensor_scatter_nd_update(addr_x, Less_addr_x, Less_Tensor_x), 
                          lambda: tf.cast(addr_x, tf.float32))
    
    new_addr_x = tf.cond(tf.less(tf.constant(0), tf.size(Greater_addr_x)), 
                          lambda: tf.tensor_scatter_nd_update(addr_x, Greater_addr_x, Greater_Tensor_x), 
                          lambda: addr_x)
    
    Less_addr_y = tf.where(tf.less(addr_y, 1))
    Less_addr_y = tf.cast(tf.slice(Less_addr_y, [0, 0], [tf.shape(Less_addr_y)[0], 1]), tf.int32)
    Greater_addr_y = tf.where(tf.greater(addr_y, (imsize[2]-1)))
    Greater_addr_y = tf.cast(tf.slice(Greater_addr_y, [0, 0], [tf.shape(Less_addr_y)[0], 1]), tf.int32)
    Less_Tensor_y = tf.ones([0, tf.size(Less_addr_y)])*(imsize[2]-1)
    Less_Tensor_y = tf.cast(Less_Tensor_y, tf.float32)
    Greater_Tensor_y = tf.ones([0, tf.size(Greater_addr_y)])
    Greater_Tensor_y = tf.cast(Greater_Tensor_y, tf.float32)
    
    new_addr_y = tf.placeholder(tf.float32, name='new_address_y')
    
    new_addr_y = tf.cond(tf.less(tf.constant(0), tf.size(Less_addr_y)), 
                          lambda: tf.tensor_scatter_nd_update(addr_y, Less_addr_y, Less_Tensor_y), 
                          lambda: tf.cast(addr_y, tf.float32))
    
    new_addr_y = tf.cond(tf.less(tf.constant(0), tf.size(Greater_addr_y)), 
                          lambda: tf.tensor_scatter_nd_update(addr_y, Greater_addr_y, Greater_Tensor_y), 
                          lambda: addr_y)
    
    
    new_address = new_addr_y + new_addr_x*imsize[1]
    
    int_address = tf.math.round(new_address)
    int_address = tf.cast(int_address, dtype=tf.int32)
    tensorintarray = tf.cast(intarray, dtype=tf.float32)
    img_var = tf.constant(tempimg, dtype=tf.float32)
    
    new_img_var = tf.tensor_scatter_nd_add(img_var, int_address, tensorintarray)
    new_img_var = tf.cast(new_img_var, dtype=tf.float32)
                                   
    
    del [[tempimg, Less_addr_x, Less_addr_y, Greater_addr_x, Greater_addr_y, Less_Tensor_x, Less_Tensor_y, Greater_Tensor_x, Greater_Tensor_y, new_addr_x, new_addr_y, new_address, int_address, tensorintarray, img_var]]
    
    return new_img_var
    
def convimggen(imsize, temp_addr_x, temp_addr_y, psf, sizeaddr):
    
   
    temp_psf = tf.constant(psf.reshape(psf.shape[0], psf.shape[1], 1, 1), dtype=tf.float32)
    tf.reshape(virtualimage(imsize, finalres_x, finalres_y, addrsize), [imsize[0], imsize[1]])
    
    temp_image = virtualimage(imsize, temp_addr_x, temp_addr_y, sizeaddr)
    temp_image = tf.reshape(temp_image, [1, imsize[0], imsize[1], 1])
    temp_convimg = tf.nn.conv2d(temp_image, temp_psf, strides=[1, 1, 1, 1], padding='SAME')
    
    re_temp_convimg = tf.math.reduce_sum(temp_psf)*tf.math.divide(tf.reshape(temp_convimg, [imsize[0], imsize[1]]),tf.math.reduce_sum(temp_convimg))*sizeaddr
    
    
    del [[temp_psf, temp_image, temp_convimg]]
    
    
    return re_temp_convimg


def fftconvimggen(imsize, temp_addr_x, temp_addr_y, psf, sizeaddr):
    
    temp_psf = tf.placeholder(dtype = tf.float32)
    temp_image = tf.placeholder(dtype = tf.float32)
    temp_fft_psf = tf.placeholder(dtype = tf.float32)
    temp_fft_image = tf.placeholder(dtype = tf.float32)
    conv_fft_image = tf.placeholder(dtype = tf.float32)
    conv_image = tf.placeholder(dtype = tf.float32)
    re_temp_convimg = tf.placeholder(dtype = tf.float32)
    
    imageshape = tf.constant([imsize[0], imsize[1], imsize[2]], tf.int32)
    temp_psf = tf.constant(psf, dtype=tf.float32)
    temp_fft_psf = tf.signal.rfft2d(temp_psf, fft_length = imageshape)
        
    temp_image = tf.reshape(virtualimage(imsize, temp_addr_x, temp_addr_y, sizeaddr), [imsize[0], imsize[1]])
    temp_fft_image = tf.signal.rfft2d(temp_image, fft_length = imageshape)
    
    conv_fft_image = tf.keras.layers.multiply([temp_fft_image, temp_fft_psf])
    
    conv_image = tf.cast(tf.signal.irfft2d(conv_fft_image, fft_length = imageshape), dtype=tf.float32)
    conv_image = tf.roll(conv_image, tf.cast([tf.math.ceil(imsize[0]/2), tf.math.ceil(imsize[1]/2)], dtype=tf.int32), [0, 1])

    re_temp_convimg = tf.reshape(conv_image, [imsize[0], imsize[1]])/tf.math.reduce_sum(conv_image)*sizeaddr
    
    del [[temp_psf, temp_image, temp_fft_psf, temp_fft_image, conv_fft_image, conv_image]]
    
    
    return re_temp_convimg



def cost_fun(re_image, imsize, temp_addr_x, temp_addr_y, psf, sizeaddr):
    
    cost = tf.nn.l2_loss(tf.math.subtract(re_image, convimggen(imsize, temp_addr_x, temp_addr_y, psf, sizeaddr)))
    
    return cost

def genetOpt(re_image, imsize, temp_addr_x, temp_addr_y, psf, sizeaddr, addr_x, addr_y):
    
    num_sel = 3
    
    tf_addr_x = tf.placeholder(dtype=tf.float32)
    tf_addr_y = tf.placeholder(dtype=tf.float32)
    tf_addr_x = tf.cast(addr_x, tf.float32)
    tf_addr_y = tf.cast(addr_y, tf.float32)
    
    
    new_addr_x = tf.placeholder(dtype=tf.float32)
    new_addr_y = tf.placeholder(dtype=tf.float32)
    
    new_addr_x = tf.cast(temp_addr_x, dtype=tf.float32)
    new_addr_y = tf.cast(temp_addr_y, dtype=tf.float32)
    
    sel_addr = tf.placeholder(dtype=tf.int32)
    #sel_addr = tf.random.uniform([math.floor(sizeaddr*sel_portion), 1], minval=0, maxval = tf.size(tf_addr_x)-1, dtype = tf.int32)
    sel_addr = tf.random.uniform([num_sel, 1], minval=0, maxval = tf.size(tf_addr_x)-1, dtype = tf.int32)
    
    sel_addr_2 = tf.placeholder(dtype=tf.int32)
    #sel_addr = tf.random.uniform([math.floor(sizeaddr*sel_portion), 1], minval=0, maxval = tf.size(tf_addr_x)-1, dtype = tf.int32)
    sel_addr_2 = tf.random.uniform([num_sel, 1], minval=0, maxval = tf.size(tf_addr_x)-1, dtype = tf.int32)
    
    
    #Exchange information
    
    temp_x = tf.placeholder(dtype=tf.float32)
    temp_y = tf.placeholder(dtype=tf.float32)
    
    #temp_x = tf.reshape(tf.gather(tf_addr_x, sel_addr), [math.floor(sizeaddr*sel_portion), 1])
    #temp_y = tf.reshape(tf.gather(tf_addr_y, sel_addr), [math.floor(sizeaddr*sel_portion), 1])
    temp_x = tf.reshape(tf.gather(tf_addr_x, sel_addr), [num_sel, 1])
    temp_y = tf.reshape(tf.gather(tf_addr_y, sel_addr), [num_sel, 1])
    
    
    new_addr_x = tf.tensor_scatter_nd_update(new_addr_x, sel_addr_2, temp_x)
    new_addr_y = tf.tensor_scatter_nd_update(new_addr_y, sel_addr_2, temp_y)
    
    #Evaluation
    
    old_eval = tf.math.reduce_sum(tf.abs(tf.math.subtract(re_image, convimggen(imsize, temp_addr_x, temp_addr_y, psf, sizeaddr))))
    new_eval = tf.math.reduce_sum(tf.abs(tf.math.subtract(re_image, convimggen(imsize, new_addr_x, new_addr_y, psf, sizeaddr))))
    
    res_x = tf.placeholder(dtype=tf.float32)
    res_y = tf.placeholder(dtype=tf.float32)
    diff_sum = tf.placeholder(dtype=tf.float32)
    
    res_x = tf.cond(tf.less(new_eval, old_eval), lambda: tf.cast(new_addr_x, dtype=tf.float32), lambda: tf.cast(temp_addr_x, dtype=tf.float32))
    res_y = tf.cond(tf.less(new_eval, old_eval), lambda: tf.cast(new_addr_y, dtype=tf.float32), lambda: tf.cast(temp_addr_y, dtype=tf.float32))
    #diff_sum = tf.cond(tf.less(new_eval, old_eval), lambda: tf.cast(new_eval, dtype=tf.float32), lambda: tf.cast(old_eval, dtype=tf.float32))
    diff_sum = tf.math.abs(tf.math.reduce_sum(tf.math.abs(tf.math.subtract(convimggen(imsize, temp_addr_x, temp_addr_y, psf, sizeaddr), re_image)))
                           - tf.math.reduce_sum(tf.math.abs(tf.math.subtract(re_image, convimggen(imsize, new_addr_x, new_addr_y, psf, sizeaddr)))))
    
    del[[old_eval, new_eval, temp_x, temp_y, new_addr_x, new_addr_y, sel_addr]]
    
    return res_x, res_y, diff_sum
    
    

def adamIntVar(re_image, imsize, temp_addr_x, temp_addr_y, psf, sizeaddr, m_x, v_x, m_y, v_y, extrapolnum, i, imagefilter):
    
    LR_xy = tf.constant(psf_size*0.1, dtype=tf.float32)
    #LR_xy = tf.constant(25, dtype=tf.float32)
    
    
    beta_1 = tf.constant(0.9, dtype=tf.float32)
    beta_1_1 = 1-beta_1
    beta_2 = tf.constant(0.99, dtype=tf.float32)
    beta_2_2 = 1-beta_2
    
    m_x_new = tf.placeholder(dtype=tf.float32)
    v_x_new = tf.placeholder(dtype=tf.float32)
    m_x_new = tf.cast(m_x, dtype=tf.float32)
    v_x_new = tf.cast(v_x, dtype=tf.float32)
    
    m_y_new = tf.placeholder(dtype=tf.float32)
    v_y_new = tf.placeholder(dtype=tf.float32)
    m_y_new = tf.cast(m_y, dtype=tf.float32)
    v_y_new = tf.cast(v_y, dtype=tf.float32)
    
    x_res = tf.placeholder(dtype=tf.float32)
    x_res = temp_addr_x
    
    y_res = tf.placeholder(dtype=tf.float32)
    y_res = temp_addr_y
    
    grad_res_x, grad_res_y = gradgenfun(re_image, imsize, temp_addr_x, temp_addr_y, psf, sizeaddr, imagefilter)
    
    m_x_new = beta_1*m_x + beta_1_1*tf.cast(grad_res_x, dtype=tf.float32)
    v_x_new = beta_2*v_x + beta_2_2*tf.cast(tf.math.square(grad_res_x), dtype=tf.float32)
    
    x_res = temp_addr_x - LR_xy*(m_x_new/(1-tf.math.pow(beta_1, i)))/tf.math.sqrt(v_x_new/(1-tf.math.pow(beta_2,i)) + tf.constant(0.001, dtype=tf.float32))
    
    x_res = tf.math.round(x_res)
    
    m_y_new = beta_1*m_y + beta_1_1*tf.cast(grad_res_y, dtype=tf.float32)
    v_y_new = beta_2*v_y + beta_2_2*tf.cast(tf.math.square(grad_res_y), dtype=tf.float32)
    
    y_res = temp_addr_y - LR_xy*(m_y_new/(1-tf.math.pow(beta_1, i)))/tf.math.sqrt(v_y_new/(1-tf.math.pow(beta_2,i)) + tf.constant(0.001, dtype=tf.float32))
    
    y_res = tf.math.round(y_res)
    
    
    Less_addr_x = tf.where(tf.less(x_res, 1))
    Less_addr_x = tf.cast(tf.slice(Less_addr_x, [0, 0], [tf.shape(Less_addr_x)[0], 1]), tf.int32)
    Greater_addr_x = tf.where(tf.greater(x_res, (imsize[1]-1)))
    Greater_addr_x = tf.cast(tf.slice(Greater_addr_x, [0, 0], [tf.shape(Less_addr_x)[0], 1]), tf.int32)
    Less_Tensor_x = tf.ones([0, tf.size(Less_addr_x)])*(imsize[1]-1)
    Less_Tensor_x = tf.cast(Less_Tensor_x, tf.float32)
    Greater_Tensor_x = tf.ones([0, tf.size(Greater_addr_x)])
    Greater_Tensor_x = tf.cast(Greater_Tensor_x, tf.float32)
    
    new_x_res = tf.placeholder(tf.float32, name='new_address_x')
    
    new_x_res = tf.cond(tf.less(tf.constant(0), tf.size(Less_addr_x)), 
                          lambda: tf.tensor_scatter_nd_update(x_res, Less_addr_x, Less_Tensor_x), 
                          lambda: tf.cast(x_res, tf.float32))
    
    new_x_res = tf.cond(tf.less(tf.constant(0), tf.size(Greater_addr_x)), 
                          lambda: tf.tensor_scatter_nd_update(x_res, Greater_addr_x, Greater_Tensor_x), 
                          lambda: x_res)
    
    Less_addr_y = tf.where(tf.less(y_res, 1))
    Less_addr_y = tf.cast(tf.slice(Less_addr_y, [0, 0], [tf.shape(Less_addr_y)[0], 1]), tf.int32)
    Greater_addr_y = tf.where(tf.greater(y_res, (imsize[2]-1)))
    Greater_addr_y = tf.cast(tf.slice(Greater_addr_y, [0, 0], [tf.shape(Less_addr_y)[0], 1]), tf.int32)
    Less_Tensor_y = tf.ones([0, tf.size(Less_addr_y)])*(imsize[2]-1)
    Less_Tensor_y = tf.cast(Less_Tensor_y, tf.float32)
    Greater_Tensor_y = tf.ones([0, tf.size(Greater_addr_y)])
    Greater_Tensor_y = tf.cast(Greater_Tensor_y, tf.float32)
    
    new_y_res = tf.placeholder(tf.float32, name='new_address_x')
    
    new_y_res = tf.cond(tf.less(tf.constant(0), tf.size(Less_addr_y)), 
                          lambda: tf.tensor_scatter_nd_update(y_res, Less_addr_y, Less_Tensor_y), 
                          lambda: tf.cast(y_res, tf.float32))
    
    new_y_res = tf.cond(tf.less(tf.constant(0), tf.size(Greater_addr_y)), 
                          lambda: tf.tensor_scatter_nd_update(y_res, Greater_addr_y, Greater_Tensor_y), 
                          lambda: y_res)
    
    diff_sum = tf.math.abs(tf.math.reduce_sum(tf.math.abs(tf.math.subtract(convimggen(imsize, temp_addr_x, temp_addr_y, psf, sizeaddr), re_image)))
                           - tf.math.reduce_sum(tf.math.abs(tf.math.subtract(re_image, convimggen(imsize, new_x_res, new_y_res, psf, sizeaddr)))))
    
    del [[Less_addr_x, Less_addr_y, Greater_addr_x, Greater_addr_y, Less_Tensor_x, Less_Tensor_y, Greater_Tensor_x, Greater_Tensor_y]]
    
    
    return new_x_res, new_y_res, m_x_new, v_x_new, m_y_new, v_y_new, diff_sum


def imgradfun(imsize, temp_addr_x, temp_addr_y, psf, sizeaddr):
    
    vec = np.zeros([3, 3])
    vec[1, 0] = -1
    vec[0, 1] = -1
    vec[1, 2] = 1
    vec[2, 1] = 1
    temp_vec = tf.constant(vec, dtype=tf.float32)
    temp_vec = tf.reshape(temp_vec, [3, 3, 1, 1])
    grad_img = tf.nn.conv2d(tf.reshape(convimggen(imsize, temp_addr_x, temp_addr_y, psf, sizeaddr), [1, imsize[0], imsize[1], 1]), temp_vec, strides=[1, 1, 1, 1], padding='SAME')
    grad_img = tf.reshape(grad_img, [imsize[0], imsize[1]])
    
    
    return grad_img
                            
    


def gradgenfun(re_image, imsize, temp_addr_x, temp_addr_y, psf, sizeaddr, imagefilter):
    
    grad_step = tf.constant(4, dtype=tf.float32)
    
    temp_result = - 1000*tf.math.subtract(re_image, convimggen(imsize, temp_addr_x, temp_addr_y, psf, sizeaddr))
    
    result_x_1_1= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_x+1, temp_addr_y], 1), dtype = tf.int32))
    result_x_1_2= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_x+grad_step, temp_addr_y], 1), dtype = tf.int32))
    result_x_1_3= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_x+2*grad_step, temp_addr_y], 1), dtype = tf.int32))
    result_x_1_4= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_x+3*grad_step, temp_addr_y], 1), dtype = tf.int32))
    result_x_2_1= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_x-1, temp_addr_y], 1), dtype = tf.int32))
    result_x_2_2= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_x-grad_step, temp_addr_y], 1), dtype = tf.int32))
    result_x_2_3= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_x-2*grad_step, temp_addr_y], 1), dtype = tf.int32))
    result_x_2_4= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_x-3*grad_step, temp_addr_y], 1), dtype = tf.int32))
    
    
    result_x_1 = (result_x_1_1 + result_x_1_2 + result_x_1_3 + result_x_1_4)
    result_x_2 = (result_x_2_1 + result_x_2_2 + result_x_2_3 + result_x_2_4)
    
    result_y_1_1= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_x, temp_addr_y+1], 1), dtype = tf.int32))
    result_y_1_2= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_x, temp_addr_y+grad_step], 1), dtype = tf.int32))
    result_y_1_3= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_x, temp_addr_y+2*grad_step], 1), dtype = tf.int32))
    result_y_1_4= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_x, temp_addr_y+3*grad_step], 1), dtype = tf.int32))
    result_y_2_1= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_x, temp_addr_y-1], 1), dtype = tf.int32))
    result_y_2_2= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_x, temp_addr_y-grad_step], 1), dtype = tf.int32))
    result_y_2_3= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_x, temp_addr_y-2*grad_step], 1), dtype = tf.int32))
    result_y_2_4= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_x, temp_addr_y-3*grad_step], 1), dtype = tf.int32))
    
    
    result_y_1 = (result_y_1_1 + result_y_1_2 + result_y_1_3 + result_y_1_4)
    result_y_2 = (result_y_2_1 + result_y_2_2 + result_y_2_3 + result_y_2_4)
    
    result_x = tf.math.subtract(result_x_1_1, result_x_2_1)
    result_x = tf.reshape(tf.cast(result_x, dtype=tf.float32), [sizeaddr, 1])
    
    result_y = tf.math.subtract(result_y_1_1, result_y_2_1)
    result_y = tf.reshape(tf.cast(result_y, dtype=tf.float32), [sizeaddr, 1])
    
    return result_x, result_y

def clearall():
    all = [var for var in globals() if var[0] != "_"]
    
    for var in all:
        del globals()[var]
        
def address_exporter(finalres_x, finalres_y):
    temp_res_x = finalres_x
    temp_res_y = finalres_y
    return temp_res_x, temp_res_y


def intensity_ratio(convimage, dotimage, jj, iternumb, addrsize):
    
    for jj in range(iternumb+1):
        
        testimage = posimage(jj*dotimage - convimage)
        
               
        if np.sum(testimage)>0:
            
            break
            
    return jj

  '{0}.{1}.{2}'.format(*version.hdf5_built_version_tuple)


In [2]:
psf = io.imread('./deconvolution_data/PSF_HEK_5x.tif')

psf_size = [0]
psf = psf/np.sum(psf)
new_psf, psf_size = psf_crop(psf)
big_psf = new_psf

big_psf = big_psf/np.sum(big_psf)


In [13]:
vir_num_scaling_factor = 35  #0.1~10000
background_threshold = 100
addr_threshold = 1
numbmaxaddr = 10000000
signal_bg_ratio = 0.85
iter_criterion = 1e-24

In [14]:
fileaddr = './deconvolution_data/'
filename = 'HEKCell_fourier5x'
intratio = 1

In [15]:
start = time.time()

for ii in range(1, 2):
    
    image = np.float32(io.imread(fileaddr + filename +'.tif'))
    #image = np.float32(io.imread(fileaddr + filename + str(ii) + '.tif'))
    image = posimage(image-background_threshold)
    
    max_val = np.max(image)
    
    hp_image = hp_filt(image, big_psf, image.shape)
    offset_image = det_offset(image, big_psf, image.shape)
    thrval = 1
    
    
    addr = np.where(image > (np.abs(np.mean(offset_image)) + addr_threshold*np.std(hp_image)))
    
    [numberofaxes, tempsize] = np.shape(addr)
    
    intensityratio = np.sum(posimage(image - (np.abs(np.mean(offset_image)) + addr_threshold*np.std(hp_image)))/np.sum(image))
    
    del [[hp_image, offset_image]]
    
    totaladdrsize = np.var(image)*tempsize/np.mean(image+1)/np.mean(image+1)*vir_num_scaling_factor
    
    imagefilter = filtimage(image - thrval)
    io.imsave(filename+'_imagefilter.tif', np.float32(imagefilter))

    image = posimage(image)
    

    if totaladdrsize <= numbmaxaddr:
    
        addrsize = np.int(np.floor(totaladdrsize))
        iternumb = 1
    
    else:
    
        iternumb = np.int(np.floor((totaladdrsize/numbmaxaddr)))
        addrsize = np.int(np.floor((totaladdrsize/iternumb)))
    
        

    intratio = np.sum(posimage(image))/totaladdrsize
    
    big_psf = big_psf/np.sum(big_psf)
    
    real_psf = signal_bg_ratio*intensityratio*big_psf*intratio
    big_psf = signal_bg_ratio*intensityratio*big_psf*intratio
    
    
    image_org = image
    io.imsave(filename+'_temp.tif', np.float32(image))
    
    imsize = [image.shape[0], image.shape[1], intratio]
    im_orgsize = [image_org.shape[0], image_org.shape[1]]
    
    config = tf.ConfigProto()
    config.gpu_options.allow_growth=True
    

    addr_x = np.zeros((addrsize, 1))
    addr_y = np.zeros((addrsize, 1))
    
    addr = np.array(addr)
    
        
    for i in range(addrsize):
    
        randaddr = random.randint(0, addr.shape[1]-1)
    
        addr_x[i] = addr[0, randaddr]
        addr_y[i] = addr[1, randaddr]
    
    io.imsave('temp_imsave.tif', np.zeros((im_orgsize[0], im_orgsize[1])))
    io.imsave('temp_imresidue.tif', np.float32(posimage(image)))
    
    
    gc.enable()
    
    jj = 0
    
    SNR = SNR_cal(image, psf, image.shape)
    
    
    while (iternumb - jj)/iternumb > 0.01:
        
        
        tf.reset_default_graph()
        
        
        image = io.imread('temp_imresidue.tif')
        simage = io.imread('temp_imsave.tif')
        
        org_psf = psf
        
        
        SNR = SNR_cal(image, org_psf, image.shape)
        
        image = posimage(image)
               
        sum_image = tf.Variable(simage, "sum_image", dtype=tf.float32)
            
        residue = tf.Variable(image, "residue", dtype=tf.float32)
        
        extrapolnum = np.float32(math.ceil(big_psf.shape[0]/2))
        
        for i in range(addrsize):
    
            randaddr = random.randint(0, addr.shape[1]-1)
    
            addr_x[i] = addr[0, randaddr]
            addr_y[i] = addr[1, randaddr]
    
        sess = tf.Session(config=config)
        
            
        finalres_x = tf.placeholder(dtype=tf.float32)
        finalres_y = tf.placeholder(dtype=tf.float32)
            
        
        if jj == 1:
            
            re_image = tf.placeholder(dtype=tf.float32)
            re_image = tf.cast(image, tf.float32)
                
        else:
            
            re_image = tf.placeholder(dtype=tf.float32)
            re_image = residue
                
        finalres_x = tf.cast(addr_x, dtype=tf.float32)
        finalres_y = tf.cast(addr_y, dtype=tf.float32)
               
    
            
        m = tf.placeholder(dtype=tf.float32)
        v = tf.placeholder(dtype=tf.float32)
        m = tf.zeros(tf.shape(finalres_x), dtype=tf.float32)
        v = tf.zeros(tf.shape(finalres_x), dtype=tf.float32)
        i = tf.constant(1, dtype=tf.float32)
        
        m_x_new = m
        v_x_new = v
        m_y_new = m
        v_y_new = v
    
          
        sess.run(tf.initialize_all_variables())
        
        
        j = tf.constant(1, dtype=tf.float32)
        
        diff_sum = tf.constant(1, dtype=tf.float32)
        
        iterthrval = np.sum(big_psf)*addrsize*iter_criterion
        
            
        def cond(j, finalres_x, finalres_y, m_x_new, v_x_new, m_y_new, v_y_new, diff_sum, extrapolnum):
            
            return tf.math.logical_or(tf.less(j, 10), tf.greater_equal(diff_sum, iterthrval))
            
        def body(j, finalres_x, finalres_y, m_x_new, v_x_new, m_y_new, v_y_new, diff_sum, extrapolnum):
            j = tf.add(j, 1)
            finalres_x, finalres_y,m_x_new, v_x_new, m_y_new, v_y_new, diff_sum = adamIntVar(re_image, imsize, finalres_x, finalres_y, big_psf, addrsize, m_x_new, v_x_new, m_y_new, v_y_new, extrapolnum, j, imagefilter)
            return [j, finalres_x, finalres_y, m_x_new, v_x_new, m_y_new, v_y_new, diff_sum, extrapolnum]
        
        res_loop = tf.while_loop(cond, body, [j, finalres_x, finalres_y, m_x_new, v_x_new, m_y_new, v_y_new, diff_sum, extrapolnum])
        
        [j, finalres_x, finalres_y, m_x_new, v_x_new, m_y_new, v_y_new, diff_sum, extrapolnum] = res_loop
        
        resultconvimage = convimggen(imsize, finalres_x, finalres_y, big_psf, addrsize)
    
        resultconvimage = tf.reshape(resultconvimage, [imsize[0], imsize[1]])
        
        resultimage = tf.reshape(virtualimage(imsize, finalres_x, finalres_y, addrsize), [imsize[0], imsize[1]])
    
        
        sub_image = tf.placeholder(tf.float32)
        sub_image = tf.cast(tf.math.subtract(re_image, resultconvimage), tf.float32)
        temp_image = tf.placeholder(tf.float32)
        temp_image = tf.cast(sum_image, dtype=tf.float32)
        
        if jj == 1:
        
            sum_image = resultimage
            residue = sub_image
            
            
        else:
                
            sum_image = tf.cast(tf.add(temp_image, resultimage), tf.float32)
            residue = tf.cast(sub_image, tf.float32)
                
        _, _, rimage, _, _, smallconvimage = sess.run([residue, sum_image, resultimage, sub_image, temp_image, resultconvimage])
        
        orgimage = image[:, :]
        
        jj += 1
        
        orgimage = orgimage - smallconvimage
        
        image = orgimage
        
        io.imsave('temp_imresidue.tif', np.float32(image))
        io.imsave('temp_imresidue_pos.tif', posimage(np.float32(image)))

        
        io.imsave('temp_imsave.tif', np.float32(rimage+simage))
        io.imsave('temp_imsave_'+str(jj)+'.tif', np.float32(rimage+simage))
        
        
        print(jj/iternumb*100, "% processed, ", "time:", time.time() - start, "SNR:", SNR)
        tf.keras.backend.clear_session()
    
        del [[simage, residue, temp_image, re_image, sub_image, sum_image, resultconvimage, resultimage, m_x_new, v_x_new, m_y_new, v_y_new, m, v, i, finalres_x, finalres_y]]
        
        gc.collect()
        
        
        
    simage = io.imread('temp_imsave.tif')
    
    rimage = io.imread('temp_imresidue.tif')
    
    if not(os.path.isdir(fileaddr + 'Doconvolved')):
        os.makedirs(os.path.join(fileaddr + 'Doconvolved'))
    if not(os.path.isdir(fileaddr + 'Deconvolved_pos_residue')):
        os.makedirs(os.path.join(fileaddr + 'Deconvolved_pos_residue'))
    if not(os.path.isdir(fileaddr + 'Deconvolved_residue')):
        os.makedirs(os.path.join(fileaddr + 'Deconvolved_residue'))
    
    io.imsave(fileaddr + 'Deconvolved/Deconvolved_'+filename+'Denoised-'+str(ii)+'.tif', np.float32(intratio*simage))
    io.imsave(fileaddr + 'Deconvolved_pos_residue/Deconvolved_'+filename+'Denoised_posresidue-'+str(ii)+'.tif', posimage(rimage))
    io.imsave(fileaddr + 'Deconvolved_residue/Deconvolved_'+filename+'Denoised_residue-'+str(ii)+'.tif', rimage)
    
    print("File number: ", ii, "elapsed time: ", (time.time() - start)/60, "min, SNR: ", SNR)
    
    
    del [[image, simage, rimage, jj, iternumb, thrval, totaladdrsize, addr, addr_x, addr_y]]
    
    gc.collect()
    
    
    tf.reset_default_graph()
    
    
    

    



100.0 % processed,  time: 2.6610584259033203 SNR: 352.5379092641093
File number:  1 elapsed time:  0.04611293077468872 min, SNR:  352.5379092641093
