In [None]:
import numpy as np
import cupy as cp
import h5py
from holotomocupy.holo import G, GT
from holotomocupy.shift import S, ST
from holotomocupy.recon_methods import multiPaganin
from holotomocupy.utils import *
from holotomocupy.proc import remove_outliers
# Use managed memory
# cp.cuda.set_allocator(cp.cuda.MemoryPool(cp.cuda.malloc_managed).malloc)

# Init data sizes and parametes of the PXM of ID16A

In [None]:
n = 2048  # object size in each dimension
detector_pixelsize = 3.03751e-6
energy = 33.35  # [keV] xray energy
wavelength = 1.2398419840550367e-09/energy  # [m] wave length
focusToDetectorDistance = 1.28  # [m]
sx0 = 1.286e-3
z1 = np.array([4.236e-3,4.3625e-3,4.86850e-3,5.91950e-3])-sx0
z2 = focusToDetectorDistance-z1
distances = (z1*z2)/focusToDetectorDistance
magnifications = focusToDetectorDistance/z1
norm_magnifications = magnifications/magnifications[0]
voxelsize = np.abs(detector_pixelsize/magnifications[0]*2048/n)  # object voxel size

show = True
path = f'/data/vnikitin/ESRF/ID16A/20240924/AtomiumS2/'
pfile = f'AtomiumS2_HT_007nm'
path_out = f'/data/vnikitin/ESRF/ID16A/20240924_rec/AtomiumS2/'
print(f'{voxelsize=}')
ntheta=1800
ndist=4
st=0

## Read data

In [None]:
import scipy.ndimage as ndimage
def remove_outliers(data, dezinger, dezinger_threshold):    
    res = data.copy()
    if (int(dezinger) > 0):
        w = int(dezinger)
        # print(data.shape)
        fdata = ndimage.median_filter(data, [1,w, w])
        print(np.sum(np.abs(data-fdata)>fdata*dezinger_threshold))
        res[:] = np.where(np.abs(data-fdata)>fdata*dezinger_threshold, fdata, data)
    return res

In [None]:
ref00 = np.zeros([ndist,2048,2048],dtype='float32')
ref01 = np.zeros([ndist,2048,2048],dtype='float32')
dark00 = np.zeros([ndist,2048,2048],dtype='float32')
    
mmeans = np.zeros(8)

for k in range(ndist):
    
    tmp = np.zeros([n,n],dtype='float32')
    for l in range(20):
        fname=f'{path}{pfile}_{k+1}_/ref{l:04}_0000.edf'
        #print(fname)
        tmp += dxchange.read_edf(fname)[0]
    tmp/=20
    ref00[k] = tmp

    tmp = np.zeros([n,n],dtype='float32')
    for l in range(20):
        fname = f'{path}{pfile}_{k+1}_/ref{l:04}_1800.edf'
        #print(fname)
        tmp += dxchange.read_edf(fname)[0]
    tmp/=20
    ref01[k] = tmp

    tmp = np.zeros([n,n],dtype='float32')
    for l in range(20):
        fname = f'{path}{pfile}_{k+1}_/darkend{l:04}.edf'
        #print(fname)
        tmp += dxchange.read_edf(fname)[0]
    tmp/=20

    dark00[k] = tmp



In [None]:
ref = ref00.copy()
dark = dark00.copy()
ref-=dark

ref[ref<0]=0

ref[:,1320//3:1320//3+25//3,890//3:890//3+25//3] = ref[:,1280//3:1280//3+25//3,890//3:890//3+25//3]
    
ref[:] = remove_outliers(ref[:], 3, 0.9)     

dark/=np.mean(ref)
ref/=np.mean(ref)

for k in range(int(np.log2(2048//n))):
    ref = (ref[:,::2]+ref[:,1::2])*0.5
    ref = (ref[:,:,::2]+ref[:,:,1::2])*0.5    
    dark = (dark[:,::2]+dark[:,1::2])*0.5
    dark = (dark[:,:,::2]+dark[:,:,1::2])*0.5    



In [None]:
def apply_shift(psi, p):
    """Apply shift for all projections."""
    psi = cp.array(psi)
    p = cp.array(p)
    tmp = cp.pad(psi,((0,0),(n//2,n//2),(n//2,n//2)), 'symmetric')
    [x, y] = cp.meshgrid(cp.fft.rfftfreq(2*n),
                         cp.fft.fftfreq(2*n))
    shift = cp.exp(-2*cp.pi*1j *
                   (x*p[:, 1, None, None]+y*p[:, 0, None, None]))
    res0 = cp.fft.irfft2(shift*cp.fft.rfft2(tmp))
    res = res0[:, n//2:3*n//2, n//2:3*n//2].get()
    return res

def _upsampled_dft(data, ups,
                   upsample_factor=1, axis_offsets=None):

    im2pi = 1j * 2 * cp.pi
    tdata = data.copy()
    kernel = (cp.tile(cp.arange(ups), (data.shape[0], 1))-axis_offsets[:, 1:2])[
        :, :, None]*cp.fft.fftfreq(data.shape[2], upsample_factor)
    kernel = cp.exp(-im2pi * kernel)
    tdata = cp.einsum('ijk,ipk->ijp', kernel, tdata)
    kernel = (cp.tile(cp.arange(ups), (data.shape[0], 1))-axis_offsets[:, 0:1])[
        :, :, None]*cp.fft.fftfreq(data.shape[1], upsample_factor)
    kernel = cp.exp(-im2pi * kernel)
    rec = cp.einsum('ijk,ipk->ijp', kernel, tdata)

    return rec

def registration_shift(src_image, target_image, upsample_factor=1, space="real"):

    src_image=cp.array(src_image)
    target_image=cp.array(target_image)
    # assume complex data is already in Fourier space
    if space.lower() == 'fourier':
        src_freq = src_image
        target_freq = target_image
    # real data needs to be fft'd.
    elif space.lower() == 'real':
        src_freq = cp.fft.fft2(src_image)
        target_freq = cp.fft.fft2(target_image)

    # Whole-pixel shift - Compute cross-correlation by an IFFT
    shape = src_freq.shape
    image_product = src_freq * target_freq.conj()
    cross_correlation = cp.fft.ifft2(image_product)
    A = cp.abs(cross_correlation)
    maxima = A.reshape(A.shape[0], -1).argmax(1)
    maxima = cp.column_stack(cp.unravel_index(maxima, A[0, :, :].shape))

    midpoints = cp.array([cp.fix(axis_size / 2)
                          for axis_size in shape[1:]])

    shifts = cp.array(maxima, dtype=cp.float64)
    ids = cp.where(shifts[:, 0] > midpoints[0])
    shifts[ids[0], 0] -= shape[1]
    ids = cp.where(shifts[:, 1] > midpoints[1])
    shifts[ids[0], 1] -= shape[2]
    
    if upsample_factor > 1:
        # Initial shift estimate in upsampled grid
        shifts = cp.round(shifts * upsample_factor) / upsample_factor
        upsampled_region_size = cp.ceil(upsample_factor * 1.5)
        # Center of output array at dftshift + 1
        dftshift = cp.fix(upsampled_region_size / 2.0)

        normalization = (src_freq[0].size * upsample_factor ** 2)
        # Matrix multiply DFT around the current shift estimate

        sample_region_offset = dftshift - shifts*upsample_factor
        cross_correlation = _upsampled_dft(image_product.conj(),
                                                upsampled_region_size,
                                                upsample_factor,
                                                sample_region_offset).conj()
        cross_correlation /= normalization
        # Locate maximum and map back to original pixel grid
        A = cp.abs(cross_correlation)
        maxima = A.reshape(A.shape[0], -1).argmax(1)
        maxima = cp.column_stack(
            cp.unravel_index(maxima, A[0, :, :].shape))

        maxima = cp.array(maxima, dtype=cp.float64) - dftshift

        shifts = shifts + maxima / upsample_factor
           
    return shifts

shifts = cp.zeros([ndist,2],dtype='float32')
ref_shifted=ref.copy()
print(ref.shape)
for k in range(ndist):
    a=cp.fft.fft2(cp.array(ref_shifted[k:k+1]))
    shifts[k] = registration_shift(ref_shifted[k:k+1],ref[0:1],upsample_factor=1000)
print(shifts)    
# shifts = cp.median(shifts,axis=0)+cp.array(shifts_random)

ref_shifted_check = ref_shifted.copy()
for k in range(ndist):
    ref_shifted_check[k:k+1] = apply_shift(ref_shifted_check[k:k+1],-shifts[k:k+1])
    mshow_complex(ref_shifted[k]-ref_shifted[0]+1j*(ref_shifted_check[k]-ref_shifted_check[0]),show,vmax=0.05,vmin=-0.05)
