In [1]:
import numpy as np
import cupy as cp
import sys
import pandas as pd
import time
import matplotlib.pyplot as plt
import h5py
from types import SimpleNamespace
import warnings
warnings.filterwarnings("ignore", message=f".*peer.*")

sys.path.insert(0, '..')
from utils import *
from rec import Rec

# Init data sizes and parametes of the PXM of ID16A

In [2]:
n = 256  # object size in each dimension
ntheta = 128  # number of angles (rotations)
theta = np.linspace(0, np.pi, ntheta).astype('float32')  # projection angles

ndist = 4
detector_pixelsize = 3e-6
energy = 33.35  # [keV] xray energy
wavelength = 1.2398419840550367e-09/energy  # [m] wave length

focusToDetectorDistance = 1.208  # [m]
sx0 = -2.493e-3
z1 = np.array([1.5335e-3, 1.7065e-3, 2.3975e-3, 3.8320e-3])[:ndist]-sx0
z2 = focusToDetectorDistance-z1
distances = (z1*z2)/focusToDetectorDistance
magnifications = focusToDetectorDistance/z1
voxelsize = detector_pixelsize/magnifications[0]*2048/n/2  # object voxel size

norm_magnifications = magnifications/magnifications[0]
distances = distances*norm_magnifications**2
npsi = int(np.ceil(n/norm_magnifications[-1]/8))*8  # make multiple of 8

In [3]:
args = SimpleNamespace()
args.ngpus = 1
args.n = n
args.ndist = ndist
args.ntheta = ntheta
args.pad = 0
args.npsi = npsi
args.nq = args.n + 2 * args.pad
args.nchunk = 32

args.voxelsize = voxelsize
args.wavelength = wavelength
args.distance = distances
args.rotation_axis=args.npsi/2

args.theta = theta
args.norm_magnifications = norm_magnifications
# create class
cl_rec = Rec(args)


## Read real and imaginary parts of the refractive index u = delta+i beta

In [None]:
data = np.zeros([ntheta,ndist,n,n],dtype='float32')
ref = np.zeros([ndist,n,n],dtype='float32')
for k in range(len(distances)):
    data[:,k] = read_tiff(f'/data/vnikitin/syn/data/data_{k}.tiff')
for k in range(len(distances)):
    ref[k] = read_tiff(f'/data/vnikitin/syn/ref_{k}.tiff')
r = np.load(f'/data/vnikitin/syn/r.npy')

for k in range(ndist):
    mshow(data[0,k]/ref[k],True)

In [None]:
rdata = data/ref
srdata = np.zeros([ntheta,ndist,args.npsi,args.npsi],dtype='float32')
distances_pag = (distances/norm_magnifications**2)
npad=n//16
for j in np.arange(ndist)[::-1]:
    print(j)
    tmp = cl_rec.STa(r[:,j]*norm_magnifications[j],rdata[:,j].astype('complex64'),
                     'edge')    
    mshow_complex(tmp[0],True)
    # tmp=cp.array(tmp)
    tmp = (cl_rec.MT(tmp,j)/norm_magnifications[j]**2).real    
    mshow(tmp[0],True)
    # ss
    st = np.where(tmp[0]>1e-1)[0][0]+4
    
    if j==ndist-1:
         tmp = np.pad(tmp[:,st:-st,st:-st],((0,0),(st,st),(st,st)),'symmetric')
    if j<ndist-1:
        w = np.ones([args.npsi],dtype='float32')  
        v = np.linspace(0, 1, npad, endpoint=False)
        v = v**5*(126-420*v+540*v**2-315*v**3+70*v**4)             
        w[:st]=0
        w[st:st+npad] = v
        w[-st-npad:-st] = 1-v
        w[-st:]=0
        w=np.outer(w,w)
        tmp=tmp*(w)+srdata[:,j+1]*(1-w)       
    srdata[:,j]=tmp
    mshow(srdata[0,j],True)

In [None]:
def multiPaganin(data, distances, wavelength, voxelsize, delta_beta,  alpha):    
    
    fx = cp.fft.fftfreq(data.shape[-1], d=voxelsize).astype('float32')
    [fx, fy] = cp.meshgrid(fx, fx)
    numerator = 0
    denominator = 0
    for j in range(data.shape[0]):        
        rad_freq = cp.fft.fft2(data[j])
        taylorExp = 1 + wavelength * distances[j] * cp.pi * (delta_beta) * (fx**2+fy**2)
        numerator = numerator + taylorExp * (rad_freq)
        denominator = denominator + taylorExp**2

    numerator = numerator / len(distances)
    denominator = (denominator / len(distances)) + alpha

    phase = cp.log(cp.real(cp.fft.ifft2(numerator / denominator)))
    phase = (delta_beta) * 0.5 * phase

    return phase


def rec_init(rdata):
    recMultiPaganin = np.zeros([args.ntheta,args.npsi, args.npsi], dtype="float32")
    for j in range(0, args.ntheta):
        r = cp.array(rdata[j])
        distances_pag = (distances/norm_magnifications**2)
        r = multiPaganin(r, distances_pag,wavelength, voxelsize,120, 1e-3)            
        recMultiPaganin[j] = r.get()           
    
    recMultiPaganin-=np.mean(recMultiPaganin[:,:,:50])
    recMultiPaganin = np.exp(1j * recMultiPaganin)
    return recMultiPaganin
psi_init = rec_init(srdata)
print(np.mean(np.angle(psi_init[:,:,:50])))
mshow_polar(psi_init[0],True)
mshow_polar(psi_init[-1],True)
write_tiff(np.angle(psi_init),'/data/tmp/tmp5')
# done

In [None]:
# psi_data = np.log(psi_init)/1j
# u_init = cl_rec.rec_tomo(psi_data,32)
# np.save('/data/vnikitin/syn/u_init',u_init)
u_init=np.load('/data/vnikitin/syn/u_init.npy')
mshow_complex(u_init[u_init.shape[0]//2],True)

In [None]:
q_init = np.ones([4,args.nq,args.nq],dtype='complex64')
for j in range(ndist):
    q_init[j] = cl_rec.DT(np.sqrt(ref[j:j+1]),j)[0]

mshow_polar(q_init[0],True)
mshow_polar(q_init[-1],True)


In [None]:
a = np.random.random([ntheta,npsi,npsi]).astype('float32')+1j*np.random.random([ntheta,npsi,npsi]).astype('float32')
b = cl_rec.M(a,2)
c = cl_rec.MT(b,2)
print(np.sum(a*np.conj(c)))
print(np.sum(b*np.conj(b)))

b = cl_rec.S(r[:,2],a)
c = cl_rec.ST(r[:,2],b)
print(np.sum(a*np.conj(c)))
print(np.sum(b*np.conj(b)))

a = np.random.random([ntheta,args.nq,args.nq]).astype('float32')+1j*np.random.random([ntheta,args.nq,args.nq]).astype('float32')
b = cl_rec.D(a,j)
c = cl_rec.DT(b,j)
print(np.sum(a*np.conj(c)))
print(np.sum(b*np.conj(b)))

a = np.random.random([npsi,npsi,npsi]).astype('float32')+1j*np.random.random([npsi,npsi,npsi]).astype('float32')
b = cl_rec.R(a)
c = cl_rec.RT(b)
print(np.sum(a*np.conj(c)))
print(np.sum(b*np.conj(b)))




In [10]:
np.random.seed(10)
rerr = r+(np.random.random(r.shape).astype('float32')-0.5)

In [None]:
args.niter=257
args.vis_step=32
args.err_step=4
args.lam=0
args.path_out='/data/vnikitin/syn'
args.show=True
args.rho=[1,25,15]
cl_rec = Rec(args)
vars={}
vars["u"] = cp.array(u_init.copy())
vars["q"] = cp.array(cp.array(q_init))
vars["r"] = cp.array(rerr)
vars["r_init"] = cp.array(r.copy())

vars["psi"] = cl_rec.R(vars['u'])        
vars["psi"][:] = cl_rec.expR(vars["psi"])        
vars["table"] = pd.DataFrame(columns=["iter", "err", "time"])    
vars = cl_rec.BH(cp.array(np.sqrt(data)), vars)  