In [None]:
import matplotlib.pyplot as plt
import csv
import numpy as np
import matplotlib.patches as patches
import matplotlib
%matplotlib inline

In [None]:
n = 1024  # object size in each dimension
pad = 0 # pad for the reconstructed probe
npos = 16 # total number of positions
z1 = 4.267e-3 # [m] position of the sample
detector_pixelsize = 3.0e-6
energy = 33.35  # [keV] xray energy
wavelength = 1.24e-09/energy  # [m] wave length
focusToDetectorDistance = 1.28  # [m]
z2 = focusToDetectorDistance-z1
distance = (z1*z2)/focusToDetectorDistance
magnification = focusToDetectorDistance/z1
voxelsize = np.abs(detector_pixelsize/magnification)  # object voxel size

extra = 0
nobj = n+n//8
nprb = n+2*pad
npatch = nprb+2*extra

show = True
path = f'/data/vnikitin/paper/near_field'
path_out = f'/data/vnikitin/paper/near_field/rec0'
print(f'{voxelsize=},{distance=}')

In [None]:
noisea = [False]
zooma = [True]           
nitera = [127]        

for jj in range(0,6):
    plt.rcParams['font.family'] = 'Times New Roman'
    plt.rcParams['font.size'] = 15
    fnamea = [  'EPIE','LSQML','DM','AD','BH-GD','BH-CG']#,'DY-LS','BH-GD','BH-CG']
    names = [ 'EPIE','LSQML','DM','AD','BH-GD','BH-CG']#,'DY-LS','BH-GD', 'BH-CG']

    noise = noisea[jj]
    zoom = zooma[jj]
    niter = nitera[jj]

    step = niter//5
    fig, ax = plt.subplots()
    ax.set_xticks(np.arange(0, niter, step))


    if zoom:
        ins = ax.inset_axes([0.25,0.63,0.4,0.33])
        ins.tick_params(axis='y',labelsize=17)
    kk=0
    st = 5
    end = 76
    for fname in fnamea:
        fname = f'{fname}_{noise}'
        # print(fname)
        loss = np.zeros(niter)
        with open(fname, 'r') as file:
            csv_reader = csv.reader(file)
            k=0
            for row in csv_reader:
                if k>0:
                    loss[k-1] = float(row[1])**2*32**2
                k+=1 
                if k>niter:
                    break       
                
        if fname[:5]=='LSQML':
            sloss=loss[-1]#*1024**2*16*32
        loss[1]=loss[0]
        ax.plot(loss,label=names[kk])
        kk+=1
        plt.yscale('log')
        # plt.xscale('log')
        
        if zoom:        
            ins.plot(np.arange(st,end),loss[st:end])#*1024**2*16*32)
            ins.set_xticks(np.arange(st, end, 20))
            ins.set_yscale('log')    
            ins.grid('on')

    if zoom:
        ymin,ymax = ins.get_ylim()
        ins.plot((end-st)/2,10**((np.log10(ymax)+np.log10(ymin))/2),'r*',markersize=10,mec='black')    
        plt.plot((end-st)/2,10**((np.log10(ymax)+np.log10(ymin))/2),'r*',markersize=10,mec='black')   
    plt.legend()
    plt.grid('on')
    plt.plot([0,niter],[sloss,sloss],'--',color='0.5')
    plt.xlabel('iteration',fontsize=17)
    plt.ylabel('MSE',fontsize=17)
    plt.yticks(fontsize=17)
    plt.xticks(fontsize=15)

    str = 'Optimization: object and probe'
    str+='.'
    if noise:
        str+= f'       Gaussian noise.'
    else:
        str+= '        No noise.'        
    plt.title(str,fontsize=17)
    plt.savefig(f'Gaussian_{noise}.png',dpi=300,bbox_inches='tight')
    plt.show()

ss

In [None]:
obj_opt = True
prb_opta = [True, True, True, True, False, False ]
pos_opta = [True, True, False, False, False, False]
noisea = [False, True, True, False, True, False]
zooma = [True, True, True, True, True, True ]           
boldea = [False, False, False, False, True, True ]           
nitera = [4001, 501, 501, 4001, 501, 501]        

for jj in range(0,1):
    plt.rcParams['font.family'] = 'Times New Roman'
    plt.rcParams['font.size'] = 15
    fnamea = [  'epie','lsqml','DY-LS','BH-GD','BH-CG']#, 'fifth_rule', 'fifth_rule_cg']#,'fifth_rule_prb05']#, 'cg_fifth_rule','lsqml_2modes','lsqml_positions']
    names = [ 'EPIE','LSQML','DY-LS','BH-GD', 'BH-CG']#,  'GD with step size (Eq.)', 'new CG']


    prb_opt= prb_opta[jj]
    pos_opt = pos_opta[jj]
    noise = noisea[jj]
    zoom = zooma[jj]
    bolde = boldea[jj]
    niter = nitera[jj]

    step = niter//5
    fig, ax = plt.subplots()
    ax.set_xticks(np.arange(0, niter, step))

    kk=0
    st = 5
    end = 76
    for fname in fnamea:
        fname = f'{fname}_{obj_opt}_{prb_opt}_{pos_opt}_{noise}'
        # print(fname)
        loss = np.zeros(niter)
        with open(fname, 'r') as file:
            csv_reader = csv.reader(file)
            k=0
            for row in csv_reader:
                if k>0:
                    loss[k-1] = float(row[2])
                k+=1 
                if k>niter:
                    break               
        loss-=loss[0]
        ax.plot(loss,label=names[kk])        
        kk+=1        
        
    plt.legend()
    plt.grid('on')
    plt.plot([0,niter],[sloss,sloss],'--',color='0.5')
    plt.xlabel('iteration',fontsize=17)
    plt.ylabel('Time (s)',fontsize=17)
    plt.yticks(fontsize=17)
    plt.xticks(fontsize=15)


    str = 'Optimization: object'
    if prb_opt:
        str+=', probe'
    if pos_opt:
        str+=', positions'    
    str+='.'
    if noise:
        str+= f'       Gaussian noise.'
    else:
        str+= '        No noise.'        
    plt.title(str,fontsize=17)
    plt.savefig(f'time_Gaussian_{obj_opt}_{prb_opt}_{pos_opt}_{noise}.png',dpi=300,bbox_inches='tight')
    plt.show()


In [None]:

from matplotlib_scalebar.scalebar import ScaleBar
from mpl_toolkits.axes_grid1 import make_axes_locatable
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['font.size'] = 15
 
def show_siemens(psi,vvmin=None,vvmax=None):    
    print(psi.shape)
    crop = 0
    psi = psi[crop:1024-crop,crop:1024-crop]
    nobj = 1024-2*crop
    ni = 100
    st1 = nobj//2-ni//2
    end1 = nobj//2+ni//2
    st2 = 190-ni//2+30-crop
    end2 = 190+ni//2+30-crop
    st22 = 340-ni//2-crop
    end22 = 340+ni//2-crop
    fig, ax = plt.subplots()
    psi[0,0]=vvmin
    psi[-1,-1]=vvmax
    im = plt.imshow(np.roll(psi,0),cmap='gray',vmax=vvmax,vmin=vvmin)
    nerr = np.linalg.norm(psi)**2/1024/1024
    # plt.text(15,864,f'MSE = {nerr:.2e}',bbox=dict(facecolor='white', alpha=1))
    fig.colorbar(im, fraction=0.046, pad=0.02)
    vvmin, vvmax = im.get_clim()
    ins = ax.inset_axes([0.6,0,0.4,0.4])
    ins.set_xticks([])
    ins.set_yticks([])
    ins.imshow(psi[st1:end1,st1:end1],cmap='gray',vmax=vvmax,vmin=vvmin)
    ins.plot(-(st1-end1)/2,-(st1-end1)/2,'c*',markersize=10,mec='black')    
    plt.plot((end1+st1)/2,(st1+end1)/2,'c*',markersize=10,mec='black')    

    ins1 = ax.inset_axes([0.6,0.6,0.4,0.4])
    ins1.set_xticks([])
    ins1.set_yticks([])
    
    ins1.imshow((psi[st2:end2,st22:end22]),cmap='gray',vmax=vvmax,vmin=vvmin)
    ins1.plot(-(st2-end2)/2,-(st22-end22)/2,'r*',markersize=10,mec='black')    
    plt.plot((st22+end22)/2,(end2+st2)/2,'r*',markersize=10,mec='black')    

    scalebar = ScaleBar(voxelsize, "m", length_fraction=0.4,height_fraction=0.03,font_properties={
                "family": "serif","size":16,
            },  # For more information, see the cell below
            location="lower right")
    ins.add_artist(scalebar)    


path = f'/data/vnikitin/paper/near_field'
psi = np.load(f'{path}/data/psi.npy')
vvmin=-0.8
vvmax=0.1
show_siemens(np.angle(psi[64:-64,64:-64]),vvmin,vvmax)
plt.savefig(f'siemens.png',dpi=300,bbox_inches='tight')
plt.show()


In [None]:
prb = np.load(f'{path}/data/gen_prb.npy')*32
plt.rcParams['font.size'] = 24
fig, ax = plt.subplots()
im = plt.imshow(np.angle(prb),cmap='gray')
fig.colorbar(im, fraction=0.046, pad=0.02)

scalebar = ScaleBar(voxelsize, "m", length_fraction=0.4,height_fraction=0.01,font_properties={
            "family": "serif","size":18,
        },  # For more information, see the cell below
        location="lower right")
ax.add_artist(scalebar)
plt.savefig(f'prb_angle.png',dpi=300,bbox_inches='tight')
plt.show()

fig, ax = plt.subplots()
im = plt.imshow(np.abs(prb),cmap='gray')
fig.colorbar(im, fraction=0.046, pad=0.02)

scalebar = ScaleBar(voxelsize, "m", length_fraction=0.4,height_fraction=0.01,font_properties={
            "family": "serif","size":18,
        },  # For more information, see the cell below
        location="lower right")
ax.add_artist(scalebar)
plt.savefig(f'prb_amp.png',dpi=300,bbox_inches='tight')

In [None]:
data = np.load(f'{path}/data/data.npy')
plt.rcParams['font.size'] = 24
fig, ax = plt.subplots()

im = plt.imshow(data[0]*1024,cmap='gray',vmax=4100,vmin=300)
fig.colorbar(im, fraction=0.046, pad=0.02)

scalebar = ScaleBar(voxelsize*magnification, "m", length_fraction=0.2,height_fraction=0.02,font_properties={
            "family": "serif","size":22,
        },  # For more information, see the cell below
        location="lower right")
plt.plot(512,512,'rx',markersize=10)
ax.add_artist(scalebar)
plt.savefig(f'data0.png',dpi=300,bbox_inches='tight')

fig, ax = plt.subplots()
im = plt.imshow(data[8]*1024,cmap='gray',vmax=4100,vmin=300)
fig.colorbar(im, fraction=0.046, pad=0.02)

scalebar = ScaleBar(voxelsize*magnification, "m", length_fraction=0.2,height_fraction=0.02,font_properties={
            "family": "serif","size":22,
        },  # For more information, see the cell below
        location="lower right")
plt.plot(512,512,'rx',markersize=10)
ax.add_artist(scalebar)
plt.savefig(f'data1.png',dpi=300,bbox_inches='tight')

In [None]:
shifts = np.load(f'{path}/data/gen_shifts.npy')[:npos]
shifts_random = np.load(f'{path}/data/gen_shifts_random.npy')[:npos]



In [None]:
x = np.linalg.norm(shifts-shifts_random,axis=1)
mad = np.mean(np.absolute(x - np.mean(x)))
std_dev = np.std(x)
mmean = np.mean(shifts-shifts_random,axis=0)
fig, ax = plt.subplots()
plt.rcParams['font.size'] = 28
ax.plot(shifts[:,0],shifts[:,1],'.',color='red',label='correct',markersize=10)
ax.plot(shifts_random[:,0],shifts_random[:,1],'.',color='blue',label='misaligned',markersize=10)
ax.text(-16,-7,f'error',fontsize=26)
ax.text(-16,-12,f'STD = {std_dev:.2e}',fontsize=26)
ax.axis('square')
ax.legend(loc=(0.2,0.45),fontsize=26)
ax.grid()
plt.savefig(f'positions.png',dpi=300,bbox_inches='tight')
plt.show()

In [None]:
import dxchange
import scipy.ndimage as ndimage
matplotlib.use('Agg')
vvmax=0.1
vvmin=-0.8

obj_opt=True
prb_opt=True
pos_opt=True
noise=False
for noise in [False,True]:
    
    CG_flg=0
    DY_flg=0
    for method in ['epie','lsqml','BH-GD','DY-LS','DY-LS','BH-CG','BH-CG']:
        if noise:
            i = 512
        else:
            i = 4096
        if method=='BH-CG' and CG_flg==0:# and noise==False:
            CG_flg=1
            i = 128+32+32
        if method=='DY-LS' and DY_flg==0:# and noise==False:
            DY_flg=1
            i = 128+32+32

        flg = f'{method}_True_{prb_opt}_{pos_opt}_{noise}'
        shifts_random = np.load(f'{path}/data/gen_shifts_random.npy')[:npos]

        if method=='epie' or method=='lsqml':
            psi_angle = dxchange.read_tiff(f'{path_out}_{flg}/crec_psi_angle/0.tiff')[:]
            psi_abs = dxchange.read_tiff(f'{path_out}_{flg}/crec_psi_abs/0.tiff')[:]
            q_angle = dxchange.read_tiff(f'{path_out}_{flg}/crec_prb_angle/0.tiff')[:]
            q_abs = dxchange.read_tiff(f'{path_out}_{flg}/crec_prb_abs/0.tiff')[:]
            if pos_opt:
                shifts_rec = np.load(f'{path_out}_{flg}/crec_shift_0.npy')
            else:    
                shifts_rec = shifts
            
        else:
            print(f'{path_out}_{flg}')
            psi_angle = dxchange.read_tiff(f'{path_out}_{flg}/crec_psi_angle/{i:03}.tiff')[:]
            psi_abs = dxchange.read_tiff(f'{path_out}_{flg}/crec_psi_abs/{i:03}.tiff')[:]
            q_angle = dxchange.read_tiff(f'{path_out}_{flg}/crec_prb_angle/{i:03}.tiff')[:]
            q_abs = dxchange.read_tiff(f'{path_out}_{flg}/crec_prb_abs/{i:03}.tiff')[:]
            shifts = np.load(f'{path}/data/gen_shifts.npy')[:npos]
            if pos_opt:
                ishift = np.round(shifts_random).astype('int32')
            else:
                ishift = np.round(shifts).astype('int32')
            shifts_rec = ishift+np.load(f'{path_out}_{flg}/crec_shift_{i:03}.npy')

        psi = psi_abs*np.exp(1j*(psi_angle)-np.mean(psi_angle[0:128,0:128]))

        psi = psi[64:-64,64:-64]
        nobj = 1024#+1024//8
        ni = 100
        st1 = nobj//2-ni//2
        end1 = nobj//2+ni//2
        st2 = 190-ni//2+30
        end2 = 190+ni//2+30
        st22 = 340-ni//2
        end22 = 340+ni//2

        psi[st1,st1] =vvmin
        psi[st1+1,st1] =vvmax
        psi[st2,st22] =vvmin
        psi[st2,st22] =vvmax

        fig, ax = plt.subplots()
        im = plt.imshow(np.angle(ndimage.shift(psi,(0,-40),mode='reflect')),cmap='gray',vmin=vvmin,vmax=vvmax)
        
        ax.set_xticks([])
        ax.set_yticks([])

        ins.set_xticks([])
        ins.set_yticks([])

        ins.imshow(np.angle(psi[st1:end1,st1:end1]),cmap='gray',vmax=vvmax,vmin=vvmin)

        ins1 = ax.inset_axes([0.65,0.65,0.35,0.35])
        ins1.set_xticks([])
        ins1.set_yticks([])
        ins1.imshow(np.angle(psi[st2:end2,st22:end22]),cmap='gray',vmax=vvmax,vmin=vvmin)


        scalebar = ScaleBar(voxelsize, "m", length_fraction=0.4,height_fraction=0.03,font_properties={
                    "family": "serif","size":12,
                },  # For more information, see the cell below
                location="lower right")
        ins.add_artist(scalebar)
        plt.savefig(f'siemens_rec_{flg}.png',dpi=300,bbox_inches='tight')
        plt.close()
        x = np.linalg.norm(shifts-shifts_rec,axis=1)
        mad = np.mean(np.absolute(x - np.mean(x)))
        std_dev = np.std(x)
        mmean = np.mean(shifts-shifts_rec,axis=0)
        fig, ax = plt.subplots()
        # Set the font size
        plt.rcParams['font.size'] = 28
        ax.plot(shifts[:,0],shifts[:,1],'.',color='red',label='correct',markersize=10)
        ax.plot(shifts_rec[:,0],shifts_rec[:,1],'.',color='blue',label='refined',markersize=10)
        ax.text(-16,-5,f'error',fontsize=26)
        
        ax.text(-16,-10,f'STD = {std_dev:.2e}',fontsize=26)
        ax.text(-16,-15,f'OFFSET = ({mmean[0]:.1f}, {mmean[1]:.1f})',fontsize=26)
        ax.axis('square')
        ax.legend(loc=(0.2,0.45),fontsize=26)
        ax.grid()
        plt.savefig(f'positions_rec_{flg}_{i}.png',dpi=300,bbox_inches='tight')
        plt.close()

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
from PIL import Image
obj_opt=True
prb_opt=True
pos_opt=True
# Create figure and axes
fig, axes = plt.subplots(2, 7,figsize=(16, 5))  # 2x2 table
ii=0
for noise in [False,True]:
    DY_flg=0
    CG_flg=0
    
    for method in ['epie','lsqml','BH-GD','DY-LS','DY-LS','BH-CG','BH-CG']:
        if noise:
            i = 512
        else:
            i = 4096
        if method=='BH-CG' and CG_flg==0:# and noise==False:
            CG_flg=1
            i = 128+32+32
        if method=='DY-LS' and DY_flg==0:# and noise==False:
            DY_flg=1
            i = 128+32+32
        flg = f'{method}_True_{prb_opt}_{pos_opt}_{noise}'
        path = f'positions_rec_{flg}_{i}.png'        
        img = Image.open(path)
        ax = axes[ii // 7, ii % 7]  # Calculate row and column
        ax.imshow(img)
        ax.set_title(f'{method} noise={noise} niter={i}',fontsize=10)
        ax.axis("off")  # Turn off axis labels
        ii+=1

plt.tight_layout()
plt.savefig('rec_positions_all.png',dpi=600)
plt.show()



In [None]:
def show_siemens(psi,vvmin=None,vvmax=None):    
    print(psi.shape)
    psi = psi[64:-64,64:-64]
    nobj = 1024-1024//8
    ni = 100
    st1 = nobj//2-ni//2
    end1 = nobj//2+ni//2
    st2 = 190-ni//2+30-64
    end2 = 190+ni//2+30-64
    st22 = 340-ni//2-64
    end22 = 340+ni//2-64
    fig, ax = plt.subplots()
    psi[0,0]=vvmin
    psi[-1,-1]=vvmax
    im = plt.imshow(np.roll(psi,0),cmap='gray',vmax=vvmax,vmin=vvmin)
    nerr = np.linalg.norm(psi)**2/1024/1024
    plt.text(15,864,f'MSE = {nerr:.2e}',bbox=dict(facecolor='white', alpha=1))
    fig.colorbar(im, fraction=0.046, pad=0.02)
    vvmin, vvmax = im.get_clim()
    ins = ax.inset_axes([0.6,0,0.4,0.4])

    
    ax.set_xticks([])
    ax.set_yticks([])
    ins.set_xticks([])
    ins.set_yticks([])
    ins.imshow(psi[st1:end1,st1:end1],cmap='gray',vmax=vvmax,vmin=vvmin)
    # ins.plot(-(st1-end1)/2,-(st1-end1)/2,'r*',markersize=10,mec='black')    
    # plt.plot((end1+st1)/2,(st1+end1)/2,'r*',markersize=10,mec='black')    

    ins1 = ax.inset_axes([0.6,0.6,0.4,0.4])
    ins1.set_xticks([])
    ins1.set_yticks([])
    
    ins1.imshow((psi[st2:end2,st22:end22]),cmap='gray',vmax=vvmax,vmin=vvmin)
    # ins1.plot(-(st2-end2)/2,-(st22-end22)/2,'c*',markersize=10,mec='black')    
    # plt.plot((st22+end22)/2,(end2+st2)/2,'c*',markersize=10,mec='black')    

    

    scalebar = ScaleBar(voxelsize, "m", length_fraction=0.4,height_fraction=0.03,font_properties={
                "family": "serif","size":16,
            },  # For more information, see the cell below
            location="lower right")
    ins.add_artist(scalebar)    


In [None]:
from holotomocupy.utils import *
plt.rcParams['font.size'] = 20
matplotlib.use('Agg')

def S(psi,p):
    n=psi.shape[-1]
    x = np.fft.fftfreq(n).astype('float32')
    [y, x] = np.meshgrid(x, x)
    pp = np.exp(-2*np.pi*1j * (y*p[1, None, None]+x*p[0, None, None])).astype('complex64')
    res = np.fft.ifft2(pp*np.fft.fft2(psi))
    return res

def adjust(psi,psi_gt,shifts,shifts_gt):
    print(np.mean(shifts-shifts_gt,axis=0))
    psi = psi[64:-64,64:-64]
    psi_gt = psi_gt[64:-64,64:-64]
    psi = S(psi,np.mean(shifts-shifts_gt,axis=0))
    
    offset = np.angle(np.sum(psi*np.conj(psi_gt)))
    psi*=np.exp(-1j*offset)
    
    apsi = np.angle(psi)
    apsi_gt = np.angle(psi_gt)
    fpsi = np.fft.rfft2(apsi)
    fpsi_gt = np.fft.rfft2(apsi_gt)
    s=32
    fpsi[:s,:s]=fpsi_gt[:s,:s]
    fpsi[-s:,:s]=fpsi_gt[-s:,:s]
    apsi = np.fft.irfft2(fpsi) 
    
    err= apsi-apsi_gt
    
    return err

vvmax=0.1
vvmin=-0.8

path = f'/data/vnikitin/paper/near_field'
vvmax=0.1
vvmin=-0.8
psi_gt = np.load(f'{path}/data/psi.npy')


prb_opta = [True, True, True, True, False, False ]
pos_opta = [True, True, False, False, False, False]
noisea = [False, True, True, False, True, False]
nitera = [4096, 512, 512, 4096, 512, 512]        
va = [0.15,.15,.15,.15,0.15,0.001]
for ic in range(6):
    CG_flg=0
    DY_flg=0
    for method in ['epie','lsqml','BH-GD','DY-LS','DY-LS','BH-CG','BH-CG']:
        prb_opt = prb_opta[ic]
        pos_opt = pos_opta[ic]
        noise = noisea[ic]
        vvmin = -va[ic]
        vvmax=va[ic]
        i = nitera[ic]
        
        if method=='BH-CG' or method=='DY-LS':
            i=512
        
        if method=='BH-CG' and CG_flg==0:# and noise==False:
            CG_flg=1
            i = 128+32+16

        if method=='DY-LS' and DY_flg==0:# and noise==False:
            DY_flg=1
            i = 256
        flg = f'{method}_True_{prb_opt}_{pos_opt}_{noise}'
        # print(flg)
        shifts_random = np.load(f'{path}/data/gen_shifts_random.npy')[:npos]

        if method=='epie' or method=='lsqml':            
            psi_angle = dxchange.read_tiff(f'{path_out}_{flg}/crec_psi_angle/0.tiff')[:]
            psi_abs = dxchange.read_tiff(f'{path_out}_{flg}/crec_psi_abs/0.tiff')[:]
            q_angle = dxchange.read_tiff(f'{path_out}_{flg}/crec_prb_angle/0.tiff')[:]
            q_abs = dxchange.read_tiff(f'{path_out}_{flg}/crec_prb_abs/0.tiff')[:]
            if pos_opt:
                shifts_rec = np.load(f'{path_out}_{flg}/crec_shift_0.npy')
            else:    
                shifts_rec = shifts
        else:
            print(f'{path_out}_{flg}/crec_psi_angle/{i:03}.tiff')
            psi_angle = dxchange.read_tiff(f'{path_out}_{flg}/crec_psi_angle/{i:03}.tiff')[:]
            psi_abs = dxchange.read_tiff(f'{path_out}_{flg}/crec_psi_abs/{i:03}.tiff')[:]
            q_angle = dxchange.read_tiff(f'{path_out}_{flg}/crec_prb_angle/{i:03}.tiff')[:]
            q_abs = dxchange.read_tiff(f'{path_out}_{flg}/crec_prb_abs/{i:03}.tiff')[:]
            shifts = np.load(f'{path}/data/gen_shifts.npy')[:npos]
            if pos_opt:
                ishift = np.round(shifts_random).astype('int32')
            else:
                ishift = np.round(shifts).astype('int32')
            shifts_rec = ishift+np.load(f'{path_out}_{flg}/crec_shift_{i:03}.npy')

        psi = psi_abs*np.exp(1j*(psi_angle))                
        err = adjust(psi,psi_gt,shifts_rec,shifts)
        show_siemens(err,vvmin=vvmin,vvmax=vvmax)
        plt.savefig(f'err{flg}_{i}.png',dpi=300,bbox_inches='tight')
        print(f'err{flg}_{i}.png')
        plt.show()
        plt.close()
        

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
%matplotlib inline
prb_opta = [True, True, True, True, False, False ]
pos_opta = [True, True, False, False, False, False]
noisea = [False, True, True, False, True, False]
nitera = [4096, 512, 512, 4096, 512, 512]        
for ic in range(0,6):
    fig, axes = plt.subplots(1, 7,figsize=(17, 3))
    ii=0
    CG_flg=0            
    DY_flg=0            
    
    for method in ['epie','lsqml','BH-GD','DY-LS','DY-LS','BH-CG','BH-CG']:
        prb_opt = prb_opta[ic]
        pos_opt = pos_opta[ic]
        noise = noisea[ic]
        i = nitera[ic]        
        if method=='BH-CG' or method=='DY-LS':
            i=512
        if method=='BH-CG' and CG_flg==0:# and noise==False:
            CG_flg=1
            i = 128+32+16
        if method=='DY-LS' and DY_flg==0:# and noise==False:
            DY_flg=1
            i = 256
        flg = f'{method}_True_{prb_opt}_{pos_opt}_{noise}'
        path = f'err{flg}_{i}.png'        
        img = Image.open(path)
        ax = axes[ii]  # Calculate row and column
        ax.imshow(img)
        ax.set_title(f'{method}, {i} iters',fontsize=10)
        ax.axis("off")  # Turn off axis labels
        ii+=1
    flg = f'True_{prb_opt}_{pos_opt}_{noise}'
    # plt.tight_layout()
    tt=f'{prb_opt=} {pos_opt=} {noise=}'
    fig.suptitle(tt,x=0.5, y=0.9, fontsize=12)
    # plt.title("My Title", loc='left')
    plt.savefig(f'err_all_{flg}.png',dpi=600)
    plt.show()



In [None]:
def show_siemens(psi,vvmin=None,vvmax=None):    
    psi = psi[64:-64,64:-64]
    nobj = 1024-1024//8
    ni = 100
    st1 = nobj//2-ni//2
    end1 = nobj//2+ni//2
    st2 = 190-ni//2+30-64
    end2 = 190+ni//2+30-64
    st22 = 340-ni//2-64
    end22 = 340+ni//2-64
    fig, ax = plt.subplots()
    psi[0,0]=vvmin
    psi[-1,-1]=vvmax
    im = plt.imshow(np.roll(psi,0),cmap='gray',vmax=vvmax,vmin=vvmin)
    nerr = np.linalg.norm(psi)**2/1024/1024
    # plt.text(15,864,f'MSE = {nerr:.2e}',bbox=dict(facecolor='white', alpha=1))
    fig.colorbar(im, fraction=0.046, pad=0.02)
    vvmin, vvmax = im.get_clim()
    ins = ax.inset_axes([0.6,0,0.4,0.4])

    
    ax.set_xticks([])
    ax.set_yticks([])
    ins.set_xticks([])
    ins.set_yticks([])
    ins.imshow(psi[st1:end1,st1:end1],cmap='gray',vmax=vvmax,vmin=vvmin)
    
    ins1 = ax.inset_axes([0.6,0.6,0.4,0.4])
    ins1.set_xticks([])
    ins1.set_yticks([])
    
    ins1.imshow((psi[st2:end2,st22:end22]),cmap='gray',vmax=vvmax,vmin=vvmin)
    
    

    scalebar = ScaleBar(voxelsize, "m", length_fraction=0.4,height_fraction=0.03,font_properties={
                "family": "serif","size":16,
            },  # For more information, see the cell below
            location="lower right")
    ins.add_artist(scalebar)    


In [None]:
from holotomocupy.utils import *
plt.rcParams['font.size'] = 20
matplotlib.use('Agg')
def S(psi,p):
    n=psi.shape[-1]
    x = np.fft.fftfreq(n).astype('float32')
    [y, x] = np.meshgrid(x, x)
    pp = np.exp(-2*np.pi*1j * (y*p[1, None, None]+x*p[0, None, None])).astype('complex64')
    res = np.fft.ifft2(pp*np.fft.fft2(psi))
    return res

def adjust(psi,psi_gt,shifts,shifts_gt):
    print(np.mean(shifts-shifts_gt,axis=0))
    psi = psi[64:-64,64:-64]
    psi_gt = psi_gt[64:-64,64:-64]
    psi = S(psi,np.mean(shifts-shifts_gt,axis=0))
    
    offset = np.angle(np.sum(psi*np.conj(psi_gt)))
    psi*=np.exp(-1j*offset)
    
    apsi = np.angle(psi)
    apsi_gt = np.angle(psi_gt)
    fpsi = np.fft.rfft2(apsi)
    fpsi_gt = np.fft.rfft2(apsi_gt)
    s=32
    fpsi[:s,:s]=fpsi_gt[:s,:s]
    fpsi[-s:,:s]=fpsi_gt[-s:,:s]
    apsi = np.fft.irfft2(fpsi)     
    return apsi



path = f'/data/vnikitin/paper/near_field'
vvmax=0.1
vvmin=-0.8
psi_gt = np.load(f'{path}/data/psi.npy')


prb_opta = [True, True, True, True, False, False ]
pos_opta = [True, True, False, False, False, False]
noisea = [False, True, True, False, True, False]
nitera = [4096, 512, 512, 4096, 512, 512]        

for ic in range(6):
    CG_flg=0
    DY_flg=0
    for method in ['epie','lsqml','BH-GD','DY-LS','DY-LS','BH-CG','BH-CG']:
        prb_opt = prb_opta[ic]
        pos_opt = pos_opta[ic]
        noise = noisea[ic]
        # vvmin = -va[ic]
        # vvmax=va[ic]
        i = nitera[ic]
        
        if method=='BH-CG' and CG_flg==0:# and noise==False:
            CG_flg=1
            i = 128+32+32
        if method=='DY-LS' and DY_flg==0:# and noise==False:
            DY_flg=1
            i = 128+32+32
        flg = f'{method}_True_{prb_opt}_{pos_opt}_{noise}'
        # print(flg)
        shifts_random = np.load(f'{path}/data/gen_shifts_random.npy')[:npos]

        if method=='epie' or method=='lsqml':            
            psi_angle = dxchange.read_tiff(f'{path_out}_{flg}/crec_psi_angle/0.tiff')[:]
            psi_abs = dxchange.read_tiff(f'{path_out}_{flg}/crec_psi_abs/0.tiff')[:]
            q_angle = dxchange.read_tiff(f'{path_out}_{flg}/crec_prb_angle/0.tiff')[:]
            q_abs = dxchange.read_tiff(f'{path_out}_{flg}/crec_prb_abs/0.tiff')[:]
            if pos_opt:
                shifts_rec = np.load(f'{path_out}_{flg}/crec_shift_0.npy')
            else:    
                shifts_rec = shifts
        else:
            print(f'{path_out}_{flg}/crec_psi_angle/{i:03}.tiff')
            psi_angle = dxchange.read_tiff(f'{path_out}_{flg}/crec_psi_angle/{i:03}.tiff')[:]
            psi_abs = dxchange.read_tiff(f'{path_out}_{flg}/crec_psi_abs/{i:03}.tiff')[:]
            q_angle = dxchange.read_tiff(f'{path_out}_{flg}/crec_prb_angle/{i:03}.tiff')[:]
            q_abs = dxchange.read_tiff(f'{path_out}_{flg}/crec_prb_abs/{i:03}.tiff')[:]
            shifts = np.load(f'{path}/data/gen_shifts.npy')[:npos]
            if pos_opt:
                ishift = np.round(shifts_random).astype('int32')
            else:
                ishift = np.round(shifts).astype('int32')
            shifts_rec = ishift+np.load(f'{path_out}_{flg}/crec_shift_{i:03}.npy')

        psi = psi_abs*np.exp(1j*(psi_angle))                
        psi = adjust(psi,psi_gt,shifts_rec,shifts)
        show_siemens(psi,vvmin=vvmin,vvmax=vvmax)
        plt.savefig(f'rec{flg}_{i}.png',dpi=300,bbox_inches='tight')
        print(f'rec{flg}_{i}.png')
        plt.show()
        plt.close()
        

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
from PIL import Image
prb_opta = [True, True, True, True, False, False ]
pos_opta = [True, True, False, False, False, False]
noisea = [False, True, True, False, True, False]
nitera = [4096, 512, 512, 4096, 512, 512]        
for ic in range(0,6):
    fig, axes = plt.subplots(1, 7,figsize=(17, 3))
    ii=0
    CG_flg=0            
    DY_flg=0            
    for method in ['epie','lsqml','BH-GD','DY-LS','DY-LS','BH-CG','BH-CG']:
        prb_opt = prb_opta[ic]
        pos_opt = pos_opta[ic]
        noise = noisea[ic]
        i = nitera[ic]
        
        if method=='BH-CG' and CG_flg==0:# and noise==False:
            CG_flg=1
            i = 128+32+32
        if method=='DY-LS' and DY_flg==0:# and noise==False:
            DY_flg=1
            i = 128+32+32
        flg = f'{method}_True_{prb_opt}_{pos_opt}_{noise}'
        path = f'rec{flg}_{i}.png'        
        img = Image.open(path)
        ax = axes[ii]  # Calculate row and column
        ax.imshow(img)
        ax.set_title(f'{method}, {i} iters',fontsize=10)
        ax.axis("off")  # Turn off axis labels
        ii+=1
    flg = f'True_{prb_opt}_{pos_opt}_{noise}'
    tt=f'{prb_opt=} {pos_opt=} {noise=}'
    fig.suptitle(tt,x=0.5, y=0.9, fontsize=12)
    plt.savefig(f'rec_all_{flg}.png',dpi=600)
    plt.show()

