In [None]:
import numpy as np
import cupy as cp
import cv2
import xraylib
from holotomocupy.utils import *

In [None]:
ne = 768
n = 256
npos = 500
show = True
path = f'/data/vnikitin/paper/far_field'
voxelsize = 1e-8
energy = 33.35  # [keV] xray energy
wavelength = 1.2398419840550367e-09/energy  # [m] wave length

In [None]:
img = np.zeros((ne, ne, 3), np.uint8)
triangle = np.array([(ne//16, ne//2-ne//32), (ne//16, ne//2+ne//32), (ne//2-ne//128, ne//2)], np.float32)
star = img[:,:,0]*0
for i in range(0, 360, 15):
    img = np.zeros((ne, ne, 3), np.uint8)
    degree = i
    theta = degree * np.pi / 180
    rot_mat = np.array([[np.cos(theta), -np.sin(theta)],
                        [np.sin(theta), np.cos(theta)]], np.float32)    
    rotated = cv2.gemm(triangle-ne//2, rot_mat, 1, None, 1, flags=cv2.GEMM_2_T)+ne//2
    cv2.fillPoly(img, [np.int32(rotated)], (255, 0, 0))
    star+=img[:,:,0]
[x,y] = np.meshgrid(np.arange(-ne//2,ne//2),np.arange(-ne//2,ne//2))
x = x/ne*2
y = y/ne*2
# add holes in triangles
circ = (x**2+y**2>0.145)+(x**2+y**2<0.135)
circ *= (x**2+y**2>0.053)+(x**2+y**2<0.05)
circ *= (x**2+y**2>0.0085)+(x**2+y**2<0.008)
circ *= (x**2+y**2>0.52)+(x**2+y**2<0.5)
# circ *= (x**2+y**2<0.5**2)
star = star*circ/255


# star[ne//2-n//2:ne//2+n//2,ne//2-n//2:ne//2+n//2]=0

v = np.arange(-ne//2,ne//2)/ne
[vx,vy] = np.meshgrid(v,v)
v = np.exp(-30*(vx**2+vy**2))
fu = np.fft.fftshift(np.fft.fftn(np.fft.fftshift(star)))
star = np.fft.fftshift(np.fft.ifftn(np.fft.fftshift(fu*v))).real

delta = 1-xraylib.Refractive_Index_Re('Au',energy,19.3)
beta = xraylib.Refractive_Index_Im('Au',energy,19.3)

thickness = 2e-6/voxelsize # siemens star thickness in pixels
# form Transmittance function
u = star*(-delta+1j*beta) # note -delta
Ru = u*thickness 
psi = np.exp(1j * Ru * voxelsize * 2 * np.pi / wavelength).astype('complex64')
fig, axs = plt.subplots(1, 2, figsize=(9, 4))
im=axs[0].imshow(np.abs(psi),cmap='gray')
axs[0].set_title('amplitude')
fig.colorbar(im)
im=axs[1].imshow(np.angle(psi),cmap='gray')
axs[1].set_title('phase')
fig.colorbar(im)

In [None]:
shifts = np.load(f'{path}/data/positions_px.npy')[0,:npos]
prb = np.load(f'{path}/data/probe.npy')[0]*n
plt.plot(shifts[:,0],shifts[:,1],'.')
plt.axis('square')
plt.show()
mshow_polar(prb,show)


In [None]:
extra = 8

def Lop(psi):
    data = cp.zeros([npos, n, n], dtype='complex64')
    data = np.fft.fft2(psi,norm='ortho')
    return data

def LTop(d):
    psi = cp.zeros([npos, n, n], dtype='complex64')
    psi = cp.fft.ifft2(d,norm='ortho')#*(n**2)
    return psi

def Ex(psi,ix,ex):
    res = cp.empty([ix.shape[0],n+2*ex,n+2*ex],dtype='complex64')
    stx = ne//2-ix[:,1]-n//2-ex
    endx = stx+n+2*ex
    sty = ne//2-ix[:,0]-n//2-ex
    endy = sty+n+2*ex
    for k in range(len(stx)):
        res[k] = psi[sty[k]:endy[k],stx[k]:endx[k]] 
    return res

def ExT(psi,psir,ix,ex):
    stx = ne//2-ix[:,1]-n//2-ex
    endx = stx+n+2*extra
    sty = ne//2-ix[:,0]-n//2-ex
    endy = sty+n+2*extra
    for k in range(len(stx)):
        psi[sty[k]:endy[k],stx[k]:endx[k]] += psir[k]
    return psi

def S(psi,p):
    n = psi.shape[-1]    
    res = cp.pad(psi, ((0, 0), (n//2, n//2), (n//2, n//2)),'constant')
    x = cp.fft.fftfreq(2*n).astype('float32')
    [y, x] = cp.meshgrid(x, x)
    pp = cp.exp(-2*cp.pi*1j * (y*p[:, 1, None, None]+x*p[:, 0, None, None]))
    res = cp.fft.ifft2(pp*cp.fft.fft2(res))
    res = res[:, n//2:-n//2, n//2:-n//2]
    return res

def Sop(psi,ix,x,ex):
    data = cp.zeros([x.shape[1], n, n], dtype='complex64')
    psir = Ex(psi,ix,ex) 
    psir = S(psir,x)
    data = psir[:, ex:-ex, ex:-ex]
    return data

def STop(d,ix,x,ex):
    psi = cp.zeros([ne, ne], dtype='complex64')
    p = cp.asarray(x)                
    dr = cp.pad(d, ((0, 0), (ex, ex), (ex, ex)),'constant')
    dr = S(dr,-p)        
    ExT(psi,dr,ix,ex)
    return psi

# adjoint tests
arr1 = (cp.random.random([ne,ne])+1j*cp.random.random([ne,ne])).astype('complex64')
shifts = cp.array(shifts)
ishifts = shifts.astype('int32')
fshifts = shifts-ishifts
extra = 8
arr2 = Sop(arr1,ishifts,fshifts,extra)
arr3 = STop(arr2,ishifts,fshifts,extra)
print(f'{np.sum(arr1*np.conj(arr3))}==\n{np.sum(arr2*np.conj(arr2))}')

arr1 = (cp.random.random([npos,n,n])+1j*cp.random.random([npos,n,n])).astype('complex64')
arr2 = Lop(arr1)
arr3 = LTop(arr2)
print(f'{np.sum(arr1*np.conj(arr3))}==\n{np.sum(arr2*np.conj(arr2))}')




In [None]:
ishifts = shifts.astype('int32')
fshifts = shifts-ishifts
extra = 8
psi = cp.array(psi)
prb = cp.array(prb)

xgt = np.fft.fftshift(Lop(prb*Sop(psi,ishifts,fshifts,extra)),axes=(-1,-2))
data = np.fft.fftshift(np.abs(Lop(prb*Sop(psi,ishifts,fshifts,extra)))**2,axes=(-1,-2))
x0 = np.fft.fftshift(Lop(prb*(1+0*Sop(psi,ishifts,fshifts,extra))**2),axes=(-1,-2))

data=data[0]
x0=x0[0]
xgt=xgt[0]

mshow(data,show)
mshow_polar(x0,show)
mshow_polar(xgt,show)



In [None]:
def gradientF(d,psi):
    d0n = d/(cp.abs(psi) +eps)
    psi0p = psi/cp.abs(psi) 
    res = 2*(psi-d0n*psi0p)       
    return res

In [None]:
def hessianF(psi,dpsi1,dpsi2,d):
    psi0p = psi/(cp.abs(psi)+eps)   
    v1 = cp.sum((1-d/(cp.abs(psi)**2+eps))*reprod(dpsi1,dpsi2))
    v2 = 2*cp.sum(d*reprod(psi0p,dpsi1)*reprod(psi0p,dpsi2)/(cp.abs(psi)**2 +eps))
    print(v1,v2)
    return 2*(v1+v2)

In [None]:
def calc_beta(psi,dpsi1,dpsi2,d):    
    top += hessianF(psi, dpsi1, dpsi2, d)  
    bottom += hessianF(psi, dpsi2, dpsi2, d)    
    return top/bottom

def calc_alpha(psi,dpsi1,dpsi2,d):
    top = -redot(dpsi1,dpsi2)    
    bottom = hessianF(psi, dpsi2, dpsi2, d)    
    return top/bottom, top, bottom

## debug functions

In [None]:
def minf(psi,d):
    f = cp.sum(cp.abs(psi)**2-2*d*cp.log(cp.abs(psi)))        
    return f

def plot_debug2(psi,dpsi2,top,bottom,alpha,data):
    '''Check the minimization functional behaviour'''
    
    npp = 3
    errt = cp.zeros(npp*2)
    errt2 = cp.zeros(npp*2)
    for k in range(0,npp*2):
        psit = psi+(alpha*k/(npp-1))*dpsi2
        errt[k] = minf(psit,data)
                
    t = alpha*(cp.arange(2*npp))/(npp-1)    
    errt2 = minf(psi,data)
    errt2 = errt2 -top*t+0.5*bottom*t**2    
    
    plt.plot(alpha.get()*cp.arange(2*npp).get()/(npp-1),errt.get(),'.')
    plt.plot(alpha.get()*cp.arange(2*npp).get()/(npp-1),errt2.get(),'.')
    plt.show()


In [None]:
def cg_holo(data, psi):
    for i in range(32):    
        
        grad = gradientF(data,psi) 
        eta = -grad
        
        alpha,top,bottom = calc_alpha(psi,grad,eta,data) 
        # alpha*=2
        print(f'{alpha=}')
        plot_debug2(psi,eta,top,bottom,alpha,data)
        
        psi += alpha*eta                    
    return 

eps = 0
# data=cp.array(data)
# psi0=cp.array(psi0)
cg_holo(data,xgt)


