In [1]:
import cupy as cp
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import cv2
import xraylib
import random
import cupyx.scipy.ndimage as ndimage
from types import SimpleNamespace
import warnings
warnings.filterwarnings("ignore", message=f".*peer.*")

from utils import *
from rec import Rec

## Sizes and propagation settings

In [2]:
n = 1024  # detector size in each dimension
npsi = 2048 # object size in each dimension
npos = 32 # total number of positions
ngpus = 4 # number of gpus
nchunk = 4 # chunks on gpu

z1 = -17.75e-3 # [m] position of the CA
detector_pixelsize = 3.03751e-6
energy = 33.35  # [keV] xray energy
wavelength = 1.24e-09/energy  # [m] wave length
focusToDetectorDistance = 1.28  # [m]
z2 = focusToDetectorDistance-z1
distance = (z1*z2)/focusToDetectorDistance
magnification = focusToDetectorDistance/z1
voxelsize = np.abs(detector_pixelsize/magnification)  # object voxel size

In [3]:
args = SimpleNamespace()

args.ngpus = ngpus
args.n = n
args.npsi = npsi
args.pad = 0
args.nq = args.n + 2 * args.pad
args.ex = 0
args.npatch = args.nq + 2 * args.ex
args.npos = npos
args.nchunk = nchunk
args.voxelsize = voxelsize
args.wavelength = wavelength
args.distance = distance

# create class
cl_rec = Rec(args)

## Generate positions of the CA

In [None]:
np.random.seed(2)
pos = (np.random.random([npos,2]).astype('float32')-0.5)*npsi/3
mplot_positions(pos)

## Load probe

In [None]:
!wget -nc https://g-110014.fd635.8443.data.globus.org/holotomocupy/examples_synthetic/data/prb_id16a/prb_abs_2048.tiff 
!wget -nc https://g-110014.fd635.8443.data.globus.org/holotomocupy/examples_synthetic/data/prb_id16a/prb_phase_2048.tiff 

q_abs = read_tiff(f'prb_abs_2048.tiff')[0,1024-args.nq//2:1024+args.nq//2,1024-args.nq//2:1024+args.nq//2]
q_phase = read_tiff(f'prb_phase_2048.tiff')[0,1024-args.nq//2:1024+args.nq//2,1024-args.nq//2:1024+args.nq//2]
q = cp.array(q_abs*np.exp(1j*q_phase)).astype('complex64')

#smooth it slightly
v = cp.arange(-args.nq//2,args.nq//2)/args.nq
[vx,vy] = cp.meshgrid(v,v)
v = cp.exp(-10*(vx**2+vy**2))
fq = cp.fft.fftshift(cp.fft.fftn(cp.fft.fftshift(q)))
q = cp.fft.fftshift(cp.fft.ifftn(cp.fft.fftshift(fq*v))).astype('complex64')
q = q.get()
mshow_polar(q)


## generate coded aperture

In [None]:
bin_size = 2e-6
code_thickness = 1.5e-6

random.seed(10)
nill = 2*args.npsi
ill_global = cp.zeros([nill,nill],dtype='bool')
ill0 = cp.zeros([nill*nill],dtype='bool')
ill_ids = random.sample(range(0, nill*nill), int(nill*nill*0.55))
ill0[ill_ids] = 1
ill_global = ill0.reshape(nill,nill)

# form codes for simulations
nill = int((args.npsi*voxelsize)//(bin_size*2))*2
ill = cp.zeros([nill,nill],dtype='bool')
ill0 = ill_global
ill = ill0[ill0.shape[0]//2-nill//2:ill0.shape[0]//2+(nill)//2,
                ill0.shape[1]//2-nill//2:ill0.shape[1]//2+(nill)//2]#.reshape(nill,nill)

ill = ndimage.zoom(ill,args.npsi/nill,order=0,grid_mode=True,mode='grid-wrap')

delta = 1-xraylib.Refractive_Index_Re('Au',energy,19.3)
beta = xraylib.Refractive_Index_Im('Au',energy,19.3)

thickness = code_thickness/voxelsize # thickness in pixels

# form Transmittance function
Rill = ill*(-delta+1j*beta)*thickness 
Rill=ndimage.rotate(Rill, 45, axes=(1, 0), reshape=False, order=3, mode='reflect',
                   prefilter=True)

v = cp.arange(-args.npsi//2,args.npsi//2)/2/args.npsi
[vx,vy] = cp.meshgrid(v,v)
v = cp.exp(-10*(vx**2+vy**2))
fill = cp.fft.fftshift(cp.fft.fftn(cp.fft.fftshift(Rill)))
Rill = cp.fft.fftshift(cp.fft.ifftn(cp.fft.fftshift(fill*v)))
Rill = Rill.astype('complex64')

psi = cp.exp(1j * Rill * voxelsize * 2 * np.pi / wavelength).astype('complex64')
psi = psi.get()
mshow_polar(psi,True)
mshow_polar(psi[:n//5,:n//5],True)



### check bin size in pixels:

In [None]:
bin_size/voxelsize

## simulate data and reference image

In [None]:
ri = np.round(pos).astype('int32')
r = (pos-ri).astype('float32')
data = np.abs(cl_rec.fwd(ri,r,psi,q))**2
ref = (np.abs(cl_rec.fwd(ri,r,psi*0+1,q))**2)[0]
mshow(data[0])
mshow(ref)
mshow(data[0]/ref)

# Reconstruction

## init probe by propagting it the reference image

In [None]:
q_init = cp.array(cl_rec.DT(np.sqrt(ref[np.newaxis]))[0])

ppad = 3 * args.pad // 2
q_init = np.pad(
    q_init[ppad : args.nq - ppad, ppad : args.nq - ppad],
    ((ppad, ppad), (ppad, ppad)),
    "symmetric",
)
v = cp.ones(args.nq, dtype="float32")
vv = cp.sin(cp.linspace(0, cp.pi / 2, ppad))
v[:ppad] = vv
v[args.nq - ppad :] = vv[::-1]
v = cp.outer(v, v)
q_init = cp.abs(q_init * v) * cp.exp(1j * cp.angle(q_init) * v)
q_init = q_init.get()
mshow_polar(q_init)

In [None]:
def Paganin(data, wavelength, voxelsize, delta_beta, alpha):
    fx = cp.fft.fftfreq(data.shape[-1], d=voxelsize).astype("float32")
    [fx, fy] = cp.meshgrid(fx, fx)
    rad_freq = cp.fft.fft2(data)
    taylorExp = 1 + wavelength * distance * cp.pi * (delta_beta) * (fx**2 + fy**2)
    numerator = taylorExp * (rad_freq)
    denominator = taylorExp**2 + alpha
    phase = cp.log(cp.real(cp.fft.ifft2(numerator / denominator)))
    phase = delta_beta * 0.5 * phase
    return phase


def rec_init(rdata, ishifts):
    recMultiPaganin = cp.zeros([args.npsi, args.npsi], dtype="float32")
    recMultiPaganinr = cp.zeros(
        [args.npsi, args.npsi], dtype="float32"
    )  # to compensate for overlap
    for j in range(0, npos):
        r = cp.array(rdata[j])
        r = Paganin(r, wavelength, voxelsize, 24.05, 1e-3)
        rr = r * 0 + 1  # to compensate for overlap
        rpsi = cp.zeros([args.npsi, args.npsi], dtype="float32")
        rrpsi = cp.zeros([args.npsi, args.npsi], dtype="float32")
        stx = args.npsi // 2 - ishifts[j, 1] - n // 2
        endx = stx + n
        sty = args.npsi // 2 - ishifts[j, 0] - n // 2
        endy = sty + n
        rpsi[sty:endy, stx:endx] = r
        rrpsi[sty:endy, stx:endx] = rr

        recMultiPaganin += rpsi
        recMultiPaganinr += rrpsi

    recMultiPaganinr[np.abs(recMultiPaganinr) < 5e-2] = 1
    recMultiPaganin /= recMultiPaganinr
    recMultiPaganin = cp.exp(1j * recMultiPaganin)
    return recMultiPaganin



rdata = np.array(data / (ref + 1e-5))
psi_init = rec_init(rdata, ri)
mshow_polar(psi_init,True)
mshow_polar(psi_init[:1000, :1000],True)

# smooth borders
v = cp.arange(-args.npsi // 2,args.npsi // 2) / args.npsi
[vx, vy] = cp.meshgrid(v, v)
v = cp.exp(-1000 * (vx**2 + vy**2)).astype("float32")

psi_init = cp.fft.fftshift(cp.fft.fftn(cp.fft.fftshift(psi_init)))
psi_init = cp.fft.fftshift(cp.fft.ifftn(cp.fft.fftshift(psi_init * v))).astype("complex64")
psi_init=psi_init.get()
mshow_polar(psi_init)
mshow_polar(psi_init[:1000, :1000])


# Reconstruction by the BH method

In [None]:
# variables
vars = {}
vars["psi"] = psi_init.copy()
vars["q"] = q_init.copy()
vars["ri"] = np.round(pos).astype("int32")
vars["r_init"] = (pos - vars["ri"]).astype("float32")
vars["r"] = vars["r_init"].copy()+(np.random.random([npos,2])-0.5).astype('float32')
vars["table"] = pd.DataFrame(columns=["iter", "err", "time"])

cl_rec.rho = [1,2,0.1]
cl_rec.niter = 257
cl_rec.vis_step = 32
cl_rec.err_step = 32
cl_rec.eps = 0
cl_rec.lam = 0
cl_rec.path_out = "/data/tmp/"
cl_rec.show = True

# reconstruction
vars = cl_rec.BH(data, ref, vars)