In [None]:
import numpy as np
import cupy as cp
import matplotlib.pyplot as plt
from holotomocupy.proc import linear, dai_yuan
from holotomocupy.tomo import R,RT
from holotomocupy.chunking import gpu_batch
from holotomocupy.utils import *

In [None]:
# st = int(sys.argv[1])
st = 0
n = 2048
ntheta = 900
show = True

# ID16B setup
ndist = 4
energy = 29.63  # [keV] xray energy
wavelength = 1.2398419840550367e-09/energy  # [m] wave length
detector_pixelsize= 0.65e-6
focusToDetectorDistance = 0.704433  # [m]
sx0h = 0.8525605999567023e-3; #1.077165773192669 for 75nm.
sx0v = 0.80170811624758109e-3; #1.110243284221266 for 75nm.
sx0 = 0.5*(sx0h+sx0v)
z1 = np.array([54.9598e-3, 55.96e-3, 59.1701e-3, 69.17e-3])[:ndist]-sx0
z2 = focusToDetectorDistance-z1
distances = (z1*z2)/focusToDetectorDistance
magnifications = focusToDetectorDistance/z1
voxelsize = detector_pixelsize/magnifications[0]*2048/n  # object voxel size
print(f"{voxelsize=}")
norm_magnifications = magnifications/magnifications[0]
# scaled propagation distances due to magnified probes
distances = distances*norm_magnifications**2

z1p = z1[0]  # positions of the probe for reconstruction
z2p = z1-np.tile(z1p, len(z1))
# magnification when propagating from the probe plane to the detector
magnifications2 = (z1p+z2p)/z1p
# propagation distances after switching from the point source wave to plane wave,
distances2 = (z1p*z2p)/(z1p+z2p)
norm_magnifications2 = magnifications2/(z1p/z1[0])  # normalized magnifications
# scaled propagation distances due to magnified probes
distances2 = distances2*norm_magnifications2**2
distances2 = distances2*(z1p/z1)**2

# allow padding if there are shifts of the probe
pad = n//32
# sample size after demagnification
ne = int(np.ceil((n+2*pad)/norm_magnifications[-1]/32))*32  # make multiple of 32

In [None]:
psi_angle = read_tiff(f'/data/vnikitin/ESRF/ID16B/009/033_009_50nm_rec/recmultiPaganin/r.tiff')[:]
psi = np.exp(1j*psi_angle)


In [None]:
import h5py
with h5py.File('/data/vnikitin/ESRF/ID16B/009/033_009_50nm_rec/recmultiPaganin/r.h5','w') as fid:
    fid.create_dataset('/exchange/data',data = psi_angle)
    fid.create_dataset('/exchange/data_white',data = np.ones([1,*psi_angle.shape[1:]],dtype='float32'))
    fid.create_dataset('/exchange/data_dark',data = np.zeros([1,*psi_angle.shape[1:]],dtype='float32'))
    fid.create_dataset('/exchange/theta',data = np.linspace(0,360,900).astype('float32'))




In [None]:
plt.imshow(np.angle(psi[-1]),cmap='gray')
plt.colorbar()

In [None]:
data = -1j*wavelength/(2*np.pi)*np.log(psi)/voxelsize

mshow_complex(data[0])


In [None]:


def line_search(minf, gamma, fu, fd):
    """ Line search for the step sizes gamma"""
    while (minf(fu)-minf(fu+gamma*fd) < 0 and gamma > 1e-12):
        gamma *= 0.5
    if (gamma <= 1e-12):  # direction not found
        # print('no direction')
        gamma = 0
    return gamma

def cg_tomo(data, init, pars):
    """Conjugate gradients method for tomogarphy"""
    # minimization functional
    @gpu_batch
    def _minf(Ru,data):
        res = cp.empty(data.shape[0],dtype='float32')
        for k in range(data.shape[0]):
            res[k] = np.linalg.norm(Ru[k]-data[k])**2
        return res
    
    def minf(Ru):
        res = np.sum(_minf(Ru,data))
        return res
    
    u = init.copy()
    conv = np.zeros(1+pars['niter']//pars['err_step'])
    
    for i in range(pars['niter']):
        fu = R(u,theta,center*ne/n)
        grad = RT(fu-data,theta,center*ne/n)/np.float32(ne*ntheta)
        # Dai-Yuan direction
        if i == 0:
            d = -grad
        else:
            d = dai_yuan(d,grad,grad0)

        grad0 = grad
        fd = R(d, theta, center*ne/n)
        gamma = line_search(minf, pars['gamma'], fu, fd)
        u = linear(u,d,1,gamma)
        if i % pars['err_step'] == 0:
            fu = R(u, theta, center*ne/n)
            err = minf(fu)
            conv[i//pars['err_step']] = err
            print(f'{i}) {gamma=}, {err=:1.5e}')

        if i % pars['vis_step'] == 0:
            mshow_complex(u[ne//2])
            
    return u, conv


pars = {'niter': 65, 'err_step': 4, 'vis_step': 16, 'gamma': 1}

# if by chunk on gpu
# rec = np.zeros([ne,ne,ne],dtype='complex64')
# data_rec = data.swapaxes(0,1)

# if fully on gpu
rec = cp.zeros([ne,ne,ne],dtype='complex64')
data_rec = cp.array(data.swapaxes(0,1))
rec, conv = cg_tomo(data_rec, rec, pars)

In [None]:
plt.imshow(rec[ne//2,ne//4:-ne//4,ne//4:-ne//4].real.get(),cmap='gray')
plt.colorbar()