In [None]:
import cupy as cp
import sys
import pandas as pd
import time
from utils import write_tiff, read_tiff
from utils import mshow, mshow_polar, mshow_complex
cp.cuda.Device(1).use()

In [2]:
n = 256  # data size in each dimension
nobj = 1024 # object size in each dimension
pad = 0#n//16 # pad for the reconstructed probe
nprb = n+2*pad # probe size
extra = 8 # extra padding for shifts
npatch = nprb+2*extra # patch size for shifts
energy = 9
wavelength = 1.24e-09/energy  # [m] wave length
npos = 96 # total number of positions
voxelsize = 8e-09

noise=False
show = True # do visualization or not at all

path = f'/data/vnikitin/paper/far_field' # input data path 
path_out = f'/data/vnikitin/paper/far_field/rec' # input data path 

In [None]:

def Lop(psi):   
    """Forward propagator""" 

    # convolution
    # ff = cp.pad(psi,((0,0),(nprb//2,nprb//2),(nprb//2,nprb//2)))
    ff = cp.fft.fft2(psi,norm='ortho')
    # ff = ff[:,nprb//2:-nprb//2,nprb//2:-nprb//2]
    
    # crop to detector size
    ff = ff[:,pad:nprb-pad,pad:nprb-pad]
    return ff

def LTop(psi):
    """Adjoint propagator""" 

    # pad to the probe size
    ff = cp.pad(psi,((0,0),(pad,pad),(pad,pad)))    
    
    # convolution
    # ff = cp.pad(ff,((0,0),(nprb//2,nprb//2),(nprb//2,nprb//2)))    
    ff = cp.fft.ifft2(ff,norm='ortho')
    # ff = ff[:,nprb//2:-nprb//2,nprb//2:-nprb//2]
    return ff

def Ex(psi,ix):
    """Extract patches"""

    res = cp.empty([ix.shape[0],npatch,npatch],dtype='complex64')
    stx = nobj//2-ix[:,1]-npatch//2
    endx = stx+npatch
    sty = nobj//2-ix[:,0]-npatch//2
    endy = sty+npatch
    for k in range(len(stx)):
        res[k] = psi[sty[k]:endy[k],stx[k]:endx[k]]     
    return res

def ExT(psi,psir,ix):
    """Adjoint extract patches"""

    stx = nobj//2-ix[:,1]-npatch//2
    endx = stx+npatch
    sty = nobj//2-ix[:,0]-npatch//2
    endy = sty+npatch
    for k in range(len(stx)):
        psi[sty[k]:endy[k],stx[k]:endx[k]] += psir[k]
    return psi

def S(psi,p):
    """Subpixel shift"""

    x = cp.fft.fftfreq(npatch).astype('float32')
    [y, x] = cp.meshgrid(x, x)
    pp = cp.exp(-2*cp.pi*1j * (y*p[:, 1, None, None]+x*p[:, 0, None, None])).astype('complex64')
    res = cp.fft.ifft2(pp*cp.fft.fft2(psi))
    return res

def Sop(psi,ix,x,ex):
    """Extract patches with subpixel shift"""
    data = cp.zeros([x.shape[1], nprb, nprb], dtype='complex64')
    psir = Ex(psi,ix)     
    psir = S(psir,x)
    data = psir[:, ex:npatch-ex, ex:npatch-ex]
    return data

def STop(d,ix,x,ex):
    """Adjont extract patches with subpixel shift"""
    psi = cp.zeros([nobj, nobj], dtype='complex64')
    dr = cp.pad(d, ((0, 0), (ex, ex), (ex, ex)))
    dr = S(dr,-x)        
    ExT(psi,dr,ix)
    return psi

# adjoint tests
shifts_test = 30*(cp.random.random([npos,2])-0.5).astype('float32')
ishifts = shifts_test.astype('int32')
fshifts = shifts_test-ishifts

arr1 = (cp.random.random([nobj,nobj])+1j*cp.random.random([nobj,nobj])).astype('complex64')
arr2 = Ex(arr1,ishifts)
arr3 = arr1*0
ExT(arr3,arr2,ishifts)
print(f'{cp.sum(arr1*cp.conj(arr3))}==\n{cp.sum(arr2*cp.conj(arr2))}')

arr1 = (cp.random.random([nobj,nobj])+1j*cp.random.random([nobj,nobj])).astype('complex64')
arr2 = Sop(arr1,ishifts,fshifts,extra)
arr3 = STop(arr2,ishifts,fshifts,extra)
print(f'{cp.sum(arr1*cp.conj(arr3))}==\n{cp.sum(arr2*cp.conj(arr2))}')

arr1 = (cp.random.random([npos,nprb,nprb])+1j*cp.random.random([npos,nprb,nprb])).astype('complex64')
arr2 = Lop(arr1)
arr3 = LTop(arr2)
print(f'{cp.sum(arr1*cp.conj(arr3))}==\n{cp.sum(arr2*cp.conj(arr2))}')

# read data

In [None]:
import random
random_indices = random.sample(range(0, 962), npos)
print(random_indices)

shifts = cp.load(f'{path}/data/gen_shifts.npy')[random_indices]
shifts_random = cp.load(f'{path}/data/gen_shifts_random.npy')[random_indices]
prb = cp.load(f'{path}/data/gen_prb.npy')
deformed_prb = cp.load(f'{path}/data/deformed_prb.npy')
if noise:
    data = cp.load(f'{path}/data/ndata.npy')[random_indices]
else:
    data = cp.load(f'{path}/data/data.npy')[random_indices]
ref = cp.load(f'{path}/data/ref.npy')
psi = cp.load(f'{path}/data/psi.npy')

mshow_polar(prb,show)
mshow_polar(deformed_prb,show)
mshow_complex(cp.fft.fftshift(data[0]+1j*ref,axes=(-2,-1)),show,vmax=0.15)

##### $$\nabla F=2 \left(L^*\left( L\psi-\tilde d\right)\right).$$
##### where $$\tilde d = d \frac{L(\psi)}{|L(\psi)|}$$




In [5]:
def gradientF(pars, reused, d):
    Lpsi =  reused['Lpsi']    
    if pars['model']=='Gaussian':
        td = d*(Lpsi/(cp.abs(Lpsi)+pars['eps']))                
        res = 2*LTop(Lpsi - td)        
    elif pars['model']=='Poisson':
        dd = d*Lpsi/(cp.abs(Lpsi)**2+pars['eps']**2) 
        res = 2*LTop(Lpsi-dd)  
    reused['gradF'] = res    

##### $$\nabla_{\psi} G|_{(q_0,\psi_0,\boldsymbol{x}_0)}= S_{\boldsymbol{x}_{0}}^*\left(\overline{J(q_0)}\cdot \nabla F\right)$$

##### $$\nabla_{q} G|_{(q_0,\psi_0,\boldsymbol{x}_0)}=J^*\left( \overline{S_{\boldsymbol{x}_{0}}(C_f^*(\psi_0)+\psi_{fr})}\cdot \nabla F\right).$$
##### $$\nabla_{\boldsymbol{x}_0} G|_{(q_0,\psi_0,\boldsymbol{x}_0)}=\textsf{Re}\Big(\big( \Big\langle \overline{q_0}\cdot \nabla F,   C(\mathcal{F}^{-1}(-2\pi i \xi_1 e^{ -2\pi i \boldsymbol{x}_{0,k}\cdot \boldsymbol{\xi}}\hat{\psi_0}))\Big\rangle,\Big\langle \overline{q_0}\cdot \nabla F,C(\mathcal{F}^{-1}(-2\pi i \xi_2 e^{ -2\pi i \boldsymbol{x}_{0,k}\cdot \boldsymbol{\xi}}\hat{\psi_0})) \Big\rangle\big)\Big)_{k=1}^K. $$




In [6]:
def gradient_psi(q,ix,x,ex,gradF):
    return STop(cp.conj(q)*gradF,ix,x,ex)

def gradient_prb(spsi,gradF):
    return cp.sum(cp.conj(spsi)*gradF,axis=0)

def gradient_shift(psi, q, ix, x, ex, gradF):    
    # frequencies
    xi1 = cp.fft.fftfreq(npatch).astype('float32')
    xi2, xi1 = cp.meshgrid(xi1, xi1)

    # multipliers in frequencies
    w = cp.exp(-2 * cp.pi * 1j * (xi2 * x[:, 1, None, None] + xi1 * x[:, 0, None, None]))
    
    # Gradient parts
    tmp = Ex(psi, ix)
    tmp = cp.fft.fft2(tmp) 

    dt1 = cp.fft.ifft2(w*xi1*tmp)
    dt2 = cp.fft.ifft2(w*xi2*tmp)
    dt1 = -2 * cp.pi * dt1[:,ex:nprb+ex,ex:nprb+ex]
    dt2 = -2 * cp.pi * dt2[:,ex:nprb+ex,ex:nprb+ex]
    
    # inner product with gradF
    gradx = cp.zeros([npos, 2], dtype='float32')
    gradx[:, 0] = imdot(gradF, q * dt1, axis=(1, 2))
    gradx[:, 1] = imdot(gradF, q * dt2, axis=(1, 2))
    return gradx

def gradients(vars,pars,reused):    
    (q,psi,x) = (vars['prb'], vars['psi'], vars['fshift'])
    (ix,ex,rho) = (pars['ishift'],pars['extra'],pars['rho'])
    (gradF, spsi) = (reused['gradF'],reused['spsi'])
    dpsi = gradient_psi(q,ix,x,ex,gradF)
    dprb = gradient_prb(spsi,gradF)
    dx = gradient_shift(psi,q,ix,x,ex,gradF)
    grads={'psi': rho[0]*dpsi, 'prb': rho[1]*dprb, 'fshift': rho[2]*dx}
    return grads

##### $$\frac{1}{2}\mathcal{H}|_{x_0}(y,z)= \left\langle \mathbf{1}-d_{0}, \mathsf{Re}({L(y)}\overline{L(z)})\right\rangle+\left\langle d_{0},(\mathsf{Re} (\overline{l_0}\cdot L(y)))\cdot (\mathsf{Re} (\overline{l_0}\cdot L(z))) \right\rangle$$
##### $$l_0=L(x_0)/|L(x_0)|$$
##### $$d_0=d/|L(x_0)|$$


In [7]:
def hessianF(Lm,Ldm1,Ldm2,data,pars):
    if pars['model']=='Gaussian':
        psi0p = Lm/(cp.abs(Lm)+pars['eps'])
        d0 = data/(cp.abs(Lm)+pars['eps'])
        v1 = cp.sum((1-d0)*reprod(Ldm1,Ldm2))
        v2 = cp.sum(d0*reprod(psi0p,Ldm1)*reprod(psi0p,Ldm2))        
    else:        
        psi0p = Lm/(cp.abs(Lm)+pars['eps'])            
        v1 = cp.sum((1-data/(cp.abs(Lm)**2+pars['eps']**2))*reprod(Ldm1,Ldm2))
        v2 = 2*cp.sum(data*reprod(psi0p,Ldm1)*reprod(psi0p,Ldm2)/(cp.abs(Lm)**2+pars['eps']**2))
    return 2*(v1+v2)

##### $$\mathcal{H}^G|_{ (q_0,\psi_0,\boldsymbol{x}_0)}\Big((\Delta q^{(1)},\Delta \psi^{(1)},\Delta \boldsymbol{x}^{(1)}),(\Delta q^{(2)},\Delta \psi^{(2)},\Delta \boldsymbol{x}^{(2)})\Big)=$$
##### $$\Big\langle \nabla F|_{M(q_0,\psi_0,\boldsymbol{x}_0)}, D^2M|_{(q_0,\psi_0,\boldsymbol{x}_0)}\Big((\Delta q^{(1)},\Delta \psi^{(1)},\Delta \boldsymbol{x}^{(1)}),(\Delta q^{(2)},\Delta \psi^{(2)},\Delta \boldsymbol{x}^{(2)})\Big)\Big\rangle +$$
##### $$\mathcal{H}^F|_{M(q_0,\psi_0,\boldsymbol{x}_0)}\Big(DM|_{(q_0,\psi_0,\boldsymbol{x}_0)}(\Delta q^{(1)},\Delta \psi^{(1)},\Delta \boldsymbol{x}^{(1)}),DM|_{(q_0,\psi_0,\boldsymbol{x}_0)}(\Delta q^{(2)},\Delta \psi^{(2)},\Delta \boldsymbol{x}^{(2)})\Big)$$


### Updates:

\begin{equation}
               \alpha_j=\frac{\mathsf{Re}\langle \nabla F|_{x_j},s_j\rangle}{H|_{x_j}( {s_j},s_j)}
             \end{equation}

\begin{equation}
                \beta_j=\frac{H(\nabla F|_{x_j},s_j)}{H|_{x_j}( {s_j},s_j)}.
\end{equation}

### Scaling variables:

\begin{equation}
\begin{aligned}
\tilde{\beta}_j=\frac{H^{\tilde{F}}|_{\tilde{x}_j} (\nabla \tilde{F}|_{\tilde{x}_j},\tilde{\eta}_j)}{H^{\tilde{F}}|_{\tilde{x}_j} (\tilde{\eta}_j,\tilde{\eta}_j)}=\frac{H^{F}|_{x_j} (\rho\nabla \tilde{F}|_{\tilde{x}_j},\rho\tilde{\eta}_j)}{H^{F}|_{x_j} (\rho\tilde{\eta}_j,\rho\tilde{\eta}_j)}=\frac{H^{F}|_{x_j} (\rho^2\nabla F|_{x_j},\rho\tilde{\eta}_j)}{H^{F}|_{x_j} (\rho\tilde{\eta}_j,\rho\tilde{\eta}_j)}
\end{aligned}
\end{equation}

\begin{equation}
\begin{aligned}
\tilde{\alpha}_j=\frac{\langle\nabla \tilde{F}|_{\tilde{x}_j},\tilde{\eta}_j\rangle}{H^{\tilde{F}}|_{\tilde{x}_j} (\tilde{\eta}_j,\tilde{\eta}_j)}=\frac{\langle \rho\nabla F|_{x_j},\tilde{\eta}_j\rangle}{H^{F}|_{x_j} (\rho\tilde{\eta}_j,\rho\tilde{\eta}_j)}
\end{aligned}
\end{equation}

\begin{equation}
    \begin{aligned}
        \tilde{\eta}_{j+1} = -\nabla \tilde{F}|_{\tilde{x}_j}+\tilde{\beta}_j\tilde{\eta}_j=-\rho\nabla F|_{x_j}+\tilde{\beta}_j\tilde{\eta}_j,\quad \text{with } \tilde{\eta}_0=-\rho\nabla F|_{x_0}
    \end{aligned}
\end{equation}

\begin{equation}
    \begin{aligned}
        \tilde{x}_{j+1} = \tilde{x}_{j}+\tilde{\alpha}_j\tilde{\eta}_{j+1}
    \end{aligned}
\end{equation}

Multiplying both sides by $\rho$,

\begin{equation}
    \begin{aligned}
        x_{j+1} = x_j+\rho\tilde{\alpha}_j\tilde{\eta}_{j+1}
    \end{aligned}
\end{equation}

# Optimized version, without extra functions

In [8]:
def calc_beta(vars,grads,etas,pars,reused,d):
    (q,psi,x) = (vars['prb'], vars['psi'], vars['fshift'])    
    (ix,ex,rho) = (pars['ishift'],pars['extra'],pars['rho'])
    (spsi,Lpsi,gradF) = (reused['spsi'], reused['Lpsi'], reused['gradF'])
    
    # note scaling with rho
    (dpsi1,dq1,dx1) = (grads['psi']*rho[0], grads['prb']*rho[1], grads['fshift']*rho[2])
    (dpsi2,dq2,dx2) = (etas['psi']*rho[0], etas['prb']*rho[1], etas['fshift']*rho[2])
        
    # frequencies
    xi1 = cp.fft.fftfreq(npatch).astype('float32')
    [xi2, xi1] = cp.meshgrid(xi1, xi1)    

    # multipliers in frequencies
    dx1 = dx1[:,:,cp.newaxis,cp.newaxis]
    dx2 = dx2[:,:,cp.newaxis,cp.newaxis]
    w = cp.exp(-2*cp.pi*1j * (xi2*x[:, 1, None, None]+xi1*x[:, 0, None, None]))
    w1 = xi1*dx1[:,0]+xi2*dx1[:,1]
    w2 = xi1*dx2[:,0]+xi2*dx2[:,1]
    w12 = xi1**2*dx1[:,0]*dx2[:,0]+ \
                xi1*xi2*(dx1[:,0]*dx2[:,1]+dx1[:,1]*dx2[:,0])+ \
                xi2**2*dx1[:,1]*dx2[:,1]
    w22 = xi1**2*dx2[:,0]**2+ 2*xi1*xi2*(dx2[:,0]*dx2[:,1]) + xi2**2*dx2[:,1]**2
    
    # DT, D2T terms
    tmp1 = Ex(dpsi1,ix)     
    tmp1 = cp.fft.fft2(tmp1)
    sdpsi1 = cp.fft.ifft2(w*tmp1)[:,ex:nprb+ex,ex:nprb+ex]
    dt12 = -2*cp.pi*1j*cp.fft.ifft2(w*w2*tmp1)[:,ex:nprb+ex,ex:nprb+ex]
    
    tmp2 = Ex(dpsi2,ix)     
    tmp2 = cp.fft.fft2(tmp2)
    sdpsi2 = cp.fft.ifft2(w*tmp2)[:,ex:nprb+ex,ex:nprb+ex]
    dt21 = -2*cp.pi*1j*cp.fft.ifft2(w*w1*tmp2)[:,ex:nprb+ex,ex:nprb+ex]
    dt22 = -2*cp.pi*1j*cp.fft.ifft2(w*w2*tmp2)[:,ex:nprb+ex,ex:nprb+ex]
    
    tmp = Ex(psi,ix)     
    tmp = cp.fft.fft2(tmp)        
    dt1 = -2*cp.pi*1j*cp.fft.ifft2(w*w1*tmp)[:,ex:nprb+ex,ex:nprb+ex]
    dt2 = -2*cp.pi*1j*cp.fft.ifft2(w*w2*tmp)[:,ex:nprb+ex,ex:nprb+ex]
    d2t1 = -4*cp.pi**2*cp.fft.ifft2(w*w12*tmp)[:,ex:nprb+ex,ex:nprb+ex]
    d2t2 = -4*cp.pi**2*cp.fft.ifft2(w*w22*tmp)[:,ex:nprb+ex,ex:nprb+ex]
    
    # DM,D2M terms
    d2m1 =  q*dt12 + q*dt21 + q*d2t1
    d2m1 += dq1*sdpsi2 + dq2*sdpsi1
    d2m1 += dq1*dt2 + dq2*dt1

    d2m2 =  q*dt22 + q*dt22 + q*d2t2
    d2m2 += dq2*sdpsi2 + dq2*sdpsi2
    d2m2 += dq2*dt2 + dq2*dt2

    dm1 = dq1*spsi+q*(sdpsi1+dt1)   
    dm2 = dq2*spsi+q*(sdpsi2+dt2)   

    # top and bottom parts
    Ldm1 = Lop(dm1)
    Ldm2 = Lop(dm2) 
    top = redot(gradF,d2m1)+hessianF(Lpsi, Ldm1, Ldm2, d, pars)            
    bottom = redot(gradF,d2m2)+hessianF(Lpsi, Ldm2, Ldm2,d, pars)
    
    return top/bottom

def calc_alpha(vars,grads,etas,pars,reused,d):    
    (q,psi,x) = (vars['prb'], vars['psi'], vars['fshift'])    
    (ix,ex,rho) = (pars['ishift'],pars['extra'],pars['rho'])
    (dpsi1,dq1,dx1) = (grads['psi'], grads['prb'], grads['fshift'])
    (dpsi2,dq2,dx2) = (etas['psi'], etas['prb'], etas['fshift'])    
    (spsi,Lpsi,gradF) = (reused['spsi'],reused['Lpsi'], reused['gradF'])

    # top part
    top = -redot(dpsi1,dpsi2)-redot(dq1,dq2)-redot(dx1,dx2)
    
    # scale variable for the hessian
    (dpsi,dq,dx) = (etas['psi']*rho[0], etas['prb']*rho[1], etas['fshift']*rho[2])

    # frequencies        
    xi1 = cp.fft.fftfreq(npatch).astype('float32')    
    [xi2, xi1] = cp.meshgrid(xi1, xi1)

    # multipliers in frequencies
    dx = dx[:,:,cp.newaxis,cp.newaxis]
    w = cp.exp(-2*cp.pi*1j * (xi2*x[:, 1, None, None]+xi1*x[:, 0, None, None]))
    w1 = xi1*dx[:,0]+xi2*dx[:,1]
    w2 = xi1**2*dx[:,0]**2+ 2*xi1*xi2*(dx[:,0]*dx[:,1]) + xi2**2*dx[:,1]**2
    
    # DT,D2T terms, and Spsi
    tmp = Ex(dpsi,ix)     
    tmp = cp.fft.fft2(tmp)    
    sdpsi = cp.fft.ifft2(w*tmp)[:,ex:nprb+ex,ex:nprb+ex]
    dt2 = -2*cp.pi*1j*cp.fft.ifft2(w*w1*tmp)[:,ex:nprb+ex,ex:nprb+ex]
    
    tmp = Ex(psi,ix)     
    tmp = cp.fft.fft2(tmp)
    dt = -2*cp.pi*1j*cp.fft.ifft2(w*w1*tmp)[:,ex:nprb+ex,ex:nprb+ex]
    d2t = -4*cp.pi**2*cp.fft.ifft2(w*w2*tmp)[:,ex:nprb+ex,ex:nprb+ex]
    
    # DM and D2M terms
    d2m2 = q*(2*dt2 + d2t)+2*dq*sdpsi+2*dq*dt
    dm = dq*spsi+q*(sdpsi+dt)   
            
    # bottom part
    Ldm = Lop(dm)
    bottom = redot(gradF,d2m2)+hessianF(Lpsi, Ldm, Ldm,d,pars)
    
    return top/bottom, top, bottom

## minimization functional and calculation of reused arrays

In [9]:
def minf(Lpsi,d,pars):
    if pars['model']=='Gaussian':
        f = cp.linalg.norm(cp.abs(Lpsi)-d)**2/(n*n*npos)    
    else:        
        f = cp.sum(cp.abs(Lpsi)**2-2*d*cp.log(cp.abs(Lpsi)+pars['eps']))/(n*n*npos)          
        # loss = torch.nn.PoissonNLLLoss(log_input=False, full=True, size_average=None, eps=pars['eps'], reduce=None, reduction='sum')
        # f = loss(torch.as_tensor(cp.abs(Lpsi)**2,device='cuda'),torch.as_tensor(d,device='cuda'))    
    return f

def calc_reused(vars, pars):
    reused = {}
    psi = vars['psi']
    q = vars['prb']
    x = vars['fshift']
    ix = pars['ishift']
    ex = pars['extra']
    reused['spsi'] = Sop(psi,ix,x,ex)     
    reused['Lpsi'] = Lop(reused['spsi']*q)     
    return reused

## debug functions

In [10]:
def plot_debug(vars,etas,pars,top,bottom,alpha,data,i):
    '''Check the minimization functional behaviour'''
    if i % pars['vis_step'] == 0 and pars['vis_step'] != -1 and show:
        (q,psi,x) = (vars['prb'], vars['psi'], vars['fshift'])    
        (ix,ex,rho) = (pars['ishift'],pars['extra'],pars['rho'])
        (dpsi2,dq2,dx2) = (etas['psi'], etas['prb'], etas['fshift'])    

        npp = 7
        errt = cp.zeros(npp*2)
        errt2 = cp.zeros(npp*2)
        for k in range(0,npp*2):
            psit = psi+(alpha*k/(npp-1))*rho[0]*dpsi2
            qt = q+(alpha*k/(npp-1))*rho[1]*dq2
            xt = x+(alpha*k/(npp-1))*rho[2]*dx2

            errt[k] = minf(Lop(Sop(psit,ix,xt,ex)*qt),data,pars)
                    
        t = alpha*(cp.arange(2*npp))/(npp-1)    
        errt2 = minf(Lop(Sop(psi,ix,x,ex)*q),data,pars)
        errt2 = errt2 -top*t/(n*n*npos)+0.5*bottom*t**2/(n*n*npos)    
        
        plt.plot(alpha.get()*cp.arange(2*npp).get()/(npp-1),errt.get(),'.')
        plt.plot(alpha.get()*cp.arange(2*npp).get()/(npp-1),errt2.get(),'.')
        plt.show()

def vis_debug(vars,pars,i):
    '''Visualization and data saving'''
    if i % pars['vis_step'] == 0 and pars['vis_step'] != -1:
        (q,psi,x) = (vars['prb'], vars['psi'], vars['fshift'])        
        mshow_polar(psi,show)
        mshow_polar(q,show)
        write_tiff(cp.angle(psi),f'{path_out}_{pars['flg']}/crec_psi_angle/{i:03}')
        write_tiff(cp.abs(psi),f'{path_out}_{pars['flg']}/crec_psi_abs/{i:03}')
        write_tiff(cp.angle(q),f'{path_out}_{pars['flg']}/crec_prb_angle/{i:03}')
        write_tiff(cp.abs(q),f'{path_out}_{pars['flg']}/crec_prb_abs/{i:03}')
        cp.save(f'{path_out}_{pars['flg']}/crec_shift_{i:03}',x)
        

def error_debug(vars, pars, reused, data, i):
    '''Visualization and data saving'''
    if i % pars['err_step'] == 0 and pars['err_step'] != -1:
        err = minf(reused['Lpsi'],data,pars)
        print(f'{i}) {err=:1.5e}',flush=True)                        
        vars['table'].loc[len(vars['table'])] = [i, err.get(), time.time()]
        vars['table'].to_csv(f'{pars['flg']}', index=False)            

def grad_debug(alpha, grads, pars, i):
    if i % pars['grad_step'] == 0 and pars['grad_step'] != -1:
        print(f'(alpha,psi,prb,shift): {alpha:.1e} {cp.linalg.norm(grads['psi']):.1e},{cp.linalg.norm(grads['prb']):.1e},{cp.linalg.norm(grads['fshift']):.1e}')

# Bilinear Hessian method

In [None]:
def BH(data, vars, pars):
   
    if pars['model']=='Gaussian':
        # work with sqrt
        data = cp.sqrt(data)
        
    alpha = 1
    for i in range(pars['niter']):                             
        reused = calc_reused(vars, pars)
        error_debug(vars, pars, reused, data, i)
        vis_debug(vars, pars, i)            
      
        gradientF(pars,reused,data) 
        grads = gradients(vars,pars,reused)
        if i==0 or pars['method']=='BH-GD':
            etas = {}
            etas['psi'] = -grads['psi']
            etas['prb'] = -grads['prb']
            etas['fshift'] = -grads['fshift']
        else:      
            beta = calc_beta(vars, grads, etas, pars, reused, data)
            etas['psi'] = -grads['psi'] + beta*etas['psi']
            etas['prb'] = -grads['prb'] + beta*etas['prb']
            etas['fshift'] = -grads['fshift'] + beta*etas['fshift']

        
        alpha,top,bottom = calc_alpha(vars, grads, etas, pars, reused, data)         

        plot_debug(vars,etas,pars,top,bottom,alpha,data,i)
        grad_debug(alpha,grads,pars,i)
        
        vars['psi'] += pars['rho'][0]*alpha*etas['psi']
        vars['prb'] += pars['rho'][1]*alpha*etas['prb']
        vars['fshift'] += pars['rho'][2]*alpha*etas['fshift']
        
    return vars

# fixed variables
pars = {'niter':128, 'err_step': 1, 'vis_step': -1, 'grad_step': -1}
pars['rho'] = [1,0.02,0.2]
pars['ishift'] = cp.floor(shifts_random).astype('int32')
pars['extra'] = extra
pars['eps'] = 1e-9
pars['model'] = 'Gaussian'
pars['method'] = 'BH-CG'
pars['flg'] = f'{pars['method']}_{pars['rho'][0]}_{pars['rho'][1]}_{pars['rho'][2]}_{noise}'

vars = {}
vars['psi'] = cp.ones([nobj,nobj],dtype='complex64')
vars['prb'] = deformed_prb.copy()#cp.ones([nprb,nprb],dtype='complex64')
vars['fshift'] = cp.array(shifts_random-cp.floor(shifts_random).astype('int32')).astype('float32')
vars['table'] = pd.DataFrame(columns=["iter", "err", "time"])

vars = BH(data, vars, pars)      

mshow_polar(vars['psi'],mshow)
mshow_polar(vars['prb'],mshow)
erra = vars['table']['err'].values
# times=vars['table']['time'].values
# times-=times[0]
# print(times)
rec_pos = (vars['fshift']+pars['ishift'])
plt.plot(erra,label=pars['method'])
plt.yscale('log')
plt.show()

plt.plot(shifts[:,1].get(),shifts[:,0].get(),'r.')
plt.plot(rec_pos[:,1].get(),rec_pos[:,0].get(),'g.')
plt.axis('equal')
plt.show()
