In [None]:
import numpy as np
import cupy as cp
import sys
import pandas as pd
import time
import matplotlib.pyplot as plt
import h5py
from types import SimpleNamespace
import warnings
warnings.filterwarnings("ignore", message=f".*peer.*")

sys.path.insert(0, '..')
from utils import *
from rec import Rec

# Init data sizes and parametes of the PXM of ID16A

In [None]:
step = 36
bin = 1

In [None]:
with h5py.File(f'/data/vnikitin/ESRF/ID16A/20240924_rec/AtomiumS2/AtomiumS2_HT_007nm.h5') as fid:
    detector_pixelsize = fid['/exchange/detector_pixelsize'][0]    
    focusToDetectorDistance = fid['/exchange/focusdetectordistance'][0]    
    z1 = fid['/exchange/z1'][:]        
    theta = fid['/exchange/theta'][::step,0]
    shifts = fid['/exchange/shifts'][::step]
    shape = fid['/exchange/data0'][::step].shape    

In [None]:
theta = theta/180*np.pi

In [None]:
ndist=4
ntheta,n = shape[:2]
n//=2**bin
shifts/=2**bin
error = 5*(np.random.random(shifts.shape).astype("float32")-0.5)
shifts_error=shifts+error

In [None]:
energy = 33.35  # [keV] xray energy
wavelength = 1.2398419840550367e-09/energy  # [m] wave length
z2 = focusToDetectorDistance-z1
distances = (z1*z2)/focusToDetectorDistance
magnifications = focusToDetectorDistance/z1
norm_magnifications = magnifications/magnifications[0]
voxelsize = detector_pixelsize/magnifications[0]*2048/n  # object voxel size
show = True


In [None]:
pad = 0
npsi = int(np.ceil((2048+2*pad)/norm_magnifications[-1]/16))*16  # make multiple of 8
npsi//=(2048//n)

In [None]:
args = SimpleNamespace()
args.ngpus = 4

args.n = n
args.ndist = ndist
args.ntheta = ntheta
args.pad = pad

args.nq = n + 2 * pad
args.nchunk = 2

args.voxelsize = voxelsize
args.wavelength = wavelength
args.distance = distances
args.eps = 1e-12
args.rho = [1, 5, 3]
args.lam = 0
args.path_out = f"/data/vnikitin/ESRF/ID16A/20240924_rec/AtomiumS2/AtomiumS2_HT_007nm/syn"
args.show = True

args.niter=1
args.vis_step=1
args.err_step=1
args.method = "BH-CG"


ppad = 0#(npsi//2-n//2)

args.rotation_axis=796.25*n/1024-ppad#397.5*2#499.75*n//1024+npsi//2-n//2
npsi-=2*ppad
args.npsi = npsi
print(ppad,npsi,n,args.rotation_axis)
args.theta = theta
args.norm_magnifications = norm_magnifications
# create class
cl_rec = Rec(args)

## Read real and imaginary parts of the refractive index u = delta+i beta

In [None]:
u = np.load(f'/data/vnikitin/syn_3d_ald/u{n}.npy').astype('complex64') # if already exists
u*=1000#*2.5
ppad = npsi//2-n//2
u = np.pad(u,((ppad,ppad),(ppad,ppad),(ppad,ppad)))
mshow_complex(u[args.npsi//2],show)
print(u.shape,npsi,u.dtype)


## Compute tomographic projection data via the Fourier based method, $\mathcal{R}u$:

In [None]:

psi = cl_rec.expR(cl_rec.R(u))
mshow_polar(psi[0],show)

## Read a reference image previously recovered by the NFP (Near-field ptychogarphy) method at ID16A. 

In [None]:
!wget -nc https://g-110014.fd635.8443.data.globus.org/holotomocupy/examples_synthetic/data/prb_id16a/prb_abs_2048.tiff -P ../data/prb_id16a
!wget -nc https://g-110014.fd635.8443.data.globus.org/holotomocupy/examples_synthetic/data/prb_id16a/prb_phase_2048.tiff -P ../data/prb_id16a

q_abs = read_tiff(f'../data/prb_id16a/prb_abs_2048.tiff')[0:ndist]
q_phase = read_tiff(f'../data/prb_id16a/prb_phase_2048.tiff')[0:ndist]
q = q_abs*np.exp(1j*q_phase).astype('complex64')


for k in range(int(np.log2(2048//n))):
    q = q[:, ::2]+q[:, 1::2]
    q = q[:, :, ::2]+q[:, :, 1::2]/4
q /= np.mean(np.abs(q))
q[:]=1

mshow_polar(q[0],show)
mshow_polar(q[-1],show)

# Smooth the probe, the loaded one is too noisy

In [None]:
# v = np.arange(-(n+2*pad)//2,(n+2*pad)//2)/(n+2*pad)
# [vx,vy] = np.meshgrid(v,v)
# v=np.exp(-5*(vx**2+vy**2))
# q = np.fft.fftshift(np.fft.fftn(np.fft.fftshift(q)))
# q = np.fft.fftshift(np.fft.ifftn(np.fft.fftshift(q*v)))
# q = q.astype('complex64')

# mshow_polar(q[0],show)

# Shifts/drifts

In [None]:
r = shifts
data = cl_rec.fwd(r,u,q)
ref = cl_rec.fwd(r,u*0,q)

mshow_polar(data[0,2],show)
mshow_polar(ref[0,2],show)

In [None]:
# with h5py.File(f'/data/vnikitin/ESRF/ID16A/20240924_rec/AtomiumS2/AtomiumS2_HT_007nm_corr.h5') as fid:
#     u_init = fid['/exchange/u_init_re'][:]+1j*fid['/exchange/u_init_imag'][:]    
r = np.zeros([ntheta,ndist,2],dtype='float32')
q = np.ones([ndist,args.nq,args.nq],dtype='complex64')
dd = cl_rec.DT(data[:,3],3)
mshow_polar(dd[0],True)
# d = np.abs(cl_rec.fwd(r,u_init,q))**2

In [None]:
mshow(np.abs(data[0,2]),show)

In [None]:
mshow(np.abs(data[0,0]),show)

### Take squared absolute value to simulate data on the detector and a reference image

In [None]:
# psi = cl_rec.D(cl_rec.M(cl_rec.expR(cl_rec.R(u)),3),0)

In [None]:
# mshow_polar(psi[0],True)

In [None]:
# mshow_polar(psi[-1],True)

In [None]:
data = np.abs(data)**2
ref = np.abs(ref)**2

### Visualize data

In [None]:
for k in range(ndist):
    mshow(data[0,k]/ref[0,k],show)


### Visualize reference images

In [None]:
for k in range(ndist):
    mshow(ref[0,k],show)

### Save data, reference images

In [None]:
r

In [None]:
with  h5py.File(f'{args.path_out}/data.h5','w') as fid:
    fid.create_dataset(f'/exchange/u_real',data=u.real)
    fid.create_dataset(f'/exchange/u_imag',data=u.imag)
    
    fid.create_dataset('/exchange/theta',data=theta/np.pi*180)
    fid.create_dataset('/exchange/cshifts_final',data=r*2048/n)   
    fid.create_dataset('/exchange/shifts_error',data=shifts_error*2048/n)   
    fid.create_dataset('/exchange/voxelsize',data=np.array([voxelsize]))
    fid.create_dataset('/exchange/z1',data=z1)
    fid.create_dataset('/exchange/detector_pixelsize',data=np.array([detector_pixelsize]))
    fid.create_dataset('/exchange/focusdetectordistance',data=np.array([focusToDetectorDistance]))    
    
    fid.create_dataset(f'/exchange/ref',data=ref[0])
    for k in range(4):
        fid.create_dataset(f'/exchange/data{k}',data=data[:,k])        
        fid.create_dataset(f'/exchange/q_abs{k}',data=np.abs(q[k]))
        fid.create_dataset(f'/exchange/q_angle{k}',data=np.angle(q[k]))
    
    

In [None]:
q_init=q.copy()
with h5py.File(f'{args.path_out}/data.h5') as fid:
    u_init = fid[f'/exchange/u_real'][:]+1j*fid[f'/exchange/u_imag'][:]    

In [None]:
print(np.linalg.norm(r),np.linalg.norm(u),np.linalg.norm(q),np.linalg.norm(q_init),z1,voxelsize)