In [None]:
import numpy as np
import cupy as cp
import sys
import scipy.ndimage as ndimage
import matplotlib.pyplot as plt
from types import SimpleNamespace
import h5py
import dxchange
import warnings
import pandas as pd
warnings.filterwarnings("ignore", message=f".*peer.*")

sys.path.insert(0, '..')
from utils import *
from rec import Rec

In [None]:
bin = 3
ntheta = 360
path = f'/data/vnikitin/ESRF/ID16A/20240924_rec_ca/data/'
with  h5py.File(f'{path}/data_atomium.h5','r') as fid:
    code = fid[f'/exchange/code'][:].astype('complex64')
    q = fid[f'/exchange/prb'][:].astype('complex64')    
    shifts_cor = fid[f'/exchange/shifts_cor'][::7200//ntheta].astype('float32')
    shifts_code = fid[f'/exchange/shifts_code'][::7200//ntheta].astype('float32')    
    z1 = fid['/exchange/z1'][0]
    detector_pixelsize = fid['/exchange/detector_pixelsize'][0]
    focusdetectordistance = fid['/exchange/focusdetectordistance'][0]
    energy = fid['/exchange/energy'][0]

with  h5py.File(f'{path}/data_ca.h5','r') as fid:
    z1c = fid['/exchange/z1'][0] 

for k in range(bin):
    q = 0.5*(q[::2]+q[1::2])
    q = 0.5*(q[:,::2]+q[:,1::2])
    code = 0.5*(code[::2]+code[1::2])
    code = 0.5*(code[:,::2]+code[:,1::2])
shifts_cor/=2**bin
shifts_code/=2**bin
# v = np.arange(-q.shape[-1]//2,q.shape[-1]//2)/q.shape[-1]
# [vx,vy] = np.meshgrid(v,v)
# v=np.exp(-20*(vx**2+vy**2))
# q = np.fft.fftshift(np.fft.fftn(np.fft.fftshift(q)))
# q = np.fft.fftshift(np.fft.ifftn(np.fft.fftshift(q*v)))
# q = q.astype('complex64')

# v = np.arange(-code.shape[-1]//2,code.shape[-1]//2)/q.shape[-1]
# [vx,vy] = np.meshgrid(v,v)
# v=np.exp(-0.2*(vx**2+vy**2))
# code = np.fft.fftshift(np.fft.fftn(np.fft.fftshift(code)))
# code = np.fft.fftshift(np.fft.ifftn(np.fft.fftshift(code*v)))
# code = code.astype('complex64')
cp.save('shifts_code',shifts_code)
cp.save('code',code)
cp.save('q',q)
print(energy)
print(z1)
print(detector_pixelsize)

In [None]:
wavelength = 1.24e-09/energy  # [m] wave length
focusToDetectorDistance = 1.28  # [m]
z2 = focusToDetectorDistance-z1
distance = (z1*z2)/focusToDetectorDistance
magnification = focusToDetectorDistance/z1
voxelsize = np.abs(detector_pixelsize/magnification)  # object voxel size
magnifications2 = z1/z1c
distancec = (z1-z1c)/(z1c/z1)
show = True


In [None]:
args = SimpleNamespace()
args.ngpus = 1
args.lam = 50

args.n = 2048//2**bin
voxelsize *= 2048/args.n

args.ntheta = ntheta
args.ncode = 8192*args.n//2048
args.pad = 32
args.npsi = args.n + 2 * args.pad
args.nq = args.n + 2 * args.pad
args.ex = 0
args.npatch = args.nq + 2 * args.ex
args.nchunk = 32

args.voxelsize = voxelsize
args.wavelength = wavelength
args.distance = distance
args.distancec = distancec
args.eps = 1e-8
args.rho = [1, 20, 10]
args.path_out = f"/data/vnikitin/ESRF/ID16A/20240924_rec_ca/rec_atomium_syn/r_{args.lam}_{args.pad}"
args.show = show

args.niter=128
args.vis_step=4
args.err_step=4
args.rotation_axis=args.npsi/2

args.theta = cp.linspace(0,2*np.pi,ntheta).astype('float32')
# create class
cl_rec = Rec(args)
print(voxelsize,distance,distancec)

In [None]:
q = np.pad(q,((args.pad,args.pad),(args.pad,args.pad)),'symmetric')

In [None]:
u = 4e3*cp.load(f'/data/vnikitin/syn_3d_ald/u{args.n}.npy').astype('complex64')
u = np.pad(u,((args.pad,args.pad),(args.pad,args.pad),(args.pad,args.pad)),'edge')
shifts_code = cp.array(shifts_code)
shifts_cor = cp.array(shifts_cor)
code = cp.array(code)
q = cp.array(q)

ri = shifts_code.astype('int32')
r = shifts_code-shifts_code.astype('int32')
rpsi = shifts_cor*0
data = cp.abs(cl_rec.D(cl_rec.Dc(q*cl_rec.S(ri,r,code))*cl_rec.expR(cl_rec.Spsi(cl_rec.R(u),rpsi))))**2


mshow(data[0],args.show)
mshow_complex(cp.fft.fftshift(cl_rec.fker[0]),args.show)

In [None]:
data_ref = cp.abs(cl_rec.D(cl_rec.Dc(q*cl_rec.S(ri,r,code))*cl_rec.expR(cl_rec.Spsi(cl_rec.R(u*0),rpsi))))**2
mshow(data_ref[0],args.show)

In [None]:
srrdata=data/data_ref
mshow(srrdata[0],True)

In [None]:
def multiPaganin(data, distances, wavelength, voxelsize, delta_beta,  alpha):    
    
    fx = cp.fft.fftfreq(data.shape[-1], d=voxelsize).astype('float32')
    [fx, fy] = cp.meshgrid(fx, fx)
    numerator = 0
    denominator = 0
    rad_freq = cp.fft.fft2(data)
    taylorExp = 1 + wavelength * distances * cp.pi * (delta_beta) * (fx**2+fy**2)
    numerator = numerator + taylorExp * (rad_freq)
    denominator = denominator + taylorExp**2

    denominator = (denominator) + alpha

    phase = cp.log(cp.real(cp.fft.ifft2(numerator / denominator)))
    phase = (delta_beta) * 0.5 * phase

    return phase

def rec_init(rdata):
    recMultiPaganin = np.zeros([args.ntheta,args.nq, args.nq], dtype="float32")
    for j in range(0, rdata.shape[0]):
        r = cp.pad(cp.array(rdata[j]),((args.pad,args.pad),(args.pad,args.pad)),'edge')
        distances_pag = (distance)
        r = multiPaganin(r, distances_pag,wavelength, voxelsize,100,1e-8)             
        recMultiPaganin[j] = r.get()           
        # recMultiPaganin[j]-=np.mean(recMultiPaganin[j,:32,:32])
    recMultiPaganin = np.exp(1j * recMultiPaganin)
    return recMultiPaganin

psi_init = rec_init(srrdata[:])
# mpad = args.npsi//2-args.nq//2
# psi_init = np.pad(psi_init,((0,0),(mpad,mpad),(mpad,mpad)),'edge')
mshow_polar(psi_init[0],args.show)

In [None]:

u_init = cl_rec.rec_tomo(cl_rec.logR(psi_init),rpsi,16)
mshow_complex(u_init[args.npsi//2],args.show)

In [None]:
t = cl_rec.S(ri,r,code)
mshow_polar(t[3,:100,:100],True)

In [None]:
ref = cp.abs(cl_rec.D(cl_rec.Dc(q*cl_rec.S(ri,r,code*0+1))*cl_rec.expR(cl_rec.Spsi(cl_rec.R(u*0),rpsi))))**2
cpad = args.pad*3//2
q_init = cl_rec.DcT(cl_rec.DT(cp.sqrt(ref)))[0,cpad:-cpad,cpad:-cpad]
q_init = cp.pad(q_init,((cpad,cpad),(cpad,cpad)),'symmetric')
mshow_polar(q_init,args.show)

In [None]:

vars = {}
vars["code"] = cp.array(code)
vars["u"] = cp.array(u_init)
vars["q"] = cp.array(q_init)
vars["ri"] = shifts_code.astype("int32")
vars["r_init"] = shifts_code - vars["ri"].astype("int32")
vars["r"] = vars["r_init"]/2
vars["rpsi"] = cp.array(shifts_cor).astype("float32")
vars["Ru"] = cl_rec.Spsi(cl_rec.R(vars['u']),vars["rpsi"])
vars["psi"] = cl_rec.expR(vars['Ru'])
vars["table"] = pd.DataFrame(columns=["iter", "err", "time"])
cl_rec.rho = [1, 10,10]
cl_rec.lam = 0 
cl_rec.vis_step=32
cl_rec.err_step=32
cl_rec.eps=0
cl_rec.niter=8000
vars = cl_rec.BH(data, vars)



In [None]:
print(vars['table'])