# Load large catalog data with LSDB

Here we load a small part of ZTF DR14 stored as HiPSCat catalog using [LSDB](https://lsdb.readthedocs.io/).

The notebook is an adaptation of [the tutorial](https://github.com/lincc-frameworks/Rare_Gems_Demo/blob/main/Notebook_2_Basic_Time_Domain.ipynb) presented by Neven Caplar at the Rare Gems in Big Data conference, May 2024. 

## Install dependencies for the notebook

The notebook requires `nested-dask` and few other packages to be installed.
- `lsdb` to load and join "object" (pointing) and "source" (detection) ZTF catalogs
- `aiohttp` is `lsdb`'s optional dependency to download the data via web
- `light-curve` to extract features from light curves
- `matplotlib` to plot the results

In [None]:
# Uncomment the following line to install nested-dask
# %pip install nested-dask

# Comment the following line to skip dependencies installation
%pip install --quiet lsdb aiohttp light-curve matplotlib

In [None]:
import dask.array
import dask.distributed
import light_curve as licu
import matplotlib.pyplot as plt
import nested_pandas as npd
import numpy as np
import pandas as pd
from dask_expr import from_legacy_dataframe
from lsdb import read_hipscat
from matplotlib.colors import LogNorm
from nested_dask import NestedFrame

## Load ZTF DR14
For the demonstration purposes we use a light version of the ZTF DR14 catalog distributed by LINCC Frameworks, a half-degree circle around RA=180, Dec=10.
We load the data from HTTPS as two LSDB catalogs: objects (metadata catalog) and source (light curve catalog).

In [None]:
catalogs_dir = "https://epyc.astro.washington.edu/~lincc-frameworks/half_degree_surveys/ztf/"

lsdb_object = read_hipscat(
    f"{catalogs_dir}/ztf_object",
    columns=["ra", "dec", "ps1_objid"],
)
lsdb_source = read_hipscat(
    f"{catalogs_dir}/ztf_source",
    columns=["mjd", "mag", "magerr", "band", "ps1_objid", "catflags"],
)
lc_columns = ["mjd", "mag", "magerr", "band", "catflags"]

We need to merge these two catalogs to get the light curve data.
It is done with LSDB's `.join()` method which would give us a new catalog with all the columns from both catalogs. 

In [None]:
# We can ignore warning here - for this particular case we don't need margin cache
lsdb_joined = lsdb_object.join(
    lsdb_source,
    left_on="ps1_objid",
    right_on="ps1_objid",
    suffixes=("", ""),
)
joined_ddf = lsdb_joined._ddf
joined_ddf

## Convert LSDB joined catalog to `nested_dask.NestedFrame`

First, we plan the computation to convert the joined Dask DataFrame to a NestedFrame.

In [None]:
def convert_to_nested_frame(df: pd.DataFrame, nested_columns: list[str]):
    other_columns = [col for col in df.columns if col not in nested_columns]

    # Since object rows are repeated, we just drop duplicates
    object_df = df[other_columns].groupby(level=0).first()
    nested_frame = npd.NestedFrame(object_df)

    source_df = df[nested_columns]
    # lc is for light curve
    return nested_frame.add_nested(source_df, "lc")


ddf = joined_ddf.map_partitions(
    lambda df: convert_to_nested_frame(df, nested_columns=lc_columns),
    meta=convert_to_nested_frame(joined_ddf._meta, nested_columns=lc_columns),
)
nested_ddf = NestedFrame.from_dask_dataframe(from_legacy_dataframe(ddf))
nested_ddf

Now we filter our dataframe by the `catflags` column (0 flags correspond to the perfect observational conditions) and the `band` column to be equal to `r`.
After filtering the detections, we are going to count the number of detections per object and keep only those objects with more than 10 detections.

In [None]:
r_band = nested_ddf.query("lc.catflags == 0 and lc.band == 'r'")
nobs = r_band.reduce(np.size, "lc.mjd", meta={0: int}).rename(columns={0: "nobs"})
r_band = r_band[nobs["nobs"] > 10]
r_band

Later we are going to extract features, so we need to prepare light-curve data to be in the same float format.

### Extract features from ZTF light curves

Now we are going to extract some features:
- Top periodogram peak
- Mean magnitude
- Von Neumann's eta statistics
- Excess variance statistics
- Number of observations

We are going to use [`light-curve`](https://github.com/light-curve/light-curve-python) package for this purposes

In [None]:
extractor = licu.Extractor(
    licu.Periodogram(
        peaks=1, max_freq_factor=50.0, fast=True
    ),  # Would give two features: peak period and signa-to-noise ratio of the peak
    licu.WeightedMean(),  # Mean magnitude
    licu.Eta(),  # Von Neumann's eta statistics
    licu.ExcessVariance(),  # Excess variance statistics
    licu.ObservationCount(),  # Number of observations
)


# light-curve requires all arrays to be the same dtype.
# It also requires the time array to be ordered and to have no duplicates.
def extract_features(mjd, mag, magerr, **kwargs):
    # We offset date, so we still would have <1 second precision
    t = np.asarray(mjd - 60000, dtype=np.float32)
    _, sort_index = np.unique(t, return_index=True)
    features = extractor(
        t[sort_index],
        mag[sort_index],
        magerr[sort_index],
        **kwargs,
    )
    # Return the features as a dictionary
    return dict(zip(extractor.names, features))


features = r_band.reduce(
    extract_features,
    "lc.mjd",
    "lc.mag",
    "lc.magerr",
    meta={name: np.float32 for name in extractor.names},
)

df_w_features = r_band.join(features)
df_w_features

Before we are going next and actually run the computation, let's create a Dask client which would allow us to run the computation in parallel.

In [None]:
client = dask.distributed.Client()

Now we can collect some statistics and plot it. 

In [None]:
lg_period_bins = np.linspace(-1, 2, 101)
lg_snr_bins = np.linspace(0, 2, 101)

lg_period = dask.array.log10(df_w_features["period_0"].to_dask_array())
lg_snr = dask.array.log10(df_w_features["period_s_to_n_0"].to_dask_array())

hist2d = dask.array.histogram2d(
    lg_period,
    lg_snr,
    bins=[lg_period_bins, lg_snr_bins],
)
# Run the computation
hist2d = hist2d[0].compute()

# Plot the 2D histogram
plt.imshow(
    hist2d.T,
    extent=(lg_period_bins[0], lg_period_bins[-1], lg_snr_bins[0], lg_snr_bins[-1]),
    origin="lower",
    norm=LogNorm(vmin=1, vmax=hist2d.max()),
)
plt.colorbar(label="Number of stars")
plt.xlabel("lg Period/day")
plt.ylabel("lg S/N")

Let's select a bright star with a high signal-to-noise ratio and plot its phase light curve. We also filter periods ~1 day, because they are very likely to be bogus.

In [None]:
bright_periodic_stars = df_w_features.query(
    "period_s_to_n_0 > 20 and weighted_mean < 15 and (period_0 < 0.9 or period_0 > 1.1)"
)
df = bright_periodic_stars.compute()

obj = df.iloc[0]
lc = obj["lc"]
period = obj["period_0"]
ra = obj["ra"]
dec = obj["dec"]

plt.errorbar(lc["mjd"] % period / period, lc["mag"], lc["magerr"], fmt="o")
plt.gca().invert_yaxis()
plt.xlabel("Phase")
plt.ylabel("Magnitude")
plt.xlim([0, 1])
plt.title(f"RA={ra:.4f}, Dec={dec:.4f} period={period:.4f} days")

print("Search this object for on the SNAD ZTF Viewer:")
print(f"https://ztf.snad.space/search/{ra}%20{dec}/1")

It looks like a nice RR Lyrae star!

In [None]:
client.close()