In [None]:
import numpy as np
import cupy as cp
from holotomocupy.utils import *
from holotomocupy.holo import G,GT
from holotomocupy.shift import S
from holotomocupy.recon_methods import CTFPurePhase, multiPaganin

%matplotlib inline

np.random.seed(1) # fix randomness

# Init data sizes and parametes of the PXM of ID16A

In [None]:
n = 2048  # object size in each dimension

z1c = -19.5e-3
npos = 4  # number of code positions
detector_pixelsize = 3e-6
energy = 33.35  # [keV] xray energy
wavelength = 1.2398419840550367e-09/energy  # [m] wave length
focusToDetectorDistance = 1.28  # [m]
sx0 = 3.7e-4
z1 = 4.584e-3-sx0
z1 = np.tile(z1, [npos])

z2 = focusToDetectorDistance-z1
distances = (z1*z2)/focusToDetectorDistance
magnifications = focusToDetectorDistance/z1
voxelsize = detector_pixelsize/magnifications[0]*2048/n  # object voxel size

# magnification when propagating from the probe plane to the detector
magnifications2 = z1/z1c
distances2 = (z1-z1c)/(z1c/z1)#magnifications2
# allow padding if there are shifts of the probe
pad = n//8
# sample size after demagnification
ne = n+2*pad

show = True
flg = f'{n}_{z1c}'
path = f'/data2/vnikitin/nfp_codes'



## Read the probe. 

In [None]:
#prb = np.load(f'{path}/prb_{flg}.npy') # original
prb = np.load(f'{path}/rec_prb_{flg}.npy') # reconstructed
prb = cp.array(prb)
mshow_polar(prb[0],show)

### Read reconstructed coded aperture

In [None]:
#code = cp.array(np.load(f'{path}/code_{flg}.npy')) #originial
code = cp.array(np.load(f'{path}/rec_code_{flg}.npy')) # reconstructed
mshow_polar(code[0],show)

### Read shifts

In [None]:
shifts_code = np.load(f'{path}/shifts_code_{flg}.npy')
shifts_code = cp.array(shifts_code)

### Read data

In [None]:
data = cp.empty([1,npos,n,n],dtype='float32')
ref = cp.empty([1,npos,n,n],dtype='float32')
for k in range(npos):
    data[:,k] = cp.array(read_tiff(f'{path}/data_{k}_{flg}.tiff'))
    ref[:,k] = cp.array(read_tiff(f'{path}/ref_{k}_{flg}.tiff'))
    
mshow(data[0,0],show)

#### Forward operator

In [None]:
def Lop(psi):
    data = cp.zeros([psi.shape[0], npos, n, n], dtype='complex64')
    for i in range(npos):
        psir = psi[:,i].copy()       
        psir = G(psir, wavelength, voxelsize, distances[i])
        data[:, i] = psir[:, pad:n+pad, pad:n+pad]
    return data

def LTop(data):
    psi = cp.zeros([data.shape[0],npos, n+2*pad, n+2*pad], dtype='complex64')
    for j in range(npos):
        psir = cp.pad(data[:, j], ((0, 0), (pad, pad), (pad, pad))).astype('complex64')
        psir = GT(psir, wavelength, voxelsize, distances[j])        
        psi[:,j] = psir
    return psi

def Lwop(psi):
    data = cp.zeros([psi.shape[0], npos, n+2*pad, n+2*pad], dtype='complex64')
    for i in range(npos):
        psir = psi[:,i].copy()       
        psir = G(psir, wavelength, voxelsize, distances2[i])
        data[:, i] = psir
    return data

def LTwop(data):
    psi = cp.zeros([data.shape[0],npos, n+2*pad, n+2*pad], dtype='complex64')
    for j in range(npos):
        psir = data[:, j].astype('complex64')
        psir = GT(psir, wavelength, voxelsize, distances2[j])        
        psi[:,j] = psir
    return psi

def Sop(psi):
    data = cp.zeros([psi.shape[0], npos, n+2*pad, n+2*pad], dtype='complex64')
    for i in range(npos):
        psir = psi.copy()
        psir = S(psir, shifts_code[:, i])
        nee = psir.shape[1]        
        data[:,i] = psir[:, nee//2-ne//2:nee//2+ne//2, nee//2-ne//2:nee//2+ne//2]
    return data

arr1 = Sop(code)*prb
arr2 = Lwop(arr1)
arr3 = LTwop(arr2)
print(f'{cp.sum(arr1*cp.conj(arr3))}==\n{cp.sum(arr2*cp.conj(arr2))}')

prb1 = prb.copy()
psi1 = cp.pad(data[:,0],((0,0),(pad,pad),(pad,pad)),'symmetric').astype('complex64')     
arr3 = Lop(arr2)
arr4 = LTop(arr3)
print(f'{cp.sum(arr2*cp.conj(arr4))}==\n{cp.sum(arr3*cp.conj(arr3))}')

## reconstruction with Paganin

In [None]:
d = cp.abs(Lop(Lwop(Sop(code))))**2
mshow(d[0,0],show)

rdata = data/d/ref
mshow((rdata)[0,0],show,vmax=2)

In [None]:
r = multiPaganin(rdata[:,:], distances[:], wavelength, voxelsize,  24.05, 2e-2)
r = cp.pad(r, ((0,0),(pad,pad),(pad,pad)))
recMultiPaganin = np.exp(1j*r)


In [None]:
def redot(a,b):
    return cp.sum(a.real*b.real+a.imag*b.imag)

#### $$\nabla F=2 \left(L^*\left( L(\psi)-\tilde D\right)\right).$$
#### where $$\tilde D = D \frac{L(\psi)}{|L(\psi)|}$$



In [None]:
def gradientF(psi,d):
    Lpsi = Lop(psi)
    td = d*(Lpsi/cp.abs(Lpsi))
    res = 2*LTop(Lpsi - td)
    return res

#### $$\frac{1}{2}\mathcal{H}^F|_{x_0}(y,z)= \left\langle \mathbf{1}-d_{0}, \mathsf{Re}({L(y)}\overline{L(z)})\right\rangle+\left\langle d_{0},(\mathsf{Re} (\overline{l_0}\cdot L(y)))\cdot (\mathsf{Re} (\overline{l_0}\cdot L(z)))\right\rangle.$$
#### $$l_0=L(x_0)/|L(x_0)|$$
#### $$d_0=d/|L(x_0)|$$


In [None]:
def hessianF(psi,psi1,psi2,d):
    Lpsi = Lop(psi)
    Lpsi1 = Lop(psi1)
    Lpsi2 = Lop(psi2)    
    l0 = Lpsi/cp.abs(Lpsi)
    d0 = d/cp.abs(Lpsi)
    v1 = redot(1-d0,cp.real(Lpsi1*cp.conj(Lpsi2)))
    v2 = redot(d0,cp.real(cp.conj(l0)*Lpsi1)*cp.real(cp.conj(l0)*Lpsi2))
    return 2*(v1+v2)


#### $$\nabla G|_q = \overline{S(c)} L_\omega^*\left(\overline{\psi}\nabla F|_{\psi(L_\omega(S(c) q)}\right)$$
#### $$\nabla G|_\psi = \overline{(L_\omega(S(c) q))}\nabla F|_{\psi(L_\omega(S(c) q)}$$

In [None]:
def gradients(psi,q,c,d):
    Lwcq = Lwop(Sop(c)*q)
    gradF = gradientF(psi*Lwcq,d)        
    gradq = cp.sum(cp.conj(Sop(c))*LTwop(cp.conj(psi)*gradF),axis=1)    
    gradpsi = cp.sum(cp.conj(Lwcq)*gradF,axis=1)
    return gradpsi,gradq

#### $$H^G(\Delta\psi_1,\Delta\psi_2,\Delta q_1,\Delta q_2) =$$
#### $$H^F_{\psi(L_\omega(S(c) q)}$$
#### $$\Big(\Delta\psi_1(L_\omega(S(c) q)) +\psi(L_\omega(S(c) \Delta q_1))+\Delta\psi_1(L_\omega(S(c) \Delta q_1))), $$
#### $$\Delta\psi_2(L_\omega(S(c) q)) +\psi(L_\omega(S(c) \Delta q_2))+\Delta\psi_2(L_\omega(S(c) \Delta q_2)))\Big)+$$
#### $$\Big\langle \nabla F|_{\psi(L_\omega(S(c) q)}, \Delta\psi_1(L_\omega(S(c) \Delta q_2))+\Delta\psi_2(L_\omega(S(c) \Delta q_1))\Big\rangle$$

In [None]:
def hessian2(psi,q,c,dpsi1,dq1,dpsi2,dq2,d):
    Lwcq = Lwop(Sop(c)*q)
    Lwcdq1 = Lwop(Sop(c)*dq1)
    Lwcdq2 = Lwop(Sop(c)*dq2)
    h1 = dpsi1*Lwcq+psi*Lwcdq1+dpsi1*Lwcdq1
    h2 = dpsi2*Lwcq+psi*Lwcdq2+dpsi2*Lwcdq2
    v1 = hessianF(psi*Lwcq,h1,h2,d)

    gradF = gradientF(psi*Lwcq,d)
    v2 = redot(gradF,dpsi1*Lwcdq2+dpsi2*Lwcdq1)  
    return v1+v2

#### Reconstruction with the CG (Carlsson) with Hessians

In [None]:

def minf(fpsi,data):
    f = np.linalg.norm(np.abs(fpsi)-data)**2
    return f

def cg_holo(data, init_psi,init_prb,  pars):
    
    data = np.sqrt(data)
    
    psi = init_psi.copy()
    prb = init_prb.copy()

    erra = np.zeros(pars['niter'])
    alphaa = np.zeros(pars['niter'])
    
    for i in range(pars['niter']):                                        
        
        gradpsi,gradprb = gradients(psi,prb,code,data)        
        # if i<4:
        #     gradprb[:]=0
        if i==0:
            etapsi = -gradpsi
            etaprb = -gradprb
        else:
            top = hessian2(psi,prb,code,gradpsi,gradprb,etapsi,etaprb,data)
            bottom = hessian2(psi,prb,code,etapsi,etaprb,etapsi,etaprb,data)
            beta = top/bottom
            
            etapsi = -gradpsi + beta*etapsi
            etaprb = -gradprb + beta*etaprb

        #top
        top = -redot(gradpsi,etapsi)-redot(gradprb,etaprb)
        # bottom
        bottom = hessian2(psi,prb,code,etapsi,etaprb,etapsi,etaprb,data)        
        alpha = top/bottom

        psi += alpha*etapsi
        prb += alpha*etaprb

        if i % pars['err_step'] == 0:
            fpsi = Lop(psi*Lwop(Sop(code)*prb))
            err = minf(fpsi,data)
            erra[i] = err
            alphaa[i] = alpha
            print(f'{i}) {alpha=:.5f}, {err=:1.5e}')

        if i % pars['vis_step'] == 0 and pars['vis_step']>0 :
            mshow_polar(psi[0],show)
            mshow_polar(psi[0,ne//2-256:ne//2+256,ne//2-256:ne//2+256],show)
            mshow_polar(prb[0],show)

    return psi,prb,erra,alphaa

print(shifts_code)
# fully on gpu
rec_psi = cp.ones([1,ne,ne],dtype='complex64')
rec_prb = prb.copy()
pars = {'niter':512, 'err_step': 1, 'vis_step': 32}
rec_psi,rec_prb,erra,alphaa = cg_holo(data,rec_psi,rec_prb, pars)