# Benchamrk LSDB for IO-bound analysis

The NB investigates LSDB/Dask overhead in parquet data reading and filtering over pure pyarrow analysis.

Cells are self-consistent, so you can run them in any order. It is recommended to run everything few times to make the data "hot".

### Naive LSDB

In [7]:
import lsdb
from dask.distributed import Client


# Load Gaia DR3 path on PSC
PATH = '/ocean/projects/phy210048p/shared/hipscat/catalogs/gaia_dr3'
catalog = lsdb.read_hipscat(
    PATH,
    # All the following kwargs are passed to pd.read_parquet()
    engine='pyarrow',  # should be the default
    columns=[
        "source_id",
        "ra",
        "dec",
        "phot_g_mean_mag",
        "phot_proc_mode",
        "azero_gspphot",
        "classprob_dsc_combmod_star",
    ],
)

# Get Dask dataframe
df = catalog._ddf

# Filter it
df = df.query(
    "15.0 <= phot_g_mean_mag"
    + "& phot_g_mean_mag <= 16.0"
    + "& phot_proc_mode == 0"
    + "& azero_gspphot < 0.1"
    + "& classprob_dsc_combmod_star >= 0.99"
)

# Run computations with Dask Client
with Client(n_workers=16) as client:
    %time result = df.compute()
    
print(result.shape)

CPU times: user 7.26 s, sys: 1.2 s, total: 8.45 s
Wall time: 25.9 s
(4370011, 7)


### Filter with `lsdb.read_hipscat(filters)`

In [8]:
import lsdb
from dask.distributed import Client


# Load Gaia DR3 path on PSC
PATH = '/ocean/projects/phy210048p/shared/hipscat/catalogs/gaia_dr3'
catalog = lsdb.read_hipscat(
    PATH,
    # All the following kwargs are passed to pd.read_parquet()
    engine='pyarrow',  # should be the default
    columns=[
        "source_id",
        "ra",
        "dec",
        "phot_g_mean_mag",
        "phot_proc_mode",
        "azero_gspphot",
        "classprob_dsc_combmod_star",
    ],
    filters=[[  # We need nested list, because inner list does AND
        ("phot_g_mean_mag", ">=", 15.0),
        ("phot_g_mean_mag", "<=", 16.0),
        ("phot_proc_mode", "==", 0),
        ("azero_gspphot", "<", 0.1),
        ("classprob_dsc_combmod_star", ">=", 0.99),
    ]],
)

# Get Dask dataframe
df = catalog._ddf

# Run computations with Dask Client
with Client(n_workers=16) as client:
    %time result = df.compute()
    
print(result.shape)

CPU times: user 6.86 s, sys: 1.02 s, total: 7.88 s
Wall time: 29.5 s
(4370011, 7)


### Use `pyarrow` directly

In [6]:
from pathlib import Path

import pyarrow.dataset
import pyarrow.parquet
from pyarrow.dataset import field


# Location on PSC
PATH = Path('/ocean/projects/phy210048p/shared/hipscat/catalogs/gaia_dr3')

# Load HiPSCat as a pyarrow dataset
schema = pyarrow.parquet.read_schema(PATH / "_common_metadata")
dataset = pyarrow.dataset.parquet_dataset(PATH / "_metadata", partitioning="hive", schema=schema)

# Read to pandas DF
# We need a function here, so we can do %time in a single line
def read_df():
    return dataset.to_table(
        columns=[
            "source_id",
            "ra",
            "dec",
            "phot_g_mean_mag",
            "phot_proc_mode",
            "azero_gspphot",
            "classprob_dsc_combmod_star",
        ],
        filter=(
            (field("phot_g_mean_mag") >= 15.0)
            & (field("phot_g_mean_mag") <= 16.0)
            & (field("phot_proc_mode") == 0)
            & (field("azero_gspphot") < 0.1)
            & (field("classprob_dsc_combmod_star") >= 0.99)
        ),
    )

%time df = read_df()

print(df.shape)

CPU times: user 1min 59s, sys: 1min 49s, total: 3min 49s
Wall time: 38.1 s
(4370011, 7)


## Appendix

### Naive LSDB, but with no columns selected

We still do `dtype_backend='pyarrow'` to not have `to_pyarrow_string` overhead, see this issue:

https://github.com/astronomy-commons/lsdb/issues/89

In [10]:
import lsdb
from dask.distributed import Client


# Load Gaia DR3 path on PSC
PATH = '/ocean/projects/phy210048p/shared/hipscat/catalogs/gaia_dr3'
catalog = lsdb.read_hipscat(
    PATH,
    # All the following kwargs are passed to pd.read_parquet()
    engine='pyarrow',  # should be the default
    dtype_backend='pyarrow',
)

# Get Dask dataframe
df = catalog._ddf

# Filter it
df = df.query(
    "15.0 <= phot_g_mean_mag"
    + "& phot_g_mean_mag <= 16.0"
    + "& phot_proc_mode == 0"
    + "& azero_gspphot < 0.1"
    + "& classprob_dsc_combmod_star >= 0.99"
)

# Run computations with Dask Client
with Client(n_workers=16) as client:
    %time result = df.compute()
    
print(result.shape)

CPU times: user 1min 2s, sys: 14.4 s, total: 1min 17s
Wall time: 5min 34s
(4370011, 155)
