[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/casangi/graphviper/blob/main/docs/graph_building_tutorial_processing_set.ipynb)

# GraphVIPER Tutorial

This tutorial provides examples of how `GraphVIPER` can be used to build `Dask` graphs by mapping a dictionary-based container of `xarray.Datasets` to `Dask` graph nodes, followed by a reduction step. The dictionary of `Datasets` used in this tutorial is referred to as a `Processing Set`, although any dictionary containing `xarray.Datasets` can be used. `GraphVIPER` [map](https://graphviper.readthedocs.io/en/latest/_api/autoapi/graphviper/graph_tools/map/index.html#module-contents) can be thought of as a generalization of the [xarray.map_blocks](https://docs.xarray.dev/en/stable/generated/xarray.map_blocks.html) that can be applied to more than one `xarray.Dataset`. The graphs are built using [dask.delayed](https://docs.dask.org/en/stable/delayed.html).

The following types of mapping are supported:

- Partitions defined by any combination of the coordinates in the `Processing Set`.
- More than one `xarray.Dataset` can be assigned to a single mapping node.
- `xarray.Dataset` partitions assigned to different nodes can have coordinates that overlap.

The tutorial will cover the following examples:

- Frequency Map Reduce: This example explains the concepts of `parallel_coords` and `node_task_data_mapping` that define parallelism and mapping.
- Overlapping Frequency Map Reduce.
- Baseline and Frequency Map Reduce.
- Time Map Reduce.

`GraphVIPER` provides improvements over the [CNGI prototype](https://cngi-prototype.readthedocs.io/en/stable/development.html):

- There is a clear separation between the concurrency layer ([GraphVIPER](https://graphviper.readthedocs.io/en/latest/)) and the domain layer (science code, [AstroVIPER](https://github.com/casangi/astroviper)).
- The memory backpressure issue was solved by incorporating the loading of data into the compute nodes. Memory backpressure is an issue for Radio Astronomy cube imaging that has to create large in-memory image cubes, which `Dask` is not aware of, causing `Dask` to be overeager in loading data from disk into memory. `Dask` might provide an alternative solution in the future where graph nodes can be annotated with expected memory usage.
- The number of graph nodes has been minimized; this was also solved by incorporating the loading of data into the compute nodes. When `Xarray` backed `Dask` datasets are used, a node is created for each data variable, and since Radio Astronomy datasets have numerous data variables, it led to a bloated graph that impacted scaling performance.
- Multiple `xarray.Datasets` can be processed together with overlap. This cannot be done with the current `Xarray` functionality, such as [xarray.map_blocks](https://docs.xarray.dev/en/stable/generated/xarray.map_blocks.html).
- Using a [Dask plugin](https://distributed.dask.org/en/latest/plugins.html), the `Dask Scheduler` has been modified so that data can be cached to a local disk when multiple passes over larger-than-memory data have to be done. This reduces clustered file system or binary object store access (see [GraphVIPER Client](https://graphviper.readthedocs.io/en/latest/_api/autoapi/graphviper/dask/client/index.html)).


## Install GraphVIPER

In [1]:
import os

try:
    import graphviper

    print("GraphVIPER version", graphviper.__version__, "already installed.")
except ImportError as e:
    print(e)
    print("Installing GraphVIPER")

    os.system("pip install graphviper")

    import graphviper

    print("GraphVIPER version", graphviper.__version__, " installed.")

GraphVIPER version 0.0.3 already installed.


## Download and Convert Dataset

In [2]:
graphviper.utils.data.download(file="Antennae_North.cal.lsrk.split.ms")

from xradio.vis.convert_msv2_to_processing_set import convert_msv2_to_processing_set

# The chunksize on disk. Chunksize can be specified for any of the following dimensions :
# time, baselin_id (interferometer) / antenna_id (single dish), frequency, and polarization.
chunks_on_disk = {"frequency": 3}

infile = "Antennae_North.cal.lsrk.split.ms"
outfile = "Antennae_North.cal.lsrk.split.vis.zarr"

convert_msv2_to_processing_set(
    in_file=infile,
    out_file=outfile,
    parallel=False,
    overwrite=True,
    main_chunksize=chunks_on_disk,
)

[[38;2;128;05;128m2024-02-02 11:47:47,899[0m] [38;2;50;50;205m    INFO[0m[38;2;112;128;144m  graphviper: [0m File exists: Antennae_North.cal.lsrk.split.ms 


## Setup Dask Cluster
To simplify things we are going to start of by just using a single process (everything will run in serial).

In [3]:
import dask

from graphviper.dask.client import local_client

viper_client = local_client(
    cores=2, 
    memory_limit="4GB",
    autorestrictor=True,
    log_params={
        'logger_name': "graphviper",
        'log_to_term': True,
        'log_level': 'INFO',
        'log_to_file': False,
        'log_file': None
    },
    worker_log_params={
        'logger_name': "graphviper",
        'log_to_term': True,
        'log_level': 'INFO',
        'log_to_file': False,
        'log_file': None
    }
)

viper_client.dashboard_link

[[38;2;128;05;128m2024-02-02 11:47:49,135[0m] [38;2;50;50;205m    INFO[0m[38;2;112;128;144m  graphviper: [0m Checking parameter values for [38;2;50;50;205mclient[0m.[38;2;50;50;205mlocal_client[0m 
[[38;2;128;05;128m2024-02-02 11:47:49,137[0m] [38;2;50;50;205m    INFO[0m[38;2;112;128;144m  graphviper: [0m /export/home/ajax/jhoskins/Development/graphviper-logger/ 
[[38;2;128;05;128m2024-02-02 11:47:49,138[0m] [38;2;50;50;205m    INFO[0m[38;2;112;128;144m  graphviper: [0m Searching [38;2;50;50;205m/export/home/ajax/jhoskins/Development/graphviper-logger/[0m for configuration file, please wait ... 
[[38;2;128;05;128m2024-02-02 11:47:50,329[0m] [38;2;50;50;205m    INFO[0m[38;2;112;128;144m  graphviper: [0m Created client <MenrvaClient: 'tcp://127.0.0.1:36335' processes=2 threads=2, memory=7.45 GiB> 


'http://127.0.0.1:8787/status'

In [None]:
dask.config.set(scheduler="synchronous")

## Inspect the Processing Set

The `read_processing_set` is a lazy function, so no data is loaded into memory; only metadata is loaded (the `load_processing_set` will load everything into memory). Metadata is defined as everything that is not an `xarray.datavariable`. Note that a `Processing Set` does not have to be used with `GraphVIPER`, and any dictionary of `xarray.datasets` can be used.

In [None]:
import pandas as pd

pd.options.display.max_colwidth = 100
ps_name = "Antennae_North.cal.lsrk.split.vis.zarr"

from xradio.vis.read_processing_set import read_processing_set

intents = ["OBSERVE_TARGET#ON_SOURCE"]
fields = None
ps = read_processing_set(
    ps_name="Antennae_North.cal.lsrk.split.vis.zarr",
    intents=intents,
    fields=fields,
)
display(ps.summary())

## Inspect a single MS v4

The `xarray.datasets` within a Processing Set are called Measurement Set v4 (`ms_v4`).

In [None]:
ms_xds = ps[
    "Antennae_North.cal.lsrk.split_ddi_0_intent_OBSERVE_TARGET#ON_SOURCE_field_id_0"
]
ms_xds

## Nomenclature

- input data: A dictionary of `xarray.datasets` or a `processing_set`.
- n_datasets: The number of `xarray.Datasets` in the input data.
- i_dim: The ith dimension name.
- n_dims: The number of dimensions over which parallelism will occur.
- n_dim_i_chunks: Number of chunks into which the dimension coordinate `dim_i` has been divided.
- n_nodes: Number of nodes in the mapping stage of a Map Reduce graph.
- _{}: If curly brackets are preceded by an underscore, it indicates a subscript and not a dictionary value.

## How Graph Parallelism is Specified: ```parallel_coords```

The `parallel_coords` is a dictionary where the keys are dimensions over which parallelism will occur and can be any of the dimension coordinate names present in the input data. For the `ms_v4` `xarray.dataset`, the options include time, baseline_id (interferometer) / antenna_id (single dish), frequency, and polarization. Each dimension coordinate name is associated with a dictionary that describes the data selection for that dimension in each node of the mapping stage of the graph.

The structure of the `parallel_coordinates`:
```
        parallel_coords = {
            dim_0: {
                'data': 1D list/np.ndarray of Number,
                'data_chunks': {
                    0 : 1D list/np.ndarray of Number,
                    ⋮
                    n_dim_i_chunks-1 : ...,
                }
                'data_chunk_edges': 1D list/np.ndarray of Number,
                'dims': (dim_0,), 
                'attrs': measure attribute,
            }
            ⋮
            dim_{n_dims-1}: ...
        }
```

The `dim_i` dictionaries have keys with the following meanings:

- `data`: An array containing all the coordinate values associated with that dimension. These values do not necessarily have to match the values in the coordinates of the input data (dictionary of `xarray.datasets` or `processing_set`), as those are interpolated onto these values. The minimum and maximum values can be respectively larger or smaller than the values in the coordinates of individual `xarray.datasets`; this will simply exclude that data from being processed. It's important to note that the `parallel_coords` and the input data coordinates must have the same measures attributes (reference frame, units, etc.).
- `data_chunks`: A dictionary where the data is broken into chunks with integer keys. This chunking determines the parallelism of the graph. The values in the chunks can overlap.
- `data_chunks_edges`: An array with the start and end values of each chunk.
- `dims`: The dimension coordinate name.
- `attrs`: The `XRADIO` measures attributes of the data (refer to [XRADIO documentation](https://docs.google.com/spreadsheets/d/14a6qMap9M5r_vjpLnaBKxsR9TF4azN5LVdOxLacOX-s/edit#gid=1504318014)).

The combinations of all the chunks in `parallel_coords` determine the parallelism of the graph. For example, if you have `parallel_coords` with 5 `time` and 3 `frequency` chunks, you would have 15-way parallelism (5x3).

This description may seem somewhat convoluted, but the following examples should help clarify things.

## Frequency Map Reduce

### Create Parallel Coordinates

GraphVIPER offers a convenient function, `make_parallel_coord`, that converts any [XRADIO measures](https://docs.google.com/spreadsheets/d/14a6qMap9M5r_vjpLnaBKxsR9TF4azN5LVdOxLacOX-s/edit#gid=1504318014) to a `parallel_coord`. In this case, we will use the frequency coordinate of one of the datasets in the `processing_set`. It's worth noting that all datasets in this `processing_set` have the same frequency coordinates but differing time coordinates. This is the case because they represent the same spectral window but different fields in a Mosaic observation.

In [None]:
from graphviper.graph_tools.coordinate_utils import make_parallel_coord
from graphviper.utils.display import dict_to_html
from IPython.display import HTML, display

parallel_coords = {}
n_chunks = 3
parallel_coords["frequency"] = make_parallel_coord(
    coord=ms_xds.frequency, n_chunks=n_chunks
)
display(HTML(dict_to_html(parallel_coords["frequency"])))

The display of the frequency `parallel_coords` clearly shows how the data was split into 3 chunks. All the chunks must have the same number of values, except the last chunk, which can have fewer. GraphVIPER also has convenience functions that can create frequency and time coordinate measures:

In [None]:
from graphviper.graph_tools.coordinate_utils import make_frequency_coord

n_chunks = 3

coord = make_frequency_coord(
    freq_start=343928096685.9587,
    freq_delta=11231488.981445312,
    n_channels=8,
    velocity_frame="lsrk",
)
parallel_coords["frequency"] = make_parallel_coord(
    coord=ms_xds.frequency, n_chunks=n_chunks
)
display(HTML(dict_to_html(parallel_coords["frequency"])))

### Create Node Task Data Mapping

Now, the coordinates in the input data must be mapped onto the `parallel_coords`. This is achieved using the `interpolate_data_coords_onto_parallel_coords` function, which produces the `node_task_data_mapping`. It is a dictionary where each key is the node id of the nodes in the mapping stage of the graph

Structure of  node_task_data_mapping:
```
    node_task_data_mapping = {
        0 : {
            'chunk_indices': tuple of int,
            'parallel_dims': tuple of str,
            'data_selection': {
                    dataset_name_0: {
                            dim_0: slice,
                            ⋮
                            dim_(n_dims-1): slice
                    }
                    ⋮
                    dataset_name_(n_dataset-1): ...
            }
            'task_coords': #Is a measures
                dim_0:{
                    'data': list/np.ndarray of Number,
                    'dims': str,
                    'attrs': measure attribute,
                }
                ⋮
                dim_(n_dims-1): ...
            }
        ⋮
        n_nodes-1 : ...
    }
```

Each node id dictionary has the keys with the following meaning:

- `chunk_indices`: The indices assigned to the data chunks in the `parallel_coords`. There must be an index for each `parallel_dims`.
- `parallel_dims`: The dimension coordinates over which parallelism will occur.
- `data_selection`: A dictionary where the keys are the names of the datasets in the `processing_set`, and the values are dictionaries with the coordinates and accompanying slices. If a coordinate is not included, all values will be selected.
- `task_coords`: The chunk of the parallel_coord that is assigned to this node.

In [None]:
from graphviper.graph_tools.coordinate_utils import (
    interpolate_data_coords_onto_parallel_coords,
)

node_task_data_mapping = interpolate_data_coords_onto_parallel_coords(
    parallel_coords, ps
)
display(HTML(dict_to_html(node_task_data_mapping)))

### Create a chunk function and map graph

The `map` function combines a `node_task_data_mapping` and a `node_task` to create the map portion of the graph. The `node_task` must be a function with a single dictionary input and a single output as is the `my_func` in the example below. The `map` function will pass the `input_parms` dictionary to the `node_task` and add the following items from the `node_task_data_mapping`:

- chunk_indices
- parallel_dims
- data_selection
- task_coords
- task_id

If local caching is enabled the following will also be included with the `input_params` dictionary:

- date_time
- viper_local_dir

In [None]:
from graphviper.graph_tools.map import map
import dask
from graphviper.utils.display import dict_to_html
from IPython.display import display, HTML


def my_func(input_parms):
    display(HTML(dict_to_html(input_parms)))

    print("*" * 30)
    return input_parms["test_input"]


input_parms = {}
input_parms["test_input"] = 42

graph = map(
    input_data=ps,
    node_task_data_mapping=node_task_data_mapping,
    node_task=my_func,
    input_parms=input_parms,
)

dask.visualize(graph, filename="map_graph")

In [None]:
graph

### Run Map Graph

In [None]:
dask.compute(graph)

### Reduce Graph

In [None]:
from graphviper.graph_tools import reduce
import numpy as np

def my_sum(graph_inputs, input_parms):
    print(graph_inputs)
    return np.sum(graph_inputs) + input_parms["test_input"]

input_parms = {}
input_parms["test_input"] = 5
graph_reduce = reduce(
    graph, my_sum, input_parms, mode="single_node"
)  # mode "tree","single_node"
dask.visualize(graph_reduce)

In [None]:
from graphviper.graph_tools import reduce
import numpy as np


def my_sum(graph_inputs, input_parms):
    print(graph_inputs)
    return np.sum(graph_inputs) + input_parms["test_input"]


input_parms = {}
input_parms["test_input"] = 5
graph_reduce = reduce(
    graph, my_sum, input_parms, mode="tree"
)  # mode "tree","single_node"
dask.visualize(graph_reduce)

### Run Map Reduce Graph

In [None]:
dask.compute(graph_reduce)

## Overlapping Frequency Map Reduce

### Create Parallel Coordinates

In [None]:
from graphviper.utils.display import dict_to_html
import dask

dask.config.set(scheduler="synchronous")
from xradio.vis.read_processing_set import read_processing_set
from IPython.display import HTML, display


ps = read_processing_set(
    ps_name="Antennae_North.cal.lsrk.split.vis.zarr",
    intents=["OBSERVE_TARGET#ON_SOURCE"],
)
ms_xds = ps.get(1)
n_chunks = 3

parallel_coords = {}
freq_coord = ms_xds.frequency.to_dict()
freq_coord["data_chunks"] = {
    0: freq_coord["data"][0:4],
    1: freq_coord["data"][3:7],
    2: freq_coord["data"][4:8],
}
parallel_coords["frequency"] = freq_coord

display(HTML(dict_to_html(parallel_coords["frequency"])))

### Create Node Task Data Mapping

In [None]:
from graphviper.graph_tools.coordinate_utils import (
    interpolate_data_coords_onto_parallel_coords,
)

node_task_data_mapping = interpolate_data_coords_onto_parallel_coords(
    parallel_coords, ps
)
display(HTML(dict_to_html(node_task_data_mapping)))

### Map Graph

In [None]:
from graphviper.graph_tools.map import map
import dask
from IPython.display import display, HTML
from xradio.vis.read_processing_set import read_processing_set


def my_func(input_parms):
    display(HTML(dict_to_html(input_parms)))

    print("*" * 30)
    return input_parms["test_input"]


# ['test_input', 'input_data_name', 'viper_local_dir', 'date_time', 'data_sel', 'chunk_coords', 'chunk_indx', 'chunk_id', 'parallel_dims']
input_parms = {}
input_parms["test_input"] = 42

ps = read_processing_set(
    ps_name="Antennae_North.cal.lsrk.split.vis.zarr",
    intents=["OBSERVE_TARGET#ON_SOURCE"],
)

graph = map(
    input_data=ps,
    node_task_data_mapping=node_task_data_mapping,
    node_task=my_func,
    input_parms=input_parms,
)

dask.visualize(graph, filename="map_graph")

### Run Map Graph

In [None]:
dask.compute(graph)

## Baseline and Frequency Map Reduce

### Create Parallel Coordinates

In [None]:
from graphviper.utils.display import dict_to_html
from graphviper.graph_tools.coordinate_utils import make_parallel_coord
import dask

dask.config.set(scheduler="synchronous")

from xradio.vis.read_processing_set import read_processing_set

from IPython.display import HTML, display

intents = ["OBSERVE_TARGET#ON_SOURCE"]
ps = read_processing_set(
    ps_name="Antennae_North.cal.lsrk.split.vis.zarr",
    intents=["OBSERVE_TARGET#ON_SOURCE"],
)
ms_xds = ps.get(1)

parallel_coords = {}

import xarray as xr
import numpy as np

n_chunks = 4
parallel_coords["baseline_id"] = make_parallel_coord(
    coord=ms_xds.baseline_id, n_chunks=n_chunks
)

n_chunks = 3
parallel_coords["frequency"] = make_parallel_coord(
    coord=ms_xds.frequency, n_chunks=n_chunks
)

display(HTML(dict_to_html(parallel_coords)))

### Create Node Task Data Mapping

In [None]:
from graphviper.graph_tools.coordinate_utils import (
    interpolate_data_coords_onto_parallel_coords,
)

node_task_data_mapping = interpolate_data_coords_onto_parallel_coords(
    parallel_coords, ps
)
display(HTML(dict_to_html(node_task_data_mapping)))

### Map Graph

In [None]:
from graphviper.graph_tools.map import map
import dask
from IPython.display import display, HTML


def my_func(input_parms):
    display(HTML(dict_to_html(input_parms)))

    print("*" * 30)
    return input_parms["test_input"]


# ['test_input', 'input_data_name', 'viper_local_dir', 'date_time', 'data_sel', 'chunk_coords', 'chunk_indx', 'chunk_id', 'parallel_dims']
input_parms = {}
input_parms["test_input"] = 42

graph = map(
    input_data=ps,
    node_task_data_mapping=node_task_data_mapping,
    node_task=my_func,
    input_parms=input_parms,
)

dask.visualize(graph, filename="map_graph")

### Run Map Graph

In [None]:
dask.compute(graph)

## Time Map Reduce

### Create Parallel Coordinates

In [None]:
from graphviper.graph_tools.coordinate_utils import make_parallel_coord
from graphviper.utils.display import dict_to_html
import dask

dask.config.set(scheduler="synchronous")

from xradio.vis.read_processing_set import read_processing_set
from IPython.display import HTML, display

intents = ["OBSERVE_TARGET#ON_SOURCE"]
ps = read_processing_set(
    ps_name="Antennae_North.cal.lsrk.split.vis.zarr",
    intents=["OBSERVE_TARGET#ON_SOURCE"],
)
ms_xds = ps.get(1)

parallel_coords = {}

import xarray as xr
import numpy as np

t0, t1, t2 = (ps.get(1).time, ps.get(0).time, ps.get(2).time)
time_coord = xr.concat([t0, t1, t2], dim="time").sortby("time").to_dict()
n_chunks = 4
parallel_coords["time"] = make_parallel_coord(coord=time_coord, n_chunks=n_chunks)
display(HTML(dict_to_html(parallel_coords["time"])))

### Create Node Task Data Mapping

In [None]:
from graphviper.graph_tools.coordinate_utils import (
    interpolate_data_coords_onto_parallel_coords,
)

node_task_data_mapping = interpolate_data_coords_onto_parallel_coords(
    parallel_coords, ps
)
display(HTML(dict_to_html(node_task_data_mapping)))

### Map Graph

In [None]:
from graphviper.graph_tools.map import map
import dask
from IPython.display import display, HTML


def my_func(input_parms):
    display(HTML(dict_to_html(input_parms)))

    print("*" * 30)
    return input_parms["test_input"]


# ['test_input', 'input_data_name', 'viper_local_dir', 'date_time', 'data_sel', 'chunk_coords', 'chunk_indx', 'chunk_id', 'parallel_dims']
input_parms = {}
input_parms["test_input"] = 42

graph = map(
    input_data=ps,
    node_task_data_mapping=node_task_data_mapping,
    node_task=my_func,
    input_parms=input_parms,
)

dask.visualize(graph, filename="map_graph")

### Run Map Graph

In [None]:
dask.compute(graph)