In [None]:
import numpy as np
import cupy as cp
import sys
import pandas as pd
import time

import matplotlib.pyplot as plt
from utils import write_tiff, read_tiff
from utils import mshow, mshow_polar, mshow_complex
import h5py

# Init data sizes and parametes of the PXM of ID16A

In [None]:
n = 2048  # object size in each dimension
pad = n//8
pos_step = 1
npos = 18*18 # total number of positions
z1c = -17.75e-3 # [m] position of the CA
detector_pixelsize = 3.03751e-6
energy = 33.35  # [keV] xray energy
wavelength = 1.24e-09/energy  # [m] wave length
focusToDetectorDistance = 1.28  # [m]
z1 = np.tile(z1c, [npos])
z2 = focusToDetectorDistance-z1
distances = (z1*z2)/focusToDetectorDistance
magnifications = focusToDetectorDistance/z1
voxelsize = np.abs(detector_pixelsize/magnifications[0]*2048/n)  # object voxel size

# sample size after demagnification
ne = (4096+2048+1024+128)//(2048//n)+2*pad
ne = 3096+2*pad
show = True

path = f'/data/vnikitin/ESRF/ID16A/20240924/SiemensLH/code2um_nfp18x18_01'
path_out = f'/data/vnikitin/ESRF/ID16A/20240924_rec/SiemensLH/code2um_nfp18x18_01'

## Read data

In [None]:
with h5py.File(f'{path}/code2um_nfp18x18_010000.h5') as fid:
    data = fid['/entry_0000/measurement/data'][:npos].astype('float32')
    
with h5py.File(f'{path}/ref_0000.h5') as fid:
    ref = fid['/entry_0000/measurement/data'][:].astype('float32')
with h5py.File(f'{path}/dark_0000.h5') as fid:
    dark = fid['/entry_0000/measurement/data'][:].astype('float32')

shifts = np.loadtxt(f'/data/vnikitin/ESRF/ID16A/20240924/positions/shifts_code_nfp18x18ordered.txt')[:,::-1]
shifts = shifts/voxelsize*1e-6
shifts[:,0]*=-1

mshow(data[0],show)
mshow(ref[0],show)
mshow(dark[0],show)

plt.plot(shifts[:,0],shifts[:,1],'.')
plt.axis('square')
plt.grid()
plt.show()


# Pre-processing

In [None]:
import cupyx.scipy.ndimage as ndimage
def remove_outliers(data, dezinger, dezinger_threshold):    
    res = data.copy()
    w = [dezinger,dezinger]
    for k in range(data.shape[0]):
        data0 = cp.array(data[k])
        fdata = ndimage.median_filter(data0, w)
        print(np.sum(np.abs(data0-fdata)>fdata*dezinger_threshold))
        res[k] = np.where(np.abs(data0-fdata)>fdata*dezinger_threshold, fdata, data0).get()
    return res

dark = np.mean(dark,axis=0)
ref = np.mean(ref,axis=0)
data -= dark
ref -= dark

data[data<0]=0
ref[ref<0]=0
data[:,1320//3:1320//3+25//3,890//3:890//3+25//3] = data[:,1280//3:1280//3+25//3,890//3:890//3+25//3]
ref[1320//3:1320//3+25//3,890//3:890//3+25//3] = ref[1280//3:1280//3+25//3,890//3:890//3+25//3]

data = remove_outliers(data, 3, 0.8)    
ref = remove_outliers(ref[None], 3, 0.8)[0]     

data /= np.mean(ref)
ref /= np.mean(ref)

data[np.isnan(data)] = 1
ref[np.isnan(ref)] = 1

mshow(data[0],mshow)
mshow(ref,mshow)

# find shift with 1 pixel accuracy

In [None]:
import scipy.ndimage as ndimage

shifts_relative = shifts*0

def my_phase_corr(d1, d2):
    image_product = np.fft.fft2(d1) * np.fft.fft2(d2).conj()
    cc_image = np.fft.fftshift(np.fft.ifft2(image_product))
    ind = np.unravel_index(np.argmax(cc_image.real, axis=None), cc_image.real.shape)
    shifts = np.subtract(ind,d1.shape[-1]//2)
    return shifts

def S(psi,p):
    """Subpixel shift"""
    psi=cp.array(psi)
    p=cp.array(p)
    n = psi.shape[-1]
    x =cp.fft.fftfreq(2*n).astype('float32')
    [y, x] = cp.meshgrid(x, x)
    psi = cp.pad(psi,((0,0),(n//2,n//2),(n//2,n//2)))
    pp = np.exp(-2*np.pi*1j * (y*p[:, 1, None, None]+x*p[:, 0, None, None])).astype('complex64')
    res = np.fft.ifft2(pp*np.fft.fft2(psi))

    res = res[:,n//2:-n//2,n//2:-n//2]
    return res.get()

rdata = data/(ref+1e-6)
rdatat = rdata.copy()    
for k in range(0,npos-1):         
    shifts_relative[k] = my_phase_corr(rdata[k],rdata[k+1])
    #print(k,shifts_relative[k])
    rdatat[k:k+1] = S(rdata[k:k+1].astype('complex64'),-shifts_relative[k:k+1]).real
    dif = rdatat[k]-rdatat[k+1]
    nn = np.linalg.norm(dif[n//2-n//16:n//2+n//16,n//2-n//16:n//2+n//16])
    if nn>80:
        print('WARNING')
        mshow(dif,show,vmax=1,vmin=-1)
    print(k,shifts_relative[k],nn)
    # if k%10==0:
    
np.save('shifts_relative',shifts_relative)    


In [None]:
ipos = npos//2+9 # align wrt the middle
shifts_relative = np.load('shifts_relative.npy')
shifts_new = shifts*0
for k in range(ipos):
    shifts_new[k] = np.sum(shifts_relative[k:ipos],axis=0) 
shifts_new[ipos] = shifts[ipos]
for k in range(ipos,npos):
    shifts_new[k] = np.sum(-shifts_relative[ipos:k],axis=0)

# back to original shape
# original shape is bad, adjust it
shifts = np.loadtxt(f'/data/vnikitin/ESRF/ID16A/20240924/positions/shifts_code_nfp18x18ordered.txt')[:,::-1]
shifts[:,1]*=-1
shifts = shifts/voxelsize*1e-6

print(shifts[-10:])
print(shifts_new[-10:])
plt.plot(shifts[:,1],shifts[:,0],'.')
plt.plot(shifts_new[:,1],shifts_new[:,0],'.')
plt.plot(shifts_new[ipos,1],shifts_new[ipos,0],'rx')
plt.show()

np.save('shifts_new.npy',shifts_new)
