In [None]:
import numpy as np
import cupy as cp
import h5py
import matplotlib.pyplot as plt
from types import SimpleNamespace
import pandas as pd
import h5py
import sys
import warnings
warnings.filterwarnings("ignore", message=f".*peer.*")

sys.path.insert(0, '..')
from utils import *
from rec import Rec
import os
import psutil
process = psutil.Process(os.getpid())

# Init data sizes and parametes of the PXM of ID16A

In [None]:
step = 1
ntheta = 100
st = 0
bin = 1
ndist=4

In [None]:
pfile = f'Y350c_HT_015nm'
path_out = f'/data/vnikitin/ESRF/ID16A/brain_rec/20240515/Y350c2'
with h5py.File(f'{path_out}/{pfile}.h5') as fid:
    detector_pixelsize = fid['/exchange/detector_pixelsize'][0]    
    focusToDetectorDistance = fid['/exchange/focusdetectordistance'][0]    
    z1 = fid['/exchange/z1'][:ndist]        
    shape = fid['/exchange/data0'].shape
    shape_ref = fid['/exchange/data_white_start0'].shape
    shape_dark = fid['/exchange/data_dark0'].shape
    

In [None]:
n = shape[-1]
n//=2**bin

In [None]:
energy = 33.35  # [keV] xray energy
wavelength = 1.2398419840550367e-09/energy  # [m] wave length
z2 = focusToDetectorDistance-z1

magnifications = focusToDetectorDistance/z1
norm_magnifications = magnifications/magnifications[0]
distances = (z1*z2)/focusToDetectorDistance*norm_magnifications**2#!!!!
voxelsize = detector_pixelsize/magnifications[0]*2048/n  # object voxel size
show = True


In [None]:
npsi = int(np.ceil(2048/norm_magnifications[-1]/16))*16  # make multiple of 8
npsi//=(2048//n)

In [None]:

def _downsample(data, binning):
    res = data.copy()
    for k in range(binning):
        res = 0.5*(res[..., ::2,:]+res[..., 1::2,:])
        res = 0.5*(res[..., :, ::2]+res[..., :, 1::2])
    return res

In [None]:


with h5py.File(f'{path_out}/{pfile}_corr.h5') as fid:
    r = (fid[f'/exchange/cshifts_final'][st:st+step*ntheta:step,:ndist]).astype('float32')
    psi_abs = (fid[f'/exchange/psi_init_abs'][st:st+step*ntheta:step,:]).astype('float32')
    psi_angle = (fid[f'/exchange/psi_init_angle'][st:st+step*ntheta:step,:]).astype('float32')
    psi = psi_abs*np.exp(1j*psi_angle)
    psi = _downsample(psi,bin)

    data = np.empty([ntheta,ndist,n,n],dtype='float32')
    for k in range(ndist):
        data[:,k] = np.sqrt(_downsample(fid[f'/exchange/data{k}'][st:st+step*ntheta:step],bin))                    
    ref = fid[f'/exchange/ref'][:ndist]    
    ref=_downsample(ref,bin)




In [None]:
args = SimpleNamespace()
args.ngpus = 4

args.n = n
args.npsi = npsi
args.ndist = ndist
args.ntheta = ntheta
args.pad = 0

args.nq = n + 2 * 0
args.nchunk = 32

args.voxelsize = voxelsize
args.wavelength = wavelength
args.distance = distances
args.norm_magnifications = norm_magnifications

args.niter = 257
args.vis_step = 4
args.err_step = 4
args.lam = 0
args.path_out = f"{path_out}/{pfile}/rec_psi_{ndist}_{ntheta}_{st}"
args.show = True


args.rho = [1,0.25,0.25]
cl_rec = Rec(args)    
q = np.empty([ndist,args.nq,args.nq],dtype='complex64')        
for j in range(ndist):
    q[j] = cl_rec.DT(np.sqrt(ref[j:j+1]),j)[0]
vars={}
vars["q"] = cp.array(q.copy())
vars["r"] = r
vars["r_init"] = r.copy()
vars["psi"] = psi
# data=cp.array(data)
vars["table"] = pd.DataFrame(columns=["iter", "err", "time"])    
vars = cl_rec.BH(data, vars)  
ss

In [None]:
# args = SimpleNamespace()
# args.ngpus = 4

# args.n = n
# args.npsi = npsi
# args.ndist = ndist
# args.ntheta = ntheta
# args.pad = 0

# args.nq = n + 2 * 0
# args.nchunk = 2

# args.voxelsize = voxelsize
# args.wavelength = wavelength
# args.distance = distances
# args.norm_magnifications = norm_magnifications

# args.niter = 65
# args.vis_step = -1
# args.err_step = 16
# args.rho = [1,2,0.1]
# args.lam = 0
# # args.rho = [1, 5, 3]
# args.path_out = f"{path_out}/{pfile}/rec_psi_{ndist}"
# args.show = True


# rrrr=[]
# for rr in [5,10]:
#     print(f'{rr=}')
#     args.rho = [1,0.25,rr]
#     cl_rec = Rec(args)    
#     q = np.empty([ndist,args.nq,args.nq],dtype='complex64')        
#     for j in range(ndist):
#         q[j] = cl_rec.DT(np.sqrt(ref[j:j+1]),j)[0]
#     vars={}
#     vars["q"] = cp.array(q.copy())
#     vars["r"] = r.copy()
#     vars["r_init"] = r.copy()
#     vars["psi"] = psi.copy()
#     # data=cp.array(data)
#     vars["table"] = pd.DataFrame(columns=["iter", "err", "time"])    
#     vars = cl_rec.BH(data, vars)  
#     rrrr.append(vars["table"]['err'])
#     #1:1.41755e+05
# ss

In [None]:
# args = SimpleNamespace()
# args.ngpus = 4

# args.n = n
# args.npsi = npsi
# args.ndist = ndist
# args.ntheta = ntheta
# args.pad = 0

# args.nq = n + 2 * 0
# args.nchunk = 2

# args.voxelsize = voxelsize
# args.wavelength = wavelength
# args.distance = distances
# args.norm_magnifications = norm_magnifications

# args.niter = 65
# args.vis_step = -1
# args.err_step = 16
# args.rho = [1,2,0.1]
# args.lam = 0
# # args.rho = [1, 5, 3]
# args.path_out = f"{path_out}/{pfile}/rec_psi_{ndist}"
# args.show = True


# rrrr=[]
# for rr in [0.25,0.5,1,2]:
#     print(f'{rr=}')
#     args.rho = [1,rr,0]
#     cl_rec = Rec(args)    
#     q = np.empty([ndist,args.nq,args.nq],dtype='complex64')        
#     for j in range(ndist):
#         q[j] = cl_rec.DT(np.sqrt(ref[j:j+1]),j)[0]
#     vars={}
#     vars["q"] = cp.array(q.copy())
#     vars["r"] = r.copy()
#     vars["r_init"] = r.copy()
#     vars["psi"] = psi.copy()
#     # data=cp.array(data)
#     vars["table"] = pd.DataFrame(columns=["iter", "err", "time"])    
#     vars = cl_rec.BH(data, vars)  
#     rrrr.append(vars["table"]['err'])
#     #1:1.41755e+05

In [None]:
args = SimpleNamespace()
args.ngpus = 4

args.n = n
args.npsi = npsi
args.ndist = ndist
args.ntheta = ntheta
args.pad = 0

args.nq = n + 2 * 0
args.nchunk = 2

args.voxelsize = voxelsize
args.wavelength = wavelength
args.distance = distances
args.norm_magnifications = norm_magnifications

args.niter = 65
args.vis_step = 4
args.err_step = 16
args.rho = [1,2,0.1]
args.lam = 0
# args.rho = [1, 5, 3]
args.path_out = f"{path_out}/{pfile}/rec_psi_{ndist}"
args.show = True


rrrr=[]
for rr in [0.25,0.5,1,2,5,10]:
    print(f'{rr=}')
    args.rho = [1,0.25,rr]
    cl_rec = Rec(args)    
    q = np.empty([ndist,args.nq,args.nq],dtype='complex64')        
    for j in range(ndist):
        q[j] = cl_rec.DT(np.sqrt(ref[j:j+1]),j)[0]
    vars={}
    vars["q"] = cp.array(q.copy())
    vars["r"] = r.copy()
    vars["r_init"] = r.copy()
    vars["psi"] = psi.copy()
    # data=cp.array(data)
    vars["table"] = pd.DataFrame(columns=["iter", "err", "time"])    
    vars = cl_rec.BH(data, vars)  
    rrrr.append(vars["table"]['err'])
    #1:1.41755e+05

In [None]:
rrrr