In [None]:
import numpy as np
import cupy as cp
import cv2
import xraylib
import matplotlib.pyplot as plt
from utils import write_tiff, read_tiff
from utils import mshow, mshow_polar, mshow_complex
cp.cuda.Device(1).use()


In [None]:
n = 256  # data size in each dimension
nobj = 1024 # object size in each dimension
pad = 0#n//16 # pad for the reconstructed probe
nprb = n+2*pad # probe size
extra = 8 # extra padding for shifts
npatch = nprb+2*extra # patch size for shifts
energy = 9
wavelength = 1.24e-09/energy  # [m] wave length
npos = 961 # total number of positions
voxelsize = 8e-09

show = True # do visualization or not at all

path = f'/data/vnikitin/paper/far_field' # input data path 

In [None]:

def Lop(psi):   
    """Forward propagator""" 

    # convolution
    # ff = cp.pad(psi,((0,0),(nprb//2,nprb//2),(nprb//2,nprb//2)))
    ff = cp.fft.fft2(psi,norm='ortho')
    # ff = ff[:,nprb//2:-nprb//2,nprb//2:-nprb//2]
    
    # crop to detector size
    ff = ff[:,pad:nprb-pad,pad:nprb-pad]
    return ff

def LTop(psi):
    """Adjoint propagator""" 

    # pad to the probe size
    ff = cp.pad(psi,((0,0),(pad,pad),(pad,pad)))    
    
    # convolution
    # ff = cp.pad(ff,((0,0),(nprb//2,nprb//2),(nprb//2,nprb//2)))    
    ff = cp.fft.ifft2(ff,norm='ortho')
    # ff = ff[:,nprb//2:-nprb//2,nprb//2:-nprb//2]
    return ff

def Ex(psi,ix):
    """Extract patches"""

    res = cp.empty([ix.shape[0],npatch,npatch],dtype='complex64')
    stx = nobj//2-ix[:,1]-npatch//2
    endx = stx+npatch
    sty = nobj//2-ix[:,0]-npatch//2
    endy = sty+npatch
    for k in range(len(stx)):
        res[k] = psi[sty[k]:endy[k],stx[k]:endx[k]]     
    return res

def ExT(psi,psir,ix):
    """Adjoint extract patches"""

    stx = nobj//2-ix[:,1]-npatch//2
    endx = stx+npatch
    sty = nobj//2-ix[:,0]-npatch//2
    endy = sty+npatch
    for k in range(len(stx)):
        psi[sty[k]:endy[k],stx[k]:endx[k]] += psir[k]
    return psi

def S(psi,p):
    """Subpixel shift"""

    x = cp.fft.fftfreq(npatch).astype('float32')
    [y, x] = cp.meshgrid(x, x)
    pp = cp.exp(-2*cp.pi*1j * (y*p[:, 1, None, None]+x*p[:, 0, None, None])).astype('complex64')
    res = cp.fft.ifft2(pp*cp.fft.fft2(psi))
    return res

def Sop(psi,ix,x,ex):
    """Extract patches with subpixel shift"""
    data = cp.zeros([x.shape[1], nprb, nprb], dtype='complex64')
    psir = Ex(psi,ix)     
    psir = S(psir,x)
    data = psir[:, ex:npatch-ex, ex:npatch-ex]
    return data

def STop(d,ix,x,ex):
    """Adjont extract patches with subpixel shift"""
    psi = cp.zeros([nobj, nobj], dtype='complex64')
    dr = cp.pad(d, ((0, 0), (ex, ex), (ex, ex)))
    dr = S(dr,-x)        
    ExT(psi,dr,ix)
    return psi

# adjoint tests
shifts_test = 30*(cp.random.random([npos,2])-0.5).astype('float32')
ishifts = shifts_test.astype('int32')
fshifts = shifts_test-ishifts

arr1 = (cp.random.random([nobj,nobj])+1j*cp.random.random([nobj,nobj])).astype('complex64')
arr2 = Ex(arr1,ishifts)
arr3 = arr1*0
ExT(arr3,arr2,ishifts)
print(f'{cp.sum(arr1*cp.conj(arr3))}==\n{cp.sum(arr2*cp.conj(arr2))}')

arr1 = (cp.random.random([nobj,nobj])+1j*cp.random.random([nobj,nobj])).astype('complex64')
arr2 = Sop(arr1,ishifts,fshifts,extra)
arr3 = STop(arr2,ishifts,fshifts,extra)
print(f'{cp.sum(arr1*cp.conj(arr3))}==\n{cp.sum(arr2*cp.conj(arr2))}')

arr1 = (cp.random.random([npos,nprb,nprb])+1j*cp.random.random([npos,nprb,nprb])).astype('complex64')
arr2 = Lop(arr1)
arr3 = LTop(arr2)
print(f'{cp.sum(arr1*cp.conj(arr3))}==\n{cp.sum(arr2*cp.conj(arr2))}')

# Generate object

In [None]:
img = np.zeros((nobj, nobj, 3), np.uint8)
triangle = np.array([(nobj//16, nobj//2-nobj//32), (nobj//16, nobj//2+nobj//32), (nobj//2-nobj//64, nobj//2)], np.float32)
star = img[:,:,0]*0
for i in range(0, 360, 15):
    img = np.zeros((nobj, nobj, 3), np.uint8)
    degree = i
    theta = degree * np.pi / 180
    rot_mat = np.array([[np.cos(theta), -np.sin(theta)],
                        [np.sin(theta), np.cos(theta)]], np.float32)    
    rotated = cv2.gemm(triangle-nobj//2, rot_mat, 1, None, 1, flags=cv2.GEMM_2_T)+nobj//2
    cv2.fillPoly(img, [np.int32(rotated)], (255, 0, 0))
    star+=img[:,:,0]

star = cp.array(star)

[x,y] = cp.meshgrid(cp.arange(-nobj//2,nobj//2),cp.arange(-nobj//2,nobj//2))
x = x/nobj*2
y = y/nobj*2
# add holes in triangles
circ = (x**2+y**2>0.145)+(x**2+y**2<0.135)
circ *= (x**2+y**2>0.053)+(x**2+y**2<0.05)
circ *= (x**2+y**2>0.0085)+(x**2+y**2<0.008)
circ *= (x**2+y**2>0.32)+(x**2+y**2<0.3)

circ *= (x**2+y**2<0.65**2)
bg =cp.random.random(star.shape)-0.5
v = cp.arange(-nobj//2,nobj//2)/nobj
[vx,vy] = cp.meshgrid(v,v)
v = cp.exp(-4000*(vx**2+vy**2))
bg = cp.fft.fftshift(np.fft.fftn(np.fft.fftshift(bg)))
bg = cp.fft.fftshift(np.fft.ifftn(np.fft.fftshift(bg*v))).real


star = star/255
star*=circ

# add rectangles randomly
nrect = 400
max_size = 16
min_size = 3
import random
irect=0
for _ in range(10000):
    x = random.randint(0, nobj)
    y = random.randint(0, nobj)
    width = random.randint(min_size, max_size)
    height = random.randint(min_size, max_size)

    # Ensure the rectangle stays within bounds
    if x + width > nobj:
        width = nobj-x
    if y + height > nobj:
        height = nobj-y
    
    if cp.sum(star[y:y+height,x:x+width]>0)==height*width:
        star[y:y+height,x:x+width]=0.5
        irect+=1
        # print(irect)
    if irect==nrect:
        break
bg-=cp.min(bg)
star += bg*30
# star[star<0]=0
# smooth
v = cp.arange(-nobj//2,nobj//2)/nobj
[vx,vy] = np.meshgrid(v,v)
v = cp.exp(-10*(vx**2+vy**2))
fu = cp.fft.fftshift(cp.fft.fftn(cp.fft.fftshift(star)))
star = cp.fft.fftshift(cp.fft.ifftn(cp.fft.fftshift(fu*v))).real

delta = 1-xraylib.Refractive_Index_Re('Au',energy,19.3)
beta = xraylib.Refractive_Index_Im('Au',energy,19.3)

thickness = 1e-6/voxelsize/2 # siemens star thickness in pixels
# form Transmittance function
u = star*(-delta+1j*beta) # note -delta
Ru = u*thickness 
psi = np.exp(1j * Ru * voxelsize * 2 * np.pi / wavelength).astype('complex64')
psi=cp.array(psi)

mshow_polar(psi,show)

# Read probe and taper it

In [None]:
prb = cp.load('probe.npy')[0,0]

mshow_polar(prb,mshow)

# Read probe shifts

In [None]:
shifts = cp.load('shifts.npy')*1.9

plt.plot(shifts[:,0].get(),shifts[:,1].get(),'.')
plt.axis('square')
plt.grid()
plt.show()


# generate data

In [None]:

ishifts = cp.floor(shifts).astype('int32')
fshifts = (shifts-ishifts).astype('float32')
psi = cp.array(psi)
prb = cp.array(prb)
data = np.abs(Lop(prb*Sop(psi,ishifts,fshifts,extra)))**2
ref = np.abs(Lop(prb*(1+0*Sop(psi*0+1,ishifts,fshifts,extra))))**2
ref = ref[0]

rdata = data/(ref+1e-11)
mshow(cp.fft.fftshift(data[0],axes=(-2,-1)),mshow)
mshow(cp.fft.fftshift(ref,axes=(-2,-1)),mshow)

In [None]:

from scipy.ndimage import geometric_transform


# Define the deformation function
def deformation(coords):
    x, y = coords
    # Add a small sinusoidal deformation
    new_x = x + 0.1 * np.sin(2 * np.pi * y / 100)
    new_y = y + 0.1 * np.cos(2 * np.pi * x / 100)
    return new_x, new_y

# Apply the deformation
deformed_prb = geometric_transform(prb.get(), deformation, order=1, mode='reflect')
# mshow_polar(deformed_prb,True)
# mshow_polar(deformed_prb,True)
print(np.amin(np.abs(prb)))
print(np.amin(np.abs(deformed_prb)))

# save data

In [None]:
np.save(f'{path}/data/data',data.get())
np.save(f'{path}/data/ref',ref.get())
np.save(f'{path}/data/gen_prb',prb.get())
np.save(f'{path}/data/deformed_prb',deformed_prb)
np.save(f'{path}/data/gen_shifts',shifts.get())
np.save(f'{path}/data/psi',psi.get())


# add very-very low noise to avoid numerical precision errors

In [None]:
std_dev = 0.00001  # Standard deviation of the noise

# Generate Gaussian noise
noise = cp.random.normal(0, std_dev, size=data.shape).astype('float32')
ndata = data+noise
ndata[ndata<np.amin(data)] = np.amin(data)
np.save(f'{path}/data/data',ndata.get())
mshow(data[0]-ndata[0],show)


# Add Gaussian noise

In [None]:
std_dev = 0.01  # Standard deviation of the noise

# Generate Gaussian noise
noise = cp.random.normal(0, std_dev, size=data.shape).astype('float32')
ndata = data+noise
ndata[ndata<np.amin(data)]=np.amin(data)
np.save(f'{path}/data/ndata',ndata.get())
mshow(ndata[0],show)
mshow(data[0]-ndata[0],show)

# Generate shifts error

In [None]:
shifts_random = shifts.get()+4*(np.random.random([npos,2])-0.5)
shifts_random = shifts_random.astype('float32')
np.save(f'{path}/data/gen_shifts_random',shifts_random)

