# Find periodic variables in ComCam data

In [None]:
    # %pip install lsdb dask nested-dask astropy light-curve gatspy astroml

In [None]:
import lsdb
lsdb.__version__

## Start Dask client

In [None]:
import dask_jobqueue
from dask.distributed import Client

## Loading the catalog

In [None]:
CATALOG_TYPE = "dia_object" # object or diaObject


In [None]:
from pathlib import Path

release = 'v29_0_0'
catalog_path = Path(f"/sdf/data/rubin/user/kostya/hats/{CATALOG_TYPE}_lc")
catalog_path

In [None]:
# Load the Forced Source + MJD Table
from lsdb import read_hats
from nested_pandas import NestedDtype


BRIGHTEST_R_MAG = 21.5

if CATALOG_TYPE == "object":
    # Load the object table
    kwargs = dict(filters = [("r_psfMag", ">", BRIGHTEST_R_MAG)])
    lc_column = "objectForcedSource"
    coord_columns = ["coord_ra", "coord_dec"]
    id_column = "objectId"
    phot_column = "psfMag"
    err_column = "psfMagErr"
    flux_column = "psfFlux"
    fluxerr_column = "psfFluxErr"
elif CATALOG_TYPE == "dia_object":
    # Cannot pass empty filters because of 
    # https://github.com/astronomy-commons/lsdb/issues/739
    kwargs = dict()
    lc_column = "diaObjectForcedSource"
    coord_columns = ["ra", "dec"]
    id_column = "diaObjectId"
    flux_column = phot_column = "psfDiffFlux"
    fluxerr_column = err_column = "psfDiffFluxErr"
else:
    raise ValueError(f"Unknown catalog type: {CATALOG_TYPE}")

obj_lc = read_hats(
    catalog_path,
    # columns=[id_column, lc_column] + coord_columns,
    columns="all",
    **kwargs,
).map_partitions(
    lambda df: df.rename(columns={lc_column: "lc"}),
)
obj_lc

In [None]:
obj_lc.dtypes["lc"].__str__()

## Filter out "bad" detections and select light curves with enough observations

In [None]:
import numpy as np
import light_curve as licu

if CATALOG_TYPE == "object":
    query = (
        "lc.psfMagErr < 0.3"
        " and ~lc.psfFlux_flag"
    )
elif CATALOG_TYPE == "dia_object":
    query = (
        "abs(lc.psfDiffFlux) > 3.0 * lc.psfDiffFluxErr"
        " and ~lc.psfDiffFlux_flag"
    )
else:
    raise ValueError(f"Unknown catalog type: {CATALOG_TYPE}")
query += (
    " and ~lc.pixelFlags_suspect"
    " and ~lc.pixelFlags_saturated"
    " and ~lc.pixelFlags_cr"
    " and ~lc.pixelFlags_bad"
)

filtered_lc = obj_lc.dropna(subset="lc.psfFlux").query(query).dropna(subset="lc")

MIN_NOBS = 50
MIN_NOBS_BAND = 30
MIN_RCHI2 = 2

if CATALOG_TYPE == "object":
    MIN_AMPLITUDE = 0.05
else:
    MIN_AMPLITUDE = 10.0

BANDS = 'ugrizy'
SCAN_BANDS = "griz"

feature_extractor = licu.Extractor(
    licu.ObservationCount(),
    licu.InterPercentileRange(0.05),
    licu.ReducedChi2(),
)

def extract_features(band, t, y, yerr):
    y, yerr = np.asarray(y, dtype=float), np.asarray(yerr, dtype=float)

    nobs = len(band)

    band_idx = band == 'r'
    del band
    t, y, yerr = t, y, yerr = t[band_idx], y[band_idx], yerr[band_idx]

    _, sort_index = np.unique(t, return_index=True)
    t, y, yerr = t[sort_index], y[sort_index], yerr[sort_index]
    
    nobs_r, amplitude_r, rchi2_r = feature_extractor(t, y, yerr, fill_value=np.nan)

    return {'nobs': nobs, 'nobs_r': nobs_r, 'amplitude_r': amplitude_r, 'rchi2_r': rchi2_r}


lc_w_features = filtered_lc.reduce(
    extract_features,
    "lc.band",
    "lc.midpointMjdTai",
    f"lc.{phot_column}",
    f"lc.{err_column}",
    meta=dict.fromkeys(['nobs', 'nobs_r', 'amplitude_r', 'rchi2_r'], float),
    append_columns=True,
).query(f"nobs >= {MIN_NOBS} and nobs_r >= {MIN_NOBS_BAND} and amplitude_r > {MIN_AMPLITUDE} and rchi2_r >= {MIN_RCHI2}")

### Add heliocentric times

In [None]:
from astropy.coordinates import SkyCoord
from astropy.time import Time

from approx_light_travel_time import fast_light_travel_time_heliocentric_elliptical

def add_helio_mjd(df):
    coord = SkyCoord(ra=df["lc.coord_ra"], dec=df["lc.coord_dec"], unit="deg")
    time = Time(df["lc.midpointMjdTai"], format="mjd", scale="tai")
    helio_time = time + fast_light_travel_time_heliocentric_elliptical(time, coord)
    df["lc.helioMjd"] = helio_time.mjd
    return df


lc_helio = lc_w_features.map_partitions(add_helio_mjd)

## Running Lomb-Scargle
Use light-curve package or astropy

In [None]:
from astropy.timeseries import LombScargle
reduced_chi2_extractor = licu.ReducedChi2()

MAX_PERIOD = 1  # days
MIN_PERIOD = 5 / 60 / 24
BAD_PERIODS = np.array([1/3, 0.25, 0.5, 2/3, 1, 2, 29.5])
BAD_PERIOD_REL_RANGE = 10 / 365.2422

def filter_periods(periods):
    periods = np.asarray(periods)
    return (
        np.all(np.abs(periods[:, None]/BAD_PERIODS - 1.0) > BAD_PERIOD_REL_RANGE, axis=1)
        & (periods >= MIN_PERIOD) 
        & (periods <= MAX_PERIOD)
    )


FREQS = np.linspace(1 / 0.5, 1 / (5 / (60 * 24)), 300_000)  # 5 minutes to 12 hours

def extract_period_single_band(band, t, flux, fluxerr, single_band, **kwargs):
    del kwargs  # unused

    band_idx = band == single_band
    del band
    t, flux, fluxerr = t[band_idx], flux[band_idx], fluxerr[band_idx]
    t = np.asarray(t - 60_000.0, dtype=np.float32)

    n = len(t)

    if n < MIN_NOBS_BAND:
        return {f"{single_band}_period_0": 1e9, f"{single_band}_period_s_to_n_0": 0.0, f"{single_band}_period_0_false_alarm_prob": 1.0}

    ls = LombScargle(t, flux, fluxerr)
    power = ls.power(FREQS)
    
    freq_idx = filter_periods(1 / FREQS)
    freq, power = FREQS[freq_idx], power[freq_idx]
    
    if len(freq) == 0:
        return {"period_0": 0.0, "period_s_to_n_0": 0.0}

    idx_period = np.argmax(power)
    period = 1 / freq[idx_period]
    s2n = (power[idx_period] - np.mean(power)) / np.std(power, ddof=1)
    period_0_false_alarm_prob = ls.false_alarm_probability(power[idx_period])

    return {f"{single_band}_period_0": period, f"{single_band}_period_s_to_n_0": s2n, f"{single_band}_period_0_false_alarm_prob": period_0_false_alarm_prob}


tmp_cat = lc_helio
for single_band in SCAN_BANDS:
    tmp_cat = tmp_cat.reduce(
        extract_period_single_band,
        "lc.band",
        "lc.helioMjd",
        f"lc.{phot_column}",
        f"lc.{err_column}",
        single_band=single_band,
        meta={f"{single_band}_period_0": float, f"{single_band}_period_s_to_n_0": float, f"{single_band}_period_0_false_alarm_prob": float},
        append_columns=True,
    )
lc_w_periods = tmp_cat
lc_w_periods

## Periodic Candidate Selection

In [None]:
import pandas as pd

def select_best_period_per_row(row):
    # Function to use in .apply(axis=1) which would output best band, best period and best prob
    period_diff = {}
    for i_first_band in range(len(SCAN_BANDS)):
        for i_second_band in range(i_first_band + 1, len(SCAN_BANDS)):
            first_band = SCAN_BANDS[i_first_band]
            second_band = SCAN_BANDS[i_second_band]
            period_diff_multiplier = np.inf
            for multiplier in [0.25, 0.5, 1.0, 2.0, 4.0]:
                period_diff_multiplier = min(period_diff_multiplier, np.abs(row[f"{first_band}_period_0"] - multiplier * row[f"{second_band}_period_0"]) / row[f"{first_band}_period_0"])
            period_diff[f"{first_band}{second_band}"] = period_diff_multiplier
    best_pair = min(period_diff, key=period_diff.get)
    min_rel_period_diff = period_diff[best_pair]
    first_band, second_band = best_pair
    first_band_prob = row[f"{first_band}_period_0_false_alarm_prob"]
    second_band_prob = row[f"{second_band}_period_0_false_alarm_prob"]
    if first_band_prob < second_band_prob:
        best_band = first_band
        best_period = row[f"{first_band}_period_0"]
        best_prob = first_band_prob
    else:
        best_band = second_band
        best_period = row[f"{second_band}_period_0"]
        best_prob = second_band_prob
    return pd.Series({
        "best_period_band": best_band,
        "period_0": best_period,
        "period_0_false_alarm_prob": best_prob,
        "min_rel_period_diff": min_rel_period_diff,
    })


def select_best_period(df):
    if len(df) == 0:
        return pd.concat(
            [
                df,
                pd.DataFrame({
                    "best_period_band": pd.Series(dtype="str"),
                    "period_0": pd.Series(dtype="float"),
                    "period_0_false_alarm_prob": pd.Series(dtype="float"),
                    "min_rel_period_diff": pd.Series(dtype="float"),
                })
            ],
            axis=1
        )
    return df.join(df.apply(select_best_period_per_row, axis=1))


lc_period_cand = lc_w_periods.map_partitions(select_best_period)
lc_period_cand

## Run the pipeline

In [None]:
%%time

with dask_jobqueue.SLURMCluster(
    processes=1,
    queue="roma",
    account="rubin:commissioning",
    cores=2,
    memory="8GB",
    walltime="01:00:00",
) as cluster:
    cluster.adapt(maximum_jobs=200)
    with Client(cluster) as client:
        display(client)
        lc_period_cand.to_hats("periodic_cand", overwrite=True)

In [None]:
from gatspy.periodic import RRLyraeTemplateModelerMultiband


def rchi2_rrlyr(t, mag, magerr, band, period, extra_sigma=0.1):
    model = RRLyraeTemplateModelerMultiband()
    model.fit(t, mag, magerr, band)
    mag_model = model.predict(t, band, period=period)
    rchi2 = np.sum((mag - mag_model) ** 2 / (magerr ** 2 + extra_sigma ** 2)) / (len(t) - 1)
    return {"rchi2_rrlyr": rchi2}

cand_cat = lsdb.read_hats("periodic_cand")
cand_cat = cand_cat.query(
   "log10(period_0_false_alarm_prob) < -10"
   " and min_rel_period_diff < 0.001"
#    " and (period_0 > 0.251 or period_0 < 0.249)"
   " and 31.4 - 2.5 * log10(r_psfFluxMax) > 22 and 31.4 - 2.5 * log10(abs(r_psfFluxMin)) > 22"
)
cand_cat = cand_cat.reduce(
    rchi2_rrlyr,
    "lc.helioMjd",
    "lc.psfMag",
    "lc.psfMagErr",
    "lc.band",
    "period_0",
    meta={"rchi2_rrlyr": float},
    append_columns=True,
)
# cand_cat = cand_cat.query("rchi2_rrlyr < 2.0")

with dask_jobqueue.SLURMCluster(
    processes=1,
    queue="roma",
    account="rubin:commissioning",
    cores=1,
    memory="4GB",
    walltime="00:20:00",
) as cluster:
    cluster.adapt(maximum_jobs=50)
    with Client(cluster) as client:
        display(client)
        cand_subset = cand_cat.compute()
cand_subset

In [None]:
rr_cand = cand_subset.sort_values("rchi2_rrlyr")
rr_cand

## Plotting a few Phase-Folded Candidates

In [None]:
import matplotlib.pyplot as plt

COLORS = {'u': '#0c71ff', 'g': '#49be61', 'r': '#c61c00',
          'i': '#ffc200', 'z': '#f341a2', 'y': '#5d0000'}

FOLDED = True

fig_path = Path("periodic_cand")
fig_path.mkdir(exist_ok=True, parents=True)

for healpix29, cand in rr_cand.iloc[:200].iterrows():
    phase = cand.lc["helioMjd"] % cand["period_0"] / cand["period_0"]
    phot = cand.lc[phot_column]
    err = cand.lc[err_column]
    fig, (ax_mjd, ax_phase) = plt.subplots(1, 2, figsize=(12, 4), sharey=True)
    all_delta_mag = []
    for b in BANDS:
        idx = (cand.lc["band"] == b)
        errorbar_kwargs = dict(
            y=phot[idx],
            yerr=err[idx],
            fmt="o",
            color=COLORS[b],
            label=f'{b}', # $- {mean_band_mag:.2f}$',
            alpha=0.3,
        )
        ax_mjd.errorbar(cand.lc["helioMjd"][idx], **errorbar_kwargs)
        ax_phase.errorbar(phase[idx], **errorbar_kwargs)
    fig.suptitle(
        f"OID: {cand[id_column]}, RA: {cand[coord_columns[0]]:.5f}, Dec: {cand[coord_columns[1]]:.5f}"
        rf"\nPeriod: {cand['period_0']:.5f}$\,$d, L—S lg(F-P): {np.log10(cand['period_0_false_alarm_prob']):.1f}"
    )

    ax_mjd.set_xlabel("MJD")
    ax_mjd.set_xlim(np.min(cand.lc["helioMjd"])-1, np.max(cand.lc["helioMjd"])+1)
    ax_phase.set_xlabel("Phase")
    ax_phase.set_xlim(0, 1)
    
    max_abs_ylim = max(np.abs(plt.ylim()))
    ax_mjd.plot(ax_mjd.get_xlim(), [0, 0], color='k', linestyle='--', alpha=0.5)
    ax_phase.plot(ax_phase.get_xlim(), [0, 0], color='k', linestyle='--', alpha=0.5)
    ax_mjd.legend(loc='upper left')
    ax_mjd.grid()
    ax_phase.grid()
    
    plt.savefig(fig_path / f"{release}-{cand[id_column]}.pdf")

    print(cand[id_column], cand[coord_columns[0]], cand[coord_columns[1]], cand["period_0"])