In [None]:
import numpy as np
import cupy as cp
import sys
import pandas as pd
import time
import scipy.ndimage as ndimage
import matplotlib.pyplot as plt
import h5py
sys.path.insert(0, '..')
from utils import *

# Init data sizes and parametes of the PXM of ID16A

In [None]:
path = f'/data/vnikitin/ESRF/ID16A/20240924_rec_ca/data/'
with  h5py.File(f'{path}/data_ca.h5','r') as fid:
    data = fid[f'/exchange/pdata'][:].astype('float32')
    ref = fid[f'/exchange/pref'][:].astype('float32')
    shifts = fid[f'/exchange/shifts'][:].astype('float32')    
    
    z1 = fid['/exchange/z1'][0]
    detector_pixelsize = fid['/exchange/detector_pixelsize'][0]
    focusdetectordistance = fid['/exchange/focusdetectordistance'][0]
    energy = fid['/exchange/energy'][0]


In [None]:
wavelength = 1.24e-09/energy  # [m] wave length
focusToDetectorDistance = 1.28  # [m]
z2 = focusToDetectorDistance-z1
distances = (z1*z2)/focusToDetectorDistance
magnification = focusToDetectorDistance/z1
voxelsize = np.abs(detector_pixelsize/magnification)  # object voxel size


In [None]:
n = 2048
npsi = 8192
npos = 324

In [None]:
def Paganin(data, wavelength, voxelsize, delta_beta, alpha):
    fx = cp.fft.fftfreq(data.shape[-1], d=voxelsize).astype("float32")
    [fx, fy] = cp.meshgrid(fx, fx)
    rad_freq = cp.fft.fft2(data)
    taylorExp = 1 + wavelength * distances * cp.pi * (delta_beta) * (fx**2 + fy**2)
    numerator = taylorExp * (rad_freq)
    denominator = taylorExp**2 + alpha
    phase = cp.log(cp.real(cp.fft.ifft2(numerator / denominator)))
    phase = delta_beta * 0.5 * phase
    return phase


def rec_init(rdata, ishifts):
    recMultiPaganin = cp.zeros([npsi, npsi], dtype="float32")
    recMultiPaganinr = cp.zeros(
        [npsi, npsi], dtype="float32"
    )  # to compensate for overlap
    for j in range(0, npos):
        r = cp.array(rdata[j])
        r = Paganin(r, wavelength, voxelsize, 24.05, 1e-1)
        rr = r * 0 + 1  # to compensate for overlap
        rpsi = cp.zeros([npsi, npsi], dtype="float32")
        rrpsi = cp.zeros([npsi, npsi], dtype="float32")
        stx = npsi // 2 - ishifts[j, 1] - n // 2
        endx = stx + n
        sty = npsi // 2 - ishifts[j, 0] - n // 2
        endy = sty + n
        rpsi[sty:endy, stx:endx] = r
        rrpsi[sty:endy, stx:endx] = rr

        recMultiPaganin += rpsi
        recMultiPaganinr += rrpsi

    recMultiPaganinr[np.abs(recMultiPaganinr) < 5e-2] = 1
    recMultiPaganin /= recMultiPaganinr
    recMultiPaganin = np.exp(1j * recMultiPaganin)
    return recMultiPaganin


ishifts = np.round(np.array(shifts)).astype("int32")
rdata = np.array(data / (ref + 1e-5))
psi_init = rec_init(rdata, ishifts)
mshow_polar(psi_init,True)
mshow_polar(psi_init[:1000, :1000],True)

# smooth borders
v = cp.arange(-npsi // 2,npsi // 2) / npsi
[vx, vy] = cp.meshgrid(v, v)
v = cp.exp(-1000 * (vx**2 + vy**2)).astype("float32")

psi_init = cp.fft.fftshift(cp.fft.fftn(cp.fft.fftshift(psi_init)))
psi_init = cp.fft.fftshift(cp.fft.ifftn(cp.fft.fftshift(psi_init * v))).astype(
    "complex64"
)
mshow_polar(psi_init,True)
mshow_polar(psi_init[:1000, :1000],True)


In [None]:
path_out = f'/data/vnikitin/ESRF/ID16A/20240924_rec_ca/data/'
with  h5py.File(f'{path_out}/data_ca.h5','a') as fid:
    try:
        del fid[f'/exchange/psi_init']
    except:
        pass
    fid.create_dataset(f'/exchange/psi_init',data=psi_init.get())     