In [None]:
import numpy as np
import cupy as cp
import h5py
from holotomocupy.holo import G, GT
from holotomocupy.shift import S, ST
from holotomocupy.recon_methods import multiPaganin
from holotomocupy.utils import *
from holotomocupy.proc import remove_outliers
##!jupyter nbconvert --to script config_template.ipynb
import sys

# Init data sizes and parametes of the PXM of ID16A

In [None]:
n = 1024  # object size in each dimension
pad = 512
npos= 16

detector_pixelsize = 3.03751e-6
energy = 33.35  # [keV] xray energy
wavelength = 1.2398419840550367e-09/energy  # [m] wave length
focusToDetectorDistance = 1.28  # [m]
sx0 = 1.286e-3
z1 = np.tile(5.5e-3-sx0, [npos])
z2 = focusToDetectorDistance-z1
distances = (z1*z2)/focusToDetectorDistance
magnifications = focusToDetectorDistance/z1
voxelsize = np.abs(detector_pixelsize/magnifications[0]*2048/n)  # object voxel size

# sample size after demagnification
ne = 2048//(2048//n)+2*pad#1024//(2048//n)#2*pad
show = True
lam1 = 1e-1
lam2 = 0
gpu=1
crop=ne//2-n//2

# lam1=float(sys.argv[1])
# lam2=float(sys.argv[2])
# gpu = int(sys.argv[3])
cp.cuda.Device(gpu).use()

flg = f'{n}'
path = f'/data/vnikitin/ESRF/ID16A/20240924/SiemensLH/SiemensLH_010nm_nfp_02/'
path_ref = f'/data/vnikitin/ESRF/ID16A/20240924/SiemensLH/SiemensLH_010nm_nfp_02/'
path_out = f'/data/vnikitin/ESRF/ID16A/20240924_rec2/SiemensLH/SiemensLH_010nm_nfp_02_reg_lap_crop{crop}_{lam1:1.1e}_{lam2:1.1e}_{n}_{pad}'

## Read data

In [None]:
with h5py.File(f'{path}SiemensLH_010nm_nfp_020000.h5') as fid:
    data0 = fid['/entry_0000/measurement/data'][:npos].astype('float32')
with h5py.File(f'{path_ref}ref_0000.h5') as fid:
    ref0 = fid['/entry_0000/measurement/data'][:].astype('float32')
with h5py.File(f'{path}/dark_0000.h5') as fid:
    dark0 = fid['/entry_0000/measurement/data'][:].astype('float32')
with h5py.File(f'{path}SiemensLH_010nm_nfp_020000.h5','r') as fid:
    spz = np.array(str(np.array(str(np.array(fid['/entry_0000/instrument/PCIe/header/spz']))[1:]))[1:-1].split(' '),dtype='float32')*1e-6/voxelsize
    spy = np.array(str(np.array(str(np.array(fid['/entry_0000/instrument/PCIe/header/spy']))[1:]))[1:-1].split(' '),dtype='float32')*1e-6/voxelsize
data0 = data0[np.newaxis]

shifts_code0 = np.zeros([1,npos,2],dtype='float32')
shifts_code0[:,:,1] = spy[:npos]
shifts_code0[:,:,0] = -spz[:npos]

In [None]:
import scipy.ndimage as ndimage
def remove_outliers(data, dezinger, dezinger_threshold):    
    res = data.copy()
    if (int(dezinger) > 0):
        w = int(dezinger)
        # print(data.shape)
        fdata = ndimage.median_filter(data, [1,w, w])
        print(np.sum(np.abs(data-fdata)>fdata*dezinger_threshold))
        res[:] = np.where(np.abs(data-fdata)>fdata*dezinger_threshold, fdata, data)
    return res

In [None]:
data = data0.copy()
ref = ref0.copy()
dark = dark0.copy()
dark = np.mean(dark,axis=0)[np.newaxis]
ref = np.mean(ref,axis=0)[np.newaxis]
data-=dark
ref-=dark

data[data<0]=0
ref[ref<0]=0
# for k in range(data.shape[1]):
#     data[0,k,data[0,k]>ref[0]] = ref[0,data[0,k]>ref[0]]
data[:,:,1320//3:1320//3+25//3,890//3:890//3+25//3] = data[:,:,1280//3:1280//3+25//3,890//3:890//3+25//3]
ref[:,1320//3:1320//3+25//3,890//3:890//3+25//3] = ref[:,1280//3:1280//3+25//3,890//3:890//3+25//3]
for k in range(npos):
    radius = 3
    threshold = 0.8
    data[:,k] = remove_outliers(data[:,k], radius, threshold)
    
ref[:] = remove_outliers(ref[:], radius, threshold)     
data/=np.mean(ref)
dark/=np.mean(ref)
ref/=np.mean(ref)

data[np.isnan(data)] = 1
ref[np.isnan(ref)] = 1

for k in range(int(np.log2(2048//n))):
    data = (data[:,:,::2]+data[:,:,1::2])*0.5
    data = (data[:,:,:,::2]+data[:,:,:,1::2])*0.5
    ref = (ref[:,::2]+ref[:,1::2])*0.5
    ref = (ref[:,:,::2]+ref[:,:,1::2])*0.5    
    dark = (dark[:,::2]+dark[:,1::2])*0.5
    dark = (dark[:,:,::2]+dark[:,:,1::2])*0.5  

rdata = data/(ref+1e-11)

mshow_complex(data[0,0]+1j*rdata[0,0],show)
mshow_complex(ref[0]+1j*dark[0],show)

# Construct operators


In [None]:
import cupyx.scipy.ndimage as ndimage
def Lop(psi):
    data = cp.zeros([1,npos, n, n], dtype='complex64')
    for i in range(npos):
        psir = psi[:,i].copy()
        v = cp.ones(ne,dtype='float32')
        v[:(ne-n)//2] = cp.sin(cp.linspace(0,1,(ne-n)//2)*cp.pi/2)
        v[-(ne-n)//2:] = cp.cos(cp.linspace(0,1,(ne-n)//2)*cp.pi/2)
        v = cp.outer(v,v)
        psir*=v
        psir = G(psir, wavelength, voxelsize, distances[i],'constant')        
        data[:, i] = psir[:,ne//2-n//2:ne//2+n//2,ne//2-n//2:ne//2+n//2]
    return data

def L1op(q):
    data = q.copy()
    v = cp.ones(ne,dtype='float32')
    v[:(ne-n)//2] = cp.sin(cp.linspace(0,1,(ne-n)//2)*cp.pi/2)
    v[-(ne-n)//2:] = cp.cos(cp.linspace(0,1,(ne-n)//2)*cp.pi/2)
    v = cp.outer(v,v)
    data*=v
    data = G(data, wavelength, voxelsize, distances[0],'constant')        
    data = data[:,ne//2-n//2:ne//2+n//2,ne//2-n//2:ne//2+n//2]
    return data



def L1Top(data):

    q = cp.array(cp.pad(data,((0,0),(ne//2-n//2,ne//2-n//2),(ne//2-n//2,ne//2-n//2)))).astype('complex64')        
    q = GT(q, wavelength, voxelsize, distances[0],'constant')        
    v = cp.ones(ne,dtype='float32')
    v[:(ne-n)//2] = cp.sin(cp.linspace(0,1,(ne-n)//2)*cp.pi/2)
    v[-(ne-n)//2:] = cp.cos(cp.linspace(0,1,(ne-n)//2)*cp.pi/2)
    v = cp.outer(v,v)        
    q *= v    
    return q

def LTop(data):
    psi = cp.zeros([1, npos, ne, ne], dtype='complex64')
    for j in range(npos):
        datar = cp.array(cp.pad(data[:, j],((0,0),(ne//2-n//2,ne//2-n//2),(ne//2-n//2,ne//2-n//2)))).astype('complex64')        
        datar = GT(datar, wavelength, voxelsize, distances[j],'constant')        

        v = cp.ones(ne,dtype='float32')
        v[:(ne-n)//2] = cp.sin(cp.linspace(0,1,(ne-n)//2)*cp.pi/2)
        v[-(ne-n)//2:] = cp.cos(cp.linspace(0,1,(ne-n)//2)*cp.pi/2)
        v = cp.outer(v,v)        
        datar *= v

        psi[:,j] += datar
    
    return psi

def Sop(psi,shifts):
    data = cp.zeros([1, npos, ne, ne], dtype='complex64')
    psi = cp.array(psi)
    for j in range(npos):
        psir = psi.copy()
        shiftsr = cp.array(shifts[:, j])
        psir = S(psir, shiftsr)
        data[:,j] = psir
    return data

def STop(data,shifts):
    psi = cp.zeros([1, ne, ne], dtype='complex64')

    for j in range(npos):
        datar = cp.array(data[:,j])
        shiftsr = cp.array(shifts[:, j])
        psi += ST(datar,shiftsr)
    return psi

def Cop(psi):
    res = psi.copy()
    res[:,crop:-crop,crop:-crop]=0
    return res

def CTop(psi):
    res = psi.copy()
    res[:,crop:-crop,crop:-crop]=0
    return res

def Gop_(psi):
    res = cp.zeros([2, *psi.shape], dtype='complex64')
    res[0, :, :, :-1] = psi[:, :, 1:]-psi[:, :, :-1]
    res[1, :, :-1, :] = psi[:, 1:, :]-psi[:, :-1, :]
    return res

def GTop_( gr):
    res = cp.zeros(gr.shape[1:], dtype='complex64')
    res[:, :, 1:] = gr[0, :, :, 1:]-gr[0, :, :, :-1]
    res[:, :, 0] = gr[0, :, :, 0]
    res[:, 1:, :] += gr[1, :, 1:, :]-gr[1, :, :-1, :]
    res[:, 0, :] += gr[1, :, 0, :]
    return -res

def Gop(psi):
    stencil = cp.array([[0, 1, 0],[1, -4, 1], [0, 1, 0]])
    res = psi.copy()
    res[0] = ndimage.convolve(res[0], stencil)
    return res

def GTop(psi):
    stencil = cp.array([[0, 1, 0],[1, -4, 1], [0, 1, 0]])
    res = psi.copy()
    res[0] = ndimage.convolve(res[0], stencil)
    return res


# adjoint tests
tmp = cp.array(data).copy()
arr1 = (cp.random.random([1,ne,ne])+1j*cp.random.random([1,ne,ne])).astype('complex64')
shifts = cp.array(shifts_code0)
arr2 = Sop(arr1,shifts)
arr3 = STop(arr2,shifts)
print(f'{np.sum(arr1*np.conj(arr3))}==\n{np.sum(arr2*np.conj(arr2))}')

arr1 = (cp.random.random([1,npos,ne,ne])+1j*cp.random.random([1,npos,ne,ne])).astype('complex64')
arr2 = Lop(arr1)
arr3 = LTop(arr2)
print(f'{np.sum(arr1*np.conj(arr3))}==\n{np.sum(arr2*np.conj(arr2))}')

arr1 = (cp.random.random([1,ne,ne])+1j*cp.random.random([1,ne,ne])).astype('complex64')
arr2 = Cop(arr1)
arr3 = CTop(arr2)
print(f'{np.sum(arr1*np.conj(arr3))}==\n{np.sum(arr2*np.conj(arr2))}')

arr1 = (cp.random.random([1,ne,ne])+1j*cp.random.random([1,ne,ne])).astype('complex64')
arr2 = Gop(arr1)
arr3 = GTop(arr2)
print(f'{np.sum(arr1*np.conj(arr3))}==\n{np.sum(arr2*np.conj(arr2))}')

arr1 = (cp.random.random([1,ne,ne])+1j*cp.random.random([1,ne,ne])).astype('complex64')
arr2 = L1op(arr1)
arr3 = L1Top(arr2)
print(f'{np.sum(arr1*np.conj(arr3))}==\n{np.sum(arr2*np.conj(arr2))}')

In [None]:
i=4096*2+1024+512
psi_angle = dxchange.read_tiff(f'{path_out}/crec_code_angle{flg}/{i:03}.tiff')
psi_abs = dxchange.read_tiff(f'{path_out}/crec_code_abs{flg}/{i:03}.tiff')
prb_angle = dxchange.read_tiff(f'{path_out}/crec_prb_angle{flg}/{i:03}.tiff')
prb_abs = dxchange.read_tiff(f'{path_out}/crec_prb_abs{flg}/{i:03}.tiff')
shifts = np.load(f'{path_out}/crec_shift{flg}_{i:03}.npy')

psi = psi_abs*np.exp(1j*psi_angle)[np.newaxis]
prb = prb_abs*np.exp(1j*prb_angle)[np.newaxis]

psi =cp.array(psi)
prb =cp.array(prb)
shifts=cp.array(shifts)
ref=cp.array(ref)
data=cp.array(data)

Lpsi = Lop(prb*Sop(psi,shifts))
Lprb = L1op(prb)

def mmshow_polar(a, show=False, v=None, **args):
    """Plot the 2D array in the polar representation with the absolute value and phase,
    handling arrays on GPU       

    Parameters
    ----------
    a : (ny, nx) complex64
        2D array for visualization
    args : 
        Other parameters for imshow    
    """
    if not show:
        return
    if isinstance(a, cp.ndarray):
        a = a.get()
    fig, axs = plt.subplots(1, 2, figsize=(15, 5))
    im = axs[0].imshow(np.abs(a), cmap='gray', **args,vmax=v[1],vmin=v[0])
    axs[0].set_title('abs')
    fig.colorbar(im, fraction=0.046, pad=0.04)
    im = axs[1].imshow(np.angle(a), cmap='gray', **args,vmax=v[3],vmin=v[2])
    axs[1].set_title('phase')
    fig.colorbar(im, fraction=0.046, pad=0.04)
    plt.show()

mmshow_polar(psi[0],show,v=[0.8,1.2,-0.15,0.15])
mmshow_polar(prb[0],show,v=[0.6,3,-1.2,1.2])

mmshow_polar(psi[0,ne//2-ne//8:ne//2+ne//8,ne//2+n//2-n//8:ne//2+n//2+n//4],show,v=[0.8,1.2,-0.1,0.05])
mmshow_polar(prb[0,ne//2-ne//8:ne//2+ne//8,ne//2+n//2-n//8:ne//2+n//2+n//4],show,v=[0.6,2,-1.2,1.2])

mmshow_polar(psi[0,ne//2-ne//16:ne//2+ne//16,ne//2+n//2-n//8+n//16:ne//2+n//2+n//4-n//16],show,v=[0.8,1.2,-0.1,0.05])

mshow(ref[0],show)
mshow(np.abs(Lprb[0])**2-ref[0],show,vmin=-0.04,vmax=0.04)

mshow(data[0,-1],show)
mshow(np.abs(Lpsi[0,-1])**2-data[0,-1],show,vmin=-0.04,vmax=0.04)



In [None]:

errprb = np.zeros(10000)
errpsi = np.zeros(10000)
for i in range(0,10000,128):
    psi_angle = dxchange.read_tiff(f'{path_out}/crec_code_angle{flg}/{i:03}.tiff')
    psi_abs = dxchange.read_tiff(f'{path_out}/crec_code_abs{flg}/{i:03}.tiff')
    prb_angle = dxchange.read_tiff(f'{path_out}/crec_prb_angle{flg}/{i:03}.tiff')
    prb_abs = dxchange.read_tiff(f'{path_out}/crec_prb_abs{flg}/{i:03}.tiff')
    shifts = np.load(f'{path_out}/crec_shift{flg}_{i:03}.npy')

    psi = psi_abs*np.exp(1j*psi_angle)[np.newaxis]
    prb = prb_abs*np.exp(1j*prb_angle)[np.newaxis]

    psi =cp.array(psi)
    prb =cp.array(prb)
    shifts=cp.array(shifts)
    ref=cp.array(ref)
    data=cp.array(data)

    Lpsi = Lop(prb*Sop(psi,shifts))
    Lprb = L1op(prb)
    errprb[i] = np.linalg.norm(np.abs(Lprb)**2-ref)**2
    errpsi[i] = np.linalg.norm(np.abs(Lpsi)**2-data)**2

plt.plot(np.arange(0,10000,128),errprb[::128])
plt.plot(np.arange(0,10000,128),errpsi[::128])
plt.grid('on')
plt.yscale('log')


In [None]:
plt.plot(np.arange(0,10000,128),errprb[::128],label='prb')
plt.plot(np.arange(0,10000,128),errpsi[::128],label='psi')
plt.grid('on')
plt.yscale('log')
plt.legend()
