### Client and Setup

In [None]:
import dask
import time
import random
import webbrowser

import numpy as np
import xarray as xr

import matplotlib
import toolviper.dask.client as client

### Start Dask Client

In [None]:
client = client.local_client(
    cores=10,
    log_params={
        "log_to_file":False,
        "log_to_term":True,
        "log_level":"DEBUG" 
    },
    worker_log_params={
        "log_to_file":False,
        "log_to_term":True,
        "log_level":"DEBUG" 
    }
)

# Spawn dashboard window in a seperate tab,
# comment out if you don't want this to spawn.
webbrowser.open(url=client.dashboard_link)

### Workflow Skeleton for hsd_blflag
<img src="media/hsd_blflag-workflow.png" width="750">

In [None]:

# Simple function to generate a time delay and simulate 
# data processing
def generate_delay(n=1, m=2):
    time.sleep(random.uniform(n, m))

# Build a simple Dask dataset based on a given set of axes
def simulate(field, spw, polarization, antenna, row):
    data_shape = {
        "field": [f"field_{i}" for i in range(field)],
        "spw": [f"spw_{i}" for i in range(spw)],
        "polarization": polarization,
        "antenna": [f"antenna_{i}" for i in range(antenna)],
        "row": [i for i in range(row)]
    }

    dataset = xr.Dataset(
        coords=data_shape,
        data_vars=dict(
            DATA=(
                list(data_shape.keys()), 
                np.zeros((field, spw, len(polarization), antenna, row))
            )
        )
    )

    return dataset

# Recursive processing to build graph 
def process_(dataset, axes, function, shape, delayed_list, previous):
    import copy
    
    if len(axes) == 0:
        # Since I don't want to write a complicated algorithm to get this right yet, lets just assume 
        # we want to process the final data shape in this way.
        np.prod(shape)
        delayed_list.append(dask.delayed(function)(previous))
        return
    
    axis = axes[0]
    for i, key in enumerate(dataset.coords[axis].data):
        process_(
            dataset=dataset, 
            axes=axes[1:], 
            function=function, 
            shape=shape, 
            delayed_list=delayed_list, 
            previous=previous
        )

# Distributed function to build graph and map functions
def distribute(dataset, axes, function, previous=None):
    # Find final data shape. This is only useful for the
    # walking skeleton example.
    dims = dict(dataset.DATA.sizes)

    delayed_list = []
    
    for axis_ in axes:
        dims.pop(axis_)

    shape = [value for key, value in dims.items()]
    process_(
        dataset=dataset, 
        axes=axes, 
        function=function, 
        shape=shape,
        delayed_list=delayed_list,
        previous=previous
    )

    return delayed_list

In [None]:
# Build simulated dataset with desired dimensions
dataset = simulate(field=2, spw=1, polarization=["XX", "YY"], antenna=2, row=2)

### Workflow Job Stages

In [None]:
# field, spw, polariazation
def flag_summary(dummy):
    flag_data_()

# field, spw, antenna, row, polarization
def calculate_statistics(previous):
    read_data_()

    calculate_stats_()

    write_stats_()

# field, spw, antenna, pol
def apply_flag_metric(previous):
    get_flags_from_stats_()

    apply_stats_flag_()

    flag_expected_rms_()

    flag_summary_()

# field, spw, antenna
def generate_flag(previous):
    generate_flag_()

# field, spw, antenna
def apply_flag_ms(previous):
    flag_data_()

# field, spw, polarization
def generate_plots(previous):
    generate_plots_()

# serial
def generate_weblog(previous):
    generate_weblog_()

    return previous

# serial
def quality_assurance(previous):
    quality_assurance_()

    return previous

### Workflow Subtasks

In [None]:
def flag_data_():
    generate_delay()

def read_data_():
    generate_delay()

def calculate_stats_():
    generate_delay()

def write_stats_():
    generate_delay()

def get_flags_from_stats_():
    generate_delay()

def apply_stats_flag_():
    generate_delay()

def flag_expected_rms_():
    generate_delay()
    
def flag_summary_():
    generate_delay()

def generate_flag_():
    generate_delay()

def generate_plots_():
    generate_delay()

def generate_weblog_():
    generate_delay()

def quality_assurance_():
    generate_delay()

# This is mainly to make a gather on the graph,
# I don't know if there will be a gather in the 
# final product.
def gather(result):
    return result

### Run Processing Pipeline

In [None]:
# Flag Summary (1)
flag_summary_results = distribute(dataset=dataset, axes=["field", "spw", "polarization"], function=flag_summary)
gather_results = dask.delayed(gather)(flag_summary_results)

# Calculate Statistics (2)
calculate_stats_results = distribute(dataset=dataset, axes=["field", "spw", "antenna", "polarization", "row"], function=calculate_statistics, previous=gather_results)
gather_results = dask.delayed(gather)(calculate_stats_results)

# Apply Flag Metric (3)
flag_metric_results = distribute(dataset=dataset, axes=["field", "spw", "antenna", "polarization"], function=apply_flag_metric, previous=gather_results)
gather_results = dask.delayed(gather)(flag_metric_results)

# Generate Flag Commandline (4)
generate_flag_results = distribute(dataset=dataset, axes=["field", "spw", "antenna"], function=generate_flag, previous=gather_results)
gather_results = dask.delayed(gather)(generate_flag_results)

# Apply Flag to MS (5)
apply_flag_results = dask.delayed(apply_flag_ms)(gather_results)

# Generate Plots
generate_plots_results = distribute(dataset=dataset, axes=["field", "spw", "antenna"], function=generate_plots, previous=apply_flag_results)
gather_results = dask.delayed(gather)(generate_plots_results)

# Flag Summary
flag_summary_results = distribute(dataset=dataset, axes=["field", "spw", "polarization"], function=flag_summary, previous=gather_results)
gather_results = dask.delayed(gather)(flag_summary_results)

# Generate Weblog
gather_results = dask.delayed(generate_weblog)(gather_results)

# Quality Assurance 
gather_results = dask.delayed(quality_assurance)(gather_results)

In [None]:
dask.visualize(gather_results, rankdir="LR", verbose=False)

In [None]:
dask.compute(gather_results)

In [None]:
client.shutdown()