# Initalize libraries

## Import libraries

In [None]:
import sys, os
import time
from os.path import join
from os import path
from importlib import reload
from getpass import getuser
from glob import glob

import xarray as xr
import h5py
from tqdm.auto import tqdm

import numpy as np
import matplotlib.pyplot as plt
import scipy
import fabio
import skimage.morphology

# Open nexus files
from nexusformat.nexus import *

from scipy.ndimage.filters import gaussian_filter

# pyFAI
import pyFAI
from pyFAI.azimuthalIntegrator import AzimuthalIntegrator
from pyFAI.detectors import Detector

# Self-written libraries
sys.path.append(join(os.getcwd(), "library"))
import reconstruct as reco
import fthcore as fth
import support_functions as sup
from interactive import cimshow
import interactive

# Gifs
import imageio

from scipy import stats

plt.rcParams["figure.constrained_layout.use"] = True  # replaces plt.tight_layout

import reconstruct as reco
import reconstruct_rb as rec

In [None]:
# Is there a GPU?
try:
    # Cupy
    import cupy as cp
    import cupyx as cpx

    GPU = True

    print("GPU available")

    # Self-written library
    import CCI_core_cupy as cci
    import Phase_Retrieval as PhR
except:
    GPU = False
    import CCI_core as cci
    import Phase_Retrieval_noGPU as PhR

    print("GPU unavailable")

In [None]:
# interactive plotting
import ipywidgets

%matplotlib widget

# Auto formatting of cells
#%load_ext jupyter_black

## Experiment specific Functions

In [None]:
PROPOSAL = 11018955
USER = getuser()

### Loading data

In [None]:
BASEFOLDER = "/asap3/petra3/gpfs/p04/2024/data/%s/" % PROPOSAL
sample_name = "2403_tomo"


# Load any kind of data from collection
def load_pre_scan_snapshot(scan_id, key):
    fname = join(BASEFOLDER, "raw", "%s_%05d.h5" % (sample_name,scan_id))
    entry = "entry%d"%scan_id
    
    with h5py.File(fname, "r") as f:
        data = np.array(f[entry]["measurement"]["pre_scan_snapshot"][key][()])
        print("Loaded: %s" % (fname))
    return np.squeeze(data)

# Load any kind of data from measurements
def load_data(scan_id, key):
    fname = join(BASEFOLDER, "raw", "%s_%05d.h5" % (sample_name,scan_id))
    entry = "entry%d"%scan_id

    with h5py.File(fname, "r") as f:
        data = np.array(f[entry]["measurement"][key][()])
        print("Loaded key %s of %s" % (key, entry))
    return np.squeeze(data)


def get_image_id(im_id):
    fname = join(BASEFOLDER, "raw", "%s_%05d.h5" % (sample_name,im_id))
    entry = "entry%d"%im_id

    with h5py.File(fname, "r") as f:
        spe_id = str(f[entry]["measurement"]["ccd2"][()])[-10:-6]
    return spe_id

def load_spe_files(spe_id):
    fname = join(BASEFOLDER, "raw", sample_name +"  " +str(spe_id) + ".spe")
    im_out = np.array(imageio.mimread(fname, memtest="5000MB"))

    return im_out
    
# Load image files
def load_images(im_id):
    """
    Load ccd images from nxs files
    """
    
    spe_id = get_image_id(im_id)
    im_out = load_spe_files(spe_id)
    
    return im_out


# Full image loading procedure
def load_processing(im_id, crop=None):
    """
    Loads images, averaging of two individual images (scans in tango consist of two images),
    padding to square shape, Additional cropping (optional)
    """

    # Load data
    images = load_images(im_id)

    # Zeropad to get square shape
    #images = sup.padding(images)

    # Calculate mean
    if images.ndim > 2:
        image = np.mean(images, axis=0)
    else:
        image = images.copy()

    # Optional cropping
    if crop is not None:
        images = images[:, :crop, :crop]
        image = image[:crop, :crop]

    return image

### Loading, saving fth & cdi data

In [None]:
# Saving of log files for fth and cdi recos
def save_fth_h5():
    # Save h5
    data = {}
    data["im_id"] = im_id
    data["topo_id"] = topo_id
    data["topo_centered"] = topo_c
    data["im_centered"] = im_c
    data["holo"] = holo
    data["sum_b"] = sum_b
    data["recon"] = recon
    data["factor"] = factor
    data["offset"] = offset
    data["center"] = center
    data["roi"] = roi
    data["prop_dist"] = prop_dist
    data["phase"] = phase
    data["mask_bs"] = mask_bs
    data["bs_diam"] = bs_diam
    data["experimental_setup"] = experimental_setup

    filename = join(
        folder_general, "Logs", "Data_ImId_%04d_RefId_%s_%s" % (im_id, topo_id, USER)
    )
    print("Now Saving: %s" % filename)
    cci.create_hdf5(data, filename)


def save_cdi_h5():
    # Save h5
    data = {}
    data["im_id"] = im_id
    data["topo_id"] = topo_id
    data["pos"] = pos
    data["neg"] = neg
    data["factor"] = factor
    data["offset"] = offset
    data["center"] = center
    data["roi"] = roi
    data["prop_dist"] = prop_dist_cdi
    data["phase"] = phase_cdi
    data["mask_bs"] = mask_bs_cdi
    data["bs_diam"] = bs_diam_cdi
    data["supportmask"] = supportmask
    data["mask_pixel"] = mask_pixel
    data["p"] = p
    data["n"] = n
    data["p_pc"] = p_pc
    data["n_pc"] = n_pc
    data["gamma_p"] = gamma_p
    data["gamma_n"] = gamma_n
    data["experimental_setup"] = experimental_setup

    filename = join(
        folder_general,
        "Logs",
        "Data_ImId_%04d_RefId_%s_cdi_%s" % (im_id, topo_id, USER),
    )
    print("Now Saving: %s" % filename)
    cci.create_hdf5(data, filename)
    return


def save_topo_holo(topo_c, pos_id, neg_id):
    """
    Save only topo holos which can be later used for single helicity reconstructions
    """
    data = {}
    data["pos_id"] = pos_id
    data["neg_id"] = neg_id
    data["topo"] = topo_c

    filename = join(
        folder_general,
        "Topos",
        "Topo_ImId_%04d_RefId_%04d_cdi_%s" % (pos_id, neg_id, USER),
    )
    print("Now Saving: %s" % filename)
    cci.create_hdf5(data, filename)
    return


def load_topo_holo(pos_id, neg_id):
    """
    Load topo holos for single helicity reconstructions
    """
    fname = join(
        folder_general,
        "Topos",
        "Topo_ImId_%04d_RefId_%04d_cdi_%s.hdf5" % (pos_id, neg_id, USER),
    )

    with h5py.File(fname, "r") as f:
        im_out = f["topo"][()]
    return im_out


def load_cdi(im_id, topo_id):
    """
    Load cdi dataset
    """
    fname = join(
        folder_general,
        "Logs",
        "Data_ImId_%04d_RefId_%s_cdi_%s.hdf5" % (im_id, topo_id, USER),
    )

    with h5py.File(fname, "r") as f:
        data_cdi = {}
        data_cdi["im_id"] = f["im_id"][()]
        data_cdi["topo_id"] = f["topo_id"][()]
        data_cdi["pos"] = f["pos"][()]
        data_cdi["neg"] = f["neg"][()]
        data_cdi["recon"] = f["recon"][()]
        data_cdi["factor"] = f["factor"][()]
        data_cdi["offset"] = f["offset"][()]
        data_cdi["shift"] = f["shift"][()]
        data_cdi["center"] = f["center"][()]
        data_cdi["prop_dist"] = f["prop_dist"][()]
        data_cdi["phase"] = f["phase"][()]
        data_cdi["mask_bs"] = f["mask_bs"][()]
        data_cdi["supportmask"] = f["supportmask"][()]
        data_cdi["mask_pixel"] = f["mask_pixel"][()]
        data_cdi["p"] = f["p"][()]
        data_cdi["n"] = f["n"][()]
        data_cdi["p_pc"] = f["p_pc"][()]
        data_cdi["n_pc"] = f["n_pc"][()]
    return data_cdi

### Masking

In [None]:
from matplotlib.path import Path


def create_single_polygon_mask(shape, coordinates):
    """
    Creates a polygon mask from coordinates of corner points

    Parameter
    =========
    shape : int tuple
        shape/dimension of output array
    coordinates: nested list
        coordinates of polygon corner points [(yc_1,xc_1),(yc_2,xc_2),...]


    Output
    ======
    mask: array
        binary mask where filled polygon is "1"
    ======
    author: ck 2023
    """

    x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]))
    x, y = x.flatten(), y.flatten()

    points = np.vstack((x, y)).T

    path = Path(coordinates)
    mask = path.contains_points(points)
    mask = mask.reshape(shape)
    return mask


def create_polygon_mask(shape, coordinates):
    """
    Creates multiple polygon masks from set of coordinates of corner points

    Parameter
    =========
    shape : int tuple
        shape/dimension of output array
    coordinates: nested list
        coordinates of polygon corner points for multiple polygons
        [[(yc_1,xc_1),(yc_2,xc_2),...],[(yc_1,xc_1),(yc_2,xc_2),...]]

    Output
    ======
    mask: array
        binary mask where filled polygons are "1"
    ======
    author: ck 2023
    """

    if len(coordinates) == 1:
        mask = create_single_polygon_mask(shape, coordinates[0])

    # Loop over coordinates
    elif len(coordinates) > 1:
        mask = np.zeros(shape)
        for coord in coordinates:
            mask = mask + create_single_polygon_mask(shape, coord)
            mask[mask > 1] = 1

    return mask


def load_poly_masks(shape, polygon_name_list):
    """
    Loads set of polygon masks based on stored coordinates

    Parameter
    =========
    shape : tuple
        shape of output mask
    polygon_name_list : list
        keys of different mask coordinates to load

    Output
    ======
    mask: array
        binary mask where filled polygons are "1"
    ======
    author: ck 2023
    """

    mask = []

    # Load dictionary of coordinates
    mask_coordinates = load_poly_coordinates()

    # Loop over relevant mask keys
    for polygon_name in polygon_name_list:
        coord = mask_coordinates[polygon_name]
        mask.append(create_polygon_mask(im_c.shape, coord).astype(float))

    # Combine all individual mask layers
    mask = np.array(mask)
    mask = np.sum(mask, axis=0)
    mask[mask > 1] = 1

    return mask

In [None]:
def auto_shift_mask(
    mask,
    image,
    shift_range_y=(-10, 10),
    shift_range_x=(-10, 10),
    step_size=1,
    crop=None,
    method="minimize",
):
    """
    Automatically shifts a binary (beamstop) mask to the optimal position

    Parameter
    =========
    mask : array
        binary mask, e.g., beamstop mask
    image : array
        image partially covered by mask
    shift_range_y, _x : tupel
        Min and max limit for search area in y- and x-direction
    step_size : scalar
        step size of scan area
    crop : int
        additional cropping of arrays to speed up calculations
        Crops according to [crop:-crop,crop:-crop]
    method : string
        method used for optimization. 'minimize' or 'maximize' metric <mask,image>

    Output
    ======
    optimized shift: tupel
        shift vector for optimized posistion
    mask_shifted: array
        mask shifted according to best shift vector
    overlap : array
        computed evaluation metric
    ======
    author: ck 2023
    """

    # Basic looping
    # Create set of arrays for shifting
    yshift = np.arange(shift_range_y[0], shift_range_y[1], step_size)
    xshift = np.arange(shift_range_x[0], shift_range_x[1], step_size)

    # Setup loss function to evaluate overlap
    overlap = np.zeros((yshift.shape[0], xshift.shape[0]))

    # Optinal cropping to reduce computation time
    if crop is not None:
        tmask, timage = mask[crop:-crop, crop:-crop], image[crop:-crop, crop:-crop]
    else:
        tmask, timage = mask.copy(), timage.copy()

    # Loop over all combinations of shifts
    for i, y in enumerate(yshift):
        for j, x in enumerate(xshift):
            # Shift mask to new position
            mask_shift = cci.shift_image(tmask, [y, x])

            # Calculate overlap
            overlap[i, j] = np.sum(mask_shift * timage)

    # Get best shift
    if method == "minimize":
        idx = np.unravel_index(
            np.argmin(
                overlap,
            ),
            overlap.shape,
        )
    elif method == "maximize":
        idx = np.unravel_index(
            np.argmax(
                overlap,
            ),
            overlap.shape,
        )

    # Output
    optimized_shift = (yshift[idx[0]], yshift[idx[1]])
    print(
        "Best mask overlap for shift: [%.2f,%.2f]"
        % (optimized_shift[0], optimized_shift[1])
    )

    # Shift
    mask_shifted = cci.shift_image(mask, optimized_shift)

    # Binarize mask (necessary for sub-px shift)
    mask_shifted[mask_shifted > 0.5] = 1
    mask_shifted[mask_shifted < 0.5] = 0

    return optimized_shift, mask_shifted, overlap

In [None]:
# Still experimental
def automated_beamstop(hologram, grad_thres, radius, expand, method="intensity"):
    # Different methods for filtering of beamstop
    if method == "intensity":
        # Some filters
        pass
    elif method == "gradient":
        hologram = np.mean(np.abs(np.gradient(hologram)), axis=0)

    # Thresholding with gradient
    hologram_mask = hologram < grad_thres

    # Draw beamstop only up to given radius
    hologram_mask = hologram_mask * cci.circle_mask(
        image.shape,
        [image.shape[0] / 2, image.shape[1] / 2],
        radius,
        sigma=None,
    )
    hologram_mask = hologram_mask.astype(bool)

    # Morphological operations to filter reference modulations as these also lead to strong intensity gradients
    # close the "dots" of the ref modulations
    footprint = skimage.morphology.disk(2)
    hologram_mask = skimage.morphology.erosion(hologram_mask, footprint)

    # Filter remainings of ref modulations
    hologram_mask = skimage.morphology.remove_small_objects(
        hologram_mask, min_size=2000
    )

    # Expand Mask
    footprint = skimage.morphology.disk(expand)
    hologram_mask = skimage.morphology.dilation(hologram_mask, footprint)

    # Fill up small holes in the mask
    hologram_mask = scipy.ndimage.binary_fill_holes(
        hologram_mask, structure=np.ones((5, 5))
    )

    return hologram_mask

In [None]:
from dipy.segment.mask import median_otsu


def automated_beamstop_center(hologram, expand):
    # "Flatten" hologram
    hologram[hologram < 0] = 0
    hologram = hologram + 1
    hologram = np.log10(hologram)

    # Prepare raw mask using otsu threshold method
    _, hologram_mask = median_otsu(hologram, median_radius=1, numpass=1)

    # Morphological operations to filter reference modulations
    # close the "dots" of the ref modulations
    footprint = skimage.morphology.disk(2)
    hologram_mask = skimage.morphology.erosion(
        (1 - hologram_mask).astype(bool), footprint
    )

    # Filter remainings of ref modulations
    hologram_mask = skimage.morphology.remove_small_objects(
        hologram_mask.astype(bool), min_size=2000
    )
    hologram_mask = 1 - skimage.morphology.remove_small_objects(
        (1 - hologram_mask).astype(bool), min_size=2000
    )

    # Expand mask to desired size
    footprint = skimage.morphology.disk(expand)
    hologram_mask = skimage.morphology.dilation(hologram_mask, footprint)

    return hologram_mask

In [None]:
def create_supportmask(support_coordinates, shape):
    """
    Create cdi support mask from a combination of multiple circular apertures

    Parameter
    =========
    support_coordinates: nested list
        Contains center coordinates and radius of each aperture [[yc_1,xc_1,r_1],[yc_2,xc_2,r_2],...]
    shape : int tuple
        shape/dimension of output array

    Output
    ======
    supportmask: array
        composed binary mask where circular apertures are "1"
    ======
    author: ck 2023

    """

    # Create support mask
    supportmask = np.zeros(shape)
    for i in range(len(support_coordinates)):
        supportmask += cci.circle_mask(
            supportmask.shape,
            [support_coordinates[i][0], support_coordinates[i][1]],
            support_coordinates[i][2],
        )

    return supportmask

### Other

In [None]:
# Worker which performs complete fth reconstruction process
def worker(image, topo):
    # Centering
    shift_c = np.array(topo.shape) / 2 - center
    topo_c = cci.shift_image(topo, shift_c)
    im_c = cci.shift_image(image, shift_c)

    # Apply beamstop to image
    im_b = im_c * mask_bs

    # Create masks
    mask = load_poly_masks(polygon_names)
    footprint = skimage.morphology.disk(10)
    mask = skimage.morphology.dilation(mask, footprint)

    # BS center part
    mask_bs_center = cci.circle_mask(
        topo.shape, np.array(topo.shape) / 2, bs.rBS, sigma=None
    )

    # Create image specific beamstop mask
    mask_im = mask.copy()
    mask_im = mask_im + (im_c > 65000) + mask_bs_center
    mask_im[mask_im > 1] = 1

    # Create topo specific beamstop mask
    mask_topo = mask.copy()
    mask_topo = mask_topo + (topo_c > 65000) + mask_bs_center
    mask_topo[mask_topo > 1] = 1

    # Combine both
    mask_pixel = mask_im + mask_topo
    mask_pixel[mask_pixel > 1] = 1
    mask_pixel_sharp = mask_pixel.copy()
    mask_pixel = gaussian_filter(mask_pixel, 3)

    # Calc  shift
    shift = cci.image_registration(
        im_c[roi_im_reg],
        topo_c[roi_im_reg],
        method="dipy",
        static_mask=mask_topo[roi_im_reg],
        moving_mask=mask_im[roi_im_reg],
    )
    print("Shift: %s" % shift)

    # Shift and apply beamstop
    topo_c = cci.shift_image(topo, shift_c + shift)  # centered image
    topo_b = topo_c * mask_bs  # centered image with beamstop

    # Get scaling factor and offset
    factor, offset = cci.dyn_factor(
        im_c * (1 - mask_pixel_sharp),
        topo_c * (1 - mask_pixel_sharp),
        method="correlation",
        print_out=True,
        plot=True,
    )

    # Calculate differences (magnetic) and sums (topographc) contrast holograms.
    # _c: centered, without beamstop, _b: centered, with beamstop
    diff_c = im_c / factor - topo_c - offset
    diff_b = im_b / factor - topo_b - offset
    sum_c = im_c / factor + topo_c - offset
    sum_b = im_b / factor + topo_b - offset

    # Do the reconstruction
    holo = diff_b.copy()
    recon = np.zeros(image.shape, dtype=np.complex_)

    # Reconstruct
    recon = cci.reconstruct(
        fth.propagate(holo, prop_dist * 1e-6, experimental_setup=experimental_setup)
        * np.exp(1j * phase)
    )

    # worker dictionary
    worker_dict = {}
    worker_dict["topo_c"] = topo_c
    worker_dict["im_c"] = im_c
    worker_dict["holo"] = holo
    worker_dict["sum"] = sum_b
    worker_dict["recon"] = recon
    worker_dict["factor"] = factor
    worker_dict["offset"] = offset
    worker_dict["diff_c"] = diff_c
    worker_dict["diff_b"] = diff_b
    worker_dict["sum_c"] = sum_c
    worker_dict["sum_b"] = sum_b

    return worker_dict

In [None]:
# Setup phase and propagation for cdi once
phase_cdi = 0
prop_dist_cdi = 0
dx = 0
dy = 0


def phase_retrieval(
    pos, neg, mask_pixel, supportmask, vmin = 0, Startimage=None, Startgamma=None
):
    # Prepare Input holograms
    pos2 = pos.copy()
    neg2 = neg.copy()

    mi, _ = np.percentile(pos2[pos2 != 0], [vmin, 99.9])
    pos2 = pos2 - mi
    mi, _ = np.percentile(neg2[neg2 != 0], [vmin, 99.9])
    neg2 = neg2 - mi

    pos2[pos2 < 0] = 0
    neg2[neg2 < 0] = 0
    pos2 = pos2.astype(complex)
    neg2 = neg2.astype(complex)

    bsmask_p = mask_pixel.copy()
    bsmask_p[pos2 <= 0] = 1
    bsmask_n = mask_pixel.copy()
    bsmask_n[neg2 <= 0] = 1

    # Setup start image and startgamma
    if Startimage is None:
        Startimage = np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift(supportmask)))
    else:
        Startimage = Startimage.copy()
    if Startgamma is None:
        Startgamma = np.ones(pos.shape) * 1e-6 * 2
        Startgamma[pos.shape[0] // 2, pos.shape[1] // 2] = 0.7
    else:
        Startgamma = Startgamma.copy()

    # Settings for phase retrieval reconstructions
    partial_coherence = True

    # Setup
    retrieved_p = np.zeros(pos2.shape, np.cdouble)
    retrieved_n = np.zeros(pos2.shape, np.cdouble)

    # Algorithms and Inital guess
    plt.rcParams["figure.dpi"] = 100
    print("CDI - larger mask")

    algorithm_list = ["mine", "mine", "mine"]
    Nit_list = [700, 50, 50]  # iterations for algorithm_list

    x = (np.sqrt(np.maximum(pos2, np.zeros(pos2.shape)))[mask_pixel == 0]).flatten()
    y = ((np.abs(Startimage))[mask_pixel == 0]).flatten()
    res = stats.linregress(x, y)
    Startimage -= res.intercept
    Startimage /= res.slope

    average_img = 30
    real_object = False  # always set to False

    if partial_coherence:
        RL_freq = 20
        RL_it = 50

        algorithm_list_pc = ["mine", "ER", "ER"]
        Nit_list_pc = [700, 50, 50]

    # Execute Phase retrieval
    start_time = time.time()
    for i in range(len(Nit_list) // 3):
        print("############ -   CDI")

        # Positive helicity - beta_mode="arctan"
        retrieved_p, Error_diff_p, Error_supp = PhR.PhaseRtrv_GPU(
            diffract=np.sqrt(np.maximum(pos2, np.zeros(pos2.shape))),
            mask=supportmask,
            mode=algorithm_list[3 * i],
            beta_zero=0.5,
            Nit=Nit_list[3 * i],
            beta_mode="arctan",
            plot_every=349,
            Phase=Startimage,
            seed=False,
            real_object=real_object,
            bsmask=bsmask_p,
            average_img=average_img,
            Fourier_last=True,
        )

        # Positive helicity - beta_mode="const"
        retrieved_p, Error_diff_p2, Error_supp = PhR.PhaseRtrv_GPU(
            diffract=np.sqrt(np.maximum(pos2, np.zeros(pos2.shape))),
            mask=supportmask,
            mode=algorithm_list[3 * i + 1],
            beta_zero=0.5,
            Nit=Nit_list[3 * i + 1],
            beta_mode="const",
            plot_every=24,
            Phase=retrieved_p,
            seed=False,
            real_object=real_object,
            bsmask=bsmask_p,
            average_img=average_img,
            Fourier_last=True,
        )

        # Negative helicity - beta_mode="arctan"
        retrieved_n, Error_diff_n2, Error_supp = PhR.PhaseRtrv_GPU(
            diffract=np.sqrt(np.maximum(neg2, np.zeros(neg2.shape))),
            mask=supportmask,
            mode=algorithm_list[3 * i + 2],
            beta_zero=0.5,
            Nit=Nit_list[3 * i + 2],
            beta_mode="const",
            plot_every=24,
            Phase=retrieved_p * np.sqrt(np.sum(neg2) / np.sum(pos2)),
            seed=False,
            real_object=real_object,
            bsmask=bsmask_n,
            average_img=average_img,
            Fourier_last=True,
        )

        print("--- %s seconds ---" % np.round((time.time() - start_time), 2))

        Startimage = retrieved_p.copy()

        # Partial coherence phase retrieval
        if partial_coherence:
            # CDI_PC
            print("############   -   CDI_pc")
            pos3 = (np.abs(retrieved_p) ** 2) * bsmask_p + np.maximum(
                pos2, np.zeros(pos2.shape)
            ) * (1 - bsmask_p)
            neg3 = (np.abs(retrieved_n) ** 2) * bsmask_n + np.maximum(
                neg2, np.zeros(neg2.shape)
            ) * (1 - bsmask_n)

            # retrieve pos image
            (
                retrieved_p_pc,
                Error_diff_p_pc,
                Error_supp,
                gamma_p,
            ) = PhR.PhaseRtrv_with_RL(
                diffract=np.sqrt(pos3),
                mask=supportmask,
                mode=algorithm_list_pc[3 * i],
                beta_zero=0.5,
                Nit=Nit_list_pc[3 * i],
                beta_mode="arctan",
                gamma=Startgamma,
                RL_freq=RL_freq,
                RL_it=RL_it,
                plot_every=349,
                Phase=Startimage,
                seed=False,
                real_object=False,
                bsmask=np.zeros(bsmask_p.shape),
                average_img=average_img,
                Fourier_last=True,
            )

            (
                retrieved_p_pc,
                Error_diff_p_pc2,
                Error_supp,
                gamma_p,
            ) = PhR.PhaseRtrv_with_RL(
                diffract=np.sqrt(pos3),
                mask=supportmask,
                mode=algorithm_list[3 * i + 1],
                beta_zero=0.5,
                Nit=Nit_list_pc[3 * i + 1],
                beta_mode="const",
                gamma=gamma_p,
                RL_freq=RL_freq,
                RL_it=RL_it,
                plot_every=24,
                Phase=retrieved_p_pc,
                real_object=False,
                bsmask=np.zeros(bsmask_p.shape),
                average_img=average_img,
                Fourier_last=True,
            )
            (
                retrieved_n_pc,
                Error_diff_n_pc2,
                Error_supp,
                gamma_n,
            ) = PhR.PhaseRtrv_with_RL(
                diffract=np.sqrt(neg3),
                mask=supportmask,
                mode=algorithm_list[3 * i + 2],
                beta_zero=0.5,
                Nit=Nit_list_pc[3 * i + 2],
                beta_mode="const",
                gamma=gamma_p,
                RL_freq=RL_freq,
                RL_it=RL_it,
                plot_every=24,
                Phase=retrieved_p_pc * np.sqrt(np.sum(neg2) / np.sum(pos2)),
                real_object=False,
                bsmask=np.zeros(bsmask_n.shape),
                average_img=average_img,
                Fourier_last=True,
            )

            print("--- %s seconds ---" % np.round((time.time() - start_time), 2))

            Startimage = retrieved_p_pc.copy()
            Startgamma = gamma_p.copy()

    print("Phase Retrieval Done!")

    return (
        retrieved_p,
        retrieved_n,
        retrieved_p_pc,
        retrieved_n_pc,
        bsmask_p,
        bsmask_n,
        gamma_p,
        gamma_n,
    )

In [None]:
def optimize_phase_contrast(recon, supportmask, method="contrast", prefered_color=None):
    """
    Automatically shifts contrast of phase retrieval reconstruction into real part

    Parameter
    =========
    recon : complex array
        FTH/CDI reconstruction plane (Patterson map)
    supportmask : array
        Supportmask of Patterson map for phase retrieval
    method : string
        Choose method for phase optimization ("contrast","minima","maxima")
    prefered_color : string or None
        Shift contrast such that color of domains with largest are is white ("white"),
        black ("black") or non-specific (None)

    Output
    ======
    recon_optimized: complex array
        reconstruction with optimized contrast
    optimized_phase: complex scalar
        phase corresponding to optimized reconstruction
    ======
    author: ck 2023
    """

    # filter references from supportmask
    mask = supportmask.copy()
    mask = mask.astype(bool)
    mask = skimage.morphology.remove_small_objects(mask, min_size=200)
    mask = mask.astype(float)

    # Make object aperture smaller to minimize edge effects
    footprint = skimage.morphology.disk(4)
    mask = skimage.morphology.erosion(mask, footprint)

    # Gaussian filter to remove high intensity peaks
    reco = scipy.ndimage.gaussian_filter(recon, 1)

    # Different functions for optimization
    def contrast(phi, reco, tmask):
        temp = np.imag(reco * np.exp(1j * phi)) * tmask
        mi, ma = np.percentile(temp[temp != 0], [1, 99])
        contrast = ma - mi
        return contrast

    def minima(phi, reco, tmask):
        tmp = np.real(reco * np.exp(1j * phi)) * tmask
        minima, maxima = np.percentile(tmp[tmp != 0], [0.01, 99])
        return minima

    def maxima(phi, reco, tmask):
        tmp = np.real(reco * np.exp(1j * phi)) * tmask
        minima, maxima = np.percentile(tmp[tmp != 0], [0.01, 99])
        return maxima

    # Choose optimization method
    if method == "minima":
        optimized_phase = scipy.optimize.fminbound(
            minima, -np.pi, np.pi, args=(recon, mask), disp=False
        ).astype(float)
    elif method == "contrast":
        optimized_phase = scipy.optimize.fminbound(
            contrast, -np.pi, np.pi, args=(recon, mask), disp=False
        ).astype(float)
    elif method == "max":
        optimized_phase = scipy.optimize.fminbound(
            contrast, -np.pi, np.pi, args=(recon, mask), disp=False
        ).astype(float)

    # Calc optimized reconstruction
    recon_optimized = recon * np.exp(1j * optimized_phase)

    # Optional: Shift phase such that "background" are white or black
    mi_p, ma_p = np.percentile(np.real(recon_optimized[mask == 1]), [1, 99])
    mean = np.mean(np.real(recon_optimized[mask == 1]))
    if prefered_color == "white":
        # Make white domains the dominant domains
        if mean < (mi_p + ma_p) / 2:
            optimized_phase = optimized_phase + np.pi
            recon_optimized = recon * np.exp(1j * optimized_phase)
    elif prefered_color == "black":
        # Make black domains the dominant domains
        if mean > (mi_p + ma_p) / 2:
            optimized_phase = optimized_phase + np.pi
            recon_optimized = recon * np.exp(1j * optimized_phase)

    return recon_optimized, optimized_phase

In [None]:
def correct_background(image):
    offset = np.quantile(image[80:120],0.1,axis=0)
    mask_vert = np.ones(image.shape)
    image = image - offset * mask_vert

    offset = np.quantile(image[-40:],0.1,axis=0)
    mask_vert = np.ones(image.shape)
    mask_vert[int(image.shape[0]/2):] = 0
    image = image - offset * mask_vert

    image[56:59,:]  = 0
    
    return image

# Experimental Details

In [None]:
# Dict with most basic experimental parameter
experimental_setup = {
    "ccd_dist": 0.09,  # ccd to sample distance
    "px_size": 13.5e-6,  # pixel_size of camera
    "binning": 1,  # Camera binning
}

# Setup for azimuthal integrator
detector = Detector(
    experimental_setup["binning"] * experimental_setup["px_size"],
    experimental_setup["binning"] * experimental_setup["px_size"],
)

# General saving folder and log folder
folder_general = sup.create_folder(join(BASEFOLDER, "processed"))
sup.create_folder(join(folder_general, "Logs"))
sup.create_folder(join(folder_general, "Topos"))
print("Output Folder: %s" % folder_general)

# Load images

Start by loading the images: image of interest (im), reference of charge scattering (topo), any kind of dark image (dark)

We estalished the following convention: Difference Hologram which contains only the magnetic scattering will be calculated according to:

$Diff = \frac{Image}{factor} - Topo$,

where the factor is used for intensity scaling. In Case that you recorded scans of the same magnetic state with both helicities, use the image with negative helicity as topo and the one with positive helicity as image

In [None]:
# Define scan ids for each image
im_id_set = [749]  # single helicity mode: image with magnetic contrast, double helicity: pos
topo_id_set = (
    [747]  # single helicity mode: image without magnetic contrast, double helicity: neg
)
dark_id_im_set = [748]
dark_id_topo_set = dark_id_im_set

# Load energy and add to experimental setup
experimental_setup["energy"] = load_pre_scan_snapshot(im_id_set[0],"energy")
experimental_setup["lambda"] = cci.photon_energy_wavelength(
    experimental_setup["energy"], input_unit="eV"
)

print("Image Id: %s" % im_id_set)
print("Topo Id: %s" % topo_id_set)
print("Dark Id: %s" % dark_id_im_set)



## Load image of interest

In [None]:
# Load image
images = []
for i, im_id in enumerate(im_id_set):
    image = load_processing(im_id)
    image[56:59,:]  = 0

    dark= load_processing(dark_id_im_set[i])
    dark[56:59,:]  = 0
    image = image - dark

    # Only if this is not image with large beamstop
    image = correct_background(image)
    image[56:59,:]  = 0

    images.append(image)
    
images = np.stack(images)
    
# Plot
fig, ax = cimshow(images)
ax.set_title("Images")

image = np.mean(images,axis=0)

## Load topo data set and average

In [None]:
# Load topo
topos = []
for i, topo_id in enumerate(topo_id_set):
    topo = load_processing(topo_id)
    dark = load_processing(dark_id_topo_set[i])

    topo = topo - dark
    topo[56:59,:]  = 0

    # Only if this is not image with large beamstop
    topo = correct_background(topo)
    topo[56:59,:]  = 0

    topos.append(topo)

topos = np.array(topos)
topo = np.mean(topos,axis=0)
    
# Plot
fig, ax = cimshow(topos)
ax.set_title("Topo")

# Load holograms for stitching

In [None]:
# Define scan ids for each image
im_id_stitch_set = [744]
topo_id_stitch_set = [746]

dark_id_im = [745]
dark_id_topo = [745]

print("Image Id: %s" % im_id)
print("Topo Id: %s" % topo_id)

## Load image of interest

In [None]:
# Load image
image_stitches = []
for i,im_id_stitch in enumerate(im_id_stitch_set):
    image_stitch = load_processing(im_id_stitch)
    dark = load_processing(dark_id_im[i])
    image_stitch = image_stitch - dark
    image_stitch[56:59,:]  = 0
    image_stitch = correct_background(image_stitch)
    image_stitch[56:59,:]  = 0

    image_stitches.append(image_stitch)
    
image_stitches = np.array(image_stitches)
image_stitch = np.mean(image_stitches,axis=0)
    
# Plot
fig, ax = cimshow(image_stitches)
ax.set_title("Image for stitching")

## Load topo data set and average

In [None]:
# Load topo
topo_stitches = []
for i, topo_id_stitch in enumerate(topo_id_stitch_set):
    topo_stitch= load_processing(topo_id_stitch)
    dark= load_processing(dark_id_topo[i])
    topo_stitch = topo_stitch - dark
    topo_stitch[56:59,:]  = 0
    topo_stitch = correct_background(topo_stitch)
    topo_stitch[56:59,:]  = 0
    
    topo_stitches.append(topo_stitch)

topo_stitches = np.array(topo_stitches)
topo_stitch = np.mean(topo_stitches,axis=0)
    
# Plot
fig, ax = cimshow(topo_stitches)
ax.set_title("Topo for stitching")

In [None]:
# Show comparision between data for stitching and normal holograms
fig, ax = plt.subplots(2, 2, figsize=(8, 8), sharex=True, sharey=True)
mi, ma = np.percentile(image, [0.1, 99.9])
ax[0, 0].imshow(image, vmin=mi, vmax=ma)
ax[0, 0].set_title("Image")
mi, ma = np.percentile(image_stitch, [0.01, 99.9])
ax[0, 1].imshow(image_stitch, vmin=mi, vmax=ma)
ax[0, 1].set_title("Image for stitching")
mi, ma = np.percentile(topo, [0.1, 99.9])
ax[1, 0].imshow(topo, vmin=mi, vmax=ma)
ax[1, 0].set_title("Topo")
mi, ma = np.percentile(topo_stitch, [0.01, 99.9])
ax[1, 1].imshow(topo_stitch, vmin=mi, vmax=ma)
ax[1, 1].set_title("Topo for stitching")

# Center holograms

* Find center of the hologram to get a well-defined q-space. 
* Create smooth mask for beamstop or overexposed areas in direct beam

## Basic widget to find center

Try to **align** the circles to the **center of the scattering pattern**. Care! Position of beamstop might be misleading and not represent the actual center of the hologram. 

In [None]:
center = [1011.5,1101]

In [None]:
# Find center position via widget
c0, c1 = [1010.25,1098.25]  # initial values
ic = interactive.InteractiveCenter(gaussian_filter(image,5), c0=c0, c1=c1)

In [None]:
# Get center positions
center = [ic.c0, ic.c1]
center = [1010.25,1098.25]
print(f"Center:", center)
shift_c = np.array(image.shape) / 2 - center
print(f"Shift vector:", shift_c)

## Azimuthal integrator widget for finetuning
More of an "expert widget" which works very well for alignment if you have an Airy Pattern as a scattering image. Transform image from carthesian into polar coordinate system with angle `phi` and radial distance `q` as axis (Azimuthal transformation). If the center is set correctly, all rings of the Airy pattern will be transformed into a straight line spanning of phi at a given q.  

In [None]:
# Setup azimuthal integrator for virtual geometry
ai = AzimuthalIntegrator(
    dist=experimental_setup["ccd_dist"],
    detector=detector,
    wavelength=experimental_setup["lambda"],
    poni1=center[0]
    * experimental_setup["px_size"]
    * experimental_setup["binning"],  # y (vertical)
    poni2=center[1]
    * experimental_setup["px_size"]
    * experimental_setup["binning"],  # x (horizontal)
)

In [None]:
# Not the widget, just for double checking to find correct radial range
# Show some vertical lines at these qs
q_lines = [0.02, 0.02]

# Perform azimuthal transformation
I_t, q_t, phi_t = ai.integrate2d(
    image,
    500,  # number of points for phi
    radial_range=(0.01, 1),  # relevant q-range
    unit="q_nm^-1",
    correctSolidAngle=False,
    method = "BBox"
)
# Combine in an xarray for plotting
az2d = xr.DataArray(I_t, dims=("phi", "q"), coords={"q": q_t, "phi": phi_t})

# Plot
fig, ax = plt.subplots()
mi, ma = np.percentile(I_t, [1, 95])
az2d.plot.imshow(ax=ax, vmin=mi, vmax=ma)
plt.title(f"Azimuthal integration")

# Vertical lines
for qt in q_lines:
    ax.axvline(qt, ymin=0, ymax=180, c="red")

In [None]:
# The widget
aic = interactive.AzimuthalIntegrationCenter(
    np.log10(image - np.min(image) + 1),
    # image,
    ai,
    c0=center[0],
    c1=center[1],
    im_data_range=[1, 95],
    radial_range=(0.025, 0.1),
    qlines=[100, 110],
)

In [None]:
# Get center positions from widget
center = [aic.c0, aic.c1]
print(f"Center:", center)

## Center image hologram

In [None]:
# Apply to topo and image
shift_c = np.array(image.shape) / 2 - center
im_c = cci.shift_image(image, shift_c)
topo_c = cci.shift_image(topo, shift_c)  # centered image

## Calculate Projection

In [None]:
%%time
proj_im = PhR.inv_gnomonic(im_c, center=None, experimental_setup = experimental_setup, method='cubic' , mask=None)
proj_topo = PhR.inv_gnomonic(topo_c, center=None, experimental_setup = experimental_setup, method='cubic' , mask=None)

In [None]:
im_c = proj_im.copy()
topo_c = proj_topo.copy()

## Center Stitching holograms

In [None]:
image_stitch_c = cci.shift_image(image_stitch, shift_c)  # centered image
topo_stitch_c = cci.shift_image(topo_stitch, shift_c)  # centered image

In [None]:
%%time
proj_im = PhR.inv_gnomonic(image_stitch_c, center=None, experimental_setup = experimental_setup, method='cubic' , mask=None)
proj_topo = PhR.inv_gnomonic(topo_stitch_c, center=None, experimental_setup = experimental_setup, method='cubic' , mask=None)

In [None]:
image_stitch_c = proj_im.copy()
topo_stitch_c = proj_topo.copy()

# Create beamstops

We want to cover the beamstop with a smooth circle to cover its sharp edges as these would create ringing-like artifacts in the reconstruction plane. Make it only as large as necessary to keep as much information as possible.

## Manual masking

In [None]:
poly_mask = interactive.draw_polygon_mask(im_c)

In [None]:
# Take poly coordinates and mask from widget
p_coord = poly_mask.coordinates
mask_draw = poly_mask.full_mask.astype(int)

print("Copy these coordinates into the 'load_poly_coordinates()' function:")
print(p_coord)

# Plot image with beamstop and valid pixel mask
fig, ax = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(9, 3))
mi, ma = np.percentile(im_c * (1 - mask_draw), [0.1, 99.9])
ax[0].imshow(im_c * (1 - mask_draw), cmap="viridis", vmin=mi, vmax=ma)
ax[0].set_title("Image * (1-mask_draw)")

mi, ma = np.percentile(im_c * mask_draw, [0.1, 99.9])
ax[1].imshow(im_c * mask_draw, vmin=mi, vmax=ma)
ax[1].set_title("Image * mask_draw")

ax[2].imshow(1 - mask_draw)
ax[2].set_title("1 - mask_draw")

In [None]:
def load_poly_coordinates():
    """
    Dictionary that stores polygon corner coordinates of all drawn masks
    Example: How to add masks with name "test":
    mask_coordinates["test"] = copy coordinates from above
    """

    # Setup dictonary
    mask_coordinates = dict()

    # Setup dictonary
    mask_coordinates = dict()

    # Mask #1
    mask_coordinates["bs_medium"] = [[(1017.394483698999, -35.46145449788588), (1035.2276956836204, 505.3876493026536), (1047.325406869094, 994.0097061491854), (1059.1093028377018, 1016.5955067556837), (1459.190721041451, 991.086784173353), (2068.7721840663694, 944.2269827091939), (2067.1563286309156, 987.8550794664445), (1460.4026126180413, 1031.8871339185596), (1053.5446852969703, 1049.0012206693552), (1101.2358933146809, 2097.0416552096503), (1041.885818564897, 2104.3587877130485), (1023.8641950107121, 1398.6985283148524), (1012.3010494068428, 1055.2204990972316), (529.7787049731799, 1059.421969824299), (-49.71919730977076, 1055.9804064833297), (-45.08612508776184, 1014.2827564852495), (532.8674197878524, 1015.4077837152142), (1010.9917276325531, 1011.3582196585246), (992.2042825905437, 510.2399139372111), (963.1751999921164, -33.513456280872134)]]
    mask_coordinates["bs_medium_2"] = [[(972.9776892552021, 1004.2743085178313), (683.0854056548453, 1011.0191797381575), (429.3488120641165, 1014.0243786386102), (-5.38786287953252, 1010.0262182986162), (-11.784588413299048, 1027.546033470307), (-6.456709339024059, 1050.189519535976), (407.3336253818711, 1051.9815970562745), (785.5107592764842, 1048.7548363356034), (970.4099189292948, 1041.9313069708498), (977.0697677721386, 1049.9231255822624), (987.7255259206885, 1067.238732573656), (1005.041132912082, 1075.2305511850686), (1010.3690119863572, 1093.8781279450309), (1022.57787485063, 1452.2001151306752), (1023.891500713352, 1538.8489667863084), (1029.7269726228894, 1668.2234921518748), (1036.1520087654237, 1833.2762608301566), (1039.0507610028706, 2010.5397226740427), (1041.0931146480093, 2061.6873617870824), (1083.7161472422092, 2055.027512944239), (1082.3841774736404, 1997.7528128957829), (1061.329837800928, 1535.5648530012597), (1045.621811861143, 1081.8904000279126), (1059.0201832634914, 1068.6419757912079), (1080.0381970670428, 1045.6535231935734), (1250.4043869753773, 1036.9986877096228), (1711.4432772620323, 1003.3274407611157), (1930.818796336602, 981.3023122296238), (1968.2328715564279, 975.4216165269843), (2057.474846050534, 986.0773746755342), (2052.9860015697454, 945.8344154978255), (1975.482075669149, 939.2662861903159), (1901.4373645403866, 946.447411617374), (1619.8397671015532, 973.7708588493829), (1382.0734859488762, 992.1616209292713), (1192.6712743111323, 1001.972213913788), (1074.126880684794, 1006.2447473119142), (1066.1341462971013, 991.3164557652381), (1042.158690462864, 972.6688790052756), (1026.505884395897, 560.1120440507323), (1007.553206285446, -9.489556366445356), (970.9641323120604, -5.802018358814962), (976.3933883688477, 150.2045104582064), (980.8479680128388, 277.27966418395783), (990.3028627366012, 506.77436338214807), (997.1844288881059, 749.2817066564569), (1006.4751487546124, 969.463223155699), (984.8003220196999, 986.5403593710846)]]
    mask_coordinates["bs_large"] =[[(-9.52514267623522, 988.0244972208288), (133.8623449984891, 1003.0301645356255), (444.7643270650601, 1004.8352658735243), (677.4009610677929, 1006.3647572722471), (999.1891601517673, 998.0282754306933), (1008.9687632034429, 924.746257259271), (1002.856550157784, 753.6042842087398), (990.8526783102134, 392.799693733892), (977.5143073637275, -17.355212870551895), (1009.1929383616316, -5.684138292376701), (1027.5331984130498, 374.4594336824738), (1040.871569359536, 736.2627456059059), (1047.540754832779, 842.9697131777938), (1049.2080512010896, 983.0226081158966), (1113.8991502915467, 991.4702430486711), (1282.2960834909322, 986.4683539437389), (1529.0559460009226, 968.1280938923206), (1872.5189978729368, 933.1148701577949), (1955.8838162884742, 928.1129810528627), (2022.575671020904, 939.784055631038), (2029.2448564941471, 996.4721321536034), (1947.5473344469203, 974.7972793655637), (1627.4264317312568, 998.139428521914), (1333.9822709085652, 1019.8142813099539), (1133.9067067112755, 1031.485355888129), (955.5059953020256, 1034.8199486247506), (921.8266086621485, 1036.8207042667236), (678.7348313316545, 1045.0328414964195), (469.98929284993585, 1046.824482476588), (116.52246276805741, 1046.824482476588), (-6.8574684869379325, 1023.4823333202376)], [(-18.49404559764608, 785.86607528193), (-15.741149396111126, -20.73251176780832), (1721.3363537724376, -15.226719364738301), (2061.961377109028, -16.878457085658965), (2070.998269815253, 2083.925571530775), (-33.65294626451481, 2100.746691469454), (-41.88066798492048, 1173.7567109934555), (-50.10838970532615, 731.8366354770906), (43.139123125938, 737.3217832910489), (81.53515782116438, 62.64860221779031), (960.9090655017594, 46.10727357274732), (1962.9632828604786, 59.87175458042202), (2008.0241868014546, 946.3583113396406), (1981.6954772961565, 1969.1555403163811), (838.0421581597698, 2002.0664271980036), (56.04291820032506, 1965.3159367546796), (34.102326945909965, 1019.127939000208), (36.84490085271187, 742.1279744132175)]]
    mask_coordinates["bs_large_2"] =[[(-2.5086775209114265, 1033.1521507595412), (462.5106850256015, 1047.2212079066492), (716.4941903654963, 1055.3664515181326), (821.4295329342035, 1050.7075435593438), (1315.2912422480256, 1058.9385720479077), (1567.9431462107898, 1056.584405702532), (2063.325378790087, 1047.3157061119305), (2059.255245247832, 986.3969629829544), (1574.0729004151235, 989.2590341562225), (1071.5700027986707, 997.6082929549916), (1058.7059864142693, 849.2688655802834), (1065.5626473107302, 266.452689381122), (1066.5421702959388, -9.772792447724441), (1001.8936532721664, -7.81374647730695), (998.9550843165403, 420.23779805888387), (998.3722543978254, 831.5693599969302), (981.7208968914667, 973.1058988009786), (865.4419098489327, 984.3308651809253), (813.1192991686708, 984.3308651809253), (463.38395409533996, 981.5770435661748), (-1.027724137005336, 967.2497251757201)]]
    mask_coordinates["bs_tiny"] = [[(1031.1221816767775, 977.4340148145258), (1034.0037306716035, 977.2570775955452), (1037.7326765093844, 971.856949685674), (1031.4787790844405, 968.2530087967233), (1028.192832979809, 972.5989375157521), (1029.8583443983453, 977.0801403765646)], [(1030.2067999471637, 998.8865063528044), (1026.1788660124541, 1002.3844489803155), (1028.2988312412488, 1007.6843620523017), (1032.9627547445966, 1008.3203516209401), (1035.4007147577104, 1005.1404037777482), (1034.8707234505118, 999.2045011371235)]]
    
    return mask_coordinates

### Image

In [None]:
cimshow(mask_draw)

In [None]:
reload(interactive)

In [None]:
mask_draw = load_poly_masks(
    im_c.shape,
    polygon_names,
)
ssm = interactive.Shift_Scale_Mask(im_c,mask_draw)

In [None]:
# Which drawn masks do you want to load? You can combine multiple masks from
# load_poly_coordinates(). Just add names of mask as strings to list like
# ["bs_small","bs_medium"]
polygon_names = ["bs_large_2"]
mask_draw = load_poly_masks(
    im_c.shape,
    polygon_names,
)
#mask_draw = np.zeros(image.shape)
mask_draw[mask_draw>1] = 1
mask_draw = np.round(cci.shift_image(mask_draw,[0,-34]))
mask_draw = mask_draw + cci.circle_mask(image.shape,[1022,1024-34],182)
mask_draw[mask_draw>1] = 1

# Expand Mask if necessary
#footprint = skimage.morphology.disk(1)
#mask_draw = skimage.morphology.erosion(mask_draw, footprint)

# Shift mask is necessary
# mask_draw = cci.shift_image(mask_draw, [50, 49])
#mask_draw = cci.shift_image(mask_draw, [-3, 0])

# Plot image with beamstop and valid pixel mask
fig, ax = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(9, 3))
mi, ma = np.percentile(im_c * (1 - mask_draw), [0.1, 99.9])
ax[0].imshow(im_c * (1 - mask_draw), cmap="viridis", vmin=mi, vmax=ma)
ax[0].set_title("Image * (1-mask_draw)")

mi, ma = np.percentile(im_c * mask_draw, [0.1, 99.9])
ax[1].imshow(im_c * mask_draw, vmin=mi, vmax=ma)
ax[1].set_title("Image * mask_draw")

ax[2].imshow(1 - mask_draw)
ax[2].set_title("1 - mask_draw")
plt.tight_layout()

In [None]:
# Which drawn masks do you want to load for images?
polygon_names = ["bs_medium"]
mask_draw = load_poly_masks(
    im_c.shape,
    polygon_names,
)

# Shift mask is necessary
mask_draw = cci.shift_image(mask_draw, [-5,-12 ]) + cci.circle_mask(image.shape,[1024,1015],54)
mask_draw[mask_draw > 1] = 1

# Expand Mask if necessary
#footprint = skimage.morphology.disk(1)
#mask_draw = skimage.morphology.erosion(mask_draw, footprint)

# Shift mask is necessary
# mask_draw = cci.shift_image(mask_draw, [50, 49])
#mask_draw = cci.shift_image(mask_draw, [-3, 0])

# Plot image with beamstop and valid pixel mask
fig, ax = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(9, 3))
mi, ma = np.percentile(im_c * (1 - mask_draw), [0.1, 99.9])
ax[0].imshow(im_c * (1 - mask_draw), cmap="viridis", vmin=mi, vmax=ma)
ax[0].set_title("Image * (1-mask_draw)")

mi, ma = np.percentile(im_c * mask_draw, [0.1, 99.9])
ax[1].imshow(im_c * mask_draw, vmin=mi, vmax=ma)
ax[1].set_title("Image * mask_draw")

ax[2].imshow(1 - mask_draw)
ax[2].set_title("1 - mask_draw")
plt.tight_layout()

### Stitch

In [None]:
# Which drawn masks do you want to load for stitching?
polygon_names = ["bs_medium"]
mask_draw_stitch = load_poly_masks(
    im_c.shape,
    polygon_names,
)

# Shift mask is necessary
mask_draw_stitch = cci.shift_image(mask_draw_stitch, [-8,-8 ]) + cci.circle_mask(image.shape,[1024,1020],54)
mask_draw_stitch[mask_draw_stitch > 1] = 1

# Expand Mask if necessary
footprint = skimage.morphology.disk(1)
#mask_draw_stitch = skimage.morphology.erosion(mask_draw_stitch, footprint)

# Shift mask is necessary
# mask_draw_stitch = cci.shift_image(mask_draw_stitch, [50, 49])
#mask_draw_stitch = cci.shift_image(mask_draw_stitch, [-3, 0])

# Plot image with beamstop and valid pixel mask
fig, ax = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(9, 3))
mi, ma = np.percentile(image_stitch_c * (1 - mask_draw_stitch), [0.1, 99.9])
ax[0].imshow(image_stitch_c * (1 - mask_draw_stitch), cmap="viridis", vmin=mi, vmax=ma)
ax[0].set_title("Image * (1-mask_draw_stitch)")

mi, ma = np.percentile(image_stitch_c * mask_draw_stitch, [0.1, 99.9])
ax[1].imshow(image_stitch_c * mask_draw_stitch, vmin=mi, vmax=ma)
ax[1].set_title("Image * mask_draw_stitch")

ax[2].imshow(1 - mask_draw_stitch)
ax[2].set_title("1 - mask_draw_stitch")
plt.tight_layout()

In [None]:
# Which drawn masks do you want to load for stitching?
polygon_names = ["bs_tiny"]
mask_draw_stitch = load_poly_masks(
    im_c.shape,
    polygon_names,
)

# Shift mask is necessary
mask_draw_stitch = mask_draw_stitch + cci.circle_mask(image.shape,[1029,1028],25)
mask_draw_stitch[mask_draw_stitch>1] = 1
mask_draw_stitch = np.round(cci.shift_image(mask_draw_stitch,[-5,-5]))
mask_draw_stitch[mask_draw_stitch>1] = 1

# Expand Mask if necessary
#footprint = skimage.morphology.disk(1)
#mask_draw_stitch = skimage.morphology.erosion(mask_draw_stitch, footprint)

# Shift mask is necessary
# mask_draw_stitch = cci.shift_image(mask_draw_stitch, [50, 49])
#mask_draw_stitch = cci.shift_image(mask_draw_stitch, [-3, 0])

# Plot image with beamstop and valid pixel mask
fig, ax = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(9, 3))
mi, ma = np.percentile(image_stitch_c * (1 - mask_draw_stitch), [0.1, 99.9])
ax[0].imshow(image_stitch_c * (1 - mask_draw_stitch), cmap="viridis", vmin=mi, vmax=ma)
ax[0].set_title("Image * (1-mask_draw_stitch)")

mi, ma = np.percentile(image_stitch_c * mask_draw_stitch, [0.1, 99.9])
ax[1].imshow(image_stitch_c * mask_draw_stitch, vmin=mi, vmax=ma)
ax[1].set_title("Image * mask_draw_stitch")

ax[2].imshow(1 - mask_draw_stitch)
ax[2].set_title("1 - mask_draw_stitch")
plt.tight_layout()

# Image Registration

Relative drift between data holograms and their corresponding topo holograms is calculated by image registration algorithm. Necessary to get well defined difference hologram. The reference is always the static background image (topo).

## Set Alignment ROI 

Set a region of interest (ROI) of reference (topo) use for image registration is performed. Don't select ROI which contains the beamstop as this leads to wrong results

How to use:
1. Zoom into the image and adjust your FOV until you are satisfied.
2. Save the axes coordinates.

In [None]:
fig, ax = cimshow(im_c * (1 - mask_draw))
ax.set_title("Can include beamstop")

In [None]:
# Takes start and end of x and y axis
x1, x2 = ax.get_xlim()
y2, y1 = ax.get_ylim()
roi_im_reg = np.array([y1, y2, x1, x2]).astype(int)
roi_im_reg = np.s_[roi_im_reg[0] : roi_im_reg[1], roi_im_reg[2] : roi_im_reg[3]]

print(f"Image registration roi:", roi_im_reg)

## Calculate drift of images

In [None]:
shift = cci.image_registration(
    im_c[roi_im_reg],
    topo_c[roi_im_reg],
    method="dipy",
    #static_mask=mask_draw[roi_im_reg],
    #moving_mask=mask_draw[roi_im_reg],
)
print(shift)

In [None]:
shift = [0, 0]

## Center and apply circular beamstop to topo image

In [None]:
# Shift and apply beamstop
topo_c = cci.shift_image(topo_c, shift)  # centered image
topo_stitch_c = cci.shift_image(topo_stitch_c, shift)
# topo_b = topo_c * (1-mask_draw)  # centered image with beamstop

# Plot original and shifted holos
mi, ma = np.percentile(np.real(im_c[im_c != 0]), (0.1, 99.9))
fig, ax = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(8, 4))
ax[0].imshow(np.real(image), cmap="viridis", vmin=mi, vmax=ma)
ax[0].set_title("Uncentered image")
ax[1].imshow(np.real(im_c), cmap="viridis", vmin=mi, vmax=ma)
ax[1].set_title("Centered image with beamstop")

# Add circles with different radi r
tmp = np.array(image.shape) / 2
for r in np.arange(50, 200, 50):
    ax[0].add_artist(plt.Circle((tmp[1], tmp[0]), r, fill=None, ec="red"))
    ax[1].add_artist(plt.Circle((tmp[1], tmp[0]), r, fill=None, ec="red"))

# Execute Stitching

## Define Rois to calc scaling between holograms

In [None]:
# Define mask for each image
#mask_im = load_poly_masks(
#    im_c.shape,
#    ["bs_medium"],
#)
mask_im = mask_draw.copy()

#mask_stitch = load_poly_masks(
#    im_c.shape,
#    ["bs_bar_ver"],
#)
mask_stitch = mask_draw_stitch.copy()
# mask_hor = cci.shift_image(mask_hor, [-2, 0])

mask_both = mask_im + mask_stitch
mask_both[mask_both>1] = 1 
# Expand Mask if necessary
#footprint = skimage.morphology.disk(2)
#mask_im = skimage.morphology.dilation(mask_ver, footprint)
#mask_stitch = skimage.morphology.dilation(mask_hor, footprint)

In [None]:
fig, ax = cimshow(np.log10(image_stitch_c - np.min(image_stitch_c) + 1))
ax.set_title("Choose roi with relevant statistics for scaling")
ax.imshow(mask_both, alpha=0.3)

In [None]:
# Takes start and end of x and y axis
x1, x2 = ax.get_xlim()
y2, y1 = ax.get_ylim()
roi_stitch = np.array([y1, y2, x1, x2]).astype(int)
roi_stitch = np.s_[roi_stitch[0] : roi_stitch[1], roi_stitch[2] : roi_stitch[3]]
print(f"Stitching roi:", roi_stitch)

## Calc scaling

### Define additional masks

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 4))
ax[0].hist(image_stitch_c[(mask_both).astype(bool)], 200)
ax[0].set_xscale("log")
ax[0].set_yscale("log")
ax[0].set_title("Stitch Image")
ax[1].hist(topo_stitch_c[(mask_both).astype(bool)], 200)
ax[1].set_title("Stitch Topo")
ax[1].set_xscale("log")
ax[1].set_yscale("log")

In [None]:
# Show comparision between data for stitching and normal holograms
fig, ax = plt.subplots(2, 2, figsize=(8, 8), sharex=True, sharey=True)
mi, ma = np.percentile((im_c * (1 - mask_both))[roi_stitch], [0.1, 99.99])
ax[0, 0].imshow((im_c * (1 - mask_both))[roi_stitch], vmin=530, vmax=ma)
ax[0, 0].set_title("Image")
mi, ma = np.percentile((image_stitch_c * (1 - mask_both))[roi_stitch], [0.1, 99.99])
ax[0, 1].imshow((image_stitch_c * (1 - mask_both))[roi_stitch], vmin=mi, vmax=ma)
ax[0, 1].set_title("Image for stitching")


mi, ma = np.percentile((topo_c * (1 - mask_both))[roi_stitch], [0.1, 99.99])
ax[1, 0].imshow((topo_c * (1 - mask_both))[roi_stitch], vmin=530, vmax=ma)
ax[1, 0].set_title("Topo")
mi, ma = np.percentile((topo_stitch_c * (1 - mask_both))[roi_stitch], [0.1, 99.99])
ax[1, 1].imshow((topo_stitch_c * (1 - mask_both))[roi_stitch], vmin=mi, vmax=ma)
ax[1, 1].set_title("Topo for stitching")

### Fitting

In [None]:
# Get scaling factor and offset of image
back_im_c, _ = np.percentile(im_c, [0.1, 100])
back_im_stitch, _ = np.percentile(image_stitch_c, [1, 100])

factor_stitch_im, offset_stitch_im = cci.dyn_factor(
    ((im_c - back_im_c) * (1 - mask_both))[roi_stitch],
    ((image_stitch_c - back_im_stitch) * (1 - mask_both))[roi_stitch],
    method="correlation",
    print_out=True,
    plot=True,
)

# Get scaling factor and offset of topo
back_topo_c, _ = np.percentile(topo_c, [0.1, 100])
back_topo_stitch, _ = np.percentile(topo_stitch_c, [10, 100])

factor_stitch_topo, offset_stitch_topo = cci.dyn_factor(
    ((topo_c - back_topo_c) * (1 - mask_both))[roi_stitch],
    ((topo_stitch_c - back_topo_stitch) * (1 - mask_both))[roi_stitch],
    method="correlation",
    print_out=True,
    plot=True,
)

print(
    "Backgrounds im: %.1f im_stitch: %.1f topo: %.1f topo_stitch: %.1f"
    % (back_im_c, back_im_stitch, back_topo_c, back_topo_stitch)
)

In [None]:
tmp = image_stitch_c > 300

fig ,ax = cimshow(tmp.astype(int))
ax.imshow(mask_both,alpha=0.3)

## Stitch

In [None]:
fig, ax = cimshow(im_c)

In [None]:
roi = interactive.axis_to_roi(ax)

In [None]:
correction = 1  # 0.5

# q range for stitching
# radi_stitch = bs_diam - 8  # + 5
# mask_stitch = +cci.circle_mask(
#    image.shape,
#    [np.array(image.shape[0]) / 2, np.array(image.shape[0]) / 2 + 4],
#    38,
# )
# mask_stitch = mask_im
# mask_stitch = gaussian_filter(mask_stitch, 1)

# Stitch together
im_stitched = (
    im_c * (1 - mask_im)
    + (
        correction * factor_stitch_im * (image_stitch_c - back_im_stitch)
        + back_im_stitch
    )
    * mask_im
)
topo_stitched = (
    topo_c * (1 - mask_im)
    + (
        correction * factor_stitch_topo * (topo_stitch_c - back_topo_stitch)
        + back_topo_stitch
    )
    * mask_im
)

#mask_stitch = mask_im.copy()

# Plot
fig, ax = plt.subplots(2, 3, figsize=(12, 8), sharex=True, sharey=True)
mi, ma = np.percentile(image_stitch_c[roi], [0.1, 99.99])
ax[0, 0].imshow(image_stitch_c[roi], vmin=mi, vmax=ma)
ax[0, 0].imshow(mask_stitch[roi], alpha=0.2)
ax[0, 0].set_title("Centered Stitching Image")
mi, ma = np.percentile(im_c[roi], [0.1, 99.99])
ax[0, 1].imshow(im_c[roi], vmin=mi, vmax=ma)
ax[0, 1].imshow(mask_stitch[roi], alpha=0.2)
ax[0, 1].set_title("Centered Image")
mi, ma = np.percentile(im_stitched[roi], [0.1, 99.95])
ax[0, 2].imshow(im_stitched[roi], vmin=mi, vmax=ma)
# ax[0, 2].imshow(mask_stitch, alpha=0.2)
ax[0, 2].set_title("Stitched Image")

mi, ma = np.percentile(topo_stitch_c[roi], [0.1, 99.99])
ax[1, 0].imshow(topo_stitch_c[roi], vmin=mi, vmax=ma)
# ax[1, 0].imshow(mask_stitch, alpha=0.2)
ax[1, 0].set_title("Centered Stitching Topo")
mi, ma = np.percentile(topo_c[roi], [0.1, 99.9])
ax[1, 1].imshow(topo_c[roi], vmin=mi, vmax=ma)
ax[1, 1].imshow(mask_stitch[roi], alpha=0.2)
ax[1, 1].set_title("Centered Topo")
mi, ma = np.percentile(topo_stitched[roi], [0.1, 99.9])
ax[1, 2].imshow(topo_stitched[roi], vmin=mi, vmax=ma)
# ax[1, 2].imshow(mask_stitch, alpha=0.2)
ax[1, 2].set_title("Stitched Topo")

## Add circles with different radi r
# tmp = np.array(image.shape) / 2
# for r in np.arange(0, 5, 10):
#   ax[0, 0].add_artist(plt.Circle((tmp[1], tmp[0]), r, fill=None, ec="red"))

In [None]:
cimshow(np.log10(im_stitched-np.min(im_stitched)+1))

# Calculate difference holograms

You can see the reconstrution of the magnetization only after subtracting the large background that you get from the diffraction on the circular object aperture (Airy Pattern). This might require a scaling factor to correct intensity changes between the hologram and the topo. Scaling factor will be determined automatically by a linear fit. If the fit seems off, there might be an issue with the data

# Create beamstops

We want to cover the beamstop with a smooth circle to cover its sharp edges as these would create ringing-like artifacts in the reconstruction plane. Make it only as large as necessary to keep as much information as possible.

In [None]:
mask_pixel = mask_im * mask_stitch
# Create smooth mask
footprint = skimage.morphology.disk(6)
mask_pixel_smooth = skimage.morphology.dilation(mask_pixel, footprint)
mask_pixel_smooth = gaussian_filter(mask_pixel_smooth, 2)

cimshow((mask_pixel)*im_stitched)

In [None]:
# Set beamstop diameter and std for smoothing filter.
# Higher values mean stronger smoothing.
# If you have a very small beamstop you might need to reduce the smoothing value.
# Otherwise the beamstop center won't mask the direct beam with zeros
bs = interactive.InteractiveBeamstop(
    im_stitched, im_c.shape[0] / 2, im_c.shape[1] / 2, rBS=32, stdBS=3
)

In [None]:
# Take value from widget and create beamstop mask
bs_diam = bs.rBS
bs_smoothing = bs.stdBS
mask_bs = 1 - cci.circle_mask(
    topo.shape, np.array(topo.shape) / 2, bs.rBS, sigma=bs_smoothing
)

# Apply beamstop to images
im_c = im_stitched.copy()
im_b = im_c * mask_bs

topo_c = topo_stitched.copy()
topo_b = topo_c * mask_bs

# Plot image with beamstop and valid pixel mask
fig, ax = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(9, 3))
mi, ma = np.percentile(im_b, [0.1, 99.9])
ax[0].imshow(im_b, cmap="viridis", vmin=mi, vmax=ma)
ax[0].set_title("Masked image")

ax[1].imshow(mask_bs)
ax[1].set_title("Beamstop mask")

ax[2].imshow(1 - mask_bs)
ax[2].set_title("1 - Beamstop mask")
plt.tight_layout()

In [None]:
# Get scaling factor and offset
factor, offset = cci.dyn_factor(
    im_c * (1 - mask_pixel),
    topo_c * (1 - mask_pixel),
    method="correlation",
    print_out=True,
    plot=True,
)

# Calculate differences (magnetic) and sums (topographc) contrast holograms.
# _c: centered, without beamstop, _b: centered, with beamstop
diff_c = im_c / factor - topo_c - offset
diff_b = diff_c * mask_bs
sum_c = im_c / factor + topo_c - offset
sum_b = sum_c * mask_bs

# Get scaling factor and offset
factor, offset = cci.dyn_factor(
    im_stitched,
    topo_stitched,
    method="correlation",
    print_out=True,
    plot=True,
)

# diff_stitched = im_stitched - factor * topo_stitched - offset
# sum_stitched = im_stitched + factor * topo_stitched - offset

In [None]:
# Plot an example of the difference or sum hologram
fig, ax = cimshow(diff_c)
ax.set_title(f" Diff Id %d" % im_id)

# fig, ax = cimshow(sum_b)
# ax.set_title(f" Sum Id %d" % im_id)

In [None]:
fig, ax = plt.subplots(figsize=(8,8))
tmp = np.sign(diff_c)*np.log10(np.abs(diff_c.copy()))
tmp = diff_c.copy()
mi, ma = np.percentile(tmp,[1,99])
ax.imshow(tmp,vmin=mi,vmax=ma)

# Reconstruct Diff Holos (FTH)

Reconstruct the hologramm.
1. Chose a region of interest (ROI) which means selecting one reconstruction from the rconstruction plane.
2. Propagate the image and shift the phase for maximal contrast and sharpness in your ROI
3. Optional finetuning with a widget

### Set Patterson Map ROI

Choose the reconstructions as the ROI.

1. Zoom into the image and adjust your FOV until you are satisfied.
2. Save the axes coordinates.

In [None]:
tmp_mask = np.ones(sum_c.shape)
tmp_mask[:435,970:1035] = 0
tmp_mask = gaussian_filter(tmp_mask,4)

cimshow(tmp_mask)

In [None]:
cimshow(np.real(cci.reconstruct(sum_c[200:-200,200:-200] * (1 - mask_pixel_smooth[200:-200,200:-200])*tmp_mask[200:-200,200:-200])),cmap='gray')

In [None]:
# Choose contrast mode
tmp = np.real(cci.reconstruct(sum_c* (1 - mask_pixel_smooth)))  # magnetic contrast only
#tmp = np.abs(cci.reconstruct(image_stitch))  # magnetic contrast only
#tmp = np.real(cci.reconstruct(sum_c[200:-200,200:-200] * (1 - mask_pixel_smooth[200:-200,200:-200])))  # topographic contrast only

fig, ax = cimshow(tmp, cmap="gray")

In [None]:
# Execute to get roi
x1, x2 = ax.get_xlim()
y2, y1 = ax.get_ylim()
roi = np.array([y1, y2, x1, x2]).astype(int)  # ystart, ystop, xstart, xstop
# roi = [576, 871, 519, 811]
# roi = [802, 1014, 283, 488]
# roi = [613, 794, 184, 363]
roi = [354, 515, 785, 953]
roi = [419, 929,   5, 425]
roi = [389, 664, 657, 924]
roi_s = np.s_[roi[0] : roi[1], roi[2] : roi[3]]
print(f"Roi Reco:{roi}")

## Tune propagation and phase
Focus the image by tuning the propagation distance. This really works like focussing in a microscope.
Phase slider will move contrast between real and imaginary part. Usually we use the phase which maximizes the contrast in the real part

In [None]:
# Widget
slider_prop, slider_phase, button = reco.propagate(
    diff_c* (1 - mask_pixel_smooth),
    roi_s,
    experimental_setup=experimental_setup,
    scale=(1, 99),
)

In [None]:
# Read prop dist and phase from widget
prop_dist = slider_prop.value
phase = slider_phase.value

print(f"Propagation distance: %0.2f" % prop_dist)
print(f"Phase: %0.2f" % phase)

## Ugly Widget for finetuning of shift and factor
You can try to find a better scaling factor and optimize the relative shift between image and topo hologram with this widget. Just execute, find your values and then execute next cell to update the values. 

In [None]:
# Ta-ke initial values
temp_factor = factor
temp_shift = shift.copy()
temp_phase = phase

# Open figure
fig, [ax, ax2] = plt.subplots(1, 2, figsize=(10, 5))
fig.tight_layout()

# Widget
layout = ipywidgets.Layout(width="100%")
opts = dict(layout=layout)


@ipywidgets.interact(
    centery=(center[0] - 10, center[0] + 10, 0.25),
    centerx=(center[1] - 10, center[1] + 10, 0.25),
    tfactor=(
        0.6,
        1.4,
        0.0025,
    ),
    xshift=(-4, 4, 0.01),
    yshift=(-4, 4, 0.01),
    tphase=(-2 * np.pi, 2 * np.pi, 0.05),
    options=opts,
)
def update(
    centery=center[0],
    centerx=center[1],
    tfactor=factor,
    yshift=shift[0],
    xshift=shift[1],
    tphase=phase,
):  # initial values
    global temp_shift, temp_factor, temp_center, temp_diff, temp_sum, temp_phase  # ugly writing as global variable

    # Couple widget only and global variables
    temp_shift = [yshift, xshift]
    temp_factor = tfactor
    temp_center = [centery, centerx]
    temp_phase = tphase

    # Calc diff and sum holo
    shifted = cci.shift_image(im_c, temp_shift)
    temp_diff = shifted - temp_factor * topo_c
    temp_sum = shifted + temp_factor * topo_c

    temp_diff = cci.shift_image(
        temp_diff, -np.array([centery - center[0], centerx - center[1]])
    )
    temp_sum = cci.shift_image(
        temp_sum, -np.array([centery - center[0], centerx - center[1]])
    )

    temp_diff = temp_diff * mask_bs * (1 - mask_pixel_smooth)
    temp_sum = temp_sum * mask_bs * (1 - mask_pixel_smooth)

    # Reconstruction
    recon = cci.reconstruct(
        fth.propagate(
            temp_sum,
            prop_dist * 1e-6,
            experimental_setup=experimental_setup,
        )
        * np.exp(1j * temp_phase)
    )[roi_s]

    # Plots
    vmin, vmax = np.percentile(np.real(recon), (1, 99))
    ax.imshow(np.real(recon), vmin=vmin, vmax=vmax, cmap="gray")
    vmin, vmax = np.percentile(np.imag(recon), (1, 99))
    ax2.imshow(np.imag(recon), vmin=vmin, vmax=vmax, cmap="gray")
    # vmin, vmax = np.percentile(temp_diff * mask_bs * (1 - mask_pixel), (0.1, 99.9))
    # ax2.imshow(temp_diff * mask_bs, vmin=vmin, vmax=vmax)

In [None]:
# Copy values from widget
factor = temp_factor
shift = temp_shift
phase = temp_phase
center = temp_center.copy()
diff_c = temp_diff.copy()
sum_c = temp_sum.copy()
diff_b = diff_c * mask_bs
sum_b = sum_c * mask_bs

## Save reconstruction

Save png files of the images and a h5 file for all important variables

In [None]:
# Saves only real and imaginary part
holo = diff_b  # * (1 - mask_pixel_smooth)
recon = np.zeros(image.shape, dtype=np.complex_)

# Reconstruct
recon = fth.reconstruct(
    fth.propagate(holo, prop_dist * 1e-6, experimental_setup=experimental_setup)
    * np.exp(1j * phase)
)

# Plot
fig, ax = plt.subplots(1, 2, figsize=(10, 4))
fig.suptitle("Image %d - %d Stitching" % (im_id, topo_id))

vmin, vmax = np.percentile(np.real(recon[roi_s]), (0.1, 99))
t_im1 = ax[0].imshow(np.real(recon[roi_s]), vmin=vmin, vmax=vmax, cmap="gray")
ax[0].set_title("Real")
plt.colorbar(t_im1, ax=ax[0], aspect=50)

vmin, vmax = np.percentile(np.imag(recon[roi_s]), (0.1, 99))
t_im2 = ax[1].imshow(np.imag(recon[roi_s]), vmin=vmin, vmax=vmax, cmap="gray")
ax[1].set_title("Imag")
plt.colorbar(t_im2, ax=ax[1], aspect=50)

# Save images
fname = join(
    folder_general,
    "Recon_ImId_%04d_RefId_%04d_stitching_%s.png" % (im_id, topo_id, USER),
)
print("Saving: %s" % fname)
plt.savefig(fname, bbox_inches="tight", transparent=False)

# Save hdf5 file
save_fth_h5()

In [None]:
# Closes all existing plots
plt.close("all")

# CDI Reconstruction

## Create set of pos and neg helicity holograms

In [None]:
# Copy values from FTH reco
pos = (sum_c + diff_c) / 2
neg = (sum_c - diff_c) / 2

## Create hologram mask for phase retrieval
``mask_pixel[px,px] == 1`` means invalid, i.e., will be replaced by phase retrieval

In [None]:
def create_auto_beamstop(im_c, topo_c, mask_draw, use_bs, bs_param):
    # Use automatically determined beamstop mask?
    if use_bs["use_bs_auto"] is True:
        mask_im = automated_beamstop(
            im_c,
            bs_param["bs_thres"],
            bs_param["bs_radius"],
            bs_param["bs_expand"],
            method=bs_param["method"],
        )
        mask_topo = automated_beamstop(
            topo_c,
            bs_param["bs_thres"],
            bs_param["bs_radius"],
            bs_param["bs_expand"],
            method=bs_param["method"],
        )
    elif use_bs["use_bs_auto"] is False:
        mask_im = mask_draw.copy()
        mask_topo = mask_draw.copy()

    # Use otsu mask as center?
    if use_bs["use_otsu"] is True:
        tmp = cci.circle_mask(
            im_c.shape, np.array(im_c.shape) / 2, auto_bs_otsu.radius, sigma=None
        )
        mask_im = (1 - tmp) * mask_im + tmp * automated_beamstop_center(im_c, 4)
        mask_topo = (1 - tmp) * mask_topo + tmp * automated_beamstop_center(
            topo_c, auto_bs_otsu.expand
        )
    elif use_bs["use_otsu"] is False:
        mask_bs_center = cci.circle_mask(
            im_c.shape,
            np.array(im_c.shape) / 2,
            bs_param["rBS"] + bs_param["stdBS"],
            sigma=None,
        )
        mask_im = mask_im + mask_bs_center
        mask_topo = mask_topo + mask_bs_center

    # Use drawn beamstop?
    if use_bs["use_draw"] is True:
        tmp = cci.circle_mask(
            im_c.shape, np.array(im_c.shape) / 2, bs_param["bs_radius"], sigma=None
        )
        mask_im = tmp * mask_im + (1 - tmp) * mask_draw
        mask_topo = tmp * mask_topo + (1 - tmp) * mask_draw

    # Saturated pixel
    mask_im = mask_im + (im_c > 40000)
    mask_topo = mask_topo + (topo_c > 40000)

    # normalize
    mask_im[mask_im > 1] = 1
    mask_topo[mask_topo > 1] = 1

    # Combine both
    mask_pixel = mask_im + mask_topo
    mask_pixel[mask_pixel > 1] = 1

    # Additional masking
    # mask_pixel[1000:1500, 450:750] = 1

    # Sharp and smooth mask
    footprint = skimage.morphology.disk(6)
    mask_pixel_smooth = skimage.morphology.dilation(mask_pixel, footprint)
    mask_pixel_smooth = gaussian_filter(mask_pixel_smooth, 2)

    return mask_im, mask_topo, mask_pixel, mask_pixel_smooth

In [None]:
# Initial parameter for automatic beamstop detection
thres = 700
radi_thres = 140
expand = 5
method_bs = "intensity"
auto_bs = interactive.InteractiveAutoBeamstop(
    neg, thres, radi_thres, expand, method="intensity"
)

In [None]:
# Initial parameter for automatic beamstop detection
thres = 50
radi_thres = 30
expand = 4
auto_bs_otsu = interactive.InteractiveAutoBeamstop(
    im_c, thres, radi_thres, expand, method="otsu"
)

In [None]:
# Parameter for automatically determined beamstop
use_bs = {"use_bs_auto": False, "use_otsu": True, "use_draw": False}
bs_param = {
    "bs_thres": auto_bs.thres,
    "bs_radius": auto_bs.radius,
    "bs_expand": auto_bs.expand,
    "method": method_bs,
    "rBs": bs.rBS,
    "stdBS": bs.stdBS,
}

## Create beamstop automatically
# mask_im, mask_topo, mask_pixel, mask_pixel_smooth = create_auto_beamstop(
#    pos, neg, mask_draw, use_bs, bs_param
# )


mask_pixel = load_poly_masks(
    im_c.shape,
    ["bs_bar_ver"],
)
# footprint = skimage.morphology.disk(6)
# mask_pixel = skimage.morphology.dilation(mask_pixel, footprint)


mask_pixel_smooth = gaussian_filter(mask_pixel, 1)

mask_im = mask_pixel.copy()
mask_topo = mask_pixel.copy()

# Plot both
fig, ax = plt.subplots(2, 4, figsize=(12, 6), sharex=True, sharey=True)
mi, ma = np.percentile(pos, [1, 99.9])
ax[0, 0].imshow(pos, vmin=mi, vmax=ma)
ax[0, 0].set_title("Image")
mi, ma = np.percentile(pos * mask_im, [0.1, 99.99])
ax[0, 1].imshow(pos * mask_im, vmin=mi, vmax=ma)
ax[0, 1].set_title("Image*mask")
mi, ma = np.percentile(pos * (1 - mask_im), [0.1, 99.9])
ax[0, 2].imshow(pos * (1 - mask_im), vmin=mi, vmax=ma)
ax[0, 2].set_title("Image*(1-mask)")
ax[0, 3].imshow(mask_pixel)
ax[0, 3].set_title("Combined Mask")

mi, ma = np.percentile(neg, [1, 99.9])
ax[1, 0].imshow(neg, vmin=mi, vmax=ma)
ax[1, 0].set_title("Neg")
mi, ma = np.percentile(neg * mask_topo, [0.1, 99.99])
ax[1, 1].imshow(neg * mask_topo, vmin=mi, vmax=ma)
ax[1, 1].set_title("Neg")
mi, ma = np.percentile(neg * (1 - mask_topo), [0.1, 99.9])
ax[1, 2].imshow(neg * (1 - mask_topo), vmin=mi, vmax=ma)
ax[1, 2].set_title("Neg*(1-mask)")
mi, ma = np.percentile((pos - neg) * (1 - mask_pixel_smooth), [0.1, 99.9])
ax[1, 3].imshow((pos - neg) * (1 - mask_pixel_smooth), vmin=mi, vmax=ma)
ax[1, 3].set_title("Pos-Neg")

## Draw support mask

## Execute if you want to create a new support mask

In [None]:
# How many references do you have?
nr_ref = 4

# Setup coordinates (nr_ref + 1 coordinates, as there is always the object aperture)
support_coordinates = [
    [pos.shape[-2] // 2, pos.shape[-1] // 2, 7] for k in range(nr_ref + 1)
]

## Execute if you want to load an existing support mask

In [None]:
def get_supportmask_coordinates(sample):
    """
    Dictionary that stores coordinates of circular support mask apertures
    """

    # Setup dictonary
    support_coord = dict()

    # coordinates
    support_coord["s2304a_F4"] = [(790.5, 471.0, 132.5), (581.7, 376.0, 14.0), (571.0, 537.7, 16.0), (993.0, 295.5, 18.0), (1041.6, 380.0, 15.0), (1007.1, 484.0, 16.0), (1026.0, 595.0, 16.0), (968.0, 671.5, 17.0), (591.7, 646.0, 16.0), (487.1, 990.3, 12.0), (1024.0, 1024.0, 14.0)]
    support_coord["s2304a_E4"] = [(788.0, 470.5, 130.0), (581.7, 376.0, 16.0), (571.0, 537.7, 17.0), (993.0, 295.5, 20.0), (1040.0, 380.5, 17.0), (1007.1, 486.0, 17.0), (1026.0, 595.0, 17.0), (966.0, 669.5, 18.0), (591.7, 646.0, 17.0), (487.1, 989.0, 14.0), (1024.0, 1024.0, 14.0)]
    support_coord["s2304a_E4_500eV"] = [(850.5, 612.5, 98.0), (694.2, 542.3, 12.0), (687.5, 663.0, 12.0), (702.1, 742.7, 13.0), (999.5, 481.0, 17.0), (1036.0, 544.9, 14.0), (1010.8, 623.5, 13.0), (1025.0, 704.4, 15.0), (981.0, 760.5, 17.0), (623.3, 998.2, 11.0), (1024.0, 1024.0, 11.0)]
    support_coord["FBEb21"] = [(1325.0, 1650.5, 156.0), (1653.7, 1529.8, 13.0), (1670.0, 1640.7, 14.0), (1674.5, 1794.0, 13.0), (1590.5, 1873.0, 12.5), (1689.5, 1060.0, 13.5), (1024.0, 1024.0, 13.0)]
    support_coord["s2304a-B2"] = [(525.5, 794.5, 124.0), (520.5, 617.5, 8.5), (670.4, 641.0, 12.5), (677.5, 937.0, 13.5), (578.6, 988.0, 7.0), (480.0, 990.5, 7.0), (381.5, 944.4, 14.0), (1011.5, 530.3, 6.0), (1024.0, 1024.0, 6.0)]
    return support_coord[sample]

### Create new support mask

In [None]:
# Which supportmask to load? ("s2306a-C1", "s2308a-B1", ...)
sample = "s2304a-B2"

# Get coordinates and create supportmask
support_coordinates = get_supportmask_coordinates(sample)

In [None]:
110/6

In [None]:
40/6

In [None]:
30/6

In [None]:
ds = interactive.InteractiveCircleCoordinates(
    np.real(
        fth.reconstruct(pos * mask_bs * np.exp(1j * phase) * (1 - mask_pixel_smooth))
    ),
    len(support_coordinates),
    coordinates=support_coordinates,
)

In [None]:
# Take coordinates of circles from widget
support_coordinates = ds.get_mask()

# Create supportmask
supportmask = create_supportmask(support_coordinates, pos.shape)

# What to plot?
tmp = np.real(
    fth.reconstruct(pos * mask_bs * np.exp(1j * phase) * (1 - mask_pixel_smooth))
)

# Plot supportmask as overlay
fig, ax = plt.subplots(figsize=(6, 6))
mi, ma = np.percentile(tmp, (1, 99))
ax.imshow(tmp, vmin=mi, vmax=ma, cmap="gray")
ax.imshow(supportmask, alpha=0.3, cmap="binary")
ax.set_title("Image with overlayed mask")

### Take Roi

In [None]:
roi_cdi = interactive.axis_to_roi(ax)
# roi_cdi = [807, 1006, 285, 481]
# roi_cdi = [802, 1012, 279, 489]
# roi_cdi = [334, 488, 725, 884]
# roi_cdi = [1204, 1474, 539, 809]
roi_cdi = [393,657,661,923]
roi_cdi = np.s_[roi_cdi[0] : roi_cdi[1], roi_cdi[2] : roi_cdi[3]]
print("Sliced roi:", roi_cdi)

## Do Phase Retrieval

In [None]:
# Executes the algorithm
offset_vmin = .5
Startimage = None
Startgamma = None

(
    retrieved_p,
    retrieved_n,
    retrieved_p_pc,
    retrieved_n_pc,
    bsmask_p,
    bsmask_n,
    gamma_p,
    gamma_n,
) = phase_retrieval(pos, neg, mask_pixel.astype(int), supportmask, vmin = offset_vmin, Startimage=Startimage, Startgamma=Startgamma)

## Reconstruct images

In [None]:
# New beamstop for CDI recos as phase retrieval of low-q might be insufficient. If phase retrieval worked well
# Try without beamstop: `use_bs = False`
use_bs = True
bs_diam_cdi = 10  # diameter of beamstop


# Create beamstop
if use_bs is True:
    mask_bs_cdi = 1 - cci.circle_mask(
        topo.shape, np.array(topo.shape) / 2, bs_diam_cdi, sigma=4
    )
    mask_bs_cdi = 1 - mask_pixel_smooth.copy()
elif use_bs is False:
    mask_bs_cdi = np.ones(pos.shape)  # if you don't want a beamstop

# Get Recos
# Positive helicity
"""
p = cci.reconstruct(
    fth.propagate(
        retrieved_p * mask_bs_cdi,
        prop_dist_cdi * 1e-6,
        experimental_setup=experimental_setup,
    )
)

# Negative helicity
n = cci.reconstruct(
    fth.propagate(
        retrieved_n * mask_bs_cdi,
        prop_dist_cdi * 1e-6,
        experimental_setup=experimental_setup,
    )
)
"""

# Get Recos partial coherence
# Positiv partial coherence
p_pc = fth.reconstruct(
    fth.propagate(
        retrieved_p_pc * mask_bs_cdi,
        prop_dist_cdi * 1e-6,
        experimental_setup=experimental_setup,
    )
)

# Negative partial coherence
n_pc = fth.reconstruct(
    fth.propagate(
        retrieved_n_pc * mask_bs_cdi,
        prop_dist_cdi * 1e-6,
        experimental_setup=experimental_setup,
    )
)

# optimize phase
recon = p_pc - n_pc
_, phase_cdi = optimize_phase_contrast(recon, supportmask, method="contrast",prefered_color="white")

# Plotting
mode = "+"
print("Fine-tuning of reconstruction parameter:")
slider_prop, slider_phase, slider_dx, slider_dy = rec.focusCDI(
    retrieved_p_pc * mask_bs_cdi,
    retrieved_n_pc * mask_bs_cdi,
    #retrieved_p * mask_bs_cdi,
    #retrieved_n * mask_bs_cdi,
    roi_cdi,
    mask=supportmask,
    phase=phase_cdi,
    dx=dx,
    dy=dy,
    prop_dist=prop_dist_cdi,
    experimental_setup=experimental_setup,
    operation=mode,
    max_prop_dist=10,
    scale=(1, 99),
)

In [None]:
# Get phase from slider
phase_cdi = slider_phase.value
prop_dist_cdi = slider_prop.value

# Apply phase
p = p * np.exp(1j * phase_cdi)
n = p * np.exp(1j * phase_cdi)
p_pc = p_pc * np.exp(1j * phase_cdi)
n_pc = n_pc * np.exp(1j * phase_cdi)

print("Phase CDI: %s" % phase_cdi)
print("Prop_dist: %s" % prop_dist_cdi)

In [None]:
# Confirm that offset subtraction works, i.e., only small fraction of hologram is actually masked
fig, ax = plt.subplots(2, 2, figsize=(8, 8), sharex=True, sharey=True)
tmp = np.abs(retrieved_p_pc * mask_bs_cdi)
mi, ma = np.percentile(tmp, [0.1, 99.9])
ax[0, 0].imshow(tmp, vmin=mi, vmax=ma)
ax[0, 0].set_title("Pos holo")

tmp = np.abs(retrieved_n_pc)
mi, ma = np.percentile(tmp, [0.1, 99.9])
ax[0, 1].imshow(tmp, vmin=mi, vmax=ma)
ax[0, 1].set_title("Neg holo")
ax[1, 0].imshow(bsmask_p)
ax[1, 0].set_title("Pos holo mask")
ax[1, 1].imshow(bsmask_n)
ax[1, 1].set_title("Neg holo mask")

In [None]:
# Ta-ke initial values
# Open figure
fig, [ax, ax2,ax3] = plt.subplots(1, 3, figsize=(12, 4))

# Widget
layout = ipywidgets.Layout(width="100%")
opts = dict(layout=layout)


@ipywidgets.interact(
    centery=(center[0] - 10, center[0] + 10, .25),
    centerx=(center[1] - 10, center[1] + 10, .25),
    tphase = (-np.pi, np.pi, .1),
    options=opts,
)
def update(
    centery=center[0],
    centerx=center[1],
    tphase=phase,
):  # initial values
    global temp_center, temp_phase  # ugly writing as global variable

    # Couple widget only and global variables
    temp_center = [centery, centerx]
    temp_phase = tphase

    temp_diff = cci.shift_image(
        retrieved_p_pc-retrieved_n_pc, -np.array([centery - center[0], centerx - center[1]])
    )
    temp_sum = cci.shift_image(
        retrieved_p_pc+retrieved_n_pc, -np.array([centery - center[0], centerx - center[1]])
    )

    temp_diff_b = temp_diff * (1 - mask_pixel_smooth)
    temp_sum_b = temp_sum * (1 - mask_pixel_smooth)

    # Reconstruction
    recon = cci.reconstruct(
        fth.propagate(
            temp_sum_b,
            -prop_dist_cdi * 1e-6,
            experimental_setup=experimental_setup,
        )
        * np.exp(1j * temp_phase)
    )[roi_s]
    
    _, phase_cdi = optimize_phase_contrast(recon, supportmask[roi_s], method="contrast",prefered_color="white")
    recon = recon * np.exp(1j * phase_cdi)
    
    # Plots
    vmin, vmax = np.percentile(np.real(recon), (1, 99))
    ax.imshow(np.real(recon), vmin=vmin, vmax=vmax, cmap="gray")
    vmin, vmax = np.percentile(np.imag(recon), (1, 99))
    ax2.imshow(np.imag(recon), vmin=vmin, vmax=vmax, cmap="gray")

    
    # Add circles with different radi r
    vmin, vmax = np.percentile(temp_diff*mask_bs, (0.1, 99.9))
    ax3.imshow(np.abs(temp_diff*mask_bs), vmin=vmin, vmax=vmax)
    ax3.set_xlim([600,1400])
    ax3.set_ylim([600,1400])
    for r in np.arange(175, 450, 25):
        ax3.add_artist(plt.Circle(np.array(image.shape)/2, r, fill=None, ec="red"))

In [None]:
center= [1007.75, 1097.75]

In [None]:
new_center = [1007.75, 1097.75]

In [None]:
center_correct = np.array(center) - np.array(new_center)

In [None]:
center_correct

In [None]:
cimshow(np.real(p_pc+n_pc),cmap="gray")

In [None]:
pos = cci.shift_image(pos,center_correct)
neg = cci.shift_image(neg,center_correct)
cimshow(pos)

In [None]:
cimshow(np.real(p_pc-n_pc),cmap="gray")

## Save

In [None]:
def get_title(data_key,im_id,topo_id,CDI=False):
    # Magnetic field value in title
    #values = np.mean(np.array(load_data(im_id, data_key)) * 1000)
    #values = [np.round(values, 2)]

    if CDI is False:
        #title = "Image %s - %s FTH @%.2f mT" % (im_id, topo_id, values[0])
        title = "Image %s - %s FTH" % (im_id, topo_id)
    elif CDI is True:
        #title = "Image %s - %s CDI @%.2f mT" % (im_id, topo_id, values[0])
        title = "Image %s - %s CDI Stitching" % (im_id, topo_id)
        
    return title

In [None]:
def plot_recon(recon, title):
    # Plot
    fig, ax = plt.subplots(1, 2, figsize=(10, 4))
    fig.suptitle(title)

    vmin, vmax = np.percentile(np.real(recon), (.1, 99.9))
    t_im1 = ax[0].imshow(np.real(recon), vmin=vmin, vmax=vmax, cmap="gray")
    ax[0].set_title("Real")
    plt.colorbar(t_im1, ax=ax[0], aspect=50)

    vmin, vmax = np.percentile(np.imag(recon), (.1, 99.9))
    t_im2 = ax[1].imshow(np.imag(recon), vmin=vmin, vmax=vmax, cmap="gray")
    ax[1].set_title("Imag")
    plt.colorbar(t_im2, ax=ax[1], aspect=50)

In [None]:
# Saves only real and imaginary part
recon = p - n
recon = p_pc + n_pc  # / (p_pc + n_pc)

# Plot
title = get_title("m_magnett_read", im_id, topo_id, CDI=True)
plot_recon(np.flipud(np.fliplr(recon))[roi_cdi], title)

# Save images
fname = join(
    folder_general,
    "Recon_ImId_%04d_RefId_%s_cdi_stitching_sum_%s.png" % (im_id, topo_id, USER),
)
print("Saving: %s" % fname)
plt.savefig(fname, bbox_inches="tight", transparent=False)

# Save h5
save_cdi_h5()

In [None]:
# Saves only real and imaginary part
recon = p - n
recon = (p_pc - n_pc  )# / (p_pc + n_pc)

# Plot
title = get_title("m_magnett_read", im_id, topo_id, CDI=True)
plot_recon(np.flipud(np.fliplr(recon))[roi_cdi], title)

# Save images
fname = join(
    folder_general,
    "Recon_ImId_%04d_RefId_%s_cdi_stitching_diff_%s.png" % (im_id, topo_id, USER),
)
print("Saving: %s" % fname)
plt.savefig(fname, bbox_inches="tight", transparent=False)

# Save h5
save_cdi_h5()

In [None]:
# Magnetic field value
data_key = "magnett_read"
values = np.mean(np.array(load_data(im_id, data_key)) * 1000)
values = [np.round(values, 2)]

# Cryostat temperature if saved in nxs file
try:
    data_key = "cryob"
    values.append(np.mean(np.array(load_collection(im_id, data_key))))
    title = "Image %d - %d CDI @%.2f mT, %0.1f K" % (
        im_id,
        topo_id,
        values[0],
        values[1],
    )
except:
    title = "Image %d - %d CDI @%.2f mT" % (im_id, topo_id, values[0])

In [None]:
# Saves only real and imaginary part
recon = p - n
recon = p_pc + n_pc  # / (p_pc + n_pc)

# Plot
title = get_title("m_magnett_read", im_id, topo_id, CDI=True)
plot_recon(np.flipud(np.fliplr(recon))[roi_cdi], title)

# Save images
fname = join(
    folder_general,
    "Recon_ImId_%04d_RefId_%s_cdi_stitching_sum_%s.png" % (im_id, topo_id, USER),
)
print("Saving: %s" % fname)
plt.savefig(fname, bbox_inches="tight", transparent=False)

# Save h5
save_cdi_h5()

# Batch processing CDI

## Define Scan Ids

In [None]:
# Load support mask of which sample?
sample = "FBIe14"

In [None]:
# Define the sets for reconstructions. You can make a list or use np.arange
# im_id_set should always have ids of positive helicity holograms,
# topo_id_set those of negative helicity or a proper topo

im_id_set = np.arange(3365, 3367 + 1)
im_id_set = [3373, 3374]
topo_id_set = 3370 * np.ones(len(im_id_set), dtype=int)

# In case of single helicity reconstructions, adapt the helicity
# for contrast inversion
helicity = 1 * np.ones(len(im_id_set), dtype=int)  # [1,-1]

# Do cdi?
do_cdi = True

print("Dynamics Set:  %s" % im_id_set)
print("Reference Set: %s" % topo_id_set)
print("Helicity: %s" % helicity)

## Execute Stack Reconstruction

In [None]:
# Ugly Automatic processing of image stacks
recons_name = []  # for gifs
for it, im_id in enumerate(im_id_set):
    # Load images
    image, _ = load_processing(im_id)

    # Get topo
    # Do you want to construct topo holo from two helicity images
    try:
        len(topo_id_set[it])
    except:
        # Usual case
        # Get also topo & dark id from list of sets
        topo_id = topo_id_set[it]

        # Load data
        print(f"Loading imageId: %04d, topoId: %04d" % (im_id, topo_id))
        topo, _ = load_processing(topo_id)

        # Process images
        worker_dict = worker(image, topo)

        # Save topo hologram
        save_topo_holo(worker_dict["sum_c"], im_id, topo_id)
    else:
        print("Using Topo from sum of two helicity holograms")
        pos_id = topo_id_set[it][0]
        neg_id = topo_id_set[it][1]
        topo_id = topo_id_set[it]

        try:
            topo = load_topo_holo(pos_id, neg_id) / 2
        except:
            topo = load_topo_holo(neg_id, pos_id) / 2
        topo = cci.shift_image(
            topo, -shift_c
        )  # shift out of center so you don't need to change the worker

        # Process images
        worker_dict = worker(image, topo)

    # Save FTH reco
    # Magnetic field value
    data_key = "magnett_read"
    values = np.mean(np.array(load_data(im_id, data_key)) * 1000)
    values = [np.round(values, 2)]

    # Cryostat temperature if saved in nxs file
    try:
        data_key = "cryob"
        values.append(np.mean(np.array(load_collection(im_id, data_key))))
        title = "Image %d - %s @%.2f mT, %0.1f K" % (
            im_id,
            topo_id,
            values[0],
            values[1],
        )
    except:
        title = "Image %d - %s @%.2f mT" % (im_id, topo_id, values[0])

    # Reconstruct
    recon = fth.reconstruct(
        fth.propagate(
            worker_dict["holo"], prop_dist * 1e-6, experimental_setup=experimental_setup
        )
        * np.exp(1j * phase)
    )

    # Plot
    fig, ax = plt.subplots(1, 2, figsize=(10, 4))
    fig.suptitle(title)

    vmin, vmax = np.percentile(
        np.real(recon[roi[0] : roi[1], roi[2] : roi[3]]), (1, 99)
    )
    t_im1 = ax[0].imshow(
        np.real(recon[roi[0] : roi[1], roi[2] : roi[3]]),
        vmin=vmin,
        vmax=vmax,
        cmap="gray",
    )
    ax[0].set_title("Real")
    plt.colorbar(t_im1, ax=ax[0], aspect=50)

    vmin, vmax = np.percentile(
        np.imag(recon[roi[0] : roi[1], roi[2] : roi[3]]), (1, 99)
    )
    t_im2 = ax[1].imshow(
        np.imag(recon[roi[0] : roi[1], roi[2] : roi[3]]),
        vmin=vmin,
        vmax=vmax,
        cmap="gray",
    )
    ax[1].set_title("Imag")
    plt.colorbar(t_im2, ax=ax[1], aspect=50)

    # Save images
    fname = join(
        folder_general,
        "Recon_ImId_%04d_RefId_%s_fth_diff_stack_%s.png" % (im_id, topo_id, USER),
    )
    print("Saving: %s" % fname)
    plt.savefig(fname, bbox_inches="tight", transparent=False)

    ################ CDI ###############
    if do_cdi is True:
        # Create pos and neg helicity set
        pos = (worker_dict["sum_c"] + worker_dict["diff_c"]) / 2
        neg = (worker_dict["sum_c"] - worker_dict["diff_c"]) / 2

        # Create beamstop automatically
        mask_im, mask_topo, mask_pixel, mask_pixel_smooth = create_auto_beamstop(
            pos, neg, mask_draw, use_bs, bs_param
        )

        fig, ax = cimshow(mask_pixel.astype(int))
        ax.set_title("Verify that this looks like an acceptable beamstop")

        # Get coordinates and create supportmask
        support_coordinates = get_supportmask_coordinates(sample)
        supportmask = create_supportmask(support_coordinates, pos.shape)

        # Do phase retrieval
        (
            retrieved_p,
            retrieved_n,
            retrieved_p_pc,
            retrieved_n_pc,
            bsmask_p,
            bsmask_n,
            gamma_p,
            gamma_n,
        ) = phase_retrieval(
            pos, neg, mask_pixel, supportmask, Startimage=None, Startgamma=None
        )

        # Get Recos partial coherence
        # Positiv partial coherence
        p_pc = fth.reconstructCDI(
            fth.propagate(
                retrieved_p_pc * mask_bs_cdi,
                prop_dist_cdi * 1e-6,
                experimental_setup=experimental_setup,
            )
        )
        # Negative partial coherence
        n_pc = fth.reconstructCDI(
            fth.propagate(
                retrieved_n_pc * mask_bs_cdi,
                prop_dist_cdi * 1e-6,
                experimental_setup=experimental_setup,
            )
        )

        ##### Calc reco and optimze contrast
        recon = helicity[it] * (p_pc - n_pc)
        _, phase_cdi = optimize_phase_contrast(
            recon,
            supportmask,
            method="contrast",
            prefered_color="white",
        )
        # phase_cdi = 0
        recon = recon * np.exp(1j * phase_cdi)
        print("Phase is:", np.round(phase_cdi, 2))
        ########

        # Plot
        fig, ax = plt.subplots(1, 2, figsize=(10, 4))
        fig.suptitle(title)

        vmin, vmax = np.percentile(np.real(recon[roi_cdi]), (0.5, 99.5))
        t_im1 = ax[0].imshow(
            np.real(recon[roi_cdi]),
            vmin=vmin,
            vmax=vmax,
            cmap="gray",
        )
        ax[0].set_title("Real")
        plt.colorbar(t_im1, ax=ax[0], aspect=50)

        vmin, vmax = np.percentile(np.imag(recon[roi_cdi]), (0.5, 99.5))
        t_im2 = ax[1].imshow(np.imag(recon[roi_cdi]), vmin=vmin, vmax=vmax, cmap="gray")
        ax[1].set_title("Imag")
        plt.colorbar(t_im2, ax=ax[1], aspect=50)

        # Save images
        fname = join(
            folder_general,
            "Recon_ImId_%04d_RefId_%s_cdi_stack_%s.png" % (im_id, topo_id, USER),
        )

        print("Saving: %s" % fname)
        plt.savefig(fname, bbox_inches="tight", transparent=False)
        recons_name.append(fname)

        # Save files as h5
        save_cdi_h5()
    else:
        print("Phase retrieval disabled!")

    print(" ")
print("CDI stack processing finished")

In [None]:
plt.close("all")

# Gifs

In [None]:
im_id_set = [
    1268,
    1269,
    1272,
    1273,
    1276,
    1277,
    1280,
    1281,
    1284,
    1285,
    1288,
    1289,
    1292,
    1293,
    1296,
    1297,
    1300,
    1301,
    1303,
    1306,
    1308,
    1313,
    1314,
    1315,
    1316,
    1317,
]
topo_id_set = [
    1267,
    1270,
    1271,
    1274,
    1275,
    1278,
    1279,
    1282,
    1283,
    1286,
    1287,
    1290,
    1291,
    1294,
    1295,
    1298,
    1299,
    1302,
    [1299, 1300],
    1307,
    [1306, 1307],
    1312,
    [1312, 1313],
    [1312, 1313],
    [1312, 1313],
    [1312, 1313],
]

recons_name = []
for i, im_id in enumerate(im_id_set):
    fname = join(
        folder_general,
        "Recon_ImId_%04d_RefId_%s_%s_cdi_diff_stack.png"
        % (im_id, topo_id_set[i], USER),
    )
    recons_name.append(fname)

# Create gif of last scan
var = [imageio.imread(file) for file in recons_name]
fname = f"ImId_%04d_%04d_%s.gif" % (im_id_set[0], im_id_set[-1], USER)
gif_path = path.join(folder_general, fname)
print("Saving gif:%s" % gif_path)
imageio.mimsave(gif_path, var, fps=2)
print("Done!")