[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/casangi/graphviper/blob/main/docs/graph_building_tutorial_processing_set.ipynb)

# GraphVIPER Tutorial: Processing Set

In this tutorial examples will be given of how GraphVIPER builds Dask graphs and maps data from a processing set (a collection of measurement set v4) to the nodes of the graph.

## Install GraphVIPER

In [None]:
import os

from importlib.metadata import version

try:
    import graphviper

    print("GraphVIPER version", version("graphviper"), "already installed.")
except ImportError as e:
    print(e)
    print("Installing GraphVIPER")

    os.system("pip install graphviper")

    import xradio

    print("GraphVIPER version", version("graphviper"), " installed.")

## Download and Convert Dataset

In [2]:
from xradio.data.datasets import download

download(file="Antennae_North.cal.lsrk.split.ms", source="dropbox")

from xradio.vis.convert_msv2_to_processing_set import convert_msv2_to_processing_set

partition_scheme = "ddi_intent_field"
chunks_on_disk = {"frequency": 3}
infile = "Antennae_North.cal.lsrk.split.ms"
outfile = "Antennae_North.cal.lsrk.split.vis.zarr"
convert_msv2_to_processing_set(
    in_file=infile,
    out_file=outfile,
    partition_scheme=partition_scheme,
    parallel=False,
    overwrite=True,
    main_chunksize=chunks_on_disk,
)

## Setup Dask Cluster
To simplify things we are going to start of by just using a single thread (everything will run in serial).

In [None]:
# from graphviper.dask.client import local_client
# viper_client = local_client(cores=2, memory_limit="4GB")
# viper_client

import dask
dask.config.set(scheduler="synchronous")

## Inspect Processing Set

The read_processing_set is a lazy function, so no data is loaded into memory, only metadata.

In [None]:
import pandas as pd

pd.options.display.max_colwidth = 100
ps_name = "Antennae_North.cal.lsrk.split.vis.zarr"

from xradio.vis.read_processing_set import read_processing_set

intents = ["OBSERVE_TARGET#ON_SOURCE"] 
fields = None
ps = read_processing_set(
    ps_name="Antennae_North.cal.lsrk.split.vis.zarr",
    intents=intents,
    fields=fields,
)
display(ps.summary())

## Inspect a single ms_v4

In [None]:
ms_xds = ps['Antennae_North.cal.lsrk.split_ddi_0_intent_OBSERVE_TARGET#ON_SOURCE_field_id_0']
ms_xds

## Frequency Map Reduce
The parallel coordinates determine the parallelism of the map graph. Each chunk in the parallel coordinates represents a selection criterion used for the subselected processing set that is sent to a node. 

The parallel_coords is a dictionary where each key represents a dimension coordinate that appears in your data. For ms_v4, the options would be time, baseline/antenna, frequency, and polarization. The make_parallel_coord function will convert any XRADIO measures into a parallel coordinate. In addition, convenient functions have been created: make_time_coord and make_frequency_coord that will create numpy arrays.

### Create Parallel Coordinates

In [None]:
from graphviper.graph_tools.coordinate_utils import make_parallel_coord
from graphviper.utils.display import dict_to_html
from IPython.display import HTML, display

parallel_coords = {}
n_chunks = 3
parallel_coords["frequency"] = make_parallel_coord(
    coord=ms_xds.frequency, n_chunks=n_chunks
)
display(HTML(dict_to_html(parallel_coords["frequency"])))

In [None]:
from graphviper.graph_tools.coordinate_utils import make_frequency_coord

n_chunks = 3

coord = make_frequency_coord(
    freq_start=343928096685.9587,
    freq_delta=11231488.981445312,
    n_channels=8,
    velocity_frame="lsrk",
)
parallel_coords["frequency"] = make_parallel_coord(
    coord=ms_xds.frequency, n_chunks=n_chunks
)
display(HTML(dict_to_html(parallel_coords["frequency"])))

### Create Node Task Data Mapping

In [None]:
from graphviper.graph_tools.coordinate_utils import interpolate_data_coords_onto_parallel_coords
node_task_data_mapping = interpolate_data_coords_onto_parallel_coords(parallel_coords, ps)
display(HTML(dict_to_html(node_task_data_mapping)))

### Create a chunk function and map graph

In [None]:
from graphviper.graph_tools.map import map
import dask
from graphviper.utils.display import dict_to_html
from IPython.display import display, HTML


def my_func(input_parms):
    display(HTML(dict_to_html(input_parms)))

    print("*" * 30)
    return input_parms["test_input"]


# ['test_input', 'input_data_name', 'viper_local_dir', 'date_time', 'data_sel', 'chunk_coords', 'chunk_indx', 'chunk_id', 'parallel_dims']
input_parms = {}
input_parms["test_input"] = 42

graph = map(
    input_data=ps,
    node_task_data_mapping=node_task_data_mapping,
    node_task=my_func,
    input_parms=input_parms,
)

dask.visualize(graph, filename="map_graph")

In [None]:
graph

### Run Map Graph

In [None]:
dask.compute(graph)

### Reduce Graph

In [None]:
from graphviper.graph_tools import reduce
import numpy as np


def my_sum(graph_inputs, input_parms):
    print(graph_inputs)
    return np.sum(graph_inputs) + input_parms["test_input"]


input_parms = {}
input_parms["test_input"] = 5
graph_reduce = reduce(
    graph, my_sum, input_parms, mode="single_node"
)  # mode "tree","single_node"
dask.visualize(graph_reduce)

### Run Map Reduce Graph

In [None]:
dask.compute(graph_reduce)

## Overlapping Frequency Map Reduce

### Create Parallel Coordinates

In [None]:
from graphviper.utils.display import dict_to_html
import dask
dask.config.set(scheduler="synchronous")
from xradio.vis.read_processing_set import read_processing_set
from IPython.display import HTML, display


ps = read_processing_set(
    ps_name="Antennae_North.cal.lsrk.split.vis.zarr",
    intents=["OBSERVE_TARGET#ON_SOURCE"],
)
ms_xds = ps.get(1)
n_chunks = 3

parallel_coords = {}
freq_coord = ms_xds.frequency.to_dict()
freq_coord["data_chunks"] = {
    0: freq_coord["data"][0:4],
    1: freq_coord["data"][3:7],
    2: freq_coord["data"][4:8],
}
parallel_coords["frequency"] = freq_coord

display(HTML(dict_to_html(parallel_coords["frequency"])))

### Create Node Task Data Mapping

In [None]:
from graphviper.graph_tools.coordinate_utils import interpolate_data_coords_onto_parallel_coords
node_task_data_mapping = interpolate_data_coords_onto_parallel_coords(parallel_coords, ps)
display(HTML(dict_to_html(node_task_data_mapping)))

### Map Graph

In [None]:
from graphviper.graph_tools.map import map
import dask
from IPython.display import display, HTML
from xradio.vis.read_processing_set import read_processing_set


def my_func(input_parms):
    display(HTML(dict_to_html(input_parms)))

    print("*" * 30)
    return input_parms["test_input"]


# ['test_input', 'input_data_name', 'viper_local_dir', 'date_time', 'data_sel', 'chunk_coords', 'chunk_indx', 'chunk_id', 'parallel_dims']
input_parms = {}
input_parms["test_input"] = 42

ps = read_processing_set(
    ps_name="Antennae_North.cal.lsrk.split.vis.zarr",
    intents=["OBSERVE_TARGET#ON_SOURCE"],
)

graph = map(
    input_data=ps,
    node_task_data_mapping=node_task_data_mapping,
    node_task=my_func,
    input_parms=input_parms,
)

dask.visualize(graph, filename="map_graph")

### Run Map Graph

In [None]:
dask.compute(graph)

## Baseline and Frequency Map Reduce

### Create Parallel Coordinates

In [None]:
from graphviper.utils.display import dict_to_html
from graphviper.graph_tools.coordinate_utils import make_parallel_coord
import dask

dask.config.set(scheduler="synchronous")

from xradio.vis.read_processing_set import read_processing_set

from IPython.display import HTML, display

intents = ["OBSERVE_TARGET#ON_SOURCE"]
ps = read_processing_set(
    ps_name="Antennae_North.cal.lsrk.split.vis.zarr",
    intents=["OBSERVE_TARGET#ON_SOURCE"],
)
ms_xds = ps.get(1)

parallel_coords = {}

import xarray as xr
import numpy as np

n_chunks = 4
parallel_coords["baseline_id"] = make_parallel_coord(
    coord=ms_xds.baseline_id, n_chunks=n_chunks
)

n_chunks = 3
parallel_coords["frequency"] = make_parallel_coord(
    coord=ms_xds.frequency, n_chunks=n_chunks
)

display(HTML(dict_to_html(parallel_coords)))

### Create Node Task Data Mapping

In [None]:
from graphviper.graph_tools.coordinate_utils import interpolate_data_coords_onto_parallel_coords
node_task_data_mapping = interpolate_data_coords_onto_parallel_coords(parallel_coords, ps)
display(HTML(dict_to_html(node_task_data_mapping)))

### Map Graph

In [None]:
from graphviper.graph_tools.map import map
import dask
from IPython.display import display, HTML


def my_func(input_parms):
    display(HTML(dict_to_html(input_parms)))

    print("*" * 30)
    return input_parms["test_input"]


# ['test_input', 'input_data_name', 'viper_local_dir', 'date_time', 'data_sel', 'chunk_coords', 'chunk_indx', 'chunk_id', 'parallel_dims']
input_parms = {}
input_parms["test_input"] = 42

graph = map(
    input_data=ps,
    node_task_data_mapping=node_task_data_mapping,
    node_task=my_func,
    input_parms=input_parms,
)

dask.visualize(graph, filename="map_graph")

### Run Map Graph

In [None]:
dask.compute(graph)

## Time Map Reduce

### Create Parallel Coordinates

In [None]:
from graphviper.graph_tools.coordinate_utils import make_parallel_coord
from graphviper.utils.display import dict_to_html
import dask

dask.config.set(scheduler="synchronous")

from xradio.vis.read_processing_set import read_processing_set
from IPython.display import HTML, display

intents = ["OBSERVE_TARGET#ON_SOURCE"]
ps = read_processing_set(
    ps_name="Antennae_North.cal.lsrk.split.vis.zarr",
    intents=["OBSERVE_TARGET#ON_SOURCE"],
)
ms_xds = ps.get(1)

parallel_coords = {}

import xarray as xr
import numpy as np

t0, t1, t2 = (ps.get(1).time, ps.get(0).time, ps.get(2).time)
time_coord = xr.concat([t0, t1, t2], dim="time").sortby("time").to_dict()
n_chunks = 4
parallel_coords["time"] = make_parallel_coord(coord=time_coord, n_chunks=n_chunks)
display(HTML(dict_to_html(parallel_coords["time"])))

### Create Node Task Data Mapping

In [None]:
from graphviper.graph_tools.coordinate_utils import interpolate_data_coords_onto_parallel_coords
node_task_data_mapping = interpolate_data_coords_onto_parallel_coords(parallel_coords, ps)
display(HTML(dict_to_html(node_task_data_mapping)))

### Map Graph

In [None]:
from graphviper.graph_tools.map import map
import dask
from IPython.display import display, HTML


def my_func(input_parms):
    display(HTML(dict_to_html(input_parms)))

    print("*" * 30)
    return input_parms["test_input"]


# ['test_input', 'input_data_name', 'viper_local_dir', 'date_time', 'data_sel', 'chunk_coords', 'chunk_indx', 'chunk_id', 'parallel_dims']
input_parms = {}
input_parms["test_input"] = 42

graph = map(
    input_data=ps,
    node_task_data_mapping=node_task_data_mapping,
    node_task=my_func,
    input_parms=input_parms,
)

dask.visualize(graph, filename="map_graph")

### Run Map Graph

In [None]:
dask.compute(graph)