# Timetool Calibration notebook
This notebook is used to take a timetool calibration run and determine the relationship between the edge location on the timetool and the $\Delta t$ between the pump and x-ray timing.

## Loading libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage as ndi
import xrayscatteringtools as xrst
from tqdm.auto import tqdm
from scipy.optimize import curve_fit
xrst.enable_underscore_cleanup()

## Loading the data

In [None]:
###############################################
runNumbers = [] # Enter a timetool calibration run here
folders = xrst.get_data_paths(runNumbers) # Defaults to the info in config.yaml. You can overwrite this with strings, character arrays, or lists of either.
###############################################
# (1) keys_to_combine: some keys loaded for each shot & stored per shot 
# (2) keys_to_sum: some keys loaded per each run and added 
# (3) keys_to_check : check if some keys exits and have same values in all runs and load these keys 
_keys_to_combine = [
    # Timetool data
    'tt/FLTPOS',
    'tt/AMPL',
    'tt/FLTPOSFWHM',
    # Raw timetool traces
    'tt/ttCorr',
    'ttRaw/tt_reference',
    'ttRaw/tt_signal',
     # Laser-Xray Timing, not timetool corrected
    'scan/lxt',
    # Upstream Diode
    'ipm_dg2/sum',
    # Gas detectors
    'gas_detector/f_11_ENRC', 'gas_detector/f_22_ENRC',
    # Light status
    'lightStatus/laser', 'lightStatus/xray',
]
_keys_to_sum = [] # <- These can be empty
_keys_to_check = [] # <- These can be empty
##### Load the data in #####
_data = xrst.combineRuns(runNumbers, folders, _keys_to_combine, _keys_to_sum, _keys_to_check, verbose=False)  # this is the function to load the data with defined keys
############################

runNumbersRange = xrst.compress_ranges(runNumbers)

# Scan data
ttRaw = _data['ttRaw/tt_signal'] # Timetool raw data.

# Raw Timetool data
scan = _data['scan/lxt'] # Time interval between the arrival of laser and xray
tt_reference = _data['ttRaw/tt_reference']

# Simplified Timetool results (old function)
ttpos_old = _data['tt/FLTPOS']
ttfwhm = _data['tt/FLTPOSFWHM']
ttampl = _data['tt/AMPL']
# Upstream diode
dg2 = _data['ipm_dg2/sum'] 
# Gas Detector, this may need to be swapped depending on which is used
pulse_energy = _data['gas_detector/f_22_ENRC']
# Event codes
laserOn = _data['lightStatus/laser'].astype(bool) # laser on events
xrayOn = _data['lightStatus/xray'].astype(bool) # xray on events
_total_shots = ttRaw.shape[0]
print("Total shots: ", _total_shots)

## Filter shots for further analysis

In [None]:
########## Different filter cutoffs######
_pulse_energy_cutoff = [0.0, 3] # In mJ!!!
_tt_edgePos_cutoff = [0.1, 1]
_tt_amp_cutoff = [0.4, 1]
_tt_width_cutoff = [0.1, 0.6]
_plot_display = 'log' # 'linear'
##########################################

# Precomputing the normalized values
_ttposNorm = ttpos_old/np.nanmax(ttpos_old)
_ttamplNorm = ttampl/np.nanmax(ttampl)
_ttfwhmNorm = ttfwhm/np.nanmax(ttfwhm)

plt.figure(figsize=[17,10]) 

plt.subplot(2,2,1)
plt.hist(pulse_energy,bins=200,range=[0,np.nanmax(pulse_energy)]);
plt.axvline(_pulse_energy_cutoff[0],color='r',linestyle='--')
plt.axvline(_pulse_energy_cutoff[1],color='r',linestyle='--')
plt.axvspan(_pulse_energy_cutoff[0],_pulse_energy_cutoff[1],color='r',alpha=0.2,label='"Good" Data')
plt.yscale(_plot_display)
plt.title('Pulse Energy Histogram')
plt.xlabel('mJ')
plt.ylabel('Counts')
plt.legend()

plt.subplot(2,2,2)
plt.hist(_ttposNorm,bins=200,range=[0,np.nanmax(_ttposNorm)]);
plt.axvline(_tt_edgePos_cutoff[0],color='r',linestyle='--')
plt.axvline(_tt_edgePos_cutoff[1],color='r',linestyle='--')
plt.axvspan(_tt_edgePos_cutoff[0],_tt_edgePos_cutoff[1],color='r',alpha=0.2,label='"Good" Data')
plt.yscale(_plot_display)
plt.title('Timetool Edge Histogram')
plt.xlabel('Fraction of Maximum')
plt.ylabel('Counts')
plt.legend()

plt.subplot(2,2,3)
plt.hist(_ttamplNorm,bins=200,range=[0,np.nanmax(_ttamplNorm)]);
plt.axvline(_tt_amp_cutoff[0],color='r',linestyle='--')
plt.axvline(_tt_amp_cutoff[1],color='r',linestyle='--')
plt.axvspan(_tt_amp_cutoff[0],_tt_amp_cutoff[1],color='r',alpha=0.2,label='"Good" Data')
plt.yscale(_plot_display)
plt.title('Timetool Amplitude Histogram')
plt.xlabel('Fraction of Maximum')
plt.ylabel('Counts')
plt.legend()

plt.subplot(2,2,4)
plt.hist(_ttfwhmNorm,bins=200,range=[0,np.nanmax(_ttfwhmNorm)]);
plt.axvline(_tt_width_cutoff[0],color='r',linestyle='--')
plt.axvline(_tt_width_cutoff[1],color='r',linestyle='--')
plt.axvspan(_tt_width_cutoff[0],_tt_width_cutoff[1],color='r',alpha=0.2,label='"Good" Data')
plt.yscale(_plot_display)
plt.title('Timetool FWHM Histogram')
plt.xlabel('Fraction of Maximum')
plt.ylabel('Counts')
plt.legend()

plt.show()
goodIdx = np.logical_and.reduce([
    _tt_edgePos_cutoff[0] <= _ttposNorm,
    _ttposNorm <= _tt_edgePos_cutoff[1],
    _tt_amp_cutoff[0] <= _ttamplNorm,
    _ttamplNorm <= _tt_amp_cutoff[1],
    _tt_width_cutoff[0] <= _ttfwhmNorm,
    _ttfwhmNorm <= _tt_width_cutoff[1]
])
# Displaying how much data was kept due to this filtering
_counts, _ = np.histogram(goodIdx.astype(int),[0,1,2]);
print(f'goodIdx data represents {_counts[1]/np.sum(_counts)*100:.2f}% of the total shots. ({_counts[1]} out of {np.sum(_counts)}).')

## Fitting the raw timetool traces to and erf (fast), no curve fitting.

In [None]:
# Preallocation
ttRawGood = ttRaw[goodIdx]
_num_shots = len(ttRawGood)
_trace_length = len(ttRawGood[0])
range_value = np.zeros(_num_shots)
centers = np.zeros(_num_shots)
amps = np.zeros(_num_shots)
slopes = np.zeros(_num_shots)
norm_all = np.zeros((_num_shots, _trace_length))

# Determine which shots meet specific critera before fitting
mins = np.min(ttRawGood, axis=1)
maxs = np.max(ttRawGood, axis=1)
valid_mask = mins <= (0.5 * maxs)
valid_indices = np.where(valid_mask)[0]

_exception_counter = 0

for _j in tqdm(valid_indices, desc="Processing valid shots"):
    # Run the fit for each of the trances
    range_value[_j], centers[_j], amps[_j], norm_all[_j], slopes[_j] = xrst.calib.fast_erf_fit(ttRawGood[_j])
    
    # Track exceptions using fast tuple comparison
    if (range_value[_j], centers[_j], amps[_j], slopes[_j]) == (0, 0, 0, 0):
        _exception_counter += 1

print(f'Exception counter = {_exception_counter} out of {_num_shots}')

## Determining "good" data fromt the slopes

In [None]:
#################################
slope_min = 0.7
slope_max = 1
#################################
norm_slopes = slopes/np.nanmax(slopes)
plt.figure()
hist, bins, _ = plt.hist(norm_slopes, bins=200)
plt.axvline(slope_min,color='r',linestyle='--')
plt.axvline(slope_max,color='r',linestyle='--')
plt.axvspan(slope_min,slope_max,color='r',alpha=0.2,label='"Good" Data')
plt.title("Histogram of Slopes for Timetool Fitting")
plt.xlabel("values")
plt.ylabel("counts")
plt.show()

## Determine "good" and "bad" fits based on the slope found, show examples of each.

In [None]:
num_bad_fits = np.sum(hist[bins[:-1] < slope_min])
num_good_fits = np.sum(hist[bins[:-1] > slope_min])

print(f"Number of bad fits: {num_bad_fits}")
print(f"Number of good fits: {num_good_fits}")

bad = [num_bad_fits]
good = [num_good_fits]

good_mask = (norm_slopes > slope_min) & (norm_slopes < slope_max)
good_indices = np.where(good_mask)[0]

bad_mask = (norm_slopes < slope_min) | (norm_slopes > slope_max)
bad_indices = np.where(bad_mask)[0]

sample_good = np.random.choice(good_indices, size=min(8, len(good_indices)), replace=False)
sample_bad = np.random.choice(bad_indices, size=min(4, len(bad_indices)), replace=False)

# Determine filter for mean of reference traces
_tt_ref_filter = ~np.all(0==tt_reference,axis=1)
tt_ref_norm = tt_reference[_tt_ref_filter].mean(axis=0) / tt_reference[_tt_ref_filter].mean(axis=0).max()

###### Plotting ######
plt.figure(figsize=(12, 5))
# Iterate directly over the randomly chosen good indices
plt.subplot(1, 2, 1)
for _j, idx in enumerate(sample_good):
    plt.plot(norm_all[idx] + _j, label=f'{slopes[idx]:.4f}')
    plt.plot(centers[idx], amps[idx] + _j, "ko")
    plt.axhline(_j, color='grey', linestyle='--')

# Plot the reference trace
plt.plot(tt_ref_norm, "-", color="black", label="Reference")

plt.title("Example of Good Fits")
plt.xlabel("Pixel")
plt.ylabel("Normalized Intensity")
plt.legend(loc='best')
# Iterate directly over the randomly chosen good indices
plt.subplot(1, 2, 2)
for _j, idx in enumerate(sample_bad):
    plt.plot(norm_all[idx] + _j, label=f'{slopes[idx]:.4f}')
    plt.plot(centers[idx], amps[idx] + _j, "ko")
    plt.axhline(_j, color='grey', linestyle='--')

plt.title("Example of Bad Fits")
plt.xlabel("Pixel")
plt.ylabel("Normalized Intensity")
plt.legend(loc='best')
plt.tight_layout()
plt.show()

## From the good range, determine the timetool fit

In [None]:
centers_good = centers[good_mask]
amps_good = amps[good_mask]
norm_good = norm_all[good_mask]
slopes_good = slopes[good_mask]
scan_good = scan[goodIdx][good_mask]

def linear_fit(x, a, b):
    return a*x + b

bounds = ([-5, -1000], 
          [ 5,  1000])

pars, cov = curve_fit(linear_fit, centers_good, scan_good, bounds=bounds) 
errs = np.sqrt(np.diag(cov))

print('Slope: ({0} ± {1})'.format(pars[0], errs[0]))
print('Intercept: ({0} ± {1})'.format(pars[1], errs[1]))

res_fit = linear_fit(centers_good, *pars)

plt.figure()
plt.scatter(centers_good, scan_good, alpha=0.2)
plt.plot(centers_good, res_fit,'--',color = 'red', label=f'{pars[0]*1e15:.4f}*edge_position + {pars[1]*1e15:.4f}')
plt.title(f'Timetool calibration - Run: {runNumbersRange}')
plt.xlabel('Edge position (in pixel)')
plt.ylabel('LXT Stage Position (s)')
plt.legend()
plt.show()

## Optionally, save/append data into `config.yaml`

In [None]:
xrst.calib.timetool_calibration.add_calibration_to_yaml(
    run_range=[0, '.inf'], 
    slope=pars[0],
    intercept=pars[1],
)