In [None]:
import torch
import ptychi.api as api
from ptychi.api.task import PtychographyTask
from ptychi.utils import get_default_complex_dtype
import matplotlib.pyplot as plt
import numpy as np
from utils_ptychi import *

In [None]:
n = 1024  # data size in each dimension
nobj = n+n//4 # object size in each dimension
pad = 0#n//16 # pad for the reconstructed probe
nprb = n+2*pad # probe size
extra = 0 # extra padding for shifts
npatch = nprb+2*extra # patch size for shifts

npos = 16 # total number of positions
z1 = 4.267e-3 # [m] position of the sample
detector_pixelsize = 3.0e-6
energy = 33.35  # [keV] xray energy
wavelength = 1.24e-09/energy  # [m] wave length
focusToDetectorDistance = 1.28  # [m]
# adjustments for the cone beam
z2 = focusToDetectorDistance-z1
distance = (z1*z2)/focusToDetectorDistance
magnification = focusToDetectorDistance/z1
voxelsize = float(np.abs(detector_pixelsize/magnification))

# reconstructon parameters
cpag = 5e-3
method = 'BH'
niter = 32 # number of iterations
eps = 1e-8 # to avoid division by 0
noise = False # noisy data

err_step = 1 # iteration step to calculate the minimization functional
vis_step = -1 # iteration step to visualize reconstructions
show = True # do visualization or not at all
flg = f'{method}_{noise}' # flg to save data

path = f'/data/vnikitin/paper/near_field' # input data path 
path_out = f'/data/vnikitin/paper/near_field/rec0' #output data path

In [None]:
shifts = np.load(f'{path}/data/gen_shifts.npy')[:npos]
shifts_wrong = np.load(f'{path}/data/gen_shifts_random.npy')[:npos]
prb = np.load(f'{path}/data/gen_prb.npy')
if noise:
    data = np.load(f'{path}/data/ndata.npy')[:npos]
else:
    data = np.load(f'{path}/data/data.npy')[:npos]
ref = np.load(f'{path}/data/ref.npy')
psi = np.load(f'{path}/data/psi.npy')

mshow_polar(prb,show)
mshow_complex((data[0]+1j*data[0]/ref),show,vmax=3)

In [None]:
def Paganin(data, wavelength, voxelsize, delta_beta,  alpha):
    fx = np.fft.fftfreq(data.shape[-1], d=voxelsize).astype('float32')
    [fx, fy] = np.meshgrid(fx, fx)
    rad_freq = np.fft.fft2(data)
    taylorExp = 1 + wavelength * distance * np.pi * (delta_beta) * (fx**2+fy**2)
    numerator = taylorExp * (rad_freq)
    denominator = taylorExp**2 + alpha
    phase = np.log(np.real(np.fft.ifft2(numerator / denominator)))
    phase = delta_beta * 0.5 * phase
    return phase

def rec_init(rdata,ishifts):
    recMultiPaganin = np.zeros([nobj,nobj],dtype='float32')
    recMultiPaganinr = np.zeros([nobj,nobj],dtype='float32')# to compensate for overlap
    for j in range(0,npos):
        r = rdata[j]        
        rr = r*0+1 # to compensate for overlap                
        rpsi = np.ones([nobj,nobj],dtype='float32')
        rrpsi = np.ones([nobj,nobj],dtype='float32')
        stx = nobj//2+ishifts[j,1]-n//2
        endx = stx+n
        sty = nobj//2+ishifts[j,0]-n//2
        endy = sty+n
        rpsi[sty:endy,stx:endx] = r
        rrpsi[sty:endy,stx:endx] = rr
        rpsi = Paganin(rpsi, wavelength, voxelsize,  24.05, cpag)
        recMultiPaganin += rpsi
        recMultiPaganinr += rrpsi
        
    recMultiPaganinr[np.abs(recMultiPaganinr)<5e-2] = 1    
    recMultiPaganin /= recMultiPaganinr    
    recMultiPaganin = np.exp(1j*recMultiPaganin)
    return recMultiPaganin

ishifts = np.round(shifts).astype('int32')
rdata = data/(ref+eps)
rdata = np.pad(rdata[:,n//16:-n//16,n//16:-n//16],((0,0),(n//16,n//16),(n//16,n//16)),'symmetric')
rec_paganin = rec_init(rdata,ishifts)
mshow_polar(rec_paganin,show)

In [None]:
options = api.BHOptions()
options.data_options.data = np.fft.fftshift(data,axes=(-2,-1))
options.data_options.wavelength_m = wavelength
options.data_options.free_space_propagation_distance_m = distance

options.object_options.initial_guess = torch.from_numpy(rec_paganin[None])
options.object_options.pixel_size_m = voxelsize
options.object_options.optimizable = True


options.probe_options.initial_guess = torch.ones([1,1,nprb, nprb],dtype=torch.complex64)
options.probe_options.optimizable = True
options.probe_options.rho = 1

options.probe_position_options.position_x_px = shifts_wrong[:, 1]
options.probe_position_options.position_y_px = shifts_wrong[:, 0]
options.probe_position_options.optimizable = True
options.probe_position_options.rho = 0.1

options.reconstructor_options.batch_size = 16
options.reconstructor_options.num_epochs = 32

# options.reconstructor_options.csv = flg 
options.reconstructor_options.method = 'CG'
task = PtychographyTask(options)

task.run()

recon = task.get_data_to_cpu('object', as_numpy=True)[0]
rec_prb = task.get_data_to_cpu('probe', as_numpy=True)[0,0]
rec_pos = task.get_data_to_cpu('probe_positions', as_numpy=True)

mshow_polar(recon,mshow)
mshow_polar(rec_prb,mshow)

plt.plot(shifts[:,0],shifts[:,1],'.')
plt.plot(rec_pos[:,0],rec_pos[:,1],'.')
plt.axis('square')
plt.grid()
plt.show()

# dxchange.write_tiff(np.abs(recon),f'{path_out}_{flg}/crec_psi_abs/0',overwrite=True)
# dxchange.write_tiff(np.angle(recon),f'{path_out}_{flg}/crec_psi_angle/0',overwrite=True)
# dxchange.write_tiff(np.abs(rec_prb[0]),f'{path_out}_{flg}/crec_prb_abs/0',overwrite=True)
# dxchange.write_tiff(np.angle(rec_prb[0]),f'{path_out}_{flg}/crec_prb_angle/0',overwrite=True)

