# Find periodic variables in ComCam data

In [None]:
# %pip install lsdb dask nested-dask astropy light-curve

In [None]:
import lsdb
lsdb.__version__

In [None]:
import pandas as pd
pd.__version__

## Start Dask client

In [None]:
from dask.distributed import Client

# Start with a small client
client = Client(n_workers=24, memory_limit="16GB", threads_per_worker=1)
client

## Loading the catalog

In [None]:
from pathlib import Path

release = 'v29_0_0_rc5'
hats_path = Path("/sdf/data/rubin/shared/lsdb_commissioning/hats/") / release
# list dir
print(list(map(str, hats_path.iterdir())))

catalog_path = hats_path / "object_lc"

In [None]:
# Load the Forced Source + MJD Table
from lsdb import read_hats
from nested_pandas import NestedDtype


BRIGHTEST_R_MAG = 21.5

obj_lc = read_hats(
    catalog_path,
    columns=["objectId", "coord_ra", "coord_dec", "objectForcedSource"],
    filters=[("r_psfMag", ">", BRIGHTEST_R_MAG)],
).map_partitions(
    lambda df: df.assign(
        lc=df["objectForcedSource"].astype(
                NestedDtype.from_pandas_arrow_dtype(df.dtypes["objectForcedSource"])
        ),
    ).drop(columns=["objectForcedSource"]),
)
obj_lc

## Filter out "bad" detections and select light curves with enough observations

In [None]:
import numpy as np
import light_curve as licu

filtered_lc = obj_lc.dropna(subset="lc.psfFlux").query(
    "lc.psfMagErr < 0.1"
    " and ~lc.psfFlux_flag"
    " and ~lc.pixelFlags_suspect"
    " and ~lc.pixelFlags_saturated"
    " and ~lc.pixelFlags_cr"
    " and ~lc.pixelFlags_bad"
).dropna(subset="lc")

MIN_NOBS = 50
MIN_NOBS_BAND = 30
MIN_RCHI2 = 10
MIN_AMPLITUDE = 0.05

BANDS = 'ugrizy'
SCAN_BANDS = "griz"

feature_extractor = licu.Extractor(
    licu.ObservationCount(),
    licu.InterPercentileRange(0.05),
    licu.ReducedChi2(),
)

def extract_features(band, t, y, yerr):
    y, yerr = np.asarray(y, dtype=float), np.asarray(yerr, dtype=float)

    nobs = len(band)

    band_idx = band == 'r'
    del band
    t, y, yerr = t, y, yerr = t[band_idx], y[band_idx], yerr[band_idx]

    _, sort_index = np.unique(t, return_index=True)
    t, y, yerr = t[sort_index], y[sort_index], yerr[sort_index]
    
    nobs_r, amplitude_r, rchi2_r = feature_extractor(t, y, yerr, fill_value=np.nan)

    return {'nobs': nobs, 'nobs_r': nobs_r, 'amplitude_r': amplitude_r, 'rchi2_r': rchi2_r}


lc_w_features = filtered_lc.reduce(
    extract_features,
    "lc.band",
    "lc.midpointMjdTai",
    "lc.psfMag",
    "lc.psfMagErr",
    meta=dict.fromkeys(['nobs', 'nobs_r', 'amplitude_r', 'rchi2_r'], float),
    append_columns=True,
).query(f"nobs >= {MIN_NOBS} and nobs_r >= {MIN_NOBS_BAND} and amplitude_r > {MIN_AMPLITUDE} and rchi2_r >= {MIN_RCHI2}")

In [None]:
# Non-lazy computation
# len(lc_w_features._ddf)

### Add heliocentric times

In [None]:
from astropy.coordinates import EarthLocation, SkyCoord
from astropy.time import Time

lsst_location = EarthLocation.of_site("LSST")


def add_helio_mjd(df):
    coord = SkyCoord(ra=df["lc.coord_ra"], dec=df["lc.coord_dec"], unit="deg")
    time = Time(df["lc.midpointMjdTai"], format="mjd", scale="tai")
    helio_time = time + time.light_travel_time(coord, "heliocentric", location=lsst_location)
    df["lc.helioMjd"] = helio_time.mjd
    return df


lc_helio = lc_w_features.map_partitions(add_helio_mjd)

## Running Lomb-Scargle
Use light-curve package or astropy

In [None]:
from astropy.timeseries import BoxLeastSquares, LombScargle, LombScargleMultiband

PERIODOGRAM_RESOLUTION = 30_000
PERIODOGRAM_NYQUIST_FACTOR = 10

periodogram_extractor = licu.Periodogram(
    peaks=10,
    nyquist='average',
    # resolution=1000,
    # max_freq_factor=10,
    resolution=PERIODOGRAM_RESOLUTION,
    max_freq_factor=PERIODOGRAM_NYQUIST_FACTOR,
    fast=False,
)
reduced_chi2_extractor = licu.ReducedChi2()

MAX_PERIOD = 1  # days
MIN_PERIOD = 5 / 60 / 24
BAD_PERIODS = np.array([1/3, 0.5, 2/3, 1, 2, 29.5])
# BAD_PERIODS = np.array([])
BAD_PERIOD_REL_RANGE = 10 / 365.2422

def filter_periods(periods):
    periods = np.asarray(periods)
    return (
        np.all(np.abs(periods[:, None]/BAD_PERIODS - 1.0) > BAD_PERIOD_REL_RANGE, axis=1)
        & (periods >= MIN_PERIOD) 
        & (periods <= MAX_PERIOD)
    )


def extract_period_multiband_licu(band, t, flux, fluxerr, **kwargs):
    # We offset date, so we still would have <1 second precision

    _, sort_index = np.unique(t, return_index=True)
    band, t, flux, fluxerr = band[sort_index], t[sort_index], flux[sort_index], fluxerr[sort_index]
    mag = -2.5 * np.log10(flux)
    magerr = 2.5 / np.log(10) * fluxerr / flux

    band_freqs = []
    band_periodograms = []
    band_weights = []
    periods = []
    s2n = []
    for b in BANDS:
        band_idx = band == b

        # At least few points in the band
        if np.count_nonzero(band_idx) < MIN_NOBS_BAND:
            continue
        
        band_t = t[band_idx]
        band_mag = mag[band_idx]
        band_magerr = magerr[band_idx]

        freq, periodogram = periodogram_extractor.freq_power(band_t, band_mag)
        freq_idx = filter_periods(2*np.pi / freq)
        if not np.any(freq_idx):
            continue
        band_freqs.append(freq[freq_idx])
        band_periodograms.append(periodogram[freq_idx])

        reduced_chi2 = reduced_chi2_extractor(band_t, band_mag, band_magerr, **kwargs)[0]
        chi2 = reduced_chi2 * (len(band_t) - 1)
        band_weights.append(chi2)

        idx_band_period = np.argmax(band_periodograms[-1])
        periods.append(2*np.pi / band_freqs[-1][idx_band_period])
        s2n.append(
            (band_periodograms[-1][idx_band_period] - np.mean(band_periodograms[-1])) / np.std(band_periodograms[-1], ddof=1)
        )

    if len(band_periodograms) == 0:
        return {"period_0": 0.0, "period_s_to_n_0": 0.0}

    freq = np.unique(np.concatenate(band_freqs))
    periodograms = [np.interp(freq, f, p, left=np.mean(p), right=np.mean(p)) for f, p in zip(band_freqs, band_periodograms)]
    multiband_periodogram = np.average(periodograms, weights=band_weights, axis=0)

    idx_multiband_period = np.argmax(multiband_periodogram)
    multiband_s2n = (multiband_periodogram[idx_multiband_period] - np.mean(multiband_periodogram)) / np.std(multiband_periodogram, ddof=1)
    multiband_period = 2*np.pi / freq[idx_multiband_period]

    periods.append(multiband_period)
    s2n.append(multiband_s2n)

    best_period = periods[np.argmax(s2n)]
    best_s2n = s2n[np.argmax(s2n)]

    # Return the features as a dictionary
    return {"period_0": best_period, "period_s_to_n_0": best_s2n}


def extract_period_singleband_licu(band, t, flux, fluxerr, **kwargs):
    del fluxerr  # unused

    _, sort_index = np.unique(t, return_index=True)
    band, t, flux = band[sort_index], t[sort_index], flux[sort_index]
    
    periods = []
    s2n = []
    for b in BANDS:
        band_index = band == b
        band_t, band_flux = t[band_index], flux[band_index]

        if len(band_t) < MIN_NOBS_BAND:
            continue

        band_mag = -2.5 * np.log10(band_flux)

        features = periodogram_extractor(band_t, band_mag, **kwargs)
        periods.extend(features[::2])
        s2n.extend(features[1::2])
    periods, s2n = np.asarray(periods), np.asarray(s2n)
    
    if len(periods) == 0:
        return {"period_0": 0.0, "period_s_to_n_0": 0.0}
    
    idx_periods = filter_periods(periods)
    if not np.any(idx_periods):
        return {"period_0": 0.0, "period_s_to_n_0": 0.0}
    periods_inrange, s2n_inrange = periods[idx_periods], s2n[idx_periods]
    
    idx_best_period = np.argmax(s2n_inrange)
    best_period = periods_inrange[idx_best_period]
    best_s2n = s2n_inrange[idx_best_period]

    return {"period_0": best_period, "period_s_to_n_0": best_s2n}


def extract_period_rband_licu(band, t, flux, fluxerr, **kwargs):
    _, sort_index = np.unique(t, return_index=True)
    band, t, flux = band[sort_index], t[sort_index], flux[sort_index]

    r_band = band == 'r'
    t, flux = t[r_band], flux[r_band]
    
    if len(t) < MIN_NOBS_BAND:
        return {"period_0": 0.0, "period_s_to_n_0": 0.0}

    mag = -2.5 * np.log10(flux)

    features = periodogram_extractor(t, mag, **kwargs)
    periods, s2n = features[::2], features[1::2]
    
    idx_periods = filter_periods(periods)
    if not np.any(idx_periods):
        return {"period_0": 0.0, "period_s_to_n_0": 0.0}

    best_period = periods[idx_periods][0]
    best_s2n = s2n[idx_periods][0]
    return {"period_0": best_period, "period_s_to_n_0": best_s2n}


def extract_period_multiband_astropy(band, t, flux, fluxerr, **kwargs):
    del kwargs  # unused

    mag = -2.5 * np.log10(flux)
    magerr = 2.5 / np.log(10) * fluxerr / flux
    freq, power = LombScargleMultiband(t, mag, band, magerr).autopower()
    
    freq_idx = filter_periods(1 / freq)
    freq, power = freq[freq_idx], power[freq_idx]
    
    if len(freq) == 0:
        return {"period_0": 0.0, "period_s_to_n_0": 0.0}

    idx_period = np.argmax(power)
    period = 1 / freq[idx_period]
    s2n = (power[idx_period] - np.mean(power[idx_period])) / np.std(power[idx_period], ddof=1)
    return {"period_0": period, "period_s_to_n_0": s2n}


def extract_period_rband_box_astropy(band, t, flux, fluxerr, **kwargs):
    del kwargs  # unused

    r_band = band == 'r'
    t, flux, fluxerr = t[r_band], flux[r_band], fluxerr[r_band]
    
    # Not enough points in the light curve
    if len(t) < MIN_NOBS_BAND:
        return {"period_0": 0.0, "period_s_to_n_0": 0.0}

    # Light curve is too short
    if np.ptp(t) <= 2.0:
        return {"period_0": 0.0, "period_s_to_n_0": 0.0}

    input_period = filter_periods(np.geomspace(2.0, MAX_PERIOD, 10))
    result = BoxLeastSquares(t, flux, fluxerr).autopower(
        duration=np.geomspace(1.0 / (24 * 60), 1.0, 10),
    )

    idx_period = np.argmax(result.power)
    return {"period_0": result.period[idx_period], "period_s_to_n_0": (result.power[idx_period] - np.mean(result.power[idx_period])) / np.std(result.power, ddof=1)}


FREQS = np.linspace(0.03, 24 * 60 / 5, 1_000_000)


def extract_period_single_band_astropy(band, t, flux, fluxerr, single_band, **kwargs):
    del kwargs  # unused

    band_idx = band == single_band
    del band
    t, flux, fluxerr = t[band_idx], flux[band_idx], fluxerr[band_idx]

    if len(t) < MIN_NOBS_BAND:
        return {f"{single_band}_period_0": 1e9, f"{single_band}_period_s_to_n_0": 0.0, f"{single_band}_period_0_false_alarm_prob": 1.0}

    mag = -2.5 * np.log10(flux)
    magerr = 2.5 / np.log(10) * fluxerr / flux
    ls = LombScargle(t, mag, magerr)
    power = ls.power(FREQS)
    
    freq_idx = filter_periods(1 / FREQS)
    freq, power = FREQS[freq_idx], power[freq_idx]
    
    if len(freq) == 0:
        return {"period_0": 0.0, "period_s_to_n_0": 0.0}

    idx_period = np.argmax(power)
    period = 1 / freq[idx_period]
    s2n = (power[idx_period] - np.mean(power[idx_period])) / np.std(power[idx_period], ddof=1)
    period_0_false_alarm_prob = ls.false_alarm_probability(power[idx_period])

    return {f"{single_band}_period_0": period, f"{single_band}_period_s_to_n_0": s2n, f"{single_band}_period_0_false_alarm_prob": period_0_false_alarm_prob}


tmp_cat = lc_helio
for single_band in SCAN_BANDS:
    tmp_cat = tmp_cat.reduce(
        extract_period_single_band_astropy,
        "lc.band",
        "lc.helioMjd",
        "lc.psfFlux",
        "lc.psfFluxErr",
        single_band=single_band,
        meta={f"{single_band}_period_0": float, f"{single_band}_period_s_to_n_0": float, f"{single_band}_period_0_false_alarm_prob": float},
        append_columns=True,
    )
lc_w_periods = tmp_cat
lc_w_periods

## Periodic Candidate Selection

In [None]:
import pyarrow as pa

def select_best_period(row):
    # Function to use in .apply(axis=1) which would output best band, best period and best prob
    period_diff = {}
    for i_first_band in range(len(SCAN_BANDS)):
        for i_second_band in range(i_first_band + 1, len(SCAN_BANDS)):
            first_band = SCAN_BANDS[i_first_band]
            second_band = SCAN_BANDS[i_second_band]
            period_diff[f"{first_band}{second_band}"] = np.abs(row[f"{first_band}_period_0"] - row[f"{second_band}_period_0"]) / row[f"{first_band}_period_0"]
    best_pair = max(period_diff, key=period_diff.get)
    min_rel_period_diff = period_diff[best_pair]
    first_band, second_band = best_pair
    first_band_prob = row[f"{first_band}_period_0_false_alarm_prob"]
    second_band_prob = row[f"{second_band}_period_0_false_alarm_prob"]
    if first_band_prob < second_band_prob:
        best_band = first_band
        best_period = row[f"{first_band}_period_0"]
        best_prob = first_band_prob
    else:
        best_band = second_band
        best_period = row[f"{second_band}_period_0"]
        best_prob = second_band_prob
    new_data = pd.Series({
        "best_period_band": best_band,
        "period_0": best_period,
        "period_0_false_alarm_prob": best_prob,
        "min_rel_period_diff": min_rel_period_diff,
    })
    return pd.concat([row, new_data])


lc_period_cand = lc_w_periods.map_partitions(
    lambda df: df.apply(select_best_period, axis=1),
    meta=pd.concat(
        [
            lc_w_periods._ddf.meta,
            pd.DataFrame({
                "best_period_band": np.array([], dtype=str),
                "period_0": np.array([], dtype=float),
                "period_0_false_alarm_prob": np.array([], dtype=float),
                "min_rel_period_diff": np.array([], dtype=float),
            })
        ],
        axis=1
    ),
)
lc_period_cand

## Plotting a few Phase-Folded Candidates

In [None]:
cand_df = lc_period_cand.compute()
# pd.DataFrame.to_parquet(cand_df, f"periodic_cand-{release}.parquet")

In [None]:
cand_subset = cand_df.sort_values(by="period_0_false_alarm_prob", ascending=True)
cand_subset = cand_subset.query(
   "log10(period_0_false_alarm_prob) < -6"
   "and min_rel_period_diff < 0.2"
)
cand_subset

In [None]:
import matplotlib.pyplot as plt

COLORS = {'u': '#0c71ff', 'g': '#49be61', 'r': '#c61c00',
          'i': '#ffc200', 'z': '#f341a2', 'y': '#5d0000'}

FOLDED = True

for healpix29, cand in cand_subset.iloc[:50].iterrows():
    phase = cand.lc["helioMjd"] % cand["period_0"] / cand["period_0"]
    mag = -2.5 * np.log10(cand.lc["psfFlux"]) + 31.4
    magerr = 2.5 / np.log(10) * cand.lc["psfFluxErr"] / cand.lc["psfFlux"]
    fig, (ax_mjd, ax_phase) = plt.subplots(1, 2, figsize=(12, 4), sharey=True)
    all_delta_mag = []
    for b in BANDS:
        idx = (cand.lc["band"] == b) & (magerr < 0.1)
        mean_band_mag = np.mean(mag[idx])
        delta_mag = mag[idx] - mean_band_mag
        all_delta_mag.append(delta_mag)
        errorbar_kwargs = dict(
            y=delta_mag,
            yerr=magerr[idx],
            fmt="o",
            color=COLORS[b],
            label=f'{b} $- {mean_band_mag:.2f}$',
            alpha=0.3,
        )
        ax_mjd.errorbar(cand.lc["helioMjd"][idx], **errorbar_kwargs)
        ax_phase.errorbar(phase[idx], **errorbar_kwargs)
    fig.suptitle(
        f"OID: {cand.objectId}, RA: {cand['coord_ra']:.5f}, Dec: {cand['coord_dec']:.5f}"
        rf"\nPeriod: {cand['period_0']:.5f}$\,$d, L—S lg(F-P): {np.log10(cand['period_0_false_alarm_prob']):.1f}"
    )
    ax_mjd.set_ylabel("mag - mean(mag)")

    ax_mjd.set_xlabel("MJD")
    ax_mjd.set_xlim(np.min(cand.lc["helioMjd"])-1, np.max(cand.lc["helioMjd"])+1)
    ax_phase.set_xlabel("Phase")
    ax_phase.set_xlim(0, 1)
    
    max_abs_ylim = max(np.abs(plt.ylim()))
    y_lim_min, y_lim_max = np.quantile(np.concatenate(all_delta_mag), [0.01, 0.99])
    plt.ylim(y_lim_min, y_lim_max)
    plt.gca().invert_yaxis()
    ax_mjd.plot(ax_mjd.get_xlim(), [0, 0], color='k', linestyle='--', alpha=0.5)
    ax_phase.plot(ax_phase.get_xlim(), [0, 0], color='k', linestyle='--', alpha=0.5)
    ax_mjd.legend(loc='upper left')
    ax_mjd.grid()
    ax_phase.grid()
    
    plt.savefig(f"periodic_cand-{release}-{cand.objectId}.pdf")

    print(cand.objectId)