In [1]:
import lsdb
import ast
from tape import Ensemble, ColumnMapper
import matplotlib.pyplot as plt
import dask
import numpy as np
import pandas as pd
from collections.abc import Iterable

dask.config.set({'temporary_directory': '/data/epyc/users/brantd/tmp'})
dask.config.set({'dataframe.query-planning': False})

from dask.distributed import Client
client = Client(n_workers=10, threads_per_worker=1,
                memory_limit="60G",
                dashboard_address=':38764')

client

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:38764/status,

0,1
Dashboard: http://127.0.0.1:38764/status,Workers: 10
Total threads: 10,Total memory: 558.79 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:33681,Workers: 10
Dashboard: http://127.0.0.1:38764/status,Total threads: 10
Started: Just now,Total memory: 558.79 GiB

0,1
Comm: tcp://127.0.0.1:43371,Total threads: 1
Dashboard: http://127.0.0.1:45185/status,Memory: 55.88 GiB
Nanny: tcp://127.0.0.1:34647,
Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-g4342r4h,Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-g4342r4h

0,1
Comm: tcp://127.0.0.1:35973,Total threads: 1
Dashboard: http://127.0.0.1:33369/status,Memory: 55.88 GiB
Nanny: tcp://127.0.0.1:45611,
Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-krkwnk79,Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-krkwnk79

0,1
Comm: tcp://127.0.0.1:41558,Total threads: 1
Dashboard: http://127.0.0.1:39471/status,Memory: 55.88 GiB
Nanny: tcp://127.0.0.1:33551,
Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-abk9nzty,Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-abk9nzty

0,1
Comm: tcp://127.0.0.1:39305,Total threads: 1
Dashboard: http://127.0.0.1:42409/status,Memory: 55.88 GiB
Nanny: tcp://127.0.0.1:37169,
Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-3s3uof80,Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-3s3uof80

0,1
Comm: tcp://127.0.0.1:41692,Total threads: 1
Dashboard: http://127.0.0.1:40357/status,Memory: 55.88 GiB
Nanny: tcp://127.0.0.1:34995,
Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-846u6i69,Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-846u6i69

0,1
Comm: tcp://127.0.0.1:36556,Total threads: 1
Dashboard: http://127.0.0.1:37208/status,Memory: 55.88 GiB
Nanny: tcp://127.0.0.1:40945,
Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-42v2aa8d,Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-42v2aa8d

0,1
Comm: tcp://127.0.0.1:36808,Total threads: 1
Dashboard: http://127.0.0.1:39463/status,Memory: 55.88 GiB
Nanny: tcp://127.0.0.1:35011,
Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-657cmqqp,Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-657cmqqp

0,1
Comm: tcp://127.0.0.1:39684,Total threads: 1
Dashboard: http://127.0.0.1:35272/status,Memory: 55.88 GiB
Nanny: tcp://127.0.0.1:45433,
Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-6eusmr30,Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-6eusmr30

0,1
Comm: tcp://127.0.0.1:46419,Total threads: 1
Dashboard: http://127.0.0.1:33245/status,Memory: 55.88 GiB
Nanny: tcp://127.0.0.1:34326,
Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-8w_tsbg9,Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-8w_tsbg9

0,1
Comm: tcp://127.0.0.1:34496,Total threads: 1
Dashboard: http://127.0.0.1:39970/status,Memory: 55.88 GiB
Nanny: tcp://127.0.0.1:40191,
Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-lssbg0_j,Local directory: /data/epyc/users/brantd/tmp/dask-scratch-space/worker-lssbg0_j


## Approach 1: Batch Find Intervals, Join and Query Source

Advantage:
* See an intermediate product of the generated intervals

Disadvantage:
* Slow and Expensive

In [2]:
# Load from the small Ensemble
ens = Ensemble(client=client)
ens.from_ensemble("./ztf_small_ensemble")

<tape.ensemble.Ensemble at 0x7fb2003b4eb0>

In [3]:
# Then continue to do analysis

# Define an example interval generator function
def get_intervals(mjd):
    bounds = np.percentile(mjd, [10.0, 30.0, 70.0, 90.0])
    intervals = [(bounds[0], bounds[1]), (bounds[2], bounds[3])]
    return intervals
    #return pd.Series({'intervals':intervals})

# Apply our interval function to batch and assign as a source column
intervals = ens.batch(get_intervals, "mjd", meta=("intervals", str)) # use string for literal_evals
ens.source.join(intervals).update_ensemble()

# Define a query to determine whether each mjd is in the interval(s)
def row_query(row):
    interval_list = row["intervals"]
    interval_list = ast.literal_eval(interval_list)
    if not isinstance(interval_list, Iterable):
        return False
    
    res = [((row["mjd"] > interval[0]) * (row["mjd"] < interval[1])) for interval in interval_list]
    return np.any(res)


# Apply this query to the source table
interval_mask = ens.source.apply(row_query, axis=1, meta = pd.Series(dtype='bool', name='in_interval'))
ens.source.assign(in_interval=interval_mask).query("in_interval==True").update_ensemble()

Using generated label, result_1, for a batch result.


<tape.ensemble.Ensemble at 0x7fb2003b4eb0>

In [4]:
len(ens.source)

1634455

## Approach 2: Determine Intervals and Immediately Filter Source

Advantage:
* Faster and more memory efficient

Disadvantage:
* Cannot view intervals without creating another function to generate them separately
* Some limitation in TAPE support as of current version (have to do some extra things at the end)
* Cannot handle objects that don't return a result

In [2]:
# Load from the small Ensemble
ens = Ensemble(client=client)
ens.from_ensemble("./ztf_small_ensemble")

<tape.ensemble.Ensemble at 0x7f31d46c29b0>

In [3]:
# Try a batch interval generation -> filter
def filter_intervals(df):
    mjd = df["mjd"]

    # Interval calculation
    bounds = np.percentile(mjd, [10.0, 30.0, 70.0, 90.0])
    intervals = [(bounds[0], bounds[1]), (bounds[2], bounds[3])]
    #intervals = []
    

    # Filter on mjd
    if len(intervals) > 0:
        for i, interval in enumerate(intervals):
            if i == 0:
                mjd_mask = ((mjd > interval[0]) * (mjd < interval[1]))
            else:
                mjd_mask += ((mjd > interval[0]) * (mjd < interval[1]))
    else:
        mjd_mask = np.ones(mjd, dtype=bool)
    
    df = df[mjd_mask]
    return df

source_subset = ens.source.groupby(ens._id_col, group_keys=False).apply(lambda x: filter_intervals(x), meta=ens.source._meta)

# Have to manually re-establish the source for now
source_subset.ensemble = ens
source_subset.set_dirty(True)
source_subset.update_ensemble()


<tape.ensemble.Ensemble at 0x7f31d46c29b0>

In [4]:
len(ens.source)

1634455

## Approach 3: Generate Filters, filter on Objects then recalculate and filter Sources

Advantages:
* Interval information is available in the Object Table
* Relatively Fast
* Robust to Objects with no intervals

Disadvantages:
* Inefficiency in double calculation of intervals, not too big of an issue if interval calculation is lightweight and this is likely faster than trying to join the intervals to source

In [2]:
# Load from the small Ensemble
ens = Ensemble(client=client)
ens.from_ensemble("./ztf_small_ensemble")

<tape.ensemble.Ensemble at 0x7fbbeff8abf0>

In [3]:
# Define an example interval generator function
def get_intervals(mjd):
    bounds = np.percentile(mjd, [10.0, 30.0, 70.0, 90.0])
    intervals = [(bounds[0], bounds[1]), (bounds[2], bounds[3])]
    n_intervals = len(intervals)
    return pd.Series({"n_intervals":n_intervals, "intervals":intervals})
    #return n_intervals, intervals
    #return pd.Series({'intervals':intervals})

# Apply our interval function to batch and filter on n_intervals to remove objects with no intervals
intervals = ens.batch(get_intervals, "mjd", meta={"n_intervals": int, "intervals": str}) # use string for literal_evals
ens.object.join(intervals).query("n_intervals > 0").update_ensemble()

Using generated label, result_1, for a batch result.


<tape.ensemble.Ensemble at 0x7fdde411c130>

In [3]:
# Recalculate intervals and filter source

def filter_intervals(df):
    mjd = df["mjd"]
    
    #if you join intervals to source
    #intervals = df["intervals"].map(ast.literal_eval).values[0]

    # Interval calculation -- do it again
    bounds = np.percentile(mjd, [10.0, 30.0, 70.0, 90.0])
    intervals = [(bounds[0], bounds[1]), (bounds[2], bounds[3])]

    # Filter on mjd
    if len(intervals) > 0:
        for i, interval in enumerate(intervals):
            if i == 0:
                mjd_mask = ((mjd > interval[0]) * (mjd < interval[1]))
            else:
                mjd_mask += ((mjd > interval[0]) * (mjd < interval[1]))
    else:
        mjd_mask = np.ones(mjd, dtype=bool)
    
    df = df[mjd_mask]
    return df

# Join the intervals -- slower generally if intervals are fast to compute
#ens.source.join(intervals["intervals"]).update_ensemble()

source_subset = ens.source.groupby(ens._id_col, group_keys=False).apply(lambda x: filter_intervals(x), meta=ens.source._meta)

# Have to manually re-establish the source for now
source_subset.ensemble = ens
source_subset.set_dirty(True)
source_subset.update_ensemble()


<tape.ensemble.Ensemble at 0x7fbbeff8abf0>

In [4]:
len(ens.source)

1634455