In [None]:
import numpy as np
import cupy as cp
import h5py
from holotomocupy.holo import G, GT
from holotomocupy.shift import S, ST
from holotomocupy.recon_methods import multiPaganin
from holotomocupy.utils import *
from holotomocupy.proc import remove_outliers
##!jupyter nbconvert --to script config_template.ipynb

# Init data sizes and parametes of the PXM of ID16A

In [None]:
n = 512  # detector size
ne = n+n//4
energy = 33.35  # [keV] xray energy
wavelength = 1.2398419840550367e-09/energy  # [m] wave length
focusToDetectorDistance = 1.28  # [m]
ndist = 16
#distances = np.array([0.0029432,0.00306911,0.00357247,0.00461673])[:ndist] # [m]
distances = np.linspace(3e-3,5e-3,ndist)


magnification = 400
detector_pixelsize = 3.03751e-6
voxelsize = detector_pixelsize/magnification*2048/n  # object voxel size

distances2 = distances[-1]-distances
path = f'/data/vnikitin/modeling/siemens{n}'
show=True
print(distances+distances2)

## Read data

In [None]:
data = cp.load(f'{path}/data.npy')
ref = cp.load(f'{path}/ref.npy')
psi = cp.load(f'{path}/psi.npy')
prb = cp.load(f'{path}/prb.npy')
rdata = data/(ref+1e-3)
mshow(rdata[0,0],show,vmax=3)

# Construct operators


In [None]:
def L1op(psi):
    data = cp.zeros([1,ndist, ne, ne], dtype='complex64')
    for i in range(ndist):
        psir = cp.array(psi)           
        psir = G(psir, wavelength, voxelsize, distances2[i],'symmetric')        
        data[:, i] = psir#[:,ne//2-n//2:ne//2+n//2,ne//2-n//2:ne//2+n//2]
    return data

def L1Top(data):
    psi = cp.zeros([1, ne, ne], dtype='complex64')
    for j in range(ndist):
        datar = cp.array(data[:, j])
        datar = GT(datar, wavelength, voxelsize, distances2[j],'symmetric')        
        psi += datar
    return psi

def Lop(psi):
    data = cp.zeros([1,ndist, n, n], dtype='complex64')
    for i in range(ndist):
        psir = psi[:,i].copy()
        v = cp.ones(ne,dtype='float32')
        v[:(ne-n)//2] = cp.sin(cp.linspace(0,1,(ne-n)//2)*cp.pi/2)
        v[-(ne-n)//2:] = cp.cos(cp.linspace(0,1,(ne-n)//2)*cp.pi/2)
        v = cp.outer(v,v)
        psir*=v
        psir = G(psir, wavelength, voxelsize, distances[i],'constant')        
        data[:, i] = psir[:,ne//2-n//2:ne//2+n//2,ne//2-n//2:ne//2+n//2]
    return data

def LTop(data):
    psi = cp.zeros([1, ndist, ne, ne], dtype='complex64')
    for j in range(ndist):
        datar = cp.array(cp.pad(data[:, j],((0,0),(ne//2-n//2,ne//2-n//2),(ne//2-n//2,ne//2-n//2)))).astype('complex64')        
        datar = GT(datar, wavelength, voxelsize, distances[j],'constant')        

        v = cp.ones(ne,dtype='float32')
        v[:(ne-n)//2] = cp.sin(cp.linspace(0,1,(ne-n)//2)*cp.pi/2)
        v[-(ne-n)//2:] = cp.cos(cp.linspace(0,1,(ne-n)//2)*cp.pi/2)
        v = cp.outer(v,v)        
        datar *= v

        psi[:,j] += datar
    
    return psi

def Cop(psi):
    return psi[:,ne//2-n//2:ne//2+n//2,ne//2-n//2:ne//2+n//2]

def CTop(psi):
    return cp.pad(psi,((0,0),(ne//2-n//2,ne//2-n//2),(ne//2-n//2,ne//2-n//2)))

# adjoint tests
arr1 = (cp.random.random([1,ne,ne])+1j*cp.random.random([1,ne,ne])).astype('complex64')
arr2 = L1op(arr1)
arr3 = L1Top(arr2)
print(f'{np.sum(arr1*np.conj(arr3))}==\n{np.sum(arr2*np.conj(arr2))}')

arr1 = (cp.random.random([1,ndist,ne,ne])+1j*cp.random.random([1,ndist,ne,ne])).astype('complex64')
arr2 = Lop(arr1)
arr3 = LTop(arr2)


print(f'{np.sum(arr1*np.conj(arr3))}==\n{np.sum(arr2*np.conj(arr2))}')


## Reconstruction with 
### $$F(\psi) = \||L((C^*(\psi)+\psi_{fr}))| -d\|^2_2$$


## Gradients

#### $$\nabla F=2 \left(L^*\left( L(\psi)-\tilde d\right)\right).$$
#### where $$\tilde d = d \frac{L(\psi_0)}{|L(\psi_0)|}$$

#### $$\nabla G_{\psi} = C(\overline{L_1(q)}\nabla F|_M(\psi_0,q_0)) $$
#### $$\nabla G_{q} = L_1^T(\overline{(C^T(\psi)+\psi_{fr})}\nabla F|_M(\psi_0,q_0))) $$


In [None]:
def gradientF(vars,d):
    (psi,psifr,q) = (vars['psi'],vars['psifr'],vars['prb'])
    Lq = L1op(q)
    psie = CTop(psi)+psifr
    Lpsi = Lop(Lq*psie)
    td = d*(Lpsi/np.abs(Lpsi))
    res = 2*LTop(Lpsi - td)
    return res

def gradientq(psi,psifr,gradF):
    psie = CTop(psi)+psifr
    t1 = np.conj(psie)*gradF
    return np.sum(L1Top(t1),axis=0)[np.newaxis]

def gradientpsi(q,gradF):
    Lq = L1op(q)    
    t1 = Cop(np.sum(np.conj(Lq)*gradF,axis=1))
    return t1

def gradients(vars,gradF):    
    (psi,psifr,q,rho) = (vars['psi'],vars['psifr'],vars['prb'],vars['rho'])
    grads = {}
    grads['psi'] = rho[0]*gradientpsi(q,gradF)    
    grads['prb'] = rho[1]*gradientq(psi,psifr,gradF)
    
    return grads

##### $$\frac{1}{2}\mathcal{H}|_{\psi_0}(\Delta\psi^{(1)},\Delta\psi^{(2)})= \left\langle \mathbf{1}-d_{0}, \mathsf{Re}({L(\Delta\psi^{(1)})}\overline{L(\Delta\psi^{(2)})})\right\rangle+\left\langle d_{0},(\mathsf{Re} (\overline{l_0}\cdot L(\Delta\psi^{(1)})))\cdot (\mathsf{Re} (\overline{l_0}\cdot L(\Delta\psi^{(2)}))) \right\rangle $$
##### 
##### $$l_0=L(\psi_0)/|L(\psi_0)|$$
##### $$d_0=d/|L(\psi_0)|$$


##### $$DM|_{\psi_0,q_0}(\Delta\psi,\Delta q) = L_1(q_0)\cdot C^T(\Delta\psi)+L_1(\Delta q)\cdot (C^T(\psi_0)+\psi_{fr})$$

##### $$D^2M|_{\psi_0,q_0}(\Delta\psi^{(1)},\Delta q^{(1)},\Delta\psi^{(2)},\Delta q^{(2)}) = L_1(\Delta q^{(1)})\cdot C^T(\Delta \psi^{(2)})+L_1(\Delta q^{(2)})\cdot C^T(\Delta \psi^{(1)})$$

\begin{equation}\begin{aligned}
H^G|_{\psi_0,q_0}(\Delta  \psi^{(1)},\Delta q^{(1)},\Delta \psi^{(2)},\Delta q^{(2)})=&\Big\langle \nabla F|_{M({\psi_0,q_0})}, D^2M|_{{\psi_0,q_0}}(\Delta\psi^{(1)},\Delta q^{(1)},\Delta\psi^{(2)},\Delta q^{(2)})\Big\rangle +\\&H^F|_{M({\psi_0,q_0})}\Big(DM|_{x_0}(\Delta \psi^{(1)},\Delta q^{(1)}),DM|_{\psi_0,q_0}(\Delta \psi^{(2)},\Delta q^{(2)})\Big).
\end{aligned}
\end{equation}

In [None]:
def hessianF(psi,dpsi1,dpsi2,data):
    Lpsi = Lop(psi)
    Lpsi1 = Lop(dpsi1)
    Lpsi2 = Lop(dpsi2)
    l0 = Lpsi/np.abs(Lpsi)
    d0 = data/np.abs(Lpsi)
    v1 = np.sum((1-d0)*reprod(Lpsi1,Lpsi2))
    v2 = np.sum(d0*reprod(l0,Lpsi1)*reprod(l0,Lpsi2))                
    return 2*(v1+v2)

def DM(psi,q,psifr,dpsi,dq):
    Lq = L1op(q)
    Ldq = L1op(dq)
    return Lq*CTop(dpsi)+Ldq*(CTop(psi)+psifr)

def D2M(dpsi1,dpsi2,dq1,dq2):
    Ldq1 = L1op(dq1)
    Ldq2 = L1op(dq2)
    return Ldq1*CTop(dpsi2)+Ldq2*CTop(dpsi1)

In [None]:
def calc_beta(vars,grads,etas,d,gradF):
    (psi,psifr,q,rho) = (vars['psi'],vars['psifr'],vars['prb'],vars['rho'])
    (dpsi1,dq1) = (rho[0]*grads['psi'],rho[1]*grads['prb']) 
    (dpsi2,dq2) = (rho[0]*etas['psi'],rho[1]*etas['prb'])   
    
    dm1 = DM(psi,q,psifr,dpsi1,dq1)
    dm2 = DM(psi,q,psifr,dpsi2,dq2)

    d2m1 = D2M(dpsi1,dpsi2,dq1,dq2)
    d2m2 = D2M(dpsi2,dpsi2,dq2,dq2)

    Lpsi = L1op(q)*(CTop(psi)+psifr)
    
    top = redot(gradF,d2m1)+hessianF(Lpsi, dm1, dm2, d)    
    bottom = redot(gradF,d2m2)+hessianF(Lpsi, dm2, dm2, d)    
    
    return top/bottom

def calc_alpha(vars,grads,etas,d,gradF):    
    (psi,psifr,q, rho) = (vars['psi'],vars['psifr'],vars['prb'],vars['rho'])
    (dpsi1,dq1) = (grads['psi'],grads['prb']) 
    (dpsi2,dq2) = (etas['psi'],etas['prb'])       
    top = -redot(dpsi1,dpsi2)-redot(dq1,dq2)    
    
    (dpsi2,dq2) = (rho[0]*etas['psi'],rho[1]*etas['prb'])       
    dm2 = DM(psi,q,psifr,dpsi2,dq2)
    Lpsi = L1op(q)*(CTop(psi)+psifr)
    d2m2 = D2M(dpsi2,dpsi2,dq2,dq2)
    
    bottom = redot(gradF,d2m2)+hessianF(Lpsi, dm2, dm2, d)    
    return top/bottom, top, bottom

### Initial guess for reconstruction (Paganin)

In [None]:
def rec_init(rdata):
    recMultiPaganin = cp.zeros([1,ndist,n,n],dtype='float32')
    for j in range(0,ndist):
        rdatar = cp.array(rdata[:,j:j+1])
        r = multiPaganin(rdatar,
                            distances[j:j+1], wavelength, voxelsize,  24.05, 1e-6)    
        recMultiPaganin[:,j] = r
        
    recMultiPaganin = np.sum(recMultiPaganin,axis=1)/ndist    
    recMultiPaganin = np.exp(1j*recMultiPaganin)

    return recMultiPaganin

rec_paganin = rec_init(rdata)
rec_paganin = np.pad(rec_paganin,((0,0),(ne//2-n//2,ne//2-n//2),(ne//2-n//2,ne//2-n//2)),'constant',constant_values=1)
mshow_polar(rec_paganin[0],show)
mshow_polar(rec_paganin[0,ne//2-128:ne//2+128,ne//2-128:ne//2+128],show)

## debug functions

In [None]:
def plot_debug2(vars,etas,top,bottom,alpha,data):
    if show==False:
        return
    (psi,psifr,q,rho) = (vars['psi'],vars['psifr'],vars['prb'],vars['rho'])
    (dpsi2,dq2) = (etas['psi'],etas['prb'])   
    npp = 7
    errt = cp.zeros(npp*2)
    errt2 = cp.zeros(npp*2)
    for k in range(0,npp*2):
        psit = psi+(alpha*k/(npp-1))*rho[0]*dpsi2
        qt = q+(alpha*k/(npp-1))*rho[1]*dq2        
        fpsit = np.abs(Lop(L1op(qt)*(CTop(psit)+psifr)))-data
        errt[k] = np.linalg.norm(fpsit)**2
        
    t = alpha*(cp.arange(2*npp))/(npp-1)
    tmp = np.abs(Lop(L1op(q)*(CTop(psi)+psifr)))-(data)
    errt2 = np.linalg.norm(tmp)**2-top*t+0.5*bottom*t**2
    
    plt.plot(alpha.get()*cp.arange(2*npp).get()/(npp-1),errt.get(),'.')
    plt.plot(alpha.get()*cp.arange(2*npp).get()/(npp-1),errt2.get(),'.')
    plt.show()


def vis_debug(vars,data,i):
    psie = CTop(vars['psi'])+vars['psifr']
    q = vars['prb']
    mshow_polar(psie[0],show)    
    mshow_polar(q[0],show)    
    mshow_polar(psie[0,ne//2-n//4:ne//2+n//4,ne//2+n//4:ne//2+n//2+n//4],show)    
    
def err_debug(vars, grads, data):    
    (psi,psifr,q) = (vars['psi'],vars['psifr'],vars['prb'])
    tmp = np.abs(Lop(L1op(q)*(CTop(psi)+psifr)))-(data)
    err = np.linalg.norm(tmp)**2
    return err

# Main CG loop (fifth rule)

In [None]:
def cg_holo(data, vars, pars):

    data = np.sqrt(data)    
    erra = cp.zeros(pars['niter'])
    alphaa = cp.zeros(pars['niter'])    
    grads ={}
    for i in range(pars['niter']):          
        if i % pars['vis_step'] == 0 and pars['vis_step'] != -1:
            vis_debug(vars, data, i) 
        gradF = gradientF(vars,data)     
        grads = gradients(vars,gradF)
        if i==0:
            etas = {}
            etas['psi'] = -grads['psi']
            etas['prb'] = -grads['prb']            
        else:      
            beta = calc_beta(vars, grads, etas, data, gradF)
            etas['psi'] = -grads['psi'] + beta*etas['psi']
            etas['prb'] = -grads['prb'] + beta*etas['prb']
        alpha,top,bottom = calc_alpha(vars, grads, etas, data, gradF) 

        if i % pars['vis_step'] == 0 and pars['err_step'] != -1:
            plot_debug2(vars,etas,top,bottom,alpha,data)

        vars['psi'] += alpha*vars['rho'][0]*etas['psi']
        vars['prb'] += alpha*vars['rho'][1]*etas['prb']
        
        if i % pars['err_step'] == 0 and pars['err_step'] != -1:
            err = err_debug(vars, grads, data)    
            print(f'{i}) {alpha=:.5f}, {vars['rho']=}, {err=:1.5e}',flush=True)
            erra[i] = err
            alphaa[i] = alpha
        t={}
        t[0]=np.linalg.norm(grads['psi'])
        t[1]=np.linalg.norm(grads['prb'])
        #print(t)
        for k in range(1,2):
            if t[k]>2*t[0]:
                vars['rho'][k]/=2
            elif t[k]<t[0]/2:
                vars['rho'][k]*=2                
    return vars,erra,alphaa

vars = {}
vars['psi'] = cp.array(rec_paganin)[:,(ne-n)//2:(ne+n)//2,(ne-n)//2:(ne+n)//2]
vars['prb'] = cp.ones([1,ne,ne],dtype='complex64')
vars['psifr'] = cp.ones([1,ne,ne],dtype='complex64')
vars['psifr'][:,(ne-n)//2:(ne+n)//2,(ne-n)//2:(ne+n)//2] = 0
vars['rho'] = [1,1]
data_rec = cp.array(data)


pars = {'niter': 4097, 'err_step':128, 'vis_step': 128}
vars,erra,alphaa = cg_holo(data_rec, vars, pars)    

In [None]:
psie = CTop(vars['psi'])+vars['psifr']
q = vars['prb']
mshow_polar(psie[0],show)    
mshow_polar(q[0],show)    
mshow_polar(psie[0,ne//2-n//4:ne//2+n//4,ne//2+n//4:ne//2+n//2+n//4],show,vmax=1,vmin=-1)    
    