In [None]:
import numpy as np
import cupy as cp
import h5py
from holotomocupy.holo import G, GT
from holotomocupy.shift import S, ST
from holotomocupy.recon_methods import multiPaganin
from holotomocupy.utils import *
from holotomocupy.proc import remove_outliers
##!jupyter nbconvert --to script config_template.ipynb
import sys

# Init data sizes and parametes of the PXM of ID16A

In [None]:

path = f'/data/WB_slit_exp_4ms_125hz_1184x1200_000010.h5'
path_out = f'/data/WB_slit_exp_4ms_125hz_1184x1200_000010_rec'
show=True
npos = 20
n=1184

## Read data

In [None]:
with h5py.File(f'{path}') as fid:
    data0 = fid['/entry/data/data'][:npos,:1184,:1184].astype('float32')
data0/=np.mean(data0)
mshow(data0[0],show)
data0 = data0[np.newaxis]

In [None]:
import scipy.ndimage as ndimage
def remove_outliers(data, dezinger, dezinger_threshold):    
    res = data.copy()
    if (int(dezinger) > 0):
        w = int(dezinger)
        # print(data.shape)
        fdata = ndimage.median_filter(data, [1,w, w])
        print(np.sum(np.abs(data-fdata)>fdata*dezinger_threshold))
        res[:] = np.where(np.abs(data-fdata)>fdata*dezinger_threshold, fdata, data)
    return res

In [None]:
data = data0.copy()
data[0] = remove_outliers(data[0], 3, 0.8)
mshow(data[0,0],show)


# Construct operators


In [None]:
def Sop(psi,shifts):
    data = cp.zeros([1, npos, n, n], dtype='complex64')
    psi = cp.array(psi)
    for j in range(npos):
        psir = psi.copy()
        shiftsr = cp.array(shifts[:, j])
        psir = S(psir, shiftsr)
        data[:,j] = psir
    return data

def STop(data,shifts):
    psi = cp.zeros([1, n, n], dtype='complex64')

    for j in range(npos):
        datar = cp.array(data[:,j])
        shiftsr = cp.array(shifts[:, j])
        psi += ST(datar,shiftsr)
    return psi

arr1 = (cp.random.random([1,n,n])).astype('complex64')
shifts0 = cp.random.random([1,npos,2]).astype('complex64')
arr2 = Sop(arr1,shifts0)
arr3 = STop(arr2,shifts0)
print(f'{np.sum(arr1*np.conj(arr3))}==\n{np.sum(arr2*np.conj(arr2))}')

shifts = cp.random.random([1,npos,2]).astype('float32')*10*0
shifts[:,-1]=0.1
data = Sop(data[:,0],shifts)
mshow_complex(data[0,0]-data[0,-1],show)

## Reconstruction with the CG (Carlsson) with Hessians

$$ H(q,\psi,{{x}})=F(J(q)\cdot S_{{x}}(\psi))+\lambda_o\|\nabla C(\psi)\|_2^2+\lambda_p\|\nabla C(q)\|_2^2=\left\||L(M(q,\psi,x))|-d\right\|_2^2+\lambda_o\|\nabla C(\psi)\|_2^2+\lambda_p\|\nabla C(q)\|_2^2+\| |L_1(q)|-d_r\|^2_2
$$

## Gradients

#### $$\nabla F=2 \left(L^*\left( L(M(q_0,\psi_0,\boldsymbol{x}_0))-\tilde d\right)\right).$$
#### where $$\tilde d = d \frac{L(M(q_0,\psi_0,\boldsymbol{x}_0))}{|L(M(q_0,\psi_0,\boldsymbol{x}_0))|}$$

#### $$\nabla F_q=2 \left(L^*\left( L_1(q_0)-\tilde d_r\right)\right).$$
#### where $$\tilde d_r = d_r \frac{L_1(q_0)}{|L_1(q_0))|}$$



In [None]:
def gradientF(vars,d):
    (psi,q,x) = (vars['psi'], vars['prb'], vars['shift'])    
    return 2*(q*Sop(psi,x)-d)


##### $$\nabla_{\psi} G|_{(q_0,\psi_0,\boldsymbol{x}_0)}=S_{\boldsymbol{x}_{0}}^*\left(\overline{J(q_0)}\cdot \nabla F\right)+ \underline{2\lambda C^T(\nabla^T (\nabla(C(\psi))))}$$

##### $$\nabla_{q} G|_{(q_0,\psi_0,\boldsymbol{x}_0)}=J^*\left( \overline{S_{\boldsymbol{x}_{0}}(\psi_0)}\cdot \nabla F\right).$$
##### $$\nabla_{\boldsymbol{x}_0} G|_{(q_0,\psi_0,\boldsymbol{x}_0)}=\textsf{Re}\Big(\big( \Big\langle \overline{q_0}\cdot \nabla F,   C(\mathcal{F}^{-1}(-2\pi i \xi_1 e^{ -2\pi i \boldsymbol{x}_{0,k}\cdot \boldsymbol{\xi}}\hat{\psi_0}))\Big\rangle,\Big\langle \overline{q_0}\cdot \nabla F,C(\mathcal{F}^{-1}(-2\pi i \xi_2 e^{ -2\pi i \boldsymbol{x}_{0,k}\cdot \boldsymbol{\xi}}\hat{\psi_0})) \Big\rangle\big)\Big)_{k=1}^K. $$




In [None]:

def gradientpsi(q,x,gradF):
    return STop(np.conj(q)*gradF,x)

def gradientq(psi,q,x,gradF):
    return np.sum(np.conj(Sop(psi,x))*gradF,axis=1)    

# def gradientx(psi,q,x,gradF):
#     xi1 = cp.fft.fftfreq(2*psi.shape[-1]).astype('float32')    
#     [xi2, xi1] = cp.meshgrid(xi1, xi1)  
#     tksi1 = Twop_(psi,x,-2*cp.pi*1j*xi1)
#     tksi2 = Twop_(psi,x,-2*cp.pi*1j*xi2)    
#     gradx = cp.zeros([1,npos,2],dtype='float32')
#     tmp = np.conj(q)*gradF
#     gradx[:,:,0] = redot(tmp,tksi1,axis=(2,3))
#     gradx[:,:,1] = redot(tmp,tksi2,axis=(2,3))
#     return gradx
def gradientx(psi,q,x,gradF):
    gradx = cp.zeros([1,npos,2],dtype='float32')    
    xi1 = cp.fft.fftfreq(n).astype('float32')
    [xi2, xi1] = cp.meshgrid(xi1, xi1)
    psir = psi.copy()#cp.pad(eRu, ((0, 0), (ne//2, ne//2), (ne//2, ne//2)))
    for j in range(npos):        
        xj = cp.array(x[:,j,:,np.newaxis,np.newaxis])
        pp = cp.exp(-2*cp.pi*1j*(xi1*xj[:, 0]+xi2*xj[:, 1]))      
        t = cp.fft.ifft2(pp*xi1*cp.fft.fft2(psir))
        gradx[:,j,0] = -2*np.pi*imdot(gradF[:,j],q*t,axis=(1,2))    
    
    for j in range(npos):        
        xj = cp.array(x[:,j,:,np.newaxis,np.newaxis])
        pp = cp.exp(-2*cp.pi*1j*(xi1*xj[:, 0]+xi2*xj[:, 1]))                            
        t = cp.fft.ifft2(pp*xi2*cp.fft.fft2(psir))
        gradx[:,j,1] = -2*np.pi*imdot(gradF[:,j],q*t,axis=(1,2))    

    return gradx

    
def gradients(vars,d,gradF):
    (psi,q,x,rho) = (vars['psi'], vars['prb'], vars['shift'],vars['rho'])
    grads = {}
    grads['psi'] = rho[0]*gradientpsi(q,x,gradF)
    grads['prb'] = rho[1]*gradientq(psi,q,x,gradF)
    grads['shift'] = rho[2]*gradientx(psi,q,x,gradF)
    return grads



##### $$\frac{1}{2}\mathcal{H}|_{x_0}(y,z)= \left\langle \mathbf{1}-d_{0}, \mathsf{Re}({L(y)}\overline{L(z)})\right\rangle+\left\langle d_{0},(\mathsf{Re} (\overline{l_0}\cdot L(y)))\cdot (\mathsf{Re} (\overline{l_0}\cdot L(z))) \right\rangle$$
##### $$l_0=L(x_0)/|L(x_0)|$$
##### $$d_0=d/|L(x_0)|$$


In [None]:
def hessianF(dpsi1,dpsi2):
    return 2*redot(dpsi1,dpsi2)

##### $D T_c|_{{{z}_0}}(\Delta {z})=-2\pi iC\Big(\mathcal{F}^{-1}\big({\Delta z \cdot \xi}) e^{-2\pi i  {z}_0\cdot {\xi}}\hat{c}({\xi})\big)\Big)=-2\pi i C\Big(\mathcal{F}^{-1}\big((\Delta z_1 {\xi_1}+\Delta z_2 {\xi_2}) e^{-2\pi i  {z}_0\cdot {\xi}}\hat{c}({\xi})\big)\Big)$
##### $ D^2{T_c}|_{{{z}_0}}(\Delta{z},\Delta{w})=-4\pi^2C(\mathcal{F}^{-1}((\Delta{z}\cdot\xi)(\Delta{w}\cdot\xi)e^{-2\pi i  {z}_0\cdot {\xi}}\hat{c}))$
##### $=-4\pi^2C(\mathcal{F}^{-1}((\Delta{z_1}\Delta{w_1}\xi_1^2 + (\Delta{z_1}\Delta{w_2}+\Delta{z_2}\Delta{w_1})\xi_1\xi_2+\Delta{z_2}\Delta{w_2}\xi_2^2)\hat{c}))$

In [None]:
def DT(psi,x,dx):
    res = cp.zeros([1,npos,n,n],dtype='complex64')
    xi1 = cp.fft.fftfreq(n).astype('float32')
    [xi2, xi1] = cp.meshgrid(xi1, xi1)
    for j in range(npos):
        psir = psi.copy()#cp.pad(psi, ((0, 0), (ne//2, ne//2), (ne//2, ne//2)))
        xj = cp.array(x[:,j,:,np.newaxis,np.newaxis])
        dxj = cp.array(dx[:,j,:,np.newaxis,np.newaxis])

        pp = cp.exp(-2*cp.pi*1j*(xi1*xj[:, 0]+xi2*xj[:, 1]))    
        xiall = xi1*dxj[:,0]+xi2*dxj[:,1]

        psir = cp.fft.ifft2(pp*xiall*cp.fft.fft2(psir))   

        res[:,j] = -2*np.pi*1j*psir#[:, ne//2:-ne//2, ne//2:-ne//2]       
    return res


def D2T(psi,x,dx1,dx2):
    res = cp.zeros([1,npos,n,n],dtype='complex64')
    xi1 = cp.fft.fftfreq(n).astype('float32')
    [xi2, xi1] = cp.meshgrid(xi1, xi1)
    for j in range(npos):
        psir = psi.copy()#cp.pad(psi, ((0, 0), (ne//2, ne//2), (ne//2, ne//2)))
        xj = cp.array(x[:,j,:,np.newaxis,np.newaxis])
        dx1j = cp.array(dx1[:,j,:,np.newaxis,np.newaxis])
        dx2j = cp.array(dx2[:,j,:,np.newaxis,np.newaxis])

        pp = cp.exp(-2*cp.pi*1j*(xi1*xj[:, 0]+xi2*xj[:, 1]))    
        xiall = xi1**2*dx1j[:,0]*dx2j[:,0]+ \
                xi1*xi2*(dx1j[:,0]*dx2j[:,1]+dx1j[:,1]*dx2j[:,0])+ \
                xi2**2*dx1j[:,1]*dx2j[:,1]

        psir = cp.fft.ifft2(pp*xiall*cp.fft.fft2(psir))   

        res[:,j] = -4*np.pi**2*psir#[:,ne//2:-ne//2, ne//2:-ne//2]       
    return res


#### $$ DM|_{(q_0,\psi_0,\boldsymbol{x})}(\Delta q, \Delta \psi,\Delta\boldsymbol{x})=$$
#### $$ \Big(\Delta q\cdot T_{\psi_0}({\boldsymbol{x}_{0,k}})+ q_0\cdot \big(T_{\Delta \psi}({\boldsymbol{x}_{0,k}})+  DT_{\psi_0}|_{{\boldsymbol{x}_{0,k}}}( \Delta \boldsymbol{x}_k)\big) \Big)_{k=1}^K=$$
#### $$ J(\Delta q)\cdot S_{\boldsymbol{x}_{0,k}}(\psi_0)+ J(q_0)\cdot S_{\boldsymbol{x}_{0}}{(\Delta \psi)}+  \Big(q_0\cdot DT_{\psi_0}|_{{\boldsymbol{x}_{0,k}}}( \Delta \boldsymbol{x}_k) \Big)_{k=1}^K$$


In [None]:
def DM(psi,q,x,dpsi,dq,dx):
    res = dq*Sop(psi,x)+q*(Sop(dpsi,x)+DT(psi,x,dx))   
    return res

##### $$ D^2M|_{(q_0,\psi_0,\boldsymbol{x})}\big((\Delta q^{(1)}, \Delta \psi^{(1)},\Delta\boldsymbol{x}^{(1)}),(\Delta q^{(2)}, \Delta \psi^{(2)},\Delta\boldsymbol{x}^{(2)})\big)= $$
##### $$\Big( q_0\cdot DT_{\Delta\psi^{(1)}}|_{{\boldsymbol{x}_{0,k}}}( \Delta \boldsymbol{x}^{(2)})+q_0\cdot DT_{\Delta\psi^{(2)}}|_{{\boldsymbol{x}_{0,k}}}( \Delta \boldsymbol{x}^{(1)})+ q_0\cdot D^2{T_\psi}|_{{\boldsymbol{x}_0}}(\Delta\boldsymbol{x}^{(1)},\Delta\boldsymbol{x}^{(2)})+$$
##### $$\Delta q^{(1)}\cdot T_{\Delta \psi^{(2)}}({\boldsymbol{x}_{0,k}})+\Delta q^{(2)}\cdot T_{\Delta \psi^{(1)}}({\boldsymbol{x}_{0,k}})+ $$
##### $$\Delta q^{(1)}\cdot DT_{\psi_0}|_{{\boldsymbol{x}_{0,k}}}( \Delta \boldsymbol{x}^{(2)})+\Delta q^{(2)}\cdot DT_{\psi_0}|_{{\boldsymbol{x}_{0,k}}}( \Delta \boldsymbol{x}^{(1)})\Big)_{k=1}^K.$$


In [None]:
def D2M(psi,q,x,dpsi1,dq1,dx1,dpsi2,dq2,dx2):    
    res =  q*DT(dpsi1,x,dx2) + q*DT(dpsi2,x,dx1) + q*D2T(psi,x,dx1,dx2)  
    res += dq1*Sop(dpsi2,x) + dq2*Sop(dpsi1,x) 
    res += dq1*DT(psi,x,dx2) + dq2*DT(psi,x,dx1)
    return res

##### $$\mathcal{H}^G|_{ (q_0,\psi_0,\boldsymbol{x}_0)}\Big((\Delta q^{(1)},\Delta \psi^{(1)},\Delta \boldsymbol{x}^{(1)}),(\Delta q^{(2)},\Delta \psi^{(2)},\Delta \boldsymbol{x}^{(2)})\Big)=$$
##### $$\Big\langle \nabla F|_{M(q_0,\psi_0,\boldsymbol{x}_0)}, D^2M|_{(q_0,\psi_0,\boldsymbol{x}_0)}\Big((\Delta q^{(1)},\Delta \psi^{(1)},\Delta \boldsymbol{x}^{(1)}),(\Delta q^{(2)},\Delta \psi^{(2)},\Delta \boldsymbol{x}^{(2)})\Big)\Big\rangle +$$
##### $$\mathcal{H}^F|_{M(q_0,\psi_0,\boldsymbol{x}_0)}\Big(DM|_{(q_0,\psi_0,\boldsymbol{x}_0)}(\Delta q^{(1)},\Delta \psi^{(1)},\Delta \boldsymbol{x}^{(1)}),DM|_{(q_0,\psi_0,\boldsymbol{x}_0)}(\Delta q^{(2)},\Delta \psi^{(2)},\Delta \boldsymbol{x}^{(2)})\Big)+$$
##### $$+\underline{2\lambda \textsf{Re}\langle \nabla (C (\Delta \psi^{(1)})),\nabla (C( \Delta \psi^{(2)}))\rangle}$$

In [None]:
def calc_beta(vars,grads,etas,d,gradF):
    (psi,q,x,rho) = (vars['psi'], vars['prb'], vars['shift'],vars['rho'])
    (dpsi1,dq1,dx1) = (grads['psi']*rho[0], grads['prb']*rho[1], grads['shift']*rho[2])
    (dpsi2,dq2,dx2) = (etas['psi']*rho[0], etas['prb']*rho[1], etas['shift']*rho[2])
    
    dm1 = DM(psi,q,x,dpsi1,dq1,dx1)
    dm2 = DM(psi,q,x,dpsi2,dq2,dx2)
    d2m1 = D2M(psi,q,x,dpsi1,dq1,dx1,dpsi2,dq2,dx2)
    d2m2 = D2M(psi,q,x,dpsi2,dq2,dx2,dpsi2,dq2,dx2)
    # sq = Sop(psi,x)*q
    
    top = redot(gradF,d2m1)        
    top += hessianF(dm1, dm2)    
    
    bottom = redot(gradF,d2m2)    
    bottom += hessianF(dm2, dm2)
    return top/bottom

def calc_alpha(vars,grads,etas,d,gradF):    
    (psi,q,x,rho) = (vars['psi'], vars['prb'], vars['shift'],vars['rho'])
    (dpsi1,dq1,dx1) = (grads['psi'], grads['prb'], grads['shift'])
    (dpsi2,dq2,dx2) = (etas['psi'], etas['prb'], etas['shift']) 
    top = -redot(dpsi1,dpsi2)-redot(dq1,dq2)-redot(dx1,dx2)
    
    (dpsi2,dq2,dx2) = (etas['psi']*rho[0], etas['prb']*rho[1], etas['shift']*rho[2])
    dm2 = DM(psi,q,x,dpsi2,dq2,dx2)
    d2m2 = D2M(psi,q,x,dpsi2,dq2,dx2,dpsi2,dq2,dx2)
    # sq = Sop(psi,x)*q
    
    print(f'{top=}')
    bottom = redot(gradF,d2m2)
    print(bottom)
    bottom += hessianF(dm2, dm2)
    print(bottom)
    return top/bottom, top, bottom

## debug functions

In [None]:
def plot_debug2(vars,etas,top,bottom,alpha,data):
    if show==False:
        return
    (psi,q,x,rho) = (vars['psi'], vars['prb'], vars['shift'], vars['rho'])
    (dpsi2,dq2,dx2) = (etas['psi'],etas['prb'],etas['shift'])
    npp = 7
    errt = cp.zeros(npp*2)
    errt2 = cp.zeros(npp*2)
    for k in range(0,npp*2):
        psit = psi+(alpha*k/(npp-1))*rho[0]*dpsi2
        qt = q+(alpha*k/(npp-1))*rho[1]*dq2
        xt = x+(alpha*k/(npp-1))*rho[2]*dx2

        errt[k] = np.linalg.norm(qt*Sop(psit,xt)-data)**2
                
        
    t = alpha*(cp.arange(2*npp))/(npp-1)
    errt2 = np.linalg.norm(q*Sop(psi,x)-data)**2
    errt2 = errt2 -top*t+0.5*bottom*t**2
    
    plt.plot(alpha.get()*cp.arange(2*npp).get()/(npp-1),errt.get(),'.')
    plt.plot(alpha.get()*cp.arange(2*npp).get()/(npp-1),errt2.get(),'.')
    plt.show()

def plot_debug3(shifts):
    if show==False:
        return
    plt.plot(shifts[0,:,0].get(),'r.')
    plt.plot(shifts[0,:,1].get(),'b.')
    plt.show()

def vis_debug(vars,i):
    mshow_complex(vars['psi'][0],show)
    mshow_complex(vars['prb'][0],show)
    # dxchange.write_tiff(np.angle(vars['psi'][0]).get(),f'{path_out}/crec_code_angle{flg}/{i:03}',overwrite=True)
    # dxchange.write_tiff(np.angle(vars['prb'][0]).get(),f'{path_out}/crec_prb_angle{flg}/{i:03}',overwrite=True)
    # dxchange.write_tiff(np.abs(vars['psi'][0]).get(),f'{path_out}/crec_code_abs{flg}/{i:03}',overwrite=True)
    # dxchange.write_tiff(np.abs(vars['prb'][0]).get(),f'{path_out}/crec_prb_abs{flg}/{i:03}',overwrite=True)
    # dxchange.write_tiff(np.abs(grads['rpsi'][0]).get(),f'{path_out}/crec_rpsi_abs{flg}/{i:03}',overwrite=True)
    # dxchange.write_tiff(np.angle(grads['rpsi'][0]).get(),f'{path_out}/crec_rpsi_angle{flg}/{i:03}',overwrite=True)
    # np.save(f'{path_out}/crec_shift{flg}_{i:03}',vars['shift'])

    
def err_debug(vars, data):    
    (psi,q,x) = (vars['psi'], vars['prb'], vars['shift'])
    err = np.linalg.norm(q*Sop(psi,x)-data)**2    
   # print(f'gradient norms (psi, prb, shift): {np.linalg.norm(grads['psi']):.2f}, {np.linalg.norm(grads['prb']):.2f}, {np.linalg.norm(grads['shift']):.2f}')                        
    return err

# Main CG loop (fifth rule)

In [None]:
def cg_holo(data, vars, pars):

    # data = np.sqrt(data)    
    
    erra = cp.zeros(pars['niter'])
    alphaa = cp.zeros(pars['niter'])    
    for i in range(pars['niter']):    
        
        gradF = gradientF(vars,data)        
        grads = gradients(vars,data,gradF)
        
        # if i==0:
        etas = {}
        etas['psi'] = -grads['psi']
        etas['prb'] = -grads['prb']
        etas['shift'] = -grads['shift']
        # else:      
        #     beta = calc_beta(vars, grads, etas, data, gradF)
        #     etas['psi'] = -grads['psi'] + beta*etas['psi']
        #     etas['prb'] = -grads['prb'] + beta*etas['prb']
        #     etas['shift'] = -grads['shift'] + beta*etas['shift']

        alpha,top,bottom = calc_alpha(vars, grads, etas, data, gradF) 
        # if i % pars['vis_step'] == 0:
        #     plot_debug2(vars,etas,top,bottom,alpha,data)
        # print(alpha)
        # alpha=1e-4

        vars['psi'] += vars['rho'][0]*alpha*etas['psi']
        vars['prb'] += vars['rho'][1]*alpha*etas['prb']
        vars['shift'] += vars['rho'][2]*alpha*etas['shift']
        
        if i % pars['err_step'] == 0:
            err = err_debug(vars, data)    
            
            print(f'{i}) {alpha=:.5f}, {vars['rho']} {err=:1.5e}',flush=True)
            erra[i] = err
            alphaa[i] = alpha

        if i % pars['vis_step'] == 0 and pars['vis_step'] != -1:
            vis_debug(vars, i)
            plot_debug3(vars['shift'])
            
        # t={}
        # t[0]=np.linalg.norm(grads['psi'])
        # t[1]=np.linalg.norm(grads['prb'])
        # t[2]=np.linalg.norm(grads['shift'])
        # for k in range(1,3):
        #     if t[k]>2*t[0]:
        #         vars['rho'][k]/=2
        #     elif t[k]<t[0]/2:
        #         vars['rho'][k]*=2        
            
    return vars,erra,alphaa

vars = {}
vars['psi'] = cp.ones([1,n,n],dtype='complex64')
vars['psi'][:] = data[0,0]
vars['prb'] = cp.ones([1,n,n],dtype='complex64')
vars['shift'] = cp.zeros([1,npos,2],dtype='float32')
vars['shift'][:,-1]=-0.5
vars['rho'] = [0,0,1]
data_rec = cp.array(data)#p.pad(cp.array(data),((0,0),(0,0),(pad,pad),(pad,pad)))
pars = {'niter': 32, 'err_step': 1, 'vis_step':-1}
vars,erra,alphaa = cg_holo(data_rec, vars, pars)  