In [None]:
import numpy as np
import cupy as cp
import sys
import pandas as pd
import time
import scipy.ndimage as ndimage
import matplotlib.pyplot as plt
import h5py
sys.path.insert(0, '..')
from utils import *

# Init data sizes and parametes of the PXM of ID16A

In [None]:
path = f'/data/vnikitin/ESRF/ID16A/20240924_rec_ca/data/'
with  h5py.File(f'{path}/data_ca.h5','r') as fid:
    data = fid[f'/exchange/data'][:].astype('float32')
    ref = fid[f'/exchange/data_white'][:].astype('float32')
    dark = fid[f'/exchange/data_dark'][:].astype('float32')

    z1 = fid['/exchange/z1'][0]
    detector_pixelsize = fid['/exchange/detector_pixelsize'][0]
    focusdetectordistance = fid['/exchange/focusdetectordistance'][0]
    energy = fid['/exchange/energy'][0]


In [None]:
wavelength = 1.24e-09/energy  # [m] wave length
focusToDetectorDistance = 1.28  # [m]
z2 = focusToDetectorDistance-z1
distances = (z1*z2)/focusToDetectorDistance
magnification = focusToDetectorDistance/z1
voxelsize = np.abs(detector_pixelsize/magnification)  # object voxel size


# Pre-processing

In [None]:
import cupyx.scipy.ndimage as ndimage
def remove_outliers(data, dezinger, dezinger_threshold):    
    res = data.copy()
    w = [dezinger,dezinger]
    for k in range(data.shape[0]):
        data0 = cp.array(data[k])
        fdata = ndimage.median_filter(data0, w)
        print(np.sum(np.abs(data0-fdata)>fdata*dezinger_threshold))
        res[k] = np.where(np.abs(data0-fdata)>fdata*dezinger_threshold, fdata, data0).get()
    return res

dark = np.mean(dark,axis=0)
ref = np.mean(ref,axis=0)
data -= dark
ref -= dark

data[data<0]=0
ref[ref<0]=0
data[:,1320//3:1320//3+25//3,890//3:890//3+25//3] = data[:,1280//3:1280//3+25//3,890//3:890//3+25//3]
ref[1320//3:1320//3+25//3,890//3:890//3+25//3] = ref[1280//3:1280//3+25//3,890//3:890//3+25//3]

data = remove_outliers(data, 3, 0.8)    
ref = remove_outliers(ref[None], 3, 0.8)[0]     

data /= np.mean(ref)
ref /= np.mean(ref)

data[np.isnan(data)] = 1
ref[np.isnan(ref)] = 1

mshow(data[0],mshow)
mshow(ref,mshow)

# find shift with 1 pixel accuracy

In [None]:
# back to original shape
# original shape is bad, adjust it
# scan goes from top left!
shifts = np.loadtxt(f'/data/vnikitin/ESRF/ID16A/20240924/positions/shifts_code_nfp18x18ordered.txt')[:,::-1]
shifts[:,1]*=-1
shifts = shifts/voxelsize*1e-6
print(shifts[:4])

In [None]:

[npos,n] = data.shape[:2]
shifts_relative = shifts*0

def my_phase_corr(d1, d2):
    image_product = np.fft.fft2(d1) * np.fft.fft2(d2).conj()
    cc_image = np.fft.fftshift(np.fft.ifft2(image_product))
    ind = np.unravel_index(np.argmax(cc_image.real, axis=None), cc_image.real.shape)
    shifts = np.subtract(ind,d1.shape[-1]//2)
    return shifts

def S(psi,p):
    """Subpixel shift"""
    psi=cp.array(psi)
    p=cp.array(p)
    n = psi.shape[-1]
    x =cp.fft.fftfreq(2*n).astype('float32')
    [y, x] = cp.meshgrid(x, x)
    psi = cp.pad(psi,((0,0),(n//2,n//2),(n//2,n//2)))
    pp = np.exp(-2*np.pi*1j * (y*p[:, 1, None, None]+x*p[:, 0, None, None])).astype('complex64')
    res = np.fft.ifft2(pp*np.fft.fft2(psi))

    res = res[:,n//2:-n//2,n//2:-n//2]
    return res.get()

rdata = data/(ref+1e-6)
rdatat = rdata.copy()    
for k in range(0,npos-1):         
    shifts_relative[k] = my_phase_corr(rdata[k],rdata[k+1])
    rdatat[k:k+1] = S(rdata[k:k+1].astype('complex64'),-shifts_relative[k:k+1]).real
    dif = rdatat[k]-rdatat[k+1]
    nn = np.linalg.norm(dif[n//2-n//16:n//2+n//16,n//2-n//16:n//2+n//16])
    if nn>80:
        print('WARNING')
        mshow(dif,True,vmax=1,vmin=-1)
    print(k,shifts_relative[k],nn)    


In [None]:
ipos = npos//2+9 # align wrt the middle
shifts_new = shifts*0
for k in range(ipos):
    shifts_new[k] = np.sum(shifts_relative[k:ipos],axis=0) 
shifts_new[ipos] = shifts[ipos]
for k in range(ipos,npos):
    shifts_new[k] = np.sum(-shifts_relative[ipos:k],axis=0)


print(shifts[-10:])
print(shifts_new[-10:])
plt.plot(shifts[:,1],shifts[:,0],'.')
plt.plot(shifts_new[:,1],shifts_new[:,0],'.')
plt.plot(shifts_new[ipos,1],shifts_new[ipos,0],'rx')
plt.show()

np.save('shifts_new.npy',shifts_new)


In [None]:
path_out = f'/data/vnikitin/ESRF/ID16A/20240924_rec_ca/data/'
with  h5py.File(f'{path_out}/data_ca.h5','a') as fid:
    try:
        del fid[f'/exchange/shifts']
        del fid[f'/exchange/pdata']
        del fid[f'/exchange/pref']
    except:
        pass
    fid.create_dataset(f'/exchange/shifts',data=shifts_new)     
    fid.create_dataset(f'/exchange/pdata',data=data)     
    fid.create_dataset(f'/exchange/pref',data=ref)     