In [None]:
import numpy as np
import cupy as cp
from holotomocupy.holo import G, GT
from holotomocupy.shift import S, ST
from holotomocupy.recon_methods import multiPaganin
from holotomocupy.utils import *

# Init data sizes and parametes of the PXM of ID16A

In [None]:
n = 1024  # object size in each dimension
pad = 512
npos= 16

detector_pixelsize = 3.03751e-6
energy = 33.35  # [keV] xray energy
wavelength = 1.2398419840550367e-09/energy  # [m] wave length
focusToDetectorDistance = 1.28  # [m]
sx0 = 1.286e-3
z1 = np.tile(5.5e-3-sx0, [npos])
z2 = focusToDetectorDistance-z1
distances = (z1*z2)/focusToDetectorDistance
magnifications = focusToDetectorDistance/z1
voxelsize = np.abs(detector_pixelsize/magnifications[0])  # object voxel size

# sample size after demagnification
ne = n+2*pad
show = True
path = f'/data/vnikitin/modeling/siemens{n}'
path_out = f'/data/vnikitin/modeling/siemens_recn_v1{n}_{pad}'


# Construct operators


In [None]:
def Lop(psi):
    data = cp.zeros([*psi.shape[:2], n, n], dtype='complex64')
    for i in range(psi.shape[1]):
        psir = cp.array(psi[:,i])               
        psir = G(psir, wavelength, voxelsize, distances[i],'constant')
        data[:, i] = psir[:,ne//2-n//2:ne//2+n//2,ne//2-n//2:ne//2+n//2]
    return data

def LTop(data):
    psi = cp.zeros([*data.shape[:2], ne, ne], dtype='complex64')
    for j in range(data.shape[1]):
        datar = cp.array(cp.pad(data[:, j],((0,0),(ne//2-n//2,ne//2-n//2),(ne//2-n//2,ne//2-n//2)))).astype('complex64')               
        datar = GT(datar, wavelength, voxelsize, distances[j],'constant')        
        psi[:,j] = datar
    return psi

def Sop(psi,shifts):
    data = cp.zeros([1, npos, ne, ne], dtype='complex64')
    psi = cp.array(psi)
    for j in range(npos):
        psir = psi.copy()
        shiftsr = cp.array(shifts[:, j])
        psir = S(psir, shiftsr,mode='constant')
        data[:,j] = psir
    return data

def STop(data,shifts):
    psi = cp.zeros([1, ne, ne], dtype='complex64')

    for j in range(npos):
        datar = cp.array(data[:,j])
        shiftsr = cp.array(shifts[:, j])
        psi += ST(datar,shiftsr,mode='constant')
    return psi

# adjoint tests
arr1 = cp.random.random([1,ne,ne]).astype('float32')+1j*cp.random.random([1,ne,ne]).astype('float32')
prb1 = cp.ones([1,n+2*pad,n+2*pad],dtype='complex64')
shifts = cp.random.random([1,npos,2]).astype('float32')
arr2 = Sop(arr1,shifts)
arr3 = STop(arr2,shifts)

arr4 = Lop(arr2)
arr5 = LTop(arr4)


print(f'{np.sum(arr1*np.conj(arr3))}==\n{np.sum(arr2*np.conj(arr2))}')
print(f'{np.sum(arr2*np.conj(arr5))}==\n{np.sum(arr4*np.conj(arr4))}')

In [None]:
import cv2
import xraylib
ne*=2
img = np.zeros((ne, ne, 3), np.uint8)
triangle = np.array([(ne//16, ne//2-ne//64), (ne//16, ne//2+ne//64), (ne//2-ne//64, ne//2)], np.float32)
star = img[:,:,0]*0
for i in range(0, 360, 8):
    img = np.zeros((ne, ne, 3), np.uint8)
    degree = i
    theta = degree * np.pi / 180
    rot_mat = np.array([[np.cos(theta), -np.sin(theta)],
                        [np.sin(theta), np.cos(theta)]], np.float32)    
    rotated = cv2.gemm(triangle-ne//2, rot_mat, 1, None, 1, flags=cv2.GEMM_2_T)+ne//2
    cv2.fillPoly(img, [np.int32(rotated)], (255, 0, 0))
    star+=img[:,:,0]
[x,y] = np.meshgrid(np.arange(-ne//2,ne//2),np.arange(-ne//2,ne//2))
x = x/ne*2
y = y/ne*2
# add holes in triangles
circ = (x**2+y**2>0.145)+(x**2+y**2<0.14)
circ *= (x**2+y**2>0.053)+(x**2+y**2<0.051)
circ *= (x**2+y**2>0.0085)+(x**2+y**2<0.008)
circ *= (x**2+y**2>0.3)+(x**2+y**2<0.295)
# circ *= (x**2+y**2>0.62)+(x**2+y**2<0.6)
circ *= (x**2+y**2<0.8**2)
star = star*circ/255
star = star[ne//4:-ne//4,ne//4:-ne//4]
ne//=2


# star[ne//2-n//2:ne//2+n//2,ne//2-n//2:ne//2+n//2]=0

v = np.arange(-ne//2,ne//2)/ne
[vx,vy] = np.meshgrid(v,v)
v = np.exp(-5*(vx**2+vy**2))
fu = np.fft.fftshift(np.fft.fftn(np.fft.fftshift(star)))
star = np.fft.fftshift(np.fft.ifftn(np.fft.fftshift(fu*v))).real

delta = 1-xraylib.Refractive_Index_Re('Au',energy,19.3)
beta = xraylib.Refractive_Index_Im('Au',energy,19.3)

print(delta/beta)
thickness = 2e-6/voxelsize/10 # siemens star thickness in pixels
# form Transmittance function
u = star*(-delta+1j*beta) # note -delta
Ru = u*thickness 
psi = np.exp(1j * Ru * voxelsize * 2 * np.pi / wavelength)[np.newaxis].astype('complex64')
mshow_polar(psi[0],show)

psi_abs = dxchange.read_tiff(f'/data/vnikitin/ESRF/ID16A/20240924_rec/SiemensLH/SiemensLH_010nm_nfp_reg_new_pos0_crop768_1.0e-01_0.0e+00_2048_768/crec_code_abs2048/1024.tiff')
psi_angle = dxchange.read_tiff(f'/data/vnikitin/ESRF/ID16A/20240924_rec/SiemensLH/SiemensLH_010nm_nfp_reg_new_pos0_crop768_1.0e-01_0.0e+00_2048_768/crec_code_angle2048/1024.tiff')

psi = (1+psi_angle/24)*np.exp(1j*(psi_angle-0.07)/2)#/np.mean(psi_abs)#-np.amax(psi_angle)
psi = psi[np.newaxis,psi.shape[1]//2-1024:psi.shape[1]//2+1024,psi.shape[1]//2-1024:psi.shape[1]//2+1024]
mshow_polar(psi[0],show)

In [None]:
!wget -nc https://g-110014.fd635.8443.data.globus.org/holotomocupy/examples_synthetic/data/prb_id16a/prb_abs_2048.tiff -P ../data/prb_id16a
!wget -nc https://g-110014.fd635.8443.data.globus.org/holotomocupy/examples_synthetic/data/prb_id16a/prb_phase_2048.tiff -P ../data/prb_id16a

import scipy.ndimage as ndimage
prb_abs = dxchange.read_tiff(f'../data/prb_id16a/prb_abs_2048.tiff')[0:1]
prb_phase = dxchange.read_tiff(f'../data/prb_id16a/prb_phase_2048.tiff')[0:1]
prb = prb_abs*np.exp(1j*prb_phase).astype('complex64')

prb = ndimage.zoom(prb,(1,2*n/2048,2*n/2048))
prb = prb[:,n-n//2-pad:n+n//2+pad,n-n//2-pad:n+n//2+pad]


prb /= np.mean(np.abs(prb))
# prb[:]=1

fig, axs = plt.subplots(1, 2, figsize=(9, 4))
im=axs[0].imshow(np.abs(prb[0]),cmap='gray')
axs[0].set_title('abs prb')
fig.colorbar(im)
im=axs[1].imshow(np.angle(prb[0]),cmap='gray')
axs[1].set_title('angle prb')
fig.colorbar(im)

In [None]:
import h5py
with h5py.File(f'/data/vnikitin/ESRF/ID16A/20240924/SiemensLH/SiemensLH_010nm_nfp_01/SiemensLH_010nm_nfp_010000.h5','r') as fid:
    spz = np.array(str(np.array(str(np.array(fid['/entry_0000/instrument/PCIe/header/spz']))[1:]))[1:-1].split(' '),dtype='float32')*1e-6/voxelsize
    spy = np.array(str(np.array(str(np.array(fid['/entry_0000/instrument/PCIe/header/spy']))[1:]))[1:-1].split(' '),dtype='float32')*1e-6/voxelsize

shifts_code0 = np.zeros([1,npos,2],dtype='float32')
shifts_code0[:,:,1] = spy[:npos]
shifts_code0[:,:,0] = -spz[:npos]
shifts=cp.array(shifts_code0)
#shifts = (cp.random.random(shifts.shape).astype('float32')-0.5)*200
psi = cp.array(psi)
prb = cp.array(prb)
data = np.abs(Lop(prb*Sop(psi,shifts)))**2
ref = np.abs(Lop(prb*Sop(psi*0+1,shifts)))**2
ref = ref[0]

rdata = data/(ref+1e-11)#/0.76**2

mshow_complex(data[0,0]+1j*rdata[0,0],show)#,show,vmax=1.3,vmin=0.8)
mshow(ref[0],show)

## Reconstruction with the CG (Carlsson) with Hessians

## Gradients

#### $$\nabla F=2 \left(L^*\left( L(M(q_0,\psi_0,\boldsymbol{x}_0))-\tilde D\right)\right).$$
#### where $$\tilde D = D \frac{L(M(q_0,\psi_0,\boldsymbol{x}_0))}{|L(M(q_0,\psi_0,\boldsymbol{x}_0))|}$$



In [None]:
def gradientF(vars,d):
    (psi,q,x) = (vars['psi'], vars['prb'], vars['shift'])
    Lpsi = Lop(Sop(psi,x)*q)
    td = d*(Lpsi/np.abs(Lpsi))
    res = 2*LTop(Lpsi - td)
    return res



##### $$\nabla_{\psi} G|_{(q_0,\psi_0,\boldsymbol{x}_0)}=S_{\boldsymbol{x}_{0}}^*\left(\overline{J(q_0)}\cdot \nabla F\right).$$
##### $$\nabla_{q} G|_{(q_0,\psi_0,\boldsymbol{x}_0)}=J^*\left( \overline{S_{\boldsymbol{x}_{0}}(\psi_0)}\cdot \nabla F\right).$$



In [None]:
def gradientpsi(q,x,gradF):
    return STop(np.conj(q)*gradF,x)

def gradientq(psi,x,gradF):
    return np.sum(np.conj(Sop(psi,x))*gradF,axis=1)

def gradients(vars,d,gradF):
    (psi,q,x) = (vars['psi'], vars['prb'], vars['shift'])
    grads = {}
    grads['psi'] = gradientpsi(q,x,gradF)
    grads['prb'] = gradientq(psi,x,gradF)
    return grads



##### $$\frac{1}{2}\mathcal{H}|_{x_0}(y,z)= \left\langle \mathbf{1}-d_{0}, \mathsf{Re}({L(y)}\overline{L(z)})\right\rangle+\left\langle d_{0},(\mathsf{Re} (\overline{l_0}\cdot L(y)))\cdot (\mathsf{Re} (\overline{l_0}\cdot L(z))) \right\rangle$$
##### $$l_0=L(x_0)/|L(x_0)|$$
##### $$d_0=d/|L(x_0)|$$


In [None]:
def hessianF(hpsi,hpsi1,hpsi2,data):
    Lpsi = Lop(hpsi)        
    Lpsi1 = Lop(hpsi1)
    Lpsi2 = Lop(hpsi2)    
    l0 = Lpsi/np.abs(Lpsi)
    d0 = data/np.abs(Lpsi)
    v1 = np.sum((1-d0)*reprod(Lpsi1,Lpsi2))
    v2 = np.sum(d0*reprod(l0,Lpsi1)*reprod(l0,Lpsi2))            
    return 2*(v1+v2)

#### $$ DM|_{(q_0,\psi_0,\boldsymbol{x})}(\Delta q, \Delta \psi,\Delta\boldsymbol{x})=$$
#### $$ \Big(\Delta q\cdot T_{\psi_0}({\boldsymbol{x}_{0,k}})+ q_0\cdot T_{\Delta \psi}({\boldsymbol{x}_{0,k}}) \Big)_{k=1}^K=$$
#### $$ \Big(J(\Delta q)\cdot S_{\boldsymbol{x}_{0,k}}(\psi_0)+ J(q_0)\cdot S_{\boldsymbol{x}_{0}}{(\Delta \psi)}\Big)_{k=1}^K$$


In [None]:
def DM(psi,q,x,dpsi,dq):
    res = dq*Sop(psi,x)+q*Sop(dpsi,x) 
    return res

##### $$ D^2M|_{(q_0,\psi_0,\boldsymbol{x})}\big((\Delta q^{(1)}, \Delta \psi^{(1)}),(\Delta q^{(2)}, \Delta \psi^{(2)})\big)= $$
##### $$\Delta q^{(1)}\cdot T_{\Delta \psi^{(2)}}({\boldsymbol{x}_{0,k}})+\Delta q^{(2)}\cdot T_{\Delta \psi^{(1)}}({\boldsymbol{x}_{0,k}}) $$



In [None]:
def D2M(psi,q,x,dpsi1,dq1,dpsi2,dq2):    
    res = dq1*Sop(dpsi2,x) + dq2*Sop(dpsi1,x) 
    return res

##### $$\mathcal{H}^G|_{ (q_0,\psi_0,\boldsymbol{x}_0)}\Big((\Delta q^{(1)},\Delta \psi^{(1)}),(\Delta q^{(2)},\Delta \psi^{(2)})\Big)=$$
##### $$\Big\langle \nabla F|_{M(q_0,\psi_0,\boldsymbol{x}_0)}, D^2M|_{(q_0,\psi_0,\boldsymbol{x}_0)}\Big((\Delta q^{(1)},\Delta \psi^{(1)},\Delta \boldsymbol{x}^{(1)}),(\Delta q^{(2)},\Delta \psi^{(2)},\Delta \boldsymbol{x}^{(2)})\Big)\Big\rangle +$$
##### $$\mathcal{H}^F|_{M(q_0,\psi_0,\boldsymbol{x}_0)}\Big(DM|_{(q_0,\psi_0,\boldsymbol{x}_0)}(\Delta q^{(1)},\Delta \psi^{(1)},\Delta \boldsymbol{x}^{(1)}),DM|_{(q_0,\psi_0,\boldsymbol{x}_0)}(\Delta q^{(2)},\Delta \psi^{(2)},\Delta \boldsymbol{x}^{(2)})\Big).$$

In [None]:
def calc_beta(vars,grads,etas,d,gradF):
    (psi,q,x) = (vars['psi'], vars['prb'], vars['shift'])
    (dpsi1,dq1) = (grads['psi'], grads['prb'])
    (dpsi2,dq2) = (etas['psi'], etas['prb'])
    
    dm1 = DM(psi,q,x,dpsi1,dq1)
    dm2 = DM(psi,q,x,dpsi2,dq2)
    d2m1 = D2M(psi,q,x,dpsi1,dq1,dpsi2,dq2)
    d2m2 = D2M(psi,q,x,dpsi2,dq2,dpsi2,dq2)
    sq = Sop(psi,x)*q

    top = redot(gradF,d2m1)        
    top += hessianF(sq, dm1, dm2,d)    
    bottom = redot(gradF,d2m2)    
    bottom += hessianF(sq, dm2, dm2,d)
    return top/bottom

def calc_alpha(vars,grads,etas,d,gradF):    
    (psi,q,x) = (vars['psi'], vars['prb'], vars['shift'])
    (dpsi1,dq1) = (grads['psi'], grads['prb'])
    (dpsi2,dq2) = (etas['psi'], etas['prb']) 
    top = -redot(dpsi1,dpsi2)-redot(dq1,dq2)
    
    dm2 = DM(psi,q,x,dpsi2,dq2)
    d2m2 = D2M(psi,q,x,dpsi2,dq2,dpsi2,dq2)
    sq = Sop(psi,x)*q
    bottom = redot(gradF,d2m2)+hessianF(sq, dm2, dm2, d)
    return top/bottom, top, bottom

### Initial guess for reconstruction (Paganin)

In [None]:
def rec_init(rdata,shifts):
    recMultiPaganin = cp.zeros([1,npos,ne,ne],dtype='float32')
    recMultiPaganinr = cp.zeros([1,npos,ne,ne],dtype='float32')# to compensate for overlap
    for j in range(0,1):
        rdatar = cp.array(rdata[:,j:j+1])
        r = multiPaganin(rdatar,
                            distances[j:j+1], wavelength, voxelsize,  24.05, 1e-8)    
        rr = r*0+1 # to compensate for overlap
        shiftsr = cp.array(shifts[:,j])
        recMultiPaganin[:,j] = ST(r,shiftsr).real
        recMultiPaganinr[:,j] = ST(rr,shiftsr).real
        
    recMultiPaganin = np.sum(recMultiPaganin,axis=1)
    recMultiPaganinr = np.sum(recMultiPaganinr,axis=1)

    # avoid division by 0
    recMultiPaganinr[np.abs(recMultiPaganinr)<5e-2] = 1

    # compensate for overlap
    recMultiPaganin /= recMultiPaganinr
    v = cp.ones(ne,dtype='float32')
    v[:pad] = np.sin(cp.linspace(0,1,pad)*np.pi/2)
    v[ne-pad:] = np.cos(cp.linspace(0,1,pad)*np.pi/2)
    v = np.outer(v,v)
    recMultiPaganin*=v
    recMultiPaganin = np.exp(1j*recMultiPaganin)

    return recMultiPaganin

rdata = data/ref
rdatap = rdata.copy()
rdatap = np.pad(rdatap,((0,0),(0,0),(ne//2-n//2,ne//2-n//2),(ne//2-n//2,ne//2-n//2)),'constant',constant_values=1)
mshow(rdata[0,0],show,vmax=1.2,vmin=0.8)
rec_paganin = rec_init(rdatap,shifts)
mshow_polar(rec_paganin[0],show,vmax=0.25)
mshow_polar(rec_paganin[0,ne//2-128:ne//2+128,ne//2-128:ne//2+128],show)

## debug functions

In [None]:
import matplotlib.patches as patches

def plot_debug2(vars,etas,top,bottom,alpha,data):
    if show==False:
        return
    (psi,q,x) = (vars['psi'], vars['prb'], vars['shift'])
    (dpsi2,dq2) = (etas['psi'], etas['prb'])
    npp = 17
    errt = cp.zeros(npp*2)
    errt2 = cp.zeros(npp*2)
    for k in range(0,npp*2):
        psit = psi+(alpha*k/(npp-1))*dpsi2
        qt = q+(alpha*k/(npp-1))*dq2
        fpsit = np.abs(Lop(Sop(psit,x)*qt))-data
        errt[k] = np.linalg.norm(fpsit)**2
        
    t = alpha*(cp.arange(2*npp))/(npp-1)
    tmp = np.abs(Lop(Sop(psi,x)*q))-data
    errt2 = np.linalg.norm(tmp)**2-top*t+0.5*bottom*t**2
    
    plt.plot(alpha.get()*cp.arange(2*npp).get()/(npp-1),errt.get(),'.')
    plt.plot(alpha.get()*cp.arange(2*npp).get()/(npp-1),errt2.get(),'.')
    plt.show()

def vis_debug(vars,data,i):
    mshow_polar(vars['psi'][0],show)    
    mshow_polar(vars['psi'][0,ne//2-n//8:ne//2+n//8,ne//2+n//4:ne//2+n//2],show)
    mshow_polar(vars['prb'][0],show)
    dxchange.write_tiff(np.angle(vars['psi'][0]).get(),f'{path_out}/crec_code_angle/{i:03}',overwrite=True)
    dxchange.write_tiff(np.angle(vars['prb'][0]).get(),f'{path_out}/crec_prb_angle/{i:03}',overwrite=True)
    dxchange.write_tiff(np.abs(vars['psi'][0]).get(),f'{path_out}/crec_code_abs/{i:03}',overwrite=True)
    dxchange.write_tiff(np.abs(vars['prb'][0]).get(),f'{path_out}/crec_prb_abs/{i:03}',overwrite=True)
    np.save(f'{path_out}/crec_shift_{i:03}',vars['shift'])

    
def err_debug(vars, grads, data):    
    (psi,q,x) = (vars['psi'], vars['prb'], vars['shift'])
    tmp = np.abs(Lop(Sop(psi,x)*q))-data
    err = np.linalg.norm(tmp)**2
    print(f'gradient norms (psi, prb): {np.linalg.norm(grads['psi']):.2f}, {np.linalg.norm(grads['prb']):.2f}')                        
    return err

# Main CG loop (fifth rule)

In [None]:
def cg_holo(data, vars, pars):

    data = np.sqrt(data)    
    erra = cp.zeros(pars['niter'])
    alphaa = cp.zeros(pars['niter'])    
    for i in range(pars['niter']):           
        if i % pars['vis_step'] == 0 and pars['vis_step'] != -1:
            vis_debug(vars, data, i)       
        gradF = gradientF(vars,data)        
        grads = gradients(vars,data,gradF)
        if i==0:
            etas = {}
            etas['psi'] = -grads['psi']
            etas['prb'] = -grads['prb']
        else:      
            beta = calc_beta(vars, grads, etas, data, gradF)
            etas['psi'] = -grads['psi'] + beta*etas['psi']
            etas['prb'] = -grads['prb'] + beta*etas['prb']

        alpha,top,bottom = calc_alpha(vars, grads, etas, data, gradF) 
        if i % pars['vis_step'] == 0:
            plot_debug2(vars,etas,top,bottom,alpha,data)

        vars['psi'] += alpha*etas['psi']
        vars['prb'] += alpha*etas['prb']
        
        if i % pars['err_step'] == 0:
            err = err_debug(vars, grads, data)    
            print(f'{i}) {alpha=:.5f}, {err=:1.5e}',flush=True)
            erra[i] = err
            alphaa[i] = alpha

             
        
    return vars,erra,alphaa

vars = {}
vars['psi'] = 1+0*cp.array(rec_paganin).copy()
vars['prb'] = cp.ones([1,n+2*pad,n+2*pad],dtype='complex64')
vars['shift'] = cp.array(shifts)
data_rec = cp.array(data)

pars = {'niter': 10000, 'err_step': 128, 'vis_step': 256}
vars,erra,alphaa = cg_holo(data_rec, vars, pars)    