This notebook explores variability in hail's python (macro)-benchmarks when  
said benchmarks are executed on the hail batch service. The analyses within  
are based off the methods proposed in [1], albeit slightly modified for long  
running benchmarks. The goals of these analyses are  

- to determine if we can detect slowdowns of 5% or less reliably when running  
  benchmarks on hail batch.  
- to identify configurations (number of batch jobs x iterations) that allow us  
  to detect slowdowns efficiently (ie without excesssive time and money).  

[1] Laaber et al., Software Microbenchmarking in the Cloud. How Bad is it Really?  
    https://dl.acm.org/doi/10.1007/s10664-019-09681-1

In [None]:
from pathlib import Path

import plotly.io as pio
import yaml
from benchmark.tools import annotate_index, maybe
from benchmark.tools.impex import import_benchmarks
from benchmark.tools.plotting import (
    plot_iteration_against_time,
    plot_mean_time_per_instance,
)
from benchmark.tools.statistics import (
    laaber_mds,
    schultz_mds,
    variability,
)
from IPython.display import Pretty, clear_output, display
from plotly.offline import init_notebook_mode

import hail as hl

In [None]:
prefix = str(Path().absolute())
hl.init(backend='spark', quiet=True)

init_notebook_mode()
pio.renderers.default = 'notebook_connected'

### Import benchmark data

Benchmarks under `hail/python/benchmarks` are executed with a custom pytest  
plugin and their results are output as json lines (.jsonl). Unscrupulously,  
we use hail to analyse itself.

In [None]:
with hl.TemporaryDirectory() as tmpdir:
    ht = import_benchmarks(Path(f'{prefix}/in/benchmarks.jsonl'), tmpdir=tmpdir)
    ht = ht.checkpoint(f"{prefix}/out/benchmarks.ht")

benchmarks = ht.aggregate(hl.agg.collect_as_set(ht.path + hl.str('::') + ht.name))
benchmarks = sorted(benchmarks)
print(*benchmarks, sep='\n')

In these next sections, we'll estimate the number of iterations required for  
a benchmark to reach a steady-state - the so-called "burn-in" iterations. It  
would be nice to automate this. Laaber et al. reference some techniques but  
that's beyond the scope of this work for now. 

The process of estimating the number of burn-in iterations is subjective and   
requires some amount of eye-balling; by plotting iteration vs execution time  
for all instances, you select an iteration number after which execution time  
doesn't decay. Easier said than done because not all benchmarks read a steady-  
state. As always, use your best judgement.  

The following cell contains some values that I eye-balled. It's included as a  
separate cell in case you want to take them as granted and skip ahead.

In [None]:
first_stable_index = {
    'test_analyze_benchmarks': 5,
    'test_block_matrix_to_matrix_table_row_major': 4,
    'test_blockmatrix_write_from_entry_expr_range_mt_standardize': 8,
    'test_blockmatrix_write_from_entry_expr_range_mt': 10,
    'test_concordance': 2,
    'test_export_range_matrix_table_col_p100': 15,
    'test_export_range_matrix_table_entry_field_p100': 3,
    'test_export_range_matrix_table_row_p100': 8,
    'test_export_vcf': 8,
    'test_genetics_pipeline': 4,
    'test_gnomad_coverage_stats_optimized': 5,
    'test_gnomad_coverage_stats': 5,
    'test_group_by_collect_per_row': 5,
    'test_group_by_take_rekey': 10,
    'test_hwe_normalized_pca_blanczos_small_data_0_iterations': 8,
    'test_hwe_normalized_pca_blanczos_small_data_10_iterations': 8,
    'test_hwe_normalized_pca': 6,
    'test_import_and_transform_gvcf': 2,
    'test_import_bgen_filter_count': 18,
    'test_import_bgen_force_count_all': 4,
    'test_import_bgen_force_count_just_gp': 20,
    'test_import_bgen_info_score': 12,
    'test_import_gvcf_force_count': 2,
    'test_import_vcf_count_rows': 1,
    'test_import_vcf_write': 5,
    'test_join_partitions_table[10-10]': 4,
    'test_join_partitions_table[10-100]': 2,
    'test_join_partitions_table[10-1000]': 5,
    'test_join_partitions_table[100-10]': 10,
    'test_join_partitions_table[100-100]': 10,
    'test_join_partitions_table[100-1000]': 8,
    'test_join_partitions_table[1000-10]': 12,
    'test_join_partitions_table[1000-100]': 10,
    'test_join_partitions_table[1000-1000]': 8,
    'test_kyle_sex_specific_qc': 6,
    'test_large_range_matrix_table_sum': 5,
    'test_ld_prune_profile_25': 10,
    'test_linear_regression_rows': 10,
    'test_logistic_regression_rows_wald': 5,
    'test_make_ndarray': 5,
    'test_matrix_table_aggregate_entries': 8,
    'test_matrix_table_array_arithmetic': 20,
    'test_matrix_table_call_stats_star_star': 8,
    'test_matrix_table_cols_show': 5,
    'test_matrix_table_decode_and_count_just_gt': 5,
    'test_matrix_table_decode_and_count': 8,
    'test_matrix_table_entries_show': 4,
    'test_matrix_table_entries_table_no_key': 4,
    'test_matrix_table_entries_table': 10,
    'test_matrix_table_filter_entries_unfilter': 8,
    'test_matrix_table_filter_entries': 6,
    'test_matrix_table_many_aggs_col_wise': 3,
    'test_matrix_table_many_aggs_row_wise': 2,
    'test_matrix_table_nested_annotate_rows_annotate_entries': 4,
    'test_matrix_table_rows_force_count': 20,
    'test_matrix_table_rows_is_transition': 5,
    'test_matrix_table_rows_show': 10,
    'test_matrix_table_scan_count_cols_2': 20,
    'test_matrix_table_scan_count_rows_2': 5,
    'test_matrix_table_show': 7,
    'test_matrix_table_take_col': 10,
    'test_matrix_table_take_entry': 8,
    'test_matrix_table_take_row': 10,
    'test_minimal_detectable_slowdown[laaber_mds]': 5,
    'test_minimal_detectable_slowdown[schultz_mds]': 6,
    'test_mt_group_by_memory_usage': 5,
    'test_mt_localize_and_collect': 5,
    'test_ndarray_addition': 10,
    'test_ndarray_matmul_float64': 6,
    'test_ndarray_matmul_int64': 10,
    'test_pc_relate_5k_5k': 4,
    'test_pc_relate': 3,
    'test_per_row_stats_star_star': 10,
    'test_python_only_10k_combine': 6,
    'test_python_only_10k_transform': 10,
    'test_read_decode_gnomad_coverage': 10,
    'test_read_force_count_partitions[10]': 10,
    'test_read_force_count_partitions[100]': 8,
    'test_read_force_count_partitions[1000]': 10,
    'test_read_with_index[1000]': 8,
    'test_sample_qc': 3,
    'test_sentinel_cpu_hash_1': 5,
    'test_sentinel_read_gunzip': 10,
    'test_shuffle_key_by_aggregate_bad_locality': 8,
    'test_shuffle_key_by_aggregate_good_locality': 5,
    'test_shuffle_key_rows_by_4096_byte_rows': 2,
    'test_shuffle_key_rows_by_65k_byte_rows': 4,
    'test_shuffle_key_rows_by_mt': 10,
    'test_shuffle_order_by_10m_int': 10,
    'test_split_multi_hts': 10,
    'test_split_multi': 8,
    'test_sum_table_of_ndarrays': 5,
    'test_table_aggregate_approx_cdf': 4,
    'test_table_aggregate_array_sum': 5,
    'test_table_aggregate_counter': 10,
    'test_table_aggregate_downsample_dense': 5,
    'test_table_aggregate_downsample_worst_case': 10,
    'test_table_aggregate_int_stats': 5,
    'test_table_aggregate_linreg': 8,
    'test_table_aggregate_take_by_strings': 10,
    'test_table_annotate_many_flat': 5,
    'test_table_big_aggregate_compilation': 4,
    'test_table_big_aggregate_compile_and_execute': 2,
    'test_table_expr_take': 25,
    'test_table_foreign_key_join[1000000-1000]': 3,
    'test_table_foreign_key_join[1000000-1000000]': 3,
    'test_table_group_by_aggregate_sorted': 8,
    'test_table_group_by_aggregate_unsorted': 7,
    'test_table_import_ints_impute': 8,
    'test_table_import_ints': 4,
    'test_table_import_strings': 3,
    'test_table_key_by_shuffle': 3,
    'test_table_python_construction': 10,
    'test_table_range_array_range_force_count': 6,
    'test_table_range_force_count': 10,
    'test_table_range_join[1000000000-1000]': 20,
    'test_table_range_join[1000000000-1000000000]': 10,
    'test_table_range_means': 10,
    'test_table_read_force_count_ints': 5,
    'test_table_read_force_count_strings': 4,
    'test_table_scan_prev_non_null': 20,
    'test_table_scan_sum_1k_partitions': 2,
    'test_table_show': 12,
    'test_table_take': 20,
    'test_test_head_and_tail_region_memory': 10,
    'test_test_inner_join_region_memory': 10,
    'test_test_left_join_region_memory': 10,
    'test_test_map_filter_region_memory': 10,
    'test_union_partitions_table[10-10]': 4,
    'test_union_partitions_table[10-100]': 5,
    'test_union_partitions_table[10-1000]': 8,
    'test_union_partitions_table[100-10]': 16,
    'test_union_partitions_table[100-100]': 5,
    'test_union_partitions_table[100-1000]': 15,
    'test_union_partitions_table[1000-10]': 10,
    'test_union_partitions_table[1000-100]': 6,
    'test_union_partitions_table[1000-1000]': 15,
    'test_variant_and_sample_qc_nested_with_filters_2': 2,
    'test_variant_and_sample_qc_nested_with_filters_4_counts': 8,
    'test_variant_and_sample_qc_nested_with_filters_4': 10,
    'test_variant_and_sample_qc': 10,
    'test_variant_qc': 5,
    'test_vds_combiner_chr22': 10,
    'test_write_profile_mt': 8,
    'test_write_range_matrix_table_p100': 2,
    'test_write_range_table[10000000-10]': 3,
    'test_write_range_table[10000000-100]': 3,
    'test_write_range_table[10000000-1000]': 6,
}

Short of an accurate algorithm for computing this, you, noble reader, are  
tasked with the mind-numbing task of looking at graphs and picking numbers.  
This is an iterative process. You'll likely lose the will to live mid-way.  

Persevere, friend. Your sacrifice will not go unrewarded.

You'll be shown a plot of iteration vs execution time for each benchmark. An  
x-axis intercept marks the current value of the first stable index.  

You'll then be prompted to enter a new first stable index for each benchmark  
until you arrive at a fixed point. Press Enter at the prompt to skip the current  
benchmark. Press Ctrl+C at any time to give up.  

Note that some benchmarks never really reach stability. In this case try to  
pick a value that compromises between cost and  accuracy (ie if a benchmark is  
really slow, we don't want to be doing tons of burn in iterations, whereas for  
a fast benchmark we could justify more).  

At the end, the first_stable_index dict will be printed. Please commit values  
for new benchmarks or if your estimates differ from mine.  

Good luck.

In [None]:
ht = hl.read_table(f'{prefix}/out/benchmarks.ht')
names: list[str] = ht.aggregate(hl.agg.collect_as_set(ht.name))  # type: ignore
names = sorted(names)

while len(names) != 0:
    __new_names, names = names, []
    for fig in plot_iteration_against_time(ht, __new_names, first_stable_index):
        clear_output(wait=True)
        pio.renderers.default = 'notebook'

        name: str = fig.labels.title  # type: ignore
        cur_index = first_stable_index.get(name)

        try:
            fig.show()
        except:
            continue

        try:
            new_index = maybe(int, input('Enter the first stable index (or blank keep same)') or None)
            if new_index is not None and new_index != cur_index:
                first_stable_index[name] = new_index
                names.append(name)
        except KeyboardInterrupt as _:
            break

first_stable_index

Before conducting our analysis, we need to clean our data. That means we exclude
- burn-in iterations
- iterations that timed-out or otherwise failed
- outliers with significant divergence from median execution time.
- benchmarks that never reach a steady-state (maybe they should be deleted?)
- instances that do not have sufficient iterations for analysis
- benchmarks that do not have sufficient instances for analysis

In [None]:
def remove_burn_in_iterations(ht: hl.Table, first_stable_index: dict[str, int]) -> hl.Table:
    ht = ht.annotate_globals(first_stable_index=first_stable_index)
    return ht.select(
        instances=ht.instances.map(
            lambda instance: instance.annotate(
                iterations=hl.filter(
                    lambda t: t.idx >= ht.first_stable_index.get(ht.name, 0),
                    annotate_index(instance.iterations),
                ),
            )
        ),
    )


def remove_outliers(ht: hl.Table, factor: hl.Float64Expression) -> hl.Table:
    return ht.select(
        instances=ht.instances.map(
            lambda instance: instance.annotate(
                iterations=hl.bind(
                    lambda median: instance.iterations.filter(
                        lambda t: hl.max([t.time, median]) / hl.min([t.time, median]) < factor
                    ),
                    hl.median(instance.iterations.map(lambda t: t.time)),
                )
            ),
        ),
    )


def keep_names(ht: hl.Table, names: hl.SetExpression) -> hl.Table:
    return ht.filter(names.contains(ht.name))


def remove_failed_iterations(ht: hl.Table) -> hl.Table:
    return ht.annotate(
        instances=ht.instances.map(
            lambda i: i.annotate(
                iterations=hl.filter(
                    lambda t: ~t.timed_out | hl.is_missing(t.failure),
                    i.iterations,
                ),
            )
        ),
    )


def remove_non_viable_instances(
    ht: hl.Table,
    min_instances: hl.Int32Expression,
    min_iterations: hl.Int32Expression,
) -> hl.Table:
    ht = ht.annotate(
        instances=hl.filter(
            lambda instance: hl.len(instance.iterations) >= min_iterations,
            ht.instances,
        ),
    )

    return ht.filter(hl.len(ht.instances) >= min_instances)


ht = hl.read_table(f'{prefix}/out/benchmarks.ht')
all: list[str] = ht.aggregate(hl.agg.collect_as_set(ht.name))  # type: ignore

ht = keep_names(ht, hl.set([n for n in all if n in first_stable_index]))
ht = remove_burn_in_iterations(ht, first_stable_index)
ht = remove_failed_iterations(ht)
ht = remove_outliers(ht, hl.float64(10))
ht = remove_non_viable_instances(ht, hl.int(50), hl.int(50))

ht = ht.checkpoint(f'{prefix}/out/filtered.ht', overwrite=True)

benchmarks = ht.aggregate(hl.agg.collect_as_set(ht.name))

print('Filtered:', *(n for n in all if n not in set(benchmarks)), sep='\n')

## Benchmark Variability

The next cells concern themselves with examining the variability of benchmarks,  
both on a per-instance basis as well as total.

The first cell plots mean execution time per instance to look for distinct modes.  
If present and significant, it may be harder to identify performance differences  
for this benchmark between instances.

The second cell quantifies variability by computing the total and per-instance  
coefficient of variation, defined as the ratio of stdev and mean.

In [None]:
ht = hl.read_table(f'{prefix}/out/filtered.ht')
for f in plot_mean_time_per_instance(ht):
    clear_output(wait=True)
    pio.renderers.default = 'notebook'
    f.show()
    input()

In [None]:
ht = hl.read_table(f'{prefix}/out/filtered.ht')
ht = ht.select(instances=ht.instances.iterations.time)
ht.select(**variability(ht)).show()

## Detecting Slowdowns

In the following section, we'll see what configurations are required to reliably  
detect slowdowns via two methods: that of Laaber at al. and another of Patrick's  
devising. See the comments in the source code for details of each.

In a later section, we'll use these results to select configurations that minimise  
false positives and maximise the minimal detectable slowdown, allowing for cost  
and time.

The analyses are fairly computationally intensive and I recommend switching to  
the `batch` backend for this task.

In [None]:
# optional - switch to the batch backend for MDS computations
hl.stop()
hl.init(backend='batch')

new_prefix = hl.current_backend().remote_tmpdir
hl.current_backend().fs.copy(f'{prefix}/out/filtered.ht', f'{new_prefix}/out/filtered.ht')
prefix, new_prefix = new_prefix, prefix

In [None]:
ht = hl.read_table(f'{prefix}/out/filtered.ht')
ht = ht.select(instances=ht.instances.iterations.time)

laaber_mds(ht).write(f'{prefix}/out/laaber-mds.ht')
schultz_mds(ht).write(f'{prefix}/out/schultz-mds.ht')

In [None]:
# optional - switch back to spark for plotting and localise the results
for mds in ['laaber', 'schultz']:
    hl.current_backend().fs.copy(f'{prefix}/out/{mds}-mds.ht', f'{new_prefix}/out/{mds}-mds.ht')

prefix, new_prefix = new_prefix, prefix
hl.stop()
hl.init(backend='spark')

## Benchmark Configurations

Now that we've computed the slowdowns detectable by each configuration, we need  
to select a configuration that reduces
- the rate of false positives
- the size of the change we can detect

The end result of this work is to produce a configuration file for automating  
benchmarking in such a way as to reliably detect slowdowns.

Starting with slowdown a simulated slowdown of 1, you'll pick a configuration of  
`ninstances, niterations` where the likelihood of detecting a change is as close  
to zero as possible. I'm assuming that this acts as a threshold configuration,  
after which configuration are no more likely to detect false positives.

Removing all configurations less than this threshold, you'll then scan the  
detectable slowdowns for configurations whose rate of detection of a given  
slowdown is as close as possible to 1. For some benchmarks, it may be impossible  
to detect small slowdowns. Enter your configuration in the prompt, leave the   
prompt empty to skip or hit Ctrl+C to give up.

At the end, a yaml document will be printed to the output.

In [None]:
t = hl.read_table(f'{prefix}/out/filtered.ht')
t = t.select(instances=t.instances.iterations.time)
t = t.select(
    mean=t.instances.aggregate(lambda k: hl.agg.explode(hl.agg.mean, k)),
    cv=variability(t).total,
)

benchmarks = t.aggregate(hl.agg.collect((t.path, t.name, t.mean, t.cv)))

In [None]:
laaber, schultz = [
    (
        t := hl.read_table(f'{prefix}/out/{m}-mds.ht'),
        t := t._key_by_assert_sorted('path', 'name', 'slowdown', 'ninstances', 'niterations'),
        t.select('ibs', 'tbs'),
    )[-1]
    for m in ('laaber', 'schultz')
]

mds = laaber.select(laaber=laaber.row_value, schultz=schultz[laaber.key])

results: list[dict] = []

for path, name, rt, cv in benchmarks:
    try:
        cur = {
            'item': f'{path}::{name}',
            'burn_in': first_stable_index[name],
            'mean': float(f'{rt:.3f}'),
            'cv': float(f'{cv:.3f}'),
        }

        info = Pretty(yaml.dump(data=[cur], sort_keys=False))

        display(info, clear=True)

        # The user will first pick a threshold configuration that minimises false
        # positives.
        t = mds.filter((mds.path == path) & (mds.name == name))
        t.filter(t.slowdown == 1).show(100_000)

        # Be fault-tolerant - if the user makes a typo then try again
        skip = False
        while True:
            config = input(
                'Enter `ninstances, niterations` that minimises the likelihood of detecting a false positive.'
            )

            if not config:
                skip = True
                break

            try:
                m, n = [int(x.strip()) for x in config.split(',')]
                break
            except Exception as _:
                pass

        if skip:
            continue

        display(info, clear=True)

        # Now the user will scan the table for configurations greater than this
        # threshold that can reliably detect the smallest slowdown, ie the rate of
        # detection is as close as possible to 1

        t.filter((t.slowdown > 1) & (hl.tuple([t.ninstances, t.niterations]) >= (m, n))).show(100_000)

        # again with the fault-tolerance.
        while True:
            config = input(
                'Enter `slowdown, ninstances, niterations` that maximises the likelihood of detecting a slowdown.'
            )

            if not config:
                break

            try:
                slowdown, m, n = [x.strip() for x in config.split(',')]
                results.append({
                    **cur,
                    'config': [
                        {
                            'slowdown': float(slowdown),
                            'instances': int(m),
                            'iterations': int(n),
                        }
                    ],
                })
                break
            except Exception as _:
                pass

    except KeyboardInterrupt as _:
        break

# print as yaml because everything needs to be yaml because yaml is yaml
display(Pretty(yaml.dump(data=results, sort_keys=False)), clear=True)