In [None]:
import numpy as np
import cupy as cp
import h5py
import matplotlib.pyplot as plt
import cupyx.scipy.ndimage as ndimage
from types import SimpleNamespace

# Use managed memory
import h5py
import sys
import warnings
warnings.filterwarnings("ignore", message=f".*peer.*")

sys.path.insert(0, '..')
from utils import *
from rec import Rec

# Init data sizes and parametes of the PXM of ID16A

In [None]:
step = 1
bin = 3


In [None]:
with h5py.File(f'/data/vnikitin/ESRF/ID16A/20240924_rec/AtomiumS2/AtomiumS2_HT_007nm.h5') as fid:
    detector_pixelsize = fid['/exchange/detector_pixelsize'][0]    
    focusToDetectorDistance = fid['/exchange/focusdetectordistance'][0]    
    z1 = fid['/exchange/z1'][:]        
    theta = fid['/exchange/theta'][::step,0]
    shifts = fid['/exchange/shifts'][::step]
    attrs = fid['/exchange/attrs'][::step]
    pos_shifts = fid['/exchange/pos_shifts'][::step]*1e-6
    shape = fid['/exchange/data0'][::step].shape
    shape_ref = fid['/exchange/data_white_start0'].shape
    shape_dark = fid['/exchange/data_dark0'].shape
    #pos_shifts-=pos_shifts[0]


In [None]:
theta = theta/180*np.pi

In [None]:
ndist=4
ntheta,n = shape[:2]
ndark = shape_dark[0]
nref = shape_ref[0]

n//=2**bin

In [None]:
print(ndist,ntheta,n)
print(nref,ndark)

In [None]:
energy = 33.35  # [keV] xray energy
wavelength = 1.2398419840550367e-09/energy  # [m] wave length
z2 = focusToDetectorDistance-z1
magnifications = focusToDetectorDistance/z1
norm_magnifications = magnifications/magnifications[0]
distances = (z1*z2)/focusToDetectorDistance*norm_magnifications**2
voxelsize = detector_pixelsize/magnifications[0]*2048/n  # object voxel size
show = True


In [None]:
voxelsize/2

In [None]:
int(np.ceil((2048+2*0)/norm_magnifications[-1]/8))*8


In [None]:
pad = 0
npsi = int(np.ceil((2048+2*pad)/norm_magnifications[-1]/16))*16  # make multiple of 8
npsi//=(2048//n)

In [None]:
args = SimpleNamespace()
args.ngpus = 4

args.n = n
args.ndist = ndist
args.ntheta = ntheta
args.pad = pad
args.npsi = npsi
args.nq = n + 2 * pad
args.nchunk = 8
args.lam = 0

args.voxelsize = voxelsize
args.wavelength = wavelength
args.distance = distances
args.eps = 1e-12
args.rho = [1, 2, 1]
args.path_out = f"/data/vnikitin/ESRF/ID16A/20240924_rec/AtomiumS2/AtomiumS2_HT_007nm/s1"
args.show = True

args.niter=10000
args.vis_step=1
args.err_step=1
args.method = "BH-CG"
args.rotation_axis=(796.25+2)*n/1024#397.5*2#499.75*n//1024+npsi//2-n//2

args.theta = theta
args.norm_magnifications = norm_magnifications
# create class
cl_rec = Rec(args)

In [None]:
data = np.zeros([ntheta,ndist,n,n],dtype='float32')
with h5py.File(f'/data/vnikitin/ESRF/ID16A/20240924_rec/AtomiumS2/AtomiumS2_HT_007nm_corr.h5') as fid:
    for k in range(4):
        tmp = fid[f'/exchange/data{k}'][::step].copy()
        
        for j in range(bin):
            tmp = 0.5*(tmp[:,:,::2]+tmp[:,:,1::2])
            tmp = 0.5*(tmp[:,::2,:]+tmp[:,1::2,:])        
        data[:,k]=tmp.copy()
    tmp = fid[f'/exchange/ref'][:]
    for j in range(bin):
        tmp = 0.5*(tmp[...,::2]+tmp[...,1::2])
        tmp = 0.5*(tmp[...,::2,:]+tmp[...,1::2,:])
    ref=tmp
    r = fid[f'/exchange/cshifts_final'][::step]*n/2048#/norm_magnifications[:,np.newaxis]# in init coordinates! not scaled

In [None]:
rdata = data/ref
srdata = np.zeros([ntheta,ndist,args.nq,args.nq],dtype='complex64')
for j in range(ndist):
    tmp=cl_rec.ST(r[:,j]*norm_magnifications[j],rdata[:,j].astype('complex64'))
    tmp = cl_rec.MT(tmp,j)/norm_magnifications[j]**2    
    tmp = tmp[:,args.npsi//2-args.nq//2:args.npsi//2+args.nq//2,args.npsi//2-args.nq//2:args.npsi//2+args.nq//2]
    srdata[:,j]=tmp#np.pad(tmp,((0,0),(mpad,mpad),(mpad,mpad)),'edge')
    
srdata=srdata.real

mshow(srdata[0,0],args.show)
mshow(srdata[0,3],args.show)
mshow(srdata[0,0]-srdata[0,3],args.show)


In [None]:
def multiPaganin(data, distances, wavelength, voxelsize, delta_beta,  alpha):    
    
    fx = cp.fft.fftfreq(data.shape[-1], d=voxelsize).astype('float32')
    [fx, fy] = cp.meshgrid(fx, fx)
    numerator = 0
    denominator = 0
    for j in range(data.shape[0]):        
        rad_freq = cp.fft.fft2(data[j])
        taylorExp = 1 + wavelength * distances[j] * cp.pi * (delta_beta) * (fx**2+fy**2)
        numerator = numerator + taylorExp * (rad_freq)
        denominator = denominator + taylorExp**2

    numerator = numerator / len(distances)
    denominator = (denominator / len(distances)) + alpha

    phase = cp.log(cp.real(cp.fft.ifft2(numerator / denominator)))
    phase = (delta_beta) * 0.5 * phase

    return phase

def CTFPurePhase(data, distances, wavelength, voxelsize, alpha):   

    fx = cp.fft.fftfreq(data.shape[-1], d=voxelsize).astype('float32')
    [fx, fy] = cp.meshgrid(fx, fx)
    numerator = 0
    denominator = 0
    for j in range(data.shape[0]):
        rad_freq = cp.fft.fft2(data[j])
        taylorExp = cp.sin(cp.pi*wavelength*distances[j]*(fx**2+fy**2))
        numerator = numerator + taylorExp * (rad_freq)
        denominator = denominator + 2*taylorExp**2
    numerator = numerator / len(distances)
    denominator = (denominator / len(distances)) + alpha
    phase = cp.real(cp.fft.ifft2(numerator / denominator))
    phase = 0.5 * phase
    return phase

def rec_init(rdata):
    recMultiPaganin = np.zeros([args.ntheta,args.nq, args.nq], dtype="float32")
    for j in range(0, args.ntheta):
        r = cp.array(rdata[j])
        distances_pag = (distances/norm_magnifications**2)
        r = multiPaganin(r, distances_pag,wavelength, voxelsize,100, 5e-3)             
        recMultiPaganin[j] = r.get()           
        recMultiPaganin[j]-=np.mean(recMultiPaganin[j,:32,:32])
    recMultiPaganin = np.exp(1j * recMultiPaganin)
    return recMultiPaganin

psi_init = rec_init(srdata)
mpad = args.npsi//2-args.nq//2
psi_init = np.pad(psi_init,((0,0),(mpad,mpad),(mpad,mpad)),'edge')
mshow_polar(psi_init[0],args.show)
mshow_polar(psi_init[1],args.show)

In [None]:
a = np.random.random([ntheta,npsi,npsi]).astype('float32')+1j*np.random.random([ntheta,npsi,npsi]).astype('float32')
b = cl_rec.M(a,2)
c = cl_rec.MT(b,2)
print(np.sum(a*np.conj(c)))
print(np.sum(b*np.conj(b)))

In [None]:

b = cl_rec.S(r[:,2],a)
c = cl_rec.ST(r[:,2],b)
print(np.sum(a*np.conj(c)))
print(np.sum(b*np.conj(b)))


In [None]:

a = np.random.random([ntheta,args.nq,args.nq]).astype('float32')+1j*np.random.random([ntheta,args.nq,args.nq]).astype('float32')
b = cl_rec.D(a,j)
c = cl_rec.DT(b,j)
print(np.sum(a*np.conj(c)))
print(np.sum(b*np.conj(b)))

In [None]:
# a = np.random.random([npsi,npsi,npsi]).astype('float32')+1j*np.random.random([npsi,npsi,npsi]).astype('float32')
a = np.ones([4,npsi,npsi],dtype='float32')+1j*np.ones([4,npsi,npsi],dtype='float32')

In [None]:

b = cl_rec.R(a)
c = cl_rec.RT(b)
print(np.sum(a*np.conj(c)))
print(np.sum(b*np.conj(b)))
# mshow_complex(c[2],True)


In [None]:
psi_data = np.log(psi_init)/1j

In [None]:

#psi_data_cen=psi_data[:,npsi//2:npsi//2+2]
cl_rec.theta = np.ascontiguousarray(theta)
psi_data = np.ascontiguousarray(psi_data)
u_init = cl_rec.rec_tomo(psi_data,32)
# np.save(f'{args.path_out}/s1/u_init.npy',u_init)
mshow_complex(u_init[u_init.shape[0]//2],True)

In [None]:
# with h5py.File(f'/data/vnikitin/ESRF/ID16A/20240924_rec/AtomiumS2/AtomiumS2_HT_007nm_corr.h5','a') as fid:
#     fid['/exchange/u_init']=u_init
#     fid['/exchange/psi_data']=psi_data
with h5py.File(f'/data/vnikitin/ESRF/ID16A/20240924_rec/AtomiumS2/AtomiumS2_HT_007nm_corr.h5','a') as fid:
    del fid['/exchange/u_init_re']
    del fid['/exchange/u_init_imag']
    del fid['/exchange/psi_data_abs']
    del fid['/exchange/psi_data_angle']
    fid.create_dataset(f'/exchange/u_init_re',data = u_init.real)
    fid.create_dataset(f'/exchange/u_init_imag',data = u_init.imag)
    fid.create_dataset(f'/exchange/psi_data_abs',data = np.abs(psi_data))
    fid.create_dataset(f'/exchange/psi_data_angle',data = np.angle(psi_data))

In [None]:
# psi_data_cen = psi_data[:,psi_data.shape[1]//2:psi_data.shape[1]//2+16]
# center = args.rotation_axis

# for k in np.arange(center-2,center+2.5,0.25):
#     print(k)
#     cl_rec.rotation_axis = k
#     u = cl_rec.rec_tomo(psi_data_cen,64)
#     mshow(u[u.shape[0]//2].real,show)
#     write_tiff(u[u.shape[0]//2].real, f'{args.path_out}/test_center_new/r{k}', overwrite=True)