In [None]:
import cupy as cp
import numpy as np
import cv2
import xraylib
import matplotlib.pyplot as plt


%matplotlib inline

# acquisiton parameters

In [None]:
n = 512  # object size in each dimension
voxelsize = 1e-6 # [m]
energy = 30  # [keV] xray energy
wavelength = 1.24e-09/energy  # [m] wave length
distance = cp.array([0.1,0.2,0.3,0.4,0.5,0.6])
d = 1e-6*2048 
(d/2)**2/wavelength/50

### visualization function for complex arrays with zooming 

In [None]:
def mshow_polar(a, **args):    
    fig, axs = plt.subplots(1, 2, figsize=(9, 3))
    im = axs[0].imshow(cp.abs(a).get(), cmap='gray', **args)
    axs[0].set_title('abs')
    fig.colorbar(im, fraction=0.046, pad=0.04)
    im = axs[1].imshow(cp.angle(a).get(), cmap='gray', **args)
    axs[1].set_title('phase')
    fig.colorbar(im, fraction=0.046, pad=0.04)
    plt.show()
    
    n = a.shape[-1]
    fig, axs = plt.subplots(1, 2, figsize=(9, 3))
    im = axs[0].imshow(cp.abs(a[n//2-n//8:n//2+n//8,0:n//4]).get(), cmap='gray', **args)
    axs[0].set_title('abs')
    fig.colorbar(im, fraction=0.046, pad=0.04)
    im = axs[1].imshow(cp.angle(a[n//2-n//8:n//2+n//8,0:n//4]).get(), cmap='gray', **args)
    axs[1].set_title('phase')
    fig.colorbar(im, fraction=0.046, pad=0.04)
    plt.show()

# siemens star object

In [None]:
img = np.zeros((n, n, 3), np.uint8)
triangle = np.array([(n//16, n//2-n//32), (n//16, n//2+n//32), (n//2-n//128, n//2)], np.float32)
star = img[:,:,0]*0
for i in range(0, 360, 15):
    img = np.zeros((n, n, 3), np.uint8)
    degree = i
    theta = degree * np.pi / 180
    rot_mat = np.array([[np.cos(theta), -np.sin(theta)],
                        [np.sin(theta), np.cos(theta)]], np.float32)    
    rotated = cv2.gemm(triangle-n//2, rot_mat, 1, None, 1, flags=cv2.GEMM_2_T)+n//2
    cv2.fillPoly(img, [np.int32(rotated)], (255, 0, 0))
    star+=img[:,:,0]
[x,y] = np.meshgrid(np.arange(-n//2,n//2),np.arange(-n//2,n//2))
x = x/n*2
y = y/n*2
# add holes in triangles
circ = (x**2+y**2>0.355)+(x**2+y**2<0.345)
circ *= (x**2+y**2>0.083)+(x**2+y**2<0.08)
circ *= (x**2+y**2>0.0085)+(x**2+y**2<0.008)
star = star*circ/255

# smooth
v = np.arange(-n//2,n//2)/n
[vx,vy] = np.meshgrid(v,v)
v = np.exp(-5*(vx**2+vy**2))
fu = np.fft.fftshift(np.fft.fftn(np.fft.fftshift(star)))
star = np.fft.fftshift(np.fft.ifftn(np.fft.fftshift(fu*v))).real

# define complex refractive index
delta = 1-xraylib.Refractive_Index_Re('Au',energy,19.3)
beta = xraylib.Refractive_Index_Im('Au',energy,19.3)

thickness = 3e-6/voxelsize # siemens star thickness in pixels
# form Transmittance function
u = star*(-delta+1j*beta) # note -delta
Ru = u*thickness 
psi = np.exp(1j * Ru * voxelsize * 2 * np.pi / wavelength).astype('complex64')

psi = cp.array(psi)
mshow_polar(psi)


# define the Fresnel transform

In [None]:
def D(psi):
    fpsi = cp.zeros([len(distance),n,n],dtype='complex64')
    fx = cp.fft.fftfreq(2*n, d=voxelsize).astype('float32')
    [fx, fy] = cp.meshgrid(fx, fx)
    psi = cp.pad(psi,((n//2,n//2),(n//2,n//2)))
    for k in range(len(distance)):
        fP = cp.exp(1j*cp.pi*wavelength*distance[k]*(fx**2+fy**2))
        fpsi0 = cp.fft.ifft2(cp.fft.fft2(psi)*fP)   
        fpsi[k] = fpsi0[n//2:-n//2,n//2:-n//2]
    return fpsi

def DT(fpsi):
    psi = cp.zeros([n,n],dtype='complex64')
    fx = cp.fft.fftfreq(2*n, d=voxelsize).astype('float32')
    [fx, fy] = cp.meshgrid(fx, fx)    
    for k in range(len(distance)):
        fpsi0 = cp.pad(fpsi[k],((n//2,n//2),(n//2,n//2)))
        fP = cp.exp(-1j*cp.pi*wavelength*distance[k]*(fx**2+fy**2))
        psi0 = cp.fft.ifft2(cp.fft.fft2(fpsi0)*fP)   
        psi += psi0[n//2:-n//2,n//2:-n//2]
    return psi
# psi = cp.random.random([n,n]).astype('complex64')
Dpsi = D(psi)
DTpsi = DT(Dpsi)

print(cp.sum(psi*cp.conj(DTpsi)))
print(cp.sum(Dpsi*cp.conj(Dpsi)))

# model data

In [None]:
fpsi = D(psi)
data = np.abs(fpsi)**2
mshow_polar(data[-1])

In [None]:
def multiPaganin(data, distances, wavelength, voxelsize, delta_beta,  alpha):   
    fx = cp.fft.fftfreq(n, d=voxelsize).astype('float32')
    [fx, fy] = cp.meshgrid(fx, fx)

    numerator = 0
    denominator = 0
    for j in range(0, len(distance)):
        rad_freq = cp.fft.fft2(data[j])
        taylorExp = 1 + wavelength * distances[j] * cp.pi * (delta_beta) * (fx**2+fy**2)
        numerator = numerator + taylorExp * (rad_freq)
        denominator = denominator + taylorExp**2

    numerator = numerator / len(distances)
    denominator = (denominator / len(distances)) + alpha

    phase = cp.log(cp.real(cp.fft.ifft2(numerator / denominator)))
    phase = (delta_beta) * 0.5 * phase

    return phase

rec_paganin = np.exp(1j*multiPaganin(data, distance, wavelength, voxelsize, delta/beta,  1e-5))
mshow_polar(rec_paganin)


### construct a CG solver for 
### $\argmin_\psi F(\psi) =\argmin_\psi\||D\psi|-\sqrt{d}\|_2^2$

#### gradient: $$\nabla F_\psi=2 \left(D^*\left( D\psi-\frac{D(\psi)}{|D\psi|}\right)\right).$$
#### hessian:
##### $$\frac{1}{2}\mathcal{H}|_{\psi_0}(\psi_1,\psi_2)= \left\langle \mathbf{1}-d_{0}, \mathsf{Re}({D(\psi_1)}\overline{D(\psi_2)})\right\rangle+\left\langle d_{0},(\mathsf{Re} (\overline{l_0}\cdot D(\psi_1)))\cdot (\mathsf{Re} (\overline{l_0}\cdot D(\psi_2)))\right\rangle.$$
##### $$l_0=D(\psi_0)/|D(\psi_0)|$$
##### $$d_0=d/|D(\psi_0)|$$


In [None]:
def gradientF(psi,d):
    Dpsi = D(psi)
    td = d*(Dpsi/np.abs(Dpsi))
    res = 2*DT(Dpsi - td)
    return res

def hessianF(psi,psi1,psi2,data):
    Dpsi = D(psi)
    Dpsi1 = D(psi1)
    Dpsi2 = D(psi2)
    l0 = Dpsi/np.abs(Dpsi)
    d0 = data/np.abs(Dpsi)
    v1 = np.sum((1-d0)*np.real(Dpsi1*np.conj(Dpsi2)))
    v2 = np.sum(d0*np.real(l0*np.conj(Dpsi1))*np.real(l0*np.conj(Dpsi2)))    
    return 2*(v1+v2)

In [None]:
def cg_holo(data, psi, niter):

    data = cp.sqrt(data)    

    for i in range(niter): 
        gradpsi = gradientF(psi,data)
        
        if i==0:
            etapsi = -gradpsi
        else:      
            beta = hessianF(psi,gradpsi,etapsi,data)/\
                   hessianF(psi,etapsi, etapsi,data)                        
            etapsi = -gradpsi + beta*etapsi
            
        top = -cp.sum(cp.real(gradpsi*cp.conj(etapsi)))               
        bottom = hessianF(psi,etapsi,etapsi,data)
        alpha = top/bottom        
        
        psi += alpha*etapsi
        
        if i % 32 == 0:
            mshow_polar(psi)
            
        Dpsi = D(psi)
        err = cp.linalg.norm(cp.abs(Dpsi)-data)**2
        print(f'{i}) {alpha=:.5f}, {err=:1.5e}')
            
    return psi

psi = rec_paganin#cp.ones([n,n],dtype='complex64')
psi = cg_holo(data,psi,129)   