In [None]:
import numpy as np
import cupy as cp
import h5py
import matplotlib.pyplot as plt
import cupyx.scipy.ndimage as ndimage
# Use managed memory
import h5py
import sys
import warnings
warnings.filterwarnings("ignore", message=f".*peer.*")

sys.path.insert(0, '..')
from utils import *


# Init data sizes and parametes of the PXM of ID16A

In [None]:
step = 1

In [None]:
pfile = f'Y350c_HT_015nm'
path = f'/data/vnikitin/ESRF/ID16A/brain/20240515/Y350c'
path_out = f'/data/vnikitin/ESRF/ID16A/brain_rec/20240515/Y350c'
with h5py.File(f'{path_out}/{pfile}.h5') as fid:
    detector_pixelsize = fid['/exchange/detector_pixelsize'][0]    
    focusToDetectorDistance = fid['/exchange/focusdetectordistance'][0]    
    z1 = fid['/exchange/z1'][:]        
    theta = fid['/exchange/theta'][::step]
    shifts = fid['/exchange/shifts'][::step]
    attrs = fid['/exchange/attrs'][::step]
    pos_shifts = fid['/exchange/pos_shifts'][::step]*1e-6
    shape = fid['/exchange/data0'][::step].shape
    shape_ref = fid['/exchange/data_white_start0'].shape
    shape_dark = fid['/exchange/data_dark0'].shape
    #pos_shifts-=pos_shifts[0]


In [None]:
ndist=4
ntheta,n = shape[:2]
ndark = shape_dark[0]
nref = shape_ref[0]

In [None]:
print(ndist,ntheta,n)
print(nref,ndark)

In [None]:
energy = 33.35  # [keV] xray energy
wavelength = 1.2398419840550367e-09/energy  # [m] wave length
z2 = focusToDetectorDistance-z1
distances = (z1*z2)/focusToDetectorDistance
magnifications = focusToDetectorDistance/z1
norm_magnifications = magnifications/magnifications[0]
show = True


In [None]:
shifts_random = np.zeros([ntheta,ndist,2],dtype='float32')
# shifts_random_test = np.zeros([ntheta,ndist,2],dtype='float32')
for k in range(ndist):
    # shifts_random_test[:,k] = np.loadtxt(f'{path}/{pfile}_{k+1}_/correct.txt')[:ntheta].astype('float32')/norm_magnifications[k]    
    #s = np.loadtxt(f'{path}{pfile}_{k+1}_/correct.txt').astype('float32')[st:st+ntheta]/norm_magnifications[k]        
    shifts_random[:,k,0] = shifts[:,k,1]/norm_magnifications[k]    #+(1024-(2048+0-0)/2)*(1/norm_magnifications[k]-1)#/norm_magnifications[k]
    shifts_random[:,k,1] = shifts[:,k,0]/norm_magnifications[k]    #+(1024-(2048+0-0)/2)*(1/norm_magnifications[k]-1)#/norm_magnifications[k]
# plt.plot(shifts_random[:,0,1])
print(shifts_random[:,0])
# print(shifts_random_test[:,0])
# plt.show()

In [None]:
def apply_shift(psi, p):
    """Apply shift for all projections."""
    psi = cp.array(psi)
    p = cp.array(p)
    tmp = cp.pad(psi,((0,0),(n//2,n//2),(n//2,n//2)), 'symmetric')
    [x, y] = cp.meshgrid(cp.fft.rfftfreq(2*n),
                         cp.fft.fftfreq(2*n))
    shift = cp.exp(-2*cp.pi*1j *
                   (x*p[:, 1, None, None]+y*p[:, 0, None, None]))
    res0 = cp.fft.irfft2(shift*cp.fft.rfft2(tmp))
    res = res0[:, n//2:3*n//2, n//2:3*n//2].get()
    return res

def _upsampled_dft(data, ups,
                   upsample_factor=1, axis_offsets=None):

    im2pi = 1j * 2 * cp.pi
    tdata = data.copy()
    kernel = (cp.tile(cp.arange(ups), (data.shape[0], 1))-axis_offsets[:, 1:2])[
        :, :, None]*cp.fft.fftfreq(data.shape[2], upsample_factor)
    kernel = cp.exp(-im2pi * kernel)
    tdata = cp.einsum('ijk,ipk->ijp', kernel, tdata)
    kernel = (cp.tile(cp.arange(ups), (data.shape[0], 1))-axis_offsets[:, 0:1])[
        :, :, None]*cp.fft.fftfreq(data.shape[1], upsample_factor)
    kernel = cp.exp(-im2pi * kernel)
    rec = cp.einsum('ijk,ipk->ijp', kernel, tdata)

    return rec

def registration_shift(src_image, target_image, upsample_factor=1, space="real"):
    # print(src_image.shape)
    src_image=cp.array(src_image)
    target_image=cp.array(target_image)
    # assume complex data is already in Fourier space
    if space.lower() == 'fourier':
        src_freq = src_image
        target_freq = target_image
    # real data needs to be fft'd.
    elif space.lower() == 'real':
        src_freq = cp.fft.fft2(src_image)
        target_freq = cp.fft.fft2(target_image)

    # Whole-pixel shift - Compute cross-correlation by an IFFT
    shape = src_freq.shape
    image_product = src_freq * target_freq.conj()
    cross_correlation = cp.fft.ifft2(image_product)
    A = cp.abs(cross_correlation)
    maxima = A.reshape(A.shape[0], -1).argmax(1)
    maxima = cp.column_stack(cp.unravel_index(maxima, A[0, :, :].shape))

    midpoints = cp.array([cp.fix(axis_size / 2)
                          for axis_size in shape[1:]])

    shifts = cp.array(maxima, dtype=cp.float64)
    ids = cp.where(shifts[:, 0] > midpoints[0])
    shifts[ids[0], 0] -= shape[1]
    ids = cp.where(shifts[:, 1] > midpoints[1])
    shifts[ids[0], 1] -= shape[2]
    
    if upsample_factor > 1:
        # Initial shift estimate in upsampled grid
        shifts = cp.round(shifts * upsample_factor) / upsample_factor
        upsampled_region_size = cp.ceil(upsample_factor * 1.5)
        # Center of output array at dftshift + 1
        dftshift = cp.fix(upsampled_region_size / 2.0)

        normalization = (src_freq[0].size * upsample_factor ** 2)
        # Matrix multiply DFT around the current shift estimate

        sample_region_offset = dftshift - shifts*upsample_factor
        cross_correlation = _upsampled_dft(image_product.conj(),
                                                upsampled_region_size,
                                                upsample_factor,
                                                sample_region_offset).conj()
        cross_correlation /= normalization
        # Locate maximum and map back to original pixel grid
        A = cp.abs(cross_correlation)
        maxima = A.reshape(A.shape[0], -1).argmax(1)
        maxima = cp.column_stack(
            cp.unravel_index(maxima, A[0, :, :].shape))

        maxima = cp.array(maxima, dtype=cp.float64) - dftshift

        shifts = shifts + maxima / upsample_factor
           
    return shifts


rdata_scaled_shifted = rdata_scaled.copy()
for j in range(100):
    for k in range(ndist):
        mstep = 3000//100//step
        rdata_scaled_shifted[j*mstep:j*mstep+mstep,k] = apply_shift(rdata_scaled_shifted[j*mstep:j*mstep+mstep,k],-shifts_random[j*mstep:j*mstep+mstep,k])
        # mshow_complex(rdata_scaled[0,k]-rdata_scaled[1,k]+1j*(rdata_scaled_shifted[0,k]-rdata_scaled_shifted[1,k]),show)

write_tiff(rdata_scaled[:,0],'/data/tmp/rdata_scaled',overwrite=True)
write_tiff(rdata_scaled_shifted[:,0],'/data/tmp/rdata_scaled_shifted0',overwrite=True)
write_tiff(rdata_scaled[:,2],'/data/tmp/rdata_scaled2',overwrite=True)
write_tiff(rdata_scaled_shifted[:,2],'/data/tmp/rdata_scaled_shifted2',overwrite=True)
shifts = cp.zeros([ntheta,ndist,2],dtype='float32')
rdata_scaled_shifted_check = rdata_scaled.copy()
for j in range(100):
    for k in range(ndist):
        mstep = 3000//100//step
        shifts[j*mstep:j*mstep+mstep,k] = registration_shift(rdata_scaled_shifted[j*mstep:j*mstep+mstep,k],rdata_scaled_shifted[j*mstep:j*mstep+mstep,0],upsample_factor=1000)
        # rdata_scaled_shifted_check[j*mstep:j*mstep+mstep,k] = apply_shift(rdata_scaled[j*mstep:j*mstep+mstep,k],-shifts[j*mstep:j*mstep+mstep,k]-shifts_random[j*mstep:j*mstep+mstep,k])
# write_tiff(rdata_scaled_shifted_check[:,0],'/data/tmp/rdata_scaled_shifted_check0',overwrite=True)
# write_tiff(rdata_scaled_shifted_check[:,2],'/data/tmp/rdata_scaled_shifted_check2',overwrite=True)

In [None]:
plt.plot(shifts[:,:,0].get())
plt.show()
plt.plot(shifts[:,:,1].get())
plt.show()

In [None]:
import scipy
pshifts = scipy.io.loadmat('/data/vnikitin/ESRF/ID16A/brain/20240515/Y350c/Y350c_HT_015nm_/rhapp_fixed.mat')['pshifts'][0,0][0]
pshifts=-pshifts.swapaxes(0,2)[:3000:step]

for k in range(1,4):
    plt.plot(pshifts[:,k,1])
    plt.plot(shifts[:,k,1].get())
    plt.plot(pshifts[:,k,0])
    plt.plot(shifts[:,k,0].get())
    plt.show()

In [None]:
#shifts = cp.median(shifts,axis=0)+cp.array(shifts_random)
shifts = cp.array(pshifts)+cp.array(shifts_random)
rdata_scaled_shifted_check=rdata_scaled.copy()
for j in range(100):
    for k in range(ndist):
        mstep = 3000//100//step
        # shifts[j*mstep:j*mstep+mstep,k] = registration_shift(rdata_scaled_shifted[j*mstep:j*mstep+mstep,k],rdata_scaled_shifted[j*mstep:j*mstep+mstep,0],upsample_factor=1000)
        rdata_scaled_shifted_check[j*mstep:j*mstep+mstep,k] = apply_shift(rdata_scaled[j*mstep:j*mstep+mstep,k],-shifts[j*mstep:j*mstep+mstep,k])



In [None]:
for k in range(ndist):
    mshow_complex(rdata_scaled_shifted[0,k]+1j*rdata_scaled_shifted[0,0],show)
    mshow_complex(rdata_scaled_shifted_check[0,k]+1j*rdata_scaled_shifted_check[0,0],show)
    mshow_complex(rdata_scaled_shifted[0,k]-rdata_scaled_shifted[0,0]+1j*(rdata_scaled_shifted_check[0,k]-rdata_scaled_shifted_check[0,0]),show)

In [None]:
s = np.loadtxt('/data/vnikitin/ESRF/ID16A/brain/20240515/Y350c/Y350c_HT_015nm_/correct_correct3D.txt')[:3000:step][:,::-1]

plt.plot(s[:,1])
plt.plot(s[:,0])
plt.show()

In [None]:

rdata_scaled_shifted_check3d = rdata_scaled_shifted_check[:,0].copy()
for j in range(100):
    mstep = 3000//100//step
    rdata_scaled_shifted_check3d[j*mstep:j*mstep+mstep] = apply_shift(rdata_scaled_shifted_check3d[j*mstep:j*mstep+mstep],-s[j*mstep:j*mstep+mstep])
    # mshow_complex(rdata_scaled[0,k]-rdata_scaled[1,k]+1j*(rdata_scaled_shifted[0,k]-rdata_scaled_shifted[1,k]),show)
write_tiff(rdata_scaled_shifted_check3d[:],'/data/tmp/rdata_scaled_shifted_check3d',overwrite=True)


In [None]:
# t = cp.linspace(-1,1.25,len(shifts))
# t = t**2*12*0
# plt.plot(t.get())
# plt.show()
# shifts_final = shifts.copy()
# for k in range(ndist):
#     shifts_final[:,k,0]=shifts[:,k,0]+t#t*norm_magnifications[k])

shifts_final = shifts.get()+s[:,np.newaxis]
final = rdata_scaled.copy()
for j in range(100):
    for k in range(ndist):
        mstep = 3000//100//step
        final[j*mstep:j*mstep+mstep,k] = apply_shift(rdata_scaled[j*mstep:j*mstep+mstep,k],-shifts_final[j*mstep:j*mstep+mstep,k])

write_tiff(final[:,0],'/data/tmp/ftmp0',overwrite=True)
write_tiff(final[:,1],'/data/tmp/ftmp1',overwrite=True)
write_tiff(final[:,2],'/data/tmp/ftmp2',overwrite=True)
write_tiff(final[:,3],'/data/tmp/ftmp3',overwrite=True)

In [None]:
shifts_final

In [None]:
with h5py.File(f'{path_out}/{pfile}_corr.h5','a') as fid:
    try:
        del fid[f'/exchange/cshifts_final']
    except:
        pass
    fid.create_dataset(f'/exchange/cshifts_final',data = shifts_final)