In [1]:
import lsdb
import ast
from tape import Ensemble, ColumnMapper
import matplotlib.pyplot as plt
import dask
import numpy as np
import pandas as pd
from collections.abc import Iterable

dask.config.set({'temporary_directory': '/data/epyc/users/brantd/tmp'})
dask.config.set({'dataframe.query-planning': False})

from dask.distributed import Client
client = Client(n_workers=10, threads_per_worker=1,
                memory_limit="60G",
                dashboard_address=':38764')

client

Perhaps you already have a cluster running?
Hosting the HTTP server on port 43747 instead


0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:43747/status,

0,1
Dashboard: http://127.0.0.1:43747/status,Workers: 10
Total threads: 10,Total memory: 558.79 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:41565,Workers: 10
Dashboard: http://127.0.0.1:43747/status,Total threads: 10
Started: Just now,Total memory: 558.79 GiB

0,1
Comm: tcp://127.0.0.1:37710,Total threads: 1
Dashboard: http://127.0.0.1:36275/status,Memory: 55.88 GiB
Nanny: tcp://127.0.0.1:38910,
Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-s6q6xzz5,Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-s6q6xzz5

0,1
Comm: tcp://127.0.0.1:40413,Total threads: 1
Dashboard: http://127.0.0.1:38530/status,Memory: 55.88 GiB
Nanny: tcp://127.0.0.1:33887,
Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-p0u8911t,Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-p0u8911t

0,1
Comm: tcp://127.0.0.1:41741,Total threads: 1
Dashboard: http://127.0.0.1:37458/status,Memory: 55.88 GiB
Nanny: tcp://127.0.0.1:37558,
Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-mllg_3l6,Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-mllg_3l6

0,1
Comm: tcp://127.0.0.1:38041,Total threads: 1
Dashboard: http://127.0.0.1:42255/status,Memory: 55.88 GiB
Nanny: tcp://127.0.0.1:38132,
Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-shwj9_w8,Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-shwj9_w8

0,1
Comm: tcp://127.0.0.1:38661,Total threads: 1
Dashboard: http://127.0.0.1:39828/status,Memory: 55.88 GiB
Nanny: tcp://127.0.0.1:37681,
Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-hfbtx8o8,Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-hfbtx8o8

0,1
Comm: tcp://127.0.0.1:42377,Total threads: 1
Dashboard: http://127.0.0.1:32808/status,Memory: 55.88 GiB
Nanny: tcp://127.0.0.1:35605,
Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-47fpcz2q,Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-47fpcz2q

0,1
Comm: tcp://127.0.0.1:37733,Total threads: 1
Dashboard: http://127.0.0.1:39104/status,Memory: 55.88 GiB
Nanny: tcp://127.0.0.1:36727,
Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-pya7tokg,Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-pya7tokg

0,1
Comm: tcp://127.0.0.1:42421,Total threads: 1
Dashboard: http://127.0.0.1:34881/status,Memory: 55.88 GiB
Nanny: tcp://127.0.0.1:35899,
Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-n25z543m,Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-n25z543m

0,1
Comm: tcp://127.0.0.1:37496,Total threads: 1
Dashboard: http://127.0.0.1:35718/status,Memory: 55.88 GiB
Nanny: tcp://127.0.0.1:34711,
Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-ruh00f__,Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-ruh00f__

0,1
Comm: tcp://127.0.0.1:32970,Total threads: 1
Dashboard: http://127.0.0.1:34159/status,Memory: 55.88 GiB
Nanny: tcp://127.0.0.1:46128,
Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-b7ctvnwo,Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-b7ctvnwo


## Approach 1: Batch Find Intervals, Join and Query Source

Advantage:
* See an intermediate product of the generated intervals

Disadvantage:
* Slow and Expensive

In [2]:
# Load from the small Ensemble
ens = Ensemble(client=client)
ens.from_ensemble("./ztf_small_ensemble")

<tape.ensemble.Ensemble at 0x7fb2003b4eb0>

In [3]:
# Then continue to do analysis

# Define an example interval generator function
def get_intervals(mjd):
    bounds = np.percentile(mjd, [10.0, 30.0, 70.0, 90.0])
    intervals = [(bounds[0], bounds[1]), (bounds[2], bounds[3])]
    return intervals
    #return pd.Series({'intervals':intervals})

# Apply our interval function to batch and assign as a source column
intervals = ens.batch(get_intervals, "mjd", meta=("intervals", str)) # use string for literal_evals
ens.source.join(intervals).update_ensemble()

# Define a query to determine whether each mjd is in the interval(s)
def row_query(row):
    interval_list = row["intervals"]
    interval_list = ast.literal_eval(interval_list)
    if not isinstance(interval_list, Iterable):
        return False
    
    res = [((row["mjd"] > interval[0]) * (row["mjd"] < interval[1])) for interval in interval_list]
    return np.any(res)


# Apply this query to the source table
interval_mask = ens.source.apply(row_query, axis=1, meta = pd.Series(dtype='bool', name='in_interval'))
ens.source.assign(in_interval=interval_mask).query("in_interval==True").update_ensemble()

Using generated label, result_1, for a batch result.


<tape.ensemble.Ensemble at 0x7fb2003b4eb0>

In [4]:
len(ens.source)

1634455

## Approach 2: Determine Intervals and Immediately Filter Source

Advantage:
* Faster and more memory efficient

Disadvantage:
* Cannot view intervals without creating another function to generate them separately
* Some limitation in TAPE support as of current version (have to do some extra things at the end)
* Cannot handle objects that don't return a result

In [2]:
# Load from the small Ensemble
ens = Ensemble(client=client)
ens.from_ensemble("./ztf_small_ensemble")

<tape.ensemble.Ensemble at 0x7f31d46c29b0>

In [3]:
# Try a batch interval generation -> filter
def filter_intervals(df):
    mjd = df["mjd"]

    # Interval calculation
    bounds = np.percentile(mjd, [10.0, 30.0, 70.0, 90.0])
    intervals = [(bounds[0], bounds[1]), (bounds[2], bounds[3])]
    #intervals = []
    

    # Filter on mjd
    if len(intervals) > 0:
        for i, interval in enumerate(intervals):
            if i == 0:
                mjd_mask = ((mjd > interval[0]) * (mjd < interval[1]))
            else:
                mjd_mask += ((mjd > interval[0]) * (mjd < interval[1]))
    else:
        mjd_mask = np.ones(mjd, dtype=bool)
    
    df = df[mjd_mask]
    return df

source_subset = ens.source.groupby(ens._id_col, group_keys=False).apply(lambda x: filter_intervals(x), meta=ens.source._meta)

# Have to manually re-establish the source for now
source_subset.ensemble = ens
source_subset.set_dirty(True)
source_subset.update_ensemble()


<tape.ensemble.Ensemble at 0x7f31d46c29b0>

In [4]:
len(ens.source)

1634455

## Approach 3: Generate Filters, filter on Objects then recalculate and filter Sources

Advantages:
* Interval information is available in the Object Table
* Relatively Fast
* Robust to Objects with no intervals

Disadvantages:
* Inefficiency in double calculation of intervals, not too big of an issue if interval calculation is lightweight and this is likely faster than trying to join the intervals to source

In [167]:
# Load from the small Ensemble
ens = Ensemble(client=client)
ens.from_ensemble("./ztf_small_ensemble")

<tape.ensemble.Ensemble at 0x7fb6a9c3dea0>

In [3]:
# Define an example interval generator function
def get_intervals(mjd):
    bounds = np.percentile(mjd, [10.0, 30.0, 70.0, 90.0])
    intervals = [(bounds[0], bounds[1]), (bounds[2], bounds[3])]
    n_intervals = len(intervals)
    return pd.Series({"n_intervals":n_intervals, "intervals":intervals})
    #return n_intervals, intervals
    #return pd.Series({'intervals':intervals})

# Apply our interval function to batch and filter on n_intervals to remove objects with no intervals
intervals = ens.batch(get_intervals, "mjd", meta={"n_intervals": int, "intervals": str}) # use string for literal_evals
ens.object.join(intervals).query("n_intervals > 0").update_ensemble()

Using generated label, result_1, for a batch result.


<tape.ensemble.Ensemble at 0x7fdde411c130>

In [3]:
# Recalculate intervals and filter source

def filter_intervals(df):
    mjd = df["mjd"]
    
    #if you join intervals to source
    #intervals = df["intervals"].map(ast.literal_eval).values[0]

    # Interval calculation -- do it again
    bounds = np.percentile(mjd, [10.0, 30.0, 70.0, 90.0])
    intervals = [(bounds[0], bounds[1]), (bounds[2], bounds[3])]

    # Filter on mjd
    if len(intervals) > 0:
        for i, interval in enumerate(intervals):
            if i == 0:
                mjd_mask = ((mjd > interval[0]) * (mjd < interval[1]))
            else:
                mjd_mask += ((mjd > interval[0]) * (mjd < interval[1]))
    else:
        mjd_mask = np.ones(mjd, dtype=bool)
    
    df = df[mjd_mask]
    return df

# Join the intervals -- slower generally if intervals are fast to compute
#ens.source.join(intervals["intervals"]).update_ensemble()

source_subset = ens.source.groupby(ens._id_col, group_keys=False).apply(lambda x: filter_intervals(x), meta=ens.source._meta)

# Have to manually re-establish the source for now
source_subset.ensemble = ens
source_subset.set_dirty(True)
source_subset.update_ensemble()


<tape.ensemble.Ensemble at 0x7fbbeff8abf0>

In [4]:
len(ens.source)

1634455

## Approach 4: Extending to Multi-Band

In [23]:
# Load from the small Ensemble
ens = Ensemble(client=client)
ens.from_ensemble("./ztf_small_ensemble")

<tape.ensemble.Ensemble at 0x7f8e8b863e50>

In [24]:
# Define an example interval generator function per-band
def get_band_intervals(mjd, band, band_label="g"):
    band_mask = band == band_label
    band_mjd = mjd[band_mask]

    if len(band_mjd) > 0:
        bounds = np.percentile(band_mjd, [10.0, 30.0, 70.0, 90.0])
        intervals = [(bounds[0], bounds[1]), (bounds[2], bounds[3])]
        n_intervals = len(intervals)
        return pd.Series({f"n_intervals_{band_label}":n_intervals, f"{band_label}_intervals":intervals})
    else:
        return pd.Series({f"n_intervals_{band_label}":0, f"{band_label}_intervals":[]})

In [25]:
# Apply our interval function to batch and filter on n_intervals to remove objects with no intervals in each band
for band_label in ["g","r","i"]:
    intervals = ens.batch(get_band_intervals, "mjd", "band", band_label=band_label, meta={f"n_intervals_{band_label}": int, f"{band_label}_intervals": str}) # use string for literal_evals
    ens.object.join(intervals).update_ensemble()

# Query for objects with at least 1 valid interval
ens.object.query("n_intervals_g + n_intervals_r + n_intervals_i > 0").update_ensemble()


Using generated label, result_1, for a batch result.
Using generated label, result_2, for a batch result.
Using generated label, result_3, for a batch result.


<tape.ensemble.Ensemble at 0x7f8e8b863e50>

In [13]:
# Define a function to filter Source by recalculating intervals -- adds a interval label flag to source table
def filter_intervals(df):
    mjd = df["mjd"]
    band = df["band"]
    #flag = df["catflags"]
    mag = df["mag"]
    magerr = df["magerr"]

    # Denotes interval groupings
    interval_labels = np.zeros_like(mjd, dtype=int)

    # Interval calculation -- do it again
    # Loop over each band and calculate intervals
    for i, band_label in enumerate(["g","r","i"]):
        intervals = get_band_intervals(mjd, band, band_label)[f"{band_label}_intervals"] # This returns a pandas series
        band_mask = band==band_label # define an initial band mask
        
        # Filter based on intervals
        if len(intervals) > 0:
            for j, interval in enumerate(intervals):
                 # Mask times based on the interval
                interval_mask = np.logical_and(band_mask, ((mjd > interval[0]) * (mjd < interval[1])))
                # Generate labels for the interval (unique per band-interval)
                interval_labels += ((100*i)+j) * interval_mask

                # Build the full mask (all intervals in all bands)
                if i == 0:
                    full_mask = interval_mask
                else:
                    full_mask = np.logical_or(full_mask, interval_mask)
        else:
            # don't keep any data from band if no intervals
            if i == 0:
                full_mask = np.zeros_like(mjd, dtype=bool)
        
        # construct full mask for all bands
        #if i == 0:
        #    full_mask = mjd_mask
        #else:
        #    full_mask = np.logical_or(full_mask, mjd_mask)

    df = df.assign(interval_flag = interval_labels)
    
    df = df[full_mask]
    return df

In [26]:
# Define a function to filter Source by recalculating intervals -- adds a interval label flag to source table
def filter_intervals(df):
    mjd = df["mjd"]
    band = df["band"]
    #flag = df["catflags"]
    mag = df["mag"]
    magerr = df["magerr"]

    # Denotes interval groupings
    interval_labels = np.zeros_like(mjd, dtype=int)

    # Interval calculation -- do it again
    # Loop over each band and calculate intervals
    for i, band_label in enumerate(["g","r","i"]):
        intervals = get_band_intervals(mjd, band, band_label)[f"{band_label}_intervals"] # This returns a pandas series
        band_mask = band==band_label # define an initial band mask
        
        # Filter based on intervals
        if len(intervals) > 0:
            for j, interval in enumerate(intervals):
                 # Mask times based on the interval
                interval_mask = np.logical_and(band_mask, ((mjd > interval[0]) * (mjd < interval[1])))
                # Generate labels for the interval (unique per band-interval)
                interval_labels += ((100*i)+j) * interval_mask

    df = df.assign(interval_flag = interval_labels)
    
    df = df[interval_labels.astype(bool)]
    return df

In [19]:
np.array([1,2,3,0]).astype(bool)

array([ True,  True,  True, False])

In [27]:
# Apply source filter function and generate a new source table
source_subset = ens.source.groupby(ens._id_col, group_keys=False).apply(lambda x: filter_intervals(x), meta=ens.source._meta.assign(interval_flag=None))

# Have to manually re-establish the source for now
#source_subset.ensemble = ens
#source_subset.set_dirty(True)
#source_subset.update_ensemble()

In [28]:
source_subset.head(5, npartitions=-1)

Unnamed: 0_level_0,ps1_objid_object,ra_object,dec_object,ps1_gMeanPSFMag_object,ps1_rMeanPSFMag_object,ps1_iMeanPSFMag_object,nobs_g_object,nobs_r_object,nobs_i_object,mean_mag_g_object,...,Dir_object,Npix_object,ps1_objid,ra,dec,mjd,mag,magerr,band,interval_flag
_hipscat_index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
5433293159180271616,120013408788005954,340.878784,10.012841,20.7568,19.6542,19.116301,18,44,28,20.830534,...,0,1206,120013408788005954,340.878784,10.012841,58379.28908,19.518396,0.093768,r,101
5433293159180271616,120013408788005954,340.878784,10.012841,20.7568,19.6542,19.116301,18,44,28,20.830534,...,0,1206,120013408788005954,340.878784,10.012841,58377.29735,19.512482,0.093367,r,101
5433293159180271616,120013408788005954,340.878784,10.012841,20.7568,19.6542,19.116301,18,44,28,20.830534,...,0,1206,120013408788005954,340.878784,10.012841,58374.32507,19.511116,0.093274,r,101
5433293159180271616,120013408788005954,340.878784,10.012841,20.7568,19.6542,19.116301,18,44,28,20.830534,...,0,1206,120013408788005954,340.878784,10.012841,58373.31629,19.573908,0.097601,r,101
5433293159180271616,120013408788005954,340.878784,10.012841,20.7568,19.6542,19.116301,18,44,28,20.830534,...,0,1206,120013408788005954,340.878784,10.012841,58372.32936,19.54579,0.095645,r,101


In [29]:
len(source_subset)

1381284

## Approach 4.5: Extending to Multi-band with grouping

In [97]:
# Load from the small Ensemble
ens = Ensemble(client=client)
ens.from_ensemble("./ztf_small_ensemble")

<tape.ensemble.Ensemble at 0x7fb6ad286230>

In [100]:
# Define an example interval generator function
def get_intervals(mjd, return_intervals=True):

    if len(mjd) > 0:
        bounds = np.percentile(mjd, [10.0, 30.0, 70.0, 90.0])
        intervals = [(bounds[0], bounds[1]), (bounds[2], bounds[3])]
        n_intervals = len(intervals)
    else:
        n_intervals = 0
        intervals = 0

    if return_intervals:
        return pd.Series({"n_intervals":n_intervals, "intervals":intervals})
    else:
        return pd.Series({"n_intervals":n_intervals})
        #return n_intervals
    #return n_intervals, intervals
    #return pd.Series({'intervals':intervals})

# Apply our interval function to batch and filter on n_intervals to remove objects with no intervals
intervals = ens.batch(get_intervals, "mjd", by_band=True, meta={"n_intervals":int}, return_intervals=False) # use string for literal_evals
ens.object.join(intervals).update_ensemble()

Using generated label, result_2, for a batch result.


<tape.ensemble.Ensemble at 0x7fb6ad286230>

In [101]:
ens.object.dtypes

ps1_objid                    int64
ra                         float64
dec                        float64
ps1_gMeanPSFMag            float64
ps1_rMeanPSFMag            float64
ps1_iMeanPSFMag            float64
nobs_g                       int32
nobs_r                       int32
nobs_i                       int32
mean_mag_g                 float64
mean_mag_r                 float64
mean_mag_i                 float64
Norder                       int32
Dir                          int32
Npix                         int32
n_intervals_g      string[pyarrow]
n_intervals_i      string[pyarrow]
n_intervals_r      string[pyarrow]
dtype: object