# Initialize

## Import libraries

In [None]:
import sys
import time
import os
from os import path
from os.path import join
from importlib import reload
import logging
logging.basicConfig(level=logging.INFO)
from tqdm import tqdm

import numpy as np
import h5py

import ipywidgets
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import matplotlib.image as mpimg

import scipy as scp
import scipy.constants as con
from scipy import stats
from scipy.ndimage import gaussian_filter

from PIL import Image, ImageOps
from skimage.draw import (line, polygon, disk,
                          circle_perimeter,
                          ellipse, ellipse_perimeter,
                          bezier_curve, ellipsoid, rectangle)

# Self-written libraries
sys.path.append(join(os.getcwd(), "library"))
import mask_lib
import fthcore as fth
import helper_functions as helper
import interactive
from interactive import cimshow
import reconstruct_rb as rec

In [None]:
# Is there a GPU?
try:
    # Cupy
    import cupy as cp
    import cupyx as cpx

    GPU = True

    # Self-written library
    import CCI_core_cupy as cci
    import Phase_Retrieval as PhR

    print("GPU available")
except:
    GPU = False
    import CCI_core as cci

    print("GPU unavailable")

In [None]:
# Is there a GPU?
# Cupy
try:
    import cupy as cp
    import cupyx as cpx
    import CCI_core_cupy as cci
    import Phase_Retrieval as PhR
    GPU = True
except:
    import CCI_core as cci
    GPU = False

In [None]:
# interactive plotting
import ipywidgets

%matplotlib widget

# Auto formatting of cells
#%load_ext jupyter_black
plt.rcParams["figure.constrained_layout.use"] = True  # replaces plt.tight_layout

## Experiment specific functions

In [None]:
# Setup phase and propagation for cdi once
phase_cdi = 0
prop_dist_cdi = 0
dx = 0
dy = 0


def phase_retrieval(
    pos, mask_pixel, supportmask, vmin = 0, Startimage=None, Startgamma=None
):
    # Prepare Input holograms
    pos2 = pos.copy()

    mi, _ = np.percentile(pos2[pos2 != 0], [vmin, 99.9])
    pos2 = pos2 - mi

    pos2[pos2 < 0] = 0
    pos2 = pos2.astype(complex)

    bsmask_p = mask_pixel.copy()
    bsmask_p[pos2 <= 0] = 1

    # Setup start image and startgamma
    if Startimage is None:
        Startimage = np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift(supportmask)))
    else:
        Startimage = Startimage.copy()
    if Startgamma is None:
        Startgamma = np.ones(pos.shape) * 1e-6 * 2
        Startgamma[pos.shape[0] // 2, pos.shape[1] // 2] = 0.7
    else:
        Startgamma = Startgamma.copy()

    # Settings for phase retrieval reconstructions
    partial_coherence = True

    # Setup
    retrieved_p = np.zeros(pos2.shape, np.cdouble)
    retrieved_n = np.zeros(pos2.shape, np.cdouble)

    # Algorithms and Inital guess
    plt.rcParams["figure.dpi"] = 100
    print("CDI - larger mask")

    algorithm_list = ["mine", "ER", "ER"]
    Nit_list = [700, 50, 50]  # iterations for algorithm_list

    x = (np.sqrt(np.maximum(pos2, np.zeros(pos2.shape)))[mask_pixel == 0]).flatten()
    y = ((np.abs(Startimage))[mask_pixel == 0]).flatten()
    res = stats.linregress(x, y)
    Startimage -= res.intercept
    Startimage /= res.slope

    average_img = 30
    real_object = False  # always set to False

    if partial_coherence:
        RL_freq = 20
        RL_it = 50

        algorithm_list_pc = ["mine", "ER", "ER"]
        Nit_list_pc = [1000, 200, 200]

    # Execute Phase retrieval
    start_time = time.time()
    for i in range(len(Nit_list) // 3):
        print("############ -   CDI")

        # Positive helicity - beta_mode="arctan"
        retrieved_p, Error_diff_p, Error_supp = PhR.PhaseRtrv_GPU(
            diffract=np.sqrt(np.maximum(pos2, np.zeros(pos2.shape))),
            mask=supportmask,
            mode=algorithm_list[3 * i],
            beta_zero=0.5,
            Nit=Nit_list[3 * i],
            beta_mode="arctan",
            plot_every=349,
            Phase=Startimage,
            seed=False,
            real_object=real_object,
            bsmask=bsmask_p,
            average_img=average_img,
            Fourier_last=True,
        )

        # Positive helicity - beta_mode="const"
        retrieved_p, Error_diff_p2, Error_supp = PhR.PhaseRtrv_GPU(
            diffract=np.sqrt(np.maximum(pos2, np.zeros(pos2.shape))),
            mask=supportmask,
            mode=algorithm_list[3 * i + 1],
            beta_zero=0.5,
            Nit=Nit_list[3 * i + 1],
            beta_mode="const",
            plot_every=24,
            Phase=retrieved_p,
            seed=False,
            real_object=real_object,
            bsmask=bsmask_p,
            average_img=average_img,
            Fourier_last=True,
        )


        print("--- %s seconds ---" % np.round((time.time() - start_time), 2))

        Startimage = retrieved_p.copy()

        # Partial coherence phase retrieval
        if partial_coherence:
            # CDI_PC
            print("############   -   CDI_pc")
            pos3 = (np.abs(retrieved_p) ** 2) * bsmask_p + np.maximum(
                pos2, np.zeros(pos2.shape)
            ) * (1 - bsmask_p)

            # retrieve pos image
            (
                retrieved_p_pc,
                Error_diff_p_pc,
                Error_supp,
                gamma_p,
            ) = PhR.PhaseRtrv_with_RL(
                diffract=np.sqrt(pos3),
                mask=supportmask,
                mode=algorithm_list_pc[3 * i],
                beta_zero=0.5,
                Nit=Nit_list_pc[3 * i],
                beta_mode="arctan",
                gamma=Startgamma,
                RL_freq=RL_freq,
                RL_it=RL_it,
                plot_every=349,
                Phase=Startimage,
                seed=False,
                real_object=False,
                bsmask=np.zeros(bsmask_p.shape),
                average_img=average_img,
                Fourier_last=True,
            )

            (
                retrieved_p_pc,
                Error_diff_p_pc2,
                Error_supp,
                gamma_p,
            ) = PhR.PhaseRtrv_with_RL(
                diffract=np.sqrt(pos3),
                mask=supportmask,
                mode=algorithm_list[3 * i + 1],
                beta_zero=0.5,
                Nit=Nit_list_pc[3 * i + 1],
                beta_mode="const",
                gamma=gamma_p,
                RL_freq=RL_freq,
                RL_it=RL_it,
                plot_every=24,
                Phase=retrieved_p_pc,
                real_object=False,
                bsmask=np.zeros(bsmask_p.shape),
                average_img=average_img,
                Fourier_last=True,
            )

            print("--- %s seconds ---" % np.round((time.time() - start_time), 2))

            Startimage = retrieved_p_pc.copy()
            Startgamma = gamma_p.copy()

    print("Phase Retrieval Done!")

    return (
        retrieved_p,
        retrieved_p_pc,
        bsmask_p,
        gamma_p,
    )

In [None]:
def gaussian_2d(shape, mean_x, mean_y, sigma_x, sigma_y, amp, normalize = True):
    """
    Erzeugt ein 2D-Gaussian-Verteilungsarray.

    Parameters:
    shape (tuple): Größe des Arrays (MxN)
    mean_x (float): Mittelwert in x-Richtung
    mean_y (float): Mittelwert in y-Richtung
    sigma_x (float): Standardabweichung in x-Richtung
    sigma_y (float): Standardabweichung in y-Richtung
    amp (float): Amplitude to scale gaussian

    Returns:
    np.ndarray: NxN-Array mit der 2D-Gaussian-Verteilung
    """
    x = np.linspace(0, shape[1]-1, shape[1])
    y = np.linspace(0, shape[0]-1, shape[0])
    x, y = np.meshgrid(x, y)

    if normalize is True:
        gaussian =  amp / (2. * np.pi * sigma_x * sigma_y) * np.exp(-((x - mean_x)**2. / (2. * sigma_x**2.) + (y - mean_y)**2. / (2. * sigma_y**2.)))
    else:
        gaussian =  amp * np.exp(-((x - mean_x)**2. / (2. * sigma_x**2.) + (y - mean_y)**2. / (2. * sigma_y**2.)))
    return gaussian

##Function to convert rgb image to gray scale
def rgb2gray(rgb):
    return np.dot(rgb[...,:3], [0.2989, 0.5870, 0.1140])

# Create Real-space object and holograms

In [None]:
from skimage.transform import resize

## import image
img = mpimg.imread(join(os.getcwd(),'paper_analysis_code','sample_image.png'))
image_background_original = rgb2gray(img)

# Padding to 1000px, 1000px
#image_background_original = resize(image_background_original,[1000,1000],anti_aliasing=True)

cimshow(image_background_original)

In [None]:
# Parameter for object generation
nr_objects = 100 #How many object configurations to generate in object aperture?

amplitude_object = 1 #Amplitude scaling of gaussian, integration over all pixels is amplitude 
std_object = 2 # Standard deviation of gaussian

radi_ref = 3 #radius of reference apertures
radi_obj = 38 # radius of object aperture

amplitude_drift = 80 #Limit of random deflection

In [None]:
# Create fth mask
# Create FTH mask for object aperture
mask = np.zeros(image_background_original.shape)
rr, cc = disk((mask.shape[0]//2, mask.shape[1]//2), radi_obj, shape=mask.shape) #object aperture
mask[rr, cc] = 1 

# Smooth edges
mask = gaussian_filter(mask,1)

# Create mask for reference apertures
mask_ref = np.zeros(image_background_original.shape)
rr, cc = disk((mask.shape[0]//2+180, mask.shape[1]//2+150), radi_ref + 1, shape=mask.shape) # reference 1
mask_ref[rr, cc] = 1
rr, cc = disk((mask.shape[0]//2+100, mask.shape[1]//2-150), radi_ref + 1, shape=mask.shape) # reference 2
mask_ref[rr, cc] = 1
#mask_ref = gaussian_filter(mask_ref,0.5)
mask_ref = mask_ref/np.max(mask_ref)

# Combine both and apply to background image
image_background = image_background_original * mask + mask_ref
mask = mask + mask_ref

holo_mask = np.abs(cci.FFT(mask))**2 # for plotting

# Plot mask
cimshow(image_background)

In [None]:
img = mpimg.imread(join(os.getcwd(),'paper_analysis_code','sample_image.png'))
image1 = rgb2gray(img)

## Define circular matrix
pixeln=int(np.sqrt(image1.size))
clip_radius=40
rows, cols = pixeln,pixeln
row_vec = np.double(np.arange(0,rows))
col_vec = np.double(np.arange(0,cols))
yy, xx= np.meshgrid(rows//2-col_vec, cols//2-row_vec)
mask =1-((xx)**2 + (yy)**2>clip_radius**2)
mask = gaussian_filter(mask.astype(float),1)

## define holography hole
r_hole_radius1=2
r_hole_radius2=2
location_r1x= 180
location_r1y= 150
location_r2x= 100
location_r2y= -150
holography_hole1=(1-((xx+location_r1x)**2 + (yy+location_r1y)**2>r_hole_radius1**2))*1
holography_hole2=(1-((xx+location_r2x)**2 + (yy+location_r2y)**2>r_hole_radius2**2))*1
holography_hole=holography_hole1+holography_hole2
holography_hole = gaussian_filter(holography_hole.astype(float),1)
holography_hole = holography_hole/np.max(holography_hole)

##Add holography hole
image_background=image1*mask+holography_hole

##Define phase
phase1= (100*(xx+yy)/(pixeln*np.pi))*mask

cimshow(image_background)

In [None]:
# Calc diffractions
# Setup of arrays
particle_avg = np.zeros(mask.shape) # Contains avg of all particles
obj_all  = np.zeros(mask.shape,dtype="complex64") # Contains avg of all objects (incl. mask)

holo_avg = np.zeros(mask.shape) # averaged hologram, i.e., squared abs valut of diffraction pattern
diffraction_fluctuation_avg = np.zeros(mask.shape,dtype="complex64")
intensity_fluctuation_avg = np.zeros(mask.shape)

# Put objects in to mask
if nr_objects == 0:
    obj = mask.copy()

    # Calculate hologram
    holo = np.abs(cci.FFT(obj))**2
    holo_avg += holo
    obj_all += obj
    holo_mask = np.abs(cci.FFT(mask))**2
else:
    for i in tqdm(range(nr_objects)):
        # Setup new object which combines mask and particle
        obj = np.zeros(mask.shape,dtype="complex64")
        
        # Generate random positions for particle
        y, x = obj.shape[0]//2+(np.random.random()-0.5)*amplitude_drift, obj.shape[1]//2+(np.random.random()-0.5)*amplitude_drift
        particle = gaussian_2d(obj.shape, y, x, std_object, std_object, amplitude_object,normalize = False)
        particle = particle * mask

        # Combine with mask
        obj = image_background + particle
        obj = obj.astype("complex64")

        # Add to real space ensemble average
        particle_avg += particle
        obj_all += obj
    
        # Calculate complex scattering pattern by FFT
        diffraction = np.fft.fftshift(np.fft.ifft2(np.fft.fftshift(obj))) # scattering amplitude single particle
        
        # Derive hologram and normalize with respect to number of configurations
        holo = np.abs(diffraction)**2
        holo_avg += holo/nr_objects

        # Calc diffraction of particles only (without mask)
        diffraction_fluctuation = np.fft.fftshift(np.fft.ifft2(np.fft.fftshift(particle)))
        diffraction_fluctuation_avg += diffraction_fluctuation
        intensity_fluctuation_avg += np.abs(diffraction_fluctuation)**2

# Reconstruct objects
reco_obj = np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(holo)))
reco_avg = np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(holo_avg)))

# Calc stochastic contribution
intensity_stochastic_numerical = np.abs(intensity_fluctuation_avg/nr_objects) - np.abs(diffraction_fluctuation_avg/nr_objects)**2

In [None]:
# Plotting of a lot of stuff
fig, ax = plt.subplots(3,3,figsize=(9,9),sharex=True,sharey=True)

ax[0,0].set_title("Single fluctuation in real space",fontsize=8)
tmp = particle.copy()
ax[0,0].imshow(tmp,vmin=0, vmax = amplitude_object)

ax[0,1].set_title("Single pattern in real space",fontsize=8)
tmp = obj.real.copy()
vmin, vmax = np.percentile(tmp[tmp!=0], [.1,99.9])
ax[0,1].imshow(tmp,vmin=vmin, vmax = vmax)

ax[1,0].set_title("Ensemble fluctuations in real space",fontsize=8)
tmp = particle_avg.copy()
vmin, vmax = np.percentile(tmp[tmp!=0], [1,100])
ax[1,0].imshow(tmp,vmin=0, vmax = vmax)

ax[1,1].set_title("Ensemble pattern in real space",fontsize=8)
tmp = obj_all.real.copy()
vmin, vmax = np.percentile(tmp[tmp>0.25], [10,100])
ax[1,1].imshow(tmp,vmin=0, vmax = vmax)

ax[0,2].set_title("Single diffraction in fourier space",fontsize=8)
tmp = holo.copy()
vmin, vmax = np.percentile(tmp[tmp!=0], [.1,99.9])
ax[0,2].imshow(tmp,norm=LogNorm(vmin=1e-20, vmax = vmax))

ax[1,2].set_title("Ensemble diffraction in fourier space",fontsize=8)
tmp = holo_avg.copy()
vmin, vmax = np.percentile(tmp[tmp!=0], [.1,99.9])
ax[1,2].imshow(tmp,norm=LogNorm(vmin=1e-20, vmax = vmax))

ax[2,0].set_title("Ensemble diffract amplitudes squared of gaussian fourier space",fontsize=6)
tmp = np.abs(diffraction_fluctuation_avg)**2
vmin, vmax = np.percentile(tmp, [1,99.9])
ax[2,0].imshow(tmp,norm=LogNorm(vmin=vmin, vmax = vmax))

ax[2,1].set_title("Ensemble diffract intensities of gaussian fourier space",fontsize=6)
tmp = np.abs(intensity_fluctuation_avg)/nr_objects
vmin, vmax = np.percentile(tmp[tmp!=0], [.1,99.9])
ax[2,1].imshow(tmp,norm=LogNorm(vmin=vmin, vmax = vmax))


ax[2,2].set_title("Stochastic contribution",fontsize=8)
tmp = intensity_stochastic_numerical.copy()
vmin, vmax = np.percentile(tmp[tmp!=0], [.1,99.9])
ax[2,2].imshow(tmp,norm=LogNorm(vmin=vmin, vmax = vmax))


# Create masks for all cross correlations

### Masking of Object-reference cross-correlations

In [None]:
# How many object references do you have?
nr_ref = 2

# Setup coordinates (nr_ref + 1 coordinates, as there is always the object aperture)
support_coordinates = [
    [mask.shape[-2] // 2, mask.shape[-1] // 2, 7] for k in range(2*nr_ref)
]

# Widget to find the positions and sizes of the different apertures
print(
    "Cover the object & reference apertures for each set of reconstructions that originates from the same reference with circles."
)
print(
    "Optimization: Change one circle parameter, calc phase retrieval image, compare with images reconstructed with old circle parameter. Repeat!"
)

# Create plot
# Reconstruct
recon = fth.reconstruct(holo_avg)
recon = np.abs(recon)

ds = interactive.InteractiveCircleCoordinates(
    recon,
    len(support_coordinates),
    coordinates=support_coordinates.copy(),
)

In [None]:
def get_supportmask_coordinates(sample):
    """
    Dictionary that stores coordinates of circular support mask apertures
    """

    # Setup dictonary
    support_coord = dict()

    # coordinates
    support_coord["simulation_obj_ref_crosscorr"] = [(189.0, 220.0, 48.0), (469.0, 220.0, 48.0), (269.5, 519.5, 48.0), (549.5, 520.0, 47.5)]

    return support_coord[sample]

In [None]:
# Which supportmask to load? ("s2306a-C1", "s2308a-B1", ...)
sample = "simulation_obj_ref_crosscorr"

# Get coordinates and create supportmask
reco_coordinates = get_supportmask_coordinates(sample)

# Create masks
mask_reco = np.zeros((len(reco_coordinates),mask.shape[0],mask.shape[1]))
for i, (y, x, radi) in enumerate(reco_coordinates):
    mask_reco[i] = mask_lib.circle_mask(mask.shape,(y,x),radi)

fig, ax = cimshow(mask_reco)
ax.set_title("Mask for objects")

### Masking of reference-reference cross-correlations

In [None]:
# How many object references do you have?
nr_ref = 2

# Setup coordinates (nr_ref + 1 coordinates, as there is always the object aperture)
support_coordinates = [
    [mask.shape[-2] // 2, mask.shape[-1] // 2, 7] for k in range(nr_ref*(nr_ref-1))
]

# Widget to find the positions and sizes of the different apertures
print(
    "Cover the object & reference apertures for each set of reconstructions that originates from the same reference with circles."
)
print(
    "Optimization: Change one circle parameter, calc phase retrieval image, compare with images reconstructed with old circle parameter. Repeat!"
)

# Create plot
# Reconstruct
recon = fth.reconstruct(holo_avg)
recon = np.abs(recon)

ds = interactive.InteractiveCircleCoordinates(
    recon,
    len(support_coordinates),
    coordinates=support_coordinates.copy(),
)

In [None]:
def get_supportmask_coordinates(sample):
    """
    Dictionary that stores coordinates of circular support mask apertures
    """

    # Setup dictonary
    support_coord = dict()

    # coordinates
    support_coord["simulation_ref_ref_crosscorr"] = [(290.0, 70.0, 12.0), (450.0, 670.0, 12.0)]

    return support_coord[sample]

In [None]:
# Which supportmask to load? ("s2306a-C1", "s2308a-B1", ...)
sample = "simulation_ref_ref_crosscorr"

# Get coordinates and create supportmask
ref_coordinates = get_supportmask_coordinates(sample)

# Create masks
mask_ref = np.zeros((len(ref_coordinates),mask.shape[0],mask.shape[1]))
for i, (y, x, radi) in enumerate(ref_coordinates):
    mask_ref[i] = mask_lib.circle_mask(mask.shape,(y,x),radi)

fig, ax = cimshow(mask_ref)
ax.set_title("Mask for references")

In [None]:
# Combine to full cross correlation mask
mask_cross = np.sum(mask_reco,axis=0) + np.sum(mask_ref,axis=0)

# Plotting
tmp = np.log10(np.abs(reco_avg)+1)
fig, ax = plt.subplots(figsize=(8,8))
vmin, vmax = np.percentile(tmp[tmp!=0],[5,60])
ax.imshow(mask_cross)
ax.imshow(tmp, vmin=vmin, vmax=vmax,alpha = 0.5)

# Extract cross_correlation values and calc FFT

In [None]:
# Get cross correlations of reco and ref
cross_reco = reco_avg[np.newaxis,:,:]*mask_reco
cross_ref = reco_avg[np.newaxis,:,:]*mask_ref

# Calc fourier transform of cross correlations
holo_cross_reco = scp.fft.ifftshift(scp.fft.ifft2(scp.fft.fftshift(cross_reco,axes=(1,2)),axes=(1,2)),axes=(1,2))
holo_cross_ref = scp.fft.ifftshift(scp.fft.ifft2(scp.fft.fftshift(cross_ref,axes=(1,2)),axes=(1,2)),axes=(1,2))

In [None]:
# Series of slideshow plots
fig, ax = cimshow(cross_reco)
ax.set_title("Selected cross-correlation areas of recos")

In [None]:
# Series of slideshow plots
fig, ax = cimshow(np.log10(np.abs(holo_cross_reco)))
ax.set_title("Diffraction of selected cross-correlation areas of recos")

In [None]:
# Series of slideshow plots
fig, ax = cimshow(cross_ref)
ax.set_title("Selected cross-correlation areas of references")

In [None]:
# Series of slideshow plots
fig, ax = cimshow(np.log10(np.abs(holo_cross_ref)))
ax.set_title("Diffraction of selected cross-correlation areas of recos")

# Extract autocorrelation terms

In [None]:
# Auto correlation contribution of object
holo_auto_reco = np.abs(np.sqrt(np.prod(holo_cross_reco,axis=0)/np.prod(holo_cross_ref,axis=0)))

# Filter invalid values
#holo_auto_reco[holo_auto_reco>1e-4] = np.nan

# Auto correlation contribution of references. Care: Cross correlations of reco must be ordered pairwise
# which means that holo_cross_reco[2*i] and holo_cross_reco[2*i+1] are the twin image reconstruction of
# the same reference aperture. Define Masks accordingly!
holo_auto_ref = np.zeros(cross_ref.shape)
for i in range(cross_reco.shape[0]//2):
    holo_auto_ref[i] = np.abs(holo_cross_reco[2*i]*holo_cross_reco[2*i+1])/holo_auto_reco

In [None]:
# Series of slideshow plots
fig, ax = cimshow(np.abs(holo_auto_reco))
ax.set_title("Auto correlation contribution of object")

In [None]:
# Series of slideshow plots
fig, ax = cimshow(holo_auto_ref)
ax.set_title("Auto correlation contribution of references")

# Calc time-averaged momentum space spectrum

In [None]:
# Subtract all different contributions from the averaged hologram
fluctuation_spectrum = holo_avg - holo_auto_reco - np.sum(holo_auto_ref,axis=0) - np.sum(holo_cross_reco,axis=0) - np.sum(holo_cross_ref,axis=0)

# Eliminate of artifacts at high q
fluctuation_spectrum *= mask_lib.circle_mask(fluctuation_spectrum.shape,np.array(fluctuation_spectrum.shape)//2,200)

#fluctuation_spectrum[np.abs(fluctuation_spectrum) > 1e-8] = np.nan

# Plotting
fig, ax = cimshow(fluctuation_spectrum)
ax.set_title("Stochastic term from CIDI")

# Normal phase retrieval

In [None]:
# Dummy Dict with most basic experimental parameter
experimental_setup = {
    "ccd_dist": 0.125,  # ccd to sample distance
    "px_size": 15e-6,  # pixel_size of camera
    "binning": 1,  # Camera binning
}

# Setup for azimuthal integrator
experimental_setup["energy"] = 778
experimental_setup["lambda"] = cci.photon_energy_wavelength(
    experimental_setup["energy"], input_unit="eV"
)

## Create normal support mask

In [None]:
# How many object references do you have?
nr_ref = 2

# Setup coordinates (nr_ref + 1 coordinates, as there is always the object aperture)
support_coordinates = [
    [mask.shape[-2] // 2, mask.shape[-1] // 2, 7] for k in range(nr_ref+1)
]

# Widget to find the positions and sizes of the different apertures
print(
    "Cover the object & reference apertures for each set of reconstructions that originates from the same reference with circles."
)
print(
    "Optimization: Change one circle parameter, calc phase retrieval image, compare with images reconstructed with old circle parameter. Repeat!"
)

# Create plot
# Reconstruct
recon = fth.reconstruct(holo_avg)
recon = np.abs(recon)

ds = interactive.InteractiveCircleCoordinates(
    recon,
    len(support_coordinates),
    coordinates=support_coordinates.copy(),
)

In [None]:
def get_supportmask_coordinates(sample):
    """
    Dictionary that stores coordinates of circular support mask apertures
    """

    # Setup dictonary
    support_coord = dict()

    # coordinates
    support_coord["simulation_normal_cdi"] =[(191.5, 220.0, 45.0), (291.0, 70.0, 9.0), (371.0, 370.0, 9.0)]

    return support_coord[sample]

In [None]:
# Which supportmask to load? ("s2306a-C1", "s2308a-B1", ...)
sample = "simulation_normal_cdi"

# Get coordinates and create supportmask
support_coordinates = get_supportmask_coordinates(sample)

In [None]:
# Widget to find the positions and sizes of the different apertures
print(
    "Cover the object & reference apertures for each set of reconstructions that originates from the same reference with circles."
)
print(
    "Optimization: Change one circle parameter, calc phase retrieval image, compare with images reconstructed with old circle parameter. Repeat!"
)

# Create plot
# Reconstruct
recon = fth.reconstruct(holo_avg)
recon = np.abs(recon)

ds = interactive.InteractiveCircleCoordinates(
    recon,
    len(support_coordinates),
    coordinates=support_coordinates.copy(),
)

In [None]:
# Take coordinates of circles from widget
support_coordinates = ds.get_params()

# Create masks
supportmask = np.zeros(mask.shape)
for y, x, radi in support_coordinates:
    supportmask += mask_lib.circle_mask(mask.shape,(y,x),radi)

# Plotting
tmp = np.log10(np.abs(reco_avg)+1)
fig, ax = plt.subplots(figsize=(8,8))
vmin, vmax = np.percentile(tmp[tmp!=0],[5,99])
ax.imshow(tmp, vmin=vmin, vmax=vmax)
ax.imshow(supportmask,alpha = 0.5) 

In [None]:
roi_cdi = interactive.axis_to_roi(ax)
#roi_cdi = [138, 524, 477, 734]
#roi_cdi = [270, 396, 603, 730]
#roi_cdi = np.s_[roi_cdi[0] : roi_cdi[1], roi_cdi[2] : roi_cdi[3]]
print("Sliced roi:", roi_cdi)

## Execute phase retrieval

In [None]:
(
    retrieved,
    _,
    bsmask,
    _,
) = phase_retrieval(holo_avg, np.zeros(mask.shape), supportmask, vmin = 0, Startimage=None, Startgamma=None)

In [None]:
# New beamstop for CDI recos as phase retrieval of low-q might be insufficient. If phase retrieval worked well
# Try without beamstop: `use_bs = False`
use_bs = False
bs_diam_cdi = 32  # diameter of beamstop

# Create beamstop
if use_bs is True:
    mask_bs_cdi = 1 - mask_lib.circle_mask(
        mask.shape, np.array(mask.shape) / 2, bs_diam_cdi, sigma=4
    )
    mask_bs_cdi = 1 - mask_pixel_smooth.copy()
elif use_bs is False:
    mask_bs_cdi = np.ones(mask.shape)  # if you don't want a beamstop


# Get Reco
p = fth.reconstructCDI(
    fth.propagate(
        retrieved * mask_bs_cdi,
        prop_dist_cdi * 1e-6,
        experimental_setup=experimental_setup,
    )
)

# Get Reco partial coherence
p_pc = fth.reconstructCDI(
    fth.propagate(
        retrieved * mask_bs_cdi,
        prop_dist_cdi * 1e-6,
        experimental_setup=experimental_setup,
    )
)

# optimize phase
#_, phase_cdi = optimize_phase_contrast(p_pc, supportmask, method="contrast")

# Plotting
mode = "-"
print("Fine-tuning of reconstruction parameter:")
slider_prop, slider_phase, slider_dx, slider_dy = rec.focusCDI(
    retrieved * mask_bs_cdi,
    np.zeros(mask.shape),
    roi_cdi,
    mask=supportmask,
    phase=phase_cdi,
    dx=dx,
    dy=dy,
    prop_dist=prop_dist_cdi,
    experimental_setup=experimental_setup,
    operation=mode,
    max_prop_dist=10,
    scale=(0.1, 99.9),
)

In [None]:
cimshow(np.abs(p_pc))

In [None]:
# Reconstruct
# Comparison of retrieved and real object
fig, ax = plt.subplots(1,2,figsize=(8,4),sharex=True,sharey=True)
tmp = np.abs(obj_all)
vmin, vmax = np.percentile(tmp,[5,99.9])
ax[0].imshow(tmp, vmin=vmin, vmax=vmax,cmap="gray")
ax[0].set_title("Real simulated objects")

tmp = np.abs(p_pc)#*supportmask
vmin, vmax = np.percentile(tmp,[1,99.9])
ax[1].imshow(tmp, vmin=vmin, vmax=vmax,cmap="gray")
ax[1].set_title("Retrieved objects")
#ax[1].set_xlim([400,730])
#ax[0].set_ylim([250,600])
#ax[1].invert_yaxis()

# CIDI phase retrieval

## Create CIDI support mask

In [None]:
# Create support mask for phase retrieval of fluctuation part
supportmask = np.zeros(mask.shape)


# Object aperture
rr, cc = disk((obj.shape[0]//2, obj.shape[1]//2), 10, shape=obj.shape)
supportmask[rr,cc] = 1

# Plotting
tmp = np.log10(np.abs(reco_avg)+1)
fig, ax = plt.subplots(figsize=(8,8))
vmin, vmax = np.percentile(tmp[tmp!=0],[5,99])
ax.imshow(tmp, vmin=vmin, vmax=vmax)
ax.imshow(supportmask,alpha = 0.5) 

In [None]:
roi_cdi = interactive.axis_to_roi(ax)
roi_cdi = [335, 409, 325, 421]
roi_cdi = np.s_[roi_cdi[0] : roi_cdi[1], roi_cdi[2] : roi_cdi[3]]
print("Sliced roi:", roi_cdi)

## Execute phase retrieval

In [None]:
(
    retrieved,
    _,
    bsmask,
    _,
) = phase_retrieval(np.abs(fluctuation_spectrum), np.zeros(mask.shape), supportmask, vmin = 0, Startimage=None, Startgamma=None)

In [None]:
# New beamstop for CDI recos as phase retrieval of low-q might be insufficient. If phase retrieval worked well
# Try without beamstop: `use_bs = False`
use_bs = False
bs_diam_cdi = 32  # diameter of beamstop

# Create beamstop
if use_bs is True:
    mask_bs_cdi = 1 - mask_lib.circle_mask(
        mask.shape, np.array(mask.shape) / 2, bs_diam_cdi, sigma=4
    )
    mask_bs_cdi = 1 - mask_pixel_smooth.copy()
elif use_bs is False:
    mask_bs_cdi = np.ones(mask.shape)  # if you don't want a beamstop


# Get Reco
p = fth.reconstructCDI(
    fth.propagate(
        retrieved * mask_bs_cdi,
        prop_dist_cdi * 1e-6,
        experimental_setup=experimental_setup,
    )
)

# Get Reco partial coherence
p_pc = fth.reconstructCDI(
    fth.propagate(
        retrieved * mask_bs_cdi,
        prop_dist_cdi * 1e-6,
        experimental_setup=experimental_setup,
    )
)

# optimize phase
#_, phase_cdi = optimize_phase_contrast(p_pc, supportmask, method="contrast")

# Plotting
mode = "-"
print("Fine-tuning of reconstruction parameter:")
slider_prop, slider_phase, slider_dx, slider_dy = rec.focusCDI(
    retrieved * mask_bs_cdi,
    np.zeros(mask.shape),
    roi_cdi,
    mask=supportmask,
    phase=phase_cdi,
    dx=dx,
    dy=dy,
    prop_dist=prop_dist_cdi,
    experimental_setup=experimental_setup,
    operation=mode,
    max_prop_dist=10,
    scale=(1, 99),
)

In [None]:
# Plotting
fig, ax = plt.subplots(2,2,sharex=True,sharey=True,figsize=(8,8))
tmp = np.abs(retrieved)
vmin, vmax = np.percentile(tmp,[1,99])
ax[0,0].imshow(tmp)
ax[0,0].set_title("'diffraction pattern retrieved from CIDI'",fontsize=8)

tmp = np.abs(np.abs(fluctuation_spectrum))
vmin, vmax = np.percentile(tmp,[1,99])
ax[0,1].imshow(tmp)
ax[0,1].set_title("diffraction pattern from numerical data",fontsize=8)

tmp = np.abs(np.abs(p))
vmin, vmax = np.percentile(tmp,[1,99])
ax[1,0].imshow(tmp)
ax[1,0].set_title("Absolute value of retrieved image (object plane) from CIDI",fontsize=8)

tmp = np.abs(particle)#*supportmask
vmin, vmax = np.percentile(tmp[tmp!=0],[1,99])
ax[1,1].imshow(tmp)
ax[1,1].set_title("Input fluctuation",fontsize=8)