# Pump Probe Notebook
## For 1015861: This notebook can be used to look at the pump probe signal in the filtered and unfiltered data, but be wary about the filtered results.

This notebook is used to take pump-probe runs, either timetool corrected (ttc) or not and:
 - Filtering of the data
 - Re-bin in time / timetool-correct
 - Determine a normalization method
 - Plot percent difference signal
 - Plot $q$, and $\Delta t$ lineouts 

In [None]:
import xrayscatteringtools as xrst
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter
from matplotlib.colors import LogNorm
import warnings
xrst.enable_underscore_cleanup()

### Loading in the data

In [None]:
###############################################
runNumbers = [] # <- this must be a list of int.
folders = xrst.get_data_paths(runNumbers) # Defaults to the info in config.yaml. You can overwrite this with strings, character arrays, or lists of either.
_whichSide = 'unfiltered' # <- Change this if you want to look at the pump probe signal on each side seperatly.
###############################################
# (1) keys_to_combine: some keys loaded for each shot & stored per shot 
# (2) keys_to_sum: some keys loaded per each run and added 
# (3) keys_to_check : check if some keys exist and have same values in all runs and load these keys 
_keys_to_combine = [
    # Azimuthal Averages
    'jungfrau4M/azav_mask0_azav',
    'jungfrau4M/azav_mask1_azav',
    # Photon energy
    'ebeam/photon_energy',
    # Timetool data
    'tt/FLTPOS', 'tt/AMPL', 'tt/FLTPOSFWHM',
     # Laser-Xray Timing. Use either time tool corrected (ttc) or not.
    # 'scan/lxt',
    'scan/lxt_ttc',
    # Upstream Diode
    'ipm_dg2/sum',
    # Gas detectors
    'gas_detector/f_11_ENRC', 'gas_detector/f_22_ENRC',
    # Light status
    'lightStatus/laser', 'lightStatus/xray',
]
_keys_to_sum = [
    'Sums/jungfrau4M_calib',
    'Sums/jungfrau4M_calib_xrayOn_thresADU1'
]
_keys_to_check = [
    'UserDataCfg/jungfrau4M/cmask', # Combined mask
    'UserDataCfg/jungfrau4M/azav_mask0__azav_mask0_q', # q bin centers
    'UserDataCfg/jungfrau4M/azav_mask0__azav_mask0_qbin', # q bin size
    'UserDataCfg/jungfrau4M/azav_mask0__azav_mask0_qbins', # q bin edges
    'UserDataCfg/jungfrau4M/azav_mask0__azav_mask0_userMask', # Mask for unfiltered side
    'UserDataCfg/jungfrau4M/azav_mask1__azav_mask1_userMask', # Mask for filtered side
    # These keys are typically not needed, but feel free to uncomment them.
    # 'UserDataCfg/jungfrau4M/azav__azav_idxq',
    # 'UserDataCfg/jungfrau4M/azav__azav_idxphi',
    # 'UserDataCfg/jungfrau4M/azav__azav_nphi',
    # 'UserDataCfg/jungfrau4M/azav__azav_matrix_q',
    # 'UserDataCfg/jungfrau4M/azav__azav_matrix_phi',
]
##### Load the data in #####
_data = xrst.combineRuns(runNumbers, folders, _keys_to_combine, _keys_to_sum, _keys_to_check, verbose=False)  # this is the function to load the data with defined keys
############################
# String for nice things
runNumbersRange = xrst.compress_ranges(runNumbers)
runType = xrst.get_config_for_runs(runNumbers[0],'samples','sample') # Default to information in the first run you load.
niceTitle = f"{xrst.get_config('expNumber')} : {'Run' if np.size(runNumbers)==1 else 'Runs'} {runNumbersRange} : {runType}"

# Jungfrau4M Data either filtered or unfiltered
if _whichSide == 'unfiltered':
    azav = np.squeeze(_data['jungfrau4M/azav_mask0_azav'])
    _mask = _data['UserDataCfg/jungfrau4M/azav_mask0__azav_mask0_userMask'].astype(bool) # Mask for detector created
else:
    azav = np.squeeze(_data['jungfrau4M/azav_mask1_azav'])
    _mask = _data['UserDataCfg/jungfrau4M/azav_mask1__azav_mask1_userMask'].astype(bool) # Mask for detector created

# Combining these two masks here into the cmask so I don't have to change any code later on
_cmask = _data['UserDataCfg/jungfrau4M/cmask'].astype(bool) # Mask for detector created 
cmask = _cmask & _mask

J4MSum = np.nansum(azav,axis=tuple(range(1, azav.ndim)))
q = _data['UserDataCfg/jungfrau4M/azav_mask0__azav_mask0_q'] # q bin centers
qbin = _data['UserDataCfg/jungfrau4M/azav_mask0__azav_mask0_qbin'] # q bin-size
qbins = _data['UserDataCfg/jungfrau4M/azav_mask0__azav_mask0_qbins'] # q bins edges
jungfrau_sum = _data['Sums/jungfrau4M_calib_xrayOn_thresADU1']   # Total Jungfrau detector counts with Thresholds added, summed in a run

# Scan and timetool information
scan = _data['scan/lxt_ttc']
# scan = data['scan/lxt']
time_bins = (np.unique(scan[:]))*10**12  # list of time bins in picoseconds
ttpos = _data['tt/FLTPOS']
ttampl = _data['tt/AMPL']
ttfwhm = _data['tt/FLTPOSFWHM']

# Event codes
laserOn = _data['lightStatus/laser'].astype(bool)  # laser on events
xrayOn = _data['lightStatus/xray'].astype(bool)  # xray on events
run_indicator = _data['run_indicator'] # run indicator for each shot

# X-ray beam diagnostics
dg2 = _data['ipm_dg2/sum']   # upstream diode x-ray intensity
pulse_energy = _data['gas_detector/f_11_ENRC']   # xray energy from gas detector (not calibrated to actual values)
photon_energy = _data['ebeam/photon_energy']    # x-ray energy energy in eV
# Print total shots
_total_shots = len(run_indicator)
print("Total shots: ", _total_shots)

### Plot the 1 ADU thresholded masked Jungfrau4M Sum

In [None]:
###############################################
if True: # For plotting or not
###############################################
    
    _masked_sum = np.copy(jungfrau_sum)
    _masked_sum[~cmask.astype(bool)] = np.nan
    plt.figure(figsize=(10,8))
    pcm = xrst.plot_j4m(_masked_sum,vmin=0,vmax=np.nanpercentile(_masked_sum,99.5))
    plt.colorbar(pcm)
    plt.title('1 ADU Jungfrau4M Sum')
    plt.suptitle(niceTitle)
    plt.show()

### Determine thresholds for different detectors and diagnostics, determine good shots, and good timetool data.

In [None]:
########## Different filter cutoffs##########
_J4M_cutoff = [0.4, 1];
_dg2_cutoff = [0.4, 1];
_pulse_energy_cutoff = [0.5, 1.5] # In mJ!!!
_tt_edgePos_cutoff = [0.1, 1]     
_tt_amp_cutoff = [0.3, 1]
_tt_width_cutoff = [0.1, 0.3]
_plot_display = 'log' # 'linear'
#############################################

# Precomputing the normalized values
_J4MSumNorm = J4MSum/np.nanmax(J4MSum);
_dg2Norm = dg2/np.nanmax(dg2);
_ttposNorm = ttpos/np.nanmax(ttpos)
_ttamplNorm = ttampl/np.nanmax(ttampl)
_ttfwhmNorm = ttfwhm/np.nanmax(ttfwhm)

plt.figure(figsize=[17,10]) 

plt.subplot(2,3,1)
plt.hist(_J4MSumNorm,bins=200,range=[0,np.nanmax(_J4MSumNorm)]);
plt.axvline(_J4M_cutoff[0],color='r',linestyle='--')
plt.axvline(_J4M_cutoff[1],color='r',linestyle='--')
plt.axvspan(_J4M_cutoff[0],_J4M_cutoff[1],color='r',alpha=0.2,label='"Good" Data')
plt.yscale(_plot_display)
plt.title('J4M Sum Histogram')
plt.xlabel('Fraction of Maximum')
plt.ylabel('Counts')
plt.legend()

plt.subplot(2,3,2)
plt.hist(_dg2Norm,bins=200,range=[0,np.nanmax(_dg2Norm)]);
plt.axvline(_dg2_cutoff[0],color='r',linestyle='--')
plt.axvline(_dg2_cutoff[1],color='r',linestyle='--')
plt.axvspan(_dg2_cutoff[0],_dg2_cutoff[1],color='r',alpha=0.2,label='"Good" Data')
plt.yscale(_plot_display)
plt.title('DG2 Sum Histogram')
plt.xlabel('Fraction of Maximum')
plt.ylabel('Counts')
plt.legend()

plt.subplot(2,3,3)
plt.hist(pulse_energy,bins=200,range=[0,np.nanmax(pulse_energy)]);
plt.axvline(_pulse_energy_cutoff[0],color='r',linestyle='--')
plt.axvline(_pulse_energy_cutoff[1],color='r',linestyle='--')
plt.axvspan(_pulse_energy_cutoff[0],_pulse_energy_cutoff[1],color='r',alpha=0.2,label='"Good" Data')
plt.yscale(_plot_display)
plt.title('Pulse Energy Histogram')
plt.xlabel('mJ')
plt.ylabel('Counts')
plt.legend()

plt.subplot(2,3,4)
plt.hist(_ttposNorm,bins=200,range=[0,np.nanmax(_ttposNorm)]);
plt.axvline(_tt_edgePos_cutoff[0],color='r',linestyle='--')
plt.axvline(_tt_edgePos_cutoff[1],color='r',linestyle='--')
plt.axvspan(_tt_edgePos_cutoff[0],_tt_edgePos_cutoff[1],color='r',alpha=0.2,label='"Good" Data')
plt.yscale(_plot_display)
plt.title('Timetool Edge Histogram')
plt.xlabel('Fraction of Maximum')
plt.ylabel('Counts')
plt.legend()

plt.subplot(2,3,5)
plt.hist(_ttamplNorm,bins=200,range=[0,np.nanmax(_ttamplNorm)]);
plt.axvline(_tt_amp_cutoff[0],color='r',linestyle='--')
plt.axvline(_tt_amp_cutoff[1],color='r',linestyle='--')
plt.axvspan(_tt_amp_cutoff[0],_tt_amp_cutoff[1],color='r',alpha=0.2,label='"Good" Data')
plt.yscale(_plot_display)
plt.title('Timetool Amplitude Histogram')
plt.xlabel('Fraction of Maximum')
plt.ylabel('Counts')
plt.legend()

plt.subplot(2,3,6)
plt.hist(_ttfwhmNorm,bins=200,range=[0,np.nanmax(_ttfwhmNorm)]);
plt.axvline(_tt_width_cutoff[0],color='r',linestyle='--')
plt.axvline(_tt_width_cutoff[1],color='r',linestyle='--')
plt.axvspan(_tt_width_cutoff[0],_tt_width_cutoff[1],color='r',alpha=0.2,label='"Good" Data')
plt.yscale(_plot_display)
plt.title('Timetool FWHM Histogram')
plt.xlabel('Fraction of Maximum')
plt.ylabel('Counts')
plt.legend()

plt.suptitle(niceTitle)
plt.show()
goodIdx = np.logical_and.reduce([
    _J4M_cutoff[0] <= _J4MSumNorm,
    _J4MSumNorm <= _J4M_cutoff[1],
    _dg2_cutoff[0] <= _dg2Norm,
    _dg2Norm <= _dg2_cutoff[1],
    _pulse_energy_cutoff[0] <= pulse_energy,
    pulse_energy <= _pulse_energy_cutoff[1],
    xrayOn
])
goodIdx_timetool = np.logical_and.reduce([
    _tt_edgePos_cutoff[0] <= _ttposNorm,
    _ttposNorm <= _tt_edgePos_cutoff[1],
    _tt_amp_cutoff[0] <= _ttamplNorm,
    _ttamplNorm <= _tt_amp_cutoff[1],
    _tt_width_cutoff[0] <= _ttfwhmNorm,
    _ttfwhmNorm <= _tt_width_cutoff[1],
    xrayOn, laserOn
])
# Displaying how much data was kept due to this filtering
_counts, _ = np.histogram(goodIdx.astype(int),[0,1,2]);
_counts_timetool, _ = np.histogram(goodIdx_timetool.astype(int),[0,1,2]);
print(f'goodIdx data represents {_counts[1]/np.sum(_counts)*100:.2f}% of the total shots. ({_counts[1]} out of {np.sum(_counts)}).')
print(f'goodIdx_timetool data represents {_counts_timetool[1]/np.sum(_counts_timetool)*100:.2f}% of the total shots. ({_counts_timetool[1]} out of {np.sum(_counts_timetool)}).')

### Extra filtering based on correlations between diagnostics (I may wrap this into a function in the future, it is nice)

In [None]:
#####################################################################
# Define your bounds as dictionaries with slope 'm' and intercept 'b'
line_upper = {'m': 1.0, 'b': 0.15}
line_lower = {'m': 1.0, 'b': -0.15}
_bypass = False
#####################################################################

# Normalize axes to fraction of max
_dg2_subset = dg2[goodIdx]
_j4m_subset = J4MSum[goodIdx]

_dg2_min = np.nanmin(_dg2_subset)
_dg2_max = np.nanmax(_dg2_subset - _dg2_min)
_dg2_norm = (_dg2_subset - _dg2_min) / _dg2_max

_j4m_min = np.nanmin(_j4m_subset)
_j4m_max = np.nanmax(_j4m_subset - _j4m_min)
_j4m_norm = (_j4m_subset - _j4m_min) / _j4m_max

plt.figure(figsize=[12,5])

# 2D Histogram
plt.subplot(1,2,1)
plt.hist2d(_dg2_norm, _j4m_norm, bins=200, norm=LogNorm())
plt.colorbar()
plt.xlabel('DG2 (fraction of max)')
plt.ylabel('J4M (fraction of max)')
plt.title('J4M vs DG2 Histogram, Good Shots')

# Plot the arbitrary linear boundary lines
_x = np.linspace(0, 1, 500)
_y_lower_plot = line_lower['m'] * _x + line_lower['b']
_y_upper_plot = line_upper['m'] * _x + line_upper['b']

plt.plot(_x, _y_lower_plot, '--', color='r')
plt.plot(_x, _y_upper_plot, '--', color='r')

plt.xlim(0,1)
plt.ylim(0,1)

# 1D Histogram (Relative Position between lines)
plt.subplot(1,2,2)

# Wxpected Y values for every data point on both boundary lines
_y_bound_lower = line_lower['m'] * _dg2_norm + line_lower['b']
_y_bound_upper = line_upper['m'] * _dg2_norm + line_upper['b']

# Calculate relative position between the lines
# 1e-9 to prevent division by zero in case lines intersect
_relative_position = (_j4m_norm - _y_bound_lower) / (_y_bound_upper - _y_bound_lower + 1e-9)

# Plotting range is set slightly wider than [0, 1] to show the rejected data tails
plt.hist(_relative_position, bins=200, range=(-1.5, 2.5))
plt.axvline(0, color='r', linestyle='--')
plt.axvline(1, color='r', linestyle='--')
plt.axvspan(0, 1, color='r', alpha=0.2, label='"Good" Data')
plt.legend()
plt.xlabel('Relative Position (0=Lower Line, 1=Upper Line)')
plt.ylabel('Counts')
plt.title('Data Position Between Bounds')

plt.suptitle(niceTitle)
plt.tight_layout()
plt.show()

# make the boolean mask
if _bypass:
    
    _subset_pass = np.ones_like(_j4m_subset)
    print('Bypassing this filter')
else:
    _subset_pass = np.logical_and(
        _j4m_norm >= _y_bound_lower,
        _j4m_norm <= _y_bound_upper
    )
    _counts, _ = np.histogram(_subset_pass.astype(int),[0,1,2]);
    print(f'This filter retained {_counts[1]/np.sum(_counts)*100:.2f}% of the already filtered shots. ({_counts[1]} out of {np.sum(_counts)}).')
goodIdx_2 = np.zeros_like(goodIdx, dtype=bool)
goodIdx_2[goodIdx] = _subset_pass
# Displaying how much data was kept due to this filtering

### Timetool correct (or not) all of the shots, bin them into custom time bins. This notebook can handle runs with different timetool corrections.

In [None]:
#############################################################################################################################################
do_timetool_correction = True
# Edit this as desired, will only have an effect if do_timetool_correction is True. Defaults to time_bins if list is empty. 
_new_time_bin_edges = np.arange(-2.5,2.5,0.03)
# Use the data in config.yaml or set your own, these should be in seconds, and can be scalar, sequence or ndarray
_tt_slopes = xrst.get_config_for_runs(runNumbers,'tt_calibration','slope')
_tt_intercepts = xrst.get_config_for_runs(runNumbers,'tt_calibration','intercept')
# Bool if you want
_plotting = True
#############################################################################################################################################

# Select bins, in ps
if do_timetool_correction and len(_new_time_bin_edges)!=0:
    time_bins_selected = _new_time_bin_edges
else:
    # Default to the time bins that were recorded
    time_bins_selected = time_bins

# Doing the timetool correction on a run by run basis
if do_timetool_correction:
    time_delay = xrst.calib.apply_timetool_correction(
        delays = scan,
        edge_positions = ttpos,
        slopes = _tt_slopes,
        intercepts = _tt_intercepts,
        run_indicator = run_indicator # Optional, if None, only one timetool calibration applied to all shots.
    )
    if _plotting:
        plt.figure(figsize=[12,5])
        plt.subplot(1,2,1)
        plt.hist(time_delay[goodIdx_timetool]*10**12,bins=200)
        plt.xlabel('Delay (ps)')
        plt.title('Timetool corrected delays, before binning')
        plt.ylabel('Counts')
        plt.subplot(1,2,2)
else:
    time_delay = scan
    
# Convert to ps, determione bin idx for each shot
binned_shot_idxs = np.digitize(time_delay*10**12,time_bins_selected)

if _plotting:
    plt.hist(time_delay[goodIdx_timetool]*10**12,bins=time_bins_selected)
    plt.xlabel('Delay (ps)')
    plt.title('Counts per selected time bin')
    plt.ylabel('Counts')
plt.suptitle(niceTitle)
plt.show()

### Determine Final Shot masks

In [None]:
if do_timetool_correction:
    # goodIdx_timetool already includes laserOn shots
    on_mask =  goodIdx_2 & goodIdx_timetool
else:
    # Just the good shots where the laser is on
    on_mask = goodIdx_2 & laserOn
# Timetool isn't needed for the off shots
off_mask = goodIdx_2 & ~laserOn

# Displaying how much data was kept due to this filtering
_counts, _ = np.histogram(on_mask.astype(int),[0,1,2]);
print(f'Combined filters reatined {_counts[1]/np.sum(laserOn)*100:.2f}% of the laser on shots. ({_counts[1]} out of {np.sum(laserOn)}).')
_counts, _ = np.histogram(off_mask.astype(int),[0,1,2]);
print(f'Combined filters reatined {_counts[1]/np.sum(~laserOn)*100:.2f}% of the laser off shots. ({_counts[1]} out of {np.sum(~laserOn)}).')

### Select normalization method

In [None]:
##########################################################################################
_norm_method = 'dg2' # Supports 'dg2', 'none', 'q_range'
_q_range = [2.5,4] # This is only used if you select 'q-range' as the normalization method
##########################################################################################
if _norm_method == 'none':
    azav_norm = azav
elif _norm_method == 'dg2':
    with np.errstate(divide='ignore', invalid='ignore'):
        azav_norm = azav/dg2[:,np.newaxis]
elif _norm_method == 'q_range':
    # Determine valid q range idxs for filtereing the signal.
    _qidx_norm = (_q_range[0]<q) & (q<_q_range[1])
    with np.errstate(divide='ignore', invalid='ignore'):
        azav_norm = azav / np.nanmean(azav[:,_qidx_norm],axis=1)[:,np.newaxis]

### Calculate the pump-probe signal (2D)

In [None]:
################################################################################################################
# Set the rolling window size for the background (in number of total shots). Set really high for total avg 
_rolling_window = 10_000
# Set a gaussian blur here
_sigma = 0
# Color limits for the plot
_clims = [-1.5, 1.5]
# Adjust the timing manually
t0 = -1.7 
################################################################################################################

# Isolate the laser off shots chronologically
# Create an array of NaNs, then fill only the valid off shots
azav_off_chronological = np.full_like(azav_norm, np.nan)
azav_off_chronological[off_mask] = azav_norm[off_mask]

# Calculate the rolling average of the off shots
azav_off_rolling = pd.DataFrame(azav_off_chronological).rolling(
    window=_rolling_window, 
    center=True, 
    min_periods=1
).mean().values

# Compute the percent difference shot-by-shot chronologically
with np.errstate(divide='ignore', invalid='ignore'):
    dI_I_shot_by_shot = 100 * (azav_norm - azav_off_rolling) / azav_off_rolling

# Bin the calculated percent differences by time delay
dI_I_2d = np.zeros((len(time_bins_selected), azav_norm.shape[1]))

for _tbinIdx in range(len(time_bins_selected)):
    # Mask to select valid laser on shots for the current time bin
    _tbin_mask = on_mask & (binned_shot_idxs == _tbinIdx)

    # Safe nanmean
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=RuntimeWarning)
        dI_I_2d[_tbinIdx] = np.nanmean(dI_I_shot_by_shot[_tbin_mask], axis=0)

# Cropping the outside nans for plotting
_keep_points = np.argwhere(~np.isnan(dI_I_2d))
if len(_keep_points) > 0:
    _top_left = _keep_points.min(axis=0)
    _bottom_right = _keep_points.max(axis=0)
    dI_I_2d_cropped = dI_I_2d[_top_left[0]:_bottom_right[0]+1, _top_left[1]:_bottom_right[1]+1]
    q_cropped = q[_top_left[1]:_bottom_right[1]+1]
    time_bins_selected_cropped = time_bins_selected[_top_left[0]:_bottom_right[0]+1]
else:
    # Fallback in case all data is NaN
    dI_I_2d_cropped, q_cropped, time_bins_selected_cropped = dI_I_2d, q, time_bins_selected

# Plot the resulting binned array
plt.figure()
plt.pcolormesh(q_cropped, time_bins_selected_cropped - t0, gaussian_filter(dI_I_2d_cropped, sigma=_sigma), cmap='seismic', shading='auto')
plt.xlabel('q ($\AA^{-1}$)', fontsize=14)
plt.ylabel('Time delay (ps)', fontsize=14)
plt.title('ΔI/I', fontsize=14)
plt.colorbar()
plt.clim(_clims)
plt.suptitle(niceTitle)
plt.show()

### Plot selected q-lineouts

In [None]:
#########################################################
# Make a list of q-lineouts here like this:
_q_lineouts = [
    # [1,2],
    # [3,3.9],
    # [3.9,4.5]
]
# Set a gaussian blur here
_sigma = 0
# Color limits for the plots
_clims = [-1.5, 1.5]
# Time limits for the plots
_tlims = [np.nanmin(time_bins_selected_cropped), np.nanmax(time_bins_selected_cropped)]
#########################################################
# Pull out the color cycle
_color_cycle = plt.rcParams['axes.prop_cycle'].by_key()['color']
# Precompute so we don't have to do it twice
_2d_data = gaussian_filter(dI_I_2d_cropped, sigma=_sigma)
_time_axis = time_bins_selected_cropped - t0
_ymin, _ymax = _time_axis[[0,-1]]

plt.figure(figsize=[12,5])
plt.subplot(1,2,1)
plt.pcolormesh(q_cropped, _time_axis, _2d_data, cmap='seismic', shading='auto')
plt.vlines(
    np.concatenate(_q_lineouts),
    _ymin, _ymax,
    linestyles='--',
    colors=np.repeat(_color_cycle[:len(_q_lineouts)], [len(x) for x in _q_lineouts]),
    linewidth=2.5
)
for (_q_min, _q_max), _c in zip(_q_lineouts, _color_cycle):
    plt.axvspan(_q_min, _q_max, color=_c, alpha=0.3)
plt.xlabel('q ($\AA^{-1}$)', fontsize=14)
plt.ylabel('Time delay (ps)', fontsize=14)
plt.title('ΔI/I', fontsize=14)
plt.colorbar()
plt.clim(_clims)
plt.ylim(_tlims)

plt.subplot(1,2,2)
for (_q_min, _q_max), _c in zip(_q_lineouts, _color_cycle):
    _mask = (q_cropped >= _q_min) & (q_cropped <= _q_max)
    _lineout = np.nanmean(_2d_data[:, _mask],axis=1)
    plt.plot(_time_axis, _lineout, color=_c, linewidth=2.5,marker='o', label=f'{_q_min}-{_q_max} Å$^{{-1}}$')
plt.axhline(0, linestyle='--', linewidth=1)
plt.xlabel('Time delay (ps)', fontsize=14)
plt.ylabel('⟨ΔI/I⟩', fontsize=14)
plt.title('q-averaged lineouts', fontsize=14)
plt.ylim(_clims)
plt.xlim(_tlims)
plt.legend()
plt.suptitle(niceTitle)
plt.show()

### Plot selected $\Delta t$ lineouts

In [None]:
#########################################################
# Make a list of t-lineouts here like this:
_t_lineouts = [
    # [0, 0.3],
    # [1, 2],
    # [3, 4]
]
# Set a gaussian blur here
_sigma = 0
# Color limits for the plots
_clims = [-1.5, 1.5]
# q limits for the plots
_qlims = [np.nanmin(q_cropped), np.nanmax(q_cropped)]
#########################################################

# Pull out the color cycle
_color_cycle = plt.rcParams['axes.prop_cycle'].by_key()['color']
# Precompute filtered data
_2d_data = gaussian_filter(dI_I_2d_cropped, sigma=_sigma)
_time_axis = time_bins_selected_cropped - t0
_xmin, _xmax = q_cropped[[0,-1]]

plt.figure(figsize=[12,5])
plt.subplot(1,2,1)
plt.pcolormesh(q_cropped, _time_axis, _2d_data, cmap='seismic', shading='auto')
plt.hlines(
    np.concatenate(_t_lineouts),
    _xmin, _xmax,
    linestyles='--',
    colors=np.repeat(_color_cycle[:len(_t_lineouts)], [len(x) for x in _t_lineouts]),
    linewidth=2.5
)
for (_t_min, _t_max), _c in zip(_t_lineouts, _color_cycle):
    plt.axhspan(_t_min, _t_max, color=_c, alpha=0.3)
plt.xlabel('q ($\\AA^{-1}$)', fontsize=14)
plt.ylabel('Time delay (ps)', fontsize=14)
plt.title('ΔI/I', fontsize=14)
plt.colorbar()
plt.clim(_clims)
plt.xlim(_qlims)

plt.subplot(1,2,2)
for (_t_min, _t_max), _c in zip(_t_lineouts, _color_cycle):
    _mask = (_time_axis >= _t_min) & (_time_axis <= _t_max)
    _lineout = np.nanmean(_2d_data[_mask, :], axis=0)
    plt.plot(q_cropped, _lineout,
             color=_c,
             linewidth=2.5,
             marker='o',
             label=f'{_t_min}-{_t_max} ps')
plt.axhline(0, linestyle='--', linewidth=1)
plt.xlabel('q ($\\AA^{-1}$)', fontsize=14)
plt.ylabel('⟨ΔI/I⟩', fontsize=14)
plt.title('time-averaged lineouts', fontsize=14)
plt.ylim(_clims)
plt.xlim(_qlims)
plt.legend()
plt.suptitle(niceTitle)
plt.show()