# Reducing WFM data

This notebook aims to illustrate how to work with the wavelength frame multiplication submodule `wfm`.
We will create a beamline that resembles the ODIN instrument beamline,
generate some fake neutron data,
and then show how to convert the neutron arrival times at the detector to neutron time-of-flight,
from which a wavelength can then be computed (or process also commonly known as 'stitching').

In [None]:
import numpy as np
import matplotlib.pyplot as plt
plt.ioff() # Turn of auto-showing of figures
import scipp as sc
import scippneutron as scn
import ess.wfm as wfm
np.random.seed(1) # Fixed for reproducibility

## Create beamline components

We first create all the components necessary to a beamline to run in WFM mode
(see [Introduction to WFM](introduction-to-wfm.ipynb) for the meanings of the different symbols).
The beamline will contain

- a neutron source, located at the origin ($x = y = z =  0$)
- a pulse with a defined length ($2860 ~\mu s$) and $t_0$ ($130 ~\mu s$)
- a single pixel detector, located at $z = 60$ m
- two WFM choppers, located at $z = 6.775$ m and $z = 7.225$ m, each with 6 frame windows/openings

The `wfm` module provides a helper function to quickly create such a beamline.
It returns a `dict` of coordinates, that can then be subsequently added to a data container.

In [None]:
coords = wfm.make_fake_beamline(nframes=6)
coords

## Generate some fake data

Next, we will generate some fake data that is supposed to mimic a spectrum with a Bragg edge located at $4\unicode{x212B}$.
We start with describing a function which will act as our underlying distribution

In [None]:
x = np.linspace(1, 10.0, 100000)
a = 20.0
b = 4.0
y1 = 0.7 / (np.exp(-a * (x - b)) + 1.0)
y2 = 1.4-0.2*x
y = y1 + y2
fig1, ax1 = plt.subplots()
ax1.plot(x, y)
ax1.set_xlabel("Wavelength [angstroms]")
fig1

We then proceed to generate two sets of 100,000 events:
- one for the `sample` using the distribution defined above
- and one for the `background` which will be just a flat random distribution

For the events in both `sample` and `background`,
we define a wavelength for the neutrons as well as a birth time,
which will be a random time between the pulse $t_0$ and the end of the useable pulse $t_0$ + pulse_length.

In [None]:
nevents = 100_000
events = {
    "sample": {
        "wavelengths": sc.array(
            dims=["wavelength"],
            values=np.random.choice(x, size=nevents, p=y/np.sum(y)),
            unit="angstrom"),
        "birth_times": sc.array(
            dims=["wavelength"],
            values=np.random.random(nevents) * coords["source_pulse_length"].value,
            unit="us") + coords["source_pulse_t_0"]
    },
    "background": {
        "wavelengths": sc.array(
            dims=["wavelength"],
            values=np.random.random(nevents) * 9.0 + 1.0,
            unit="angstrom"),
        "birth_times": sc.array(
            dims=["wavelength"],
            values=np.random.random(nevents) * coords["source_pulse_length"].value,
            unit="us") + coords["source_pulse_t_0"]
    }
}

We can then take a quick look at our fake data by histogramming the events

In [None]:
# Histogram and plot the event data
bins = np.linspace(1.0, 10.0, 129)
fig2, ax2 = plt.subplots()
for key in events:
    h = ax2.hist(events[key]["wavelengths"].values, bins=128, alpha=0.5, label=key)
ax2.set_xlabel("Wavelength [angstroms]")
ax2.set_ylabel("Counts")
ax2.legend()
fig2

We can also verify that the birth times fall within the expected range:

In [None]:
for key, item in events.items():
    print(key)
    print(sc.min(item["birth_times"]))
    print(sc.max(item["birth_times"]))

We can then compute the arrival times of the events at the detector pixel

In [None]:
# The ratio of neutron mass to the Planck constant
alpha = 2.5278e+2 * (sc.Unit('us') / sc.Unit('angstrom') / sc.Unit('m'))
# The distance between the source and the detector
dz = sc.norm(coords['position'] - coords['source_position'])
for key, item in events.items():
    item["arrival_times"] = alpha * dz * item["wavelengths"] + item["birth_times"]
events["sample"]["arrival_times"]

## Visualize the beamline's chopper cascade

We first attach the beamline geometry to a Dataset

In [None]:
ds = sc.Dataset(coords=coords)
ds

The `wfm.plot` submodule provides a useful tool to visualise the chopper cascade as a time-distance diagram.
This is achieved by calling

In [None]:
wfm.plot.time_distance_diagram(ds)

This shows the 6 frames, generated by the WFM choppers,
as well as their predicted time boundaries at the position of the detector.

Each frame has a time window during which neutrons are allowed to pass through,
as well as minimum and maximum allowed wavelengths.

This information is obtained from the beamline geometry by calling

In [None]:
frames = wfm.get_frames(ds)
frames

## Discard neutrons that do not make it through the chopper windows

Once we have the parameters of the 6 wavelength frames,
we need to run through all our generated neutrons and filter out all the neutrons with invalid flight paths,
i.e. the ones that do not make it through both chopper openings in a given frame.

In [None]:
events["sample"]["valid_indices"] = []
events["background"]["valid_indices"] = []
near_wfm_chopper = ds.coords["choppers"].value["WFMC1"]
far_wfm_chopper = ds.coords["choppers"].value["WFMC2"]
near_time_open = near_wfm_chopper.time_open
near_time_close = near_wfm_chopper.time_close
far_time_open = far_wfm_chopper.time_open
far_time_close = far_wfm_chopper.time_close

for item in events.values():
    # Compute event arrival times at wfm choppers 1 and 2
    slopes = 1.0 / (alpha * item["wavelengths"])
    intercepts = -slopes * item["birth_times"]
    times_at_wfm1 = (sc.norm(near_wfm_chopper.position) - intercepts) / slopes
    times_at_wfm2 = (sc.norm(far_wfm_chopper.position) - intercepts) / slopes
    # Create a mask to see if neutrons go through one of the openings
    mask = sc.zeros(dims=times_at_wfm1.dims, shape=times_at_wfm1.shape, dtype=bool)
    for i in range(len(frames["time_min"])):
        mask |= ((times_at_wfm1 >= near_time_open["frame", i]) &
                 (times_at_wfm1 <= near_time_close["frame", i]) &
                 (item["wavelengths"] >= frames["wavelength_min"]["frame", i]).data &
                 (item["wavelengths"] <= frames["wavelength_max"]["frame", i]).data)
    item["valid_indices"] = np.argwhere(mask.values)

## Create a realistic Dataset

We now create a dataset that contains:
- the beamline geometry
- the time coordinate
- the histogrammed events

In [None]:
for item in events.values():
    item["valid_times"] = item["arrival_times"].values[item["valid_indices"]]

tmin = min([item["valid_times"].min() for item in events.values()])
tmax = max([item["valid_times"].max() for item in events.values()])

dt = 0.1 * (tmax - tmin)
time_coord = sc.linspace(dim='time',
                         start=tmin - dt,
                         stop=tmax + dt,
                         num=513,
                         unit=events["sample"]["arrival_times"].unit)
# Histogram the data
for key, item in events.items():
    item["counts"], _ = np.histogram(item["valid_times"], bins=time_coord.values)
    ds[key] = sc.array(dims=['time'], values=item["counts"], unit='counts')
# Add the time coordinate
ds.coords["time"] = time_coord
ds

In [None]:
ds.plot()

## Stitch the frames

Wave-frame multiplication consists of making 6 new pulses from the original pulse.
This implies that the WFM choppers are acting as a source chopper.
Hence, to compute a wavelength from a time and a distance between source and detector,
the location of the source must now be at the position of the WFM choppers,
or more exactly at the mid-point between the two WFM choppers.

The stitching operation equates to converting the `time` dimension to `time-of-flight`,
by subtracting from each frame a time shift equal to the mid-point between the two WFM choppers.

This is performed with the `stitch` function in the `wfm` module:

In [None]:
stitched = wfm.stitch(frames=frames,
                      data=ds,
                      dim='time',
                      bins=513)
stitched

In [None]:
stitched.plot()

By default, the `stitched` function returns a single object,
where all frames have been combined onto a common axis.
It is however possible to return the individual frames separately, using the `merge_frames=False` argument.
This makes it possible to visualize the different frames, for diagnostic purposes:

In [None]:
stitched_frames = wfm.stitch(frames=frames,
                      data=ds,
                      dim='time',
                      merge_frames=False)
# In the case of stitching a Dataset,
# stitch returns a dict of frames for each entry in the Dataset
sc.plot(stitched_frames["sample"])

## Convert to wavelength

Now that the data coordinate is time-of-flight (`tof`),
we can use `scippneutron` to perform the unit conversion from `tof` to `wavelength`.

In [None]:
converted = scn.convert(stitched, origin='tof', target='wavelength', scatter=False)
converted.plot()

## Normalization

Normalization is performed simply by dividing the counts of the `sample` run by the counts of the `background` run.

In [None]:
normalized = converted['sample'] / converted['background']
normalized.plot()

## Comparing to the raw wavelengths

The final step is a sanity check to verify that the wavelength-dependent data obtained from the stitching process
agrees (to within the beamline resolution) with the original wavelength distribution that was generated at
the start of the workflow.

For this, we simply histogram the raw neutron events using the same bins as the `normalized` data,
filtering out the neutrons with invalid flight paths.

In [None]:
for item in events.values():
    item["wavelength_counts"], _ = np.histogram(
        item["wavelengths"].values[item["valid_indices"]],
        bins=normalized.coords['wavelength'].values)

We then normalize the `sample` by the `background` run,
and plot the resulting spectrum alongside the one obtained from the stitching.

In [None]:
original = sc.DataArray(
    data=sc.array(dims=['wavelength'],
                  values=events["sample"]["wavelength_counts"] /
                         events["background"]["wavelength_counts"]),
    coords = {"wavelength": normalized.coords['wavelength']})

sc.plot({"stitched": normalized, "original": original})

We can see that the counts in the `stitched` data agree very well with the original data.
There is some smoothing of the data seen in the `stitched` result,
and this is expected because of the resolution limitations of the beamline due to its long source pulse.
This smoothing (or smearing) would, however, be much stronger if WFM choppers were not used. 