In [None]:
import numpy as np
import cupy as cp
import dxchange
import matplotlib.pyplot as plt
from holotomocupy.holo import G, GT
from holotomocupy.tomo import R, RT, exp1j
from holotomocupy.magnification import M, MT
from holotomocupy.shift import S, ST
from holotomocupy.recon_methods import multiPaganin
from holotomocupy.utils import *
import holotomocupy.chunking as chunking


chunk = 16
chunking.global_chunk = chunk


%matplotlib inline

# Init data sizes and parametes of the PXM of ID16A

In [None]:
n = 1024  # object size in each dimension
ntheta = 720  # number of angles (rotations)


theta = cp.linspace(0, np.pi, ntheta).astype('float32')  # projection angles

# ID16a setup
ndist = 4

detector_pixelsize = 3e-6/1.5
energy = 17.05  # [keV] xray energy
wavelength = 1.2398419840550367e-09/energy  # [m] wave length

focusToDetectorDistance = 1.208  # [m]
sx0 = -2.493e-3
z1 = np.array([1.5335e-3, 1.7065e-3, 2.3975e-3, 3.8320e-3])[:ndist]-sx0
z2 = focusToDetectorDistance-z1
distances = (z1*z2)/focusToDetectorDistance
magnifications = focusToDetectorDistance/z1
voxelsize = detector_pixelsize/magnifications[0]*2048/n  # object voxel size

norm_magnifications = magnifications/magnifications[0]
# scaled propagation distances due to magnified probes
distances = distances*norm_magnifications**2

z1p = z1[0]  # positions of the probe for reconstruction
z2p = z1-np.tile(z1p, len(z1))
# magnification when propagating from the probe plane to the detector
magnifications2 = (z1p+z2p)/z1p
# propagation distances after switching from the point source wave to plane wave,
distances2 = (z1p*z2p)/(z1p+z2p)
norm_magnifications2 = magnifications2/(z1p/z1[0])  # normalized magnifications
# scaled propagation distances due to magnified probes
distances2 = distances2*norm_magnifications2**2
distances2 = distances2*(z1p/z1)**2

# allow padding if there are shifts of the probe
pad = n//8
show=True
# sample size after demagnification
ne = int(np.ceil((n+2*pad)/norm_magnifications[-1]/8))*8  # make multiple of 8
center = ne/2  # rotation axis

## Read data

In [None]:
data = np.load(f'/data/vnikitin/phantoms/data{n}.npy')
ref = np.load(f'/data/vnikitin/phantoms/ref{n}.npy')
shifts = np.load(f'/data/vnikitin/phantoms/shifts{n}.npy')
shifts_ref = np.load(f'/data/vnikitin/phantoms/shifts_ref{n}.npy')
shifts_err = np.load(f'/data/vnikitin/phantoms/shifts_err{n}.npy')

shifts_correct = shifts.copy()
shifts += shifts_err/2
print(data.shape)

### Correction with the reference

In [None]:
rdata = data/ref
for k in range(ndist):
    mshow_complex(data[0,k]+1j*rdata[0,k],show)

### Operators

In [None]:
def L2op(psi,j):
    return G(psi, wavelength, voxelsize, distances[j],'symmetric')[:, pad:n+pad, pad:n+pad]

def LT2op(data,j):
    psir = np.pad(data, ((0, 0), (pad, pad), (pad, pad))).astype('complex64')
    psir = GT(psir, wavelength, voxelsize, distances[j],'symmetric')        
    return psir

def L1op(psi,j):
    return G(psi, wavelength, voxelsize, distances2[j],'symmetric')
    
def LT1op(data,j):
    psir = GT(data, wavelength, voxelsize, distances2[j],'symmetric')        
    return psir

def S2op(psi,shift):
    return S(psi, shift) 

def ST2op(data,shift):
    return ST(data,shift)
    
def Mop(psi, j):
    return M(psi, norm_magnifications[j]*ne/(n+2*pad), n+2*pad)                        

def MTop(psi,j):
    return MT(psi, norm_magnifications[j]*ne/(n+2*pad), ne)        

def Rop(u):     
    data = np.empty([ntheta,ne,ne],dtype='complex64')   
    for ichunk in range(0,int(np.ceil(ne/chunk))):
        st = ichunk*chunk
        end = min((ichunk+1)*chunk,ne)    
        u_gpu = cp.asarray(u[st:end])        
        data[:,st:end] = cp.asnumpy(R(u_gpu,theta,center).swapaxes(0,1))
    return data    

def RTop(data):     
    u = np.empty([ne,ne,ne],dtype='complex64')   
    for ichunk in range(0,int(np.ceil(ne/chunk))):
        st = ichunk*chunk
        end = min((ichunk+1)*chunk,ne)    
        data_gpu = cp.asarray(data[:,st:end])
        u[st:end] = cp.asnumpy(RT(data_gpu.swapaxes(0,1),theta,center))
    return u

arr1 = cp.random.random([1,n+2*pad,n+2*pad]).astype('complex64')
arr2 = L1op(arr1,0)
arr3 = LT1op(arr2,0)
print(f'{np.sum(arr1*np.conj(arr3))}==\n{np.sum(arr2*np.conj(arr2))}')

arr1 = cp.random.random([chunk,ne,ne]).astype('complex64')
shifts_test = cp.random.random([chunk,2]).astype('float32')
arr2 = S2op(arr1,shifts_test)
arr3 = ST2op(arr2,shifts_test)
print(f'{np.sum(arr1*np.conj(arr3))}==\n{np.sum(arr2*np.conj(arr2))}')

arr1 = cp.random.random([chunk,ne,ne]).astype('complex64')
arr2 = Mop(arr1,1)
arr3 = MTop(arr2,1)
print(f'{np.sum(arr1*np.conj(arr3))}==\n{np.sum(arr2*np.conj(arr2))}')

arr1 = cp.random.random([chunk,n+2*pad,n+2*pad]).astype('complex64')
arr2 = L2op(arr1,0)
arr3 = LT2op(arr2,0)
print(f'{np.sum(arr1*np.conj(arr3))}==\n{np.sum(arr2*np.conj(arr2))}')

arr1 = np.random.random([ne,ne,ne]).astype('complex64')
arr2 = Rop(arr1)
arr3 = RTop(arr2)
print(f'{np.sum(arr1*np.conj(arr3))}==\n{np.sum(arr2*np.conj(arr2))}')
arr1=arr2=arr3=[]

### Scale images and shift them based on random shifts

In [None]:
for k in range(ndist):
    r = M(rdata[:,k].astype('complex64'),1/norm_magnifications[k])
    rdata[:,k] = ST(r,shifts[:,k],'symmetric').real

for k in range(ndist):
    mshow(rdata[0,0]-rdata[0,k],show)


In [None]:
# distances should not be normalized
distances_pag = (distances/norm_magnifications**2)[:ndist]
recMultiPaganin = multiPaganin(rdata, distances_pag, wavelength, voxelsize, 100, 1.2e-2)
mshow(recMultiPaganin[0],show)

rdata = []

### initial guess for the object

In [None]:
from holotomocupy.proc import dai_yuan, linear
def cg_tomo(data, u, pars):
    def minf(u):
        res = np.linalg.norm(Rop(u)-data)**2
        return res
    for i in range(pars['niter']):
        grad = 2*RTop(Rop(u)-data)#/np.float32(ne*ntheta)
        if i == 0:
            d = -grad
        else:
            d = dai_yuan(d,grad,grad0)
        grad0 = grad.copy()
        gamma = 0.5
        u = linear(u,d,1,gamma)
        if i % pars['err_step'] == 0:
            err = minf(u)
            print(f'{i}) {gamma=}, {err=:1.5e}',flush=True)

        if i % pars['vis_step'] == 0:
            mshow_complex(u[ne//2],show)            
    return u


rec_init = np.zeros([ne,ne,ne],dtype='complex64')
recMultiPaganin = np.pad(recMultiPaganin,((0,0),(ne//2-n//2,ne//2-n//2),(ne//2-n//2,ne//2-n//2)),'edge')

pars = {'niter': 49, 'err_step': 48, 'vis_step': 48}
rec_init = cg_tomo(recMultiPaganin, rec_init, pars)   
mshow_complex(rec_init[ne//2]+1j*rec_init[:,ne//2],show)

recMultiPaganin = []

# Fifth rule

##### $D T_c|_{{{z}_0}}(\Delta {z})=-2\pi iC\Big(\mathcal{F}^{-1}\big({\Delta z \cdot \xi}) e^{-2\pi i  {z}_0\cdot {\xi}}\hat{c}({\xi})\big)\Big)=-2\pi i C\Big(\mathcal{F}^{-1}\big((\Delta z_1 {\xi_1}+\Delta z_2 {\xi_2}) e^{-2\pi i  {z}_0\cdot {\xi}}\hat{c}({\xi})\big)\Big)$
##### $ D^2{T_c}|_{{{z}_0}}(\Delta{z},\Delta{w})=-4\pi^2C(\mathcal{F}^{-1}((\Delta{z}\cdot\xi)(\Delta{w}\cdot\xi)e^{-2\pi i  {z}_0\cdot {\xi}}\hat{c}))$
##### $=-4\pi^2C(\mathcal{F}^{-1}((\Delta{z_1}\Delta{w_1}\xi_1^2 + (\Delta{z_1}\Delta{w_2}+\Delta{z_2}\Delta{w_1})\xi_1\xi_2+\Delta{z_2}\Delta{w_2}\xi_2^2)\hat{c}))$

In [None]:
def DT(psi,x,dx):
    psir = psi.copy()
    xi1 = cp.fft.fftfreq(ne).astype('float32')
    [xi2, xi1] = cp.meshgrid(xi1, xi1)
    
    xj = x[:,:,np.newaxis,np.newaxis]
    dxj = dx[:,:,np.newaxis,np.newaxis]

    pp = cp.exp(-2*cp.pi*1j*(xi1*xj[:, 0]+xi2*xj[:, 1]))    
    xiall = xi1*dxj[:,0]+xi2*dxj[:,1]

    psir = cp.fft.ifft2(pp*xiall*cp.fft.fft2(psir))   
    psir = -2*np.pi*1j*psir
    return psir

def D2T(psi,x,dx1,dx2):
    psir = psi.copy()
    
    xi1 = cp.fft.fftfreq(ne).astype('float32')
    [xi2, xi1] = cp.meshgrid(xi1, xi1)
    xj = x[:,:,np.newaxis,np.newaxis]
    dx1j = dx1[:,:,np.newaxis,np.newaxis]
    dx2j = dx2[:,:,np.newaxis,np.newaxis]

    pp = cp.exp(-2*cp.pi*1j*(xi1*xj[:, 0]+xi2*xj[:, 1]))    
    xiall = xi1**2*dx1j[:,0]*dx2j[:,0]+ \
            xi1*xi2*(dx1j[:,0]*dx2j[:,1]+dx1j[:,1]*dx2j[:,0])+ \
            xi2**2*dx1j[:,1]*dx2j[:,1]

    psir = cp.fft.ifft2(pp*xiall*cp.fft.fft2(psir))   
    psir = -4*np.pi**2*psir
    return psir

\begin{align*}
 & DV|_{(q_0,u_0,{x}_0)}(\Delta q, \Delta u,\Delta{x})=L_1(q_0)\cdot M_j(T_{e^{i R (u_0)}\cdot(iR({\Delta u}))}(z_0)+ DT_{{e^{iR(u_0)}}}|_{{{z}_0}}( \Delta {z}))+L_1(\Delta q)\cdot M_j(T_{{e^{iR(u_0)}}}({{z}_0}))
\end{align*}

\begin{align*}
 & D^2V|_{(q_0,u_0,{x}_0)}\big((\Delta q^{(1)}, \Delta u^{(1)},\Delta{x^{(1)}}),(\Delta q^{(2)}, \Delta u^{(2)},\Delta{x}{(2)})\big)=\\&L_1(q_0)\cdot M_j(T_{e^{i R (u_0)}\cdot(-\frac{1}{2}(R({\Delta u^{(1)}})R({\Delta u^{(2)}})))}({{z}_0})+DT_{e^{i R (u_0)}\cdot\big(iR({\Delta u^{(1)}})\big)}|_{{{z}_0}}( \Delta {z}^{(2)})+DT_{e^{i R (u_0)}\cdot\big(iR({\Delta u^{(2)}})\big)}|_{{{z}_0}}( \Delta {z}^{(1)})+\frac{1}{2}\left(D^2{T_{e^{iR(u_0)}}}(\Delta z^{(1)},\Delta z^{(2)})\right))+\\&L_1(\Delta q^{(1)})\cdot M_j(T_{e^{i R (u_0)}\cdot(iR({\Delta u^{(2)}}))}+ DT_{{e^{iR(u_0)}}}|_{{{z}_0}}( \Delta {z}^{(2)}))+L_1(\Delta q^{(2)})\cdot M_j(T_{e^{i R (u_0)}\cdot(iR({\Delta u^{(1)}}))}+ DT_{{e^{iR(u_0)}}}|_{{{z}_0}}( \Delta {z}^{(1)}))
\end{align*}

In [None]:
def DV(q,eRu,x,dq,Rdu,dx,j):
    Lq = L1op(q,j)
    Ldq = L1op(dq,j)    
    t1 = S2op(eRu*(1j*Rdu),x)+DT(eRu,x,dx)    
    t2 = S2op(eRu,x)
    return Lq*Mop(t1,j)+Ldq*Mop(t2,j)

def D2V(q,eRu,x,dq1,Rdu1,dx1,dq2,Rdu2,dx2,j):
    Lq = L1op(q,j)
    Ldq1 = L1op(dq1,j)    
    Ldq2 = L1op(dq2,j)    

    t1 = S2op(-0.5*eRu*Rdu1*Rdu2,x)+DT(eRu*1j*Rdu1,x,dx2)+DT(eRu*1j*Rdu2,x,dx1)+0.5*D2T(eRu,x,dx1,dx2)
    t2 = S2op(eRu*1j*Rdu2,x)+DT(eRu,x,dx2)
    t3 = S2op(eRu*1j*Rdu1,x)+DT(eRu,x,dx1)
    t1 = Lq*Mop(t1,j)
    t2 = Ldq1*Mop(t2,j)
    t3 = Ldq2*Mop(t3,j)
    return t1+t2+t3

\begin{equation*}
     \nabla_{q} H|_{(q_0,u_0,{x}_0)}= L_1^*\left(\overline{M_jS_{{x}_{0}}(e^{iRu})}\cdot \nabla F|_{V(q_0,u_0,{x}_0)}\right)
\end{equation*}

\begin{equation*}
  \nabla_{u} H|_{(q_0,u_0,{x}_0)}=-iR^*(\overline{e^{iRu_0}}S_{{x}_{0}}^*M_j^*\left(\overline{L_1(q_0)}\cdot \nabla F|_{V(q_0,u_0,{x}_0)}\right))
\end{equation*}

\begin{equation*}
\begin{aligned}
  \nabla_{{x}} H|_{(q_0,u_0,{x}_0)}=-2\pi \mathsf{Im} \Big(\big( &\Big\langle (\nabla F|_{V(q_0,u_0,{x}_0)}), L_1(q_0)\cdot  M_j(C(\mathcal{F}^{-1}(\xi_1 e^{-2\pi i x_{0}\cdot {\xi}}\widehat{e^{iRu}})))\Big\rangle,\\&\Big\langle (\nabla F|_{V(q_0,c_0,{x}_0)}),L_1(q_0)\cdot M_j(C(\mathcal{F}^{-1}(\xi_2 e^{-2\pi i x_{0}\cdot {\xi}}\widehat{e^{iRu}}))) \Big\rangle\big)\Big).
\end{aligned}
\end{equation*}
##### $$\tilde{L} = L_{2,j}(L_{1,j}(q)\cdot M_j(S_{{x}_{j,k}}(e^{iR_k(u)})))$$
##### $$\nabla F=2 \left(L^*_2\left( (L_2(\tilde{L}))-\tilde D\right)\right), \text{where } \tilde D = D \frac{(L_2(\tilde{L}))}{|L_2(\tilde{L})|}$$

   

In [None]:
def gradientq(q,x,gradF,eRu,j):
    t1 = np.conj(Mop(S2op(eRu,x),j))*gradF
    return np.sum(LT1op(t1,j),axis=0)[np.newaxis]

def gradientu(q,x,gradF,eRu,j):
    Lq = L1op(q,j)
    t1 = np.conj(eRu)*ST2op(MTop(np.conj(Lq)*gradF,j),x)    
    return t1
  
def gradientx(q,x,gradF,eRu,j):
    Lq = L1op(q, j)
    
    gradx = cp.zeros([x.shape[0],2],dtype='float32')    
    xi1 = cp.fft.fftfreq(2*ne).astype('float32')
    [xi2, xi1] = cp.meshgrid(xi1, xi1)
    psir = cp.pad(eRu, ((0, 0), (ne//2, ne//2), (ne//2, ne//2)))
    
    xj = cp.asarray(x[:,:,np.newaxis,np.newaxis])
    pp = cp.exp(-2*cp.pi*1j*(xi1*xj[:, 0]+xi2*xj[:, 1]))                    
    t = cp.fft.ifft2(pp*xi1*cp.fft.fft2(psir))[:, ne//2:-ne//2, ne//2:-ne//2]      
    t = Mop(t, j)
    gradx[:,0] = -2*np.pi*imdot(gradF,Lq*t,axis=(1,2))    

    t = cp.fft.ifft2(pp*xi2*cp.fft.fft2(psir))[:, ne//2:-ne//2, ne//2:-ne//2]              
    t = Mop(t, j)
    gradx[:,1] = -2*np.pi*imdot(gradF,Lq*t,axis=(1,2))    

    return gradx

def gradientF(vars,d):
    (q,eRu,x) = (vars['prb'], vars['eRu'], vars['shift'])
    res = np.zeros([ntheta,ndist,n+2*pad,n+2*pad],dtype='complex64')
    for ichunk in range(0,int(np.ceil(ntheta/chunk))):
        st = ichunk*chunk
        end = min((ichunk+1)*chunk,ntheta)    
        eRu_gpu = cp.asarray(eRu[st:end])
        for j in range(ndist):
            Lq = L1op(q, j)
            x_gpu = cp.asarray(x[st:end,j])
            d_gpu = cp.asarray(d[st:end,j])            
            L2psi = L2op(Lq*Mop(S2op(eRu_gpu,x_gpu),j),j)
            td = d_gpu*(L2psi/np.abs(L2psi))
            res[st:end,j] = cp.asnumpy(2*LT2op(L2psi - td,j))
    return res

def gradients(vars,gradF):    
    (q,eRu,x,rho) = (vars['prb'], vars['eRu'], vars['shift'], vars['rho'])

    grads = {}
    grads['prb'] = cp.zeros([1,n+2*pad,n+2*pad],dtype='complex64')
    grads['u'] = np.zeros([ntheta,ne,ne],dtype='complex64')
    grads['shift'] = np.zeros([ntheta,ndist,2],dtype='float32')
    
    for ichunk in range(0,int(np.ceil(ntheta/chunk))):
        st = ichunk*chunk
        end = min((ichunk+1)*chunk,ntheta)
        eRu_gpu = cp.asarray(eRu[st:end])            
        for j in range(ndist):
            x_gpu = cp.asarray(x[st:end,j])
            gradF_gpu = cp.asarray(gradF[st:end,j])
            grads['prb'] += rho[0]*gradientq(q,x_gpu,gradF_gpu,eRu_gpu,j)
            grads['u'][st:end] += cp.asnumpy(gradientu(q,x_gpu,gradF_gpu,eRu_gpu,j))
            grads['shift'][st:end,j] = cp.asnumpy(rho[1]*gradientx(q,x_gpu,gradF_gpu,eRu_gpu,j))
    grads['u'] = -1j*RTop(grads['u'])
    grads['Ru'] = Rop(grads['u'])

    return grads

##### $$\frac{1}{2}\mathcal{H}^F|_{x_0}(y,z)= \left\langle \mathbf{1}-d_{0}, \mathsf{Re}({L_2(y)}\overline{L(z)})\right\rangle+\left\langle d_{0},(\mathsf{Re} (\overline{l_0}\cdot L_2(y)))\cdot (\mathsf{Re} (\overline{l_0}\cdot L_2(z)))\right\rangle.$$
##### $$l_0=L_2(x_0)/|L_2(x_0)|$$
##### $$d_0=d/|L_2(x_0)|$$


In [None]:
def hessianF(hpsi,hpsi1,hpsi2,data,j):
    Lpsi = L2op(hpsi,j)        
    Lpsi1 = L2op(hpsi1,j)
    Lpsi2 = L2op(hpsi2,j)    
    l0 = Lpsi/np.abs(Lpsi)
    d0 = data/np.abs(Lpsi)
    v1 = np.sum((1-d0)*reprod(Lpsi1,Lpsi2))
    v2 = np.sum(d0*reprod(l0,Lpsi1)*reprod(l0,Lpsi2))    
    return 2*(v1+v2)

\begin{equation*}\begin{aligned}
&\mathcal{H}^H|_{ (q_0,u_0,{x}_0)}\Big((\Delta q^{(1)},\Delta u^{(1)},\Delta {x}^{(1)}),(\Delta q^{(2)},\Delta u^{(2)},\Delta {x}^{(2)})\Big)=\\&\Big\langle \nabla F|_{V(q_0,u_0,{x}_0)}, D^2V|_{(q_0,u_0,{x}_0)}\Big((\Delta q^{(1)},\Delta u^{(1)},\Delta {x}^{(1)}),(\Delta q^{(2)},\Delta u^{(2)},\Delta {x}^{(2)})\Big)\Big\rangle +\\&\mathcal{H}^F|_{V(q_0,u_0,{x}_0)}\Big(DV|_{(q_0,u_0,{x}_0)}(\Delta q^{(1)},\Delta u^{(1)},\Delta {x}^{(1)}),DV|_{(q_0,u_0,{x}_0)}(\Delta q^{(2)},\Delta u^{(2)},\Delta {x}^{(2)})\Big).
\end{aligned}
\end{equation*}

In [None]:
def calc_beta(vars,grads,etas,data,gradF):
    (q,eRu,x,rho) = (vars['prb'], vars['eRu'], vars['shift'], vars['rho'])
    (dq1,Rdu1,dx1) = (grads['prb']*rho[0], grads['Ru'], grads['shift']*rho[1])
    (dq2,Rdu2,dx2) = (etas['prb']*rho[0], etas['Ru'], etas['shift']*rho[1])
    
    top = 0
    bottom = 0
    for ichunk in range(0,int(np.ceil(ntheta/chunk))):
        st = ichunk*chunk
        end = min((ichunk+1)*chunk,ntheta)
        eRu_gpu = cp.asarray(eRu[st:end])
        Rdu1_gpu = cp.asarray(Rdu1[st:end])
        Rdu2_gpu = cp.asarray(Rdu2[st:end])

        for j in range(ndist):                                
            data_gpu = cp.asarray(data[st:end,j])
            x_gpu = cp.asarray(x[st:end,j])
            dx1_gpu = cp.asarray(dx1[st:end,j])
            dx2_gpu = cp.asarray(dx2[st:end,j])
            gradF_gpu = cp.asarray(gradF[st:end,j])

            L1psi = L1op(q,j)*Mop(S2op(eRu_gpu,x_gpu),j)        

            dv1 = DV(q,eRu_gpu,x_gpu,dq1,Rdu1_gpu,dx1_gpu,j)    
            d2v1 = D2V(q,eRu_gpu,x_gpu,dq1,Rdu1_gpu,dx1_gpu,dq2,Rdu2_gpu,dx2_gpu,j) 
            dv2 = DV(q,eRu_gpu,x_gpu,dq2,Rdu2_gpu,dx2_gpu,j)    
            d2v2 = D2V(q,eRu_gpu,x_gpu,dq2,Rdu2_gpu,dx2_gpu,dq2,Rdu2_gpu,dx2_gpu,j) 

            top += redot(gradF_gpu,d2v1)+hessianF(L1psi,dv1,dv2,data_gpu,j)                 
            bottom += redot(gradF_gpu,d2v2)+hessianF(L1psi,dv2,dv2,data_gpu,j)                     
    return float(top/bottom)

def _redot(a,b):
    res = 0    
    for ichunk in range(0,int(np.ceil(a.shape[0]/chunk))):
        st = ichunk*chunk
        end = min((ichunk+1)*chunk,a.shape[0])
        a_gpu = cp.asarray(a[st:end])
        b_gpu = cp.asarray(b[st:end])
        res+=redot(a_gpu,b_gpu)
    return res

def calc_alpha(vars,grads,etas,data,gradF):    
    (q,eRu,x,rho) = (vars['prb'], vars['eRu'], vars['shift'], vars['rho'])
    (dq1,du1,dx1) = (grads['prb'], grads['u'], grads['shift'])
    (dq2,du2,dx2,Rdu2) = (etas['prb'], etas['u'], etas['shift'], etas['Ru'])
    
    top = -redot(dq1,dq2)-_redot(du1,du2)-redot(dx1,dx2)        
    bottom=0
    for ichunk in range(0,int(np.ceil(ntheta/chunk))):
        q = cp.asarray(q)
        dq2 = cp.asarray(dq2)
        st = ichunk*chunk
        end = min((ichunk+1)*chunk,ntheta)
        eRu_gpu = cp.asarray(eRu[st:end])
        Rdu2_gpu = cp.asarray(Rdu2[st:end])
        for j in range(ndist):                    
            data_gpu = cp.asarray(data[st:end,j])
            x_gpu = cp.asarray(x[st:end,j])
            dx2_gpu = cp.asarray(dx2[st:end,j])            
            gradF_gpu = cp.asarray(gradF[st:end,j])
            Lq = L1op(q,j)
            L1psi = Lq*Mop(S2op(eRu_gpu,x_gpu),j)                
            rdq2 = dq2*rho[0]
            rdx2_gpu = dx2_gpu*rho[1]
            d2v2 = D2V(q,eRu_gpu,x_gpu,rdq2,Rdu2_gpu,rdx2_gpu,rdq2,Rdu2_gpu,rdx2_gpu,j) 
            dv2 = DV(q,eRu_gpu,x_gpu,rdq2,Rdu2_gpu,rdx2_gpu,j)
        
            bottom += redot(gradF_gpu,d2v2)+hessianF(L1psi,dv2,dv2,data_gpu,j)                     
        
    return float(top/bottom), float(top), float(bottom)

In [None]:
def minf(q,u,x,data):
    eRu = exp1j(Rop(u))
    res = 0
    for ichunk in range(0,int(np.ceil(ntheta/chunk))):
        st = ichunk*chunk
        end = min((ichunk+1)*chunk,ntheta)
        eRu_gpu = cp.asarray(eRu[st:end])
        for j in range(ndist):               
            Lq = L1op(q,j)                             
            data_gpu = cp.asarray(data[st:end,j])
            x_gpu = cp.asarray(x[st:end,j])        
            L2psi = L2op(Lq*Mop(S2op(eRu_gpu,x_gpu),j),j)
            res += np.linalg.norm(np.abs(L2psi)-data_gpu)**2
    return float(res)
    
def plot_debug2(vars,etas,top,bottom,alpha,data):
    (q,u,x,rho) = (vars['prb'], vars['u'], vars['shift'],vars['rho'])
    (dq2,du2,dx2) = (etas['prb'], etas['u'], etas['shift'])
    npp = 3
    errt = np.zeros(npp*2)
    errt2 = np.zeros(npp*2)
    for k in range(0,npp*2):
        ut = u+(alpha*k/(npp-1))*du2
        qt = q+(alpha*k/(npp-1))*dq2*rho[0]
        xt = x+(alpha*k/(npp-1))*dx2*rho[1]
        errt[k] = minf(qt,ut,xt,data)    

    t = alpha*(np.arange(2*npp))/(npp-1)
    errt2 = minf(q,u,x,data)-top*t+0.5*bottom*t**2
    
    plt.plot((alpha*np.arange(2*npp)/(npp-1)),errt,'.')
    plt.plot((alpha*np.arange(2*npp)/(npp-1)),errt2,'.')
    plt.show()

def plot_debug3(shifts):
    fig, axs = plt.subplots(1, 2, figsize=(9, 3))
    for k in range(ndist):        
        axs[0].plot(shifts_correct[:,k,0]-shifts[:,k,0],'.')
        axs[1].plot(shifts_correct[:,k,1]-shifts[:,k,1],'.')
    plt.show()

def vis_debug(vars,i):
    mshow_complex(vars['u'][ne//2,ne//2-n//2:ne//2+n//2,ne//2-n//2:ne//2+n//2].real+1j*
                  vars['u'][ne//2-n//2:ne//2+n//2,ne//2,ne//2-n//2:ne//2+n//2].real,show)
    mshow_polar(vars['prb'][0],show)
    dxchange.write_tiff(np.real(vars['u'])[ne//2],f'/data/vnikitin/phantoms/urech_re{n}/{i:03}',overwrite=True)
    dxchange.write_tiff(np.real(vars['u'])[:,ne//2],f'/data/vnikitin/phantoms/urecv_re{n}/{i:03}',overwrite=True)
    dxchange.write_tiff(np.real(vars['u']),f'/data/vnikitin/phantoms/urec_re{n}/{i:03}',overwrite=True)
    dxchange.write_tiff(cp.asnumpy(np.angle(vars['prb'])),f'/data/vnikitin/phantoms/prbrec_angle{n}/{i:03}',overwrite=True)
    dxchange.write_tiff(np.real(vars['u']),f'/data/vnikitin/phantoms/urec_real{n}/{i:03}',overwrite=True)
    dxchange.write_tiff(cp.asnumpy(np.abs(vars['prb'])),f'/data/vnikitin/phantoms/prbrec_abs{n}/{i:03}',overwrite=True)
    
def err_debug(vars, grads, data):  
    (q,u,x) = (vars['prb'], vars['u'], vars['shift'])  
    (dq,du,dx) = (grads['prb'], grads['u'], grads['shift'])  
    err = minf(q,u,x,data)        
    print(f'gradient norms (prb, u, shift): {np.linalg.norm(dq):.2f}, {np.linalg.norm(du):.2f}, {np.linalg.norm(dx):.2f}',flush=True)                        
    return err

## Main CG loop (fifth rule)

In [None]:
def cg_holo(data, vars, pars):
    data = np.sqrt(data)    
    erra = np.zeros(pars['niter'])
    
    vars['Ru'] = Rop(vars['u'])            
    for i in range(pars['niter']):              
        vars['eRu'] = exp1j(vars['Ru'])
        gradF = gradientF(vars,data)
        grads = gradients(vars,gradF)
        if i==0:
            etas = {}
            etas['u'] = -grads['u']
            etas['prb'] = -grads['prb']
            etas['shift'] = -grads['shift']
            etas['Ru'] = -grads['Ru']
        else:      
            beta = calc_beta(vars, grads, etas, data, gradF)
            etas['u'] = -grads['u'] + beta*etas['u']
            etas['prb'] = -grads['prb'] + beta*etas['prb']
            etas['shift'] = -grads['shift'] + beta*etas['shift']    
            etas['Ru'] = -grads['Ru'] + beta*etas['Ru']
            
        alpha,top,bottom = calc_alpha(vars, grads, etas, data, gradF) 
        if i % pars['vis_step'] == 0 and pars['vis_step'] != -1:
            plot_debug2(vars,etas,top,bottom,alpha,data)
        vars['u'] += alpha*etas['u']
        vars['Ru'] += alpha*etas['Ru']
        vars['prb'] += vars['rho'][0]*alpha*etas['prb']
        vars['shift'] += vars['rho'][1]*alpha*etas['shift']
        
        if i % pars['err_step'] == 0 and pars['err_step'] != -1:
            err = err_debug(vars, grads, data)    
            print(f'{i}) {vars['rho']} {alpha=:.5f}, {err=:1.5e}',flush=True)
            erra[i] = err
        
        if i % pars['vis_step'] == 0 and pars['vis_step'] != -1:
            vis_debug(vars, i)
            plot_debug3(vars['shift'])     

        t={}
        
        t[0]=np.linalg.norm(grads['prb'])
        t[1]=np.linalg.norm(grads['shift'])
        t[2]=np.linalg.norm(grads['u'])
        
        for k in range(2):
            if t[k]>2*t[2]:
                vars['rho'][k]/=2
            elif t[k]<t[2]/2:
                vars['rho'][k]*=2     

    return vars,erra

vars={}
vars['u'] = rec_init.copy()
vars['prb'] = cp.ones([1,n+2*pad,n+2*pad],dtype='complex64')
vars['shift'] = shifts.copy()
vars['rho'] = [0.05,0.1]
data_rec = data.copy()

pars = {'niter': 2049, 'err_step': 8, 'vis_step': 16}
# %load_ext line_profiler
# %lprun -f cg_holo cg_holo(data_rec, vars, pars)   

vars,erra = cg_holo(data_rec, vars, pars)   