In [None]:
import numpy as np
import cupy as cp
import sys
import pandas as pd
import time
import matplotlib.pyplot as plt
import h5py
from types import SimpleNamespace
import warnings
warnings.filterwarnings("ignore", message=f".*peer.*")


sys.path.insert(0, '..')
from utils import *
from rec import Rec

%matplotlib inline

In [None]:
energy = 33.5
wavelength = 1.24e-09 / energy
z1 = -17.75e-3  # [m] position of the sample
detector_pixelsize = 3.03751e-6
focusToDetectorDistance = 1.28  # [m]
# adjustments for the cone beam
z2 = focusToDetectorDistance - z1
distance = (z1 * z2) / focusToDetectorDistance
magnification = focusToDetectorDistance / z1
voxelsize = float(cp.abs(detector_pixelsize / magnification))

In [None]:
args = SimpleNamespace()

args.ngpus = 1#int(sys.args[1])
args.n = 8704
args.ncode = 8704
args.npsi = 8704
args.pad = 0
args.nq = 8704
args.ex = 0
args.npatch = 8704
args.npos = 1
args.nchunk = 1

args.voxelsize = voxelsize
args.wavelength = wavelength
args.distance = distance
args.distancec = distance

# doesnt matter
args.lam = 0
args.eps = 1e-8
args.rho = [1, 0.01, 0.1]
args.crop = 0
args.path_out = ""
args.niter = 2049
args.err_step = 1
args.vis_step = 8
args.method = "BH-CG"

args.show = True

# create class
cl_rec = Rec(args)

In [None]:
path_code = f"/data/vnikitin/ESRF/ID16A/20240924_rec0224/SiemensLH/code2um_nfp18x18_01/bets_final_256_0.1_2"
iter = 768

In [None]:
code_angle = read_tiff(f'{path_code}/rec_psi_angle/{iter:04}.tiff')
code_abs = read_tiff(f'{path_code}/rec_psi_abs/{iter:04}.tiff')
code = code_abs*np.exp(1j*code_angle)
mshow_polar(code,args.show)

q_angle = read_tiff(f'{path_code}/rec_prb_angle/{iter:04}.tiff')
q_abs = read_tiff(f'{path_code}/rec_prb_abs/{iter:04}.tiff')
q = q_abs*np.exp(1j*q_angle)
mshow_polar(q,args.show)

In [None]:
code_data = (cp.abs(cl_rec.D(cp.array(code[cp.newaxis]))[0])**2).get()

In [None]:
mshow(code_data,args.show,vmax=1.7,vmin=0.7)
mshow(code_data[2000:2500,2000:2500],args.show,vmax=1.7,vmin=0.7)

## Read data

In [None]:
with h5py.File('/data/vnikitin/ESRF/ID16A/20240924/SiemensLH/SiemensLH_010nm_code2um_nfp9x9_01/SiemensLH_010nm_code2um_nfp9x9_010000.h5') as fid:
    data = fid['/entry_0000/measurement/data'][:].astype('float32')    
with h5py.File('/data/vnikitin/ESRF/ID16A/20240924/SiemensLH/SiemensLH_010nm_nfp_02/ref_0000.h5') as fid:
    ref = fid['/entry_0000/measurement/data'][:].astype('float32')
with h5py.File('/data/vnikitin/ESRF/ID16A/20240924/SiemensLH/SiemensLH_010nm_code2um_nfp9x9_01/dark_0000.h5') as fid:
    dark = fid['/entry_0000/measurement/data'][:].astype('float32')


# Pre-processing

In [None]:
import cupyx.scipy.ndimage as ndimage
def remove_outliers(data, dezinger, dezinger_threshold):    
    res = data.copy()
    w = [dezinger,dezinger]
    for k in range(data.shape[0]):
        data0 = cp.array(data[k])
        fdata = ndimage.median_filter(data0, w)
        print(np.sum(np.abs(data0-fdata)>fdata*dezinger_threshold),end=" ")
        res[k] = np.where(np.abs(data0-fdata)>fdata*dezinger_threshold, fdata, data0).get()
    return res

dark = np.mean(dark,axis=0)
ref = np.mean(ref,axis=0)
data -= dark
ref -= dark

data[data<0]=0
ref[ref<0]=0
data[:,1320//3:1320//3+25//3,890//3:890//3+25//3] = data[:,1280//3:1280//3+25//3,890//3:890//3+25//3]
ref[1320//3:1320//3+25//3,890//3:890//3+25//3] = ref[1280//3:1280//3+25//3,890//3:890//3+25//3]

data = remove_outliers(data, 3, 0.8)    
ref = remove_outliers(ref[None], 3, 0.8)[0]     

data /= np.mean(ref)
ref /= np.mean(ref)

data[np.isnan(data)] = 1
ref[np.isnan(ref)] = 1

mshow(data[0],mshow)
mshow(ref,mshow)

# find shift with 1 pixel accuracy

In [None]:
z1c = -17.75e-3
detector_pixelsize = 3.03751e-6
energy = 33.35  # [keV] xray energy
wavelength = 1.2398419840550367e-09/energy  # [m] wave length
focusToDetectorDistance = 1.28  # [m]
sx0 = 1.286e-3
z1 = 5.5e-3-sx0
z2 = focusToDetectorDistance-z1
distance = (z1*z2)/focusToDetectorDistance
magnifications = focusToDetectorDistance/z1
voxelsize = np.abs(detector_pixelsize/magnifications)  # object voxel size

# magnification when propagating from the probe plane to the detector
magnifications2 = z1/z1c
distancec = (z1-z1c)/(z1c/z1)#magnifications2

show = True

In [None]:
args = SimpleNamespace()
args.npos = 81

args.ngpus = 3
args.lam = 0.2

args.n = 2048
args.ncode = 8704
args.pad = args.n // 8
args.npsi = args.n + 2 * args.pad
args.nq = args.n + 2 * args.pad
args.ex = 8
args.npatch = args.nq + 2 * args.ex
args.nchunk = 9

args.voxelsize = voxelsize
args.wavelength = wavelength
args.distance = distance
args.distancec = distancec
args.eps = 1e-12
args.rho = [1, 0.5, 0.1]
args.crop = 0#2 * args.pad
args.path_out = f"/data/vnikitin/ESRF/ID16A/20240924_rec0224/SiemensLH/SiemensLH_010nm_code2um_nfp9x9_01/{args.lam}_{args.rho[1]}_{args.rho[2]}"

args.niter = 2049
args.err_step = 1
args.vis_step = 32
args.method = "BH-CG"
args.show = True

# create class
cl_rec = Rec(args)


In [None]:
rdata = data/ref
mshow(rdata[-1],show)

In [None]:
def my_phase_corr(d1, d2):
    image_product = np.fft.fft2(d1) * np.fft.fft2(d2).conj()
    cc_image = np.fft.fftshift(np.fft.ifft2(image_product))
    ind = np.unravel_index(np.argmax(cc_image.real, axis=None), cc_image.real.shape)
    shifts = cp.zeros(2,'float32')
    shifts[0] = ind[0]
    shifts[1] = ind[1]
    shifts -= d1.shape[-1]//2
    return shifts.get()

shifts_code = np.zeros([args.npos,2],dtype='float32')
a = cp.array(code_data)
nn = code_data.shape[-1]
rrdata0=rdata.copy()
for k in range(rdata.shape[0]):        
    b = cp.pad(cp.array(rdata[k]),((nn//2-args.n//2,nn//2-args.n//2),(nn//2-args.n//2,nn//2-args.n//2)))
    shift = -my_phase_corr(a,b)
    shifts_code[k] = shift
    aa = a[nn//2-shift[0]-args.n//2:nn//2-shift[0]+args.n//2,
           nn//2-shift[1]-args.n//2:nn//2-shift[1]+args.n//2]
    bb = cp.array(rdata[k])
    rrdata0[k] = (bb/aa).get()
# print(shifts_code)


In [None]:
ri = shifts_code.astype('int32')
r = (shifts_code-ri).astype('float32')
scode = cl_rec.S(ri,r,code)
Dscode = cl_rec.Dc(scode*q)
DDscode =  cl_rec.D(Dscode)
rrdata = data/np.abs(DDscode)**2
mshow(rrdata0[0],show,vmin=0.5,vmax=1.5)
mshow(rrdata[0],show,vmax=1.3,vmin=0.7)

In [None]:
mshow(cp.mean(rrdata[:],axis=0),show,vmax=1.1,vmin=0.8)

In [None]:
def Paganin(data, wavelength, voxelsize, delta_beta, alpha):
    fx = cp.fft.fftfreq(data.shape[-1], d=voxelsize).astype("float32")
    [fx, fy] = cp.meshgrid(fx, fx)
    rad_freq = cp.fft.fft2(data)
    taylorExp = 1 + wavelength * distance * cp.pi * (delta_beta) * (fx**2 + fy**2)
    numerator = taylorExp * (rad_freq)
    denominator = taylorExp**2 + alpha
    phase = cp.log(cp.real(cp.fft.ifft2(numerator / denominator)))
    phase = delta_beta * 0.5 * phase
    return phase

rrdata0 = np.mean(rrdata,axis=0)
rrdata0/=np.mean(rrdata0)
rrdata0[rrdata0>1.5]=1.5

rrdata0 = cp.array(rrdata0)
psi_init = Paganin(rrdata0, wavelength, voxelsize, 4.05, 2e-2)
psi_init = np.pad(psi_init,((args.npsi//2-args.n//2,args.npsi//2-args.n//2),
                                         (args.npsi//2-args.n//2,args.npsi//2-args.n//2)
                                         ))
psi_init = np.exp(1j * psi_init)
mshow_polar(psi_init,args.show)

# smooth borders
v = cp.arange(-args.npsi // 2,args.npsi // 2) / args.npsi
[vx, vy] = cp.meshgrid(v, v)
v = cp.exp(-100 * (vx**2 + vy**2)).astype("float32")

psi_init = cp.fft.fftshift(cp.fft.fftn(cp.fft.fftshift(psi_init)))
psi_init = cp.fft.fftshift(cp.fft.ifftn(cp.fft.fftshift(psi_init * v))).astype(
    "complex64"
)
mshow_polar(psi_init,args.show)

rdata = v = []

In [None]:

# variables
vars = {}
vars["code"] = cp.array(code)
vars["psi"] = cp.array(psi_init)
vars["q"] = cp.array(q)
vars["ri"] = shifts_code.astype("int32")
vars["r"] = np.array(shifts_code - vars["ri"]).astype("float32")
vars["table"] = pd.DataFrame(columns=["iter", "err", "time"])

# reconstruction
vars = cl_rec.BH(data, vars)
