In [None]:
%matplotlib widget

# Coordinate transformations

We often need to compute new coordinates from other coordinates; for example wavelength from time-of-flight or scattering angle from positions.
Scipp has a flexible utility for this purpose: [scipp.transform_coords](https://scipp.github.io/generated/functions/scipp.transform_coords.html).

## Setup

Consider a beamline with straight neutron beams (i.e., without guides, focusing optics, analyzers, etc.):

![image](../images/straight-beamline.svg)

We want to compute the total length of the flight path `Ltotal` from positions of the detector, sample, and source.
It is easy enough to write code that does this.
However, consider the backscattering QENS beamline from the McStas session:

![image](../images/qens-beamline.svg)

Here, we need to take the analyzer into account when computing `Ltotal`.

To illustrate, we begin with the example of straight beams and generate some test data.
The exact contents of the data don't matter here, but note that we store the various positions as coordinates.

In [None]:
import scipp as sc

In [None]:
# elastic_data
source_position = sc.vector([0.0, 0.0, -10.0], unit="m")
sample_position = sc.vector([0.0, 0.0, 0.0], unit="m")
position = sc.vectors(
    dims=["position"],
    values=[
        [0.0, 0.5, 1.0],
        [0.0, 1.0, 1.0],
        [0.0, 1.5, 1.0],
    ],
    unit="m",
)

elastic_data = sc.DataArray(
    sc.ones(sizes={"position": 3}),
    coords={
        "source_position": source_position,
        "sample_position": sample_position,
        "position": position,
    },
)
elastic_data

In [None]:
# qens_data
source_position = sc.vector([0.0, 0.0, -10.0], unit="m")
sample_position = sc.vector([0.0, 0.0, 0.0], unit="m")
analyzer_position = sc.vector([0.0, 1.0, 1.0], unit="m")
position = sc.vectors(
    dims=["position"],
    values=[
        [0.0, 1.9, 0.0],
        [0.0, 2.0, 0.0],
        [0.0, 2.1, 0.0],
    ],
    unit="m",
)

qens_data = sc.DataArray(
    sc.ones(sizes={"position": 3}),
    coords={
        "source_position": source_position,
        "sample_position": sample_position,
        "analyzer_position": analyzer_position,
        "position": position,
    },
)
qens_data

## The manual approach

A straight-forward way of computing `Ltotal` from this is the following:

In [None]:
L1 = sc.norm(elastic_data.coords["sample_position"] - elastic_data.coords["source_position"])
L2 = sc.norm(elastic_data.coords["position"] - elastic_data.coords["sample_position"])
Ltotal = L1 + L2
Ltotal

This uses vector arithmetic on the coordinates and [scipp.norm](https://scipp.github.io/generated/functions/scipp.norm.html) to compute vector lengths.

However, if we now want to do this for the QENS experiment, we need to rewrite the entire procedure:

In [None]:
L1 = sc.norm(qens_data.coords["sample_position"] - qens_data.coords["source_position"])
L2 = (sc.norm(qens_data.coords["position"] - qens_data.coords["analyzer_position"])
     + sc.norm(qens_data.coords["analyzer_position"] - qens_data.coords["sample_position"]))
Ltotal = L1 + L2
Ltotal

## Using `transform_coords`

Instead, we are going to use [scipp.transform_coords](https://scipp.github.io/generated/functions/scipp.transform_coords.html).

First, we have to define functions to compute `Ltotal` and its components `L1` and `L2`:

In [None]:
def straight_l1(source_position, sample_position):
    return sc.norm(sample_position - source_position)


def straight_l2(sample_position, position):
    return sc.norm(position - sample_position)


def l_total(L1, L2):
    return L1 + L2

We then store those functions in a `dict`.
The `dict`-keys are names for the outputs of the functions.

In [None]:
graph = {"L1": straight_l1, "L2": straight_l2, "Ltotal": l_total}

This dict defines a graph that connects coordinates with functions that can compute them.
We can visualize it with Scipp:

In [None]:
sc.show_graph(graph)

Note how coordinates (white boxes) and functions (gray boxes) are connected.
Scipp knows that, e.g., `straight_l1` produces `L1` because of the `dict`-key and uses `source_position` and `sample_position` as inputs because of the names of the function arguments.

We can now compute `Ltotal` by using the graph with `transform_coords`:

In [None]:
converted = elastic_data.transform_coords("Ltotal", graph=graph)
converted

This did several things for us.

- It computed `Ltotal` as we requested and stored it as a new coordinate.
- It also computed `L1` and `L2` because those were needed for `Ltotal`.
- It renamed the dimension from `position` to `Ltotal` because we consider the latter to have replaced the former.

### Customizing the graph for QENS

We can now adapt the above example to compute `Ltotal` for the `QENS` experiment.
We need a new function that computes `L2` the flight path length from sample to analyzer to detector.

![image](../images/qens-beamline.svg)

In [None]:
def backscattering_l2(sample_position, analyzer_position, position):
    a = sc.norm(analyzer_position - sample_position)
    b = sc.norm(position - analyzer_position)
    return a + b

We can reuse the graph for the straight beamline and simply replace the function for `L2`:

In [None]:
graph["L2"] = backscattering_l2
sc.show_graph(graph)

In [None]:
converted = qens_data.transform_coords("Ltotal", graph=graph)
converted

## The larger picture

The examples shown above are fairly small and easy to see through.
But in practice, coordinate transformations can involve more and more complicated steps.
As an example, here is the default graph provided by ScippNeutron:

In [None]:
import scippneutron as scn

graph = scn.conversion.graph.beamline.beamline(scatter=True)
sc.show_graph(graph)

It is similar to our own graph but involves additional intermediate results and can also be used to compute the scattering angle `two_theta`.

We can also add functions to compute the many more coordinates such as $d$-spacing, $Q$, or hkl indices:
(If you don't know the syntax, simple read `{**a, **b}` as merging the two dicts `a` and `b` into a single dict.)

In [None]:
graph = {
    **scn.conversion.graph.beamline.beamline(scatter=True),
    **scn.conversion.graph.tof.elastic(start="tof"),
}
sc.show_graph(graph)