In [1]:
import numpy as np
import cupy as cp
import sys
import pandas as pd
import time

import matplotlib.pyplot as plt
from utils import write_tiff, read_tiff
from utils import mshow, mshow_polar, mshow_complex

In [2]:
n = 2048  # data size in each dimension
nobj = 4096+4096#n+n//4+n//8 # object size in each dimension
pad = n//16 # pad for the reconstructed probe
nprb = n+2*pad # probe size
extra = 8 # extra padding for shifts
npatch = nprb+2*extra # patch size for shifts

npos = 16 # total number of positions
z1 = -17.75e-3# [m] position of the sample
detector_pixelsize = 3.03751e-6
energy = 33.35  # [keV] xray energy
wavelength = 1.24e-09/energy  # [m] wave length
focusToDetectorDistance = 1.28  # [m]
# adjustments for the cone beam
z2 = focusToDetectorDistance-z1
distance = (z1*z2)/focusToDetectorDistance
magnification = focusToDetectorDistance/z1
voxelsize = float(cp.abs(detector_pixelsize/magnification))

show = True # do visualization or not at all

path = f'/data/vnikitin/ESRF/ID16A/20240924/SiemensLH/code2um_nfp18x18_01'
path_out = f'/data/vnikitin/ESRF/ID16A/20240924_rec2/SiemensLH/code2um_nfp18x18_01'



# Fresnel kernel

In [None]:
fx = cp.fft.fftfreq(nprb, d=voxelsize)#.astype('float32')
[fx, fy] = cp.meshgrid(fx, fx)
fker = cp.exp(-1j*cp.pi*wavelength*distance*(fx**2+fy**2))
mshow_complex(cp.fft.fftshift(fker),mshow)

In [4]:

def Lop(psi):   
    """Forward propagator""" 

    # convolution
    # ff = cp.pad(psi,((0,0),(nprb//2,nprb//2),(nprb//2,nprb//2)))
    ff = psi
    ff = cp.fft.ifft2(cp.fft.fft2(ff)*fker)    
    # ff = ff[:,nprb//2:-nprb//2,nprb//2:-nprb//2]
    
    # crop to detector size
    ff = ff[:,pad:nprb-pad,pad:nprb-pad]
    return ff

def LTop(psi):
    """Adjoint propagator""" 

    # pad to the probe size
    ff = cp.pad(psi,((0,0),(pad,pad),(pad,pad)))    
    
    # convolution
    # ff = cp.pad(ff,((0,0),(nprb//2,nprb//2),(nprb//2,nprb//2)))    
    ff = cp.fft.ifft2(cp.fft.fft2(ff)/fker)
    # ff = ff[:,nprb//2:-nprb//2,nprb//2:-nprb//2]
    return ff

def Ex(psi,ix):
    """Extract patches"""

    res = cp.empty([ix.shape[0],npatch,npatch],dtype='complex64')
    stx = nobj//2-ix[:,1]-npatch//2
    endx = stx+npatch
    sty = nobj//2-ix[:,0]-npatch//2
    endy = sty+npatch
    for k in range(len(stx)):
        res[k] = psi[sty[k]:endy[k],stx[k]:endx[k]]     
    return res

def ExT(psi,psir,ix):
    """Adjoint extract patches"""

    stx = nobj//2-ix[:,1]-npatch//2
    endx = stx+npatch
    sty = nobj//2-ix[:,0]-npatch//2
    endy = sty+npatch
    for k in range(len(stx)):
        psi[sty[k]:endy[k],stx[k]:endx[k]] += psir[k]
    return psi

def S(psi,p):
    """Subpixel shift"""

    x = cp.fft.fftfreq(npatch).astype('float32')
    [y, x] = cp.meshgrid(x, x)
    pp = cp.exp(-2*cp.pi*1j * (y*p[:, 1, None, None]+x*p[:, 0, None, None])).astype('complex64')
    res = cp.fft.ifft2(pp*cp.fft.fft2(psi))
    return res

def Sop(psi,ix,x,ex):
    """Extract patches with subpixel shift"""
    data = cp.zeros([x.shape[1], nprb, nprb], dtype='complex64')
    psir = Ex(psi,ix)     
    psir = S(psir,x)
    data = psir[:, ex:npatch-ex, ex:npatch-ex]
    return data

def STop(d,ix,x,ex):
    """Adjont extract patches with subpixel shift"""
    psi = cp.zeros([nobj, nobj], dtype='complex64')
    dr = cp.pad(d, ((0, 0), (ex, ex), (ex, ex)))
    dr = S(dr,-x)        
    ExT(psi,dr,ix)
    return psi
# # adjoint tests
# shifts_test = 30*(cp.random.random([npos,2])-0.5).astype('float32')
# ishifts = shifts_test.astype('int32')
# fshifts = shifts_test-ishifts

# arr1 = (cp.random.random([nobj,nobj])+1j*cp.random.random([nobj,nobj])).astype('complex64')
# arr2 = Ex(arr1,ishifts)
# arr3 = arr1*0
# ExT(arr3,arr2,ishifts)
# print(f'{cp.sum(arr1*cp.conj(arr3))}==\n{cp.sum(arr2*cp.conj(arr2))}')

# arr1 = (cp.random.random([nobj,nobj])+1j*cp.random.random([nobj,nobj])).astype('complex64')
# arr2 = Sop(arr1,ishifts,fshifts,extra)
# arr3 = STop(arr2,ishifts,fshifts,extra)
# print(f'{cp.sum(arr1*cp.conj(arr3))}==\n{cp.sum(arr2*cp.conj(arr2))}')

# arr1 = (cp.random.random([npos,nprb,nprb])+1j*cp.random.random([npos,nprb,nprb])).astype('complex64')
# arr2 = Lop(arr1)
# arr3 = LTop(arr2)
# print(f'{cp.sum(arr1*cp.conj(arr3))}==\n{cp.sum(arr2*cp.conj(arr2))}')
# arr1=arr2=arr3=[]

# read data

In [None]:
import h5py
npos = 18*18
pos_step = 1 # steps in positions
with h5py.File(f'{path}/code2um_nfp18x18_010000.h5') as fid:
    data = fid['/entry_0000/measurement/data'][:npos].astype('float32')
    
with h5py.File(f'{path}/ref_0000.h5') as fid:
    ref = fid['/entry_0000/measurement/data'][:].astype('float32')
with h5py.File(f'{path}/dark_0000.h5') as fid:
    dark = fid['/entry_0000/measurement/data'][:].astype('float32')

shifts = np.loadtxt(f'/data/vnikitin/ESRF/ID16A/20240924/positions/shifts_code_nfp18x18ordered.txt')[:,::-1]
shifts = shifts/voxelsize*(2048//n)*1e-6
shifts[:,1]*=-1

print(shifts[-10:])
shifts = np.load(f'shifts_new.npy')
print(shifts[-10:])
#centering
shifts[:,1]-=(np.amax(shifts[:,1])+np.amin(shifts[:,1]))/2
shifts[:,0]-=(np.amax(shifts[:,0])+np.amin(shifts[:,0]))/2
shifts = shifts.reshape(int(np.sqrt(npos)),int(np.sqrt(npos)),2)
shifts = shifts[::pos_step,::pos_step,:].reshape(npos//pos_step**2,2)
data = data.reshape(int(np.sqrt(npos)),int(np.sqrt(npos)),n,n)
data = data[::pos_step,::pos_step,:].reshape(npos//pos_step**2,n,n)

ids = np.where((np.abs(shifts[:,0])<nobj//2-n//2-pad-extra)*(np.abs(shifts[:,1])<nobj//2-n//2-pad-extra))[0]#[0:2]
data = data[ids]
shifts = shifts[ids]

plt.plot(shifts[:,0],shifts[:,1],'.')
plt.axis('square')
plt.show()

npos = len(ids)
print(f'{npos=}')

chunk = 16
nchunk = int(np.ceil(npos/chunk))

#data=cp.array(data)
#dark=cp.array(dark)
#ref=cp.array(ref)
#shifts=cp.array(shifts)

In [None]:
import cupyx.scipy.ndimage as ndimage
def remove_outliers(data, dezinger, dezinger_threshold):    
    res = data.copy()
    w = [dezinger,dezinger]
    for k in range(data.shape[0]):
        data0 = cp.array(data[k])
        fdata = ndimage.median_filter(data0, w)
        print(np.sum(np.abs(data0-fdata)>fdata*dezinger_threshold))
        res[k] = np.where(np.abs(data0-fdata)>fdata*dezinger_threshold, fdata, data0).get()
    return res

dark = np.mean(dark,axis=0)
ref = np.mean(ref,axis=0)
data -= dark
ref -= dark

data[data<0]=0
ref[ref<0]=0
data[:,1320//3:1320//3+25//3,890//3:890//3+25//3] = data[:,1280//3:1280//3+25//3,890//3:890//3+25//3]
ref[1320//3:1320//3+25//3,890//3:890//3+25//3] = ref[1280//3:1280//3+25//3,890//3:890//3+25//3]

data = remove_outliers(data, 3, 0.8)    
ref = remove_outliers(ref[None], 3, 0.8)[0]

data /= np.mean(ref)
ref /= np.mean(ref)

data[np.isnan(data)] = 1
ref[np.isnan(ref)] = 1

mshow(data[0],mshow)
mshow(ref,mshow)

# Paganin reconstruction

In [None]:
def Paganin(data, wavelength, voxelsize, delta_beta,  alpha):
    fx = cp.fft.fftfreq(data.shape[-1], d=voxelsize).astype('float32')
    [fx, fy] = cp.meshgrid(fx, fx)
    rad_freq = cp.fft.fft2(data)
    taylorExp = 1 + wavelength * distance * cp.pi * (delta_beta) * (fx**2+fy**2)
    numerator = taylorExp * (rad_freq)
    denominator = taylorExp**2 + alpha
    phase = cp.log(cp.real(cp.fft.ifft2(numerator / denominator)))
    phase = delta_beta * 0.5 * phase
    return phase

def rec_init(rdata,ishifts):
    recMultiPaganin = cp.zeros([nobj,nobj],dtype='float32')
    recMultiPaganinr = cp.zeros([nobj,nobj],dtype='float32')# to compensate for overlap
    for j in range(0,npos):
        r = cp.array(rdata[j])        
        r = Paganin(r, wavelength, voxelsize,  24.05, 1e-1)
        rr = r*0+1 # to compensate for overlap                        
        rpsi = cp.zeros([nobj,nobj],dtype='float32')
        rrpsi = cp.zeros([nobj,nobj],dtype='float32')
        stx = nobj//2-ishifts[j,1]-n//2
        endx = stx+n
        sty = nobj//2-ishifts[j,0]-n//2
        endy = sty+n
        rpsi[sty:endy,stx:endx] = r
        rrpsi[sty:endy,stx:endx] = rr
        
        recMultiPaganin += rpsi
        recMultiPaganinr += rrpsi
        
    recMultiPaganinr[np.abs(recMultiPaganinr)<5e-2] = 1    
    recMultiPaganin /= recMultiPaganinr    
    recMultiPaganin = np.exp(1j*recMultiPaganin)
    return recMultiPaganin

ishifts = np.round(np.array(shifts)).astype('int32')
rdata = np.array(data/(ref+1e-5))
rec_paganin = rec_init(rdata,ishifts)
mshow_polar(rec_paganin,show)
mshow_polar(rec_paganin[:1000,:1000],show)

# smooth borders
v = cp.arange(-nobj//2, nobj//2)/nobj
[vx, vy] = cp.meshgrid(v, v)
v = cp.exp(-1000*(vx**2+vy**2)).astype('float32')

rec_paganin = cp.fft.fftshift(cp.fft.fftn(cp.fft.fftshift(rec_paganin)))
rec_paganin = cp.fft.fftshift(cp.fft.ifftn(cp.fft.fftshift(rec_paganin*v))).astype('complex64')
mshow_polar(rec_paganin,show)
mshow_polar(rec_paganin[:1000,:1000],show)

rdata=v=[]


In [None]:
def gradientF0(pars, reused, d, st, end):
    d = cp.array(d[st:end])

    Lpsi =  cp.array(reused['Lpsi'][st:end]    )
    if pars['model']=='Gaussian':
        td = d*(Lpsi/(cp.abs(Lpsi)+pars['eps']))                
        res = 2*LTop(Lpsi - td)        
    elif pars['model']=='Poisson':
        dd = d*Lpsi/(cp.abs(Lpsi)**2+pars['eps']**2) 
        res = 2*LTop(Lpsi-dd)  
    reused['gradF'][st:end] = res.get()    

def gradientF(pars, reused, d):
    reused['gradF'] = np.zeros([npos,nprb,nprb],dtype='complex64')
    for k in range(nchunk):         
        st = k*chunk         
        end = min((k+1)*chunk,npos)
        gradientF0(pars, reused, d, st, end)


In [None]:
def gradient_psi0(psi, q,ix,x,ex,gradF,st,end):
    x = cp.array(x[st:end])
    ix = cp.array(ix[st:end])
    gradF = cp.array(gradF[st:end])
    d = cp.conj(q)*gradF
    dr = cp.pad(d, ((0, 0), (ex, ex), (ex, ex)))
    dr = S(dr,-x)        
    ExT(psi,dr,ix)

def gradient_psi(q,ix,x,ex,gradF):
    psi = cp.zeros([nobj, nobj], dtype='complex64')
    for k in range(nchunk):         
        st = k*chunk         
        end = min((k+1)*chunk,npos)
        gradient_psi0(psi, q,ix,x,ex,gradF,st,end)    
    return psi

def gradient_prb0(spsi,gradF,st,end):
    spsi = cp.array(spsi[st:end])
    gradF = cp.array(gradF[st:end])
    return cp.sum(cp.conj(spsi)*gradF,axis=0)

def gradient_prb(spsi,gradF):
    res = cp.zeros([nprb,nprb],dtype='complex64')
    for k in range(nchunk):         
        st = k*chunk         
        end = min((k+1)*chunk,npos)
        res += gradient_prb0(spsi,gradF,st,end)
    return res

def gradient_shift0(psi, q, ix, x, ex, gradF, st, end):    
    ix = cp.array(ix[st:end])
    x = cp.array(x[st:end])
    gradF = cp.array(gradF[st:end])
    # frequencies
    xi1 = cp.fft.fftfreq(npatch).astype('float32')
    xi2, xi1 = cp.meshgrid(xi1, xi1)

    # multipliers in frequencies
    w = cp.exp(-2 * cp.pi * 1j * (xi2 * x[:, 1, None, None] + xi1 * x[:, 0, None, None]))
    
    # Gradient parts
    tmp = Ex(psi, ix)
    tmp = cp.fft.fft2(tmp) 

    dt1 = cp.fft.ifft2(w*xi1*tmp)
    dt2 = cp.fft.ifft2(w*xi2*tmp)
    dt1 = -2 * cp.pi * dt1[:,ex:nprb+ex,ex:nprb+ex]
    dt2 = -2 * cp.pi * dt2[:,ex:nprb+ex,ex:nprb+ex]
    
    # inner product with gradF
    gradx = cp.zeros([gradF.shape[0], 2], dtype='float32')
    gradx[:,0] = imdot(gradF, q * dt1, axis=(1, 2))
    gradx[:,1] = imdot(gradF, q * dt2, axis=(1, 2))
    return gradx.get()

def gradient_shift(psi, q, ix, x, ex, gradF):    
    gradx = np.zeros([npos, 2], dtype='float32')
    for k in range(nchunk):
        st = k*chunk
        end = min((k+1)*chunk,npos)
        gradx[st:end] = gradient_shift0(psi, q, ix, x, ex, gradF, st, end)    
    return gradx


def gradients(vars,pars,reused):    
    (q,psi,x) = (vars['prb'], vars['psi'], vars['fshift'])
    (ix,ex,rho) = (pars['ishift'],pars['extra'],pars['rho'])
    (gradF, spsi) = (reused['gradF'],reused['spsi'])
    dpsi = gradient_psi(q,ix,x,ex,gradF)
    dprb = gradient_prb(spsi,gradF)
    dx = gradient_shift(psi,q,ix,x,ex,gradF)
    grads={'psi': rho[0]*dpsi, 'prb': rho[1]*dprb, 'fshift': rho[2]*dx}
    return grads

In [None]:
def hessianF(Lm,Ldm1,Ldm2,data,pars):
    if pars['model']=='Gaussian':
        psi0p = Lm/(cp.abs(Lm)+pars['eps'])
        d0 = data/(cp.abs(Lm)+pars['eps'])
        v1 = cp.sum((1-d0)*reprod(Ldm1,Ldm2))
        v2 = cp.sum(d0*reprod(psi0p,Ldm1)*reprod(psi0p,Ldm2))        
    else:        
        psi0p = Lm/(cp.abs(Lm)+pars['eps'])            
        v1 = cp.sum((1-data/(cp.abs(Lm)**2+pars['eps']**2))*reprod(Ldm1,Ldm2))
        v2 = 2*cp.sum(data*reprod(psi0p,Ldm1)*reprod(psi0p,Ldm2)/(cp.abs(Lm)**2+pars['eps']**2))
    return 2*(v1+v2)

# Optimized version, without extra functions

In [None]:
def calc_beta0(vars,grads,etas,pars,reused,d,st,end):
    (q,psi,x) = (vars['prb'], vars['psi'], vars['fshift'][st:end])    
    (ix,ex,rho) = (pars['ishift'][st:end],pars['extra'],pars['rho'])
    (spsi,Lpsi,gradF) = (reused['spsi'][st:end], reused['Lpsi'][st:end], reused['gradF'][st:end])
    x = cp.array(x)
    ix = cp.array(ix)
    spsi = cp.array(spsi)
    Lpsi = cp.array(Lpsi)
    gradF = cp.array(gradF)
    d = cp.array(d[st:end])

    # note scaling with rho
    (dpsi1,dq1,dx1) = (grads['psi']*rho[0], grads['prb']*rho[1], grads['fshift'][st:end]*rho[2])
    (dpsi2,dq2,dx2) = (etas['psi']*rho[0], etas['prb']*rho[1], etas['fshift'][st:end]*rho[2])
    dx1 = cp.array(dx1)
    dx2 = cp.array(dx2)
        
    # frequencies
    xi1 = cp.fft.fftfreq(npatch).astype('float32')
    [xi2, xi1] = cp.meshgrid(xi1, xi1)    

    # multipliers in frequencies
    dx1 = dx1[:,:,cp.newaxis,cp.newaxis]
    dx2 = dx2[:,:,cp.newaxis,cp.newaxis]
    w = cp.exp(-2*cp.pi*1j * (xi2*x[:, 1, None, None]+xi1*x[:, 0, None, None]))
    w1 = xi1*dx1[:,0]+xi2*dx1[:,1]
    w2 = xi1*dx2[:,0]+xi2*dx2[:,1]
    w12 = xi1**2*dx1[:,0]*dx2[:,0]+ \
                xi1*xi2*(dx1[:,0]*dx2[:,1]+dx1[:,1]*dx2[:,0])+ \
                xi2**2*dx1[:,1]*dx2[:,1]
    w22 = xi1**2*dx2[:,0]**2+ 2*xi1*xi2*(dx2[:,0]*dx2[:,1]) + xi2**2*dx2[:,1]**2
    
    # DT, D2T terms
    tmp1 = Ex(dpsi1,ix)     
    tmp1 = cp.fft.fft2(tmp1)
    sdpsi1 = cp.fft.ifft2(w*tmp1)[:,ex:nprb+ex,ex:nprb+ex]
    dt12 = -2*cp.pi*1j*cp.fft.ifft2(w*w2*tmp1)[:,ex:nprb+ex,ex:nprb+ex]
    
    tmp2 = Ex(dpsi2,ix)     
    tmp2 = cp.fft.fft2(tmp2)
    sdpsi2 = cp.fft.ifft2(w*tmp2)[:,ex:nprb+ex,ex:nprb+ex]
    dt21 = -2*cp.pi*1j*cp.fft.ifft2(w*w1*tmp2)[:,ex:nprb+ex,ex:nprb+ex]
    dt22 = -2*cp.pi*1j*cp.fft.ifft2(w*w2*tmp2)[:,ex:nprb+ex,ex:nprb+ex]
    
    tmp = Ex(psi,ix)     
    tmp = cp.fft.fft2(tmp)        
    dt1 = -2*cp.pi*1j*cp.fft.ifft2(w*w1*tmp)[:,ex:nprb+ex,ex:nprb+ex]
    dt2 = -2*cp.pi*1j*cp.fft.ifft2(w*w2*tmp)[:,ex:nprb+ex,ex:nprb+ex]
    d2t1 = -4*cp.pi**2*cp.fft.ifft2(w*w12*tmp)[:,ex:nprb+ex,ex:nprb+ex]
    d2t2 = -4*cp.pi**2*cp.fft.ifft2(w*w22*tmp)[:,ex:nprb+ex,ex:nprb+ex]
    
    # DM,D2M terms
    d2m1 =  q*dt12 + q*dt21 + q*d2t1
    d2m1 += dq1*sdpsi2 + dq2*sdpsi1
    d2m1 += dq1*dt2 + dq2*dt1

    d2m2 =  q*dt22 + q*dt22 + q*d2t2
    d2m2 += dq2*sdpsi2 + dq2*sdpsi2
    d2m2 += dq2*dt2 + dq2*dt2

    dm1 = dq1*spsi+q*(sdpsi1+dt1)   
    dm2 = dq2*spsi+q*(sdpsi2+dt2)   

    # top and bottom parts
    Ldm1 = Lop(dm1)
    Ldm2 = Lop(dm2) 
    top = redot(gradF,d2m1)+hessianF(Lpsi, Ldm1, Ldm2, d, pars)            
    bottom = redot(gradF,d2m2)+hessianF(Lpsi, Ldm2, Ldm2,d, pars)
    
    return top, bottom

def calc_beta(vars,grads,etas,pars,reused,d):
    top = bottom = 0
    for k in range(nchunk):         
        st = k*chunk         
        end = min((k+1)*chunk,npos)
        top0,bottom0 = calc_beta0(vars,grads,etas,pars,reused,d,st,end)
        top+=top0
        bottom+=bottom0    
    return float(top/bottom)

def calc_alpha0(vars,grads,etas,pars,reused,d,st,end):    
    (q,psi,x) = (vars['prb'], vars['psi'], vars['fshift'][st:end])    
    (ix,ex,rho) = (pars['ishift'][st:end],pars['extra'],pars['rho'])
    (dpsi1,dq1,dx1) = (grads['psi'], grads['prb'], grads['fshift'][st:end])
    (dpsi2,dq2,dx2) = (etas['psi'], etas['prb'], etas['fshift'][st:end])    
    (spsi,Lpsi,gradF) = (reused['spsi'][st:end],reused['Lpsi'][st:end], reused['gradF'][st:end])
    x = cp.array(x)
    ix = cp.array(ix)
    spsi = cp.array(spsi)
    Lpsi = cp.array(Lpsi)
    gradF = cp.array(gradF)
    dx1 = cp.array(dx1)
    dx2 = cp.array(dx2)
    d = cp.array(d[st:end])

    top=-redot(dx1,dx2)
    # top part
    if st==0:
        top += -redot(dpsi1,dpsi2)-redot(dq1,dq2)
        
    # scale variable for the hessian
    (dpsi,dq,dx) = (etas['psi']*rho[0], etas['prb']*rho[1], etas['fshift'][st:end]*rho[2])
    dx = cp.array(dx)

    # frequencies        
    xi1 = cp.fft.fftfreq(npatch).astype('float32')    
    [xi2, xi1] = cp.meshgrid(xi1, xi1)

    # multipliers in frequencies
    dx = dx[:,:,cp.newaxis,cp.newaxis]
    w = cp.exp(-2*cp.pi*1j * (xi2*x[:, 1, None, None]+xi1*x[:, 0, None, None]))
    w1 = xi1*dx[:,0]+xi2*dx[:,1]
    w2 = xi1**2*dx[:,0]**2+ 2*xi1*xi2*(dx[:,0]*dx[:,1]) + xi2**2*dx[:,1]**2
    
    # DT,D2T terms, and Spsi
    tmp = Ex(dpsi,ix)     
    tmp = cp.fft.fft2(tmp)    
    sdpsi = cp.fft.ifft2(w*tmp)[:,ex:nprb+ex,ex:nprb+ex]
    dt2 = -2*cp.pi*1j*cp.fft.ifft2(w*w1*tmp)[:,ex:nprb+ex,ex:nprb+ex]
    
    tmp = Ex(psi,ix)     
    tmp = cp.fft.fft2(tmp)
    dt = -2*cp.pi*1j*cp.fft.ifft2(w*w1*tmp)[:,ex:nprb+ex,ex:nprb+ex]
    d2t = -4*cp.pi**2*cp.fft.ifft2(w*w2*tmp)[:,ex:nprb+ex,ex:nprb+ex]
    
    # DM and D2M terms
    d2m2 = q*(2*dt2 + d2t)+2*dq*sdpsi+2*dq*dt
    dm = dq*spsi+q*(sdpsi+dt)   
            
    # bottom part
    Ldm = Lop(dm)
    bottom = redot(gradF,d2m2)+hessianF(Lpsi, Ldm, Ldm,d,pars)
    
    return top/bottom, top, bottom


def calc_alpha(vars,grads,etas,pars,reused,d):    
    top = bottom = 0
    for k in range(nchunk):         
        st = k*chunk         
        end = min((k+1)*chunk,npos)
        alpha0,top0,bottom0 = calc_alpha0(vars,grads,etas,pars,reused,d,st,end)    
        top+=top0
        bottom+=bottom0
    
    return float(top/bottom), float(top), float(bottom)

## minimization functional and calculation of reused arrays

In [None]:
def minf0(Lpsi,d,pars,st,end):
    Lpsi = cp.array(Lpsi[st:end])
    d = cp.array(d[st:end])
    if pars['model']=='Gaussian':
        f = cp.linalg.norm(cp.abs(Lpsi)-d)**2/(n*n*npos)    
    else:        
        f = cp.sum(cp.abs(Lpsi)**2-2*d*cp.log(cp.abs(Lpsi)+pars['eps']))/(n*n*npos)          
        # loss = torch.nn.PoissonNLLLoss(log_input=False, full=True, size_average=None, eps=pars['eps'], reduce=None, reduction='sum')
        # f = loss(torch.as_tensor(cp.abs(Lpsi)**2,device='cuda'),torch.as_tensor(d,device='cuda'))    
    return f

def minf(Lpsi,d,pars):
    f = 0
    for k in range(nchunk):         
        st = k*chunk         
        end = min((k+1)*chunk,npos)
        f += minf0(Lpsi,d,pars,st,end)
    return float(f)

def calc_reused0(reused, vars, pars, st, end):
    
    psi = vars['psi']
    q = vars['prb']
    x = vars['fshift'][st:end]
    ix = pars['ishift'][st:end]
    ex = pars['extra']

    x = cp.array(x)
    ix = cp.array(ix)
    spsi = Sop(psi,ix,x,ex)
    reused['Lpsi'][st:end] = Lop(spsi*q).get()     
    reused['spsi'][st:end] = spsi.get()
    return reused

def calc_reused(vars, pars):
    reused = {}
    reused['spsi'] = np.zeros([npos,nprb,nprb],dtype='complex64')
    reused['Lpsi'] = np.zeros([npos,n,n],dtype='complex64')
    for k in range(nchunk):         
        st = k*chunk         
        end = min((k+1)*chunk,npos)
        calc_reused0(reused, vars, pars, st,end)
    return reused




## debug functions

In [None]:
def calc_Lpsi0(Lpsi,q,psi,ix,x,ex,st,end):
    x = cp.array(x[st:end])
    ix = cp.array(ix[st:end])
    spsi = Sop(psi,ix,x,ex)
    Lpsi[st:end] = Lop(spsi*q).get()     

def calc_Lpsi(q,psi,ix,x,ex):
    Lpsi = np.zeros([npos,n,n],dtype='complex64')
    for k in range(nchunk):         
        st = k*chunk         
        end = min((k+1)*chunk,npos)
        calc_Lpsi0(Lpsi,q,psi,ix,x,ex,st,end)
    return Lpsi
    

def plot_debug(vars,etas,pars,top,bottom,alpha,data,i):
    '''Check the minimization functional behaviour'''
    if i % pars['vis_step'] == 0 and pars['vis_step'] != -1 and show:
        (q,psi,x) = (vars['prb'], vars['psi'], vars['fshift'])    
        (ix,ex,rho) = (pars['ishift'],pars['extra'],pars['rho'])
        (dpsi2,dq2,dx2) = (etas['psi'], etas['prb'], etas['fshift'])    

        npp = 3
        errt = np.zeros(npp*2)
        errt2 = np.zeros(npp*2)
        for k in range(0,npp*2):
            psit = psi+(alpha*k/(npp-1))*rho[0]*dpsi2
            qt = q+(alpha*k/(npp-1))*rho[1]*dq2
            xt = x+(alpha*k/(npp-1))*rho[2]*dx2

            errt[k] = minf(calc_Lpsi(qt,psit,ix,xt,ex),data,pars)
                    
        t = alpha*(np.arange(2*npp))/(npp-1)    
        errt2 = minf(calc_Lpsi(q,psi,ix,x,ex),data,pars)
        errt2 = errt2 -top*t/(n*n*npos)+0.5*bottom*t**2/(n*n*npos)    
        
        plt.plot(alpha*np.arange(2*npp)/(npp-1),errt,'.')
        plt.plot(alpha*np.arange(2*npp)/(npp-1),errt2,'.')
        plt.show()

def vis_debug(vars,pars,i):
    '''Visualization and data saving'''
    if i % pars['vis_step'] == 0 and pars['vis_step'] != -1:
        (q,psi,x) = (vars['prb'], vars['psi'], vars['fshift'])        
        mshow_polar(psi,show)
        
        mshow_polar(q,show)
        mshow_polar(q[nprb//2-nprb//8:nprb//2+nprb//8,nprb//2+nprb//4:nprb//2+nprb//2],show)
        write_tiff(cp.angle(psi),f'{path_out}_{pars['flg']}/crec_psi_angle/{i:03}')
        write_tiff(cp.abs(psi),f'{path_out}_{pars['flg']}/crec_psi_abs/{i:03}')
        write_tiff(cp.angle(q),f'{path_out}_{pars['flg']}/crec_prb_angle/{i:03}')
        write_tiff(cp.abs(q),f'{path_out}_{pars['flg']}/crec_prb_abs/{i:03}')
        cp.save(f'{path_out}_{pars['flg']}/crec_shift_{i:03}',x)
        plt.plot(x[:,0]-fshifts_init[:,0],'.',label='y')
        plt.plot(x[:,1]-fshifts_init[:,1],'.',label='x')
        plt.legend()
        plt.grid()

        plt.show()
        

def error_debug(vars, pars, reused, data, i):
    '''Visualization and data saving'''
    if i % pars['err_step'] == 0 and pars['err_step'] != -1:
        err = minf(reused['Lpsi'],data,pars)
        print(f'{i}) {err=:1.5e}',flush=True)                        
        vars['table'].loc[len(vars['table'])] = [i, err, time.time()]
        vars['table'].to_csv(f'{pars['flg']}', index=False)            

def grad_debug(alpha, grads, pars, i):
    if i % pars['grad_step'] == 0 and pars['grad_step'] != -1:
        print(f'(alpha,psi,prb,shift): {alpha:.1e} {cp.linalg.norm(grads['psi']):.1e},{cp.linalg.norm(grads['prb']):.1e},{cp.linalg.norm(grads['fshift']):.1e}')

In [None]:
prb_init = cp.ones([nprb,nprb],dtype='complex64')
v = cp.ones(nprb)
ppad = n//16
vv = cp.sin(cp.linspace(0.0,cp.pi/2,ppad))
v[:ppad] = vv
v[nprb-ppad:] = vv[::-1]
v = cp.outer(v,v)
prb_init*=v
mshow_polar(prb_init,mshow)

# Bilinear Hessian method

In [None]:
def BH(data, vars, pars):
   
    if pars['model']=='Gaussian':
        # work with sqrt
        data = np.sqrt(data)
        
    alpha = 1
    for i in range(pars['niter']):                             
        reused = calc_reused(vars, pars)
        error_debug(vars, pars, reused, data, i)
        vis_debug(vars, pars, i)            
      
        gradientF(pars,reused,data) 
        grads = gradients(vars,pars,reused)
        if i==0 or pars['method']=='BH-GD':
            etas = {}
            etas['psi'] = -grads['psi']
            etas['prb'] = -grads['prb']
            etas['fshift'] = -grads['fshift']
        else:      
            beta = calc_beta(vars, grads, etas, pars, reused, data)
            etas['psi'] = -grads['psi'] + beta*etas['psi']
            etas['prb'] = -grads['prb'] + beta*etas['prb']
            etas['fshift'] = -grads['fshift'] + beta*etas['fshift']

        
        alpha,top,bottom = calc_alpha(vars, grads, etas, pars, reused, data)         

        plot_debug(vars,etas,pars,top,bottom,alpha,data,i)
        grad_debug(alpha,grads,pars,i)
        
        vars['psi'] += pars['rho'][0]*alpha*etas['psi']
        vars['prb'] += pars['rho'][1]*alpha*etas['prb']        
        vars['fshift'] += pars['rho'][2]*alpha*etas['fshift']
        
    return vars

# fixed variables
pars = {'niter':2049, 'err_step': 1, 'vis_step': 32, 'grad_step': -1}
pars['rho'] = [1,1.5,0.1]
pars['ishift'] = np.floor(shifts).astype('int32')
pars['extra'] = extra
pars['eps'] = 1e-9
pars['model'] = 'Gaussian'
pars['method'] = 'BH-CG'
pars['flg'] = f'{pars['method']}_{pars['rho'][0]}_{pars['rho'][1]}_{pars['rho'][2]}'

vars = {}
vars['psi'] = rec_paganin.copy()
vars['prb'] = prb_init.copy()#cp.ones([nprb,nprb],dtype='complex64')
vars['fshift'] = np.array(shifts-np.floor(shifts).astype('int32')).astype('float32')
vars['table'] = pd.DataFrame(columns=["iter", "err", "time"])

fshifts_init = vars['fshift'].copy()
vars = BH(data, vars, pars)      

mshow_polar(vars['psi'],mshow)
mshow_polar(vars['prb'],mshow)
erra = vars['table']['err'].values
# times=vars['table']['time'].values
# times-=times[0]
# print(times)
rec_pos = (vars['fshift']+pars['ishift'])
plt.plot(erra,label=pars['method'])
plt.yscale('log')
plt.show()

plt.plot(shifts[:,1],shifts[:,0],'r.')
plt.plot(rec_pos[:,1],rec_pos[:,0],'g.')
plt.axis('equal')
plt.show()
