In [None]:
import numpy as np
import cupy as cp
import sys
import pandas as pd
import time
import matplotlib.pyplot as plt
import h5py
from types import SimpleNamespace
import warnings
warnings.filterwarnings("ignore", message=f".*peer.*")

sys.path.insert(0, '..')
from utils import *
from rec import Rec

# Init data sizes and parametes of the PXM of ID16A

In [None]:
n = 256  # object size in each dimension
ntheta = 180  # number of angles (rotations)

theta = np.linspace(0, np.pi, ntheta).astype('float32')  # projection angles

# ID16a setup
ndist = 4

detector_pixelsize = 3e-6
energy = 17.35  # [keV] xray energy
wavelength = 1.2398419840550367e-09/energy  # [m] wave length

focusToDetectorDistance = 1.208  # [m]
sx0 = -2.493e-3
z1 = np.array([1.5335e-3, 1.7065e-3, 2.3975e-3, 3.8320e-3])[:ndist]-sx0
z2 = focusToDetectorDistance-z1
distances = (z1*z2)/focusToDetectorDistance
magnifications = focusToDetectorDistance/z1
voxelsize = detector_pixelsize/magnifications[0]*2048/n  # object voxel size

norm_magnifications = magnifications/magnifications[0]
# scaled propagation distances due to magnified probes
distances = distances*norm_magnifications**2

# z1p = z1[0]  # positions of the probe for reconstruction
# z2p = z1-np.tile(z1p, len(z1))
# # magnification when propagating from the probe plane to the detector
# magnifications2 = (z1p+z2p)/z1p
# # propagation distances after switching from the point source wave to plane wave,
# distances2 = (z1p*z2p)/(z1p+z2p)
# norm_magnifications2 = magnifications2/(z1p/z1[0])  # normalized magnifications
# # scaled propagation distances due to magnified probes
# distances2 = distances2*norm_magnifications2**2
# distances2 = distances2*(z1p/z1)**2

# allow padding if there are shifts of the probe
pad = 0
show=True
# sample size after demagnification
npsi = int(np.ceil((n+2*pad)/norm_magnifications[-1]/8))*8  # make multiple of 8
print(distances)

In [None]:
args = SimpleNamespace()
args.npos = 1
args.ngpus = 2

args.n = n
args.ndist = ndist
args.ntheta = ntheta
args.pad = 0
args.npsi = npsi
args.nq = args.n + 2 * args.pad
args.nchunk = 32

args.voxelsize = voxelsize
args.wavelength = wavelength
args.distance = distances
args.eps = 1e-12
args.rho = [1, 20, 10]
args.path_out = f"/data/vnikitin/ESRF/ID16A/20240924_rec0224//AtomiumS2/HT/s1"
args.show = True

args.niter=10000
args.vis_step=1
args.err_step=1
args.method = "BH-CG"
args.rotation_axis=args.npsi/2

args.theta = theta
args.norm_magnifications = norm_magnifications
# create class
cl_rec = Rec(args)


## Read real and imaginary parts of the refractive index u = delta+i beta

In [None]:
data = np.zeros([ntheta,ndist,n,n],dtype='float32')
ref = np.zeros([ndist,n,n],dtype='float32')
for k in range(len(distances)):
    data[:,k] = read_tiff(f'/data/vnikitin/syn/data/data_{k}.tiff')
for k in range(len(distances)):
    ref[k] = read_tiff(f'/data/vnikitin/syn/ref_{k}.tiff')
r = np.load(f'/data/vnikitin/syn/r.npy')

for k in range(ndist):
    mshow(data[0,k]/ref[k],show)

In [None]:
def multiPaganin(data, distances, wavelength, voxelsize, delta_beta,  alpha):    
    
    fx = cp.fft.fftfreq(data.shape[-1], d=voxelsize).astype('float32')
    [fx, fy] = cp.meshgrid(fx, fx)
    numerator = 0
    denominator = 0
    for j in range(data.shape[0]):
        
        rad_freq = cp.fft.fft2(data[j])
        taylorExp = 1 + wavelength * distances[j] * cp.pi * (delta_beta) * (fx**2+fy**2)
        numerator = numerator + taylorExp * (rad_freq)
        denominator = denominator + taylorExp**2

    numerator = numerator / len(distances)
    denominator = (denominator / len(distances)) + alpha

    phase = cp.log(cp.real(cp.fft.ifft2(numerator / denominator)))
    phase = (delta_beta) * 0.5 * phase

    return phase


def rec_init(rdata):
    recMultiPaganin = np.zeros([args.ntheta,args.nq, args.nq], dtype="float32")
    for j in range(0, args.ntheta):
        r = cp.array(rdata[j])
        distances_pag = (distances/norm_magnifications**2)
        r = multiPaganin(r, distances_pag,wavelength, voxelsize, 15, 1e-4)             
        recMultiPaganin[j] = r.get()           
        recMultiPaganin[j]-=np.mean(recMultiPaganin[j,:32,:32])
    recMultiPaganin = np.exp(1j * recMultiPaganin)
    return recMultiPaganin

rdata = data/ref
srdata = np.zeros([ntheta,ndist,args.nq,args.nq],dtype='complex64')
for j in range(ndist):
    tmp=cl_rec.ST(r[:,j]*norm_magnifications[j],rdata[:,j].astype('complex64'))
    tmp = cl_rec.MT(tmp,j)    
    tmp = tmp[:,args.npsi//2-args.nq//2:args.npsi//2+args.nq//2,args.npsi//2-args.nq//2:args.npsi//2+args.nq//2]
    srdata[:,j]=tmp#np.pad(tmp,((0,0),(mpad,mpad),(mpad,mpad)),'edge')
    
srdata=srdata.real


psi_init = rec_init(srdata)
mpad = args.npsi//2-args.nq//2
psi_init = np.pad(psi_init,((0,0),(mpad,mpad),(mpad,mpad)),'edge')
# mshow(srdata[0,0],args.show)
# mshow(srdata[0,2],args.show)
mshow_polar(psi_init[0],args.show)
# mshow_polar(psi_init[45],args.show)
# mshow_polar(psi_init[90],args.show)
# mshow_polar(psi_init[120],args.show)

In [None]:
# psi_data = np.log(psi_init)/1j
# u_init = cl_rec.rec_tomo(psi_data,32)


In [None]:
# mshow_complex(u_init[u_init.shape[0]//2],True)
# np.save('/local/tmp/u',u_init)
u_init = np.load('/local/tmp/u.npy')

In [None]:
q_init = np.ones([4,args.nq,args.nq],dtype='complex64')
for j in range(ndist):
    q_init[j] = cl_rec.DT(np.sqrt(ref[j:j+1]),j)[0]

mshow_polar(q_init[0],show)
mshow_polar(q_init[-2],show)

q_init0 = np.ones([4,args.nq,args.nq],dtype='complex64')
for k in range(len(distances)):
    q_abs = read_tiff(f'/data/vnikitin/syn/q_abs_{k}.tiff')
    q_angle = read_tiff(f'/data/vnikitin/syn/q_angle_{k}.tiff')
    q_init0[k] = q_abs*np.exp(1j*q_angle)
mshow_polar(q_init0[0],show)
mshow_polar(q_init0[-2],show)

In [None]:
a = np.random.random([ntheta,npsi,npsi]).astype('float32')+1j*np.random.random([ntheta,npsi,npsi]).astype('float32')
b = cl_rec.M(a,2)
c = cl_rec.MT(b,2)
print(np.sum(a*np.conj(c)))
print(np.sum(b*np.conj(b)))

b = cl_rec.S(r[:,2],a)
c = cl_rec.ST(r[:,2],b)
print(np.sum(a*np.conj(c)))
print(np.sum(b*np.conj(b)))

a = np.random.random([ntheta,args.nq,args.nq]).astype('float32')+1j*np.random.random([ntheta,args.nq,args.nq]).astype('float32')
b = cl_rec.D(a,j)
c = cl_rec.DT(b,j)
print(np.sum(a*np.conj(c)))
print(np.sum(b*np.conj(b)))

a = np.random.random([npsi,npsi,npsi]).astype('float32')+1j*np.random.random([npsi,npsi,npsi]).astype('float32')
b = cl_rec.R(a)
c = cl_rec.RT(b)
print(np.sum(a*np.conj(c)))
print(np.sum(b*np.conj(b)))




In [None]:
print(cl_rec.theta)
mshow_complex(b[:,npsi//2],True)

In [None]:
u_init0 = np.load('data/u.npy') 
u_init0*=5000

In [None]:
np.random.seed(10)
t = np.linspace(-1,1,ntheta).astype('float32')
t = t**2*3

rerr = (np.random.random(r.shape).astype('float32')-0.5)
rerr[:,:,0] += t[:,np.newaxis] * norm_magnifications[np.newaxis]
print(norm_magnifications) 

plt.plot(rerr[:,:,0],'.')
plt.plot(rerr[:,:,1],'x')
# cp.random.seed(10)

In [None]:

# args.rho = [1,2,2]
# args.vis_step=16
# args.err_step=8
# args.niter=1280
# args.ngpus=4
# args.nchunk=32 
# cl_rec = Rec(args)
# vars = {}
# vars["u"] = cp.array(u_init.copy())
# vars["Ru"] = cl_rec.R(vars['u'])
# vars["psi"] = cl_rec.expR(vars['Ru'])

# vars["q"] = cp.array(q_init.copy())
# vars["r"] = cp.array(r.copy()+rerr)
# vars["r_init"] = cp.array(r.copy())
# vars["table"] = pd.DataFrame(columns=["iter", "err", "time"])
# data=cp.array(data)
# vars = cl_rec.BH(data, vars)

In [None]:
args.rho = [1,2,2]
args.vis_step=16
args.err_step=8
args.niter=4096
args.ngpus=4
args.nchunk=32
cl_rec = Rec(args)
vars = {}
vars["u"] = u_init.copy()
vars["Ru"] = cl_rec.R(vars['u'])
vars["psi"] = cl_rec.expR(vars['Ru'])

vars["q"] = cp.array(q_init.copy())
vars["r"] = r.copy() + rerr
vars["r_init"] = r.copy()
vars["table"] = pd.DataFrame(columns=["iter", "err", "time"])
vars = cl_rec.BH(data, vars)

\begin{align*}
 & D^2V|_{(q_0,u_0,{x}_0)}\big((\Delta q^{(1)}, \Delta u^{(1)},\Delta{x^{(1)}}),(\Delta q^{(2)}, \Delta u^{(2)},\Delta{x}{(2)})\big)=\\&L_1(q_0)\cdot M_j(2T_{e^{i R (u_0)}\cdot(-\frac{1}{2}(R({\Delta u^{(1)}})R({\Delta u^{(2)}})))}({{z}_0})+DT_{e^{i R (u_0)}\cdot\big(iR({\Delta u^{(1)}})\big)}|_{{{z}_0}}( \Delta {z}^{(2)})+DT_{e^{i R (u_0)}\cdot\big(iR({\Delta u^{(2)}})\big)}|_{{{z}_0}}( \Delta {z}^{(1)})+\left(D^2{T_{e^{iR(u_0)}}}(\Delta z^{(1)},\Delta z^{(2)})\right))+\\&L_1(\Delta q^{(1)})\cdot M_j(T_{e^{i R (u_0)}\cdot(iR({\Delta u^{(2)}}))}+ DT_{{e^{iR(u_0)}}}|_{{{z}_0}}( \Delta {z}^{(2)}))+L_1(\Delta q^{(2)})\cdot M_j(T_{e^{i R (u_0)}\cdot(iR({\Delta u^{(1)}}))}+ DT_{{e^{iR(u_0)}}}|_{{{z}_0}}( \Delta {z}^{(1)}))
\end{align*}

\begin{align*}
 & DV|_{(q_0,u_0,{x}_0)}(\Delta q, \Delta u,\Delta{x})=L_1(q_0)\cdot M_j(T_{e^{i R (u_0)}\cdot(iR({\Delta u}))}(z_0)+ DT_{{e^{iR(u_0)}}}|_{{{z}_0}}( \Delta {z}))+L_1(\Delta q)\cdot M_j(T_{{e^{iR(u_0)}}}({{z}_0}))
\end{align*}