# Ptychography using gradient descent
### Chengyu Wang, Duke University
### David J. Brady, University of Arizona
### Timothy J. Schulz, Michigan Technological University

This script implements the 2D ptychography and a phase retrieval algorithm using gradient descent as described in "Photon-limited bounds for phase retrieval."



In [None]:
import tensorflow as tf
import numpy as np
from numpy import matlib
import random
tf.compat.v1.disable_eager_execution()

# import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

## Error metrics

If $\underline{F}$ is an $N$-element optical field, and an estimate of this field is $\widehat{\underline{F}}$.
One way to address this is to optimize the MSE over a global phase:
\begin{equation} 
MSE = \underset{\psi}{argmin} \frac{1}{N} \left\| e^{j \psi} \underline{F} - \widehat{\underline{F}} \right\|^2. 
\end{equation}

In [None]:
def minAngMSE(xt,xest):
    mse=(np.abs(np.vdot(xt,xt))+np.abs(np.vdot(xest,xest))-2*np.abs(np.vdot(xt,xest)));
    return mse/xt.shape[0]

## Forward model
For the forward model, $N_x$ is the number of pixels in each dimension of the random gaussian signal, $L_x$ is the size of the subaperture, and $M_x$ is the size of the each sampld frame. The subaperture moves circularly with step size $\Delta $. The total number of frames is $\left (\frac{N_x}{\Delta}\right )^2$. We consider $N_x$ and $\Delta $ to be powers of 2.

A phase shift can be added to the aperture using the argument *is_mask*.

In [None]:
## Dimensions
N = 256
L = 16
# M = 2 * L
pad_size = 8
M = L + pad_size*2
pitch = 1
total_frame = int(N/pitch)**2

is_mask = False ## add a phase shift
is_noisy = True  ## add Poisson noise

# ## random phase mask
# mask = np.exp(1j * np.random.random([L,L]) * 2 * np.pi)

# ## periodic phase mask
# K = np.pi/L
# mesh = np.arange(L)
# xlocations,ylocations = np.meshgrid(mesh,mesh);
# mask = np.exp(1j * np.cos(K*xlocations) * np.cos(K*ylocations)  * 2*np.pi)

# ## MURA mask
# def qres(p):
#     qr = np.zeros([p-1])
#     for i in range(p-1):
#         qr[np.mod((i+1)**2,p)-1]=1
#     return qr

# def gen_mura(p):
#     ## p is a prime number
#     mask = np.zeros([p,p])
#     mask[1:,:] = 1
#     qr = qres(p)
#     for i in range(1,p):
#         for j in range(1,p):
#             if qr[i-1] == 1 and qr[j-1] == 1:
#                 mask[i,j] = 1
#             elif qr[i-1] == 0 and qr[j-1] == 0:
#                 mask[i,j] = 1
#             else:
#                 mask[i,j] = 0
#     return mask
# mask = gen_mura(L)*2-1 + 0 * 1j

In [None]:
## subapertures
## xlocations/ylocations save the cordinates of the top left pixels of subapertures
mesh = np.arange(0,N,pitch)
xlocations,ylocations = np.meshgrid(mesh,mesh);
xlocations = xlocations.astype(int).flatten()
ylocations = ylocations.astype(int).flatten()

## Energy constraint
## Count how many times each pixel (in Fourier domain) is measured
## The energy will be eqaully divided to all measurements
counts = np.zeros([2*N,2*N])
for i in range(total_frame):
     counts[xlocations[i]:xlocations[i]+L,ylocations[i]:ylocations[i]+L] += 1
counts = counts[0:N,0:N] + counts[0:N,N:2*N] + counts[N:2*N,0:N] + counts[N:2*N,N:2*N]

## Define the tensorflow graph to compute gradient

In [None]:
## define the tensorflow model
X_FT_EST = tf.compat.v1.placeholder(tf.complex128, shape=(None,L,L))
paddings = tf.constant([[0,0],[pad_size, pad_size], [pad_size, pad_size]])
MEASUREMENT = tf.compat.v1.placeholder(tf.float64, shape=(None,M,M))
if is_mask:
    MASK = tf.constant(mask)
    X_FT_EST_PAD = tf.pad(tf.multiply(X_FT_EST,MASK),paddings)
else:
    X_FT_EST_PAD = tf.pad(X_FT_EST,paddings)
FORWARD = tf.abs(tf.signal.ifft2d(tf.signal.ifftshift(X_FT_EST_PAD,[1,2]))) * M
loss = tf.reshape(tf.cast(tf.compat.v1.losses.mean_squared_error(abs(MEASUREMENT),abs(FORWARD)),tf.complex128),[1,1])
weight = [[1 + 1j*0]]
gradient = tf.gradients(loss, X_FT_EST, grad_ys = weight)

## session
sess = tf.compat.v1.Session()

## Simulation

In phase retirevel, we initialize the gradient descent with a simple projection method. The projection method itself is sensitive to noise, resulting in poor MSE, but it approximates the groundtruth faster than gradient descent method. The phase retrieval implemented in this script has two stages:

- First a projection method approximates the grountruth
- Second the gradient descent algorithm improves the accuracy.

FYI: Because of the circular connections, the if ... elif ... elif ... else part is used to consider different locations of the aperture.

FP_phaseretrieval_w_tensorflow.ipynb provides a version without circular connection. The phase retrieval in that file uses the same scheme, but it may look more straightforward.

In [None]:
num_loop = 3 ## number of loops for projection method
num_iteration = 50 ## number of iterations for gradient descent
lr = 5 * 1e7 ## learning rate for gradient update
tol_gd = 0 ## tolerance of the humgradient descent, 0 if early stop is not needed
photons = 1e6 ## average number of photons per signal element is106
num_trial = 10 ## number of simulations

In [None]:
## random signals
# XT = np.random.randn(num_trial,N,N)+1j*np.random.randn(num_trial,N,N);
# XT = np.sqrt(photons)*XT/np.sqrt(2.)
# np.save('XT.npy',XT)
XT = np.load('XT.npy')

In [None]:
MSE = np.zeros([num_trial])
for ct in range(num_trial):
    print(ct,end='\r',flush=True)
    
    ## signal
#     xt = np.random.randn(N,N)+1j*np.random.randn(N,N);
#     xt = np.sqrt(photons)*xt/np.sqrt(2.)
    xt = XT[ct,:,:]
    objectFT = np.fft.fftshift(np.fft.fft2(xt))/N/np.sqrt(counts)
    objectFT = matlib.repmat(objectFT, 2, 2)
    
    ## measurement
    imSeqLowRes = np.zeros([total_frame, M, M]);
    for i in range(total_frame):
        imSeqLowFT = objectFT[xlocations[i]:xlocations[i] + L,ylocations[i]:ylocations[i] + L]
        if is_mask:
            imSeqLowFT = np.pad(imSeqLowFT * mask,pad_size)
        else:
            imSeqLowFT = np.pad(imSeqLowFT,pad_size)
        imSeqLowRes[i,:,:] = np.abs(np.fft.ifft2(np.fft.ifftshift(imSeqLowFT))) * M
    if is_noisy:
        imSeqLowRes = np.sqrt(np.random.poisson(np.power(np.abs(imSeqLowRes),2)))

#     #energy
#     print(np.sum(np.power(abs(xt.flatten()),2)))
#     print(np.sum(np.power(abs(imSeqLowRes.flatten()),2)))

    # Projection
    imageRecover = np.random.randn(N,N);
    imageRecoverFT = np.fft.fftshift(np.fft.fft2(imageRecover))
    imageRecoverFT = objectFT[0:N,0:N] + 10
    seq = list(range(total_frame))
    for loop in range(num_loop):
        random.shuffle(seq)
        for i in seq:
            if not xlocations[i] > N - L and not ylocations[i] > N-L:
                if is_mask:
                    imLowRes = np.fft.ifft2(np.fft.ifftshift(np.pad(imageRecoverFT[xlocations[i]:xlocations[i]+L,ylocations[i]:ylocations[i]+L] * mask,pad_size)))
                else:
                    imLowRes = np.fft.ifft2(np.fft.ifftshift(np.pad(imageRecoverFT[xlocations[i]:xlocations[i]+L,ylocations[i]:ylocations[i]+L],pad_size)))
                        
                imLowRes = imSeqLowRes[i,:,:] * np.exp(1j * np.angle(imLowRes))
                inverse = np.fft.fftshift(np.fft.fft2(imLowRes)) / M
                if is_mask:
                    imageRecoverFT[xlocations[i]:xlocations[i]+L,ylocations[i]:ylocations[i]+L] = inverse[pad_size:pad_size+L,pad_size:pad_size+L]/mask
                else:
                    imageRecoverFT[xlocations[i]:xlocations[i]+L,ylocations[i]:ylocations[i]+L] = inverse[pad_size:pad_size+L,pad_size:pad_size+L]
            else:
                imageRecoverFTRep = matlib.repmat(imageRecoverFT, 2, 2)
                if is_mask:
                    imLowRes = np.fft.ifft2(np.fft.ifftshift(np.pad(imageRecoverFTRep[xlocations[i]:xlocations[i]+L,ylocations[i]:ylocations[i]+L] * mask,pad_size)))
                else:
                    imLowRes = np.fft.ifft2(np.fft.ifftshift(np.pad(imageRecoverFTRep[xlocations[i]:xlocations[i]+L,ylocations[i]:ylocations[i]+L],pad_size)))
                imLowRes = imSeqLowRes[i,:,:] * np.exp(1j * np.angle(imLowRes))
                inverse = np.fft.fftshift(np.fft.fft2(imLowRes)) / M
                if is_mask:
                    inverse = inverse[pad_size:pad_size+L,pad_size:pad_size+L]/mask
                else:
                    inverse = inverse[pad_size:pad_size+L,pad_size:pad_size+L]
                if xlocations[i] > N - L and not ylocations[i] > N-L:
                    imageRecoverFT[xlocations[i]:N,ylocations[i]:ylocations[i]+L] = inverse[0:N-xlocations[i],:]
                    imageRecoverFT[0:L+xlocations[i]-N,ylocations[i]:ylocations[i]+L] = inverse[N-xlocations[i]:,:]
                elif not xlocations[i] > N - L and ylocations[i] > N-L:
                    imageRecoverFT[xlocations[i]:xlocations[i]+L,ylocations[i]:N] = inverse[:,0:N-ylocations[i]]
                    imageRecoverFT[xlocations[i]:xlocations[i]+L,0:L+ylocations[i]-N] = inverse[:,N-ylocations[i]:]
                else:
                    imageRecoverFT[xlocations[i]:N,ylocations[i]:N] = inverse[0:N-xlocations[i],0:N-ylocations[i]]
                    imageRecoverFT[0:L+xlocations[i]-N,0:L+ylocations[i]-N] = inverse[N-xlocations[i]:N,N-ylocations[i]:]
                    imageRecoverFT[xlocations[i]:N,0:L+ylocations[i]-N] = inverse[0:N-xlocations[i],N-ylocations[i]:]
                    imageRecoverFT[0:L+xlocations[i]-N,ylocations[i]:N] = inverse[N-xlocations[i]:,0:N-ylocations[i]]
        if loop%1 == 0:
            imageRecover=np.fft.ifft2(np.fft.ifftshift(imageRecoverFT*np.sqrt(counts))) * N
            print(loop,minAngMSE(xt.flatten(),imageRecover.flatten()),end='\r',flush=True)
#         if minAngMSE(xt.flatten(),imageRecover.flatten()) < 5.7:
#             break

    ## GD
    # imageRecover = np.random.randn(N,N)+1j*np.random.randn(N,N)
    # imageRecover = np.sqrt(photons)*imageRecover/np.sqrt(2.)  
    # imageRecoverFT = np.fft.fftshift(np.fft.fft2(imageRecover))/4/N
    loss_pre = 0
    for iter_ in range(num_iteration):

        ## generate batch
        imageRecoverFTRep = matlib.repmat(imageRecoverFT, 2, 2)
        batch_tensor = np.zeros([total_frame ,L, L]) * 1j
        for i in range(total_frame):
            batch_tensor[i,:,:] = imageRecoverFTRep[xlocations[i]:xlocations[i]+L,ylocations[i]:ylocations[i]+L]

        current_loss = abs(sess.run(loss,feed_dict={X_FT_EST:batch_tensor,MEASUREMENT:imSeqLowRes})[0])
        
        if abs(loss_pre - current_loss )<tol_gd:  ## early stop
            break
        loss_pre = current_loss

        gradient_sum = np.zeros([N,N]) * 1j
        gradient_tensor = sess.run(gradient,feed_dict={X_FT_EST:batch_tensor,MEASUREMENT:imSeqLowRes})[0]
        for i in range(total_frame):
            if not xlocations[i] > N - L and not ylocations[i] > N-L:
                gradient_sum[xlocations[i]:xlocations[i]+L,ylocations[i]:ylocations[i]+L] += gradient_tensor[i,:,:]
            elif xlocations[i] > N - L and not ylocations[i] > N-L:
                gradient_sum[xlocations[i]:N,ylocations[i]:ylocations[i]+L] += gradient_tensor[i,0:N-xlocations[i],:]
                gradient_sum[0:L+xlocations[i]-N,ylocations[i]:ylocations[i]+L] += gradient_tensor[i,N-xlocations[i]:,:]
            elif not xlocations[i] > N - L and ylocations[i] > N-L:
                gradient_sum[xlocations[i]:xlocations[i]+L,ylocations[i]:N] += gradient_tensor[i,:,0:N-ylocations[i]]
                gradient_sum[xlocations[i]:xlocations[i]+L,0:L+ylocations[i]-N] += gradient_tensor[i,:,N-ylocations[i]:]
            else:
                gradient_sum[xlocations[i]:N,ylocations[i]:N] += gradient_tensor[i,0:N-xlocations[i],0:N-ylocations[i]]
                gradient_sum[0:L+xlocations[i]-N,0:L+ylocations[i]-N] += gradient_tensor[i,N-xlocations[i]:N,N-ylocations[i]:]
                gradient_sum[xlocations[i]:N,0:L+ylocations[i]-N] += gradient_tensor[i,0:N-xlocations[i],N-ylocations[i]:]
                gradient_sum[0:L+xlocations[i]-N,ylocations[i]:N] += gradient_tensor[i,N-xlocations[i]:,0:N-ylocations[i]]      
        imageRecoverFT = imageRecoverFT - lr * gradient_sum/counts
        if iter_%1 == 0:
            imageRecover=np.fft.ifft2(np.fft.ifftshift(imageRecoverFT*np.sqrt(counts))) * N
            print(iter_,minAngMSE(xt.flatten(),imageRecover.flatten()),end='\r',flush=True)
    imageRecover=np.fft.ifft2(np.fft.ifftshift(imageRecoverFT*np.sqrt(counts))) * N
    MSE[ct] = minAngMSE(xt.flatten(),imageRecover.flatten())
print('photons = %d, pitch = %d,L = %d, M = %d, MSE = %f.'%(photons,pitch,L,M,np.mean(MSE)))

photons = 1000000, pitch = 1,L = 16, M = 32, MSE = 1.115040.


In [None]:
np.mean(MSE[0:18])

1.1814776699627854

In [None]:
MSE[0:10]

array([1.10005762, 1.09636615, 1.1067594 , 1.09373958, 1.10623215,
       1.0996757 , 1.10551584, 1.09144104, 1.10184666, 1.10203794])

In [None]:
print('photons = %d, pitch = %d,L = %d, M = %d, MSE = %f.'%(photons,pitch,L,M,np.mean(MSE)))

photons = 1000000, pitch = 4,L = 32, M = 96, MSE = 1.102353.


# results

|pitch|L = 128, M = 256|L = 64, M = 128|L = 64, M = 192|L = 64, M = 256|
|-----|----------------|---------------|---------------|---------------|
|64   |1.225894        |N/A            |N/A            |N/A            |
|32   |1.127360        |1.191672       |1.178917       |1.176729       |
|16   |1.110422        |1.109887       |1.108261       |1.108130       |
|8    |1.126297        |1.096419       |1.100261       |1.115072       |
|4    |1.445           |1.112696       |               |               |

|pitch|L = 32,M = 64|L = 32,M = 96|L = 32, M = 128|L = 32, M = 256|L = 16, M = 32|L = 16, M = 64|L = 16, M = 128|
|-----|-------------|-------------|---------------|---------------|--------------|--------------|---------------|
|64   |N/A          |N/A          |N/A            |N/A            |N/A           |N/A           |N/A            |
|32   |N/A          |N/A          |N/A            |N/A            |N/A           |N/A           |N/A            |
|16   |1.191906     |1.177208     |1.175811       |1.177193       |N/A           |N/A           |N/A            |
|8    |1.108914     |1.107747     |1.108682       |1.129494       |1.197728      |1.181478      |1.183515       | 
|4    |1.095429     |1.101360     |1.112736       |               |1.114534      |              |               |
|2    |1.109849     |             |               |               |1.100355      |              |               |
|1    |             |             |               |               |1.115040      |              |               |



In [None]:
(1.100367+1.102353)/2

1.1013600000000001

In [None]:
2.2589870000000003/2

1.1294935000000002

|L = 64, M = 128   | no mask||L = 32, M = 64   | no mask|
|-|-|-|-|-|
|pitch 32    |   1.191672||||
|pitch 16   |    1.109887| |pitch16||
|pitch 8 |1.096419||pitch 8||
|pitch 4 |1.112696||pitch 4||


|L (M = 2L)       |random  mask  |  period      |      no mask|   |L (M = 2L-1) |MURA|random mask|
 |-|-|-|-|-|-|-|-|
|L = 16,pitch = 4| 1.033013      |  *1.068053*   |    1.114833|   |L = 17,pitch = 4 |1.039511 |*1.031405*|
|L = 32,pitch = 8 | 1.026172      |   *1.065940*   |    1.110008|   |L = 29,pitch = 8 |1.037748 ||
|L = 64,pitch = 16 | 1.025036      |  *1.064933*    |    1.110296|   |L = 61,pitch = 16 | 1.030068|1.027090|
|L = 128,pitch = 32| 1.023604      |  *1.065430*   |    1.127946|   |L = 113,pitch = 32|1.034224 ||
|L = 128,pitch = 16|1.007899 | | | |L = 127,pitch = 16| 1.008056|1.007153 ||
|| | | | |L = 127,pitch = 8|*1.016922*  |  |

- Italic digits are simulted on 10 samples, others are on 100 samples.
- Because MURA has odd number of pixels each dimension, some points (in fourier domain) are measured more times than other points.
- Decreasing the pitch (subaperture shift) from 50% to 12.5% of aperture width, the MSE decreases, with or without mask. However, when pitch is 6.25% of aperture width, the MSE increases. (I have tested with photons = 1e7, it still happens, so it is not related to the low energy per sample.)

|L (M = 2L)       |random  mask  |  period      |      no mask|   |L (M = 2L-1) |MURA|random mask|
 |-|-|-|-|-|-|-|-|
|L = 16,pitch = 4| 1.033013      |  *1.068053*   |    1.114833|   |L = 17,pitch = 4 |1.039511 |*1.031405*|
|L = 32,pitch = 8 | 1.026172      |   *1.065940*   |    1.110008|   |L = 31,pitch = 8 | 1.034318|1.027950|
|L = 64,pitch = 16 | 1.025036      |  *1.064933*    |    1.110296|   |L = 61,pitch = 16 | 1.030068|1.027090|
|L = 128,pitch = 32| 1.023604      |  *1.065430*   |    1.127946|   |L = 127,pitch = 32| 1.026307|1.024718|
|L = 128,pitch = 16|1.007899 | | | |L = 127,pitch = 16| 1.008056|1.007153 ||
|| | | | |L = 127,pitch = 8|*1.016922*  |  |

|L (M = 2L)       |random  mask  |  period      |      no mask|   |L (M = 2L-1) |MURA|random mask|
 |-|-|-|-|-|-|-|-|
|L = 16,pitch = 4| 1.033013      |  *1.068053*   |    1.114833|   |L = 17,pitch = 4 |1.039511 ||
|L = 32,pitch = 8 | 1.026172      |   *1.065940*   |    1.110008|   |L = 31,pitch = 8 | 1.034318|1.027950|
|L = 64,pitch = 16 | 1.025036      |  *1.064933*    |    1.110296|   |L = 61,pitch = 16 | 1.030068|1.027090|
|L = 128,pitch = 32| 1.023604      |  *1.065430*   |    1.127946|   |L = 127,pitch = 32| 1.026307|1.024718|
|L = 128,pitch = 16|1.007899 | | | |L = 127,pitch = 16| 1.008056|1.007153 ||
|| | | | |L = 127,pitch = 8|*1.016922*  |  |

decreasing: pitch > L/8

increasing: pitch < L/8

In [None]:
256-32

224

In [None]:
MSE = np.zeros([num_trial])
for ct in range(num_trial):
    print(ct,end='\r',flush=True)
    
    ## signal
#     xt = np.random.randn(N,N)+1j*np.random.randn(N,N);
#     xt = np.sqrt(photons)*xt/np.sqrt(2.)
    xt = XT[ct,:,:]
    objectFT = np.fft.fftshift(np.fft.fft2(xt))/N/np.sqrt(counts)
    objectFT = matlib.repmat(objectFT, 2, 2)
    
    ## measurement
    imSeqLowRes = np.zeros([total_frame, M, M]);
    for i in range(total_frame):
        imSeqLowFT = objectFT[xlocations[i]:xlocations[i] + L,ylocations[i]:ylocations[i] + L]
        if is_mask:
            imSeqLowFT = np.pad(imSeqLowFT * mask,pad_size)
        else:
            imSeqLowFT = np.pad(imSeqLowFT,pad_size)
        imSeqLowRes[i,:,:] = np.abs(np.fft.ifft2(np.fft.ifftshift(imSeqLowFT))) * M
    if is_noisy:
        imSeqLowRes = np.sqrt(np.random.poisson(np.power(np.abs(imSeqLowRes),2)))

#     #energy
#     print(np.sum(np.power(abs(xt.flatten()),2)))
#     print(np.sum(np.power(abs(imSeqLowRes.flatten()),2)))

    # Projection
    imageRecover = np.random.randn(N,N);
    imageRecoverFT = np.fft.fftshift(np.fft.fft2(imageRecover))
    imageRecoverFT = objectFT[0:N,0:N] + 10
    seq = list(range(total_frame))
    for loop in range(num_loop):
        random.shuffle(seq)
        for i in seq:
            if not xlocations[i] > N - L and not ylocations[i] > N-L:
                if is_mask:
                    imLowRes = np.fft.ifft2(np.fft.ifftshift(np.pad(imageRecoverFT[xlocations[i]:xlocations[i]+L,ylocations[i]:ylocations[i]+L] * mask,pad_size)))
                else:
                    imLowRes = np.fft.ifft2(np.fft.ifftshift(np.pad(imageRecoverFT[xlocations[i]:xlocations[i]+L,ylocations[i]:ylocations[i]+L],pad_size)))
                        
                imLowRes = imSeqLowRes[i,:,:] * np.exp(1j * np.angle(imLowRes))
                inverse = np.fft.fftshift(np.fft.fft2(imLowRes)) / M
                if is_mask:
                    imageRecoverFT[xlocations[i]:xlocations[i]+L,ylocations[i]:ylocations[i]+L] = inverse[pad_size:pad_size+L,pad_size:pad_size+L]/mask
                else:
                    imageRecoverFT[xlocations[i]:xlocations[i]+L,ylocations[i]:ylocations[i]+L] = inverse[pad_size:pad_size+L,pad_size:pad_size+L]
            else:
                imageRecoverFTRep = matlib.repmat(imageRecoverFT, 2, 2)
                if is_mask:
                    imLowRes = np.fft.ifft2(np.fft.ifftshift(np.pad(imageRecoverFTRep[xlocations[i]:xlocations[i]+L,ylocations[i]:ylocations[i]+L] * mask,pad_size)))
                else:
                    imLowRes = np.fft.ifft2(np.fft.ifftshift(np.pad(imageRecoverFTRep[xlocations[i]:xlocations[i]+L,ylocations[i]:ylocations[i]+L],pad_size)))
                imLowRes = imSeqLowRes[i,:,:] * np.exp(1j * np.angle(imLowRes))
                inverse = np.fft.fftshift(np.fft.fft2(imLowRes)) / M
                if is_mask:
                    inverse = inverse[pad_size:pad_size+L,pad_size:pad_size+L]/mask
                else:
                    inverse = inverse[pad_size:pad_size+L,pad_size:pad_size+L]
                if xlocations[i] > N - L and not ylocations[i] > N-L:
                    imageRecoverFT[xlocations[i]:N,ylocations[i]:ylocations[i]+L] = inverse[0:N-xlocations[i],:]
                    imageRecoverFT[0:L+xlocations[i]-N,ylocations[i]:ylocations[i]+L] = inverse[N-xlocations[i]:,:]
                elif not xlocations[i] > N - L and ylocations[i] > N-L:
                    imageRecoverFT[xlocations[i]:xlocations[i]+L,ylocations[i]:N] = inverse[:,0:N-ylocations[i]]
                    imageRecoverFT[xlocations[i]:xlocations[i]+L,0:L+ylocations[i]-N] = inverse[:,N-ylocations[i]:]
                else:
                    imageRecoverFT[xlocations[i]:N,ylocations[i]:N] = inverse[0:N-xlocations[i],0:N-ylocations[i]]
                    imageRecoverFT[0:L+xlocations[i]-N,0:L+ylocations[i]-N] = inverse[N-xlocations[i]:N,N-ylocations[i]:]
                    imageRecoverFT[xlocations[i]:N,0:L+ylocations[i]-N] = inverse[0:N-xlocations[i],N-ylocations[i]:]
                    imageRecoverFT[0:L+xlocations[i]-N,ylocations[i]:N] = inverse[N-xlocations[i]:,0:N-ylocations[i]]
        if loop%1 == 0:
            imageRecover=np.fft.ifft2(np.fft.ifftshift(imageRecoverFT*np.sqrt(counts))) * N
            print(loop,minAngMSE(xt.flatten(),imageRecover.flatten()),end='\r',flush=True)
#         if minAngMSE(xt.flatten(),imageRecover.flatten()) < 5.7:
#             break

    ## GD
    # imageRecover = np.random.randn(N,N)+1j*np.random.randn(N,N)
    # imageRecover = np.sqrt(photons)*imageRecover/np.sqrt(2.)  
    # imageRecoverFT = np.fft.fftshift(np.fft.fft2(imageRecover))/4/N
    loss_pre = 0
    gradient_tensor = np.zeros([total_frame,L,L]) * 1j
    for iter_ in range(num_iteration):

        ## generate batch
        imageRecoverFTRep = matlib.repmat(imageRecoverFT, 2, 2)
        batch_tensor = np.zeros([total_frame ,L, L]) * 1j
        for i in range(total_frame):
            batch_tensor[i,:,:] = imageRecoverFTRep[xlocations[i]:xlocations[i]+L,ylocations[i]:ylocations[i]+L]

#         current_loss = abs(sess.run(loss,feed_dict={X_FT_EST:batch_tensor,MEASUREMENT:imSeqLowRes})[0])
        
#         if abs(loss_pre - current_loss )<tol_gd:  ## early stop
#             break
#         loss_pre = current_loss

        gradient_sum = np.zeros([N,N]) * 1j
        quater = int(total_frame/4)
        gradient_tensor[0:quater,:,:] = sess.run(gradient,feed_dict={X_FT_EST:batch_tensor[0:quater,:,:],MEASUREMENT:imSeqLowRes[0:quater,:,:]})[0]
        gradient_tensor[quater:2*quater,:,:] = sess.run(gradient,feed_dict={X_FT_EST:batch_tensor[quater:2*quater,:,:],MEASUREMENT:imSeqLowRes[quater:2*quater,:,:]})[0]
        gradient_tensor[2*quater:3*quater,:,:] = sess.run(gradient,feed_dict={X_FT_EST:batch_tensor[2*quater:3*quater,:,:],MEASUREMENT:imSeqLowRes[2*quater:3*quater,:,:]})[0]
        gradient_tensor[3*quater:4*quater,:,:] = sess.run(gradient,feed_dict={X_FT_EST:batch_tensor[3*quater:4*quater,:,:],MEASUREMENT:imSeqLowRes[3*quater:4*quater,:,:]})[0]
        for i in range(total_frame):
            if not xlocations[i] > N - L and not ylocations[i] > N-L:
                gradient_sum[xlocations[i]:xlocations[i]+L,ylocations[i]:ylocations[i]+L] += gradient_tensor[i,:,:]
            elif xlocations[i] > N - L and not ylocations[i] > N-L:
                gradient_sum[xlocations[i]:N,ylocations[i]:ylocations[i]+L] += gradient_tensor[i,0:N-xlocations[i],:]
                gradient_sum[0:L+xlocations[i]-N,ylocations[i]:ylocations[i]+L] += gradient_tensor[i,N-xlocations[i]:,:]
            elif not xlocations[i] > N - L and ylocations[i] > N-L:
                gradient_sum[xlocations[i]:xlocations[i]+L,ylocations[i]:N] += gradient_tensor[i,:,0:N-ylocations[i]]
                gradient_sum[xlocations[i]:xlocations[i]+L,0:L+ylocations[i]-N] += gradient_tensor[i,:,N-ylocations[i]:]
            else:
                gradient_sum[xlocations[i]:N,ylocations[i]:N] += gradient_tensor[i,0:N-xlocations[i],0:N-ylocations[i]]
                gradient_sum[0:L+xlocations[i]-N,0:L+ylocations[i]-N] += gradient_tensor[i,N-xlocations[i]:N,N-ylocations[i]:]
                gradient_sum[xlocations[i]:N,0:L+ylocations[i]-N] += gradient_tensor[i,0:N-xlocations[i],N-ylocations[i]:]
                gradient_sum[0:L+xlocations[i]-N,ylocations[i]:N] += gradient_tensor[i,N-xlocations[i]:,0:N-ylocations[i]]      
        imageRecoverFT = imageRecoverFT - lr * gradient_sum/counts
        if iter_%1 == 0:
            imageRecover=np.fft.ifft2(np.fft.ifftshift(imageRecoverFT*np.sqrt(counts))) * N
            print(iter_,minAngMSE(xt.flatten(),imageRecover.flatten()))
    imageRecover=np.fft.ifft2(np.fft.ifftshift(imageRecoverFT*np.sqrt(counts))) * N
    MSE[ct] = minAngMSE(xt.flatten(),imageRecover.flatten())
print('photons = %d, pitch = %d,L = %d, M = %d, MSE = %f.'%(photons,pitch,L,M,np.mean(MSE)))

0 192.76839499548078
1 99.91567951394245
2 54.36699071037583
3 31.1281609616708
4 18.823639750713482
5 12.06997874029912
6 8.229133875574917
7 5.966254276689142
8 4.585195736493915
9 3.7122472829651088
10 3.1410764548927546
11 2.7545914959628135
12 2.4845264572650194
13 2.290009975200519
14 2.145923960953951
15 2.0364304708782583
16 1.951289305696264
17 1.8837160470429808
18 1.8291085872333497
19 1.7842737992759794
20 1.7469479783903807
21 1.715493369847536
22 1.6887023337185383
23 1.6656684495974332
24 1.6457001818343997
25 1.6282619086559862
26 1.6129330324474722
27 1.599379074992612
28 1.587330993032083
29 1.5765700936317444
30 1.5669169209431857
31 1.5582229369319975
32 1.5503642074763775
33 1.543236592784524
34 1.5367519895080477
35 1.530835394980386
36 1.5254226250108331
37 1.520458438899368
38 1.515895067481324
39 1.5116910196375102
40 1.5078101137187332
41 1.5042206263169646
42 1.5008946976158768
43 1.4978077372070402
44 1.4949379763565958
45 1.4922660852316767
46 1.48977483622

KeyboardInterrupt: 

MSE vs photons

1e7 1.108
1e6 1.110
1e5 1.113
1e4 1.129
1e3 1.280
1e2 2.596