In [None]:
import numpy as np
import cupy as cp
from holotomocupy.holo import G, GT
from holotomocupy.shift import S, ST
from holotomocupy.recon_methods import CTFPurePhase, multiPaganin
from holotomocupy.utils import *
from holotomocupy.proc import remove_outliers
import h5py

# Init data sizes and parametes of the PXM of ID16A

In [None]:
n = 2048  # object size in each dimension
pad = n//8
pos_step = 1
npos = 18*18 # total number of positions
z1c = -17.75e-3 # [m] position of the CA
detector_pixelsize = 3.03751e-6
energy = 33.35  # [keV] xray energy
wavelength = 1.2398419840550367e-09/energy  # [m] wave length
focusToDetectorDistance = 1.28  # [m]
z1 = np.tile(z1c, [npos])
z2 = focusToDetectorDistance-z1
distances = (z1*z2)/focusToDetectorDistance
magnifications = focusToDetectorDistance/z1
voxelsize = np.abs(detector_pixelsize/magnifications[0]*2048/n)  # object voxel size

# sample size after demagnification
ne = (4096+2048+1024+128)//(2048//n)+2*pad
ne = 3096+2*pad
show = True

path = f'/data/vnikitin/ESRF/ID16A/20240924/SiemensLH/code2um_nfp18x18_01'
path_out = f'/data/vnikitin/ESRF/ID16A/20240924_rec/SiemensLH/code2um_nfp18x18_01'

## Read data

In [None]:
with h5py.File(f'{path}/code2um_nfp18x18_010000.h5') as fid:
    data0 = fid['/entry_0000/measurement/data'][:npos].astype('float32')
    
with h5py.File(f'{path}/ref_0000.h5') as fid:
    ref0 = fid['/entry_0000/measurement/data'][:].astype('float32')
with h5py.File(f'{path}/dark_0000.h5') as fid:
    dark0 = fid['/entry_0000/measurement/data'][:].astype('float32')

data0 = data0[np.newaxis]
shifts_code0 = np.loadtxt(f'/data/vnikitin/ESRF/ID16A/20240924/positions/shifts_code_nfp18x18ordered.txt')[np.newaxis,:,::-1]
shifts_code0 = shifts_code0/voxelsize*1e-6
shifts_code0[:,:,0]*=-1

mshow(data0[0,0],show)
mshow(ref0[0],show)
mshow(dark0[0],show)

# Pre-processing

In [None]:
data = data0.copy()
ref = ref0.copy()
dark = dark0.copy()

dark = np.mean(dark,axis=0)[np.newaxis]
ref = np.mean(ref,axis=0)[np.newaxis]

data-=dark
ref-=dark

data[data<0]=0
ref[ref<0]=0

# broken detector pixels
data[:,:,1320//3:1320//3+25//3,890//3:890//3+25//3] = data[:,:,1280//3:1280//3+25//3,890//3:890//3+25//3]
ref[:,1320//3:1320//3+25//3,890//3:890//3+25//3] = ref[:,1280//3:1280//3+25//3,890//3:890//3+25//3]
for k in range(npos):
    radius = 3
    threshold = 0.8
    data[:,k] = remove_outliers(data[:,k], radius, threshold)
ref[:] = remove_outliers(ref[:], radius, threshold)     

# normalization
data/=np.mean(ref)
dark/=np.mean(ref)
ref/=np.mean(ref)

data[np.isnan(data)] = 1
ref[np.isnan(ref)] = 1

# binning
for k in range(int(np.log2(2048//n))):
    data = (data[:,:,::2]+data[:,:,1::2])*0.5
    data = (data[:,:,:,::2]+data[:,:,:,1::2])*0.5
    ref = (ref[:,::2]+ref[:,1::2])*0.5
    ref = (ref[:,:,::2]+ref[:,:,1::2])*0.5    
    dark = (dark[:,::2]+dark[:,1::2])*0.5
    dark = (dark[:,:,::2]+dark[:,:,1::2])*0.5  

rdata = data/(ref+1e-11)

mshow_complex(data[0,0]+1j*rdata[0,0],show)
mshow_complex(ref[0]+1j*dark[0],show)

# find shift with 1 pixel accuracy

In [None]:
from skimage.registration import phase_cross_correlation
import scipy.ndimage as ndimage

shifts_relative = shifts_code0.copy()*0

def my_phase_corr(d1, d2):
    image_product = np.fft.fft2(d1) * np.fft.fft2(d2).conj()
    cc_image = np.fft.fftshift(np.fft.ifft2(image_product))
    ind = np.unravel_index(np.argmax(cc_image.real, axis=None), cc_image.real.shape)
    shifts = np.subtract(ind,d1.shape[-1]//2)
    return shifts

rdatat = rdata.copy()    
for k in range(0,npos-1):         
    shifts_relative[0,k] = my_phase_corr(rdata[0,k],rdata[0,k+1])
    print(k,shifts_relative[0,k])
    rdatat[:,k] = ST(rdata[:,k].astype('complex64'),shifts_relative[:,k]).real
    if k%10==0:
        mshow(rdatat[0,k]-rdatat[0,k+1],show,vmax=1,vmin=-1)
np.save('shifts_relative',shifts_relative)    


In [None]:
ipos = npos//2+9 # align wrt the middle
shifts_relative = np.load('shifts_relative.npy')
shifts_code_new = shifts_code0*0
for k in range(ipos):
    shifts_code_new[:,k] = np.sum(shifts_relative[:,k:ipos],axis=1) 
shifts_code_new[:,ipos] = shifts_code0[:,ipos]
for k in range(ipos,npos):
    shifts_code_new[:,k] = np.sum(-shifts_relative[:,ipos:k],axis=1)

# back to original shape
# original shape is bad, adjust it
shifts_code0 = np.loadtxt(f'/data/vnikitin/ESRF/ID16A/20240924/positions/shifts_code_nfp18x18ordered.txt')[np.newaxis,:,::-1]
shifts_code0[:,:,1]*=-1
shifts_code0 = shifts_code0/voxelsize*1e-6

print(shifts_code0[:,-10:])
print(shifts_code_new[:,-10:])
plt.plot(shifts_code0[0,:,1],shifts_code0[0,:,0],'.')
plt.plot(shifts_code_new[0,:,1],shifts_code_new[0,:,0],'.')
plt.show()

np.save('shifts_code_new.npy',shifts_code_new)
