# Ensemble Batch Showcase

`Ensemble.batch` is a versatile function that allows users to pass in external functions that operate on groupings of `Ensemble` data, most commonly these are functions that calculate something per lightcurve. Because external functions can have a huge variety of inputs and outputs, this notebook serves as a collection of example functions and how `batch` can be used with them. The hope is that there is a function here similar to a function that you are trying to apply via `batch` so that example can be used as a template for getting your function to work.

### Generate some toy data and create an Ensemble

In [None]:
from tape import Ensemble, ColumnMapper, TapeFrame
import numpy as np
import pandas as pd
import sys

In [None]:
# Generate some fake data

np.random.seed(1)

obj_ids = []
mjds = []
for i in range(10,110):
    obj_ids.append(np.array([i]*1250))
    mjds.append(np.arange(0.,1250.,1.))
obj_ids = np.concatenate(obj_ids)
mjds = np.concatenate(mjds)

flux = 10*np.random.random(125000)
err = flux/10
band = np.random.choice(['g','r'], 125000)

source_dict = {"id":obj_ids, "mjd":mjds,"flux":flux,"err":err,"band":band}

In [None]:
# Load the data into an Ensemble
ens = Ensemble()

ens.from_source_dict(source_dict, column_mapper = ColumnMapper(id_col="id",
                                                              time_col="mjd",
                                                              flux_col="flux",
                                                              err_col="err",
                                                              band_col="band"))

## Case 1: A Simple Mean

We define a simple function that takes in an array-like argument, `flux`, and returns it's mean.

In [None]:
# Case 1: Simple
def my_mean(flux):
    return np.mean(flux)

my_mean([1,2,3,4,5])

To run the `my_mean` function with `Ensemble.batch`, we simply pass the function, and the argument(s) as separate function arguments. In this case, we pass "flux" as a string, as batch will grab the data at that column label to evaluate on.

In [None]:
# Default batch
res1 = ens.batch(my_mean, "flux") # "flux" is provided to have TAPE pass the "flux" column data along to my_mean
res1.compute() # Compute to see the result

By default, `Ensemble.batch` groups each lightcurve together (grouping on the specified id column). However, batch also support custom grouping assignments, as below we instead pass `on=["band"]`, letting batch know to calculate the mean for all data from each band.

In [None]:
# Batch with custom grouping

res2 = ens.batch(my_mean, "flux", on=["band"])
res2.compute()

This can be extended to more than just a single column, as below we group by id and then sub-group by band. In `Pandas`, an operation like this would return a multi-index, but due to `Dask` not supporting multi-indexes we return sub-groupings as columns.

In [None]:
# Multi-level groupbys

res3 = ens.batch(my_mean, "flux", on=["id", "band"])
res3.compute()

Sub-grouping by photometric band is a use case we expect to be common in TAPE workflows, and so there is the `by_band` kwarg available within batch. This will ensure that the last sub-grouping level is on band and will return independent columns for each band result.

In [None]:
# Batch with the by_band flag
res4 = ens.batch(my_mean, "flux", by_band=True)
res4.compute()

## Case 2: Functions That Return a Series

In case 2, we write a function that returns a `Pandas.Series` object. This object has the min and max of the flux array stored at different indices of the output series.

In [None]:
def my_bounds(flux):
    return pd.Series({"min":np.min(flux), "max":np.max(flux)})

# Function output
my_bounds([1,2,3,4,5])

As in case 1, we're able to pass this function and the "flux" column along to run the function. However, this time we need the `meta` to be set. The `meta` is a needed component of `Dask's` lazy evaluation. As `Dask` does not actually compute results until requested to, `meta` serves as the expected form of the output. In this case, we just need to let `Dask` know that a min and max column will be present in a dataframe (TAPE will always return a dataframe)  and that both will be float values.

For more information on the `Dask` meta argument, read their [documentation](https://blog.dask.org/2022/08/09/understanding-meta-keyword-argument).

In [None]:
# Default Batch

res1 = ens.batch(my_bounds, "flux", meta={"min":float, "max":float}) # Requires meta to be set
res1.compute()

The same flexibility with grouping extends to case 2, with again needing to specify the `meta`. Note that the meta given to `Ensemble.batch` remains the same, only depending on the function output, it handles the meta for any columns generated by the grouping on it's own.

In [None]:
# Multi-level groupbys, note that meta does not need to change
res2 = ens.batch(my_bounds, "flux", on=["id", "band"], meta={"min":float, "max":float}) # Requires meta to be set
res2.compute()

Using the `by_band` kwarg extends the output columns to be per-band.

In [None]:
# Using by_band

res3 = ens.batch(my_bounds, "flux", by_band=True, meta={"min":float, "max":float}) # Requires meta to be set
res3.compute()

## Case 3: Functions That Return a DataFrame

Here we define a function, `my_bounds_df` that computes the same quantities as `my_bounds` above, but in this case we return a dataframe of the results.

In [None]:
def my_bounds_df(flux):
    return pd.DataFrame({'min':[np.min(flux)], 'max':[np.max(flux)]})

my_bounds_df([1,2,3,4,5])

This is perfectly reasonable, but when passing a function like this through `batch` there's an issue currently to watch out for.

In [None]:
# Default Batch, some things to watch out for

res1 = ens.batch(my_bounds_df, "flux", meta={'min':float, 'max':float})
res1.compute()


As with the series, we needed to pass the `meta` kwarg letting TAPE know which output columns to expect from the function. However,
we see that our result is carrying over the index generated by the dataframe in addition to the batch index, represented as a multi-index. At the time of this notebooks creation, `Dask` does not have explicit support for multi-indexes. We can see this problem in the following cells.

In [None]:
# Pandas resolves these indexes as a multi-index
res1.reset_index().compute()

In [None]:
# Dask assumes there's just a single index column being sent to the dataframe columns
res1.reset_index()

When `Dask` and the underlying `Pandas` disagree on what the dataframe looks like, this causes issues with you as the user being able to work with the dataframe. As `Dask` won't recognize any calls to "id" or "level_1" here, and instead will only accept a call to "index" which in turn `Pandas` won't understand. If this is the issue you run into, we recommend trying to modify your function into a non-dataframe output format. However, in the case that this isn't possible, here's a somewhat hacky way to move around it.

We can resolve this by updating the `Dask` meta manually, to re-align `Dask` and `Pandas`.

In [None]:
# If it's not too compute intensive, grabbing the actual dataframe is the easiest way forward
real_meta_from_result = res1.reset_index().head(0)
real_meta_from_result

In [None]:

# otherwise, can generate this ourselves
real_meta_from_dataframe = TapeFrame(columns=["id","level_1","min","max"])
real_meta_from_dataframe

In [None]:
# Overwrite the _meta property

res1_noindex = res1.reset_index()
res1_noindex._meta = real_meta_from_dataframe
res1_noindex

Note that in the above, we've reset the index as `Dask` will not support meta that tracks a multi-index. In the case of this function, we gain no information from the "level_1" column, and it would be nice to restablish "id" as the index, so we close the loop by executing the commands in the next cell.

In [None]:
res1 = res1_noindex.drop(columns=["level_1"]).set_index("id")
res1.compute()

## Case 4: Functions that Require Non-Array Inputs

Let's return to case 1, but this time instead of the list-like `flux` argument, let's say that the function needs to take in a dataframe with a column titled `my_flux`

In [None]:
# Case 4: DataFrame input
def my_mean_from_df(df):
    return np.mean(df['my_flux'])

df = pd.DataFrame({'my_flux':[1, 2, 3, 4, 5]})
my_mean_from_df(df)

In this case, batch won't be able to directly provide inputs to this function, as batch passes along the column data as arrays to the function. However, we can make this function able to be used by batch by wrapping it with another function.

In [None]:
def mean_wrapper(flux):
    df = pd.DataFrame({'my_flux': flux})
    return my_mean_from_df(df)

# Can pass the wrapper function along to batch
res1 = ens.batch(mean_wrapper, "flux")
res1.compute()

This is a really simple case, but highlights that in some cases a wrapper function can be written to serve as a middle man between your function and `batch`, even doing work to sort or filter your data on a per function call basis if not done as a pre-filter step for your Ensemble.

## Case 5: TAPE Analysis Functions

TAPE analysis functions are a special case of input function to `Ensemble.batch`, where normally required information such as the specified column labels to pass to the function and the `meta` are passed along from the function to `Ensemble.batch` internally, meaning you just need to specify the function and any additional kwargs. For this case, let's leverage the [light-curve](https://github.com/light-curve/light-curve-python) package, which implements the extraction of many light curve [features](https://github.com/light-curve/light-curve-python?tab=readme-ov-file#available-features) used in astrophysics. Feature extraction from this package is also supported within TAPE as an analysis function.

In [None]:
# Grab two features extraction methods from light-curve
from light_curve import Periodogram, OtsuSplit

In the below example, we apply the Lomb-Scargle Periodogram to our `Ensemble` light curves. Again, noting that in this case the `meta` we had to configure above is already handled by TAPE, and the needed timeseries columns are already passed along internally as well.

In [None]:
# Find periods using Lomb-Scargle periodogram
periodogram = Periodogram(peaks=1, nyquist=0.1, max_freq_factor=10, fast=False)

# Use r band only
res_per = ens.batch(periodogram, band_to_calc='r') # band_to_calc is a kwarg of Periodogram
res_per.compute()

Next, we use the `OtsuSplit` function, used to perform automatic thresholding. In this case, we also supply the `by_band` kwarg to get a result per photometric band.

In [None]:
res_otsu = ens.batch(OtsuSplit(), band_to_calc=None, by_band=True)
res_otsu.compute()