# Initalize libraries

## Import libraries

In [2]:
import sys, os
import time
from os.path import join
from os import path
from importlib import reload
from getpass import getuser
from glob import glob
from tqdm.auto import tqdm

# Data
import xarray as xr
import h5py
import numpy as np
from nexusformat.nexus import *

# Plotting
import imageio
import matplotlib.pyplot as plt
import fabio
import skimage.morphology

# skimage
from skimage.draw import ellipse

# scipy
from scipy.ndimage.filters import gaussian_filter
import scipy
from scipy import stats
from scipy import ndimage

# pyFAI
import pyFAI
from pyFAI.azimuthalIntegrator import AzimuthalIntegrator
from pyFAI.detectors import Detector

# Self-written libraries
sys.path.append(join(os.getcwd(), "library"))
import reconstruct as reco
import fthcore as fth
import support_functions as sup
import interactive
from interactive import cimshow
import reconstruct_rb as rec
import reconstruct as reco

plt.rcParams["figure.constrained_layout.use"] = True  # replaces plt.tight_layout


In [3]:
# Is there a GPU?
try:
    # Cupy
    import cupy as cp
    import cupyx as cpx

    GPU = True

    print("GPU available")

    # Self-written library
    import CCI_core_cupy as cci
    import Phase_Retrieval as PhR
except:
    GPU = False
    import CCI_core as cci

    print("GPU unavailable")

GPU available


In [4]:
# interactive plotting
import ipywidgets

%matplotlib widget

# Auto formatting of cells
#%load_ext jupyter_black

## Experiment specific Functions

In [5]:
PROPOSAL = 11018955
USER = getuser()

### Loading data

In [6]:
BASEFOLDER = "/asap3/petra3/gpfs/p04/2024/data/%s/" % PROPOSAL
sample_name = "2403_tomo"


# Load any kind of data from collection
def load_pre_scan_snapshot(scan_id, key):
    fname = join(BASEFOLDER, "raw", "%s_%05d.h5" % (sample_name,scan_id))
    entry = "entry%d"%scan_id
    
    with h5py.File(fname, "r") as f:
        data = np.array(f[entry]["measurement"]["pre_scan_snapshot"][key][()])
        print("Loaded: %s" % (fname))
    return np.squeeze(data)

# Load any kind of data from measurements
def load_data(scan_id, key):
    fname = join(BASEFOLDER, "raw", "%s_%05d.h5" % (sample_name,scan_id))
    entry = "entry%d"%scan_id

    with h5py.File(fname, "r") as f:
        data = np.array(f[entry]["measurement"][key][()])
        print("Loaded key %s of %s" % (key, entry))
    return np.squeeze(data)


def get_image_id(im_id):
    fname = join(BASEFOLDER, "raw", "%s_%05d.h5" % (sample_name,im_id))
    entry = "entry%d"%im_id

    with h5py.File(fname, "r") as f:
        spe_id = str(f[entry]["measurement"]["ccd2"][()])[-10:-6]
    return spe_id

def load_spe_files(spe_id):
    fname = join(BASEFOLDER, "raw", sample_name +"  " +str(spe_id) + ".spe")
    im_out = np.array(imageio.mimread(fname, memtest="5000MB"))

    return im_out
    
# Load image files
def load_images(im_id):
    """
    Load ccd images from nxs files
    """
    
    spe_id = get_image_id(im_id)
    im_out = load_spe_files(spe_id)
    
    return im_out


# Full image loading procedure
def load_processing(im_id, crop=None):
    """
    Loads images, averaging of two individual images (scans in tango consist of two images),
    padding to square shape, Additional cropping (optional)
    """

    # Load data
    images = load_images(im_id)

    # Zeropad to get square shape
    images = sup.padding(images)

    # Calculate mean
    if images.ndim > 2:
        image = np.mean(images, axis=0)
    else:
        image = images.copy()

    # Optional cropping
    if crop is not None:
        images = images[:, :crop, :crop]
        image = image[:crop, :crop]

    return image

### Loading, saving fth & cdi data

In [7]:
# Saving of log files for fth and cdi recos
def save_fth_h5():
    # Save h5
    data = {}
    data["im_id"] = im_id
    data["dark_id"] = dark_id
    data["im_centered"] = im_c
    data["recon"] = recon
    data["center"] = center
    data["roi"] = roi
    data["prop_dist"] = prop_dist
    data["phase"] = phase
    data["mask_bs"] = mask_pixel
    data["experimental_setup"] = experimental_setup

    filename = join(
        folder_general, "Logs", "Data_ImId_%04d_fth_%s" % (im_id, USER)
    )
    print("Now Saving: %s" % filename)
    cci.create_hdf5(data, filename)


def save_cdi_h5():
    # Save h5
    data = {}
    data["im_id"] = im_id
    data["dark_id"] = dark_id
    data["pos"] = pos
    data["center"] = center
    data["roi"] = roi
    data["prop_dist"] = prop_dist_cdi
    data["phase"] = phase_cdi
    data["mask_bs"] = mask_bs_cdi
    data["supportmask"] = supportmask
    data["mask_pixel"] = mask_pixel
    data["reco"] = p
    data["reco_pc"] = p_pc
    data["experimental_setup"] = experimental_setup

    filename = join(
        folder_general,
        "Logs",
        "Data_ImId_%04d_cdi_%s" % (im_id, USER),
    )
    print("Now Saving: %s" % filename)
    cci.create_hdf5(data, filename)
    return

def load_cdi(im_id):
    """
    Load cdi dataset
    """
    fname = join(
        folder_general,
        "Logs",
        "Data_ImId_%04d_cdi_%s.hdf5" % (im_id, USER),
    )

    print("Loading: %s"%fname)

    
    with h5py.File(fname, "r") as f:
        data = {}
        data["im_id"] = f["im_id"][()]
        data["roi"] = f["roi"][()]
        data["prop_dist"] = f["prop_dist"][()]
        data["phase"] = f["phase"][()]
        data["supportmask"] = f["supportmask"][()]
        data["reco"] = f["reco"][()]
        data["reco_pc"] = f["reco_pc"][()]
        #data["srotx"] = f["experimental_setup"]["srotx"][()]
        data["srotz"] = f["experimental_setup"]["srotz"][()]
    return data

def load_fth(im_id):
    """
    Load fth dataset
    """
    fname = join(
        folder_general,
        "Logs",
        "Data_ImId_%04d_fth_%s.hdf5" % (im_id, USER),
    )
    
    with h5py.File(fname, "r") as f:
        data = {}
        data["im_id"] = f["im_id"][()]
        data["dark_id"] = f["dark_id"][()]
        data["im_centered"] = f["im_centered"][()]
        data["center"] = f["center"][()]
        data["roi"] = f["roi"][()]
        data["prop_dist"] = f["prop_dist"][()]
        data["phase"] = f["phase"][()]
        data["mask_bs"] = f["mask_bs"][()]
        data["recon"] = f["recon"][()]
        data["experimental_setup"] = f["experimental_setup"][()]
    return data

### Masking

In [8]:
from matplotlib.path import Path


def create_single_polygon_mask(shape, coordinates):
    """
    Creates a polygon mask from coordinates of corner points

    Parameter
    =========
    shape : int tuple
        shape/dimension of output array
    coordinates: nested list
        coordinates of polygon corner points [(yc_1,xc_1),(yc_2,xc_2),...]


    Output
    ======
    mask: array
        binary mask where filled polygon is "1"
    ======
    author: ck 2023
    """

    x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]))
    x, y = x.flatten(), y.flatten()

    points = np.vstack((x, y)).T

    path = Path(coordinates)
    mask = path.contains_points(points)
    mask = mask.reshape(shape)
    return mask


def create_polygon_mask(shape, coordinates):
    """
    Creates multiple polygon masks from set of coordinates of corner points

    Parameter
    =========
    shape : int tuple
        shape/dimension of output array
    coordinates: nested list
        coordinates of polygon corner points for multiple polygons
        [[(yc_1,xc_1),(yc_2,xc_2),...],[(yc_1,xc_1),(yc_2,xc_2),...]]

    Output
    ======
    mask: array
        binary mask where filled polygons are "1"
    ======
    author: ck 2023
    """

    if len(coordinates) == 1:
        mask = create_single_polygon_mask(shape, coordinates[0])

    # Loop over coordinates
    elif len(coordinates) > 1:
        mask = np.zeros(shape)
        for coord in coordinates:
            mask = mask + create_single_polygon_mask(shape, coord)
            mask[mask > 1] = 1

    return mask


def load_poly_masks(shape, polygon_name_list):
    """
    Loads set of polygon masks based on stored coordinates

    Parameter
    =========
    shape : tuple
        shape of output mask
    polygon_name_list : list
        keys of different mask coordinates to load

    Output
    ======
    mask: array
        binary mask where filled polygons are "1"
    ======
    author: ck 2023
    """

    mask = []

    # Load dictionary of coordinates
    mask_coordinates = load_poly_coordinates()

    # Loop over relevant mask keys
    for polygon_name in polygon_name_list:
        coord = mask_coordinates[polygon_name]
        mask.append(create_polygon_mask(im_c.shape, coord).astype(float))

    # Combine all individual mask layers
    mask = np.array(mask)
    mask = np.sum(mask, axis=0)
    mask[mask > 1] = 1

    return mask

In [9]:
def auto_shift_mask(
    mask,
    image,
    shift_range_y=(-10, 10),
    shift_range_x=(-10, 10),
    step_size=1,
    crop=None,
    method="minimize",
):
    """
    Automatically shifts a binary (beamstop) mask to the optimal position

    Parameter
    =========
    mask : array
        binary mask, e.g., beamstop mask
    image : array
        image partially covered by mask
    shift_range_y, _x : tupel
        Min and max limit for search area in y- and x-direction
    step_size : scalar
        step size of scan area
    crop : int
        additional cropping of arrays to speed up calculations
        Crops according to [crop:-crop,crop:-crop]
    method : string
        method used for optimization. 'minimize' or 'maximize' metric <mask,image>

    Output
    ======
    optimized shift: tupel
        shift vector for optimized posistion
    mask_shifted: array
        mask shifted according to best shift vector
    overlap : array
        computed evaluation metric
    ======
    author: ck 2023
    """

    # Basic looping
    # Create set of arrays for shifting
    yshift = np.arange(shift_range_y[0], shift_range_y[1], step_size)
    xshift = np.arange(shift_range_x[0], shift_range_x[1], step_size)

    # Setup loss function to evaluate overlap
    overlap = np.zeros((yshift.shape[0], xshift.shape[0]))

    # Optinal cropping to reduce computation time
    if crop is not None:
        tmask, timage = mask[crop:-crop, crop:-crop], image[crop:-crop, crop:-crop]
    else:
        tmask, timage = mask.copy(), timage.copy()

    # Loop over all combinations of shifts
    for i, y in enumerate(yshift):
        for j, x in enumerate(xshift):
            # Shift mask to new position
            mask_shift = cci.shift_image(tmask, [y, x])

            # Calculate overlap
            overlap[i, j] = np.sum(mask_shift * timage)

    # Get best shift
    if method == "minimize":
        idx = np.unravel_index(
            np.argmin(
                overlap,
            ),
            overlap.shape,
        )
    elif method == "maximize":
        idx = np.unravel_index(
            np.argmax(
                overlap,
            ),
            overlap.shape,
        )

    # Output
    optimized_shift = (yshift[idx[0]], yshift[idx[1]])
    print(
        "Best mask overlap for shift: [%.2f,%.2f]"
        % (optimized_shift[0], optimized_shift[1])
    )

    # Shift
    mask_shifted = cci.shift_image(mask, optimized_shift)

    # Binarize mask (necessary for sub-px shift)
    mask_shifted[mask_shifted > 0.5] = 1
    mask_shifted[mask_shifted < 0.5] = 0

    return optimized_shift, mask_shifted, overlap

In [10]:
def create_supportmask(support_coordinates, shape):
    """
    Create cdi support mask from a combination of multiple circular apertures

    Parameter
    =========
    support_coordinates: nested list
        Contains center coordinates and radius of each aperture [[yc_1,xc_1,r_1],[yc_2,xc_2,r_2],...]
    shape : int tuple
        shape/dimension of output array

    Output
    ======
    supportmask: array
        composed binary mask where circular apertures are "1"
    ======
    author: ck 2023

    """

    # Create support mask
    supportmask = np.zeros(shape)
    for i in range(len(support_coordinates)):
        supportmask += cci.circle_mask(
            supportmask.shape,
            [support_coordinates[i][0], support_coordinates[i][1]],
            support_coordinates[i][2],
        )

    return supportmask

In [11]:
def create_ellipse_supportmask(support_coordinates, shape):
    """
    Create cdi support mask as a combination of multiple ellipses

    Parameter
    =========
    support_coordinates: nested list
        Contains center coordinates, height, width and rotation angle of each aperture 
        [[(yc_1,xc_1),height_1,width_1,angle_1],[(yc_2,xc_2),height_2,width_2,angle_2],...]
    shape : int tuple
        shape/dimension of output array

    Output
    ======
    supportmask: array
        composed binary mask where ellipses apertures are "1"
    ======
    author: ck/sg 2024

    """

    # Create support mask
    supportmask = np.zeros(shape)
    for i in range(len(support_coordinates)):
        center, height, width, angle = support_coordinates[i][0], support_coordinates[i][1], support_coordinates[i][2], support_coordinates[i][3]
        yy, xx = ellipse(center[1],center[0],height/2,width/2,rotation=-angle) 
        supportmask[yy,xx] = 1
    
    return supportmask

### Other

In [12]:
def differential_operator(shape,center, experimental_setup, angle = 0):
    '''
    Calculates Fourier-space differential operator for heraldo reconstruction
    
    Parameter
    ---------
    shape: int tuple
        shape of output array
    center: tuple
        array center coordinates (y,x) 
    experimental_setup: dict
        must contain detector pixel_size ["px_size"], distance ["ccd_dist"],
        wavelength ["lambda"]
    angle: float
        rotation angle of heraldo slit
        
    Returns:
    --------
    return: complex array
        differential operator in Fourier space
        
    '''
    
    # Convert deg to rad
    angle = np.deg2rad(angle)
    
    # Create x,y grid to convert pixel in q-space
    x, y = np.meshgrid(np.arange(shape[0]),np.arange(shape[0]))
    
    # Center meshgrid
    y, x = y - center[0],  x - center[1]
    
    # Multiplay with pixel size
    y, x = y*experimental_setup["px_size"], x*experimental_setup["px_size"]
    
    # Convert to q-space
    qy = 4*np.pi/experimental_setup["lambda"] * np.sin(0.5*np.arctan(y/experimental_setup["ccd_dist"]))
    qx = 4*np.pi/experimental_setup["lambda"] * np.sin(0.5*np.arctan(x/experimental_setup["ccd_dist"]))
    
    # Normalize q space to [-1,1]
    qy, qx = qy/np.max(np.abs(qy)), qx/np.max(np.abs(qx))
    
    return 2j*np.pi*qx*np.cos(angle) + 2j*np.pi*qy*np.sin(angle)

In [13]:
def reconstruct_heraldo(holo, experimental_setup, center = None, prop_dist = 0, phase = 0,angle = 0):
    '''
    Reconstruction of holograms in heraldo reference scheme
    
    Parameter:
    ----------
    holo: array
        Centered input hologram
    experimental_setup: dict
        must contain detector pixel_size ["px_size"], distance ["ccd_dist"],
        wavelength ["lambda"]
    center: tuple or None
        array center coordinates (y,x) 
    prop_dist: float
        propagation distance
    phase: float
        global phase shift of complex array
    angle: float
        rotation angle of heraldo slit
        
    returns:
    image: array
        reconstructed image
    heraldo_operator: complex array
        differential operator in Fourier space
    '''
    
    if center is None:
        center = np.array(holo.shape)/2
    
    heraldo_operator = differential_operator(holo.shape,center, 
                    experimental_setup, angle = angle)
    holo = holo * heraldo_operator
    holo = fth.propagate(holo, prop_dist*1e-6, experimental_setup = experimental_setup)*np.exp(1j*phase)
    image = fth.reconstruct(holo)
    
    return image, heraldo_operator

In [14]:
# Worker which performs complete fth reconstruction process
def worker(image):
    # Centering
    shift_c = np.array(image.shape) / 2 - center
    im_c = cci.shift_image(image, shift_c)

    # Optional ewalds sphere projection
    if project_ewalds_sphere is True:
        %%time
        proj_im = PhR.inv_gnomonic(im_c, center=None, experimental_setup = experimental_setup, method='cubic' , mask=None)
        im_c = proj_im.copy()
    
    # Apply beamstop to image
    im_b = im_c * mask_bs

    # Create masks
    mask = load_poly_masks(im_c.shape, polygon_names)
    mask = np.round(cci.shift_image(mask,mask_shift))
    
    # Increase/Decrease mask size
    footprint = skimage.morphology.disk(mask_scale)
    if mask_scale > 0:
        mask = skimage.morphology.dilation(mask, footprint) # increase size
    elif mask_scale < 0: 
        mask = skimage.morphology.erosion(mask, footprint) # decrease size

    # Optimize position of drawn mask relative to target image
    if optimize_position is True:
        # level 1 (rough)
        optimized_shift, _, _ = auto_shift_mask(
            mask,
            im_c,
            shift_range_y=[-10, 10],
            shift_range_x=[-10, 10],
            step_size=2,
            crop=300,
        )
        # Level 2 (fine)
        optimized_shift, mask_shifted, overlap = auto_shift_mask(
            mask,
            im_c,
            shift_range_y=[optimized_shift[0] - 4, optimized_shift[0] + 4],
            shift_range_x=[optimized_shift[1] - 4, optimized_shift[1] + 4],
            step_size=0.5,
            crop=300,
        )

    # optional. More mask functions
    

    # BS center part
    mask_bs_center = cci.circle_mask(
        image.shape, np.array(image.shape) / 2, bs.rBS-3, sigma=None
    )

    # Create image specific beamstop mask
    mask_pixel = mask.copy()
    mask_pixel = mask_pixel + (im_c > 40000) + mask_bs_center
    mask_pixel[mask_pixel > 1] = 1

    # Create smooth mask
    footprint = skimage.morphology.disk(6)
    mask_pixel_smooth = skimage.morphology.dilation(mask_pixel, footprint)
    mask_pixel_smooth = gaussian_filter(mask_pixel_smooth, 4)

    if heraldo is True:
        #holo = sum_c * heraldo_operator * mask_bs*(1 - mask_pixel_smooth)
        holo = diff_c * heraldo_operator * mask_bs*(1 - mask_pixel_smooth)
    elif heraldo is False:
        #holo = sum_c * mask_bs*(1 - mask_pixel_smooth)
        holo = diff_c * mask_bs*(1 - mask_pixel_smooth)
        
    # Reconstruct
    recon = cci.reconstruct(
        fth.propagate(holo, prop_dist * 1e-6, experimental_setup=experimental_setup)
        * np.exp(1j * phase)
    )

    # worker dictionary
    worker_dict = {}
    worker_dict["im_c"] = im_c
    worker_dict["holo"] = holo
    worker_dict["recon"] = recon
    worker_dict["mask_pixel"] = mask_pixel
    
    return worker_dict

In [15]:
# Setup phase and propagation for cdi once
phase_cdi = 0
prop_dist_cdi = 0
dx = 0
dy = 0


def phase_retrieval(
    pos, mask_pixel, supportmask, vmin = 0, Startimage=None, Startgamma=None
):
    # Prepare Input holograms
    pos2 = pos.copy()

    mi, _ = np.percentile(pos2[pos2 != 0], [vmin, 99.9])
    pos2 = pos2 - mi

    pos2[pos2 < 0] = 0
    pos2 = pos2.astype(complex)

    bsmask_p = mask_pixel.copy()
    bsmask_p[pos2 <= 0] = 1

    # Setup start image and startgamma
    if Startimage is None:
        Startimage = np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift(supportmask)))
    else:
        Startimage = Startimage.copy()
    if Startgamma is None:
        Startgamma = np.ones(pos.shape) * 1e-6 * 2
        Startgamma[pos.shape[0] // 2, pos.shape[1] // 2] = 0.7
    else:
        Startgamma = Startgamma.copy()

    # Settings for phase retrieval reconstructions
    partial_coherence = True

    # Setup
    retrieved_p = np.zeros(pos2.shape, np.cdouble)
    retrieved_n = np.zeros(pos2.shape, np.cdouble)

    # Algorithms and Inital guess
    plt.rcParams["figure.dpi"] = 100
    print("CDI - larger mask")

    algorithm_list = ["mine", "mine", "mine"]
    Nit_list = [700, 50, 50]  # iterations for algorithm_list

    x = (np.sqrt(np.maximum(pos2, np.zeros(pos2.shape)))[mask_pixel == 0]).flatten()
    y = ((np.abs(Startimage))[mask_pixel == 0]).flatten()
    res = stats.linregress(x, y)
    Startimage -= res.intercept
    Startimage /= res.slope

    average_img = 30
    real_object = False  # always set to False

    if partial_coherence:
        RL_freq = 20
        RL_it = 50

        algorithm_list_pc = ["mine", "ER", "ER"]
        Nit_list_pc = [700, 50, 50]

    # Execute Phase retrieval
    start_time = time.time()
    for i in range(len(Nit_list) // 3):
        print("############ -   CDI")

        # Positive helicity - beta_mode="arctan"
        retrieved_p, Error_diff_p, Error_supp = PhR.PhaseRtrv_GPU(
            diffract=np.sqrt(np.maximum(pos2, np.zeros(pos2.shape))),
            mask=supportmask,
            mode=algorithm_list[3 * i],
            beta_zero=0.5,
            Nit=Nit_list[3 * i],
            beta_mode="arctan",
            plot_every=349,
            Phase=Startimage,
            seed=False,
            real_object=real_object,
            bsmask=bsmask_p,
            average_img=average_img,
            Fourier_last=True,
        )

        # Positive helicity - beta_mode="const"
        retrieved_p, Error_diff_p2, Error_supp = PhR.PhaseRtrv_GPU(
            diffract=np.sqrt(np.maximum(pos2, np.zeros(pos2.shape))),
            mask=supportmask,
            mode=algorithm_list[3 * i + 1],
            beta_zero=0.5,
            Nit=Nit_list[3 * i + 1],
            beta_mode="const",
            plot_every=24,
            Phase=retrieved_p,
            seed=False,
            real_object=real_object,
            bsmask=bsmask_p,
            average_img=average_img,
            Fourier_last=True,
        )


        print("--- %s seconds ---" % np.round((time.time() - start_time), 2))

        Startimage = retrieved_p.copy()

        # Partial coherence phase retrieval
        if partial_coherence:
            # CDI_PC
            print("############   -   CDI_pc")
            pos3 = (np.abs(retrieved_p) ** 2) * bsmask_p + np.maximum(
                pos2, np.zeros(pos2.shape)
            ) * (1 - bsmask_p)

            # retrieve pos image
            (
                retrieved_p_pc,
                Error_diff_p_pc,
                Error_supp,
                gamma_p,
            ) = PhR.PhaseRtrv_with_RL(
                diffract=np.sqrt(pos3),
                mask=supportmask,
                mode=algorithm_list_pc[3 * i],
                beta_zero=0.5,
                Nit=Nit_list_pc[3 * i],
                beta_mode="arctan",
                gamma=Startgamma,
                RL_freq=RL_freq,
                RL_it=RL_it,
                plot_every=349,
                Phase=Startimage,
                seed=False,
                real_object=False,
                bsmask=np.zeros(bsmask_p.shape),
                average_img=average_img,
                Fourier_last=True,
            )

            (
                retrieved_p_pc,
                Error_diff_p_pc2,
                Error_supp,
                gamma_p,
            ) = PhR.PhaseRtrv_with_RL(
                diffract=np.sqrt(pos3),
                mask=supportmask,
                mode=algorithm_list[3 * i + 1],
                beta_zero=0.5,
                Nit=Nit_list_pc[3 * i + 1],
                beta_mode="const",
                gamma=gamma_p,
                RL_freq=RL_freq,
                RL_it=RL_it,
                plot_every=24,
                Phase=retrieved_p_pc,
                real_object=False,
                bsmask=np.zeros(bsmask_p.shape),
                average_img=average_img,
                Fourier_last=True,
            )

            print("--- %s seconds ---" % np.round((time.time() - start_time), 2))

            Startimage = retrieved_p_pc.copy()
            Startgamma = gamma_p.copy()

    print("Phase Retrieval Done!")

    return (
        retrieved_p,
        retrieved_p_pc,
        bsmask_p,
        gamma_p,
    )

In [16]:
def optimize_phase_contrast(recon, supportmask, method="contrast", prefered_color=None):
    """
    Automatically shifts contrast of phase retrieval reconstruction into real part

    Parameter
    =========
    recon : complex array
        FTH/CDI reconstruction plane (Patterson map)
    supportmask : array
        Supportmask of Patterson map for phase retrieval
    method : string
        Choose method for phase optimization ("contrast","minima","maxima")
    prefered_color : string or None
        Shift contrast such that color of domains with largest are is white ("white"),
        black ("black") or non-specific (None)

    Output
    ======
    recon_optimized: complex array
        reconstruction with optimized contrast
    optimized_phase: complex scalar
        phase corresponding to optimized reconstruction
    ======
    author: ck 2023
    """

    # filter references from supportmask
    mask = supportmask.copy()
    mask = mask.astype(bool)
    mask = skimage.morphology.remove_small_objects(mask, min_size=200)
    mask = mask.astype(float)

    # Make object aperture smaller to minimize edge effects
    footprint = skimage.morphology.disk(4)
    mask = skimage.morphology.erosion(mask, footprint)

    # Gaussian filter to remove high intensity peaks
    reco = scipy.ndimage.gaussian_filter(recon, 1)

    # Different functions for optimization
    def contrast(phi, reco, tmask):
        temp = np.imag(reco * np.exp(1j * phi)) * tmask
        mi, ma = np.percentile(temp[temp != 0], [1, 99])
        contrast = ma - mi
        return contrast

    def minima(phi, reco, tmask):
        tmp = np.real(reco * np.exp(1j * phi)) * tmask
        minima, maxima = np.percentile(tmp[tmp != 0], [0.01, 99])
        return minima

    def maxima(phi, reco, tmask):
        tmp = np.real(reco * np.exp(1j * phi)) * tmask
        minima, maxima = np.percentile(tmp[tmp != 0], [0.01, 99])
        return maxima

    # Choose optimization method
    if method == "minima":
        optimized_phase = scipy.optimize.fminbound(
            minima, -np.pi, np.pi, args=(recon, mask), disp=False
        ).astype(float)
    elif method == "contrast":
        optimized_phase = scipy.optimize.fminbound(
            contrast, -np.pi, np.pi, args=(recon, mask), disp=False
        ).astype(float)
    elif method == "max":
        optimized_phase = scipy.optimize.fminbound(
            contrast, -np.pi, np.pi, args=(recon, mask), disp=False
        ).astype(float)

    # Calc optimized reconstruction
    recon_optimized = recon * np.exp(1j * optimized_phase)

    # Optional: Shift phase such that "background" are white or black
    mi_p, ma_p = np.percentile(recon_optimized[mask == 1], [1, 99])
    mean = np.mean(recon_optimized[mask == 1])
    if prefered_color == "white":
        # Make white domains the dominant domains
        if mean < (mi_p + ma_p) / 2:
            optimized_phase = optimized_phase + np.pi
            recon_optimized = recon * np.exp(1j * optimized_phase)
    elif prefered_color == "black":
        # Make black domains the dominant domains
        if mean > (mi_p + ma_p) / 2:
            optimized_phase = optimized_phase + np.pi
            recon_optimized = recon * np.exp(1j * optimized_phase)

    return recon_optimized, optimized_phase

In [17]:
def correct_background(image):
    offset = np.quantile(image[80:120],0.1,axis=0)
    mask_vert = np.ones(image.shape)
    image = image - offset * mask_vert

    #offset = np.quantile(image[-40:],0.1,axis=0)
    #mask_vert = np.ones(image.shape)
    #mask_vert[int(image.shape[0]/2):] = 0
    #image = image - offset * mask_vert

    image[56:59,:]  = 0
    
    return image

In [18]:
import tomopy
def get_mask_projections(mask, theta):
    '''
    mask: 2D array
    theta: 1D list of angles in rad

    return projections of mask
    '''

    thick_mask = np.dstack(( np.zeros(mask.shape),mask, np.zeros(mask.shape)))
    proj = tomopy.project(np.rot90(thick_mask, axes=(1,2)), theta, pad=True)
    proj[proj>0]=1
    proj = proj[:,:,2:-2]
    proj = np.squeeze(proj)
    
    return proj

# Experimental Details

In [19]:
# Dict with most basic experimental parameter
experimental_setup = {
    "ccd_dist": 0.20,  # ccd to sample distance
    "px_size": 13.5e-6,  # pixel_size of camera
    "binning": 1,  # Camera binning
}

# Setup for azimuthal integrator
detector = Detector(
    experimental_setup["binning"] * experimental_setup["px_size"],
    experimental_setup["binning"] * experimental_setup["px_size"],
)

experimental_setup["z_angle_offset"] = 67

# General saving folder and log folder
folder_general = sup.create_folder(join(BASEFOLDER, "processed"))
sup.create_folder(join(folder_general, "Logs"))
sup.create_folder(join(folder_general, "Topos"))
print("Output Folder: %s" % folder_general)

Output Folder: /asap3/petra3/gpfs/p04/2024/data/11018955/processed


# Load images

Start by loading the images: image of interest (im), any kind of dark image (dark)

In [None]:
# Define scan ids for each image
im_id = 1024
dark_id = 1025

# Are you using a HERALDO dataset?
heraldo = True

# Are you going to use ewalds sphere projection?
project_ewalds_sphere = False

# Load energy and add to experimental setup
experimental_setup["energy"] = load_pre_scan_snapshot(im_id,"energy")
experimental_setup["lambda"] = cci.photon_energy_wavelength(
    experimental_setup["energy"], input_unit="eV"
)

# Load angles
experimental_setup["srotx"] = load_pre_scan_snapshot(im_id,"srotx")
experimental_setup["srotz"] = load_pre_scan_snapshot(im_id,"srotz")

print(experimental_setup)
print("Image Id: %s" % im_id)



## Load image of interest

In [None]:
# Load image
image = load_processing(im_id, crop=None)

# Plot
fig, ax = cimshow(image)
ax.set_title("Image")

## Load dark image

In [None]:
# Load image
if dark_id is not None:
    dark = load_processing(dark_id, crop=None)
    image = image - dark
    #image = correct_background(image)

    # Plot
    fig, ax = cimshow(dark)
    ax.set_title("Dark")

In [None]:
fig, ax = cimshow(image)
ax.set_title("Background subtracted image")

# Center holograms

* Find center of the hologram to get a well-defined q-space. 
* Create smooth mask for beamstop or overexposed areas in direct beam

## Basic widget to find center

Try to **align** the circles to the **center of the scattering pattern**. Care! Position of beamstop might be misleading and not represent the actual center of the hologram. Circles are just a guide to eye and will not be used otherwise.

In [None]:
# Find center position via widget
c0, c1 = [1009, 1099]  # initial values
ic = interactive.InteractiveCenter(gaussian_filter(image,3), c0=c0, c1=c1)

In [None]:
# Get center positions
center = [ic.c0, ic.c1]
print(f"Center:", center)

## Azimuthal integrator widget for finetuning
More of an "expert widget" which works very well for alignment if you have an Airy Pattern as a scattering image. PyFai transforms images from carthesian detector coordinate system into polar coordinate system with angle `phi` and radial distance `q` as axis (azimuthal transformation). The center of the coordinate system will be defined in the azimuthal integrator class and must not necessarily represents the center coordinates of your image array. If the center is set correctly, all rings of the Airy pattern will be transformed into a straight line in the I(q,chi)-plot as rings appear at a given q for all angles chi.

## Centering of image hologram

In [None]:
# Apply to topo and image
shift_c = np.array(image.shape) / 2 - center
im_c = cci.shift_image(image, shift_c)

## Calculate Projection

In [None]:
if project_ewalds_sphere is True:
    %%time
    proj_im = PhR.inv_gnomonic(im_c, center=None, experimental_setup = experimental_setup, method='cubic' , mask=None)
    im_c = proj_im.copy()

# Create beamstops

We want to cover the beamstop with a smooth circle to cover its sharp edges as these would create ringing-like artifacts in the reconstruction plane. Make it only as large as necessary to keep as much information as possible.

## Draw Circle beamstop

Set beamstop diameter and std for smoothing filter. Higher values mean stronger smoothing. If you have a very small beamstop you might need to reduce the smoothing value. Otherwise the sharp gradient of the real beamstop will still be visible

In [None]:
bs = interactive.InteractiveBeamstop(
    im_c, im_c.shape[0] / 2, im_c.shape[1] / 2, rBS=50, stdBS=4
)

In [None]:
# Take value from widget and create beamstop mask
bs_diam = bs.rBS
bs_smoothing = bs.stdBS
mask_bs = 1 - cci.circle_mask(
    image.shape, np.array(image.shape) / 2, bs.rBS, sigma=bs_smoothing
)

# Apply beamstop to image
im_b = im_c * mask_bs

# Plot image with beamstop and valid pixel mask
fig, ax = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(9, 3))
mi, ma = np.percentile(im_b, [0.1, 99.9])
ax[0].imshow(im_b, cmap="viridis", vmin=mi, vmax=ma)
ax[0].set_title("Masked image")

ax[1].imshow(mask_bs)
ax[1].set_title("Beamstop mask")

ax[2].imshow(1 - mask_bs)
ax[2].set_title("1 - Beamstop mask")


## Manual masking of beamstop wires

Just mask the beamstop wires, broken pixels, etc. 

In [None]:
poly_mask = interactive.draw_polygon_mask(im_c)

In [None]:
# Take poly coordinates and mask from widget
p_coord = poly_mask.coordinates
mask_draw = poly_mask.full_mask.astype(int)

print("Copy these coordinates into the 'load_poly_coordinates()' function:")
print(p_coord)

# Plot image with beamstop and valid pixel mask
fig, ax = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(9, 3))
mi, ma = np.percentile(im_c * (1 - mask_draw), [0.1, 99.9])
ax[0].imshow(im_c * (1 - mask_draw), cmap="viridis", vmin=mi, vmax=ma)
ax[0].set_title("Image * (1-mask_draw)")

mi, ma = np.percentile(im_c * mask_draw, [0.1, 99.9])
ax[1].imshow(im_c * mask_draw, vmin=mi, vmax=ma)
ax[1].set_title("Image * mask_draw")

ax[2].imshow(1 - mask_draw)
ax[2].set_title("1 - mask_draw")


In [None]:
def load_poly_coordinates():
    """
    Dictionary that stores polygon corner coordinates of all drawn masks
    Example: How to add masks with name "test":
    mask_coordinates["test"] = copy coordinates from above
    """

    # Setup dictonary
    mask_coordinates = dict()

    # Setup dictonary
    mask_coordinates = dict()

    # Mask #1
    mask_coordinates["bs_tiny"] = [[(1017.866806601309, 149.77970139629696), (1017.9864760868489, 153.21022664844236), (1024.2092893349268, 153.0506673343891), (1024.0896198493867, 149.97915053886354)], [(1031.6293165413158, 965.2245984409541), (1026.8137667514804, 968.289039216304), (1029.6593189000196, 973.5423662597608), (1032.94264830218, 973.7612548865716), (1035.5693118239085, 969.3834823503574)], [(1030.3159847804516, 993.2423426727241), (1027.689321258723, 995.2123403140205), (1026.5948781246698, 998.495669716181), (1027.908209885534, 1002.2167763719631), (1030.7537620340731, 1003.9678853864486), (1033.3804255558016, 1001.3412218647202), (1034.0370914362336, 996.5256720748847)], [(1028.3459871391553, 1061.0978169840423), (1025.9382122442375, 1061.0978169840423), (1025.2815463638055, 1071.6044710709562), (1025.0626577369949, 1083.2055682919233), (1027.689321258723, 1082.9866796651127), (1028.7837643927767, 1074.668911846306)], [(1021.7379140721885, 978.9420126138665), (1018.8591114173319, 979.1819128351044), (1019.1469916828174, 982.0127354457135), (1020.8742732757314, 983.308196640399), (1022.9853952226263, 982.9243562864181), (1022.8894351341311, 979.0859527466092)], [(1033.1091845588721, 977.9824117289141), (1027.92733978013, 977.9824117289143), (1028.1672400013683, 981.9647554014658), (1029.2228009748158, 983.2602165961514), (1032.9652444261292, 983.2122365519037)], [(1029.2228009748155, 972.0808662864582), (1029.0308807978251, 980.2374738085518), (1029.0788608420728, 987.2425602687031), (1030.4702821252536, 988.8738817731218), (1032.1975637181674, 988.7299416403789), (1031.9576634969294, 983.5001168173894), (1032.9172643818818, 972.6086467731818)]]
    return mask_coordinates

In [None]:
# Which drawn masks do you want to load? You can combine multiple masks from
# load_poly_coordinates(). Just add names of mask as strings to list like
# ["bs_small","bs_medium"]
polygon_names = ["bs_tiny_2"]
mask_draw = load_poly_masks(
    im_c.shape,
    polygon_names,
)

#footprint = skimage.morphology.disk(20)
#mask_draw = skimage.morphology.dilation(mask_draw, footprint) # increase size
#mask_draw = skimage.morphology.erosion(mask_draw, footprint) # decrease size

# The relative position of the drawn beamstop with respect to the actual beamstop
# might change due to sample change, sample drift, realignment, etc. The drawn
# beamstop mask would therefore not cover the actual beamstop position. This function
# aligns the beamstop mask with respect to the actual position
optimize_position = False

# Optimize position of drawn mask relative to target image
if optimize_position is True:
    # level 1 (rough)
    optimized_shift, _, _ = auto_shift_mask(
        mask_draw,
        im_c,
        shift_range_y=[-10, 10],
        shift_range_x=[-10, 10],
        step_size=2,
        crop=300,
    )
    # Level 2 (fine)
    optimized_shift, mask_shifted, overlap = auto_shift_mask(
        mask_draw,
        im_c,
        shift_range_y=[optimized_shift[0] - 4, optimized_shift[0] + 4],
        shift_range_x=[optimized_shift[1] - 4, optimized_shift[1] + 4],
        step_size=0.5,
        crop=300,
    )

    mask_draw = mask_shifted.copy()

# Shift mask
#mask_draw = np.round(cci.shift_image(mask_draw,[2,0]))

# Plot image with beamstop and valid pixel mask
fig, ax = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(9, 3))
mi, ma = np.percentile(im_c * (1 - mask_draw), [0.1, 99.9])
ax[0].imshow(im_c * (1 - mask_draw), cmap="viridis", vmin=mi, vmax=ma)
ax[0].set_title("Image * (1-mask_draw)")

mi, ma = np.percentile(im_c * mask_draw, [0.1, 99.9])
ax[1].imshow(im_c * mask_draw, vmin=mi, vmax=ma)
ax[1].set_title("Image * mask_draw")

ax[2].imshow(1 - mask_draw)
ax[2].set_title("1 - mask_draw")

## Finetuning of mask position

In [None]:
# Use widget to shift and expand or shrink the mask
ss_mask = interactive.Shift_Scale_Mask(im_c,mask_draw)

In [None]:
# Take mask, shift and scaling from widget
mask_draw, mask_shift, mask_scale = ss_mask.get_mask()

## Overview beamstops
Verify good beamstop alignment

In [None]:
# Mask over-saturated pixel
mask_im = (im_c > 40000)

mask_pixel = mask_draw + mask_im
mask_pixel = mask_pixel + cci.circle_mask(image.shape,[1026,1026],24)
mask_pixel[mask_pixel>1] = 1 

# Create smooth mask
footprint = skimage.morphology.disk(6)
mask_pixel_smooth = skimage.morphology.dilation(mask_pixel, footprint)
mask_pixel_smooth = gaussian_filter(mask_pixel_smooth, 4)

# Plot both
fig, ax = plt.subplots(2, 2, figsize=(8, 8), sharex=True, sharey=True)
mi, ma = np.percentile(im_c, [1, 99.9])
ax[0, 0].imshow(im_c*(1-mask_pixel), vmin=mi, vmax=ma)
ax[0, 0].set_title("Image")
mi, ma = np.percentile(im_c * mask_pixel, [1, 99.99])
ax[0, 1].imshow(im_c * mask_pixel, vmin=mi, vmax=ma)
ax[0, 1].set_title("Image*mask")
ax[1, 0].imshow(mask_pixel)
ax[1, 0].set_title("Mask_pixel")
ax[1, 1].imshow(mask_pixel_smooth)
ax[1, 1].set_title("mask_pixel_smooth")

# Reconstruct Holo (FTH)

Reconstruct the hologram.

0. If you are doing heraldo, determine the rotation angle of the hologram
1. Choose a region of interest (ROI) which means selecting one reconstruction from the reconstruction plane.
2. Propagate the image and shift the phase for maximal contrast and sharpness in your ROI

In [None]:
xticks, yticks = [1014,1036], []
shift_rotate = interactive.Shift_Rotate(im_c, shift = [0,0], angle = -1.25, ticks = (yticks,xticks))

## Set Patterson Map ROI

Choose the reconstructions as the ROI.

1. Zoom into the image and adjust your FOV until you are satisfied.
2. Save the axes coordinates.

In [None]:
if heraldo is True:
    _, _, heraldo_rotation = shift_rotate.get_parameter()
    tmp, heraldo_operator  = reconstruct_heraldo(im_c *mask_bs* (1 - mask_pixel_smooth), experimental_setup, 
                              center = None, prop_dist = 0, phase = 0,angle = heraldo_rotation)
else:
    tmp = fth.reconstruct(im_c * mask_bs*(1 - mask_pixel_smooth))
    
fig, ax = cimshow(np.real(tmp), cmap="gray")

In [None]:
# Execute to get roi
x1, x2 = ax.get_xlim()
y2, y1 = ax.get_ylim()

roi = np.array([y1, y2, x1, x2]).astype(int)  # ystart, ystop, xstart, xstop
#roi = [ 478,  801,  930, 1365]
roi_s = np.s_[roi[0] : roi[1], roi[2] : roi[3]]
print(f"Roi Reco:{roi}")

## Tune propagation and phase
Focus the image by tuning the propagation distance. This really works like focussing in a microscope.
Phase slider will move contrast between real and imaginary part. Usually we use the phase which maximizes the contrast in the real part.

In [None]:
# Widget
if heraldo is True:
    holo = im_c * heraldo_operator * mask_bs*(1 - mask_pixel_smooth)
elif heraldo is False:
    holo = im_c * mask_bs*(1 - mask_pixel_smooth)

slider_prop, slider_phase, button = reco.propagate(
    holo,
    roi_s,
    phase = 0, #Initial value
    prop_dist = 0, #Initial value
    experimental_setup=experimental_setup,
    scale=(1, 99),
)

In [None]:
# Read prop dist and phase from widget
prop_dist = slider_prop.value
phase = slider_phase.value

print(f"Propagation distance: %0.2f" % prop_dist)
print(f"Phase: %0.2f" % phase)

## Save reconstruction

Save png files of the images and a h5 file containing all important variables

In [None]:
# Style of reconstruction plot
def plot_recon(recon, title):
    # Plot
    fig, ax = plt.subplots(1, 2, figsize=(10, 4))
    fig.suptitle(title)

    vmin, vmax = np.percentile(np.real(recon), (1, 99))
    t_im1 = ax[0].imshow(np.real(recon), vmin=vmin, vmax=vmax, cmap="gray")
    ax[0].set_title("Real")
    #plt.colorbar(t_im1, ax=ax[0], aspect=50)

    vmin, vmax = np.percentile(np.imag(recon), (1, 99))
    t_im2 = ax[1].imshow(np.imag(recon), vmin=vmin, vmax=vmax, cmap="gray")
    ax[1].set_title("Imag")
    #plt.colorbar(t_im2, ax=ax[1], aspect=50)

In [None]:
def get_title(data_key,im_id,CDI=False):
    # Rotation in title
    #values = np.mean(np.array(load_pre_scan_snapshot(im_id, data_key)) * 1000)
    #values = [np.round(values, 2)]
    
    # Rotation in title
    values = load_pre_scan_snapshot(im_id, data_key)

    if CDI is False:
        mode = "FTH"
    elif CDI is True:
        mode = "CDI"
        
    if data_key == "srotz":
        title = "Image %d - %s: %s = %.1f (%.1f deg)" % (im_id, mode, data_key, values, values-experimental_setup["z_angle_offset"])
    else:
        title = "Image %d - %s: $s = %.2f "%(im_id, mode, data_key, values)
        
    return title

In [None]:
# Create plot
holo = im_c * mask_bs* (1 - mask_pixel_smooth)

# Reconstruct
if heraldo is True:
    recon, _ = reconstruct_heraldo(holo,experimental_setup,prop_dist = prop_dist, phase = phase, angle = heraldo_rotation)
elif heraldo is False:
    recon = fth.reconstruct(
        fth.propagate(holo, prop_dist * 1e-6, experimental_setup=experimental_setup)
        * np.exp(1j * phase)
    )

# Create plot
title = get_title("srotz",im_id)
plot_recon(recon[roi_s], title)

# Save images
fname = join(
    folder_general,
    "Recon_ImId_%04d_%s.png" % (im_id, USER),
)
print("Saving: %s" % fname)
plt.savefig(fname, bbox_inches="tight", transparent=False)

# Save hdf5 file
save_fth_h5()

In [None]:
# Closes all existing plots
plt.close("all")

# Batch processing FTH (needs to be updated)

In [None]:
# Define the sets for reconstructions. You can make a iterable list or use np.arange
im_id_set = [954]
dark_id_set = [953]

print("Image Set:  %s" % im_id_set)
print("Dark Id Set:  %s" % dark_id_set)

In [None]:
# Less ugly Automatic processing of image stacks
for it, im_id in enumerate(tqdm(im_id_set)):
    # Load energy and add to experimental setup
    experimental_setup["energy"] = load_pre_scan_snapshot(im_id,"energy")
    experimental_setup["lambda"] = cci.photon_energy_wavelength(
        experimental_setup["energy"], input_unit="eV"
    )
    
    # Load angles
    experimental_setup["srotx"] = load_pre_scan_snapshot(im_id,"srotx")
    experimental_setup["srotz"] = load_pre_scan_snapshot(im_id,"srotz")
    
    # Load images
    image = load_processing(im_id)

    # Load image
    dark_id = dark_id_set[it]
    if dark_id is not None:
        dark = load_processing(dark_id, crop=None)
        image = image - dark

    # Process images
    worker_dict = worker(image)

    # Take values from worker
    recon = worker_dict["recon"]

    # Plotting
    title = get_title("srotz",im_id)
    plot_recon(recon[roi_s], title)
    
    # Save images
    fname = join(
        folder_general,
        "Recon_ImId_%04d_fth_stack_%s.png" % (im_id, USER),
    )
    print("Saving: %s" % fname)
    plt.savefig(fname, bbox_inches="tight", transparent=False)

    # Save hdf5 file
    save_fth_h5()

# CDI Reconstruction

In [None]:
# Copy values from FTH reco
pos = im_c.copy()

## Draw support mask

The support mask is the real-space constraint used for the (holographically-aided) phase retrieval, i.e., certain details about our sample like the sample geometry. For our samples we can directly derive a very strong constraint: The FTH reconstructions show us previsely the actual real-space sample structure, i.e., the arrangement of our aperture where x-rays are transmitted ("1") while the masked areas show no transmission ("0"). We will therefore create a binary mask that reflects this transmission as an input for the phase retrieval.

How to draw a support mask: Create a binary mask of the locations of sample apertures in the fth reconstruction. Areas with apertures are "1". Select only a single set of reconstructions (object & reference apertures) that originate from a single reference. Use the widget!

### Option 1: Execute if you want to create a new support mask with widget

If you really want to create a new support mask, execute next cell and then the "InteractiveCircleCoordinates"-widget

In [None]:
# How many references do you have?
nr_ref = 2

# Setup coordinates (nr_ref + 1 coordinates, as there is always the object aperture)
support_coordinates = [
    [(pos.shape[-2] // 2, pos.shape[-1] // 2), 10, 10, 0] for k in range(nr_ref + 1)
]

# Widget to find the positions and sizes of the different apertures
print("Cover the object & reference apertures for each set of reconstructions that originates from the same reference with circles.")
print("Optimization: Change one circle parameter, calc phase retrieval image, compare with images reconstructed with old circle parameter. Repeat!")

# Create plot
holo = pos * mask_bs* (1 - mask_pixel_smooth)

# Reconstruct
if heraldo is True:
    recon, _ = reconstruct_heraldo(holo,experimental_setup,prop_dist = prop_dist, phase = phase, angle = heraldo_rotation)
elif heraldo is False:
    recon = fth.reconstruct(
        fth.propagate(holo, prop_dist * 1e-6, experimental_setup=experimental_setup)
        * np.exp(1j * phase)
    )
recon = np.real(recon)
    
ds_ellipse = interactive.InteractiveEllipseCoordinates(recon,
    len(support_coordinates),
    coordinates=support_coordinates,
)

In [None]:
# Take coordinates of circles from widget
support_coordinates = ds_ellipse.get_params()

# Create supportmask from coordinates
supportmask = create_ellipse_supportmask(support_coordinates, pos.shape)

# Add heraldo slit
supportmask[1018:1030,880:1024] = 1

# Plot supportmask as overlay
fig, ax = plt.subplots(figsize=(6, 6))
mi, ma = np.percentile(recon, (1, 99))
ax.imshow(recon, vmin=mi, vmax=ma, cmap="gray")
ax.imshow(supportmask, alpha=0.4, cmap="binary")
ax.set_title("Image with overlayed mask")

### Option 2: Execute if you want to load an existing support mask

In [None]:
def get_supportmask_coordinates(sample):
    """
    Dictionary that stores coordinates of circular support mask apertures
    """

    # Setup dictonary
    support_coord = dict()

    # coordinates
    #support_coord["s2402f_2-35deg"] = [((1053.5, 839.0), 117.0, 100.0, 0.0), [(1187,1018),12,8,0], ((1024, 1024), 10, 10, 0)]
    support_coord["s2402f_2-35deg"] = [((978.0, 845.0), 110.0, 110.0, 0.0), ((722.0, 941.0), 10.0, 10.0, 0.0)]          
    return support_coord[sample]

In [None]:
# Which supportmask to load? ("s2306a-C1", "s2308a-B1", ...)
sample = "s2402f_2-35deg"

# Get coordinates and create supportmask
support_coordinates = get_supportmask_coordinates(sample)

In [None]:
# Widget to find the positions and sizes of the different apertures
print("Cover the object & reference apertures for each set of reconstructions that originates from the same reference with circles.")
print("Optimization: Change one circle parameter, calc phase retrieval image, compare with images reconstructed with old circle parameter. Repeat!")

# Plotting image
holo = pos * mask_bs* (1 - mask_pixel_smooth)

# Reconstruct
if heraldo is True:
    recon, _ = reconstruct_heraldo(holo,experimental_setup,prop_dist = prop_dist, phase = phase, angle = heraldo_rotation)
elif heraldo is False:
    recon = fth.reconstruct(
        fth.propagate(holo, prop_dist * 1e-6, experimental_setup=experimental_setup)
        * np.exp(1j * phase)
    )
recon = np.real(recon)

ds_ellipse = interactive.InteractiveEllipseCoordinates(recon,
    len(support_coordinates),
    coordinates=support_coordinates,
)

In [None]:
# Take coordinates of circles from widget
support_coordinates = ds_ellipse.get_params()

# Create supportmask from coordinates
supportmask = create_ellipse_supportmask(support_coordinates, pos.shape)

# Add heraldo slit
supportmask[1020:1028,880:1024] = 1

# Plot supportmask as overlay
fig, ax = plt.subplots(figsize=(6, 6))
mi, ma = np.percentile(recon, (1, 99))
ax.imshow(recon, vmin=mi, vmax=ma, cmap="gray")
ax.imshow(supportmask, alpha=0.4, cmap="binary")
ax.set_title("Image with overlayed mask")

### Option 3: Load image under normal incidence, create mask and project mask

In [None]:
# Define scan ids for image under normale incidence
load_id = 1024

print("Reference CDI Image Id: %s" % load_id)

#### Load n
ormal image and create support mask for that image

In [None]:
# Load phase retrieval image and support mask
reference_dict = load_cdi(load_id)
reference_cdi = reference_dict["reco_pc"] * reference_dict["supportmask"]
reference_cdi = np.abs(reference_cdi)

# Segment support mask into object and reference apertures
big_apertures = np.logical_not(skimage.morphology.remove_small_holes(np.logical_not(reference_dict["supportmask"]),area_threshold = 5000))
small_apertures = reference_dict["supportmask"] - big_apertures

# Improve support mask of big apertues
# Threshold filtering
supportmask = big_apertures*reference_cdi > 200000

# Remove artifacts from threshold filtering
footprint = skimage.morphology.disk(3)
supportmask= skimage.morphology.opening(supportmask, footprint)
supportmask = skimage.morphology.remove_small_holes(supportmask,area_threshold = 30)

# Smoothing of mask
supportmask = gaussian_filter(supportmask.astype(float),1.3)
supportmask = supportmask > 0.4

# Add references
supportmask = supportmask + small_apertures

# Plotting
fig, ax = plt.subplots(2,2,figsize=(12,12),sharex=True,sharey=True)
mi, ma = np.percentile(reference_cdi[reference_cdi!=0],[.1,99.9])
ax[0,0].imshow(reference_cdi,vmin=mi,vmax=ma,cmap="gray")
ax[0,0].set_xlim([700,1100])
ax[0,0].set_ylim([700,1100])
ax[0,0].invert_yaxis()
ax[0,1].imshow(supportmask)
ax[1,0].imshow(reference_cdi*supportmask,vmin=mi,vmax=ma,cmap="gray")
ax[1,1].imshow(reference_cdi*(1-supportmask),vmin=mi,vmax=ma,cmap="gray")

#### Project image onto new reconstruction of tilted sample

In [None]:
# Plotting image
holo = pos * mask_bs* (1 - mask_pixel_smooth)

# Reconstruct
if heraldo is True:
    recon, _ = reconstruct_heraldo(holo,experimental_setup,prop_dist = prop_dist, phase = phase, angle = heraldo_rotation)
elif heraldo is False:
    recon = fth.reconstruct(
        fth.propagate(holo, prop_dist * 1e-6, experimental_setup=experimental_setup)
        * np.exp(1j * phase)
    )
recon = np.real(recon)

# projection
mask = get_mask_projections(supportmask, [np.deg2rad(experimental_setup["srotz"]-experimental_setup["z_angle_offset"])])

fig, ax = plt.subplots(2,2,figsize=(12,12),sharex=True,sharey=True)
mi, ma = np.percentile(recon[recon!=0],[.1,99.9])
ax[0,0].imshow(recon,vmin=mi,vmax=ma,cmap="gray")
ax[0,0].set_xlim([700,1100])
ax[0,0].set_ylim([700,1100])
ax[0,0].invert_yaxis()
ax[0,1].imshow(mask)
ax[1,0].imshow(recon*mask,vmin=mi,vmax=ma,cmap="gray")
ax[1,1].imshow(recon*(1-mask),vmin=mi,vmax=ma,cmap="gray")

In [None]:
# Use widget to shift and expand or shrink the mask
ss_mask = interactive.Shift_Scale_Mask(recon,supportmask,cmap="gray")

In [None]:
# Take mask, shift and scaling from widget
supportmask, mask_shift, mask_scale = ss_mask.get_mask()

### Option 4: Draw manually

In [None]:
# Plotting image
holo = pos * mask_bs* (1 - mask_pixel_smooth)

# Reconstruct
if heraldo is True:
    recon, _ = reconstruct_heraldo(holo,experimental_setup,prop_dist = prop_dist, phase = phase, angle = heraldo_rotation)
elif heraldo is False:
    recon = fth.reconstruct(
        fth.propagate(holo, prop_dist * 1e-6, experimental_setup=experimental_setup)
        * np.exp(1j * phase)
    )
recon = np.real(recon)

poly_mask = interactive.draw_polygon_mask(recon,cmap='gray')

In [None]:
# Take poly coordinates and mask from widget
p_coord = poly_mask.coordinates

# Polygon mask
manual_supportmask = poly_mask.full_mask.astype(int)

print("Copy these coordinates into the 'load_poly_coordinates()' function:")
print(p_coord)


In [None]:
def load_poly_support():
    """
    Dictionary that stores polygon corner coordinates of all drawn masks
    Example: How to add masks with name "test":
    mask_coordinates["test"] = copy coordinates from above
    """

    # Setup dictonary
    mask_coordinates = dict()

    # Masks
    mask_coordinates["drawn_mask_heraldo_0deg"] = [[(964.5627505848117, 790.6253886706052), (958.9502425001773, 794.4839879787913), (957.0209428460842, 798.3425872869776), (950.005307740291, 800.4472778187155), (941.5865456133392, 804.1304862492569), (939.4818550816012, 806.4105676586397), (938.0787280604427, 813.4262027644329), (935.2724740181254, 815.5308932961708), (933.8693469969668, 818.512538216133), (931.2384838322943, 821.494183136095), (928.4322297899771, 825.8789550772157), (925.450584870015, 831.8422449171399), (926.6783210135287, 837.805534757064), (922.4689399500529, 844.2949972299227), (927.3798845241081, 853.5907137450987), (927.2044936464632, 866.7450295684608), (931.0630929546495, 871.4805832648711), (940.5342003474702, 882.5302085564954), (943.3404543897875, 889.0196710293541), (948.7775715967772, 889.0196710293541), (961.0549330319152, 896.5614787680817), (970.1752586694463, 897.6138240339507), (978.5940207963981, 900.4200780762679), (984.5573106363223, 896.210697012792), (995.9577176832362, 895.8599152575024), (1001.7456166455155, 893.22905209283), (1010.6905514054017, 892.176706826961), (1016.1276686123914, 886.5641987423264), (1016.3030594900363, 881.8286450459161), (1028.5804209251744, 872.1821467754505), (1027.5280756593054, 865.3419025473022), (1031.912847600426, 855.6954042768366), (1032.965192866295, 846.5750786393055), (1028.931202680464, 825.8789550772157), (1023.6694763511191, 819.3894926043571), (1020.3370496758673, 816.9340203173294), (1011.5675057936259, 805.8843950257052), (1003.1487436666741, 798.5179781646224), (995.7823268055913, 795.88711499995), (993.3268545185637, 791.5023430588293), (984.7327015139671, 789.5730434047362), (981.9264474716498, 791.5023430588293), (973.6830762223428, 791.677733936474)], [(724.7618360616121, 936.5561838840708), (721.176462855007, 935.1476444100474), (717.7191386914949, 935.9159386686057), (715.7984030450992, 938.4769195304665), (715.6703540020062, 942.7025379525369), (718.4874329500531, 945.6476659436768), (723.3532965875886, 945.9037640298628), (726.0423264925425, 944.6232735989324), (728.2191602251243, 942.0622927370716), (727.578915009659, 938.4769195304665)], [(1023.3663313017321, 1020.2409014928181), (877.1088756413031, 1023.0004761279206), (877.3848331048134, 1032.1070724237586), (1023.3663313017321, 1029.0715403251459)]]
    mask_coordinates["drawn_mask_heraldo_-35deg"] = [[(983.4763791077349, 782.8878502755655), (973.1682003936147, 785.9421254501196), (962.4782372826752, 790.5235382119508), (958.6603933144825, 793.9595977833243), (954.0789805526513, 798.1592261483363), (948.7339989971815, 803.5042077038061), (943.3890174417118, 808.0856204656372), (933.4626231244107, 823.356996338408), (930.0265635530372, 833.6651750525282), (930.0265635530372, 842.4462161793715), (931.5537011403144, 851.6090417030339), (934.6079763148686, 861.1536516235155), (938.8076046798805, 870.6982615439973), (943.3890174417118, 879.4793026708405), (951.0247053780971, 887.1149906072259), (959.8057465049403, 892.8417565595149), (969.350356425422, 899.3320913054425), (976.9860443618074, 902.3863664799967), (1067.0871620111548, 902.768150876816), (1079.6860471061907, 900.8592288927196), (1089.994225820311, 897.8049537181654), (1097.6299137566964, 893.2235409563342), (1101.0659733280697, 888.2603437976837), (1105.2656016930819, 885.5878530199489), (1110.6105832485514, 874.51610551219), (1112.1377208358285, 868.4075551630817), (1117.100917994479, 856.5722388616844), (1119.773408772214, 845.5004913539256), (1119.0098399785754, 835.9558814334439), (1114.4284272167442, 824.1205651320465), (1109.4652300580938, 813.8123864179263), (1101.0659733280697, 800.0681481324326), (1089.2306570266724, 790.9053226087701), (1080.8314002966486, 786.7056942437582), (1070.141437185709, 782.5060658787462)]]
    return mask_coordinates

In [None]:
# Which drawn masks do you want to load? You can combine multiple masks from
# load_poly_coordinates(). Just add names of mask as strings to list like
# ["bs_small","bs_medium"]
coord = load_poly_support()
polygon_names = "drawn_mask_heraldo_0deg"
supportmask = create_polygon_mask(
    im_c.shape,
    coord[polygon_names],
)

# More preprocessing
footprint = skimage.morphology.disk(1)
supportmask= skimage.morphology.dilation(supportmask, footprint) # expand
supportmask= skimage.morphology.erosion(supportmask, footprint) # shrink

# Smoothing of mask
supportmask = gaussian_filter(supportmask.astype(float),2)
supportmask = supportmask > 0.4

# Plotting
fig, ax = plt.subplots(2,2,figsize=(12,12),sharex=True,sharey=True)
mi, ma = np.percentile(recon[recon!=0],[1,99])
ax[0,0].imshow(recon,vmin=mi,vmax=ma,cmap="gray")
ax[0,0].set_xlim([700,1100])
ax[0,0].set_ylim([700,1100])
ax[0,0].invert_yaxis()
ax[0,1].imshow(supportmask)

ax[1,0].imshow(recon*supportmask,vmin=mi,vmax=ma,cmap="gray")
ax[1,1].imshow(recon*(1-supportmask),vmin=mi,vmax=ma,cmap="gray")


### Take Roi
Choose the reconstructions as the ROI.

1. Zoom into the image and adjust your FOV until you are satisfied.
2. Save the axes coordinates.

In [None]:
fig, ax = cimshow(supportmask.astype(int))

In [None]:
roi_cdi = interactive.axis_to_roi(ax)
roi_cdi = [751,1059,693,1064]
roi_cdi = np.s_[roi_cdi[0] : roi_cdi[1], roi_cdi[2] : roi_cdi[3]]
print("Sliced roi:", roi_cdi)

## Do Phase Retrieval

In [None]:
# Executes the algorithm
#fth_reco = fth.reconstruct(pos * mask_bs * np.exp(1j * phase) * (1 - mask_pixel_smooth))
#startimage = cci.reconstruct(supportmask*np.abs(fth_reco))

(
    retrieved_p,
    retrieved_p_pc,
    bsmask_p,
    gamma_p,
) = phase_retrieval(pos, mask_pixel, supportmask, vmin = .5, Startimage=None, Startgamma=None)


## Reconstruct images from phase retrieval

In [None]:
# New beamstop for CDI recos as phase retrieval of low-q might be insufficient. If phase retrieval worked well
# Try without beamstop: `use_bs = False`
use_bs = False
bs_diam_cdi = 32  # diameter of beamstop

# Create beamstop
if use_bs is True:
    mask_bs_cdi = 1 - cci.circle_mask(
        image.shape, np.array(image.shape) / 2, 10, sigma=4
    )
    mask_bs_cdi = 1 - mask_pixel_smooth.copy()
elif use_bs is False:
    mask_bs_cdi = np.ones(pos.shape)  # if you don't want a beamstop


# Get Reco
p = fth.reconstructCDI(
    fth.propagate(
        retrieved_p * mask_bs_cdi,
        prop_dist_cdi * 1e-6,
        experimental_setup=experimental_setup,
    )
)

# Get Reco partial coherence
p_pc = fth.reconstructCDI(
    fth.propagate(
        retrieved_p_pc * mask_bs_cdi,
        prop_dist_cdi * 1e-6,
        experimental_setup=experimental_setup,
    )
)

# optimize phase
#_, phase_cdi = optimize_phase_contrast(p_pc, supportmask, method="contrast")

# Plotting
mode = "+"
print("Fine-tuning of reconstruction parameter:")
slider_prop, slider_phase, _, _= rec.focusCDI(
    retrieved_p_pc * mask_bs_cdi,
    np.zeros(retrieved_p_pc.shape),
    roi_cdi,
    mask=supportmask,
    phase=phase_cdi,
    dx=dx,
    dy=dy,
    prop_dist=prop_dist_cdi,
    experimental_setup=experimental_setup,
    operation=mode,
    max_prop_dist=2,
    scale=(1, 99.9),
)

In [None]:
# Get phase from slider
phase_cdi = slider_phase.value
prop_dist_cdi = slider_prop.value

# Reconstruct images with new parameter
p_pc = fth.reconstructCDI(
    fth.propagate(
        retrieved_p_pc * mask_bs_cdi,
        prop_dist_cdi * 1e-6,
        experimental_setup=experimental_setup,
    )
) * np.exp(1j * phase_cdi)

print("Phase CDI: %s" % phase_cdi)
print("Prop_dist: %s" % prop_dist_cdi)

In [None]:
cimshow(np.abs(p_pc),cmap='gray')

In [None]:
# Confirm that offset subtraction in cdi function works, i.e., only small fraction of hologram is actually masked
fig, ax = plt.subplots(1, 2, figsize=(8, 4), sharex=True, sharey=True)
tmp = np.abs(retrieved_p_pc)
mi, ma = np.percentile(tmp, [0.1, 99.9])
ax[0].imshow(tmp, vmin=mi, vmax=ma)
ax[0].set_title("Retrieved Holo")
ax[1].imshow(bsmask_p)
ax[1].set_title("Phase Retrieval mask")


## Save reconstructions

In [None]:
# Style of reconstruction plot
def plot_recon(recon, title):
    # Plot
    fig, ax = plt.subplots(1, 2, figsize=(8, 4))
    fig.suptitle(title)

    vmin, vmax = np.percentile(np.abs(recon)*supportmask[roi_cdi], (20, 99.9))
    t_im1 = ax[0].imshow(np.abs(recon), vmin=vmin, vmax=vmax, cmap="gray")
    ax[0].set_title("Abs")
    #plt.colorbar(t_im1, ax=ax[0], aspect=50)

    vmin, vmax = np.percentile(np.real(recon)*supportmask[roi_cdi], (20, 99.9))
    t_im2 = ax[1].imshow(np.real(recon), vmin=vmin, vmax=vmax, cmap="gray")
    ax[1].set_title("Real")
    #plt.colorbar(t_im2, ax=ax[1], aspect=50)

In [None]:
# Saves only real and imaginary part
recon = p_pc.copy()

# Plot
title = get_title("srotz", im_id, CDI=True)
plot_recon(recon[roi_cdi], title)

# Save images
fname = join(
    folder_general,
    "Recon_ImId_%04d_cdi_%s.png" % (im_id, USER),
)
print("Saving: %s" % fname)
plt.savefig(fname, bbox_inches="tight", transparent=False)

# Save h5
save_cdi_h5()

# Batch processing CDI (not tested)

## Define Scan Ids

In [None]:
# Load support mask of which sample?
sample = "s2305e-C5"

In [None]:
# Define the sets for reconstructions. You can make a iterable list or use np.arange
im_id_set = [911]
dark_id_set = [912]

print("Image Set:  %s" % im_id_set)

## Execute Stack Reconstruction

In [None]:
# Less ugly Automatic processing of image stacks
for it, im_id in enumerate(tqdm(im_id_set)):
    # Load energy and add to experimental setup
    experimental_setup["energy"] = load_pre_scan_snapshot(im_id,"energy")
    experimental_setup["lambda"] = cci.photon_energy_wavelength(
        experimental_setup["energy"], input_unit="eV"
    )
    
    # Load angles
    experimental_setup["srotx"] = load_pre_scan_snapshot(im_id,"srotx")
    experimental_setup["srotz"] = load_pre_scan_snapshot(im_id,"srotz")
    
    # Load images
    image = load_processing(im_id)

    # Load image
    dark_id = dark_id_set[it]
    if dark_id is not None:
        dark = load_processing(dark_id, crop=None)
        image = image - dark

    # Process images
    worker_dict = worker(image)

    # Take values from worker
    recon = worker_dict["recon"]
    
    # Plotting
    title = get_title("srotz",im_id)
    plot_recon(worker_dict["recon"][roi_s], title)
        # Save images
    fname = join(
        folder_general,
        "Recon_ImId_%04d_fth_stack_%s.png" % (im_id, USER),
    )
    print("Saving: %s" % fname)
    plt.savefig(fname, bbox_inches="tight", transparent=False)

    # Save hdf5 file
    save_fth_h5()

    ################ CDI ###############
    # Create pos and neg helicity set
    pos = worker_dict["im_c"]

    # Get coordinates and create supportmask
    support_coordinates = get_supportmask_coordinates(sample)
    supportmask = create_ellipse_supportmask(support_coordinates, pos.shape)

    (
        retrieved_p,
        retrieved_p_pc,
        bsmask_p,
        gamma_p,
    ) = phase_retrieval(pos, mask_pixel, supportmask, vmin = 0, Startimage=None, Startgamma=None)
    
    # Get Reco partial coherence
    p_pc = fth.reconstructCDI(
        fth.propagate(
            retrieved_p_pc * mask_bs_cdi,
            prop_dist_cdi * 1e-6,
            experimental_setup=experimental_setup,
        )
    )

    # Apply phase
    p_pc = p_pc * np.exp(1j * phase_cdi)

    # Saves only real and imaginary part
    recon = p_pc.copy()
    
    # Plot
    title = get_title("m_magnett_read", im_id, CDI=True)
    plot_recon(recon[roi_cdi], title)
    
    # Save images
    fname = join(
        folder_general,
        "Recon_ImId_%04d_cdi_stack_%s.png" % (im_id, USER),
    )
    print("Saving: %s" % fname)
    print("")
    plt.savefig(fname, bbox_inches="tight", transparent=False)
    
    # Save h5
    save_cdi_h5()