In [None]:
import numpy as np
import cupy as cp
import h5py
import matplotlib.pyplot as plt
import cupyx.scipy.ndimage as ndimage
from types import SimpleNamespace

# Use managed memory
import h5py
import sys
import warnings
warnings.filterwarnings("ignore", message=f".*peer.*")

sys.path.insert(0, '..')
from utils import *
from rec import Rec

# Init data sizes and parametes of the PXM of ID16A

In [None]:
step = 1
ntheta = 3000
st = 0
bin = 0
ndist=4


In [None]:
pfile = f'Y350c_HT_015nm'
path_out = f'/data/vnikitin/ESRF/ID16A/brain_rec/20240515/Y350c2'
with h5py.File(f'{path_out}/{pfile}.h5') as fid:
    detector_pixelsize = fid['/exchange/detector_pixelsize'][0]    
    focusToDetectorDistance = fid['/exchange/focusdetectordistance'][0]    
    z1 = fid['/exchange/z1'][:ndist]        
    shape = fid['/exchange/data0'].shape
    shape_ref = fid['/exchange/data_white_start0'].shape
    shape_dark = fid['/exchange/data_dark0'].shape

In [None]:

n = shape[-1]
ndark = shape_dark[0]
nref = shape_ref[0]
n//=2**bin

In [None]:
energy = 33.35  # [keV] xray energy
wavelength = 1.2398419840550367e-09/energy  # [m] wave length
z2 = focusToDetectorDistance-z1
magnifications = focusToDetectorDistance/z1
norm_magnifications = magnifications/magnifications[0]
distances = (z1*z2)/focusToDetectorDistance*norm_magnifications**2
voxelsize = detector_pixelsize/magnifications[0]*2048/n  # object voxel size
show = True

In [None]:
npsi = int(np.ceil(2048/norm_magnifications[-1]/16))*16  # make multiple of 8
npsi//=(2048//n)

In [None]:
args = SimpleNamespace()
args.ngpus = 4

args.n = n
args.ndist = ndist
args.ntheta = ntheta
args.pad = 0
args.npsi = npsi
args.nq = n + 2 * 0
args.nchunk = 2
args.lam = 0

args.voxelsize = voxelsize
args.wavelength = wavelength
args.distance = distances
args.eps = 1e-12
args.rho = [1, 2, 1]
args.path_out = f"{path_out}/s1"
args.show = True
args.norm_magnifications = norm_magnifications
print(norm_magnifications)
# create class
cl_rec = Rec(args)


In [None]:
# a = np.random.random([ntheta,npsi,npsi]).astype('float32')+1j*np.random.random([ntheta,npsi,npsi]).astype('float32')
# b = cl_rec.M(a,2)
# c = cl_rec.MT(b,2)
# nnn=np.sum(a*np.conj(c))
# print(np.sum(a*np.conj(c)))
# print(np.sum(b*np.conj(b)))

# a = np.random.random([ntheta,npsi,npsi]).astype('float32')+1j*np.random.random([ntheta,npsi,npsi]).astype('float32')
# b = cl_rec.M(a,3)
# c = cl_rec.MT(b,3)
# nnnn=np.sum(a*np.conj(c))
# print(np.sum(a*np.conj(c)))
# print(np.sum(b*np.conj(b)))

# print(nnn/nnnn,1/norm_magnifications[-1])

In [None]:
data = np.zeros([ntheta,ndist,n,n],dtype='float32')
with h5py.File(f'{path_out}/{pfile}_corr.h5') as fid:
    for k in range(ndist):
        tmp = fid[f'/exchange/data{k}'][st:step*ntheta:step].copy()
        
        for j in range(bin):
            tmp = 0.5*(tmp[:,:,::2]+tmp[:,:,1::2])
            tmp = 0.5*(tmp[:,::2,:]+tmp[:,1::2,:])        
        data[:,k]=tmp.copy()
    tmp = fid[f'/exchange/ref'][:ndist]
    for j in range(bin):
        tmp = 0.5*(tmp[...,::2]+tmp[...,1::2])
        tmp = 0.5*(tmp[...,::2,:]+tmp[...,1::2,:])
    ref=tmp
    r = fid[f'/exchange/cshifts_final'][st:step*ntheta:step,:ndist]*n/2048#/norm_magnifications[:,np.newaxis]# in init coordinates! not scaled

In [None]:
rdata = data/ref
srdata = np.zeros([ntheta,ndist,args.npsi,args.npsi],dtype='float32')
distances_pag = (distances/norm_magnifications**2)
npad=n//32
for j in np.arange(ndist)[::-1]:
    tmp = cl_rec.STa(r[:,j]*norm_magnifications[j],rdata[:,j].astype('complex64'),
                     'edge')    
    #mshow_complex(tmp[0],True)
    # tmp=cp.array(tmp)
    tmp = (cl_rec.MT(tmp,j)/norm_magnifications[j]**2).real    
    # mshow(tmp[0],True)
    st = np.where(np.abs(tmp[0])>1e-1)[0][0]+8
    
    if j==ndist-1:
         tmp = np.pad(tmp[:,st:-st,st:-st],((0,0),(st,st),(st,st)),'symmetric')
    if j<ndist-1:
        w = np.ones([args.npsi],dtype='float32')  
        v = np.linspace(0, 1, npad, endpoint=False)
        v = v**5*(126-420*v+540*v**2-315*v**3+70*v**4)             
        w[:st]=0
        w[st:st+npad] = v
        w[-st-npad:-st] = 1-v
        w[-st:]=0
        w=np.outer(w,w)
        #mshow(w,True)
        tmp=tmp*(w)+srdata[:,j+1]*(1-w)       
    srdata[:,j]=tmp
    mshow(srdata[0,j],True)
    


In [None]:

mshow(srdata[0,0],args.show,vmax=1.3,vmin=0.7)
mshow(srdata[0,ndist-1],args.show,vmax=1.3,vmin=0.7)
mshow(srdata[0,0]-srdata[0,2],args.show,vmax=0.2,vmin=-0.2)


In [None]:
def multiPaganin(data, distances, wavelength, voxelsize, delta_beta,  alpha):    
    
    fx = cp.fft.fftfreq(data.shape[-1], d=voxelsize).astype('float32')
    [fx, fy] = cp.meshgrid(fx, fx)
    numerator = 0
    denominator = 0
    for j in range(data.shape[0]):        
        rad_freq = cp.fft.fft2(data[j])
        taylorExp = 1 + wavelength * distances[j] * cp.pi * (delta_beta) * (fx**2+fy**2)
        numerator = numerator + taylorExp * (rad_freq)
        denominator = denominator + taylorExp**2

    numerator = numerator / len(distances)
    denominator = (denominator / len(distances)) + alpha

    phase = cp.log(cp.real(cp.fft.ifft2(numerator / denominator)))
    phase = (delta_beta) * 0.5 * phase

    return phase

def rec_init(rdata):
    recMultiPaganin = np.zeros([args.ntheta,args.npsi, args.npsi], dtype="float32")
    for j in range(0, args.ntheta):
        print(j)
        r = cp.array(rdata[j])
        distances_pag = (distances/norm_magnifications**2)
        r = multiPaganin(r, distances_pag,wavelength, voxelsize,120, 1e-5)            
        # r = CTFPurePhase(r, distances_pag,wavelength, voxelsize, 1e-6)             
        # r[r>0]=0
        recMultiPaganin[j] = r.get()           
        # recMultiPaganin[recMultiPaganin>0]=0
        # recMultiPaganin[j]-=np.mean(recMultiPaganin[j,:8,:8])
    recMultiPaganin = np.exp(1j * recMultiPaganin)
    return recMultiPaganin
psi_init = rec_init(srdata)
mshow_polar(psi_init[0],args.show)
mshow_polar(psi_init[-1],args.show)
# write_tiff(np.angle(psi_init),'/data/tmp/tmp5')

In [None]:
# a = np.random.random([ntheta,npsi,npsi]).astype('float32')+1j*np.random.random([ntheta,npsi,npsi]).astype('float32')
# for k in range(4):
#     b = cl_rec.M(a,k)
#     c = cl_rec.MT(b,k)
#     print(np.sum(a*np.conj(c)))
#     print(np.sum(b*np.conj(b)))

In [None]:

# b = cl_rec.S(r[:,ndist-1],a)
# c = cl_rec.ST(r[:,ndist-1],b)
# print(np.sum(a*np.conj(c)))
# print(np.sum(b*np.conj(b)))


In [None]:

# a = np.random.random([ntheta,args.nq,args.nq]).astype('float32')+1j*np.random.random([ntheta,args.nq,args.nq]).astype('float32')
# b = cl_rec.D(a,j)
# c = cl_rec.DT(b,j)
# print(np.sum(a*np.conj(c)))
# print(np.sum(b*np.conj(b)))

In [None]:


print(psi_init.shape)
with h5py.File(f'{path_out}/{pfile}_corr.h5','a') as fid:
    try:
        del fid['/exchange/psi_init_abs']
        del fid['/exchange/psi_init_angle']
    except:
        pass
    fid.create_dataset(f'/exchange/psi_init_abs',data = np.abs(psi_init))
    fid.create_dataset(f'/exchange/psi_init_angle',data = np.angle(psi_init))