# Run a Periodogram Across Full ZTF Sources

This notebook is an adaptation of the Nested Dask [tutorial for loading HiPSCat data](https://nested-dask.readthedocs.io/en/latest/tutorials/work_with_lsdb.html).

## Install dependencies for the notebook

The notebook requires few packages to be installed.
- `lsdb` to load and join "object" (pointing) and "source" (detection) ZTF catalogs
- `aiohttp` is `lsdb`'s optional dependency to download the data via web
- `light-curve` to extract features from light curves
- `matplotlib` to plot the results

In [1]:
# Comment the following line to skip dependencies installation
%pip install --quiet tqdm aiohttp light-curve matplotlib lsdb

Note: you may need to restart the kernel to use updated packages.


In [2]:
import os
from importlib.metadata import version
from pathlib import Path

import dask.array
import dask.distributed
import dask_jobqueue
import light_curve as licu
import matplotlib.pyplot as plt
import nested_pandas as npd
import numpy as np
import pandas as pd
from lsdb import read_hipscat
from matplotlib.colors import LogNorm
from nested_dask import NestedFrame

In [3]:
print(f"{version('lsdb') = }")
print(f"{version('nested-dask') = }")
print(f"{version('dask') = }")
print(f"{version('dask-expr') = }")

version('lsdb') = '0.3.0'
version('nested-dask') = '0.2.0'
version('dask') = '2024.8.2'
version('dask-expr') = '1.1.13'


Some additional setup for using Dask on PSC Bridges2:

## Load ZTF DR14

In [4]:
# Full catalog
search_area = None

In [5]:
catalogs_dir = "https://data.lsdb.io/unstable/ztf/"


lsdb_object = read_hipscat(
    f"{catalogs_dir}/ztf_dr14",
    columns=["ra", "dec", "ps1_objid"],
    search_filter=search_area,
)
lsdb_source = read_hipscat(
    f"{catalogs_dir}/ztf_zource",
    columns=["mjd", "ra", "dec", "mag", "magerr", "band", "ps1_objid", "catflags"],
    search_filter=search_area,
)
lc_columns = ["mjd", "mag", "magerr", "band", "catflags"]

In [6]:
lsdb_source

Unnamed: 0_level_0,mjd,ra,dec,mag,magerr,band,ps1_objid,catflags
npartitions=41679,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
0,double[pyarrow],double[pyarrow],double[pyarrow],float[pyarrow],float[pyarrow],string[pyarrow],int64[pyarrow],int16[pyarrow]
4503599627370496,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...
13833932155375321088,...,...,...,...,...,...,...,...
18446744073709551615,...,...,...,...,...,...,...,...


We need to merge these two catalogs to get the light curve data.
It is done with LSDB's `.join_nested()` method which would give us a new catalog with a nested frame of ZTF sources. For this tutorial we'll just use the underlying nested dataframe for the rest of the analysis rather than the LSDB catalog directly.

In [7]:
# Nesting Sources into Object
nested_ddf = lsdb_object.join_nested(lsdb_source, left_on="ps1_objid", right_on="ps1_objid", nested_column_name="lc")

# TODO remove once have added LSDB wrappers for nested_dask (reduce, dropna, etc)
nested_ddf = nested_ddf._ddf



## Convert LSDB joined catalog to `nested_dask.NestedFrame`

First, we plan the computation to convert the joined Dask DataFrame to a NestedFrame.

Now we filter our dataframe by the `catflags` column (0 flags correspond to the perfect observational conditions) and the `band` column to be equal to `r`.
After filtering the detections, we are going to count the number of detections per object and keep only those objects with more than 10 detections.

In [8]:
%%time

r_band = nested_ddf.query("lc.catflags == 0 and lc.band == 'r'")
nobs = r_band.reduce(np.size, "lc.mjd", meta={0: int}).rename(columns={0: "nobs"})
r_band = r_band[nobs["nobs"] > 10]
r_band

CPU times: user 2 s, sys: 2.5 ms, total: 2 s
Wall time: 2.58 s


Unnamed: 0_level_0,ra,dec,ps1_objid,lc
npartitions=41679,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,double[pyarrow],double[pyarrow],int64[pyarrow],"nested<mjd: [double], ra: [double], dec: [double], mag: [float], magerr: [float], band: [string], catflags: [int16]>"
4503599627370496,...,...,...,...
...,...,...,...,...
13833932155375321088,...,...,...,...
18446744073709551615,...,...,...,...


Later we are going to extract features, so we need to prepare light-curve data to be in the same float format.

### Extract features from ZTF light curves

Now we are going to extract some features:
- Top periodogram peak
- Mean magnitude
- Von Neumann's eta statistics
- Excess variance statistics
- Number of observations

We are going to use [`light-curve`](https://github.com/light-curve/light-curve-python) package for this purposes

In [9]:
%%time

extractor = licu.Extractor(
    licu.Periodogram(
        peaks=1,
        max_freq_factor=1.0, # Currently 1.0 for fast runs, will raise for more interesting graphs later
        fast=True,
    ),  # Would give two features: peak period and signa-to-noise ratio of the peak
)


# light-curve requires all arrays to be the same dtype.
# It also requires the time array to be ordered and to have no duplicates.
def extract_features(mjd, mag, **kwargs):
    # We offset date, so we still would have <1 second precision
    t = np.asarray(mjd - 60000, dtype=np.float32)
    _, sort_index = np.unique(t, return_index=True)
    features = extractor(
        t[sort_index],
        mag[sort_index],
        **kwargs,
    )
    # Return the features as a dictionary
    return dict(zip(extractor.names, features))


features = r_band.reduce(
    extract_features,
    "lc.mjd",
    "lc.mag",
    meta={name: np.float32 for name in extractor.names},
)

CPU times: user 3.14 ms, sys: 37 Î¼s, total: 3.18 ms
Wall time: 3.12 ms


Before we are going next and actually run the computation, let's create a Dask client which would allow us to run the computation in parallel.

Now we can collect some statistics and plot it. 

In [10]:
%%time

mean_period = features['period_0'].mean()

CPU times: user 1.11 ms, sys: 0 ns, total: 1.11 ms
Wall time: 1.14 ms


In [11]:
%%time

print("Dask task graph length", len(mean_period.dask))

Dask task graph length 585861
CPU times: user 4min 18s, sys: 17.1 s, total: 4min 35s
Wall time: 4min 36s


In [12]:
client = dask.distributed.Client(n_workers=8, threads_per_worker=8)
client

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 8
Total threads: 64,Total memory: 247.07 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:38393,Workers: 8
Dashboard: http://127.0.0.1:8787/status,Total threads: 64
Started: Just now,Total memory: 247.07 GiB

0,1
Comm: tcp://127.0.0.1:43863,Total threads: 8
Dashboard: http://127.0.0.1:43167/status,Memory: 30.88 GiB
Nanny: tcp://127.0.0.1:42017,
Local directory: /var/tmp/dask-scratch-space/worker-m7jzu1yt,Local directory: /var/tmp/dask-scratch-space/worker-m7jzu1yt

0,1
Comm: tcp://127.0.0.1:36823,Total threads: 8
Dashboard: http://127.0.0.1:34651/status,Memory: 30.88 GiB
Nanny: tcp://127.0.0.1:44689,
Local directory: /var/tmp/dask-scratch-space/worker-tmlaompr,Local directory: /var/tmp/dask-scratch-space/worker-tmlaompr

0,1
Comm: tcp://127.0.0.1:35699,Total threads: 8
Dashboard: http://127.0.0.1:43703/status,Memory: 30.88 GiB
Nanny: tcp://127.0.0.1:33891,
Local directory: /var/tmp/dask-scratch-space/worker-sg31k0mz,Local directory: /var/tmp/dask-scratch-space/worker-sg31k0mz

0,1
Comm: tcp://127.0.0.1:37493,Total threads: 8
Dashboard: http://127.0.0.1:39665/status,Memory: 30.88 GiB
Nanny: tcp://127.0.0.1:36125,
Local directory: /var/tmp/dask-scratch-space/worker-eitygb7s,Local directory: /var/tmp/dask-scratch-space/worker-eitygb7s

0,1
Comm: tcp://127.0.0.1:46701,Total threads: 8
Dashboard: http://127.0.0.1:36847/status,Memory: 30.88 GiB
Nanny: tcp://127.0.0.1:46297,
Local directory: /var/tmp/dask-scratch-space/worker-6y_x4_r7,Local directory: /var/tmp/dask-scratch-space/worker-6y_x4_r7

0,1
Comm: tcp://127.0.0.1:45857,Total threads: 8
Dashboard: http://127.0.0.1:42829/status,Memory: 30.88 GiB
Nanny: tcp://127.0.0.1:36799,
Local directory: /var/tmp/dask-scratch-space/worker-i45r__v9,Local directory: /var/tmp/dask-scratch-space/worker-i45r__v9

0,1
Comm: tcp://127.0.0.1:35113,Total threads: 8
Dashboard: http://127.0.0.1:43529/status,Memory: 30.88 GiB
Nanny: tcp://127.0.0.1:44451,
Local directory: /var/tmp/dask-scratch-space/worker-dqk1rkpt,Local directory: /var/tmp/dask-scratch-space/worker-dqk1rkpt

0,1
Comm: tcp://127.0.0.1:33597,Total threads: 8
Dashboard: http://127.0.0.1:35143/status,Memory: 30.88 GiB
Nanny: tcp://127.0.0.1:45357,
Local directory: /var/tmp/dask-scratch-space/worker-dr_bod8s,Local directory: /var/tmp/dask-scratch-space/worker-dr_bod8s


Task exception was never retrieved
future: <Task finished name='Task-1996' coro=<Client._gather.<locals>.wait() done, defined at /ocean/projects/phy210048p/malanche/lsdb-tests/cenv/lib/python3.11/site-packages/distributed/client.py:2382> exception=AllExit()>
Traceback (most recent call last):
  File "/ocean/projects/phy210048p/malanche/lsdb-tests/cenv/lib/python3.11/site-packages/distributed/client.py", line 2391, in wait
    raise AllExit()
distributed.client.AllExit


In [None]:
%%time

# Time-stamp we use to track graph scheduling time
from time import time; print(time())

# Run the computation
mean_period_value = mean_period.compute()
mean_period_value

In [None]:
client.close()