In [None]:
import numpy as np
import cupy as cp
import h5py
from holotomocupy.holo import G, GT
from holotomocupy.shift import S, ST
from holotomocupy.recon_methods import multiPaganin
from holotomocupy.utils import *
from holotomocupy.proc import remove_outliers
import cv2
import xraylib
##!jupyter nbconvert --to script config_template.ipynb

# Init data sizes and parametes of the PXM of ID16A

In [None]:
n = 512  # detector size
ne = n+n//4
energy = 33.35  # [keV] xray energy
wavelength = 1.2398419840550367e-09/energy  # [m] wave length
focusToDetectorDistance = 1.28  # [m]
ndist = 16
# distances = np.array([0.0029432,0.00306911,0.00357247,0.00461673])[:ndist] # [m]
distances = np.linspace(3e-3,5e-3,ndist)
magnification = 400
detector_pixelsize = 3.03751e-6
voxelsize = detector_pixelsize/magnification*2048/n  # object voxel size

distances2 = distances[-1]-distances
path = f'/data/vnikitin/modeling/siemens{n}'
show=True
print(distances+distances2)

## Modeling data

In [None]:
img = np.zeros((n, n, 3), np.uint8)
triangle = np.array([(n//32, n//2-n//32), (n//32, n//2+n//32), (n//2-n//128, n//2)], np.float32)
star = img[:,:,0]*0
for i in range(0, 360, 15):
    img = np.zeros((n, n, 3), np.uint8)
    degree = i
    theta = degree * np.pi / 180
    rot_mat = np.array([[np.cos(theta), -np.sin(theta)],
                        [np.sin(theta), np.cos(theta)]], np.float32)    
    rotated = cv2.gemm(triangle-n//2, rot_mat, 1, None, 1, flags=cv2.GEMM_2_T)+n//2
    cv2.fillPoly(img, [np.int32(rotated)], (255, 0, 0))
    star+=img[:,:,0]
[x,y] = np.meshgrid(np.arange(-n//2,n//2),np.arange(-n//2,n//2))
x = x/n*2
y = y/n*2
# add holes in triangles
circ = (x**2+y**2>0.145)+(x**2+y**2<0.135)
circ *= (x**2+y**2>0.053)+(x**2+y**2<0.05)
circ *= (x**2+y**2>0.0085)+(x**2+y**2<0.008)
circ *= (x**2+y**2>0.82)+(x**2+y**2<0.8)

star = star*circ/255
star = np.pad(star,((ne//2-n//2,ne//2-n//2),(ne//2-n//2,ne//2-n//2)),'edge')
# star[n//2-n//2:n//2+n//2,n//2-n//2:n//2+n//2]=0

v = np.arange(-ne//2,ne//2)/ne
[vx,vy] = np.meshgrid(v,v)
v = np.exp(-5*(vx**2+vy**2))
fu = np.fft.fftshift(np.fft.fftn(np.fft.fftshift(star)))
star = np.fft.fftshift(np.fft.ifftn(np.fft.fftshift(fu*v))).real

delta = 1-xraylib.Refractive_Index_Re('Au',energy,19.3)
beta = xraylib.Refractive_Index_Im('Au',energy,19.3)

thicknss = 2e-6/voxelsize # siemens star thicknss in pixels
# form Transmittance function
u = star*(-delta+1j*beta) # note -delta
Ru = u*thicknss 
psi = np.exp(1j * Ru * voxelsize * 2 * np.pi / wavelength)[np.newaxis].astype('complex64')
fig, axs = plt.subplots(1, 2, figsize=(9, 4))
im=axs[0].imshow(np.abs(psi[0]),cmap='gray')
axs[0].set_title('amplitude')
fig.colorbar(im)
im=axs[1].imshow(np.angle(psi[0]),cmap='gray')
axs[1].set_title('phase')
fig.colorbar(im)


from scipy.io import savemat
savemat('data.mat',{'psi': psi[0,n//2-n//2:n//2+n//2,n//2-n//2:n//2+n//2]})

In [None]:
psi=cp.array(psi)
v = cp.ones(ne,dtype='float32')
fs = (ne-n)//2

v[:(ne-n)//2] = cp.sin(cp.linspace(0,1,(ne-n)//2)*cp.pi/2)
v[-(ne-n)//2:] = cp.cos(cp.linspace(0,1,(ne-n)//2)*cp.pi/2)
v = cp.outer(v,v)
mshow(v,show)
mshow_polar(v*psi[0],show)


# Construct operators


In [None]:
def L1op(psi):
    data = cp.zeros([1,ndist, ne, ne], dtype='complex64')
    for i in range(ndist):
        psir = cp.array(psi)           
        psir = G(psir, wavelength, voxelsize, distances2[i],'symmetric')        
        data[:, i] = psir#[:,ne//2-n//2:ne//2+n//2,ne//2-n//2:ne//2+n//2]
    return data

def L1Top(data):
    psi = cp.zeros([1, ne, ne], dtype='complex64')
    for j in range(ndist):
        datar = cp.array(data[:, j])
        datar = GT(datar, wavelength, voxelsize, distances2[j],'symmetric')        
        psi += datar
    return psi

def Lop(psi,prb):
    data = cp.zeros([1,ndist, n, n], dtype='complex64')
    Lprb = L1op(prb)
    for i in range(ndist):
        psir = psi*Lprb[:,i]
        v = cp.ones(ne,dtype='float32')
        v[:(ne-n)//2] = cp.sin(cp.linspace(0,1,(ne-n)//2)*cp.pi/2)
        v[-(ne-n)//2:] = cp.cos(cp.linspace(0,1,(ne-n)//2)*cp.pi/2)
        v = cp.outer(v,v)


        psir*=v
        psir = G(psir, wavelength, voxelsize, distances[i],'constant')        

        data[:, i] = psir[:,ne//2-n//2:ne//2+n//2,ne//2-n//2:ne//2+n//2]
    return data

def LTop(data,prb):
    psi = cp.zeros([1, ne, ne], dtype='complex64')
    Lprb = L1op(prb)
    for j in range(ndist):
        datar = cp.array(cp.pad(data[:, j],((0,0),(ne//2-n//2,ne//2-n//2),(ne//2-n//2,ne//2-n//2)))).astype('complex64')        
        datar = GT(datar, wavelength, voxelsize, distances[j],'constant')        

        v = cp.ones(ne,dtype='float32')
        v[:(ne-n)//2] = cp.sin(cp.linspace(0,1,(ne-n)//2)*cp.pi/2)
        v[-(ne-n)//2:] = cp.cos(cp.linspace(0,1,(ne-n)//2)*cp.pi/2)
        v = cp.outer(v,v)        
        datar *= v

        psi += datar*cp.conj(Lprb[:,j])
    
    return psi
# def Cop(psi):
#     return psi[:,ne//2-n//2:ne//2+n//2,ne//2-n//2:ne//2+n//2]

# def CTop(psi):
#     return cp.pad(psi,((0,0),(ne//2-n//2,ne//2-n//2),(ne//2-n//2,ne//2-n//2)))

# # adjoint tests
# tmp = cp.array(data).copy().astype('complex64')
# arr1 = cp.pad(tmp[:,0],((0,0),(ne//2-n//2,ne//2-n//2),(ne//2-n//2,ne//2-n//2)),'symmetric')     

# arr2 = Lop(arr1)
# arr3 = LTop(arr2)


# print(f'{np.sum(arr1*np.conj(arr3))}==\n{np.sum(arr2*np.conj(arr2))}')


In [None]:
!wget -nc https://g-110014.fd635.8443.data.globus.org/holotomocupy/examples_synthetic/data/prb_id16a/prb_abs_2048.tiff -P ../data/prb_id16a
!wget -nc https://g-110014.fd635.8443.data.globus.org/holotomocupy/examples_synthetic/data/prb_id16a/prb_phase_2048.tiff -P ../data/prb_id16a

prb_abs = read_tiff(f'../data/prb_id16a/prb_abs_2048.tiff')[0:1]
prb_phase = read_tiff(f'../data/prb_id16a/prb_phase_2048.tiff')[0:1]
prb = prb_abs*np.exp(1j*prb_phase).astype('complex64')


for k in range(1):
    prb = prb[:, ::2]+prb[:, 1::2]
    prb = prb[:, :, ::2]+prb[:, :, 1::2]/4

prb = prb[:, 512-ne//2:512+ne//2, 512-ne//2:512+ne//2]
prb /= np.mean(np.abs(prb))
prb=cp.array(prb)
mshow_polar(prb[0],show)

In [None]:
psi = cp.array(psi)
data = np.abs(Lop(psi,prb))**2
ref = np.abs(Lop(psi*0+1,prb))**2

for k in range(ndist):
    mshow_complex(data[0,k]+1j*(data[0,k]-data[0,0]),show)



In [None]:
ndata = (cp.random.poisson(data*200)/200).astype('float32')
nref = (cp.random.poisson(ref*200)/200).astype('float32')
mshow_complex(ndata[0,0]-data[0,0]+1j*ndata[0,0],show)

In [None]:
import os 
os.system(f"mkdir -p {path}")
np.save(f'{path}/data',ndata.get())
np.save(f'{path}/psi',psi.get())
np.save(f'{path}/prb',prb.get())
np.save(f'{path}/ref',nref.get())
