In [None]:
import os
import uproot
import matplotlib.pyplot as plt
import numpy as np
import dataconfig  # to get paths to data

import scipp as sc
import ess.wfm as esswfm
import ess.choppers as essch
import plopp as pp
from typing import Union

plt.rcParams.update({'figure.max_open_warning': 0})

%matplotlib widget

# McStas

In [None]:
# McStas 2D file
mcstas_2dfile = "monitor_tx_DENEX.dat"

# check existence of path to folder containing output of McStas simulation
assert os.path.isdir(dataconfig.data_mcstas), \
'The folder which should contain outputs of McStas simulation does not exist.'

path_to_mcstas2D_file = os.path.join(dataconfig.data_mcstas, mcstas_2dfile)

assert os.path.isfile(path_to_mcstas2D_file), \
'There is an issue with the chosen McStas 2D datafile'

## Extract shape of output data

In [None]:
with open(path_to_mcstas2D_file, 'r') as file:
    for line in file:
        if "array_2d" in line:
            type_array = line.rstrip()
            start = type_array.find('(') + 1
            end = type_array.find(')', start)
            nx_value, ny_value = map(int, type_array[start:end].split(','))
        if "xylimits" in line:
            xylims = np.array(line.split(':')[1].split()).astype(float)

print(f'Limits of x- and y-axis: {xylims}\nNumber of points: nx = {nx_value}, ny = {ny_value}')

## Read file and create scipp DataArray

In [None]:
data2d = np.genfromtxt(path_to_mcstas2D_file, max_rows=ny_value)

#flip data along y axis 
data2d_mcstas = np.flip(data2d, 0)

# define x, y axes (bin-centered)
dx = (xylims[1] - xylims[0]) / float(nx_value)
dy = (xylims[3] - xylims[2]) / float(ny_value)
xaxis_mcstas = np.linspace(xylims[0] + 0.5*dx, xylims[1] - 0.5*dx, nx_value) * 1.0e6
yaxis_mcstas = np.linspace(xylims[2] + 0.5*dx, xylims[3] - 0.5*dy, nx_value)

In [None]:
da_mcstas = sc.DataArray(data=sc.array(dims=['x', 'tof'], values=data2d_mcstas),
                         coords={'tof': sc.array(dims=['tof'], values=xaxis_mcstas, unit='us'),
                                 'x': sc.array(dims=['x'], values=yaxis_mcstas, unit='m')
                                 }).hist(x=ny_value).hist(tof=nx_value)
da_mcstas

In [None]:
pp.plot(da_mcstas, grid=True)

## Beamline setup for WFM stitching
Create V20 beamline with information about WFM and FOC choppers.  
Note that the phase of the chopper has been added to the openings.

In [None]:
def make_v20_beamline() -> dict:
    dim = 'frame'
    nframes = 6
    position_source_chopper1 = 21.729
    position_source_chopper2 = 21.759
    avg_pos_source_choppers = 0.5 * (position_source_chopper1 + position_source_chopper2)
    
    beamline = {
        "source_pulse_length": sc.scalar(2.86e+03, unit='us'),
        "source_pulse_t_0": sc.to_unit(sc.scalar(1.3e-4, unit='s'), 'us'),
        "source_position": sc.vector(value=[0.0, 0.0, 0.0], unit='m')
    }
    
    # WFM1
    # opening of the slits in degrees
    mcstas_wfm1_theta0 = [10.9872, 15.2964, 19.3032, 23.0076, 26.46, 29.6856]
    # offset to the positions of the slits
    mcstas_wfm1_offset = 14.8428 + 17.1
    # centres of the slits in degrees
    mcstas_wfm1_mid_slits = [89.208, 148.1382, 202.9104, 253.827, 301.14, 345.1392]
    
    # WFM2
    # opening of the slits in degrees
    mcstas_wfm2_theta0 = [10.9872, 15.2964, 19.3032, 19.3032, 23.0076, 29.6856]
    # offset to the positions of the slits
    mcstas_wfm2_offset = 14.8428 + 46.76
    # centres of the slits in degrees
    mcstas_wfm2_mid_slits = [70.5348, 133.749, 192.528, 245.322, 296.2386, 345.1644]
  
    cutout_angles_center_wfm1 = sc.empty(dims=[dim], shape=[nframes], unit='rad')
    cutout_angles_center_wfm2 = sc.empty_like(cutout_angles_center_wfm1)
    cutout_angles_width_wfm1 = sc.empty_like(cutout_angles_center_wfm1)
    cutout_angles_width_wfm2 = sc.empty_like(cutout_angles_center_wfm1)

    # FOC1
    # opening of the slits in degrees
    mcstas_foc1_theta0 = [20.64, 23.24, 21.81, 17.87, 15.76, 24.47]
    # offset to the positions of the slits
    mcstas_foc1_offset = 12.235 + 32.4
    # centres of the slits in degrees
    mcstas_foc1_mid_slits = [74.67, 136.67, 194.315, 245.335, 294.92, 347.765]

    # FOC2
    # opening of the slits in degrees
    mcstas_foc2_theta0 = [36.6, 36.06, 30.21, 26.88, 24.56, 29.11]
    # offset to the positions of the slits
    mcstas_foc2_offset = 14.555 + 342.27 - 360
    # centres of the slits in degrees
    mcstas_foc2_mid_slits = [98.06, 154.44, 206.835, 254.25, 299.41, 345.445]
   
    cutout_angles_center_foc1 = sc.empty(dims=[dim], shape=[nframes], unit='rad')
    cutout_angles_center_foc2 = sc.empty_like(cutout_angles_center_foc1)
    cutout_angles_width_foc1 = sc.empty_like(cutout_angles_center_foc1)
    cutout_angles_width_foc2 = sc.empty_like(cutout_angles_center_foc1)
   
    for i in range(nframes):
        cutout_angles_width_wfm1[dim, i] = np.deg2rad(mcstas_wfm1_theta0[i])
        cutout_angles_center_wfm1[dim, i] = np.deg2rad(mcstas_wfm1_offset + mcstas_wfm1_mid_slits[i])
    
        cutout_angles_width_wfm2[dim, i] = np.deg2rad(mcstas_wfm2_theta0[i])
        cutout_angles_center_wfm2[dim, i] = np.deg2rad(mcstas_wfm2_offset + mcstas_wfm2_mid_slits[i])
    
        cutout_angles_width_foc1[dim, i] = np.deg2rad(mcstas_foc1_theta0[i])
        cutout_angles_center_foc1[dim, i] = np.deg2rad(mcstas_foc1_offset + mcstas_foc1_mid_slits[i])
    
        cutout_angles_width_foc2[dim, i] = np.deg2rad(mcstas_foc2_theta0[i])
        cutout_angles_center_foc2[dim, i] = np.deg2rad(mcstas_foc2_offset + mcstas_foc2_mid_slits[i])

    beamline["chopper_wfm_1"] = sc.scalar(
        essch.make_chopper(
            frequency=sc.scalar(70.0, unit="Hz"),
            phase=sc.scalar(0., unit='deg'), # sc.scalar(17.10, unit='deg'),
            position=sc.vector(value=[0.0, 0.0, 28.594 - 0.292 * 0.5 - avg_pos_source_choppers], unit='m'),     
            cutout_angles_center=cutout_angles_center_wfm1,
            cutout_angles_width=cutout_angles_width_wfm1,
            kind=sc.scalar('wfm')))
    
    beamline["chopper_wfm_2"] = sc.scalar(
        essch.make_chopper(
            frequency=sc.scalar(70.0, unit="Hz"),
            phase=sc.scalar(0., unit='deg'), # sc.scalar(46.76, unit='deg'),
            position=sc.vector(value=[0.0, 0.0, 28.594 + 0.292 * 0.5 - avg_pos_source_choppers], unit='m'),     
            cutout_angles_center=cutout_angles_center_wfm2,
            cutout_angles_width=cutout_angles_width_wfm2,
            kind=sc.scalar('wfm')))
    
    beamline["chopper_foc_1"] = sc.scalar(
            essch.make_chopper(
                frequency=sc.scalar(56.0, unit=sc.units.one / sc.units.s),
                phase=sc.scalar(0., unit='deg'), #sc.scalar(32.4, unit='deg'), 
                position=sc.vector(value=[0.0, 0.0, 30.444 - avg_pos_source_choppers], unit='m'),
                cutout_angles_center=cutout_angles_center_foc1,
                cutout_angles_width=cutout_angles_width_foc1,
                kind=sc.scalar('frame_overlap')))
        
    beamline["chopper_foc_2"] = sc.scalar(
            essch.make_chopper(
                frequency=sc.scalar(28.0, unit=sc.units.one / sc.units.s),
                phase=sc.scalar(0, unit='deg'), # sc.scalar(342.27-360, unit='deg'), 
                position=sc.vector(value=[0.0, 0.0, 37.544 - avg_pos_source_choppers], unit='m'),
                cutout_angles_center=cutout_angles_center_foc2,
                cutout_angles_width=cutout_angles_width_foc2,
                kind=sc.scalar('frame_overlap')))
    
    # point detector -> TO DO add spatial structure of detector
    beamline['position'] = sc.vector(value=[0., 0., 50.55 + 0.945 - avg_pos_source_choppers], unit='m')
        
    return beamline  

In [None]:
v20_beamline = make_v20_beamline()
v20_data = sc.Dataset(coords=v20_beamline)
v20_data

## Frames

In [None]:
def get_v20_frames(data: Union[sc.DataArray, sc.Dataset]) -> sc.Dataset:
    """
    Compute analytical frame boundaries and shifts based on chopper
    parameters and detector positions.
    
    Approach based on legacy method implemented for V20 Diffraction using dress
    """
    wfm_choppers = {}
    v20_choppers = {}
    
    for name in essch.find_chopper_keys(data):
        chopper = data.meta[name].value
        if chopper["kind"].value == "wfm":
            wfm_choppers[name] = chopper
            v20_choppers[name] = chopper
        elif chopper["kind"].value == "frame_overlap":
            v20_choppers[name] = chopper
    
    # Determine number of frames
    all_number_of_frames = [data.meta[ch_name].value.sizes['frame'] for ch_name in essch.find_chopper_keys(data)]
    if all_number_of_frames.count(all_number_of_frames[0]) == len(all_number_of_frames):
        number_frames = all_number_of_frames[0]
            
    if len(wfm_choppers) != 2:
        raise RuntimeError("The number of WFM choppers is expected to be 2, "
                           "found {}".format(len(wfm_choppers)))
    
    # Find the near and far WFM choppers based on their positions relative to the source
    wfm_chopper_names = list(wfm_choppers.keys())
    
    if (sc.norm(wfm_choppers[wfm_chopper_names[0]]["position"].data - data.meta["source_position"]) <
            sc.norm(wfm_choppers[wfm_chopper_names[1]]["position"].data - data.meta["source_position"])).value:
        near_index = 0
        far_index = 1
    else:
        near_index = 1
        far_index = 0
    near_wfm_chopper = wfm_choppers[wfm_chopper_names[near_index]]
    far_wfm_chopper = wfm_choppers[wfm_chopper_names[far_index]]
    
    # Compute distances for each detector pixel
    detector_positions = (data.meta["position"] - data.meta["source_position"]).values[2]

    # Container for frames information
    frames = sc.Dataset()
    
    frames["wfm_chopper_mid_point"] = 0.5 * (near_wfm_chopper["position"].data +
                                             far_wfm_chopper["position"].data)

    left_edges = []
    right_edges = []
   
    for i in range(number_frames):
        # Find the minimum and maximum slopes that are allowed through each frame
        slope_min = 1.e30
        slope_max = -1.e30
            
        # loop over choppers
        for key, ch in v20_choppers.items():

            # For now, ignore Wavelength band double chopper
            omega = (2.0 * np.pi * sc.units.rad) * ch['frequency']
                
            xmin = sc.to_unit((ch['cutout_angles_center'] 
                               - 0.5 * ch['cutout_angles_width'] 
                               + sc.to_unit(ch['phase'], 'rad')) / omega, 'us').values[i]
            xmax = sc.to_unit((ch['cutout_angles_center'] 
                               + 0.5 * ch['cutout_angles_width']
                               + sc.to_unit(ch['phase'], 'rad')) / omega, 'us').values[i]
      
            slope1 = ch['position'].value[2] / (xmin - data.meta["source_pulse_length"].value)
            slope2 = ch['position'].value[2] / xmax
           
            if slope_min > slope1:
                x2 = xmin
                y2 = ch['position'].value[2]
                slope_min = slope1
            if slope_max < slope2:
                x3 = xmax
                y3 = ch['position'].value[2]
                slope_max = slope2
 
        # Compute line equation parameters y = a*x + b
        a1 = y3  / x3 
        a2 = y2 / (x2 - data.meta["source_pulse_length"].value)
        b2 = - a2 * data.meta["source_pulse_length"].value
        
        # This is the frame boundaries
        x5 = detector_positions / a1
        x4 = (detector_positions  - b2) / a2
        left_edges.append(x4)
        right_edges.append(x5)
        
    frames["time_min"] =  sc.array(dims=['frame'], values=left_edges, unit='us')
    frames["time_max"] = sc.array(dims=['frame'], values=right_edges, unit='us')
        
    # Frame time corrections: mid-time point between the WFM choppers
    frames["time_correction"] = 0.25 * (
        essch.time_open(far_wfm_chopper) + 
        essch.time_closed(far_wfm_chopper) + 
        essch.time_open(near_wfm_chopper) + 
        essch.time_closed(near_wfm_chopper))
    # Gaps between frames.
    # TO DO: check where to put 0-value: at the beginning or at the end?
    gaps = [0] + [0.5 *(frames["time_min"][i+1] + frames["time_max"][i]).value for i in range(len(frames["time_max"]) - 1)]
    frames["gaps"] = sc.array(dims=['frame'],
                              values= gaps,
                              unit=frames["time_max"].unit
                             )
    return frames

In [None]:
v20_frames = get_v20_frames(v20_data)
v20_frames

In [None]:
v20_frames['gaps'].values

In [None]:
# Check closing of WFM1 and opening of WFM2
essch.time_open(v20_data.meta['chopper_wfm_2'].value).values, essch.time_closed(v20_data.meta['chopper_wfm_1'].value).values

In [None]:
from matplotlib.patches import Rectangle

def v20_time_distance_diagram(data: sc.DataArray, **kwargs) -> plt.Figure:
    """
    Plot the time-distance diagram for V20 beamline.
    The expected input is a Dataset or DataArray containing the chopper cascade information as well as the description of the source pulse.
    This internally calls the `get_v20_frames` method which is used to compute the frame properties for stitching.
    """

    # Get the frame properties
    frames = get_v20_frames(data, **kwargs)

    # Find detector pixel furthest away from source
    source_pos = data.meta["source_position"]
    furthest_detector_pos = sc.max(sc.norm(data.meta["position"] - source_pos)).value
    pulse_rectangle_height = furthest_detector_pos / 50.0
    tmax_glob = frames["time_max"].values[-1] 

    # Create figure and axes
    fig, ax = plt.subplots(1, 1, figsize=(9, 7))
    ax.grid(True, color='lightgray', linestyle="dotted")
    ax.set_axisbelow(True)

    # Draw a light grey rectangle from the origin to t_0 + pulse_length + t_0
    # The second t_0 should in fact be the end of the pulse tail, but since this information is not needed for computing the frame properties, it may not be present in the description of the beamline.
    # So we fake this by simply using t_0 again at the end of the pulse.
    ax.add_patch(Rectangle((0, 0),
                           (2.0 * data.meta["source_pulse_t_0"] + data.meta["source_pulse_length"]).value, 
                           -pulse_rectangle_height,
                           lw=1, 
                           fc='lightgrey', 
                           ec='k', 
                           zorder=10))
    
    # Draw a dark grey rectangle from t_0 to t_0 + pulse_length to represent the usable pulse.
    ax.add_patch(Rectangle((data.meta["source_pulse_t_0"].value, 0), 
                           data.meta["source_pulse_length"].value, 
                           -pulse_rectangle_height, 
                           lw=1, 
                           fc='grey', 
                           ec='k', 
                           zorder=11))
    
    # Indicate source pulse and add the duration.
    ax.text(data.meta["source_pulse_t_0"].value, 
            -pulse_rectangle_height, 
            "Source pulse ({} {})".format(data.meta["source_pulse_length"].value, 
                                          data.meta["source_pulse_length"].unit),
            ha="left", 
            va="top", 
            fontsize=6)

    # Plot the chopper openings as segments
    for name in essch.find_chopper_keys(data):
        chopper = data.meta[name].value
        yframe = sc.norm(chopper["position"].data - source_pos).value
        time_open = essch.time_open(chopper).values
        time_close = essch.time_closed(chopper).values
        tmin = 0.0
        for fnum in range(len(time_open)):
            tmax = time_open[fnum]
            ax.plot([tmin, tmax], [yframe] * 2, color='k')
            tmin = time_close[fnum]
        ax.plot([tmin, tmax_glob], [yframe] * 2, color='k')
        ax.text(2.0 * time_close[-1] - time_open[-1], 
                yframe, 
                name, 
                ha="left", 
                va="bottom")

    # Plot the shades of possible neutron paths
    for i in range(frames.sizes["frame"]):

        col = "C{}".format(i)
        frame = frames["frame", i]
        for dim in data.meta["position"].dims:
            frame = frame[dim, 0]
            
        ax.fill([data.meta["source_pulse_t_0"].value,
                 (data.meta["source_pulse_t_0"] + data.meta["source_pulse_length"]).value, 
                 frames["time_min"].values[i],
                 frames["time_max"].values[i]], [0, 0, furthest_detector_pos, furthest_detector_pos],
                    alpha=0.3,
                    color=col)
        ax.plot([data.meta["source_pulse_t_0"].value, frames["time_max"].values[i]],
                    [0, furthest_detector_pos],
                    color=col,
                    lw=1)
        ax.plot([(data.meta["source_pulse_t_0"] + data.meta["source_pulse_length"]).value, frames["time_min"].values[i]],
                    [0, furthest_detector_pos],
                    color=col,
                    lw=1)
        
    # Add thick solid line for the detector position, spanning the entire width
    ax.plot([0, tmax_glob], [furthest_detector_pos] * 2, lw=3, color='grey')
    ax.text(0.0, furthest_detector_pos, "Detector", va="bottom", ha="left")

    # Set axis labels
    ax.set_xlabel("Time [microseconds]")
    ax.set_ylabel("Distance [m]")

    return fig

In [None]:
f = v20_time_distance_diagram(v20_data)

In [None]:
frames_to_plot = sc.DataArray(data=sc.array(dims=['tof'], 
                                  values=[0]*6),
                    coords={'tof': sc.array(dims=['tof'], 
                                            values=[v20_frames['gaps'].values[i] for i in range(6)], 
                                            unit=v20_frames['gaps'].unit)})
frames_to_plot

In [None]:
p = pp.plot({'ini': da_mcstas.sum('x'), 
             'frames': frames_to_plot}, 
            grid=True, color={'frames': 'red'}, 
            marker={'frames': '|'}, 
            markersize={'frames': 100})

p.canvas.ax.legend().set_visible(False)
p

In [None]:
esswfm.plot.frames_before_stitching(da_mcstas, v20_frames, 'tof', bins_per_frame=60)

In [None]:
esswfm.plot.frames_after_stitching(da_mcstas, v20_frames, 'tof', bins_per_frame=30)

## Stitching

In [None]:
# stitched = v20_stitch(frames=frames, item=da_mcstas, dim='tof', bins=300)
stitched = esswfm.stitch(frames=v20_frames, data=da_mcstas, dim='tof', bins=300)
stitched

In [None]:
pp.plot(stitched, grid=True)

In [None]:
pp.plot({'ini':da_mcstas.sum('x'), 'stitched': stitched.sum('x')}, linestyle='-', marker='.', grid=True)

# DENEX detector
Experimental data

In [None]:
# ROOT file
assert os.path.isdir(dataconfig.data_root), \
    'The path to the folder which should contain ROOT files does not exist.'

# ROOT_file_sp3 = "Spectrum03_DENEX006_1_18-02-05_0000.root"
ROOT_file_sp3 = "Spectrum13_DENEX006_1_18-02-11_0000.root"

path_to_root_file = os.path.join(dataconfig.data_root, ROOT_file_sp3)

# Some metadata related to TOF channel for ROOT file
tof_tick = 25e-3 # in microseconds (25 ns)

In [None]:
# open a ROOT file and extract only one 2D dataset specified in selected_dataset
# Note the vertical axis of 2D datasets is inverted

key_spectrum = 'Spectrum14'  # 'Spectrum03'
dir_with_data = 'Meas_1'  #'Meas_3'
selected_dataset = 'H_TOF,X1-X2_User_2D4_dsp_after_run_1'  # 'H_TOF,X1-X2_User_2D2_dsp_after_run_3'

with uproot.open(path_to_root_file)[dir_with_data] as myFile:

    for key in myFile.keys():
        if 'BoardParam_run' in str(key):
            myObject = myFile[key]
            nb_xbins = int(myObject.member('fEntries'))
                
            for i in range(nb_xbins):
                if 'TOF_Time_Channel_Width' in myObject.axis(axis=0).labels()[i]:
                    TOF_Time_Channel_Width = myObject.counts(False)[i]
                elif 'TOF_Window_Delay_Register' in myObject.axis(axis=0).labels()[i]:
                    TOF_Window_Delay_Register = myObject.counts(False)[i]
                    
                #print(f"{myObject.axis(axis=0).labels()[i]}: {myObject.counts(False)[i]}")

        # 2D contourplot
        if selected_dataset in str(key):
            myObject = myFile[key]
            data2d_root_sp3 =  np.flip(myObject.counts(False), 1).transpose().astype(np.float64)
            # create x- and y-axis
            xaxis = myObject.axis(axis=0).edges()[:-1]
            yaxis = myObject.axis(axis=1).edges()[:-1]

In [None]:
# Convert TOF channels to microseconds:
# a TOF channel width is 2773 ticks and one tick is 25ns.
# Then also add a 6.25ms delay

xaxis = (xaxis * TOF_Time_Channel_Width + TOF_Window_Delay_Register) * tof_tick

In [None]:
da_sp3 = sc.DataArray(data=sc.array(dims=['x', 'tof'], values=data2d_root_sp3),
                         coords={'tof': sc.array(dims=['tof'], values=xaxis, unit='us'),
                                 'x': sc.array(dims=['x'], values=yaxis, unit='m')
                                 }).hist(x=len(yaxis)).hist(tof=len(xaxis))
da_sp3

In [None]:
pp.plot(da_sp3)

## Beamline setup for WFM stitching

In [None]:
def make_v20_beamline_exp() -> dict:
    dim = 'frame'
    nframes = 6
    position_source_chopper1 = 21.89
    position_source_chopper2 = 21.91
    avg_pos_source_choppers = 0.5 * (position_source_chopper1 + position_source_chopper2)
    
    beamline = {
        "source_pulse_length": sc.scalar(2.86e+03, unit='us'),
        "source_pulse_t_0": sc.to_unit(sc.scalar(1.3e-4, unit='s'), 'us'),
        "source_position": sc.vector(value=[0.0, 0.0, 0.0], unit='m')
    }
    
    # WFM1
    # opening of the slits in degrees
    mcstas_wfm1_theta0 = [10.9872, 15.2964, 19.3032, 23.0076, 26.46, 29.6856]
    # offset to the positions of the slits
    mcstas_wfm1_offset = 14.8428 + 17.1
    # centres of the slits in degrees
    mcstas_wfm1_mid_slits = [89.208, 148.1382, 202.9104, 253.827, 301.14, 345.1392]
    
    # WFM2
    # opening of the slits in degrees
    mcstas_wfm2_theta0 = [10.9872, 15.2964, 19.3032, 19.3032, 23.0076, 29.6856]
    # offset to the positions of the slits
    mcstas_wfm2_offset = 14.8428 + 46.76
    # centres of the slits in degrees
    mcstas_wfm2_mid_slits = [70.5348, 133.749, 192.528, 245.322, 296.2386, 345.1644]
  
    cutout_angles_center_wfm1 = sc.empty(dims=[dim], shape=[nframes], unit='rad')
    cutout_angles_center_wfm2 = sc.empty_like(cutout_angles_center_wfm1)
    cutout_angles_width_wfm1 = sc.empty_like(cutout_angles_center_wfm1)
    cutout_angles_width_wfm2 = sc.empty_like(cutout_angles_center_wfm1)

    # FOC1
    # opening of the slits in degrees
    mcstas_foc1_theta0 = [20.64, 23.24, 21.81, 17.87, 15.76, 24.47]
    # offset to the positions of the slits
    mcstas_foc1_offset = 12.235 + 32.4
    # centres of the slits in degrees
    mcstas_foc1_mid_slits = [74.67, 136.67, 194.315, 245.335, 294.92, 347.765]

    # FOC2
    # opening of the slits in degrees
    mcstas_foc2_theta0 = [36.6, 36.06, 30.21, 26.88, 24.56, 29.11]
    # offset to the positions of the slits
    mcstas_foc2_offset = 14.555 + 342.27 - 360
    # centres of the slits in degrees
    mcstas_foc2_mid_slits = [98.06, 154.44, 206.835, 254.25, 299.41, 345.445]
   
    cutout_angles_center_foc1 = sc.empty(dims=[dim], shape=[nframes], unit='rad')
    cutout_angles_center_foc2 = sc.empty_like(cutout_angles_center_foc1)
    cutout_angles_width_foc1 = sc.empty_like(cutout_angles_center_foc1)
    cutout_angles_width_foc2 = sc.empty_like(cutout_angles_center_foc1)
   
    for i in range(nframes):
        cutout_angles_width_wfm1[dim, i] = np.deg2rad(mcstas_wfm1_theta0[i])
        cutout_angles_center_wfm1[dim, i] = np.deg2rad(mcstas_wfm1_offset + mcstas_wfm1_mid_slits[i])
    
        cutout_angles_width_wfm2[dim, i] = np.deg2rad(mcstas_wfm2_theta0[i])
        cutout_angles_center_wfm2[dim, i] = np.deg2rad(mcstas_wfm2_offset + mcstas_wfm2_mid_slits[i])
    
        cutout_angles_width_foc1[dim, i] = np.deg2rad(mcstas_foc1_theta0[i])
        cutout_angles_center_foc1[dim, i] = np.deg2rad(mcstas_foc1_offset + mcstas_foc1_mid_slits[i])
    
        cutout_angles_width_foc2[dim, i] = np.deg2rad(mcstas_foc2_theta0[i])
        cutout_angles_center_foc2[dim, i] = np.deg2rad(mcstas_foc2_offset + mcstas_foc2_mid_slits[i])

    beamline["chopper_wfm_1"] = sc.scalar(
        essch.make_chopper(
            frequency=sc.scalar(70.0, unit="Hz"),
            phase=sc.scalar(0., unit='deg'), # sc.scalar(17.10, unit='deg'),
            position=sc.vector(value=[0.0, 0.0, 28.57 - avg_pos_source_choppers], unit='m'),     
            cutout_angles_center=cutout_angles_center_wfm1,
            cutout_angles_width=cutout_angles_width_wfm1,
            kind=sc.scalar('wfm')))
    
    beamline["chopper_wfm_2"] = sc.scalar(
        essch.make_chopper(
            frequency=sc.scalar(70.0, unit="Hz"),
            phase=sc.scalar(0., unit='deg'), # sc.scalar(46.76, unit='deg'),
            position=sc.vector(value=[0.0, 0.0, 28.91 - avg_pos_source_choppers], unit='m'),     
            cutout_angles_center=cutout_angles_center_wfm2,
            cutout_angles_width=cutout_angles_width_wfm2,
            kind=sc.scalar('wfm')))
    
    beamline["chopper_foc_1"] = sc.scalar(
            essch.make_chopper(
                frequency=sc.scalar(56.0, unit=sc.units.one / sc.units.s),
                phase=sc.scalar(0., unit='deg'), #sc.scalar(32.4, unit='deg'), 
                position=sc.vector(value=[0.0, 0.0, 30.5 - avg_pos_source_choppers], unit='m'),
                cutout_angles_center=cutout_angles_center_foc1,
                cutout_angles_width=cutout_angles_width_foc1,
                kind=sc.scalar('frame_overlap')))
        
    beamline["chopper_foc_2"] = sc.scalar(
            essch.make_chopper(
                frequency=sc.scalar(28.0, unit=sc.units.one / sc.units.s),
                phase=sc.scalar(0, unit='deg'), # sc.scalar(342.27-360, unit='deg'), 
                position=sc.vector(value=[0.0, 0.0, 37.6 - avg_pos_source_choppers], unit='m'),
                cutout_angles_center=cutout_angles_center_foc2,
                cutout_angles_width=cutout_angles_width_foc2,
                kind=sc.scalar('frame_overlap')))
    
    beamline['position'] = sc.vector(value=[0., 0., 50.55 + 0.945 - avg_pos_source_choppers], unit='m')
        
    return beamline  

In [None]:
v20_beamline_exp = make_v20_beamline_exp()
v20_data_exp = sc.Dataset(coords=v20_beamline_exp)
v20_data_exp

## Frames

In [None]:
v20_frames_exp = get_v20_frames(v20_data_exp)
v20_frames_exp

In [None]:
f_exp = v20_time_distance_diagram(v20_data_exp)

In [None]:
v20_frames_exp_modif = v20_frames_exp.copy()
print(v20_frames_exp_modif['gaps'].values, v20_frames_exp['gaps'].values)
v20_frames_exp_modif['gaps'].values = [0., 21830.98517558, 31647.65893236, 40440, 48900, 57120]
v20_frames_exp_modif['gaps'].values, v20_frames_exp['gaps'].values

In [None]:
frames_exp_to_plot = sc.DataArray(data=sc.array(dims=['tof'], 
                                  values=[0]*6),
                    coords={'tof': sc.array(dims=['tof'], 
                                            values=[v20_frames_exp['gaps'].values[i] for i in range(6)], 
                                            unit=v20_frames_exp['gaps'].unit)})

time_min_exp_to_plot = sc.DataArray(data=sc.array(dims=['tof'], 
                                  values=[0]*6),
                    coords={'tof': sc.array(dims=['tof'], 
                                            values=[v20_frames_exp['time_min'].values[i] for i in range(6)], 
                                            unit=v20_frames_exp['time_min'].unit)})

time_max_exp_to_plot = sc.DataArray(data=sc.array(dims=['tof'], 
                                  values=[0]*6),
                    coords={'tof': sc.array(dims=['tof'], 
                                            values=[v20_frames_exp['time_max'].values[i] for i in range(6)], 
                                            unit=v20_frames_exp['time_max'].unit)})
frames_exp_to_plot

In [None]:
p = pp.plot({'ini': da_sp3.sum('x'), 
             'frames': frames_exp_to_plot},
            
            grid=True, color={'frames': 'red'}, 
            marker={'frames': '|'}, 
            markersize={'frames': 100})

p.canvas.ax.legend().set_visible(False)
p

In [None]:
p = pp.plot({'ini': da_sp3.sum('x'), 
             'frames': frames_exp_to_plot,
             'time_min': time_min_exp_to_plot,
             'time_max': time_max_exp_to_plot},
            
            grid=True, color={'frames': 'red', 'time_min': 'green', 'time_max': 'blue'}, 
            marker={'frames': '|', 'time_min': '|', 'time_max': '|'}, 
            markersize={'frames': 100, 'time_min': 90, 'time_max': 80})

p.canvas.ax.legend().set_visible(False)
p

In [None]:
esswfm.plot.frames_before_stitching(da_sp3, v20_frames_exp_modif, 'tof', bins_per_frame=60)

In [None]:
esswfm.plot.frames_after_stitching(da_sp3, v20_frames_exp_modif, 'tof', bins_per_frame=30)

## Stitching

In [None]:
def _stitch_dense_data(item: sc.DataArray, frames: sc.Dataset, dim: str, new_dim: str,
                       bins: Union[int, sc.Variable]) -> Union[sc.DataArray, dict]:

    # Make empty data container
    if isinstance(bins, int):
        new_coord = sc.linspace(
            dim=new_dim,
            start=(frames["time_min"]["frame", 0] -
                   frames["time_correction"]["frame", 0]).value,
            stop=(frames["time_max"]["frame", -1] -
                  frames["time_correction"]["frame", -1]).value,
            num=bins + 1,
            unit=frames["time_min"].unit,
        )
    else:
        new_coord = bins

    dims = []
    shape = []
    for dim_ in item.dims:
        if dim_ != dim:
            dims.append(dim_)
            shape.append(item.sizes[dim_])
        else:
            dims.append(new_dim)
            shape.append(new_coord.sizes[new_dim] - 1)

    out = sc.DataArray(data=sc.zeros(dims=dims,
                                     shape=shape,
                                     with_variances=item.variances is not None,
                                     unit=item.unit),
                       coords={new_dim: new_coord})
    for group in ["coords", "attrs"]:
        for key in getattr(item, group):
            if key != dim:
                getattr(out, group)[key] = getattr(item, group)[key].copy()

    for i in range(frames.sizes["frame"]):
        section = item[dim, frames["time_min"].data["frame", i]:frames["time_max"].data["frame", i]].rename_dims({dim: new_dim})
        section.coords[new_dim] = section.coords[dim] - frames["time_correction"].data["frame", i]
        if new_dim != dim:
            del section.coords[dim]

        out += section.rebin({new_dim: out.coords[new_dim]})

    return out

def stitch(
        data: Union[sc.DataArray, sc.Dataset],
        dim: str,
        frames: sc.Dataset,
        new_dim: str = 'tof',
        bins: Union[int, sc.Variable] = 256) -> Union[sc.DataArray, sc.Dataset, dict]:
    """
    Convert raw arrival time WFM data to time-of-flight by shifting each frame
    (described by the `frames` argument) by a time offset defined by the position
    of the WFM choppers.
    This process is also known as 'stitching' the frames.
    :param data: The DataArray or Dataset to be stitched.
    :param dim: The dimension along which the stitching will be performed.
    :param frames: The Dataset containing the information on the frame boundaries.
    :param new_dim: New dimension of the returned data, after stitching. Default: 'tof'.
    :param bins: Number or Variable describing the bins for the returned data. Default:
        256.
    """

    # TODO: for now, if frames depend on positions, we take the mean along the
    # position dimensions. We should implement the position-dependent stitching
    # in the future.
    frames = frames.copy()
    dims_to_reduce = list(set(frames.dims) - {'frame'})
    for _dim in dims_to_reduce:
        frames["time_min"] = sc.mean(frames["time_min"], _dim)
        frames["time_max"] = sc.mean(frames["time_max"], _dim)

    stitched = _stitch_dense_data(item=data,
                                dim=dim,
                                frames=frames,
                                new_dim=new_dim,
                                bins=bins)

    return stitched

In [None]:
stitched_sp3 = esswfm.stitch(frames=v20_frames_exp, data=da_sp3, dim='tof', bins=200)
stitched_sp3

In [None]:
pp.plot(stitched_sp3)#, grid=True)

In [None]:
pp.plot({'ini':da_sp3.sum('x'), 'stitched': stitched_sp3.sum('x')}, linestyle='-', marker='.', grid=True)

In [None]:
pp.plot(stitched_sp3.sum('x'), linestyle='-', marker='.', grid=True)