In [None]:
import numpy as np
import cupy as cp
import sys
import pandas as pd
import time
import matplotlib.pyplot as plt
import h5py
from types import SimpleNamespace
import warnings
warnings.filterwarnings("ignore", message=f".*peer.*")

sys.path.insert(0, '..')
from utils import *
from rec import Rec

# Init data sizes and parametes of the PXM of ID16A

In [None]:
n = 256  # object size in each dimension
ntheta = 128  # number of angles (rotations)
theta = np.linspace(0, np.pi, ntheta).astype('float32')  # projection angles

ndist = 4
detector_pixelsize = 3e-6
energy = 33.35  # [keV] xray energy
wavelength = 1.2398419840550367e-09/energy  # [m] wave length

focusToDetectorDistance = 1.208  # [m]
sx0 = -2.493e-3
z1 = np.array([1.5335e-3, 1.7065e-3, 2.3975e-3, 3.8320e-3])[:ndist]-sx0
z2 = focusToDetectorDistance-z1
distances = (z1*z2)/focusToDetectorDistance
magnifications = focusToDetectorDistance/z1
voxelsize = detector_pixelsize/magnifications[0]*2048/n/2  # object voxel size

norm_magnifications = magnifications/magnifications[0]
distances = distances*norm_magnifications**2
npsi = int(np.ceil(n/norm_magnifications[-1]/8))*8  # make multiple of 8

In [None]:
args = SimpleNamespace()
args.ngpus = 1
args.n = n
args.ndist = ndist
args.ntheta = ntheta
args.pad = 0
args.npsi = npsi
args.nq = args.n + 2 * args.pad
args.nchunk = 32

args.voxelsize = voxelsize
args.wavelength = wavelength
args.distance = distances
args.rotation_axis=args.npsi/2

args.theta = theta
args.norm_magnifications = norm_magnifications
# create class
cl_rec = Rec(args)


## Read real and imaginary parts of the refractive index u = delta+i beta

In [None]:
u = np.load('data/u.npy')
u = u[u.shape[0]//2-args.npsi//2:u.shape[0]//2+args.npsi//2,
      u.shape[0]//2-args.npsi//2:u.shape[0]//2+args.npsi//2,
      u.shape[0]//2-args.npsi//2:u.shape[0]//2+args.npsi//2] # if already exists
u*=5000
mshow_complex(u[args.npsi//2],True)


## Compute tomographic projection data via the Fourier based method, $\mathcal{R}u$:

In [None]:

psi = cl_rec.expR(cl_rec.R(u))
mshow_polar(psi[0],True)

## Read a reference image previously recovered by the NFP (Near-field ptychogarphy) method at ID16A. 

In [None]:
!wget -nc https://g-110014.fd635.8443.data.globus.org/holotomocupy/examples_synthetic/data/prb_id16a/prb_abs_2048.tiff -P ../data/prb_id16a
!wget -nc https://g-110014.fd635.8443.data.globus.org/holotomocupy/examples_synthetic/data/prb_id16a/prb_phase_2048.tiff -P ../data/prb_id16a

q_abs = read_tiff(f'../data/prb_id16a/prb_abs_2048.tiff')[0:ndist]
q_phase = read_tiff(f'../data/prb_id16a/prb_phase_2048.tiff')[0:ndist]
q = q_abs*np.exp(1j*q_phase).astype('complex64')


for k in range(3):
    q = q[:, ::2]+q[:, 1::2]
    q = q[:, :, ::2]+q[:, :, 1::2]/4

# q = q[:, 128-pad:-128+pad, 128-pad:-128+pad]
q /= np.mean(np.abs(q))

mshow_polar(q[0],True)
mshow_polar(q[-1],True)

# Smooth the probe, the loaded one is too noisy

In [None]:
v = np.arange(-n//2,n//2)/n
[vx,vy] = np.meshgrid(v,v)
v=np.exp(-5*(vx**2+vy**2))
q = np.fft.fftshift(np.fft.fftn(np.fft.fftshift(q)))
q = np.fft.fftshift(np.fft.ifftn(np.fft.fftshift(q*v)))
q = q.astype('complex64')

mshow_polar(q[0],True)

# Shifts/drifts

In [None]:
# random sample shifts for each projection (note:before magnification)
shifts = (np.random.random([ntheta, ndist, 2]).astype('float32')-0.5)*n/32

# use the first plane as the global reference for illumination
if ndist>1:
    shifts[:, 1] += np.array([0.6, 0.3])
    shifts[:, 2] += np.array([-1.3, 1.5])
    shifts[:, 3] += np.array([2.3, -3.5])

np.save('data/shifts', shifts)

In [None]:
a=psi
b = cl_rec.M(a,2)
c = cl_rec.MT(b,2)
print(np.sum(a*c.conj()))
print(np.sum(b*b.conj()))
mshow_polar(a[0],True)
mshow_polar(b[0],True)
mshow_polar(c[0],True)

In [None]:
r = shifts
data = cl_rec.fwd(cp.array(r),cp.array(u),cp.array(q))
ref = cl_rec.fwd(cp.array(r),cp.array(u*0),cp.array(q))
mshow_polar(data[0,2],True)
mshow_polar(ref[0,2],True)

### Take squared absolute value to simulate data on the detector and a reference image

In [None]:
data = np.abs(data)**2
ref = np.abs(ref)**2

### Visualize data

In [None]:
for k in range(ndist):
    mshow(data[0,k],show)


### Visualize reference images

In [None]:
for k in range(ndist):
    mshow(ref[0,k],True)

### Save data, reference images

In [None]:
for k in range(len(distances)):
    write_tiff(data[:,k],f'/data/vnikitin/syn/data/data_{k}')
for k in range(len(distances)):
    write_tiff(ref[0,k],f'/data/vnikitin/syn/ref_{k}')
for k in range(len(distances)):
    write_tiff(np.abs(q[k]),f'/data/vnikitin/syn/q_abs_{k}')
    write_tiff(np.angle(q[k]),f'/data/vnikitin/syn/q_angle_{k}')

np.save(f'/data/vnikitin/syn/r',r)