### This notebook is a conglomeration of the pump probe notebook along with a translated version of the ghost imaging analysis (sans some steps) in python.

Note: It does not separate shots based on time binning. Either, I make a new notebook for that or we just use this notebook multiple times to determine time-dependent data.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import xrayscatteringtools as xrst
import warnings
from matplotlib.colors import LogNorm
from scipy.linalg import toeplitz
from scipy.optimize import nnls
import scipy.io as scio
xrst.enable_underscore_cleanup()

In [None]:
###############################################
runNumbers = [15,16,17,18,19,20,21,22,23,24,25,26,27,28] # <- this must be a list of int.
folders = '/sdf/data/lcls/ds/cxi/cxil1037623/results/davidjr/' # xrst.get_data_paths(runNumbers) # Defaults to the info in config.yaml. You can overwrite this with strings, character arrays, or lists of either.
###############################################
# (1) keys_to_combine: some keys loaded for each shot & stored per shot 
# (2) keys_to_sum: some keys loaded per each run and added 
# (3) keys_to_check : check if some keys exist and have same values in all runs and load these keys 
_keys_to_combine = [
    # Azimuthal Averages
    'jungfrau4M/azav_mask0_azav',
    'jungfrau4M/azav_mask1_azav',
    # Photon energy
    'ebeam/photon_energy',
    # Timetool data
    # 'tt/FLTPOS', 'tt/AMPL', 'tt/FLTPOSFWHM',
    # Laser-Xray Timing. Use either time tool corrected (ttc) or the other one.
    # 'scan/lxt',
    # 'scan/lxt_ttc',
    # Upstream Diode
    'ipm_dg2/sum',
    # Gas detectors
    'gas_detector/f_11_ENRC', 'gas_detector/f_22_ENRC',
    # Light status
    'lightStatus/laser', 'lightStatus/xray',
    # Spectrometer lineouts
    'feeBld/hproj',
    'unixTime',
    'epicsUser/gasCell_pressure'
]
_keys_to_sum = [
    'Sums/jungfrau4M_calib',
    # 'Sums/jungfrau4M_calib_xrayOn_thresADU1'
]
_keys_to_check = [
    'UserDataCfg/jungfrau4M/cmask', # Combined mask
    'UserDataCfg/jungfrau4M/azav_mask0__azav_mask0_q', # q bin centers
    'UserDataCfg/jungfrau4M/azav_mask0__azav_mask0_qbin', # q bin size
    'UserDataCfg/jungfrau4M/azav_mask0__azav_mask0_qbins', # q bin edges
    'UserDataCfg/jungfrau4M/azav_mask0__azav_mask0_userMask', # Mask for unfiltered side
    'UserDataCfg/jungfrau4M/azav_mask1__azav_mask1_userMask', # Mask for filtered side
    # These keys are typically not needed, but feel free to uncomment them.
    # 'UserDataCfg/jungfrau4M/azav__azav_idxq',
    # 'UserDataCfg/jungfrau4M/azav__azav_idxphi',
    # 'UserDataCfg/jungfrau4M/azav__azav_nphi',
    # 'UserDataCfg/jungfrau4M/azav__azav_matrix_q',
    # 'UserDataCfg/jungfrau4M/azav__azav_matrix_phi',
]
##### Load the data in #####
_data = xrst.combineRuns(runNumbers, folders, _keys_to_combine, _keys_to_sum, _keys_to_check, verbose=False, archImport=True)  # this is the function to load the data with defined keys
############################
# String for nice things
runNumbersRange = xrst.compress_ranges(runNumbers)
runType = xrst.get_config_for_runs(runNumbers[0],'samples','sample') # Default to information in the first run you load.
niceTitle = f"{xrst.get_config('expNumber')} : {'Run' if np.size(runNumbers)==1 else 'Runs'} {runNumbersRange} : {runType}"

# Cmask
_cmask = _data['UserDataCfg/jungfrau4M/cmask'].astype(bool) # Mask for detector created 

# Unfiltered data
azav_filtered = np.squeeze(_data['jungfrau4M/azav_mask0_azav']) # I(q) : 1D azimuthal average of signals in each q bin
_filtered_mask = _data['UserDataCfg/jungfrau4M/azav_mask0__azav_mask0_userMask'].astype(bool)
# Filtered data
azav_unfiltered = np.squeeze(_data['jungfrau4M/azav_mask1_azav']) # I(q) : 1D azimuthal average of signals in each q bin
_unfiltered_mask = _data['UserDataCfg/jungfrau4M/azav_mask1__azav_mask1_userMask'].astype(bool)

# Combining cmask and the half masks for ease of use
cmask_unfiltered = _cmask & _unfiltered_mask
cmask_filtered = _cmask & _filtered_mask

J4MSum_unfiltered = np.nansum(azav_unfiltered,axis=tuple(range(1, azav_unfiltered.ndim)))
J4MSum_filtered = np.nansum(azav_filtered,axis=tuple(range(1, azav_filtered.ndim)))
q = _data['UserDataCfg/jungfrau4M/azav_mask0__azav_mask0_q'] # q bin centers
qbin = _data['UserDataCfg/jungfrau4M/azav_mask0__azav_mask0_qbin'] # q bin-size
qbins = _data['UserDataCfg/jungfrau4M/azav_mask0__azav_mask0_qbins'] # q bins edges
# jungfrau_sum = _data['Sums/jungfrau4M_calib_thresADU1']   # Total Jungfrau detector counts with Thresholds added, summed in a run

# Scan and timetool information
# scan = _data['scan/lxt_ttc']
# # scan = data['scan/lxt']
# time_bins = (np.unique(scan[:]))*10**12  # list of time bins in picoseconds
# ttpos = _data['tt/FLTPOS']
# ttampl = _data['tt/AMPL']
# ttfwhm = _data['tt/FLTPOSFWHM']

# Gas pressure
gasPressure = _data['epicsUser/gasCell_pressure']

# Event codes
laserOn = _data['lightStatus/laser'].astype(bool)  # laser on events
xrayOn = _data['lightStatus/xray'].astype(bool)  # xray on events
run_indicator = _data['run_indicator'] # run indicator for each shot

# X-ray beam diagnostics
dg2 = _data['ipm_dg2/sum']   # upstream diode x-ray intensity
pulse_energy = _data['gas_detector/f_11_ENRC']   # xray energy from gas detector (not calibrated to actual values)
photon_energy = _data['ebeam/photon_energy']    # x-ray energy energy in eV
spec = _data['feeBld/hproj'] # Shot to shot spectrometer
# Print total shots
_total_shots = len(run_indicator)
print("Total shots: ", _total_shots)

## Determine if there is any shot offset between the spectrometer and the gas detectors

In [None]:
plt.figure(figsize=[20,5])
plt.subplot(1,4,1)
plt.hist2d(spec.sum(axis=1)[0:-3],pulse_energy[2:-1],bins=100,norm=LogNorm(),range=((np.nanmin(spec.sum(axis=1)),np.nanmax(spec.sum(axis=1))),(np.nanmin(pulse_energy),(np.nanmax(pulse_energy)))));
plt.xlabel('XRT Spectrometer Sum')
plt.ylabel('J4M Sum')
plt.title('Spectrometer Offset: -2 Shots')
plt.subplot(1,4,2)
plt.hist2d(spec.sum(axis=1)[0:-2],pulse_energy[1:-1],bins=100,norm=LogNorm(),range=((np.nanmin(spec.sum(axis=1)),np.nanmax(spec.sum(axis=1))),(np.nanmin(pulse_energy),(np.nanmax(pulse_energy)))));
plt.xlabel('XRT Spectrometer Sum')
plt.ylabel('J4M Sum')
plt.title('Spectrometer Offset: -1 Shot')
plt.subplot(1,4,3)
plt.hist2d(spec.sum(axis=1),pulse_energy,bins=100,norm=LogNorm(),range=((np.nanmin(spec.sum(axis=1)),np.nanmax(spec.sum(axis=1))),(np.nanmin(pulse_energy),(np.nanmax(pulse_energy)))));
plt.xlabel('XRT Spectrometer Sum')
plt.ylabel('J4M Sum')
plt.title('Spectrometer Offset: 0 Shots')
plt.subplot(1,4,4)
plt.hist2d(spec.sum(axis=1)[1:-1],pulse_energy[0:-2],bins=100,norm=LogNorm(),range=((np.nanmin(spec.sum(axis=1)),np.nanmax(spec.sum(axis=1))),(np.nanmin(pulse_energy),(np.nanmax(pulse_energy)))));
plt.xlabel('XRT Spectrometer Sum')
plt.ylabel('J4M Sum')
plt.title('Spectrometer Offset: +1 Shot');

## Filtering the shots based on the detector readings

In [None]:
########## Different filter cutoffs
_J4M_cutoff = [0.1, 1];
_dg2_cutoff = [0.1, 1];
_spec_cutoff = [0.04, 0.19];
_pulse_energy_cutoff = [0.5, 1.5] # In mJ!!!
_tt_edgePos_cutoff = [0.1, 1]     
_tt_amp_cutoff = [0.3, 1]
_tt_width_cutoff = [0.1, 0.3]
_plot_display = 'log' # 'linear'
##########

###### Set the offset Here ######
_offset = -1;
#################################
# Determine slice indices based on offset
if _offset == -1:
    _spec_slice = slice(0, -2)
    _azav_slice = slice(1, -1)
elif _offset == 1:
    _spec_slice = slice(1, -1)
    azav_slice = slice(0, -2)
else:  # offset == 0 or any default
    _spec_slice = azav_slice = slice(None)

# Apply slices
xray_spectrum = spec[_spec_slice]
I_uf = azav_unfiltered[_azav_slice]
I_f = azav_filtered[_azav_slice]
J4MSumSliced = J4MSum_unfiltered[_azav_slice]
J4MSumFilterSliced = J4MSum_filtered[_azav_slice]
pulse_energySliced = pulse_energy[_azav_slice]
gasPressure = gasPressure[_azav_slice]
I_0 = dg2[_azav_slice]
xrayOnSliced = xrayOn[_azav_slice]
laserOnSliced = laserOn[_azav_slice]
# ttposSliced = ttpos[_azav_slice]
# ttamplSliced = ttampl[_azav_slice]
# ttfwhmSliced = ttfwhm[_azav_slice]

# Precomputing the normalized values
_J4MSumSlicedNorm = J4MSumSliced/np.nanmax(J4MSumSliced);
_I_0Norm = I_0/np.nanmax(I_0);
_xray_spectrumNorm = xray_spectrum.sum(axis=1)/xray_spectrum.sum(axis=1).max();
# _ttposSlicedNorm = ttposSliced/np.nanmax(ttposSliced)
# _ttamplSlicedNorm = ttamplSliced/np.nanmax(ttamplSliced)
# _ttfwhmSlicedNorm = ttfwhmSliced/np.nanmax(ttfwhmSlicedm)

## Plotting

plt.figure(figsize=[17,10]) 

plt.subplot(2,4,1)
plt.hist(_J4MSumSlicedNorm,bins=200,range=[0,np.nanmax(_J4MSumSlicedNorm)]);
plt.axvline(_J4M_cutoff[0],color='r',linestyle='--')
plt.axvline(_J4M_cutoff[1],color='r',linestyle='--')
plt.axvspan(_J4M_cutoff[0],_J4M_cutoff[1],color='r',alpha=0.2,label='"Good" Data')
plt.yscale(_plot_display)
plt.title('J4M Sum Histogram')
plt.xlabel('Fraction of Maximum')
plt.ylabel('Counts')
plt.legend()

plt.subplot(2,4,2)
plt.hist(_I_0Norm,bins=200,range=[0,np.nanmax(_I_0Norm)]);
plt.axvline(_dg2_cutoff[0],color='r',linestyle='--')
plt.axvline(_dg2_cutoff[1],color='r',linestyle='--')
plt.axvspan(_dg2_cutoff[0],_dg2_cutoff[1],color='r',alpha=0.2,label='"Good" Data')
plt.yscale(_plot_display)
plt.title('DG2 Sum Histogram')
plt.xlabel('Fraction of Maximum')
plt.ylabel('Counts')
plt.legend()

plt.subplot(2,4,3)
plt.hist(pulse_energySliced,bins=200,range=[0,np.nanmax(pulse_energySliced)]);
plt.axvline(_pulse_energy_cutoff[0],color='r',linestyle='--')
plt.axvline(_pulse_energy_cutoff[1],color='r',linestyle='--')
plt.axvspan(_pulse_energy_cutoff[0],_pulse_energy_cutoff[1],color='r',alpha=0.2,label='"Good" Data')
plt.yscale(_plot_display)
plt.title('Pulse Energy Histogram')
plt.xlabel('mJ')
plt.ylabel('Counts')
plt.legend()

plt.subplot(2,4,4)
plt.hist(_xray_spectrumNorm,bins=200,range=[0,np.nanmax(_xray_spectrumNorm)]);
plt.axvline(_spec_cutoff[0],color='r',linestyle='--')
plt.axvline(_spec_cutoff[1],color='r',linestyle='--')
plt.axvspan(_spec_cutoff[0],_spec_cutoff[1],color='r',alpha=0.2,label='"Good" Data')
plt.yscale(_plot_display)
plt.title('Spectrometer Sum Histogram')
plt.xlabel('mJ')
plt.ylabel('Counts')
plt.legend()


# Uncomment with timetool data
# plt.subplot(2,4,5)
# plt.hist(_ttposSlicedNorm,bins=200,range=[0,np.nanmax(_ttposSlicedNorm)]);
# plt.axvline(_tt_edgePos_cutoff[0],color='r',linestyle='--')
# plt.axvline(_tt_edgePos_cutoff[1],color='r',linestyle='--')
# plt.axvspan(_tt_edgePos_cutoff[0],_tt_edgePos_cutoff[1],color='r',alpha=0.2,label='"Good" Data')
# plt.yscale(_plot_display)
# plt.title('Timetool Edge Histogram')
# plt.xlabel('Fraction of Maximum')
# plt.ylabel('Counts')
# plt.legend()

# plt.subplot(2,4,6)
# plt.hist(_ttamplSlicedNorm,bins=200,range=[0,np.nanmax(_ttamplSlicedNorm)]);
# plt.axvline(_tt_amp_cutoff[0],color='r',linestyle='--')
# plt.axvline(_tt_amp_cutoff[1],color='r',linestyle='--')
# plt.axvspan(_tt_amp_cutoff[0],_tt_amp_cutoff[1],color='r',alpha=0.2,label='"Good" Data')
# plt.yscale(_plot_display)
# plt.title('Timetool Amplitude Histogram')
# plt.xlabel('Fraction of Maximum')
# plt.ylabel('Counts')
# plt.legend()

# plt.subplot(2,4,7)
# plt.hist(_ttfwhmSlicedNorm,bins=200,range=[0,np.nanmax(_ttfwhmSlicedNorm)]);
# plt.axvline(_tt_width_cutoff[0],color='r',linestyle='--')
# plt.axvline(_tt_width_cutoff[1],color='r',linestyle='--')
# plt.axvspan(_tt_width_cutoff[0],_tt_width_cutoff[1],color='r',alpha=0.2,label='"Good" Data')
# plt.yscale(_plot_display)
# plt.title('Timetool FWHM Histogram')
# plt.xlabel('Fraction of Maximum')
# plt.ylabel('Counts')
# plt.legend(  )
plt.show()


goodIdx = np.logical_and.reduce([
    xrayOnSliced,
    _J4M_cutoff[0] <= _J4MSumSlicedNorm,
    _J4MSumSlicedNorm <= _J4M_cutoff[1],
    _dg2_cutoff[0] <= _I_0Norm,
    _I_0Norm <= _dg2_cutoff[1],
    _pulse_energy_cutoff[0] <= pulse_energySliced,
    pulse_energySliced <= _pulse_energy_cutoff[1],
    _spec_cutoff[0] <= _xray_spectrumNorm,
    _xray_spectrumNorm <= _spec_cutoff[1],
    ~np.isnan(J4MSumSliced)
])

# goodIdx_timetool = np.logical_and.reduce([
#     _tt_edgePos_cutoff[0] <= _ttposSlicedNorm,
#     _ttposSlicedNorm <= _tt_edgePos_cutoff[1],
#     _tt_amp_cutoff[0] <= _ttamplSlicedNorm,
#     _ttamplSlicedNorm <= _tt_amp_cutoff[1],
#     _tt_width_cutoff[0] <= _ttfwhmSlicedNorm,
#     _ttfwhmSlicedNorm <= _tt_width_cutoff[1],
#     xrayOn, laserOn
# ])

# Displaying how much data was kept due to this filtering
_counts, _ = np.histogram(goodIdx.astype(int),[0,1,2]);
# _counts_timetool, _ = np.histogram(goodIdx_timetool.astype(int),[0,1,2]);
print(f'goodIdx data represents {_counts[1]/np.sum(_counts)*100:.2f}% of the total shots. ({_counts[1]} out of {np.sum(_counts)}).')
# print(f'goodIdx_timetool data represents {_counts_timetool[1]/np.sum(_counts_timetool)*100:.2f}% of the total shots. ({_counts_timetool[1]} out of {np.sum(_counts_timetool)}).')

In [None]:
##########
# _slope = xrst.get_config_for_runs(runNumbers[0],'spectrometer_calib','slope')
# _intercept = xrst.get_config_for_runs(runNumbers[0],'spectrometer_calib','intercept')
##########
# xray_spectrum_range = np.arange(len(specGood[0])) #* _slope + _intercept
xray_spectrum_range = np.load('15-28.npy')*1000 # Temp

In [None]:
# ---------- User parameters ----------
_CROP_FACTOR = 0.6       # fraction of the spectrum to keep
_FRACTION_INSIDE = 0.9925
_PLOT_SPEC_FILTER = True
# --------------------------------------

_cropin = round((1 - _CROP_FACTOR) / 2 * xray_spectrum.shape[1])
_cropped_sum_norm = (
    xray_spectrum[:, _cropin : xray_spectrum.shape[1] - _cropin].sum(axis=1)
    / xray_spectrum.sum(axis=1)
)
_mask_spec_crop = _cropped_sum_norm > _FRACTION_INSIDE

_spec_diff_std = np.std(np.diff(xray_spectrum, axis=1), axis=1)
_mask_noise = _spec_diff_std <= (_spec_diff_std.mean() + _spec_diff_std.std())

if _PLOT_SPEC_FILTER:
    fig, axes = plt.subplots(2, 2, figsize=(12, 8))

    ax = axes[0, 0]
    ax.plot(xray_spectrum_range, xray_spectrum.mean(axis=0))
    ax.axvline(xray_spectrum_range[_cropin], color="r", ls="--", label="Crop limit")
    ax.axvline(xray_spectrum_range[-_cropin - 1], color="r", ls="--")
    ax.axvspan(xray_spectrum_range[_cropin], xray_spectrum_range[-_cropin - 1],
               alpha=0.1, color="r")
    ax.set(title="Mean X-ray Spectrum", xlabel="Energy (eV)", ylabel="Intensity (a.u.)")

    ax = axes[0, 1]
    ax.hist(_cropped_sum_norm, bins=100)
    ax.set_yscale("log")
    ax.axvline(_FRACTION_INSIDE, color="r", ls="--", label="Threshold")
    ax.set(title="Cropped & normalized spectrum histogram",
           xlabel="Normalized intensity", ylabel="Counts")

    ax = axes[1, 0]
    ax.hist(_spec_diff_std, bins=100)
    ax.set_yscale("log")
    ax.axvline(_spec_diff_std.mean() + _spec_diff_std.std(), color="r", ls="--")
    ax.set(title="Std of spectral difference", xlabel="Std", ylabel="Counts")

    ax = axes[1, 1]
    ax.plot(_spec_diff_std)
    ax.axhline(_spec_diff_std.mean() + _spec_diff_std.std(), color="r", ls="--")
    ax.set(title="Spectral diff std per shot", xlabel="Shot index", ylabel="Std")
    
    plt.tight_layout()
    plt.show()

mask_spectrum = _mask_noise & _mask_spec_crop
print(f"Fraction of good data after spectrometer filtering: "
      f"{mask_spectrum.sum() / mask_spectrum.size:.4f}")

## Timetool correction / rebinning if necessary

In [None]:
######## Temp
_shotIdx = np.arange(len(goodIdx))
scan = (np.digitize(_shotIdx,np.linspace(0,len(goodIdx),5))-2)*(10**-12)
plt.plot(scan)
time_bins = np.unique(scan)*(10**12)
goodIdx_timetool = np.ones_like(scan).astype(bool)

In [None]:
#############################################################################################################################################
do_timetool_correction = False
# Edit this as desired, will only have an effect if do_timetool_correction is True. Defaults to time_bins if list is empty. 
_new_time_bin_edges = []
# Use the data in config.yaml or set your own, these should be in seconds, and can be scalar, sequence or ndarray
# _tt_slopes = xrst.get_config_for_runs(runNumbers,'tt_calibration','slope')
# _tt_intercepts = xrst.get_config_for_runs(runNumbers,'tt_calibration','intercept')
# Bool if you want
_plotting = True
#############################################################################################################################################

# Select bins, in ps
if do_timetool_correction and len(_new_time_bin_edges)!=0:
    time_bins_selected = _new_time_bin_edges
else:
    # Default to the time bins that were recorded
    time_bins_selected = time_bins

# Doing the timetool correction on a run by run basis
if do_timetool_correction:
    time_delay = xrst.calib.apply_timetool_correction(
        delays = scan,
        edge_positions = ttposSliced,
        slopes = _tt_slopes,
        intercepts = _tt_intercepts,
        run_indicator = run_indicator # Optional, if None, only one timetool calibration applied to all shots.
    )
    if _plotting:
        plt.figure(figsize=[12,5])
        plt.subplot(1,2,1)
        plt.hist(time_delay[goodIdx_timetool]*10**12,bins=200)
        plt.xlabel('Delay (ps)')
        plt.title('Timetool corrected delays, before binning')
        plt.ylabel('Counts')
        plt.subplot(1,2,2)
else:
    time_delay = scan
    
# Convert to ps, determione bin idx for each shot
binned_shot_idxs = np.digitize(time_delay*10**12,time_bins_selected)

if _plotting:
    plt.hist(time_delay[goodIdx_timetool]*10**12,bins=time_bins_selected)
    plt.xlabel('Delay (ps)')
    plt.title('Counts per selected time bin')
    plt.ylabel('Counts')
plt.suptitle(niceTitle)
plt.show()

### Determine final shot masks

In [None]:
## Temp 
laserOn[7-1::7] = False

In [None]:
if do_timetool_correction:
    # goodIdx_timetool already includes laserOn shots
    on_mask =  goodIdx & goodIdx_timetool & mask_spectrum
else:
    # Just the good shots where the laser is on
    on_mask = goodIdx & laserOnSliced & mask_spectrum
# Timetool isn't needed for the off shots
off_mask = goodIdx & ~laserOnSliced & mask_spectrum

# Displaying how much data was kept due to this filtering
_counts, _ = np.histogram(on_mask.astype(int),[0,1,2]);
print(f'Combined filters reatined {_counts[1]/np.sum(laserOnSliced)*100:.2f}% of the laser on shots. ({_counts[1]} out of {np.sum(laserOnSliced)}).')
_counts, _ = np.histogram(off_mask.astype(int),[0,1,2]);
print(f'Combined filters reatined {_counts[1]/np.sum(~laserOnSliced)*100:.2f}% of the laser off shots. ({_counts[1]} out of {np.sum(~laserOnSliced)}).')

In [None]:

# Apply the union of on/off masks to select all analyzable shots
_all_good = on_mask | off_mask

gasPressure_m   = gasPressure[_all_good]
xray_spectrum_m = xray_spectrum[_all_good]
I_f_m           = I_f[_all_good]
I_uf_m          = I_uf[_all_good]
I_0_m           = I_0[_all_good]
I_ratio         = I_f_m / I_uf_m

# Track laser status and time bin index within the masked set
is_on_m        = on_mask[_all_good]
is_off_m       = off_mask[_all_good]
time_bin_idx_m = binned_shot_idxs[_all_good]

# Build ordered list of time groups: index 0 = off, 1..N = laser-on time bins
_unique_on_bins = np.sort(np.unique(time_bin_idx_m[is_on_m]))
time_group_masks = [is_off_m]       # index 0 = laser off
delay_labels     = ['off']
delay_values     = [np.nan]         # no delay for off shots

for _b in _unique_on_bins:
    if 1 <= _b < len(time_bins_selected):
        _center = (time_bins_selected[_b - 1] + time_bins_selected[_b]) / 2
    elif _b == 0:
        _center = time_bins_selected[0]
    else:
        _center = time_bins_selected[-1]
    delay_values.append(_center)
    delay_labels.append(f'{_center:.2f} ps')
    time_group_masks.append(is_on_m & (time_bin_idx_m == _b))

delay_values  = np.array(delay_values)
n_time_groups = len(delay_labels)

# Report shot counts per group
print(f"{'Group':<15} {'Shots':>8}")
print("-" * 25)
for _lbl, _m in zip(delay_labels, time_group_masks):
    print(f"  {_lbl:<13} {np.sum(_m):>6}")
print(f"\nTotal time groups: {n_time_groups}")



## Preprocessing for ghost imaging


### Rebin the spectrum coarser

In [None]:
# ---------- User parameters ----------
_N_REBIN = 8
SPECTRUM_CROP = (0, 256)
_PLOT_REBIN = True
# --------------------------------------

_n_shots, _n_spec = xray_spectrum_m.shape
_n_bins = _n_spec // _N_REBIN

# Rebin by averaging groups of _N_REBIN consecutive channels
_xray_spec_reduced = xray_spectrum_m.reshape(_n_shots, _n_bins, _N_REBIN).mean(axis=2)
_idx_centers = np.floor(np.arange(_n_spec).reshape(_n_bins, _N_REBIN).mean(axis=1)).astype(int)
xray_range_reduced = xray_spectrum_range[_idx_centers]

# Normalize each shot's spectrum to unit integral
xray_spec_norm = _xray_spec_reduced / _xray_spec_reduced.sum(axis=1, keepdims=True)

if _PLOT_REBIN:
    fig, ax = plt.subplots()
    ax.plot(xray_range_reduced, xray_spec_norm.mean(axis=0))
    ax.axvline(xray_range_reduced[SPECTRUM_CROP[0]], color="r", ls="--", label="Lower limit")
    ax.axvline(xray_range_reduced[SPECTRUM_CROP[1] - 1], color="r", ls="--", label="Upper limit")
    ax.set(title="Mean Reduced X-ray Spectrum", xlabel="Energy (eV)", ylabel="Intensity (a.u.)")
    ax.legend()
    plt.show()

### This cell won't really do anything because the fluorescence correction has to be done in post, and we don't have time for it. For now I artifically set the fluorescence correction array to all one so it doesn't modify the signal anymore.

In [None]:

# fl_struct = sio.loadmat("F_Correction_9666.mat", squeeze_me=True)
_fl_q       = q  # np.asarray(fl_struct["q"], dtype=np.float64)
_flu_f_corr  = np.ones_like(q)  # np.asarray(fl_struct["flu_f_corr"], dtype=np.float64)
_flu_uf_corr = np.ones_like(q)  # np.asarray(fl_struct["flu_uf_corr"], dtype=np.float64)
_above_ratio = _flu_f_corr / _flu_uf_corr

# Build fluorescence correction matrices — vectorised
_above_mask = xray_range_reduced > 9664  # bool array (n_bins,)

_ZnFlu    = np.ones((len(xray_range_reduced), len(q)))
_ZnFlu_uf = np.ones((len(xray_range_reduced), len(q)))

_above_ratio_interp = np.interp(q, _fl_q, _above_ratio)
_flu_uf_interp      = np.interp(q, _fl_q, _flu_uf_corr)

_ZnFlu[_above_mask, :]    = _above_ratio_interp[None, :]
_ZnFlu_uf[_above_mask, :] = _flu_uf_interp[None, :]

# --- Plot ---
fig, ax = plt.subplots()
_im = ax.imshow(
    _ZnFlu.T, aspect="auto", origin="lower", cmap="hot",
    extent=[xray_range_reduced[0], xray_range_reduced[-1], q[0], q[-1]],
)
ax.set(xlabel="Energy (eV)", ylabel=r"$q\;(\AA^{-1})$", title="Fluorescence Correction")
plt.colorbar(_im, ax=ax)
plt.show()

# Apply fluorescence correction to all shots
_shot_flu_corr    = xray_spec_norm @ _ZnFlu       # (n_shots, n_q)
_shot_uf_flu_corr = xray_spec_norm @ _ZnFlu_uf    # (n_shots, n_q)
I_ratio_corrected = I_ratio * _shot_flu_corr

# Compute I_tot per time group
I_tot     = np.zeros((len(q), n_time_groups))
I_tot_std = np.zeros((len(q), n_time_groups))
for _t, _m in enumerate(time_group_masks):
    _norm = I_uf_m[_m] / (I_0_m[_m, None] * gasPressure_m[_m, None]) * _shot_uf_flu_corr[_m]
    I_tot[:, _t]     = np.nanmean(_norm, axis=0)
    I_tot_std[:, _t] = np.nanstd(_norm, axis=0)

print(f"I_tot shape: {I_tot.shape}  →  (q, time_groups)")



## Time-dependent ghost imaging analysis
Compute A-vectors, filter function, and deconvolved scattering spectrum for each time group (off + each laser-on delay bin).


In [None]:

# ---------- User parameters ----------
_LAMBDA = 1.5   # regularisation strength
_PLOT_A = True
# --------------------------------------

_X_all = xray_spec_norm[:, SPECTRUM_CROP[0]:SPECTRUM_CROP[1]]   # (n_shots, n_features)
_Y_all = I_ratio_corrected                                      # (n_shots, n_q)
_n_features = _X_all.shape[1]

# 2nd-derivative penalty matrix (shared across time groups)
_D2 = np.diff(np.eye(_n_features), n=2, axis=0)  # (n_features-2, n_features)
_R  = _D2.T @ _D2                                 # (n_features, n_features)

# Compute A for each time group
A = np.zeros((_n_features, len(q), n_time_groups))
for _t, _m in enumerate(time_group_masks):
    _X = _X_all[_m]
    _Y = _Y_all[_m]
    _XtX = _X.T @ _X
    _XtY = _X.T @ _Y
    A[:, :, _t] = np.linalg.solve(_XtX + _LAMBDA * _R, _XtY)

if _PLOT_A:
    _E_crop = xray_range_reduced[SPECTRUM_CROP[0]:SPECTRUM_CROP[1]]
    _ncols = min(n_time_groups, 4)
    _nrows = int(np.ceil(n_time_groups / _ncols))
    fig, axes = plt.subplots(_nrows, _ncols, figsize=(5 * _ncols, 4 * _nrows), squeeze=False)
    for _t in range(n_time_groups):
        _ax = axes[_t // _ncols, _t % _ncols]
        for _qi in range(A.shape[1]):
            _c = (0, 0.5, _qi / A.shape[1])
            _ax.plot(_E_crop, A[:, _qi, _t], lw=0.2, color=_c)
        _ax.set(xlim=(9645, 9685), ylim=(-0.1, 0.9),
                title=f"A-vectors: {delay_labels[_t]}",
                xlabel="Photon Energy (eV)", ylabel="Intensity (arb.)")
    for _i in range(n_time_groups, _nrows * _ncols):
        axes[_i // _ncols, _i % _ncols].set_visible(False)
    plt.tight_layout()
    plt.show()

print(f"A shape: {A.shape}  →  (energy, q, time_groups)")



### Determine the filter function from the laser-off reference
The Zn foil filter function is derived from the laser-off A-vectors (ground state) and applied uniformly to all time groups during deconvolution.


In [None]:

# ---------- User parameters ----------
E_MIN, E_MAX = 9640, 9685  # eV
PLOT_FILTER = True
# --------------------------------------

_mask_e   = (xray_range_reduced >= E_MIN) & (xray_range_reduced < E_MAX)
_mask_e_crop = _mask_e[SPECTRUM_CROP[0]:SPECTRUM_CROP[1]]

# Energy-crop A for all time groups
A_short  = A[_mask_e_crop, :, :]   # (N_short, n_q, n_time_groups)
A_short  = np.nan_to_num(A_short, nan=0.0, posinf=0.0, neginf=0.0)

E_pix    = xray_range_reduced[_mask_e]
N_short  = len(E_pix)

WAVELENGTH = 1.2827289955778933  # Å
twotheta = 2 * np.arcsin(q * WAVELENGTH / (4 * np.pi))

# Reference range (low-q)  — MATLAB indices 17:19 → Python 16:19
QLOW1, QLOW2 = 16, 19  # 0-based, slice is [16:19]

# Derive filter function from laser-OFF A-vectors (time group index 0)
_A_off_ref = np.clip(A_short[:, QLOW1:QLOW2, 0], a_min=1e-30, a_max=None)
_mut = np.mean(-np.log(_A_off_ref) * np.cos(twotheta[QLOW1:QLOW2]), axis=1)
f_q = np.real(np.exp(-_mut[:, None] / np.cos(twotheta[None, :])))
f_q = np.nan_to_num(f_q, nan=0.0, posinf=0.0, neginf=0.0)

QVALUE = 39  # 0-based index for plotting

if PLOT_FILTER:
    fig, ax = plt.subplots()
    ax.plot(E_pix, f_q[:, QVALUE], lw=2, label="Filter Function (from off)")
    ax.plot(E_pix, A_short[:, QVALUE, 0], lw=2, label="A Vector (off)")
    ax.set(xlim=(E_MIN, E_MAX), title="Filter Function (derived from laser-off)",
           xlabel="Photon Energy (eV)", ylabel="Intensity (arb.)")
    ax.text(0.65, 0.75, f"q = {q[QVALUE]:.2f}", transform=ax.transAxes, fontsize=16)
    ax.legend()
    plt.show()



### Deconvolve scattering spectrum for each time group
Produces `S_spec` with shape `(energy, q, time_groups)` where index 0 is laser-off.


In [None]:

# ---------- User parameters ----------
PLOT_SCATTER = True
# --------------------------------------

CondN  = np.zeros(len(q))
S_spec = np.zeros((N_short, len(q), n_time_groups))

for _t in range(n_time_groups):
    for _qi in range(len(q)):
        # Build Toeplitz filter matrix
        _col = np.zeros(N_short)
        _col[0] = f_q[0, _qi]
        _row = f_q[:, _qi]
        _F_Zn = toeplitz(_col, _row).T

        if _t == 0:
            CondN[_qi] = np.linalg.cond(_F_Zn)
        S_spec[:, _qi, _t], _ = nnls(_F_Zn, A_short[:, _qi, _t])

if PLOT_SCATTER:
    _E_loss = E_pix - E_pix[0]
    _ncols = min(n_time_groups, 4)
    _nrows = int(np.ceil(n_time_groups / _ncols))

    # Average scattering spectrum per time group
    fig, axes = plt.subplots(_nrows, _ncols, figsize=(5 * _ncols, 4 * _nrows), squeeze=False)
    for _t in range(n_time_groups):
        _ax = axes[_t // _ncols, _t % _ncols]
        _ax.plot(_E_loss, S_spec[:, 16:, _t].mean(axis=1), "r.-")
        _ax.axhline(0, color="k", ls="--", lw=1)
        _ax.set(title=f"SF$_6$ Scattering — {delay_labels[_t]}",
                xlabel="X-Ray Energy Loss (eV)")
    for _i in range(n_time_groups, _nrows * _ncols):
        axes[_i // _ncols, _i % _ncols].set_visible(False)
    plt.tight_layout()
    plt.show()

    # Condition number vs q (same filter for all groups)
    fig, ax = plt.subplots()
    ax.plot(q, CondN, "k.-")
    ax.plot(q, CondN, "r*")
    ax.axhline(0, color="k", lw=1)
    ax.set(title="Filter Matrix Condition Number vs q",
           xlabel=r"$q\;(\AA^{-1})$")
    ax.text(0.55, 0.8, "Unshifted Filter Matrix", transform=ax.transAxes, fontsize=16)
    plt.show()

    # 2-D colour map S(q, dE) per time group
    fig, axes = plt.subplots(_nrows, _ncols, figsize=(5 * _ncols, 5 * _nrows), squeeze=False)
    for _t in range(n_time_groups):
        _ax = axes[_t // _ncols, _t % _ncols]
        _im = _ax.imshow(
            S_spec[:, :, _t].T, aspect="auto", origin="lower", cmap="hot",
            extent=[_E_loss[0], _E_loss[-1], q[0], q[-1]],
            vmin=0, vmax=0.05,
        )
        _ax.set_ylim(q[14], q[-1])
        _ax.set(xlabel="Energy Loss (eV)", ylabel=r"$q\;(\AA^{-1})$",
                title=f"S(q, dE) — {delay_labels[_t]}")
        plt.colorbar(_im, ax=_ax)
    for _i in range(n_time_groups, _nrows * _ncols):
        axes[_i // _ncols, _i % _ncols].set_visible(False)
    plt.tight_layout()
    plt.show()

print(f"S_spec shape: {S_spec.shape}  →  (energy, q, time_groups)")
print(f"Time groups: {delay_labels}")


In [None]:

# ---------- User parameters ----------
_PLOT_ABS = True
_VMAX_ABS = None  # set to a float to fix the colour scale, None for auto
# --------------------------------------

# S_spec : (energy, q, time_groups)  — normalised (total = 1)
# I_tot  : (q, time_groups)          — absolute total scattering per q per time group
# Broadcast I_tot across the energy axis to recover absolute scattering
S_spec_abs = S_spec * I_tot[None, :, :]   # (energy, q, time_groups)

if _PLOT_ABS:
    _E_loss = E_pix - E_pix[0]
    _ncols = min(n_time_groups, 4)
    _nrows = int(np.ceil(n_time_groups / _ncols))

    # --- 1D: q-averaged absolute scattering spectrum per time group ---
    fig, axes = plt.subplots(_nrows, _ncols, figsize=(5 * _ncols, 4 * _nrows), squeeze=False)
    for _t in range(n_time_groups):
        _ax = axes[_t // _ncols, _t % _ncols]
        _ax.plot(_E_loss, S_spec_abs[:, 16:, _t].mean(axis=1), "b.-")
        _ax.axhline(0, color="k", ls="--", lw=1)
        _ax.set(title=f"Absolute Scattering — {delay_labels[_t]}",
                xlabel="X-Ray Energy Loss (eV)", ylabel="Intensity (arb.)")
    for _i in range(n_time_groups, _nrows * _ncols):
        axes[_i // _ncols, _i % _ncols].set_visible(False)
    plt.tight_layout()
    plt.show()

    # --- 2D colour maps: S_abs(q, dE) per time group ---
    _vmax = _VMAX_ABS if _VMAX_ABS is not None else np.nanpercentile(S_spec_abs[:, 16:, :], 99)
    fig, axes = plt.subplots(_nrows, _ncols, figsize=(5 * _ncols, 5 * _nrows), squeeze=False)
    for _t in range(n_time_groups):
        _ax = axes[_t // _ncols, _t % _ncols]
        _im = _ax.imshow(
            S_spec_abs[:, :, _t].T, aspect="auto", origin="lower", cmap="hot",
            extent=[_E_loss[0], _E_loss[-1], q[0], q[-1]],
            vmin=0, vmax=_vmax,
        )
        _ax.set_ylim(q[14], q[-1])
        _ax.set(xlabel="Energy Loss (eV)", ylabel=r"$q\;(\AA^{-1})$",
                title=f"Abs. Scattering — {delay_labels[_t]}")
        plt.colorbar(_im, ax=_ax)
    for _i in range(n_time_groups, _nrows * _ncols):
        axes[_i // _ncols, _i % _ncols].set_visible(False)
    plt.tight_layout()
    plt.show()

print(f"S_spec_abs shape: {S_spec_abs.shape}  →  (energy, q, time_groups)")


In [None]:

# ---------- User parameters ----------
ELASTIC_RANGE  = (0, 3)     # energy-loss range in eV for elastic
INELASTIC_RANGE = (3, 45)   # energy-loss range in eV for inelastic
_CLIM_PCT = 10               # symmetric colour limit (%)
# --------------------------------------

_E_loss = E_pix - E_pix[0]

# Build energy masks
_mask_elastic   = (_E_loss >= ELASTIC_RANGE[0])  & (_E_loss < ELASTIC_RANGE[1])
_mask_inelastic = (_E_loss >= INELASTIC_RANGE[0]) & (_E_loss < INELASTIC_RANGE[1])

# Integrate S_spec_abs over the energy ranges  → (q, n_time_groups)
_S_elastic   = S_spec_abs[_mask_elastic, :, :].sum(axis=0)
_S_inelastic = S_spec_abs[_mask_inelastic, :, :].sum(axis=0)

# Percent difference relative to laser-off (time group 0)
_S_elastic_off   = _S_elastic[:, 0]
_S_inelastic_off = _S_inelastic[:, 0]

# Only compute for laser-on groups (indices 1:)
_on_delays = delay_values[1:]
_pct_elastic   = (_S_elastic[:, 1:]   - _S_elastic_off[:, None])   / np.where(_S_elastic_off[:, None] != 0, _S_elastic_off[:, None], np.nan)   * 100
_pct_inelastic = (_S_inelastic[:, 1:] - _S_inelastic_off[:, None]) / np.where(_S_inelastic_off[:, None] != 0, _S_inelastic_off[:, None], np.nan) * 100

# pcolormesh needs bin edges for q and delay
_q_edges = np.concatenate([qbins[:len(q)], [qbins[len(q)]]] if len(qbins) > len(q) else [q - np.diff(q, prepend=q[0]-np.diff(q[:2]))/2,
                           [q[-1] + np.diff(q[-2:])/2]])
# Simpler: use q bin edges if available, otherwise approximate
if len(qbins) == len(q) + 1:
    _q_edges = qbins
else:
    _dq = np.diff(q)
    _q_edges = np.concatenate([[q[0] - _dq[0]/2], (q[:-1] + q[1:])/2, [q[-1] + _dq[-1]/2]])

_delay_edges = np.arange(len(_on_delays) + 1) - 0.5  # centred integer bins

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

_ax = axes[0]
_im = _ax.pcolormesh(q, np.arange(len(_on_delays)), _pct_elastic.T,
                     cmap="RdBu_r", vmin=-_CLIM_PCT, vmax=_CLIM_PCT, shading="nearest")
_ax.set_yticks(np.arange(len(_on_delays)))
_ax.set_yticklabels([f"{d:.2f}" for d in _on_delays])
_ax.set(xlabel=r"$q\;(\AA^{-1})$", ylabel="Delay (ps)",
        title=f"Elastic ({ELASTIC_RANGE[0]}–{ELASTIC_RANGE[1]} eV loss) — % Diff")
plt.colorbar(_im, ax=_ax, label="% Difference")

_ax = axes[1]
_im = _ax.pcolormesh(q, np.arange(len(_on_delays)), _pct_inelastic.T,
                     cmap="RdBu_r", vmin=-_CLIM_PCT, vmax=_CLIM_PCT, shading="nearest")
_ax.set_yticks(np.arange(len(_on_delays)))
_ax.set_yticklabels([f"{d:.2f}" for d in _on_delays])
_ax.set(xlabel=r"$q\;(\AA^{-1})$", ylabel="Delay (ps)",
        title=f"Inelastic ({INELASTIC_RANGE[0]}–{INELASTIC_RANGE[1]} eV loss) — % Diff")
plt.colorbar(_im, ax=_ax, label="% Difference")

plt.suptitle(r"$\Delta S / S_\mathrm{off} \times 100$", fontsize=14)
plt.tight_layout()
plt.show()


In [None]:

# ---------- User parameters ----------
Q_INTEGRATE_RANGE = (16, None)  # q-index range to integrate over (skip low-q); None = end
_CLIM_PCT_E = 10                # symmetric colour limit (%)
# --------------------------------------

_E_loss = E_pix - E_pix[0]
_q_lo = Q_INTEGRATE_RANGE[0]
_q_hi = Q_INTEGRATE_RANGE[1] if Q_INTEGRATE_RANGE[1] is not None else len(q)

# Integrate S_spec_abs over q range → (energy, n_time_groups)
_S_q_int = S_spec_abs[:, _q_lo:_q_hi, :].sum(axis=1)

# Percent difference relative to laser-off (time group 0)
_S_off_e = _S_q_int[:, 0]
_pct_e = (_S_q_int[:, 1:] - _S_off_e[:, None]) / np.where(_S_off_e[:, None] != 0, _S_off_e[:, None], np.nan) * 100

_on_delays = delay_values[1:]

fig, ax = plt.subplots(figsize=(10, 5))
_im = ax.pcolormesh(_E_loss, np.arange(len(_on_delays)), _pct_e.T,
                    cmap="RdBu_r", vmin=-_CLIM_PCT_E, vmax=_CLIM_PCT_E, shading="nearest")
ax.set_yticks(np.arange(len(_on_delays)))
ax.set_yticklabels([f"{d:.2f}" for d in _on_delays])
ax.set(xlabel="X-Ray Energy Loss (eV)", ylabel="Delay (ps)",
       title=rf"$\Delta S / S_\mathrm{{off}} \times 100$ — q-integrated ({q[_q_lo]:.2f}–{q[_q_hi-1]:.2f} $\AA^{{-1}}$)")
plt.colorbar(_im, ax=ax, label="% Difference")
plt.tight_layout()
plt.show()
