# Imports

Import python libraries as well as the self written FERMI library.

In [None]:
import sys, os
from os.path import join, split
from getpass import getuser
from glob import glob
from time import strftime
from importlib import reload
from tqdm.auto import tqdm
import gc

# data
import numpy as np
import xarray as xr
import pandas as pd
import h5py

# Images
import imageio
from imageio import imread

# Plotting
import matplotlib.pyplot as plt
from matplotlib.image import NonUniformImage
import matplotlib.gridspec as gridspec
from matplotlib.path import Path

# pyFAI
import pyFAI

pyFAI.disable_opencl = True  # get rid of annoying warning ;)
from pyFAI.azimuthalIntegrator import AzimuthalIntegrator
from pyFAI.detectors import Detector

# Scipy
from scipy.ndimage import median_filter

# Self-written libraries
sys.path.append(os.path.abspath(join(os.pardir,"process_FERMI")))
import helper_functions as helper
import mask_lib
import process_FERMI as pf
import interactive
from interactive import cimshow

In [None]:
# interactive plotting
import ipywidgets

%matplotlib widget
plt.rcParams["figure.constrained_layout.use"] = True

# Auto formatting of cells
#%load_ext jupyter_black

pd.set_option("display.max_colwidth", None)

## Functions

### Other

In [None]:
def create_ringmask(ring_coordinates, shape):
    # Create ring mask
    ringmask = np.zeros(shape)
    for i in range(len(ring_coordinates)):
        ringmask += mask_lib.circle_mask(
            ringmask.shape,
            [ring_coordinates[i][0], ring_coordinates[i][1]],
            ring_coordinates[i][2],
        )

    ringmask = np.abs(ringmask - len(ring_coordinates) - 1)
    ringmask[ringmask == len(ring_coordinates) + 1] = 0
    masks_ring = np.zeros((len(ring_coordinates), shape[0], shape[1]), dtype=bool)
    for i in range(0, len(ring_coordinates)):
        masks_ring[i] = ringmask == i

    return ringmask, masks_ring

In [None]:
def calc_delay_ps(exp):
    exp["delay_ps"] = np.round(6.67 * (exp["delay"] - t0),1) -1   # + (640 - exp["global_delay"])
    return exp

### Loading

In [None]:
def preprocess_exp(datafolder, extension, keys=None, sort=False):
    """
    Loads log parameter and image file names, not the actual images
    """
    
    # Loading experiment data
    print("Loading: %s" % (datafolder + extension))
    exp = pf.get_exp_dataframe(datafolder + extension, keys=keys)
    for k in ["xgm_UH", "xgm_SH", "diode_sum"]:
        exp[k + "_sum"] = exp[k].apply(np.sum)

    exp["diode_sum_mean"] = exp.diode_sum.apply(np.mean)
    exp["diode_sum_sum"] = exp.diode_sum.apply(np.sum)
    exp["diode_sum_std"] = exp.diode_sum.apply(np.std)
    exp["IR_mean"] = exp.IR.apply(np.mean)
    exp["IR_std"] = exp.IR.apply(np.std)
    exp["magnet_mean"] = exp.magnet.apply(np.mean)
    exp["magnet_mean"] = exp.magnet_mean.apply(np.round, args=(3,))
    exp["bunchid"] = exp.bunches.apply(lambda l: l[-1])

    # Add delay
    exp = calc_delay_ps(exp)

    # Sort according to delay
    if sort is True:
        exp = exp.sort_values(scan_axis)
    
    return exp

In [None]:
def load_images_into_xarry(exp, extra_keys=[]):
    """
    Loads only images, scan_axis (e.g.,delay), basic filter keys (diode for I0)
    and some specified extra keys into xarray
    """

    # Loading image data
    data = xr.Dataset()
    data["images"] = xr.DataArray(
        [pf.loadh5(fname)[0].astype(np.float32) for fname in tqdm(exp["filename"])],
        dims=["file_idx", "frame_idx", "y", "x"],
    )
    data[scan_axis] = xr.DataArray(
        np.array(np.round(exp[scan_axis], 1)).astype(int), dims=["file_idx"]
    )  # in fs
    data[filter_key] = xr.DataArray(
        np.stack(exp[filter_key]), dims=["file_idx", "frame_idx"]
    )

    if len(extra_keys) > 0:
        for key in extra_keys:
            stack  = np.stack(exp[key])

            if stack.ndim == 1:
                data[key] = xr.DataArray(
                stack, dims=["file_idx"]
                )
            elif stack.ndim == 2:
                data[key] = xr.DataArray(
                stack, dims=["file_idx", "frame_idx"]
                )
            else:
                print("Key: %s does not exist"%key)
    
    return data

In [None]:
def process_images(exp,extra_keys = []):
    # Load data into xarray
    data = load_images_into_xarry(exp, extra_keys=extra_keys)

    # Basic filtering of images
    data = filter_empty_images(data)

    # Calc Averaged images and intensity
    data["images_mean"] = data["images"].mean("frame_idx")
    data[filter_key + "_sum"] = data[filter_key].sum("frame_idx")

    data["time"] = xr.DataArray(
        np.repeat(np.array(exp["time"])[:, np.newaxis], len(data["frame_idx"]), axis=1),
        dims=["file_idx", "frame_idx"],
    )

    return data

## Filtering of images

In [None]:
def filter_empty_images(data):
    """
    Filters entirely empty images without any counts
    """
    print("Filtering empty images")
    
    # Filter empty images
    data["images_max"] = data.images.max(["y", "x"])
    data = data.where(data["images_max"] > 1)

    return data

In [None]:
def filter_false_images(images,filter_thres, NaNs = False):
    # Calc Monitoring parameter
    if NaNs is False:
        image_mean = np.mean(images,axis = (-2,-1))
        ensemble_mean = np.median(image_mean)
        image_std = np.std(image_mean)
    elif NaNs is True:
        image_mean = np.nanmean(images,axis = (-2,-1))
        ensemble_mean = np.nanmedian(image_mean)
        image_std = np.nanstd(image_mean)
        
    # Filter
    valid = np.abs(image_mean-ensemble_mean) < filter_thres * image_std

    # Plot filter condition
    fig, ax = plt.subplots()
    ax.plot(image_mean,'o-')
    ax.grid()
    ax.set_xlabel("Image Index")
    ax.set_ylabel("Image Mean")
    ax.set_title("Check for inconsistencies of the averaged intensity")
    ax.axhline(ensemble_mean,0,images.shape[0],color = 'g',linestyle = '--')
    ax.axhline((ensemble_mean + filter_thres*image_std),0,images.shape[0],color = 'r',linestyle = '--')
    ax.axhline((ensemble_mean - filter_thres*image_std),0,images.shape[0],color = 'r',linestyle = '--')
    
    return valid

# Experimental details

In [None]:
# Define basic folders
BASEFOLDER = r"/data/beamtimes/FERMI/2310_XPCS"
PROPOSAL = "20224053"
USER = getuser()

In [None]:
# Dict with most basic experimental parameter
experimental_setup = {
    "px_size": 11e-6,  # pixel_size of camera
    "binning": 1,  # Camera binning
}

# Setup for azimuthal integrator
detector = Detector(
    experimental_setup["binning"] * experimental_setup["px_size"],
    experimental_setup["binning"] * experimental_setup["px_size"],
)

# General saving folder
folder_target = pf.create_folder(join("/data/export/cklose/2310_FERMI_Skyrmion", "Results"))
print("Output Folder: %s" % folder_target)

# Load Data

## Define Scan ids for loading

In [None]:
# Define for loading
sample = "Sample54"
membrane = "H4"
scan_id = 161
ACDC_delay = "2p0ps"
scan = f"%s_XPCS_FF_Delay_%s_Scan_Both_%03d" % (membrane, ACDC_delay, scan_id)
scan_axis = "delay_ps"

# Time-zero for delay scans
t0 = 7.25

# Folder for loading
samplefolder = join(sample, scan)
datafolder = join(BASEFOLDER, samplefolder)
extra_keys = {
    "diode_sum": "PAM/FQPDSum",
    "IR": "Laser/Energy1",
    "magnet": "DPI/CoilCurrent",
    "magnet_waveform": "DPI/CoilWaveform",
    "bunches": "bunches",
    "time": "",
    "samplex": "DPI/SampleX",
    "sampley": "DPI/SampleY",
    "ccdz": "DPI/CcdZ",
    "global_delay": "Laser/DelayTotem",
}

# Create savefolder
fsave = helper.create_folder(join(folder_target, sample, membrane,"XSVS"))

## Pumped images

In [None]:
# Loading experiment data
filter_key = "diode_sum"
norm_key = "diode_sum"
extension = ""
exp = preprocess_exp(datafolder, extension, keys=extra_keys, sort = True)

# Add wavelength
experimental_setup["lambda"] = exp["wavelength"][0] * 1e-9
experimental_setup["ccd_dist"] = (exp["ccdz"][0] + 50) * 1e-3

# Load data
data = process_images(exp.iloc[np.hstack([np.arange(0,60),np.arange(90,120)])], extra_keys = ["diode_sum","magnet","diode_sum_mean"])

# All frames of a given file correspond to the same delay and recording time
data[scan_axis] = data[scan_axis].mean("frame_idx")
data["time"] = data["time"].mean(["frame_idx"])
print("Data loaded!")
data

In [None]:
# What did you scan?
fig, ax = plt.subplots()
ax.plot(np.arange(len(data[scan_axis])), data[scan_axis], "-o")
ax.set_xlabel("file_idx")
ax.set_ylabel(scan_axis)
ax.grid()

In [None]:
# Monitoring plots
fig, ax = plt.subplots(figsize=(8,6))

for diode_sum in data.diode_sum.values:
    ax.plot(diode_sum)

ax.grid()

In [None]:
# Plot images
fig, ax = cimshow(data["images"][-1])
fig.set_size_inches(6, 6)
ax.set_title("Pumped Images")

## Dark images

In [None]:
# Loading experiment data
extension = "_BG"
exp_bg = preprocess_exp(datafolder, extension, keys=extra_keys)
exp_bg = exp_bg.sort_values("time")

# Loading image data
dark = process_images(exp_bg)
print("Data loaded!")

In [None]:
# Plot image sequence
fig, ax = cimshow(dark["images"][0])
fig.set_size_inches(6, 6)
ax.set_title("Dark Images")

In [None]:
# Filter wrongly assigned frames, i.e., when shutter was not working correctly
valid = filter_false_images(dark["images"][0],2)

In [None]:
# Drop invalid frames
dark = dark.where(valid).dropna(dim="frame_idx")
dark["images_mean"] = dark["images"].mean(["frame_idx"])

# Drop single frames
dark["time"] = dark["time"].mean("frame_idx")
#dark = dark.drop_dims("frame_idx")
dark = dark.swap_dims({"file_idx": "time"})

# Plot image
fig, ax = cimshow(dark["images_mean"])
fig.set_size_inches(6, 6)
ax.set_title("Dark Image")

## Laser only

In [None]:
# Loading experiment data
extension = "_OL"
exp_ol = preprocess_exp(datafolder, extension, keys=extra_keys)
exp_ol = exp_ol.sort_values("time")

# Loading image data
dark_ol = process_images(exp_ol)
print("Data loaded!")

In [None]:
# Plot image sequence
fig, ax = cimshow(dark_ol["images"][0])
fig.set_size_inches(6, 6)
ax.set_title("Only Laser Images")

In [None]:
# Filter wrongly assigned frames, i.e., when shutter was not working correctly
valid = filter_false_images(dark_ol["images"][0],2,NaNs=True)

In [None]:
# Drop invalid frames
valid = xr.DataArray(valid,dims=["frame_idx"])
dark_ol = dark_ol.where(valid).dropna(dim="frame_idx")
dark_ol["images_mean"] = dark_ol["images"].mean(["frame_idx"])

# Drop single frames
dark_ol["time"] = dark_ol["time"].mean("frame_idx")
#dark_ol = dark_ol.drop_dims("frame_idx")
dark_ol = dark_ol.swap_dims({"file_idx": "time"})

# Plot image
fig, ax = cimshow(dark_ol["images_mean"])
fig.set_size_inches(6, 6)
ax.set_title("Only Laser Image")

## FEL only

In [None]:
# Loading experiment data
extension = "_OF"
exp_of = preprocess_exp(datafolder, extension, keys=extra_keys)
exp_of = exp_of.sort_values("time")

# Loading image data
dark_of = process_images(exp_of)
print("Data loaded!")

In [None]:
# Plot image sequence
fig, ax = cimshow(dark_of["images"][0])
fig.set_size_inches(6, 6)
ax.set_title("Only FEL Images")

In [None]:
# Filter wrongly assigned frames, i.e., when shutter was not working correctly
valid = filter_false_images(dark_of["images"][0],2,NaNs=True)

In [None]:
# Drop invalid frames
valid = xr.DataArray(valid,dims=["frame_idx"])
dark_of = dark_of.where(valid).dropna(dim="frame_idx")
dark_of["diode_sum_sum"] = dark_of["diode_sum"].sum(["frame_idx"])
dark_of["diode_sum_mean"] = dark_of["diode_sum"].mean(["frame_idx"])
dark_of["images_mean"] = dark_of["images"].mean(["frame_idx"])

# Drop single frames
dark_of["time"] = dark_of["time"].mean("frame_idx")
#dark_of = dark_of.drop_dims("frame_idx")
dark_of = dark_of.swap_dims({"file_idx": "time"})

# Plot image
fig, ax = cimshow(dark_of["images_mean"])
fig.set_size_inches(6, 6)
ax.set_title("Only FEL Image")

# Preprocessing

## Subtract Dark Background

In [None]:
# Subtract only dark background from pumped images
for i, delay in enumerate(tqdm(data[scan_axis].values)):
    data["images"][i] = data["images"][i] - (
        dark.sel(time=data.time[i], method="nearest").images_mean
    )
# Just for a feedback plot
data["images_mean"] = data["images"].mean(["frame_idx"])

# Subtract only dark background from only laser image
for i, delay in enumerate(tqdm(dark_ol["time"].values)):
    dark_ol["images"][i] = dark_ol["images"][i] - (
        dark.sel(time=dark_ol.time[i], method="nearest").images_mean
    )
# Just for a feedback plot
dark_ol["images_mean"] = dark_ol["images"].mean(["frame_idx"])

# Subtract only dark background from only FEL image
for i, delay in enumerate(tqdm(dark_of["time"].values)):
    dark_of["images"][i] = dark_of["images"][i] - (
        dark.sel(time=dark_of.time[i], method="nearest").images_mean
    )
# Just for a feedback plot
dark_of["images_mean"] = dark_of["images"].mean(["frame_idx"])

### Feedback Images

In [None]:
# Plot images
fig, ax = cimshow(data["images"][-1])
fig.set_size_inches(6, 6)
ax.set_title("Pumped Single Images for one delay")

In [None]:
# Plot images
fig, ax = cimshow(data["images_mean"])
fig.set_size_inches(6, 6)
ax.set_title("Pumped Images for all delays as slideshow")

In [None]:
# Plot image sequence
fig, ax = cimshow(dark_ol["images"][0])
fig.set_size_inches(6, 6)
ax.set_title("Only Laser Images")

In [None]:
# Plot image sequence
fig, ax = cimshow(dark_ol["images_mean"][0])
fig.set_size_inches(6, 6)
ax.set_title("Only Laser Images")

In [None]:
# Plot image sequence
fig, ax = cimshow(dark_of["images"][0])
fig.set_size_inches(6, 6)
ax.set_title("Only FEL Images")

In [None]:
# Plot image sequence
fig, ax = cimshow(dark_of["images_mean"][0])
fig.set_size_inches(6, 6)
ax.set_title("Only FEL averaged Image")

### Feedback Plots

In [None]:
# Histograms of all image types
# Histogram parameter
delay_idx_list = [-1,50,1]
start = -40
end = 40
nr_steps = 200 
bins = np.linspace(start,end,nr_steps)

# Calc Histograms
#hist_dark, _ = np.histogram(dark["images_mean"][0],bins = bins,density=True)
hist_of, _ = np.histogram(dark_of["images"][0].values.ravel(),bins = bins,density=True)
hist_ol, _ = np.histogram(dark_ol["images"][0].values.ravel(),bins = bins,density=True)

# Plotting
fig, ax = plt.subplots(2,1,figsize=(8,6))
ax[0].set_title("Histogram of frames")
ax[1].set_title("Deviation from only-FEL histogram")
for delay_idx in delay_idx_list:
    hist_image, _ = np.histogram(data["images"][delay_idx].values.ravel(),bins = bins,density=True)
    ax[0].bar(bins[1:],hist_image,width=(end-start)/200, label="Image Delay: %d ps"%data["delay_ps"][delay_idx],alpha = 0.25)
    ax[1].plot(bins[1:],hist_image-hist_of,label="Image Delay: %d ps"%data["delay_ps"][delay_idx])
ax[0].bar(bins[1:],hist_ol,width=(end-start)/200, label="Only Laser",alpha = 0.25)
ax[0].bar(bins[1:],hist_of,width=(end-start)/200, label="Only FEL",alpha = 0.25)
ax[1].plot(bins[1:],hist_ol-hist_of,label="Only Laser")
#ax.hist(dark["images_mean"][0].values.ravel(),500, label="Dark")

ax[0].set_ylabel("Frequency")
ax[1].set_xlabel("Counts")
ax[1].set_ylabel("Probability deviation")
ax[0].grid()
ax[1].grid()
ax[0].set_yscale("log")
ax[0].legend()
ax[1].legend()

## Normalize and remove all backgrounds from pumped images 

In [None]:
# Subtract only laser background from pumped image
for i, delay in enumerate(tqdm(data[scan_axis].values)):
    data["images"][i] = data["images"][i] - (
        dark_ol.sel(time=data.time[i], method="nearest").images_mean
    )

# Normalize to laser fluences
data["images"] = data["images"] / data[norm_key]
dark_of["images"] = dark_of["images"] / dark_of[norm_key] #all frames are normalized individually
dark_of["images_mean"] = dark_of["images"].mean(["frame_idx"])

# Subtract final background
for i, delay in enumerate(tqdm(data[scan_axis].values)):
    data["images"][i] = data["images"][i] - (
        dark_of.sel(time=data.time[i], method="nearest").images_mean
    )

In [None]:
# Show images of which file?
file_idx = -1

# Plot images
fig, ax = cimshow(data.images[file_idx])
fig.set_size_inches(6, 6)
ax.set_title("Images of delay: %.2f ps" % data[scan_axis][file_idx].values)

In [None]:
# Calc average image for each delay
data["images_mean"] = data["images"].mean("frame_idx")
data["im_mean"] = data.images_mean.mean("file_idx")
im_mean = data["im_mean"].values

print("Done!")

## Draw beamstop mask

In [None]:
poly_mask = interactive.draw_polygon_mask(data.im_mean)

In [None]:
# Take poly coordinates and mask from widget
p_coord = poly_mask.get_vertice_coordinates()
mask = poly_mask.full_mask.astype(int)

cimshow(mask)

print("Mask coordinates: %s" % p_coord)

In [None]:
def load_poly_coordinates():
    """
    Dictionary that stores polygon corner coordinates of all drawn masks
    Example: How to add masks with name "test":
    mask_coordinates["test"] = copy coordinates from above
    """

    # Setup dictonary
    mask_coordinates = dict()

    mask_coordinates["bs_cross"] = [[(2064.9, -23.5), (-23.3, -34.8), (-53.0, 10.5), (1977.6, 69.7), (1978.7, 613.6), (569.6, 610.3), (-58.1, 599.2), (-49.9, 744.6), (1632.1, 739.1), (2166.8, 760.3), (2166.8, 367.2)]]

    return mask_coordinates

In [None]:
# Which drawn masks do you want to load? (you can add multiple masks in list e.g. ["bs_cross","bs_bar_delayscans"])
polygon_names = ["bs_cross"] 
mask = mask_lib.load_poly_masks(im_mean.shape,load_poly_coordinates(),polygon_names)

# Additional manual masking of vertical stripes
mask[:,45:55] = 1
mask[:,160:170] = 1
mask[:,182:198] = 1
mask[:,1046:1047] = 1
mask[:,1215:1216] = 1
mask[:,1672:1673] = 1
mask[:,1784:1795] = 1

In [None]:
data["mask"] = xr.DataArray(mask, dims=["y", "x"])
print("Mask == 1 areas will be excluded in azimuthal integrator")

fig, ax = plt.subplots(1, 2, figsize=(10, 5), sharex=True, sharey=True)
mi, ma = np.percentile(im_mean, [30, 70])
ax[0].imshow(im_mean * (1 - mask), vmin=mi, vmax=ma)
ax[0].set_title("(1-mask)")
ax[1].imshow(im_mean * mask, vmin=mi, vmax=ma)
ax[1].set_title("mask")

In [None]:
# Apply mask
sel = np.logical_not(data["mask"])
data["images"] = data["images"].where(sel, other = np.nan)
dark_ol["images"] = dark_ol["images"].where(sel, other = np.nan)
dark_of["images"] = dark_of["images"].where(sel, other = np.nan)

In [None]:
# Plot images
fig, ax = cimshow(data.images[0])
fig.set_size_inches(6, 6)
ax.set_title("Averaged image of each delay")

In [None]:
# Histograms of all image types
# Histogram parameter
delay_idx_list = [-1,50,1]
start = -0.1
end = 0.1
nr_steps = 200 
bins = np.linspace(start,end,nr_steps)

# Calc Histograms
#hist_dark, _ = np.histogram(dark["images_mean"][0],bins = bins,density=True)
temp_data = dark_of["images"][0].values.ravel()
hist_of, _ = np.histogram(temp_data[~np.isnan(temp_data)],bins = bins,density=True)
#hist_ol, _ = np.histogram(dark_ol["images"][0].values.ravel(),bins = bins,density=True)

# Plotting
fig, ax = plt.subplots(2,1,figsize=(8,6))
ax[0].set_title("Histogram of frames")
ax[1].set_title("Deviation from only-FEL histogram")
for delay_idx in delay_idx_list:
    temp_data = data["images"][delay_idx].values.ravel()
    hist_image, _ = np.histogram(temp_data[~np.isnan(temp_data)],bins = bins,density=True)
    ax[0].bar(bins[1:],hist_image,width=(end-start)/200, label="Image Delay: %d ps"%data["delay_ps"][delay_idx],alpha = 0.25)
    ax[1].plot(bins[1:],hist_image-hist_of,label="Image Delay: %d ps"%data["delay_ps"][delay_idx])
#ax[0].bar(bins[1:],hist_ol,width=(end-start)/200, label="Only Laser",alpha = 0.25)
ax[0].bar(bins[1:],hist_of,width=(end-start)/200, label="Only FEL",alpha = 0.25)
#ax[1].plot(bins[1:],hist_ol-hist_of,label="Only Laser")
#ax.hist(dark["images_mean"][0].values.ravel(),500, label="Dark")

ax[0].set_ylabel("Frequency")
ax[1].set_xlabel("Counts")
ax[1].set_ylabel("Probability deviation")
ax[0].grid()
ax[1].grid()
ax[0].set_yscale("log")
ax[0].legend()
ax[1].legend()

## Reorganize data xarray

In [None]:
# Group according to new coordinate delay in ps
new_data = []
for label, group in list(data.groupby("delay_ps")):
    group = group.expand_dims(dim="delay_ps_coord")
    group["delay_ps_coord"] = [label]
    new_data.append(group)
data = xr.concat(new_data,dim="delay_ps_coord")

# Combine file_idx and frame_idx
data = data.stack(delay_frame_tuple = ("file_idx","frame_idx"))

# Replace entangled variables
data["delay_frame_idx"] = xr.DataArray(np.arange(data["delay_frame_tuple"].shape[0]),dims=["delay_frame_tuple"])
data = data.swap_dims({"delay_frame_tuple":"delay_frame_idx"})

# Process images
data["images_mean"] = data["images"].mean("delay_frame_idx")
data["im_mean"] = data["images_mean"].mean("delay_ps_coord")
data["images"] = data["images"].transpose("delay_ps_coord","delay_frame_idx","y","x")

del new_data, group
gc.collect()

data

In [None]:
fig, ax = cimshow(data["images_mean"])
ax.set_title("Pumped Images grouped for different delays")

## Filtering of images with too high or low intensity

### Find Standard Deviation threshold

In [None]:
# Filter wrongly assigned frames, i.e., when shutter was not working correctly
std_thres = [2,1,1]
for i, delay in enumerate(data["delay_ps_coord"].values):
    valid = filter_false_images(data["images"][i],std_thres[i],NaNs=True)

In [None]:
# View faulty frames
fig, ax = cimshow(data["images"][-1])

### Apply filtering

In [None]:
# Filter wrongly assigned frames, i.e., when shutter was not working correctly
for i, delay in enumerate(data["delay_ps_coord"].values):
    valid = filter_false_images(data["images"][i],std_thres[i],NaNs=True)
    valid = xr.DataArray(valid,dims=["delay_frame_idx"])

    # Drop invalid frames
    data["images"][i] = data["images"][i].where(valid)
    
data["images_mean"] = data["images"].mean(["delay_frame_idx"])

## Find center

### Basic widget to find center

Try to **align** the circles to the **center of the scattering pattern**. Care! Position of beamstop might be misleading and not represent the actual center of the hologram. 

In [None]:
# Set center position via widget
ic = interactive.InteractiveCenter(data.images_mean[-1], c0 = -77, c1 = 2081, rBS = 250)

In [None]:
# Get center positions
center = [ic.c0, ic.c1]
print(f"Center:", center)

### Azimuthal integrator widget for finetuning

In [None]:
# Setup azimuthal integrator for virtual geometry
ai = interactive.AzimuthalIntegrator(
    dist=experimental_setup["ccd_dist"],
    detector=detector,
    wavelength=experimental_setup["lambda"],
    poni1=center[0]
    * experimental_setup["px_size"]
    * experimental_setup["binning"],  # y (vertical)
    poni2=center[1]
    * experimental_setup["px_size"]
    * experimental_setup["binning"],  # x (horizontal)
)

In [None]:
# Plotting to find  relevant q range
I_t, q_t, phi_t = ai.integrate2d(
    np.mean(data.images_mean.values[-29:],axis=0),
    200,
    radial_range=(0, 0.11),
    unit="q_nm^-1",
    correctSolidAngle=False,
    dummy=np.nan,
    mask=mask,
    method= "bbox"
)
az2d = xr.DataArray(I_t, dims=("phi", "q"), coords={"q": q_t, "phi": phi_t})

# Plot
fig, ax = plt.subplots()
mi, ma = np.nanpercentile(I_t, [10, 90])
az2d.plot.imshow(ax=ax, vmin=mi, vmax=ma)
plt.title(f"Azimuthal integration")

# Vertical lines
# q_lines = [0.025, 0.05]
# for qt in q_lines:
#    ax.axvline(qt, ymin=0, ymax=180, c="red")

In [None]:
aic = interactive.AzimuthalIntegrationCenter(
    # np.log10(im_mean - np.min(im_mean) + 1),
    data.images_mean[-1].values,
    ai,
    c0=center[0],
    c1=center[1],
    mask=mask,
    im_data_range=[20, 80],
    radial_range=(0.015, 0.08),
    qlines=[40, 60],
)

In [None]:
# Get center positions
center = [aic.c0, aic.c1]
data = data.assign_attrs({"center": center})
print(f"Center:", center)

# X-ray speckle visibility spectroscopy (XSVS)

## Setup ring mask

In [None]:
# How many rings do you want?
nr_rings = 14

# Setup coordinates
ring_coordinates = [[center[0], center[1], (k + 1) * 150] for k in range(nr_rings + 1)]

In [None]:
def get_ringmask_coordinates(ring_name):
    if ring_name == "xsvs_test":
        ring_coord = [
            [256.0, 1024.0, 200.0],
            [256, 1024, 400],
            [256, 1024, 600],
            [256, 1024, 800],
            [256, 1024, 1000],
            [256, 1024, 1200],
        ]
    return ring_coord

In [None]:
# Which sample do you study? ("GdFe", "Permalloy", "YIG")
ring_name = "xsvs_test"

# Get coordinates
ring_coordinates = get_ringmask_coordinates(ring_name)
nr_rings = len(ring_coordinates)

In [None]:
# Widget to find the positions and sizes of the different apertures
ds = interactive.InteractiveCircleCoordinates(
    data.images_mean[-1].values,
    nr_rings,
    coordinates=ring_coordinates.copy(),
)

In [None]:
# Take coordinates of circles from widget
ring_coordinates = ds.c_yxr

# Create mask
ringmask, masks_ring = create_ringmask(ring_coordinates, data.im_mean.shape)

# Add beamstop mask
ringmask = ringmask * (1 - mask)
masks_ring = masks_ring * (1 - mask)

# Add to xarray
data["ringmask"] = xr.DataArray(ringmask, dims=["y", "x"])

# Plot
tmp = data.im_mean.values
fig, ax = plt.subplots(figsize=(6, 6))
mi, ma = np.nanpercentile(
    tmp[tmp != 0],
    (1, 85),
)
ax.imshow(
    tmp,
    vmin=mi,
    vmax=ma,
)
ax.imshow(ringmask, alpha=0.2, cmap="flag")
ax.set_title("Image with overlayed mask")

# Save figure
fname = join(fsave, "XSVS_ID_%s_qrings_%s.png" % (scan_id, USER))
print("Saving: %s" % fname)
plt.savefig(fname)

## Mean intensity and brightest pixel of each ring

In [None]:
# Setup
ring_mean = np.zeros((nr_rings, len(data["delay_ps_coord"]), len(data["delay_frame_idx"])))
max_cts = np.zeros_like((ring_mean))
min_cts = np.zeros_like((ring_mean))

# Loop over rings
for i in tqdm(range(nr_rings)):
    # Indices of ringmask
    idx = np.argwhere((ringmask == i))

    # Calc only if slices are not empty
    if idx.size > 0:
        array_slices = data["images"].values[...,idx[:,0],idx[:,1]]
    
        # Mean intensity
        ring_mean[i] = np.nanmean(array_slices,axis=-1)
    
        # Brightest pixel
        max_cts[i] = np.nanmax(array_slices,axis=-1)
    
        # Darkest pixel
        min_cts[i] = np.nanmin(array_slices,axis=-1)


# Remove infs
#max_cts[np.isinf(max_cts)] = 0
#min_cts[np.isinf(min_cts)] = 0

# Assign to xarray
data["ring_mean"] = xr.DataArray(ring_mean, dims=["q_ring", "delay_ps_coord", "delay_frame_idx"])
data["max_cts"] = xr.DataArray(max_cts, dims=["q_ring", "delay_ps_coord", "delay_frame_idx"])
data["min_cts"] = xr.DataArray(min_cts, dims=["q_ring", "delay_ps_coord", "delay_frame_idx"])

In [None]:
# Nr of Columns and rows
columns = 3
row = len(data["delay_ps_coord"])

fig, ax = plt.subplots(row, columns, figsize=(4 * columns, 4 * row))
# Loop over different delays
for i, delay in enumerate(tqdm(data["delay_ps_coord"].values)):
    ax[i, 0].set_title("Delay Idx: %d" % delay)
    ax[i, 0].set_xlabel("Delay frame_idx")
    ax[i, 0].set_ylabel("Mean Intensity")

    ax[i, 1].set_title("Delay Idx: %d" % delay)
    ax[i, 1].set_xlabel("frame_idx")
    ax[i, 1].set_ylabel("Max Intensity")

    ax[i, 2].set_title("Delay Idx: %d" % delay)
    ax[i, 2].set_xlabel("frame_idx")
    ax[i, 2].set_ylabel("Min Intensity")

    # Plotting loop
    for k, ring in enumerate(data["q_ring"].values[2:]):
        # Mean intensity
        ax[i, 0].plot(
            data["delay_frame_idx"].values,
            data["ring_mean"].sel(q_ring=ring, delay_ps_coord=delay).values,
            label="qring = %d" % ring,
        )
        # Max intensity
        ax[i, 1].plot(
            data["delay_frame_idx"].values,
            data["max_cts"].sel(q_ring=ring, delay_ps_coord=delay).values,
            label="qring = %d" % ring,
        )
        #Min intensity
        ax[i, 2].plot(
            data["delay_frame_idx"].values,
            data["min_cts"].sel(q_ring=ring, delay_ps_coord=delay).values,
            label="qring = %d" % ring,
        )
    
    ax[i, 0].legend(prop={'size': 6})
    ax[i, 1].legend(prop={'size': 6})
    ax[i, 2].legend(prop={'size': 6})
    ax[i, 0].grid()
    ax[i, 1].grid()
    ax[i, 2].grid()

In [None]:
# Nr of Columns and rows
columns = 3
row = len(data["q_ring"][2:])

fig, ax = plt.subplots(row, columns, figsize=(4 * columns, 4 * row))
# Loop over different delays
for i, ring in enumerate(data["q_ring"].values[2:]):

    ax[i, 0].set_title("QRing: %d" % ring)
    ax[i, 0].set_xlabel("Delay frame_idx")
    ax[i, 0].set_ylabel("Mean Intensity")

    ax[i, 0].set_title("QRing: %d" % ring)
    ax[i, 1].set_xlabel("frame_idx")
    ax[i, 1].set_ylabel("Max Intensity")

    ax[i, 0].set_title("QRing: %d" % ring)
    ax[i, 2].set_xlabel("frame_idx")
    ax[i, 2].set_ylabel("Min Intensity")

    # Plotting loop
    for k, delay in enumerate(tqdm(data["delay_ps_coord"].values)):
        # Mean intensity
        ax[i, 0].plot(
            data["delay_frame_idx"].values,
            data["ring_mean"].sel(q_ring=ring, delay_ps_coord=delay).values,
            label="Delay = %d ps" % delay,
        )
        # Max intensity
        ax[i, 1].plot(
            data["delay_frame_idx"].values,
            data["max_cts"].sel(q_ring=ring, delay_ps_coord=delay).values,
            label="Delay = %d ps" % delay,
        )
        #Min intensity
        ax[i, 2].plot(
            data["delay_frame_idx"].values,
            data["min_cts"].sel(q_ring=ring, delay_ps_coord=delay).values,
            label="Delay = %d ps" % delay,
        )
    
    ax[i, 0].legend(prop={'size': 6})
    ax[i, 1].legend(prop={'size': 6})
    ax[i, 2].legend(prop={'size': 6})
    ax[i, 0].grid()
    ax[i, 1].grid()
    ax[i, 2].grid()

## Create bin edges

In [None]:
# Constant spacing
nr_bins = 100
photon_bins = 1

bin_edges = np.linspace(
    data["min_cts"].min().values,
    data["max_cts"].max().values,
    nr_bins,
)

bin_edges = np.linspace(-0.1, 0.5, nr_bins)


if len(bin_edges) > 1e5:
    print("Warning! Too many bins for plotting!")

# Clear existing bin axis
if "bins" in data.keys():
    data = data.drop_dims("bins")

# Add as new xarray coordinate
data["bins"] = bin_edges[1:]

fig, ax = plt.subplots()
ax.plot(bin_edges, "o-")
ax.set_xlabel("frame_idx")
ax.set_ylabel("Bin edge")

## Do the x-ray speckle visibility spectroscopy

In [None]:
photon_stat = np.zeros(
    (
        len(data["delay_ps_coord"]),
        len(data["delay_frame_idx"]),
        len(data["q_ring"]),
        len(data["bins"]),
    )
)

# Loop over q-rings
for k, ring in enumerate(tqdm(data["q_ring"].values)):
    idx = np.argwhere((ringmask == ring))

    # Calc only if slices are not empty
    if idx.size > 0:
        # Loop over delays
        for i, delay in enumerate(data["delay_ps_coord"].values):
            # Loop over image frame_indices
            for j, frame in enumerate(data["delay_frame_idx"].values):
                photon_stat[i, j, k], _ = np.histogram(
                    data.images.values[i, j, idx[:,0],idx[:,1]],
                    bin_edges,
                    density=True,
                )
                #photon_stat[i, j, k] = photon_stat[i, j, k] / np.sum(photon_stat[i, j, k])

# Assign to xarray and average over all files
data["photon_stat"] = xr.DataArray(
    photon_stat, dims=["delay_ps_coord", "delay_frame_idx", "q_ring", "bins"]
)
data["photon_stat_mean"] = data["photon_stat"].mean("delay_frame_idx")
data["photon_stat_std"] = data["photon_stat"].std("delay_frame_idx")

In [None]:
fields = [
    "photon_stat",
    "photon_stat_mean",
    "photon_stat_std",
    "ring_mean",
    str(scan_axis),
]
plotting = data[fields]
plotting["photon_stat"] = plotting["photon_stat"].mean("delay_frame_idx")
plotting["ring_mean"] = plotting["ring_mean"].mean("delay_frame_idx")
#plotting = plotting.groupby(scan_axis).mean()
plotting

## Plot results of each q-ring

In [None]:
# Plot with errorbars?
error_bars = False

# Columns and rows
columns = 3
row = int(np.ceil(nr_rings / columns))
# row = 1

fig = plt.figure(figsize=(3.5 * columns, 3.5 * row))
# fig.suptitle("%s: %s" % (sample_scan, "{:.0e}".format(sample[sample_scan].values)))
# fig.suptitle("%s: %s" % (sample_scan, "{:.0e}".format(setup[sample_scan])))

# Loop over rings
k = 0
for i in range(nr_rings):
    axes = fig.add_subplot(row, columns, k + 1)
    # axes.set_title("Q "+ '%.3f  '%(q_ring_center[i])+ r'$\AA^{-1}$')
    axes.set_title("q-Ring: %s" % i)
    axes.set_xlabel("K/<K>")
    axes.set_ylabel("P(K)")

    # Loop over delays
    for j, delay in enumerate(plotting["delay_ps_coord"].values):
        if error_bars is True:
            art = axes.errorbar(
                plotting["bins"].values / plotting.ring_mean[i,j].values,
                plotting["photon_stat_mean"][j, i].values,
                plotting["photon_stat_std"][j, i].values,
                fmt="-o",
                label="%.2f t.u." % delay,
            )
        else:
            art = axes.plot(
                plotting["bins"].values / plotting.ring_mean[i, j].values,
                plotting["photon_stat_mean"][j, i].values,
                "-o",
                label="%.2f t.u." % delay,
            )
    axes.set_xlim(-400, 600)
    #axes.set_xlim(-0.1, 0.2)
    #axes.set_xscale("log")
    axes.legend()
    axes.grid()
    k = k + 1
plt.show()

fname = join(fsave, "XSVS_%s.png" % scan)
print("Saving: %s" % fname)
plt.savefig(fname)