In [None]:
import numpy as np
import cupy as cp
import h5py
import matplotlib.pyplot as plt
import cupyx.scipy.ndimage as ndimage
from types import SimpleNamespace

# Use managed memory
import h5py
import sys
import warnings
warnings.filterwarnings("ignore", message=f".*peer.*")

sys.path.insert(0, '..')
from utils import *
from rec import Rec

# Init data sizes and parametes of the PXM of ID16A

In [None]:
step = 1

In [None]:
pfile = f'Y350c_HT_015nm'
path_out = f'/data/vnikitin/ESRF/ID16A/brain_rec/20240515/Y350c'
with h5py.File(f'{path_out}/{pfile}.h5') as fid:
    detector_pixelsize = fid['/exchange/detector_pixelsize'][0]    
    focusToDetectorDistance = fid['/exchange/focusdetectordistance'][0]    
    z1 = fid['/exchange/z1'][:]        
    theta = fid['/exchange/theta'][::step]
    shifts = fid['/exchange/shifts'][::step]
    attrs = fid['/exchange/attrs'][::step]
    pos_shifts = fid['/exchange/pos_shifts'][::step]*1e-6
    shape = fid['/exchange/data0'][::step].shape
    shape_ref = fid['/exchange/data_white_start0'].shape
    shape_dark = fid['/exchange/data_dark0'].shape
    #pos_shifts-=pos_shifts[0]


In [None]:
ndist=4
ntheta,n = shape[:2]
ndark = shape_dark[0]
nref = shape_ref[0]

In [None]:
print(ndist,ntheta,n)
print(nref,ndark)

In [None]:
energy = 33.35  # [keV] xray energy
wavelength = 1.2398419840550367e-09/energy  # [m] wave length
z2 = focusToDetectorDistance-z1
distances = (z1*z2)/focusToDetectorDistance
magnifications = focusToDetectorDistance/z1
norm_magnifications = magnifications/magnifications[0]
show = True

energy = 33.35  # [keV] xray energy
wavelength = 1.2398419840550367e-09/energy  # [m] wave length
z2 = focusToDetectorDistance-z1
magnifications = focusToDetectorDistance/z1
norm_magnifications = magnifications/magnifications[0]
distances = (z1*z2)/focusToDetectorDistance*norm_magnifications**2
voxelsize = detector_pixelsize/magnifications[0]*2048/n  # object voxel size
show = True

pad = 0
npsi = int(np.ceil((2048+2*pad)/norm_magnifications[-1]/16))*16  # make multiple of 8
# npsi+=64
rotation_axis=(879-(1616-npsi//2)//2+2.5)*n/1024#n/2#(796.25+2)*n/1024#397.5*2#499.75*n//1024+npsi//2-n//2

print(rotation_axis)
npsi//=(2048//n)

In [None]:
args = SimpleNamespace()
args.ngpus = 4

args.n = n
args.ndist = ndist
args.ntheta = ntheta
args.pad = pad
args.npsi = npsi
args.nq = n + 2 * pad
args.nchunk = 1
args.voxelsize = voxelsize
args.wavelength = wavelength
args.distance = distances
args.show = True
args.norm_magnifications = norm_magnifications

# create class
cl_rec = Rec(args)

# sss

In [None]:
data = np.empty([ntheta,ndist,n,n],dtype='float32')
with h5py.File(f'{path_out}/{pfile}_corr.h5') as fid:
    for k in range(ndist):
        data[:,k] = fid[f'/exchange/data{k}'][::step]
    ref = fid[f'/exchange/ref'][:]        

In [None]:
shifts_random = np.zeros([ntheta,ndist,2],dtype='float32')
for k in range(ndist):
    shifts_random[:,k,0] = shifts[:,k,1]/norm_magnifications[k] 
    shifts_random[:,k,1] = shifts[:,k,0]/norm_magnifications[k] 
plt.plot(shifts_random[:,-1,1])
plt.show()

In [None]:
rdata = data/ref

rdata_scaled = np.zeros([ntheta,ndist,args.n,args.n],dtype='float32')
for j in np.arange(ndist)[::-1]:
    tmp = cl_rec.STa(shifts_random[:,j]*norm_magnifications[j],rdata[:,j].astype('complex64'),'edge')    
    tmp = (cl_rec.MT(tmp,j)/norm_magnifications[j]**2).real    
    tmp=tmp[:,npsi//2-n//2:npsi//2+n//2,npsi//2-n//2:npsi//2+n//2]
    rdata_scaled[:,j] = tmp    


In [None]:
def _upsampled_dft(data, ups,
                   upsample_factor=1, axis_offsets=None):

    im2pi = 1j * 2 * cp.pi
    tdata = data.copy()
    kernel = (cp.tile(cp.arange(ups), (data.shape[0], 1))-axis_offsets[:, 1:2])[
        :, :, None]*cp.fft.fftfreq(data.shape[2], upsample_factor)
    kernel = cp.exp(-im2pi * kernel)
    tdata = cp.einsum('ijk,ipk->ijp', kernel, tdata)
    kernel = (cp.tile(cp.arange(ups), (data.shape[0], 1))-axis_offsets[:, 0:1])[
        :, :, None]*cp.fft.fftfreq(data.shape[1], upsample_factor)
    kernel = cp.exp(-im2pi * kernel)
    rec = cp.einsum('ijk,ipk->ijp', kernel, tdata)

    return rec

def registration_shift(src_image, target_image, upsample_factor=1, space="real"):
    # print(src_image.shape)
    src_image=cp.array(src_image)
    target_image=cp.array(target_image)
    # assume complex data is already in Fourier space
    if space.lower() == 'fourier':
        src_freq = src_image
        target_freq = target_image
    # real data needs to be fft'd.
    elif space.lower() == 'real':
        src_freq = cp.fft.fft2(src_image)
        target_freq = cp.fft.fft2(target_image)

    # Whole-pixel shift - Compute cross-correlation by an IFFT
    shape = src_freq.shape
    image_product = src_freq * target_freq.conj()
    cross_correlation = cp.fft.ifft2(image_product)
    A = cp.abs(cross_correlation)
    maxima = A.reshape(A.shape[0], -1).argmax(1)
    maxima = cp.column_stack(cp.unravel_index(maxima, A[0, :, :].shape))

    midpoints = cp.array([cp.fix(axis_size / 2)
                          for axis_size in shape[1:]])

    shifts = cp.array(maxima, dtype=cp.float64)
    ids = cp.where(shifts[:, 0] > midpoints[0])
    shifts[ids[0], 0] -= shape[1]
    ids = cp.where(shifts[:, 1] > midpoints[1])
    shifts[ids[0], 1] -= shape[2]
    
    if upsample_factor > 1:
        # Initial shift estimate in upsampled grid
        shifts = cp.round(shifts * upsample_factor) / upsample_factor
        upsampled_region_size = cp.ceil(upsample_factor * 1.5)
        # Center of output array at dftshift + 1
        dftshift = cp.fix(upsampled_region_size / 2.0)

        normalization = (src_freq[0].size * upsample_factor ** 2)
        # Matrix multiply DFT around the current shift estimate

        sample_region_offset = dftshift - shifts*upsample_factor
        cross_correlation = _upsampled_dft(image_product.conj(),
                                                upsampled_region_size,
                                                upsample_factor,
                                                sample_region_offset).conj()
        cross_correlation /= normalization
        # Locate maximum and map back to original pixel grid
        A = cp.abs(cross_correlation)
        maxima = A.reshape(A.shape[0], -1).argmax(1)
        maxima = cp.column_stack(
            cp.unravel_index(maxima, A[0, :, :].shape))

        maxima = cp.array(maxima, dtype=cp.float64) - dftshift

        shifts = shifts + maxima / upsample_factor
           
    return shifts.get()

rdata_scaled_shifted = rdata_scaled.copy()
shifts_my = shifts.copy()
for k in range(100):
    for j in range(ndist):
        mstep = 3000//100//step
        shifts_my[k*mstep:k*mstep+mstep,j] = registration_shift(rdata_scaled[k*mstep:k*mstep+mstep,j],rdata_scaled[k*mstep:k*mstep+mstep,0],upsample_factor=1000)


In [None]:
import scipy
pshifts = scipy.io.loadmat('/data/vnikitin/ESRF/ID16A/brain/20240515/Y350c/Y350c_HT_015nm_/rhapp_fixed.mat')['pshifts'][0,0][0]
pshifts=-pshifts.swapaxes(0,2)[:3000:step]

for k in range(1,4):
    plt.plot(pshifts[:,k,1])
    plt.plot(shifts_my[:,k,1])
    plt.plot(pshifts[:,k,0])
    plt.plot(shifts_my[:,k,0])
    plt.show()


In [None]:
shifts_my=np.tile(np.median(shifts_my,axis=0),(ntheta,1,1))

In [None]:
import scipy
pshifts = scipy.io.loadmat('/data/vnikitin/ESRF/ID16A/brain/20240515/Y350c/Y350c_HT_015nm_/rhapp_fixed.mat')['pshifts'][0,0][0]
pshifts=-pshifts.swapaxes(0,2)[:3000:step]

for k in range(1,4):
    plt.plot(pshifts[:,k,1])
    plt.plot(shifts_my[:,k,1])
    plt.plot(pshifts[:,k,0])
    plt.plot(shifts_my[:,k,0])
    plt.show()


In [None]:
# shifts_my = shifts_my/norm_magnifications[:,np.newaxis]
# pshifts = pshifts/norm_magnifications[:,np.newaxis]

In [None]:
for k in range(1,4):
    plt.plot(pshifts[:,k,1])
    plt.plot(shifts_my[:,k,1])
    plt.plot(pshifts[:,k,0])
    plt.plot(shifts_my[:,k,0])
    plt.show()

In [None]:
shifts_my = shifts_my+shifts_random
rdata_scaled_shifted_check_my=rdata_scaled.copy()
for j in np.arange(ndist)[::-1]:
    tmp = cl_rec.STa(shifts_my[:,j]*norm_magnifications[j],rdata[:,j].astype('complex64'),'edge')    
    tmp = (cl_rec.MT(tmp,j)/norm_magnifications[j]**2).real    
    tmp=tmp[:,npsi//2-n//2:npsi//2+n//2,npsi//2-n//2:npsi//2+n//2]
    rdata_scaled_shifted_check_my[:,j] = tmp    
    print(np.linalg.norm(rdata_scaled_shifted_check_my[:,j]))

    

In [None]:
shifts = pshifts+shifts_random
rdata_scaled_shifted_check=rdata_scaled.copy()

for j in np.arange(ndist)[::-1]:
    tmp = cl_rec.STa(shifts[:,j]*norm_magnifications[j],rdata[:,j].astype('complex64'),'edge')    
    tmp = (cl_rec.MT(tmp,j)/norm_magnifications[j]**2).real    
    tmp=tmp[:,npsi//2-n//2:npsi//2+n//2,npsi//2-n//2:npsi//2+n//2]
    rdata_scaled_shifted_check[:,j] = tmp    


In [None]:
# tifffile.imwrite('/data/tmp/rrr0',rdata_scaled_shifted_check[5,0])
# tifffile.imwrite('/data/tmp/rrr1',rdata_scaled_shifted_check[5,3])
# tifffile.imwrite('/data/tmp/rrr2',rdata_scaled_shifted_check[5,0])
# tifffile.imwrite('/data/tmp/rrr3',rdata_scaled_shifted_check_my[5,3])
# # tifffile.imwrite('/data/tmp/r1',rdata_scaled_shifted_check_my[50,0]-rdata_scaled_shifted_check_my[50,3])
# # ss

In [None]:
s = np.loadtxt('/data/vnikitin/ESRF/ID16A/brain/20240515/Y350c/Y350c_HT_015nm_/correct_correct3D.txt')[:3000:step][:,::-1]

plt.plot(s[:,1])
plt.plot(s[:,0])
plt.show()

In [None]:
s = np.tile(s[:,np.newaxis],(1,ndist,1))
# for k in range(ndist):
#     s[:,k]/=norm_magnifications[k]

plt.plot(s[:,:,1])
plt.plot(s[:,:,0])
plt.show()

In [None]:
shifts_final = shifts+s
shifts_final_my = shifts_my+s
final = rdata_scaled_shifted_check
# final_my = rdata_scaled_shifted_check_my

for j in np.arange(ndist)[::-1]:
    tmp = cl_rec.STa(shifts_final[:,j]*norm_magnifications[j],rdata[:,j].astype('complex64'),'edge')    
    tmp = (cl_rec.MT(tmp,j)/norm_magnifications[j]**2).real    
    tmp=tmp[:,npsi//2-n//2:npsi//2+n//2,npsi//2-n//2:npsi//2+n//2]
    final[:,j] = tmp   

# for j in np.arange(ndist)[::-1]:
#     tmp = cl_rec.STa(shifts_final_my[:,j]*norm_magnifications[j],rdata[:,j].astype('complex64'),'edge')    
#     tmp = (cl_rec.MT(tmp,j)/norm_magnifications[j]**2).real    
#     tmp=tmp[:,npsi//2-n//2:npsi//2+n//2,npsi//2-n//2:npsi//2+n//2]
#     final_my[:,j] = tmp   

        



In [None]:
print(f'{path_out}/{pfile}_corr.h5')
with h5py.File(f'{path_out}/{pfile}_corr.h5','a') as fid:
    
    try:
        for k in range(ndist):
            del fid[f'/exchange/check_shifts3d{k}']
            del fid[f'/exchange/cshifts_final']
    except:
        pass
    fid.create_dataset(f'/exchange/cshifts_final',data = shifts_final)
    for k in range(ndist):
        fid.create_dataset(f'/exchange/check_shifts3d{k}',data = final[:,k])    

In [None]:
# with h5py.File(f'{path_out}/{pfile}_corr.h5','a') as fid:
#     fid.create_dataset(f'/exchange/check_shifts',data = rdata_scaled_shifted_check)
#     fid.create_dataset(f'/exchange/check_shifts_my',data = rdata_scaled_shifted_check_my)