In [None]:
import os
import unpopular
import numpy as np
import polars as pl
import lightkurve as lk
from astrocut import CutoutFactory
from scipy.signal import find_peaks
from astropy.coordinates import SkyCoord
from dask.distributed import wait, Client, LocalCluster

In [2]:
targets = pl.read_parquet("./data/targets.parquet").select("ID", "ra", "dec", "sector", "camera", "ccd")

In [4]:
def _count_harmonics(
    lc: lk.LightCurve, height: float = 0.15
) -> list[tuple[float, float]]:
    """Find the harmonics in the L-S periodogram of a given lightcurve.

    Args:
        lc (lk.LightCurve)
        height (float, optional): The minimum height of a peak as a fraction of the main harmonic. Defaults to 0.15.

    Returns:
        list[tuple[float, float]]: A list containing each harmonic as a tuple of period and power.
    """

    pg = lc.to_periodogram()
    period = pg.period_at_max_power

    if period.value >= 2:
        return []

    expected_harmonics = []
    for i in range(1, 9):
        expected_harmonics.append(period.value / i)

    peaks, properties = find_peaks(
        pg.power, distance=120, height=pg.max_power.value * height
    )

    peak_periods = [pg.period[idx].value for idx in peaks]

    found_harmonics = []
    for i, period in enumerate(peak_periods):
        in_range = 0.9 * expected_harmonics[i] <= period <= 1.1 * expected_harmonics[i]
        if in_range:
            found_harmonics.append((period, properties["peak_heights"][i]))
    # for harmonic in expected_harmonics:
    #     for idx, period in enumerate(peak_periods):
    #         if 0.9 * harmonic <= period <= 1.1 * harmonic:
    #             found_harmonics.append((period, properties["peak_heights"][idx]))
    #             break



    return found_harmonics

def is_complex(lc: lk.LightCurve) -> bool:
    """Check if a given lightcurve is complex by counting the number of harmonics."""
    return len(_count_harmonics(lc)) >= 3

In [5]:
def make_lightcurve(tic, coords, sector, camera, ccd):
    cube_cutter = CutoutFactory()

    cube_file = f"s3://stpubdata/tess/public/mast/tess-s{str(sector).zfill(4)}-{camera}-{ccd}-cube.fits"
    cutout = cube_cutter.cube_cut(cube_file, coordinates=coords, cutout_size=50, verbose=True, threads="auto")

    s = unpopular.Source(cutout, remove_bad=True)
    s.set_aperture(rowlims=[25, 26], collims=[25, 26])
    
    s.add_cpm_model(exclusion_size=5, n=64, predictor_method="similar_brightness")
    s.set_regs([0.1])
    s.holdout_fit_predict(k=100)

    apt_detrended_flux = s.get_aperture_lc(data_type="cpm_subtracted_flux")
    
    os.remove(cutout)
    return lk.TessLightCurve(time=s.time, flux=apt_detrended_flux)

In [6]:
def process_target(target):
    coords = SkyCoord(target["ra"], target["dec"], frame="icrs", unit="deg")
    lc = make_lightcurve(target["ID"], coords, target["sector"], target["camera"], target["ccd"])

    try:
        result = is_complex(lc)
    except:
        result = None
    
    return {**target, "result": result}

In [None]:
num_workers = os.cpu_count()
cluster = LocalCluster(n_workers=num_workers, threads_per_worker=1, processes=True)
client = Client(cluster)



In [None]:
futures = client.map(process_target, targets[0:1000].to_dicts())
wait(futures)