In [None]:
import numpy as np
import cupy as cp
import sys
import pandas as pd
import time
import matplotlib.pyplot as plt
import h5py
from types import SimpleNamespace
import warnings
warnings.filterwarnings("ignore", message=f".*peer.*")

sys.path.insert(0, '..')
from utils import *
from rec import Rec

# Init data sizes and parametes of the PXM of ID16A

In [None]:
n = 256  # object size in each dimension
ntheta = 4  # number of angles (rotations)

theta = np.linspace(0, np.pi, ntheta).astype('float32')  # projection angles

# ID16a setup
ndist = 4

detector_pixelsize = 3e-6
energy = 33.35  # [keV] xray energy
wavelength = 1.2398419840550367e-09/energy  # [m] wave length

focusToDetectorDistance = 1.208  # [m]
sx0 = -2.493e-3
z1 = np.array([1.5335e-3, 1.7065e-3, 2.3975e-3, 3.8320e-3])[:ndist]-sx0
# z1[-1]*=4
z2 = focusToDetectorDistance-z1
distances = (z1*z2)/focusToDetectorDistance
magnifications = focusToDetectorDistance/z1
voxelsize = detector_pixelsize/magnifications[0]*2048/n/2  # object voxel size

norm_magnifications = magnifications/magnifications[0]
# scaled propagation distances due to magnified probes
distances = distances*norm_magnifications**2

# z1p = z1[0]  # positions of the probe for reconstruction
# z2p = z1-np.tile(z1p, len(z1))
# # magnification when propagating from the probe plane to the detector
# magnifications2 = (z1p+z2p)/z1p
# # propagation distances after switching from the point source wave to plane wave,
# distances2 = (z1p*z2p)/(z1p+z2p)
# norm_magnifications2 = magnifications2/(z1p/z1[0])  # normalized magnifications
# # scaled propagation distances due to magnified probes
# distances2 = distances2*norm_magnifications2**2
# distances2 = distances2*(z1p/z1)**2

# allow padding if there are shifts of the probe
pad = 0
show=True
# sample size after demagnification
npsi = int(np.ceil((n+2*pad)/norm_magnifications[-1]/8))*8  # make multiple of 8
print(voxelsize)

In [None]:
args = SimpleNamespace()
args.npos = 1
args.ngpus = 1

args.n = n
args.ndist = ndist
args.ntheta = ntheta
args.pad = 0
args.npsi = npsi
args.nq = args.n + 2 * args.pad
args.ex = 0
args.npatch = args.nq + 2 * args.ex
args.nchunk = 4

args.voxelsize = voxelsize
args.wavelength = wavelength
args.distance = distances
args.eps = 1e-12
args.rho = [1, 20, 10]
args.path_out = f"/data/vnikitin/ESRF/ID16A/20240924_rec0224//AtomiumS2/HT/s1"
args.show = True
args.lam=0

args.niter=10000
args.vis_step=1
args.err_step=1
args.method = "BH-CG"
args.rotation_axis=args.npsi/2

args.theta = theta
args.norm_magnifications = norm_magnifications
# create class
cl_rec = Rec(args)


## Read real and imaginary parts of the refractive index u = delta+i beta

In [None]:
# from scipy import ndimage

# cube_all = np.zeros([args.npsi,args.npsi,args.npsi],dtype='float32')
# rr = (np.ones(8)*408*0.2).astype(np.int32)
# amps = [3, 2, -3, 1, 2,-4,2]#, -2, -4, 5 ]
# dil = [30, 28, 25, 21, 16,10,3]#, 6, 3,1]
# for kk in range(len(amps)):
#     cube = np.zeros([args.npsi,args.npsi,args.npsi],dtype='bool')
#     r = rr[kk]
#     p1 = args.npsi//2-r//2
#     p2 = args.npsi//2+r//2
#     for k in range(3):    
#         cube = cube.swapaxes(0,k)
#         cube[p1:p2,p1,p1] = True
#         cube[p1:p2,p1,p2] = True
#         cube[p1:p2,p2,p1] = True
#         cube[p1:p2,p2,p2] = True        
#         #cube[p1:p2,p2,p2] = True        
        
#     [x,y,z] = np.meshgrid(np.arange(-args.npsi//2,args.npsi//2),np.arange(-args.npsi//2,args.npsi//2),np.arange(-args.npsi//2,args.npsi//2))
#     circ = (x**2+y**2+z**2)<dil[kk]**2        
#     fcirc = np.fft.fftshift(np.fft.fftn(np.fft.fftshift(circ)))
#     fcube = np.fft.fftshift(np.fft.fftn(np.fft.fftshift(cube.astype('float32'))))
#     cube = np.fft.fftshift(np.fft.ifftn(np.fft.fftshift(fcube*fcirc))).real
#     cube = cube>1
#     cube_all+=amps[kk]*cube

# cube_all = ndimage.rotate(cube_all,52,axes=(1,2),reshape=False,order=1)
# cube_all = ndimage.rotate(cube_all,38,axes=(0,1),reshape=False,order=1)
# cube_all = ndimage.rotate(cube_all,10,axes=(0,2),reshape=False,order=1)
# cube_all[cube_all<0] = 0



# u0 = (-1*cube_all*1e-6+1j*cube_all*1e-8)/1.5
# u0=np.roll(u0,-15,axis=2)
# u0=np.roll(u0,-10,axis=1)
# v = np.arange(-args.npsi//2,args.npsi//2)/args.npsi
# [vx,vy,vz] = np.meshgrid(v,v,v)
# v = np.exp(-10*(vx**2+vy**2+vz**2))
# fu = np.fft.fftshift(np.fft.fftn(np.fft.fftshift(u0)))
# u0 = np.fft.fftshift(np.fft.ifftn(np.fft.fftshift(fu*v))).astype('complex64')

# !mkdir -p data
# np.save('data/u',u0)
# u = u0
u = np.load('data/u.npy')
u = u[u.shape[0]//2-args.npsi//2:u.shape[0]//2+args.npsi//2,
      u.shape[0]//2-args.npsi//2:u.shape[0]//2+args.npsi//2,
      u.shape[0]//2-args.npsi//2:u.shape[0]//2+args.npsi//2] # if already exists
u*=5000
mshow_complex(u[args.npsi//2],show)


## Compute tomographic projection data via the Fourier based method, $\mathcal{R}u$:

In [None]:

psi = cl_rec.expR(cl_rec.R(u))
mshow_polar(psi[0],show)

## Read a reference image previously recovered by the NFP (Near-field ptychogarphy) method at ID16A. 

In [None]:
!wget -nc https://g-110014.fd635.8443.data.globus.org/holotomocupy/examples_synthetic/data/prb_id16a/prb_abs_2048.tiff -P ../data/prb_id16a
!wget -nc https://g-110014.fd635.8443.data.globus.org/holotomocupy/examples_synthetic/data/prb_id16a/prb_phase_2048.tiff -P ../data/prb_id16a

q_abs = read_tiff(f'../data/prb_id16a/prb_abs_2048.tiff')[0:ndist]
q_phase = read_tiff(f'../data/prb_id16a/prb_phase_2048.tiff')[0:ndist]
q = q_abs*np.exp(1j*q_phase).astype('complex64')


for k in range(3):
    q = q[:, ::2]+q[:, 1::2]
    q = q[:, :, ::2]+q[:, :, 1::2]/4

# q = q[:, 128-pad:-128+pad, 128-pad:-128+pad]
q /= np.mean(np.abs(q))
# q[:]=1

mshow_polar(q[0],show)
mshow_polar(q[-1],show)

# Smooth the probe, the loaded one is too noisy

In [None]:
v = np.arange(-(n+2*pad)//2,(n+2*pad)//2)/(n+2*pad)
[vx,vy] = np.meshgrid(v,v)
v=np.exp(-5*(vx**2+vy**2))
q = np.fft.fftshift(np.fft.fftn(np.fft.fftshift(q)))
q = np.fft.fftshift(np.fft.ifftn(np.fft.fftshift(q*v)))
q = q.astype('complex64')

mshow_polar(q[0],show)

# Shifts/drifts

In [None]:
# random sample shifts for each projection (note:before magnification)
shifts = (np.random.random([ntheta, ndist, 2]).astype('float32')-0.5)*n/32

# use the first plane as the global reference for illumination
if ndist>1:
    shifts[:, 1] += np.array([0.6, 0.3])
    shifts[:, 2] += np.array([-1.3, 1.5])
    shifts[:, 3] += np.array([2.3, -3.5])

np.save('data/shifts', shifts)

In [None]:
a=psi
b = cl_rec.M(a,2)
c = cl_rec.MT(b,2)
print(np.sum(a*c.conj()))
print(np.sum(b*b.conj()))
mshow_polar(a[0],show)
mshow_polar(b[0],show)
mshow_polar(c[0],show)

In [None]:
r = shifts
data = cl_rec.fwd(cp.array(r),cp.array(u),cp.array(q))
ref = cl_rec.fwd(cp.array(r),cp.array(u*0),cp.array(q))
mshow_polar(data[0,2],show)
mshow_polar(ref[0,2],show)

### Take squared absolute value to simulate data on the detector and a reference image

In [None]:
data = np.abs(data)**2
ref = np.abs(ref)**2

### Visualize data

In [None]:
for k in range(ndist):
    mshow(data[0,k],show)


In [None]:
rdata = data/ref
srdata = np.zeros([ntheta,ndist,args.npsi,args.npsi],dtype='complex64')
distances_pag = (distances/norm_magnifications**2)
for j in np.arange(ndist)[::-1]:
    
    tmp = cl_rec.ST(r[:,j]*norm_magnifications[j],rdata[:,j].astype('complex64'))
    tmp = np.abs(cl_rec.MT(tmp,j)/norm_magnifications[j]**2)    
    tmp=tmp.astype('float32')
    print(distances_pag[j-1]-distances_pag[j])
    
    if j>0:
        mshow_complex(prev[0]+1j*tmp[0],True)
        mshow(prev[0]-tmp[0],True)
        prev = np.abs(cl_rec.Da(np.sqrt(tmp).astype('complex64'),distances_pag[j-1]-distances_pag[j]))**2
    mshow_complex(prev[0]+1j*tmp[0],True)
    mshow(prev[0]-tmp[0],True)
    # t = cl_rec.ST(r[:,j]*norm_magnifications[j],rdata[:,j].astype('complex64')*0+1)
    # t = cl_rec.MT(t,j)/norm_magnifications[j]**2    

    # tmp[t<1e-3]=0
    # mshow_complex(t[0],True)
    # mshow_complex(tmp[0]/(t[0]+1e-5),True)

    # tmp = tmp[:,args.npsi//2-args.nq//2:args.npsi//2+args.nq//2,args.npsi//2-args.nq//2:args.npsi//2+args.nq//2]
    srdata[:,j]=tmp#np.pad(tmp,((0,0),(mpad,mpad),(mpad,mpad)),'edge')
    
srdata=srdata.real


### Visualize reference images

In [None]:
for k in range(ndist):
    mshow(ref[0,k],show)

### Save data, reference images

In [None]:
for k in range(len(distances)):
    write_tiff(data[:,k],f'/data/vnikitin/syn/data/data_{k}')
for k in range(len(distances)):
    write_tiff(ref[0,k],f'/data/vnikitin/syn/ref_{k}')
for k in range(len(distances)):
    write_tiff(np.abs(q[k]),f'/data/vnikitin/syn/q_abs_{k}')
    write_tiff(np.angle(q[k]),f'/data/vnikitin/syn/q_angle_{k}')

np.save(f'/data/vnikitin/syn/r',r)