In [1]:
import tensorflow as tf
import numpy as np
import random
from PIL import Image
from skimage import io
#import cv2
from scipy import misc
from scipy import ndimage
from scipy import signal
from scipy.ndimage import gaussian_filter
from scipy.interpolate import interp1d
from scipy import stats
from skimage.transform import resize
import matplotlib.pyplot as plt
import math
import gc
import os
import pandas as pd
import time



    

def z_extrapolimg(image, psf):
    
    z_xtrapolnum = 0
    
    max_addr = np.where(psf == np.max(psf))
    
    newimg = np.zeros([image.shape[0]+2*(z_xtrapolnum), image.shape[1], image.shape[2]])
    
    newimg[z_xtrapolnum:z_xtrapolnum+image.shape[0], :, :] = image
    
    Inten_ratio = psf[max_addr[0][0]:max_addr[0][0]+z_xtrapolnum, max_addr[1][0], max_addr[2][0]]/psf[max_addr[0][0], max_addr[1][0], max_addr[2][0]]
    
    for i in range(1, z_xtrapolnum-1):
        
        newimg[z_xtrapolnum+image.shape[0]+i-1, :, :] = image[image.shape[0]-1, :, :]*Inten_ratio[i-1]*0.1
        newimg[z_xtrapolnum-i, :, :] = image[0, :, :]*Inten_ratio[i-1]*0
        
    return newimg

def posimage(image):
    
    addr = np.where(image < 0)
    for i in range(0, addr[0].size):
        image[addr[0][i], addr[1][i], addr[2][i]] = 0
    
    return image

def filtimage(image):
    
    filtimage = np.ones(image.shape)
    
    addr = np.where(image < 0)
    for i in range(0, addr[0].size):
        filtimage[addr[0][i], addr[1][i], addr[2][i]] = 0
    
    return filtimage*100

def lp_filt(image, psf, imsize):
    
    fftpsf = np.fft.fftn(psf, s=imsize)
    fftimage = np.fft.fftn(image, s=imsize)
    
    OTF_filt = np.zeros(fftimage.shape)
    
    pos_addr = np.where(np.absolute(fftpsf) > np.max(np.absolute(fftpsf))*0.01)
    OTF_filt[pos_addr] = 1
    
    newfftimage = OTF_filt * fftimage
    
    newimage = np.fft.ifftn(newfftimage, s=imsize)
    
    return newimage

def lp_filt_2(image, psf, imsize):
    
    fftpsf = np.fft.fftn(psf, s=imsize)
    fftimage = np.fft.fftn(image, s=imsize)
    
    OTF_filt = np.zeros(fftimage.shape)
    
    pos_addr = np.where(np.absolute(fftpsf) > np.max(np.absolute(fftpsf))*0.1)
    OTF_filt[pos_addr] = 1
    
    newfftimage = OTF_filt * fftimage
    
    newimage = np.fft.ifftn(newfftimage, s=imsize)
    
    return newimage

def SNR_cal(a, psf, imsize): 
    
    img_noise = hp_filt(a, psf, imsize)
    img_signal = lp_filt(a, psf, imsize)
    img_offset = det_offset(a, psf, imsize)
    SNR = (np.mean(np.abs(img_signal)))/(np.std(img_noise))
    
    return SNR

def xcov(x, y):
    
    x_64 = tf.placeholder(dtype=tf.float64)
    y_64 = tf.placeholder(dtype=tf.float64)
    x_64 = tf.cast(x, dtype=tf.float64)
    y_64 = tf.cast(y, dtype=tf.float64)
    
    cov_xy = tf.placeholder(dtype=tf.float64)
    cov_xy = 1 / (tf.shape(x_64)[0] - 1) * tf.reduce_sum((x_64 - tf.reduce_mean(x_64)) * (y_64 - tf.reduce_mean(y_64)))
    
    cov_xy_32 = tf.placeholder(dtype=tf.float32)
    cov_xy_32 = tf.cast(cov_xy, dtype=tf.float32)
    
    return cov_xy_32

def SSIM_cal(orgimg, image):
    
    sig_o = tf.placeholder(dtype=tf.float32)
    sig_i = tf.placeholder(dtype=tf.float32)
    mu_o = tf.placeholder(dtype=tf.float32)
    mu_i = tf.placeholder(dtype=tf.float32)
    sig_oi = tf.placeholder(dtype=tf.float32)
    
    sig_o = tf.math.reduce_std(orgimg)
    sig_i = tf.math.reduce_std(image)
    mu_o = tf.math.reduce_mean(orgimg)
    mu_i = tf.math.reduce_mean(image)
    sig_oi = xcov(orgimg, image)
    
    SSIM_val = tf.placeholder(dtype=tf.float32)
    SSIM_val = 4*tf.math.divide(tf.math.multiply(tf.math.multiply(mu_o, mu_i), sig_oi), tf.math.multiply(tf.math.add(tf.math.square(mu_o), tf.math.square(mu_i)), tf.math.add(tf.math.square(sig_o), tf.math.square(sig_o))))
    
                    
    
    return SSIM_val


def hp_filt(image, psf, imsize):
    
    fftpsf = np.fft.fftn(psf, s=imsize)
    fftimage = np.fft.fftn(image, s=imsize)
    
    OTF_filt = np.ones(fftimage.shape)
    
    pos_addr = np.where(np.absolute(fftpsf) > np.max(np.absolute(fftpsf))*0.01)
    OTF_filt[pos_addr] = 0
    
    newfftimage = OTF_filt * fftimage
    
    newimage = np.fft.ifftn(newfftimage, s=imsize)
    
    return newimage

def det_offset(image, psf, imsize):
    
    fftpsf = np.fft.fftn(psf, s=imsize)
    fftimage = np.fft.fftn(image, s=imsize)
    
    OTF_filt = np.ones(fftimage.shape)
    
    pos_addr = np.where(np.absolute(fftpsf) < np.max(np.absolute(fftpsf))*0.90)
    OTF_filt[pos_addr] = 0
    
    newfftimage = OTF_filt * fftimage
    
    newimage = np.fft.ifftn(newfftimage, s=imsize)
    
    return newimage



def adamdenoise(orgimg, image, P_img, G_img, m_P, v_P, m_G, v_G):
    
    LR_P = tf.constant(0.01, dtype=tf.float32)
    LR_G = tf.constant(10, dtype=tf.float32)
    
    beta_1 = tf.constant(0.9, dtype=tf.float32)
    beta_1_1 = 1-beta_1
    beta_2 = tf.constant(0.99, dtype=tf.float32)
    beta_2_2 = 1-beta_2
    
    m_P_new = tf.placeholder(dtype=tf.float32)
    v_P_new = tf.placeholder(dtype=tf.float32)
    m_P_new = tf.cast(m_P, dtype=tf.float32)
    v_P_new = tf.cast(v_P, dtype=tf.float32)
        
    m_G_new = tf.placeholder(dtype=tf.float32)
    v_G_new = tf.placeholder(dtype=tf.float32)
    m_G_new = tf.cast(m_G, dtype=tf.float32)
    v_G_new = tf.cast(v_G, dtype=tf.float32)
    
    P_res = tf.placeholder(dtype=tf.float32)
    P_res = P_img
    
    G_res = tf.placeholder(dtype=tf.float32)
    G_res = G_img
    
    grad_P_img, grad_G_img = gradgenfun_denoise(orgimg, image, P_img, G_img)
    
    
    m_P_new = beta_1*m_P + beta_1_1*tf.cast(grad_P_img, dtype=tf.float32)
    v_P_new = beta_2*v_P + beta_2_2*tf.cast(tf.math.square(grad_P_img), dtype=tf.float32)
    
    P_res = P_img - LR_P*(m_P_new/(1-tf.math.pow(beta_1, i)))/tf.math.sqrt(v_P_new/(1-tf.math.pow(beta_2,i)) + tf.constant(0.001, dtype=tf.float32))
    
    m_G_new = beta_1*m_G + beta_1_1*tf.cast(grad_G_img, dtype=tf.float32)
    v_G_new = beta_2*v_G + beta_2_2*tf.cast(tf.math.square(grad_G_img), dtype=tf.float32)
    
    G_res = G_img - LR_G*(m_G_new/(1-tf.math.pow(beta_1, i)))/tf.math.sqrt(v_G_new/(1-tf.math.pow(beta_2,i)) + tf.constant(0.001, dtype=tf.float32))
    
    diff_mean = tf.math.reduce_mean(tf.math.abs(tf.math.subtract(P_res, P_img)) + tf.math.abs(tf.math.subtract(G_res, G_img)))
    
    return P_res, G_res, m_P, v_P, m_G, v_G, diff_mean




def gradgenfun_denoise(orgimg, image, P_img, G_img):
    
    step = tf.constant(1, dtype=tf.float32)
    
    P_grad = tf.subtract(tf.math.abs(tf.math.subtract(tf.math.subtract(orgimg, tf.math.multiply(P_img+step, image)), G_img)), tf.math.abs(tf.math.subtract(tf.math.subtract(orgimg, tf.math.multiply(P_img-step, image)), G_img)))
    
    G_grad = tf.subtract(tf.math.abs(tf.math.subtract(tf.math.subtract(orgimg, tf.math.multiply(P_img, image)), G_img+step)), tf.math.abs(tf.math.subtract(tf.math.subtract(orgimg, tf.math.multiply(P_img, image)), G_img-step)))
    
    return P_grad, G_grad
                             
                




def psf_crop(psf):
    
    
    max_addr = np.where(psf == np.max(psf))
    
    #i = np.int32(1)
    
    max_val_psf = psf[max_addr[0][0], max_addr[1][0], max_addr[2][0]]
    
    max_val_psf_z = psf[0:psf.shape[0], max_addr[1][0], max_addr[2][0]]
    max_val_psf_x = psf[max_addr[0][0], max_addr[1][0], 0:psf.shape[2]]
    
    
    for i in range(1, psf.shape[0]):
        
        val_psf_z = psf[max_addr[0][0]+i, max_addr[1][0], max_addr[2][0]]
        
        if val_psf_z < 0.5*max_val_psf:
            
            psf_size[0] = round((i)*1.5)
            break
    
    for j in range(1, psf.shape[1]):
        
        val_psf_x = psf[max_addr[0][0], max_addr[1][0], max_addr[2][0]+j]
        
        if val_psf_x < 0.5*max_val_psf:
            
            psf_size[1] = round((j)*1.5)
            break
    
    new_psf = psf[max_addr[0][0]-psf_size[0]:max_addr[0][0]+psf_size[0]+1, max_addr[1][0]-psf_size[1]:max_addr[1][0]+psf_size[1]+1, max_addr[2][0]-psf_size[1]:max_addr[2][0]+psf_size[1]+1]
    
    return new_psf, psf_size

        
def virtualimage(imsize, addr_z, addr_x, addr_y, sizeaddr):
    
    tempimg = np.zeros(imsize[0]*imsize[1]*imsize[2])
    
    unit_inten = 1
    
    
    intarray = np.ones(sizeaddr)*unit_inten
    
    Less_addr_z = tf.where(tf.less(addr_z, 1))
    Less_addr_z = tf.cast(tf.slice(Less_addr_z, [0, 0], [tf.shape(Less_addr_z)[0], 1]), tf.int32)
    Greater_addr_z = tf.where(tf.greater(addr_z, (imsize[0]-1)))
    Greater_addr_z = tf.cast(tf.slice(Greater_addr_z, [0, 0], [tf.shape(Less_addr_z)[0], 1]), tf.int32)
    Less_Tensor_z = tf.ones([0, tf.size(Less_addr_z)])*(imsize[0]-1)
    Less_Tensor_z = tf.cast(Less_Tensor_z, tf.float32)
    Greater_Tensor_z = tf.ones([0, tf.size(Greater_addr_z)])
    Greater_Tensor_z = tf.cast(Greater_Tensor_z, tf.float32)
    
    new_addr_z = tf.placeholder(tf.float32, name='new_address_z')
    
    new_addr_z = tf.cond(tf.less(tf.constant(0), tf.size(Less_addr_z)), 
                          lambda: tf.tensor_scatter_nd_update(addr_z, Less_addr_z, Less_Tensor_z), 
                          lambda: tf.cast(addr_z, tf.float32))
    
    new_addr_z = tf.cond(tf.less(tf.constant(0), tf.size(Greater_addr_z)), 
                          lambda: tf.tensor_scatter_nd_update(addr_z, Greater_addr_z, Greater_Tensor_z), 
                          lambda: addr_z)
    
    Less_addr_x = tf.where(tf.less(addr_x, 1))
    Less_addr_x = tf.cast(tf.slice(Less_addr_x, [0, 0], [tf.shape(Less_addr_x)[0], 1]), tf.int32)
    Greater_addr_x = tf.where(tf.greater(addr_x, (imsize[1]-1)))
    Greater_addr_x = tf.cast(tf.slice(Greater_addr_x, [0, 0], [tf.shape(Less_addr_x)[0], 1]), tf.int32)
    Less_Tensor_x = tf.ones([0, tf.size(Less_addr_x)])*(imsize[1]-1)
    Less_Tensor_x = tf.cast(Less_Tensor_x, tf.float32)
    Greater_Tensor_x = tf.ones([0, tf.size(Greater_addr_x)])
    Greater_Tensor_x = tf.cast(Greater_Tensor_x, tf.float32)
    
    new_addr_x = tf.placeholder(tf.float32, name='new_address_x')
    
    new_addr_x = tf.cond(tf.less(tf.constant(0), tf.size(Less_addr_x)), 
                          lambda: tf.tensor_scatter_nd_update(addr_x, Less_addr_x, Less_Tensor_x), 
                          lambda: tf.cast(addr_x, tf.float32))
    
    new_addr_x = tf.cond(tf.less(tf.constant(0), tf.size(Greater_addr_x)), 
                          lambda: tf.tensor_scatter_nd_update(addr_x, Greater_addr_x, Greater_Tensor_x), 
                          lambda: addr_x)
    
    Less_addr_y = tf.where(tf.less(addr_y, 1))
    Less_addr_y = tf.cast(tf.slice(Less_addr_y, [0, 0], [tf.shape(Less_addr_y)[0], 1]), tf.int32)
    Greater_addr_y = tf.where(tf.greater(addr_y, (imsize[2]-1)))
    Greater_addr_y = tf.cast(tf.slice(Greater_addr_y, [0, 0], [tf.shape(Less_addr_y)[0], 1]), tf.int32)
    Less_Tensor_y = tf.ones([0, tf.size(Less_addr_y)])*(imsize[2]-1)
    Less_Tensor_y = tf.cast(Less_Tensor_y, tf.float32)
    Greater_Tensor_y = tf.ones([0, tf.size(Greater_addr_y)])
    Greater_Tensor_y = tf.cast(Greater_Tensor_y, tf.float32)
    
    new_addr_y = tf.placeholder(tf.float32, name='new_address_y')
    
    new_addr_y = tf.cond(tf.less(tf.constant(0), tf.size(Less_addr_y)), 
                          lambda: tf.tensor_scatter_nd_update(addr_y, Less_addr_y, Less_Tensor_y), 
                          lambda: tf.cast(addr_y, tf.float32))
    
    new_addr_y = tf.cond(tf.less(tf.constant(0), tf.size(Greater_addr_y)), 
                          lambda: tf.tensor_scatter_nd_update(addr_y, Greater_addr_y, Greater_Tensor_y), 
                          lambda: addr_y)
    
    
    new_address = new_addr_y + new_addr_x*imsize[2] + new_addr_z*imsize[2]*imsize[1]
    
    int_address = tf.math.round(new_address)
    int_address = tf.cast(int_address, dtype=tf.int32)
    tensorintarray = tf.cast(intarray, dtype=tf.float32)
    img_var = tf.constant(tempimg, dtype=tf.float32)
    
    new_img_var = tf.tensor_scatter_nd_add(img_var, int_address, tensorintarray)
    new_img_var = tf.cast(new_img_var, dtype=tf.float32)
                                   
    del [[tempimg, Less_addr_z, Less_addr_x, Less_addr_y, Greater_addr_z, Greater_addr_x, Greater_addr_y, Less_Tensor_z, Less_Tensor_x, Less_Tensor_y, Greater_Tensor_z, Greater_Tensor_x, Greater_Tensor_y, new_addr_z, new_addr_x, new_addr_y, new_address, int_address, tensorintarray, img_var]]
    
    return new_img_var
    
def convimggen(imsize, temp_addr_z, temp_addr_x, temp_addr_y, psf, sizeaddr):
    
    temp_psf = tf.constant(psf.reshape(psf.shape[0], psf.shape[1], psf.shape[2], 1, 1), dtype=tf.float32)
    tf.reshape(virtualimage(imsize, finalres_z, finalres_x, finalres_y, addrsize), [imsize[0], imsize[1], imsize[2]])
    
    temp_image = virtualimage(imsize, temp_addr_z, temp_addr_x, temp_addr_y, sizeaddr)
    temp_image = tf.reshape(temp_image, [1, imsize[0], imsize[1], imsize[2], 1])
    temp_convimg = tf.nn.conv3d(temp_image, temp_psf, strides=[1, 1, 1, 1, 1], padding='SAME')
    
    re_temp_convimg = tf.math.reduce_sum(temp_psf)*tf.math.divide(tf.reshape(temp_convimg, [imsize[0], imsize[1], imsize[2]]), tf.math.reduce_sum(temp_convimg))/imsize[3]
    
    
    del [[temp_psf, temp_image, temp_convimg]]
    
    
    return re_temp_convimg


def fftconvimggen(imsize, temp_addr_z, temp_addr_x, temp_addr_y, psf, sizeaddr):
    
    temp_psf = tf.placeholder(dtype = tf.float32)
    temp_image = tf.placeholder(dtype = tf.float32)
    temp_fft_psf = tf.placeholder(dtype = tf.float32)
    temp_fft_image = tf.placeholder(dtype = tf.float32)
    conv_fft_image = tf.placeholder(dtype = tf.float32)
    conv_image = tf.placeholder(dtype = tf.float32)
    re_temp_convimg = tf.placeholder(dtype = tf.float32)
    
    imageshape = tf.constant([imsize[0], imsize[1], imsize[2]], tf.int32)
    temp_psf = tf.constant(psf, dtype=tf.float32)
    temp_fft_psf = tf.signal.rfft3d(temp_psf, fft_length = imageshape)
        
    temp_image = tf.reshape(virtualimage(imsize, temp_addr_z, temp_addr_x, temp_addr_y, sizeaddr), [imsize[0], imsize[1],imsize[2]])
    temp_fft_image = tf.signal.rfft3d(temp_image, fft_length = imageshape)
    
    conv_fft_image = tf.keras.layers.multiply([temp_fft_image, temp_fft_psf])
    
    conv_image = tf.cast(tf.signal.irfft3d(conv_fft_image, fft_length = imageshape), dtype=tf.float32)
    conv_image = tf.roll(conv_image, tf.cast([tf.math.ceil(imsize[0]/2), tf.math.ceil(imsize[1]/2), tf.math.ceil(imsize[2]/2)], dtype=tf.int32), [0, 1, 2])

    re_temp_convimg = tf.reshape(conv_image, [imsize[0], imsize[1], imsize[2]])/tf.math.reduce_sum(conv_image)*sizeaddr
    
    del [[temp_psf, temp_image, temp_fft_psf, temp_fft_image, conv_fft_image, conv_image]]
    
    
    return re_temp_convimg



def cost_fun(re_image, imsize, temp_addr_x, temp_addr_y, psf, sizeaddr):
    
    cost = tf.nn.l2_loss(tf.math.subtract(re_image, convimggen(imsize, temp_addr_x, temp_addr_y, psf, sizeaddr)))
    
    return cost


def adamIntVar(re_image, imsize, temp_addr_z, temp_addr_x, temp_addr_y, psf, sizeaddr, m_z, v_z, m_x, v_x, m_y, v_y, extrapolnum, i, imagefilter):
    
    LR_z = tf.constant(2, dtype=tf.float32)
    LR_xy = tf.constant(2, dtype=tf.float32)
    
    beta_1 = tf.constant(0.9, dtype=tf.float32)
    beta_1_1 = 1-beta_1
    beta_2 = tf.constant(0.99, dtype=tf.float32)
    beta_2_2 = 1-beta_2
    
    m_z_new = tf.placeholder(dtype=tf.float32)
    v_z_new = tf.placeholder(dtype=tf.float32)
    m_z_new = tf.cast(m_z, dtype=tf.float32)
    v_z_new = tf.cast(v_z, dtype=tf.float32)
        
    m_x_new = tf.placeholder(dtype=tf.float32)
    v_x_new = tf.placeholder(dtype=tf.float32)
    m_x_new = tf.cast(m_x, dtype=tf.float32)
    v_x_new = tf.cast(v_x, dtype=tf.float32)
    
    m_y_new = tf.placeholder(dtype=tf.float32)
    v_y_new = tf.placeholder(dtype=tf.float32)
    m_y_new = tf.cast(m_y, dtype=tf.float32)
    v_y_new = tf.cast(v_y, dtype=tf.float32)
    
    z_res = tf.placeholder(dtype=tf.float32)
    z_res = temp_addr_z
    
    x_res = tf.placeholder(dtype=tf.float32)
    x_res = temp_addr_x
    
    y_res = tf.placeholder(dtype=tf.float32)
    y_res = temp_addr_y
    
    grad_res_z, grad_res_x, grad_res_y = gradgenfun(re_image, imsize, temp_addr_z, temp_addr_x, temp_addr_y, psf, sizeaddr, imagefilter)
    
    m_z_new = beta_1*m_z + beta_1_1*tf.cast(grad_res_z, dtype=tf.float32)
    v_z_new = beta_2*v_z + beta_2_2*tf.cast(tf.math.square(grad_res_z), dtype=tf.float32)
    
    z_res = temp_addr_z - LR_z*(m_z_new/(1-tf.math.pow(beta_1, i)))/tf.math.sqrt(v_z_new/(1-tf.math.pow(beta_2,i)) + tf.constant(0.001, dtype=tf.float32))
    
    z_res = tf.math.round(z_res)
    
    m_x_new = beta_1*m_x + beta_1_1*tf.cast(grad_res_x, dtype=tf.float32)
    v_x_new = beta_2*v_x + beta_2_2*tf.cast(tf.math.square(grad_res_x), dtype=tf.float32)
    
    x_res = temp_addr_x - LR_xy*(m_x_new/(1-tf.math.pow(beta_1, i)))/tf.math.sqrt(v_x_new/(1-tf.math.pow(beta_2,i)) + tf.constant(0.001, dtype=tf.float32))
    
    x_res = tf.math.round(x_res)
    
    m_y_new = beta_1*m_y + beta_1_1*tf.cast(grad_res_y, dtype=tf.float32)
    v_y_new = beta_2*v_y + beta_2_2*tf.cast(tf.math.square(grad_res_y), dtype=tf.float32)
    
    y_res = temp_addr_y - LR_xy*(m_y_new/(1-tf.math.pow(beta_1, i)))/tf.math.sqrt(v_y_new/(1-tf.math.pow(beta_2,i)) + tf.constant(0.001, dtype=tf.float32))
    
    y_res = tf.math.round(y_res)
    
    
    Less_addr_z = tf.where(tf.less(z_res, extrapolnum))
    Less_addr_z = tf.cast(tf.slice(Less_addr_z, [0, 0], [tf.shape(Less_addr_z)[0], 1]), tf.int32)
    Greater_addr_z = tf.where(tf.greater(z_res, (imsize[0]-extrapolnum-1)))
    Greater_addr_z = tf.cast(tf.slice(Greater_addr_z, [0, 0], [tf.shape(Less_addr_z)[0], 1]), tf.int32)
    Less_Tensor_z = tf.ones([0, tf.size(Less_addr_z)])*(imsize[0]-extrapolnum-1)
    Less_Tensor_z = tf.cast(Less_Tensor_z, tf.float32)
    Greater_Tensor_z = tf.ones([0, tf.size(Greater_addr_z)])*extrapolnum
    Greater_Tensor_z = tf.cast(Greater_Tensor_z, tf.float32)
    
    new_z_res = tf.placeholder(tf.float32, name='new_address_z')
    
    new_z_res = tf.cond(tf.less(tf.constant(0), tf.size(Less_addr_z)), 
                          lambda: tf.tensor_scatter_nd_update(z_res, Less_addr_z, Less_Tensor_z), 
                          lambda: tf.cast(z_res, tf.float32))
    
    new_z_res = tf.cond(tf.less(tf.constant(0), tf.size(Greater_addr_z)), 
                          lambda: tf.tensor_scatter_nd_update(z_res, Greater_addr_z, Greater_Tensor_z), 
                          lambda: z_res)
    
    Less_addr_x = tf.where(tf.less(x_res, 1))
    Less_addr_x = tf.cast(tf.slice(Less_addr_x, [0, 0], [tf.shape(Less_addr_x)[0], 1]), tf.int32)
    Greater_addr_x = tf.where(tf.greater(x_res, (imsize[1]-1)))
    Greater_addr_x = tf.cast(tf.slice(Greater_addr_x, [0, 0], [tf.shape(Less_addr_x)[0], 1]), tf.int32)
    Less_Tensor_x = tf.ones([0, tf.size(Less_addr_x)])*(imsize[1]-1)
    Less_Tensor_x = tf.cast(Less_Tensor_x, tf.float32)
    Greater_Tensor_x = tf.ones([0, tf.size(Greater_addr_x)])
    Greater_Tensor_x = tf.cast(Greater_Tensor_x, tf.float32)
    
    new_x_res = tf.placeholder(tf.float32, name='new_address_x')
    
    new_x_res = tf.cond(tf.less(tf.constant(0), tf.size(Less_addr_x)), 
                          lambda: tf.tensor_scatter_nd_update(x_res, Less_addr_x, Less_Tensor_x), 
                          lambda: tf.cast(x_res, tf.float32))
    
    new_x_res = tf.cond(tf.less(tf.constant(0), tf.size(Greater_addr_x)), 
                          lambda: tf.tensor_scatter_nd_update(x_res, Greater_addr_x, Greater_Tensor_x), 
                          lambda: x_res)
    
    Less_addr_y = tf.where(tf.less(y_res, 1))
    Less_addr_y = tf.cast(tf.slice(Less_addr_y, [0, 0], [tf.shape(Less_addr_y)[0], 1]), tf.int32)
    Greater_addr_y = tf.where(tf.greater(y_res, (imsize[2]-1)))
    Greater_addr_y = tf.cast(tf.slice(Greater_addr_y, [0, 0], [tf.shape(Less_addr_y)[0], 1]), tf.int32)
    Less_Tensor_y = tf.ones([0, tf.size(Less_addr_y)])*(imsize[2]-1)
    Less_Tensor_y = tf.cast(Less_Tensor_y, tf.float32)
    Greater_Tensor_y = tf.ones([0, tf.size(Greater_addr_y)])
    Greater_Tensor_y = tf.cast(Greater_Tensor_y, tf.float32)
    
    new_y_res = tf.placeholder(tf.float32, name='new_address_x')
    
    new_y_res = tf.cond(tf.less(tf.constant(0), tf.size(Less_addr_y)), 
                          lambda: tf.tensor_scatter_nd_update(y_res, Less_addr_y, Less_Tensor_y), 
                          lambda: tf.cast(y_res, tf.float32))
    
    new_y_res = tf.cond(tf.less(tf.constant(0), tf.size(Greater_addr_y)), 
                          lambda: tf.tensor_scatter_nd_update(y_res, Greater_addr_y, Greater_Tensor_y), 
                          lambda: y_res)
    
    diff_sum = tf.math.reduce_sum(tf.math.abs(tf.math.subtract(convimggen(imsize, temp_addr_z, temp_addr_x, temp_addr_y, psf, sizeaddr), re_image))) - tf.math.reduce_sum(tf.math.abs(tf.math.subtract(re_image, convimggen(imsize, new_z_res, new_x_res, new_y_res, psf, sizeaddr))))
    
    del [[Less_addr_z, Less_addr_x, Less_addr_y, Greater_addr_z, Greater_addr_x, Greater_addr_y, Less_Tensor_z, Less_Tensor_x, Less_Tensor_y, Greater_Tensor_z, Greater_Tensor_x, Greater_Tensor_y]]
    
    
    return new_z_res, new_x_res, new_y_res, m_z_new, v_z_new, m_x_new, v_x_new, m_y_new, v_y_new, diff_sum
    


def gradgenfun(re_image, imsize, temp_addr_z, temp_addr_x, temp_addr_y, psf, sizeaddr, imagefilter):
    
    temp_result = - 1000*tf.math.subtract(re_image, convimggen(imsize, temp_addr_z, temp_addr_x, temp_addr_y, psf, sizeaddr))
    
    result_z_1_1= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_z+1, temp_addr_x, temp_addr_y], 1), dtype = tf.int32))
    result_z_1_2= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_z+2, temp_addr_x, temp_addr_y], 1), dtype = tf.int32))
    result_z_1_3= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_z+3, temp_addr_x, temp_addr_y], 1), dtype = tf.int32))
    result_z_1_4= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_z+4, temp_addr_x, temp_addr_y], 1), dtype = tf.int32))
    result_z_1_5= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_z+5, temp_addr_x, temp_addr_y], 1), dtype = tf.int32))
    result_z_1_6= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_z+6, temp_addr_x, temp_addr_y], 1), dtype = tf.int32))
    result_z_1_7= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_z+7, temp_addr_x, temp_addr_y], 1), dtype = tf.int32))
    result_z_1_8= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_z+8, temp_addr_x, temp_addr_y], 1), dtype = tf.int32))
    result_z_2_1= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_z-1, temp_addr_x, temp_addr_y], 1), dtype = tf.int32))
    result_z_2_2= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_z-2, temp_addr_x, temp_addr_y], 1), dtype = tf.int32))
    result_z_2_3= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_z-3, temp_addr_x, temp_addr_y], 1), dtype = tf.int32))
    result_z_2_4= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_z-4, temp_addr_x, temp_addr_y], 1), dtype = tf.int32))
    result_z_2_5= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_z-5, temp_addr_x, temp_addr_y], 1), dtype = tf.int32))
    result_z_2_6= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_z-6, temp_addr_x, temp_addr_y], 1), dtype = tf.int32))
    result_z_2_7= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_z-7, temp_addr_x, temp_addr_y], 1), dtype = tf.int32))
    result_z_2_8= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_z-8, temp_addr_x, temp_addr_y], 1), dtype = tf.int32))
    
    
    result_z_1 = (result_z_1_1 + result_z_1_2 + result_z_1_3 + result_z_1_4 + result_z_1_5 + result_z_1_6 + result_z_1_7 + result_z_1_8)/8
    result_z_2 = (result_z_2_1 + result_z_2_2 + result_z_2_3 + result_z_2_4 + result_z_2_5 + result_z_2_6 + result_z_2_7 + result_z_2_8)/8
     
    result_x_1_1= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_z, temp_addr_x+2, temp_addr_y], 1), dtype = tf.int32))
    result_x_1_2= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_z, temp_addr_x+4, temp_addr_y], 1), dtype = tf.int32))
    result_x_1_3= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_z, temp_addr_x+6, temp_addr_y], 1), dtype = tf.int32))
    result_x_1_4= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_z, temp_addr_x+8, temp_addr_y], 1), dtype = tf.int32))
    result_x_2_1= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_z, temp_addr_x-2, temp_addr_y], 1), dtype = tf.int32))
    result_x_2_2= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_z, temp_addr_x-4, temp_addr_y], 1), dtype = tf.int32))
    result_x_2_3= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_z, temp_addr_x-6, temp_addr_y], 1), dtype = tf.int32))
    result_x_2_4= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_z, temp_addr_x-8, temp_addr_y], 1), dtype = tf.int32))
    
    
    result_x_1 = (result_x_1_1 + result_x_1_2 + result_x_1_3 + result_x_1_4)/4
    result_x_2 = (result_x_2_1 + result_x_2_2 + result_x_2_3 + result_x_2_4)/4
    
    
    result_y_1_1= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_z, temp_addr_x, temp_addr_y+2], 1), dtype = tf.int32))
    result_y_1_2= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_z, temp_addr_x, temp_addr_y+4], 1), dtype = tf.int32))
    result_y_1_3= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_z, temp_addr_x, temp_addr_y+6], 1), dtype = tf.int32))
    result_y_1_4= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_z, temp_addr_x, temp_addr_y+8], 1), dtype = tf.int32))
    result_y_2_1= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_z, temp_addr_x, temp_addr_y-2], 1), dtype = tf.int32))
    result_y_2_2= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_z, temp_addr_x, temp_addr_y-4], 1), dtype = tf.int32))
    result_y_2_3= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_z, temp_addr_x, temp_addr_y-6], 1), dtype = tf.int32))
    result_y_2_4= tf.gather_nd(temp_result, tf.cast(tf.concat([temp_addr_z, temp_addr_x, temp_addr_y-8], 1), dtype = tf.int32))
    
    
    result_y_1 = (result_y_1_1 + result_y_1_2 + result_y_1_3 + result_y_1_4)/4
    result_y_2 = (result_y_2_1 + result_y_2_2 + result_y_2_3 + result_y_2_4)/4
    
    result_z = tf.math.subtract(result_z_1_1, result_z_2_1)
    result_z = tf.reshape(tf.cast(result_z, dtype=tf.float32), [sizeaddr, 1])
    
    result_x = tf.math.subtract(result_x_1_1, result_x_2_1)
    result_x = tf.reshape(tf.cast(result_x, dtype=tf.float32), [sizeaddr, 1])
    
    result_y = tf.math.subtract(result_y_1_1, result_y_2_1)
    result_y = tf.reshape(tf.cast(result_y, dtype=tf.float32), [sizeaddr, 1])
    
    return result_z, result_x, result_y

def clearall():
    all = [var for var in globals() if var[0] != "_"]
    
    for var in all:
        del globals()[var]
        

  '{0}.{1}.{2}'.format(*version.hdf5_built_version_tuple)


In [2]:
psf = io.imread('./deconvolution_data/PSF_791nm_62nm_3D.tif')
psf_size = [0, 0]
psf = psf/np.sum(psf)
new_psf, psf_size = psf_crop(psf)
big_psf = new_psf
big_psf = big_psf/np.sum(big_psf)

TiffPage 0: TypeError: read_bytes() missing 3 required positional arguments: 'dtype', 'count', and 'offsetsize'


In [3]:
vir_num_scaling_factor = 35  #0.1~10000
background_threshold = 100
addr_threshold = 1
numbmaxaddr = 10000000
signal_bg_ratio = 0.85
iter_criterion = 1e-24

In [4]:
fileaddr = './deconvolution_data/Brain_image/slices/'
filename = 'Denoised-791.0-200mw-zoom8_3D_0001_cropped_fourier8x_1'
intratio = 1

In [5]:
start = time.time()

for ii in range(1, 2):
    
    big_psf = new_psf
    big_psf = big_psf/np.sum(big_psf)
    image = np.float32(io.imread(fileaddr + filename +'.tif'))
    #image = np.float32(io.imread(fileaddr + filename + str(ii) +'.tif'))
    
    image = posimage(image - background_threshold)
    
    lp_image = lp_filt(image, big_psf, image.shape)
    hp_image = hp_filt(image, big_psf, image.shape)
    offset_image = det_offset(image, big_psf, image.shape)
    addr = np.where(image > addr_threshold)
    [numberofaxes, tempsize] = np.shape(addr)
    signalmatrix = image[addr[0][0:tempsize-1], addr[1][0:tempsize-1], addr[2][0:tempsize-1]]
    
    intensityratio = np.sum(posimage(image - (np.abs(np.mean(offset_image)) + 0*np.std(hp_image)))/np.sum(image))
    


    totaladdrsize = np.var(image)*tempsize/np.mean(image)/np.mean(image)*vir_num_scaling_factor
                      
    imagefilter = filtimage(image - background_threshold)
    imagefilter = z_extrapolimg(imagefilter, big_psf)
    io.imsave(filename+'_imagefilter.tif', np.float32(imagefilter))

    image = posimage(image)
    
    if np.isnan(totaladdrsize) == True:
        
        totaladdrsize = 0
        

    if totaladdrsize <= numbmaxaddr:
    
        addrsize = np.int(np.floor(totaladdrsize))
        iternumb = 1
    
    else:
    
        iternumb = np.int(np.floor((totaladdrsize/numbmaxaddr)))
        addrsize = np.int(np.floor((totaladdrsize/iternumb)))
    
        

    intratio = np.sum(posimage(image-background_threshold))/iternumb/addrsize
    
    big_psf = signal_bg_ratio*big_psf*intratio

    image_org = image
    image = z_extrapolimg(image_org, big_psf)
    io.imsave(filename+'_temp.tif', np.float32(image))
    
    imsize = [image.shape[0], image.shape[1], image.shape[2], intratio]
    im_orgsize = [image_org.shape[0], image_org.shape[1], image_org.shape[2]]
    
    config = tf.ConfigProto()
    config.gpu_options.allow_growth=True
    

    addr_x = np.zeros((addrsize, 1))
    addr_y = np.zeros((addrsize, 1))
    addr_z = np.zeros((addrsize, 1))
    
    addr = np.where(image > np.abs(np.mean(offset_image)) + 1*np.std(hp_image))
    addr = np.array(addr)
    
        
    for i in range(addrsize):
    
        randaddr = random.randint(0, addr.shape[1]-1)
    
        addr_z[i] = addr[0, randaddr]
        addr_x[i] = addr[1, randaddr]
        addr_y[i] = addr[2, randaddr]
    
    io.imsave('temp_imsave.tif', np.zeros((im_orgsize[0], im_orgsize[1], im_orgsize[2])))
    io.imsave('temp_imresidue.tif', np.float32(posimage(image-background_threshold)))
    
    gc.enable()
    
    jj = 0
    
    SNR = SNR_cal(image, psf, image.shape)
    
    
    while (iternumb - jj)/iternumb > 0.01:
        
        
        tf.reset_default_graph()
        
        
        image = io.imread('temp_imresidue.tif')
        simage = io.imread('temp_imsave.tif')
        
        org_psf = psf
    
        SNR = SNR_cal(image, org_psf, image.shape)
        
        image = posimage(image)
               
        sum_image = tf.Variable(simage, "sum_image", dtype=tf.float32)
            
        residue = tf.Variable(image, "residue", dtype=tf.float32)
        
        extrapolnum = np.float32(math.ceil(big_psf.shape[0]/2))
        
        for i in range(addrsize):
    
            randaddr = random.randint(0, addr.shape[1]-1)
    
            addr_z[i] = addr[0, randaddr]
            addr_x[i] = addr[1, randaddr]
            addr_y[i] = addr[2, randaddr]
    
        sess = tf.Session(config=config)
        
            
        finalres_z = tf.placeholder(dtype=tf.float32)
        finalres_x = tf.placeholder(dtype=tf.float32)
        finalres_y = tf.placeholder(dtype=tf.float32)
            
        
        if jj == 1:
            
            re_image = tf.placeholder(dtype=tf.float32)
            re_image = tf.cast(image, tf.float32)
                
        else:
            
            re_image = tf.placeholder(dtype=tf.float32)
            re_image = residue
                
        finalres_z = tf.cast(addr_z, dtype=tf.float32)
        finalres_x = tf.cast(addr_x, dtype=tf.float32)
        finalres_y = tf.cast(addr_y, dtype=tf.float32)
            
    
            
        m = tf.placeholder(dtype=tf.float32)
        v = tf.placeholder(dtype=tf.float32)
        m = tf.zeros(tf.shape(finalres_x), dtype=tf.float32)
        v = tf.zeros(tf.shape(finalres_x), dtype=tf.float32)
        i = tf.constant(1, dtype=tf.float32)
        
        m_z_new = m
        v_z_new = v
        m_x_new = m
        v_x_new = v
        m_y_new = m
        v_y_new = v
        
        sess.run(tf.initialize_all_variables())
        
        
        j = tf.constant(1, dtype=tf.float32)
        
        diff_sum = tf.constant(1, dtype=tf.float32)
        
        iterthrval = intratio*addrsize*1e-8
        
            
        def cond(j, finalres_z, finalres_x, finalres_y, m_z_new, v_z_new, m_x_new, v_x_new, m_y_new, v_y_new, diff_sum, extrapolnum):
            
            return tf.math.logical_or(tf.less(j, 10), tf.greater_equal(diff_sum, iterthrval))
            
        def body(j, finalres_z, finalres_x, finalres_y, m_z_new, v_z_new, m_x_new, v_x_new, m_y_new, v_y_new, diff_sum, extrapolnum):
            j = tf.add(j, 1)
            finalres_z, finalres_x, finalres_y, m_z_new, v_z_new, m_x_new, v_x_new, m_y_new, v_y_new, diff_sum = adamIntVar(re_image, imsize, finalres_z, finalres_x, finalres_y, big_psf, addrsize, m_z_new, v_z_new, m_x_new, v_x_new, m_y_new, v_y_new, extrapolnum, j, imagefilter)
            return [j, finalres_z, finalres_x, finalres_y, m_z_new, v_z_new, m_x_new, v_x_new, m_y_new, v_y_new, diff_sum, extrapolnum]
        
        res_loop = tf.while_loop(cond, body, [j, finalres_z, finalres_x, finalres_y, m_z_new, v_z_new, m_x_new, v_x_new, m_y_new, v_y_new, diff_sum, extrapolnum])
        
        [j, finalres_z, finalres_x, finalres_y, m_z_new, v_z_new, m_x_new, v_x_new, m_y_new, v_y_new, diff_sum, extrapolnum] = res_loop
        
        resultconvimage = convimggen(imsize, finalres_z, finalres_x, finalres_y, big_psf, addrsize)
    
        resultconvimage = tf.reshape(resultconvimage, [imsize[0], imsize[1], imsize[2]])
        
        resultimage = tf.reshape(virtualimage(imsize, finalres_z, finalres_x, finalres_y, addrsize), [imsize[0], imsize[1], imsize[2]])
    
        resultimage = tf.slice(resultimage, [np.int32(np.round(np.abs(im_orgsize[0]-imsize[0])/2)), 0, 0], [im_orgsize[0], im_orgsize[1], im_orgsize[2]])
        
        
        
        sub_image = tf.placeholder(tf.float32)
        sub_image = tf.cast(tf.math.subtract(re_image, resultconvimage), tf.float32)
        temp_image = tf.placeholder(tf.float32)
        temp_image = tf.cast(sum_image, dtype=tf.float32)
        
        if jj == 1:
        
            sum_image = resultimage
            residue = sub_image
            
            
        else:
                
            sum_image = tf.cast(tf.add(temp_image, resultimage), tf.float32)
            residue = tf.cast(sub_image, tf.float32)
                
        _, _, rimage, _, _ = sess.run([residue, sum_image, resultimage, sub_image, temp_image])
        
        convimage = signal.fftconvolve(rimage, psf, mode='full', axes=None)
        convsize = [math.floor((convimage.shape[0]-rimage.shape[0])/2),math.floor((convimage.shape[1]-rimage.shape[1])/2), math.floor((convimage.shape[2]-rimage.shape[2])/2)]
        orgimage = image[math.floor((imsize[0]-im_orgsize[0])/2):math.floor((imsize[0]-im_orgsize[0])/2)+im_orgsize[0], :, :]
        smallconvimage = convimage[convsize[0]-1:convsize[0]+rimage.shape[0]-1, convsize[1]-1:convsize[1]+rimage.shape[1]-1, convsize[2]-1:convsize[2]+rimage.shape[2]-1]
        smallconvimage = smallconvimage/np.sum(smallconvimage)*addrsize
        intratio = 1
        
        
        
        jj += 1
        
        orgimage = orgimage - smallconvimage
        
        image = z_extrapolimg(orgimage, big_psf)
        
        io.imsave('temp_imresidue.tif', np.float32(image))
        io.imsave('temp_imresidue_pos.tif', posimage(np.float32(image)))

        
        io.imsave('temp_imsave.tif', np.float32(rimage+simage))
        
        
        print(jj/iternumb*100, "% processed", "time:", time.time() - start, "SNR:", SNR)
        tf.keras.backend.clear_session()
    
        del [[simage, residue, temp_image, re_image, sub_image, sum_image, resultconvimage, resultimage, m_z_new, v_z_new, m_x_new, v_x_new, m_y_new, v_y_new, m, v, i, finalres_z, finalres_x, finalres_y]]
        
        gc.collect()
        
        
        
    simage = io.imread('temp_imsave.tif')
    
    rimage = io.imread('temp_imresidue.tif')
    
    if not(os.path.isdir(fileaddr + 'Doconvolved')):
        os.makedirs(os.path.join(fileaddr + 'Doconvolved'))
    if not(os.path.isdir(fileaddr + 'Deconvolved_pos_residue')):
        os.makedirs(os.path.join(fileaddr + 'Deconvolved_pos_residue'))
    if not(os.path.isdir(fileaddr + 'Deconvolved_residue')):
        os.makedirs(os.path.join(fileaddr + 'Deconvolved_residue'))
    
    io.imsave(fileaddr + 'Doconvolved/Deconvolved_'+filename+'Denoised-'+str(ii)+'.tif', intratio*simage)
    io.imsave(fileaddr + 'Deconvolved_pos_residue/Deconvolved_'+filename+'Denoised_posresidue-'+str(ii)+'.tif', posimage(rimage))
    io.imsave(fileaddr + 'Deconvolved_residue/Deconvolved_'+filename+'Denoised_residue-'+str(ii)+'.tif', rimage)
    
    print("finished!! total calculation time:", (time.time() - start)/60, "min, SNR: ", SNR)
    
    
    del [[image, simage, rimage, jj, iternumb, lp_image, hp_image, background_threshold]]
    
    gc.collect()
    
    
    tf.reset_default_graph()
    
    
    
    

    









Instructions for updating:
Use `tf.global_variables_initializer` instead.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
100.0 % processed time: 191.34406304359436 SNR: 17.266463437420615
finished!! total calculation time: 3.200179986159007 min, SNR:  17.266463437420615
