# Imports

Import python libraries as well as the self written FERMI library.

In [None]:
import sys, os
from os.path import join, split
from getpass import getuser
from glob import glob
from time import strftime
from importlib import reload
from tqdm.auto import tqdm

# data
import numpy as np
import xarray as xr
import pandas as pd
import h5py

# Images
import imageio
from imageio import imread

# Plotting
import matplotlib.pyplot as plt
from matplotlib.image import NonUniformImage
import matplotlib.gridspec as gridspec
from matplotlib.path import Path

# pyFAI
import pyFAI

pyFAI.disable_opencl = True  # get rid of annoying warning ;)
from pyFAI.azimuthalIntegrator import AzimuthalIntegrator
from pyFAI.detectors import Detector

# Scipy
from scipy.ndimage import median_filter

# Self-written libraries
sys.path.append(os.path.abspath(join(os.pardir,"process_FERMI")))
import helper_functions as helper
import mask_lib
import process_FERMI as pf
import interactive
from interactive import cimshow

In [None]:
# interactive plotting
import ipywidgets

%matplotlib widget
plt.rcParams["figure.constrained_layout.use"] = True

# Auto formatting of cells
#%load_ext jupyter_black

## Functions

In [None]:
def preprocess_exp(datafolder, extension, keys=None, sort=False, full_rate=False):
    # Loading experiment data
    print("Loading: %s" % (datafolder + extension))
    exp = pf.get_exp_dataframe(datafolder + extension, keys=keys)
    for k in ["xgm_UH", "xgm_SH", "diode_sum"]:
        exp[k + "_sum"] = exp[k].apply(np.sum)

    exp["diode_sum_mean"] = exp.diode_sum.apply(np.mean)
    exp["diode_sum_sum"] = exp.diode_sum.apply(np.sum)
    exp["diode_sum_std"] = exp.diode_sum.apply(np.std)
    exp["IR_mean"] = exp.IR.apply(np.mean)
    exp["IR_std"] = exp.IR.apply(np.std)
    exp["magnet_mean"] = exp.magnet.apply(np.mean)
    exp["magnet_mean"] = exp.magnet_mean.apply(np.round, args=(3,))
    exp["bunchid"] = exp.bunches.apply(lambda l: l[-1])

    if sort is True:
        exp = exp.sort_values(scan_axis)

    load_images = []
    for idx in range(len(exp["filename"])):
        try:
            if full_rate:
                temp = []
                index = 0
                load_images_full = pf.loadh5(
                    exp["filename"][idx], extra_keys=["alignz", "PAM/FQPDSum"]
                )[0].astype("float32")
                for i in range(len(load_images_full)):
                    if np.max(load_images_full[i]) > 0:
                        temp.append(load_images_full[i])
                        index += 1
                load_images.append(np.mean(temp, axis=0))
                print("Skipped %d empty frames" % (len(load_images_full) - index))
            else:
                load_images.append(
                    pf.loadh5(
                        exp["filename"][idx], extra_keys=["alignz", "PAM/FQPDSum"]
                    )[0].astype("float32")
                )
            print("Loaded %s" % exp["filename"][idx])
        except:
            print("Skipped %s" % exp["filename"][idx])

    exp["images"] = load_images

    ## Loading image data
    # exp["images"] = [
    #    np.mean(pf.loadh5(fname, extra_keys=["alignz", "PAM/FQPDSum"])[0], axis=0)
    #    # pf.loadh5(fname, extra_keys=["alignz", "PAM/FQPDSum"])[0]
    #    for fname in exp["filename"]
    # ]

    return exp

In [None]:
def filter_false_images(images,filter_thres):
    # Calc Monitoring parameter
    image_mean = np.mean(images,axis = (-2,-1))
    ensemble_mean = np.median(image_mean)
    image_std = np.std(image_mean)

    # Filter
    valid = np.abs(image_mean-ensemble_mean) < filter_thres * image_std

    # Plot filter condition
    fig, ax = plt.subplots()
    ax.plot(image_mean,'o-')
    ax.grid()
    ax.set_xlabel("Image Index")
    ax.set_ylabel("Image Mean")
    ax.set_title("Check for inconsistencies of the averaged intensity")
    ax.axhline(ensemble_mean,0,images.shape[0],color = 'g',linestyle = '--')
    ax.axhline((ensemble_mean + filter_thres*image_std),0,images.shape[0],color = 'r',linestyle = '--')
    ax.axhline((ensemble_mean - filter_thres*image_std),0,images.shape[0],color = 'r',linestyle = '--')
    
    return valid

# Experimental details

In [None]:
# Define basic folders
BASEFOLDER = r"/data/beamtimes/FERMI/2310_XPCS"
PROPOSAL = "20224053"
USER = getuser()

In [None]:
# Dict with most basic experimental parameter
experimental_setup = {
    "px_size": 11e-6,  # pixel_size of camera
    "binning": 1,  # Camera binning
}

# Setup for azimuthal integrator
detector = Detector(
    experimental_setup["binning"] * experimental_setup["px_size"],
    experimental_setup["binning"] * experimental_setup["px_size"],
)

# General saving folder
folder_target = pf.create_folder(join("/data/export/cklose/2310_FERMI_Skyrmion", "Results"))
print("Output Folder: %s" % folder_target)

# Load Data

## Define Scan ids for loading

In [None]:
# Define for loading
sample = "Sample54"
membrane = "C6"
scan_id = 188
scan = f"%s_Slu_%03d" % (membrane, scan_id)
scan_axis = "delay_ps"
full_rate = True

# Folder for loading
samplefolder = join(sample, scan)
datafolder = join(BASEFOLDER, samplefolder)
extra_keys = {
    "diode_sum": "PAM/FQPDSum",
    "IR": "Laser/Energy1",
    "magnet": "DPI/CoilWaveform",
    "bunches": "bunches",
    "time": "",
    "samplex": "DPI/SampleX",
    "sampley": "DPI/SampleY",
    "ccdz": "DPI/CcdZ",
    "global_delay": "Laser/DelayTotem",
}

# Create savefolder
fsave = helper.create_folder(join(folder_target, sample, membrane))

## Pumped images

In [None]:
# Loading experiment data
extension = ""
exp = preprocess_exp(datafolder, extension, keys=extra_keys)  # , full_rate=full_rate)

# delay scan
# Time-zero for delay scans
t0 = 7.25
exp["delay_ps"] = 6.67 * (exp["delay"] - t0)  # + (640 - exp["global_delay"])
exp = exp.sort_values("delay_ps")

# Add wavelength and distance
experimental_setup["lambda"] = exp["wavelength"][0] * 1e-9
experimental_setup["ccd_dist"] = (exp["ccdz"][0] + 50) * 1e-3


images_pump = np.stack(exp["images"])
print("Data loaded!")

In [None]:
# What did you scan?
fig, ax = plt.subplots()
ax.plot(np.arange(len(exp)), exp[scan_axis], "-o")
ax.set_xlabel("Index")
ax.set_ylabel(scan_axis)
ax.grid()

In [None]:
# Plot images
fig, ax = cimshow(images_pump)
fig.set_size_inches(6, 6)
ax.set_title("Pumped Images")

In [None]:
# Filter wrongly assigned frames, i.e., when shutter was not working correctly
valid = filter_false_images(images_pump,3)

In [None]:
# Take only valid values
exp = exp[valid]
images_pump = np.stack(exp["images"])

In [None]:
# Plot images
fig, ax = cimshow(images_pump)
fig.set_size_inches(6, 6)
ax.set_title("Pumped Images after filtering")

## Dark images

In [None]:
# Loading experiment data
extension = "_BG"
exp_bg = preprocess_exp(
    datafolder, extension, keys=extra_keys
)  # , full_rate=full_rate)
exp_bg = exp_bg.sort_values("time")

dark = np.stack(exp_bg["images"])
print("Data loaded!")

In [None]:
# Plot images
fig, ax = cimshow(dark)
fig.set_size_inches(6, 6)
ax.set_title("Dark Images")

In [None]:
# Filter wrongly assigned frames, i.e., when shutter was not working correctly
valid = filter_false_images(dark,2.5)

In [None]:
# Take only valid values
exp_bg = exp_bg[valid]
dark = np.stack(exp_bg["images"])

## Laser only

In [None]:
# Loading experiment data
extension = "_OL"
exp_ol = preprocess_exp(
    datafolder, extension, keys=extra_keys
)  # , full_rate=full_rate)
exp_ol = exp_ol.sort_values("time")

dark_ol = np.stack(exp_ol["images"])
print("Data loaded!")

In [None]:
# Plot images
fig, ax = cimshow(dark_ol)
fig.set_size_inches(6, 6)
ax.set_title("Only Laser Images")

In [None]:
# Filter wrongly assigned frames, i.e., when shutter was not working correctly
valid = filter_false_images(dark_ol,2)

In [None]:
# Take only valid values
exp_ol = exp_ol[valid]
dark_ol = np.stack(exp_ol["images"])

## FEL only

In [None]:
# Loading experiment data
extension = "_OF"
exp_of = preprocess_exp(
    datafolder, extension, keys=extra_keys
)  # , full_rate=full_rate)
exp_of = exp_of.sort_values("time")

dark_of = np.stack(exp_of["images"])
print("Data loaded!")

In [None]:
# Plot images
fig, ax = cimshow(dark_of)
fig.set_size_inches(6, 6)
ax.set_title("Only FEL Images")

In [None]:
# Filter wrongly assigned frames, i.e., when shutter was not working correctly
valid = filter_false_images(dark_of,0.5)

In [None]:
# Take only valid values
exp_of = exp_of[valid]
dark_of = np.stack(exp_of["images"])

# Preprocessing

## Normalize images

In [None]:
# Which key to use for normalization?
norm_key = "diode_sum_sum"
filter_key = "diode_sum_sum"

# Loop over images
images = []
for index, r in tqdm(exp.iterrows(), total=len(exp)):
    # Find closest dark image in time series
    idx = np.argmin(abs(r.time - exp_bg.time))
    im_bg = exp_bg.iloc[idx]["images"]

    # Find closest only laser image in time series
    idx = np.argmin(abs(r.time - exp_ol.time))
    im_ol = exp_ol.iloc[idx]["images"]

    # Find closest only fel image in time series
    idx = np.argmin(abs(r.time - exp_of.time))
    im_of = exp_of.iloc[idx]["images"]

    # Subtract background
    #im = (r.images - im_ol) / r[norm_key]
    #of_norm = (im_of - im_bg) / exp_of.iloc[idx][norm_key]
    #im = im - of_norm

    # Other option where background is subtracted first
    im = r.images.copy()
    im_ol = im_ol - im_bg
    im_of = im_of - im_bg
    im = im - im_bg

    # Correction factor arising from shutter issues (not all images are pumped with laser, but considered for averaging)
    shutter_correction = np.mean(im[1950:2025,750:1100])/np.mean(im_ol[1950:2025,750:1100]) #best option
    #shutter_correction = np.mean(im[1900:2025,750:1100]*im_ol[1900:2025,750:1100])/np.mean(im_ol[1900:2025,750:1100]*im_ol[1900:2025,750:1100])
    shutter_correction = 1
    
    im = (im - shutter_correction*im_ol)/ r[norm_key] - im_of / exp_of.iloc[idx][norm_key] 

    images.append(im)

# Plot Intensity distribution
fig, ax = plt.subplots()
ax.scatter(exp[scan_axis].values,np.mean(images,axis=(-2,-1)))
ax.grid()
ax.set_xlabel(str(scan_axis))
ax.set_ylabel("Mean Intensity")

# Setup xarray
data = xr.Dataset()
data["images"] = xr.DataArray(images, dims=["index", "y", "x"])
data[scan_axis] = xr.DataArray(exp[scan_axis], dims=["index"])
data[norm_key] = xr.DataArray(exp[norm_key], dims=["index"])

data

In [None]:
# Check filter for images with too high or low intensity
exp_mean = data[norm_key].mean()
exp_std = data[norm_key].std()
filter_thres = 1.5

# Plot filter condition
fig, ax = plt.subplots()
ax.plot(data[norm_key],'o')
ax.set_xlabel("Image Index")
ax.set_ylabel("Image Mean")
ax.axhline(exp_mean,color = 'g',linestyle = '--')
ax.axhline((exp_mean + filter_thres*exp_std),color = 'r',linestyle = '--')
ax.axhline((exp_mean - filter_thres*exp_std),color = 'r',linestyle = '--')
ax.grid()

In [None]:
# Apply filter
data = data.where(np.abs(data[norm_key] - exp_mean) < filter_thres * exp_std, drop=True)
images = data["images"].values
im_mean = data["images"].mean("index").values
data

In [None]:
# Plot images
fig, ax = cimshow(images)
fig.set_size_inches(6, 6)
ax.set_title("Normalized % Filtered Images")

## Draw beamstop mask

In [None]:
poly_mask = interactive.draw_polygon_mask(np.mean(images[-10:],axis=0))

In [None]:
# Take poly coordinates and mask from widget
p_coord = poly_mask.get_vertice_coordinates()
mask = poly_mask.full_mask.astype(int)

cimshow(mask)

print("Mask coordinates: %s" % p_coord)

In [None]:
def load_poly_coordinates():
    """
    Dictionary that stores polygon corner coordinates of all drawn masks
    Example: How to add masks with name "test":
    mask_coordinates["test"] = copy coordinates from above
    """

    # Setup dictonary
    mask_coordinates = dict()

    mask_coordinates["bs_cross"] = [[(1500.1, -18.0), (1508.4, 552.7), (1508.4, 854.6), (1498.7, 972.9), (1455.0, 982.0), (737.3, 997.2), (-9.0, 1002.7), (-9.0, 1156.4), (309.3, 1150.9), (956.8, 1145.4), (1505.6, 1134.4), (1508.4, 1219.5), (1507.5, 1577.5), (1507.1, 2059.1), (1659.3, 2057.6), (1664.2, 1693.6), (1661.2, 1138.7), (2055.3, 1133.1), (2056.6, 987.1), (1664.9, 988.1), (1656.9, 588.6), (1658.4, 145.5), (1658.4, -32.2)]]
    mask_coordinates["membranes"] = [[(1266.7, 969.9), (1267.7, 1015.0), (1337.4, 1016.0), (1337.4, 963.0)], [(1365.8, 976.6), (1367.8, 1019.8), (1425.6, 1019.8), (1430.6, 968.8)], [(1452.1, 974.7), (1454.1, 1030.6), (1531.6, 1032.6), (1536.5, 971.7)], [(1608.1, 971.7), (1609.1, 1027.7), (1705.2, 1025.7), (1707.2, 969.8)], [(1713.1, 974.7), (1713.1, 1033.5), (1793.5, 1030.6), (1792.5, 969.8)], [(1805.3, 969.8), (1805.3, 1016.9), (1877.9, 1018.8), (1879.8, 965.9)], [(1902.4, 973.7), (1904.4, 1019.8), (1969.1, 1014.9), (1961.2, 968.8)],[(1264.1, 1146.0), (1267.5, 1195.2), (1335.6, 1193.7), (1332.5, 1144.5)]]
    mask_coordinates["membranes_2"] = [[(1823.2, 993.6), (1822.4, 1040.9), (1881.0, 1043.4), (1888.3, 1022.7), (1881.4, 989.6)], [(1909.8, 998.3), (1908.3, 1035.4), (1935.6, 1047.8), (1966.1, 1044.1), (1980.7, 1030.7), (1981.4, 1006.7), (1972.3, 992.1), (1926.8, 989.6)]]
    return mask_coordinates

In [None]:
# Which drawn masks do you want to load? (you can add multiple masks in list e.g. ["bs_cross","bs_bar_delayscans"])
polygon_names = ["bs_cross","membranes"] 
mask = mask_lib.load_poly_masks(images[0].shape,load_poly_coordinates(),polygon_names)

fig, ax = plt.subplots(1, 2, figsize=(10, 5), sharex=True, sharey=True)
mi, ma = np.percentile(im_mean, [1, 99])
ax[0].imshow(im_mean * (1 - mask), vmin=mi, vmax=ma)
ax[0].set_title("(1-mask)")
ax[1].imshow(im_mean * mask, vmin=mi, vmax=ma)
ax[1].set_title("mask")

In [None]:
# Use widget to shift and expand or shrink the mask
ss_mask = interactive.Shift_Scale_Mask(im_of, mask, shift=[0, 0], scale=1)

In [None]:
# Take mask, shift and scaling from widget
mask, mask_shift, mask_scale = ss_mask.get_mask()

In [None]:
polygon_names = ["membranes_2"] 
mask = mask + mask_lib.load_poly_masks(images[0].shape,load_poly_coordinates(),polygon_names)
mask[mask>1] = 1

In [None]:
# Use widget to shift and expand or shrink the mask
ss_mask = interactive.Shift_Scale_Mask(im_of, mask, shift=[0, 0], scale=0)

## Find center

### Basic widget to find center

Try to **align** the circles to the **center of the scattering pattern**. Care! Position of beamstop might be misleading and not represent the actual center of the hologram. 

In [None]:
# Set center position via widget
ic = interactive.InteractiveCenter(np.mean(images[0:1]*(1-mask), axis=0),c0=1087,c1=1583,rBS=100)

In [None]:
# Get center positions
center = [ic.c0, ic.c1]
print(f"Center:", center)

### Azimuthal integrator widget for finetuning

In [None]:
# Setup azimuthal integrator for virtual geometry
ai = interactive.AzimuthalIntegrator(
    dist=experimental_setup["ccd_dist"],
    detector=detector,
    wavelength=experimental_setup["lambda"],
    poni1=center[0]
    * experimental_setup["px_size"]
    * experimental_setup["binning"],  # y (vertical)
    poni2=center[1]
    * experimental_setup["px_size"]
    * experimental_setup["binning"],  # x (horizontal)
)

In [None]:
# Plotting to find  relevant q range
I_t, q_t, phi_t = ai.integrate2d(
    images[-1],
    200,
    radial_range=(0, 0.1),
    unit="q_nm^-1",
    correctSolidAngle=False,
    dummy=np.nan,
    mask=mask,
    method="bbox"
)
az2d = xr.DataArray(I_t, dims=("phi", "q"), coords={"q": q_t, "phi": phi_t})

# Plot
fig, ax = plt.subplots()
mi, ma = np.nanpercentile(I_t, [1, 95])
az2d.plot.imshow(ax=ax, vmin=mi, vmax=ma)
plt.title(f"Azimuthal integration")

# Vertical lines
# q_lines = [0.025, 0.05]
# for qt in q_lines:
#    ax.axvline(qt, ymin=0, ymax=180, c="red")

In [None]:
aic = interactive.AzimuthalIntegrationCenter(
    # np.log10(images[36] - np.min(images[36]) + 1),
    np.mean(images[-10:], axis=0),
    ai,
    c0=center[0],
    c1=center[1],
    mask=mask,
    im_data_range=[1, 95],
    radial_range=(0.003, 0.07),
    qlines=[40, 60],
)

In [None]:
# Get center positions
center = [aic.c0, aic.c1]
print(f"Center:", center)

# Azimuthal Integration

In [None]:
# Setup final azimuthal integrator for virtual geometry
ai = interactive.AzimuthalIntegrator(
    dist=experimental_setup["ccd_dist"],
    detector=detector,
    wavelength=experimental_setup["lambda"],
    poni1=center[0]
    * experimental_setup["px_size"]
    * experimental_setup["binning"],  # y (vertical)
    poni2=center[1]
    * experimental_setup["px_size"]
    * experimental_setup["binning"],  # x (horizontal)
)

In [None]:
# Do 2d Azimuthal integration of all images and add to xarray
list_i2d = []
for im in tqdm(data["images"].values):
    i2d, q, chi = ai.integrate2d(im, 500, 90, dummy=np.nan, mask=mask,method="bbox")
    list_i2d.append(i2d)

# Setup xarray
data["q"] = q
data["chi"] = chi
data["i2d"] = xr.DataArray(list_i2d, dims=["index", "chi", "q"])
data = data.assign_attrs({"center": center})

## Select relevant chi-range

In [None]:
# Plot 2d and 1d azimuthal integration to estimate the relevant chi and q range
# which image to show?
idx = -1

# Select chi-range
# Which chi-mode? (all,hetero,homo)
chi_mode = "all"

# Select chi-range
if chi_mode == "all":
    sel_chi = (data.chi <= 180) * (data.chi >= -180)
    data["i1d"] = data.i2d.where(sel_chi, drop=True).mean("chi")
elif chi_mode == "hetero":
    sel_chi = (data.chi < 180) * (data.chi > 95)
    data["i1d"] = data.i2d.where(sel_chi, drop=True).mean("chi")
elif chi_mode == "homo":
    sel_chi = (data.chi <= -95) * (data.chi >= -180) + (data.chi <= 90) * (
        data.chi >= 5
    )
    data["i1d"] = data.i2d.where(sel_chi, drop=True).mean("chi")
elif chi_mode == "other":
    sel_chi = (data.chi <= 150) * (data.chi >= 40)
    data["i1d"] = data.i2d.where(sel_chi, drop=True).mean("chi")
# Plot
fig, ax = plt.subplots(
    2,
    1,
    figsize=(8, 8),
    sharex=True,
)
mi, ma = np.nanpercentile(I_t, [0.1, 90])
data["i2d"][idx].plot.imshow(ax=ax[0], vmin=mi, vmax=ma)
ax[0].set_title(f"2d Azimuthal integration")
ax[0].grid()

# Plot 1d azimuthal integration to estimate the relevant q-range
ax[1].plot(data.q, data.i1d[idx])
ax[1].set_yscale("log")
ax[1].set_title("1d Azimuthal Integration")
ax[1].grid()
ax[1].set_ylabel("Integrated intensity")
ax[1].set_xlabel("q")

## Select relevant q-range

In [None]:
# Select relevant q-range for averaging
q0, q1 = 0.005, 0.1
binning = False
bins = []

# Get SAXS from q-range
sel = (data.q > q0) * (data.q < q1)
data["saxs"] = data.i1d.where(sel, drop=True).mean("q")

# Averaging of same scan axis values or binning
if binning is True:
    # Execute binning
    data_bin = data.groupby_bins(scan_axis, bins).mean()

    # Rename binned values, drop intervals as those cannot be save in h5
    bin_scan_axis = scan_axis + "_bins"
    data_bin = data_bin.swap_dims({bin_scan_axis: scan_axis})
    data_bin = data_bin.drop(bin_scan_axis)
else:
    _, count = np.unique(data[scan_axis].values, return_counts=True)
    if np.any(count > 1):
        data_bin = data.groupby(scan_axis).mean()
    else:
        data_bin = data.swap_dims({"index": scan_axis})

# To create log plot
data_bin["i1dlog"] = np.log10(data_bin["i1d"]+ 1)

# Add scan identifier
data_bin["scan"] = scan

# Add AI mask
data_bin["mask"] = xr.DataArray(mask, dims=["y", "x"])

# Plotting

In [None]:
# Plot I(q,t) and integrated intensity
fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(6, 4.5), sharex=True)
vmin, vmax = data_bin["i1d"].min(), data_bin["i1d"].max()
vmin, vmax = np.percentile(data_bin["i1d"], [1, 99])
data_bin["i1d"].plot.contourf(
    x="delay_ps",
    y="q",
    ax=ax[0],
    cmap="viridis",
    add_colorbar=False,
    vmin=vmin,
    vmax=vmax,
    levels=100,
    ylim = [q0,q1]
)

ax[1].plot(data_bin["delay_ps"], data_bin["saxs"], "o-")
ax[1].grid()
ax[0].set_xlabel("")
ax[1].set_ylabel("total scattered intensity")
ax[1].set_xlabel("delay (ps)")

ir = exp["IR_mean"].mean()
mag = exp.magnet[0][0]
fig.suptitle(f"{samplefolder}: IR: {ir:.2f} µJ, magnet: {mag:.2f} A")

fname = join(fsave, "SAXS_%s_qdelay_%s.png" % (scan, USER))
print("Saving: %s" % fname)
plt.savefig(fname)

# Gif

## Select roi for plotting

How to use:
1. Zoom into the image and adjust your FOV until you are satisfied.
2. Save the axes coordinates.

In [None]:
fig, ax = cimshow(data_bin["images"].values*(1-mask))

In [None]:
# Takes start and end of x and y axis
x1, x2 = ax.get_xlim()
y2, y1 = ax.get_ylim()
roi = np.array([int(y1), int(y2), int(x1), int(x2)])
roi_s = np.s_[roi[0] : roi[1], roi[2] : roi[3]]
print(f"Roi:", roi)

## Plotting

In [None]:
# Find max and min considering all images
allmin, allmax = np.nanpercentile(data_bin["i2d"].values, [10, 99])
print("Min: %d Max: %d" % (allmin, allmax))

# Create folder for gif single frames
folder_gif = helper.create_folder(join(fsave, "Scan_%s" % scan))

im_fnames = []
for i in tqdm(range(len(data_bin[scan_axis].values))):
    # Plot for averaged image
    fig = plt.figure(figsize=(6, 10))
    gs1 = gridspec.GridSpec(
        4,
        1,
        figure=fig,
        left=0.2,
        bottom=0.05,
        right=0.975,
        top=1.1,
        wspace=0,
        hspace=0,
        height_ratios=[6, 1, 2, 1],
    )

    # Plot image roi
    ax0 = fig.add_subplot(gs1[0])
    ir = exp["IR_mean"].mean()
    mag = exp.magnet[0][0]  # exp.magnet.mean()
    ax0.set_title(
        f"{scan}: IR: {ir:.1f} µJ, magnet: {mag:.2f} A, delay: {data_bin[scan_axis].values[i]:.0f} ps",
        fontsize=10,
    )
    tmp = data_bin["images"][i].values
    m = ax0.imshow(tmp[roi_s]*(1-mask[roi_s]), vmin=allmin, vmax=allmax)
    plt.colorbar(m, ax=ax0, pad=0.045, location="bottom")

    # Plot 1d azimuthal integration
    ax1 = fig.add_subplot(gs1[1])
    tmp = data_bin.i1d[i]
    ax1.plot(data_bin.q, tmp)
    ax1.set_xlabel("q")
    ax1.set_ylabel("Mean Intensity")
    ax1.set_xlim([q0, q1])
    ax1.set_ylim([allmin, allmax])
    #ax1.set_yscale("log")
    ax1.grid()

    
    ax2 = fig.add_subplot(gs1[2])
    vmin, vmax = np.nanpercentile(data_bin["i1d"], [.5, 99.5])
    data_bin["i1d"].plot.contourf(
        x=scan_axis,
        y="q",
        ax=ax2,
        cmap="viridis",
        add_colorbar=False,
        vmin=vmin,
        vmax=vmax,
        levels=200,
        ylim = [q0,q1]
    )
    ax2.vlines(data_bin[scan_axis].values[i], q0, q1,'r')
    ax2.hlines(q0, data_bin[scan_axis].min(),data_bin[scan_axis].max(),'w',linestyles='dashed')
    ax2.hlines(q1, data_bin[scan_axis].min(),data_bin[scan_axis].max(),'w',linestyles='dashed')

    # Plot SAXS Intensity
    ax3 = fig.add_subplot(gs1[3])
    ax3.plot(data_bin[scan_axis].values, data_bin["saxs"].values)
    ax3.scatter(data_bin[scan_axis].values[i], data_bin["saxs"].values[i], 20, color="r")
    ax3.set_xlabel(scan_axis)
    ax3.set_ylabel("Mean intensity")
    ax3.grid()
    ax3.set_xlim(data_bin[scan_axis].min(),data_bin[scan_axis].max())

    # Save
    fname = join(folder_gif, "SAXS_%s_%03d_%s.png" % (scan, i, USER))
    im_fnames.append(fname)
    plt.savefig(fname)
    plt.close()

# Create gif for 1d AI
fname = f"SAXS_%s_%s.gif" % (scan, USER)
gif_path = join(fsave, fname)
print("Saving gif:%s" % gif_path)
helper.create_gif(im_fnames,gif_path,fps=2)
print("Done!")

In [None]:
# Drop images
data_bin2 = data_bin.drop_vars(["images"])

# Save log
folder = join(fsave, "Logs")
helper.create_folder(folder)
fname = join(folder, "Log_SAXS_Scan_%03d_%s.nc" % (scan_id, USER))

print(f"Saving:", fname)
data_bin2.to_netcdf(fname)