In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import astropy.units as u
%matplotlib inline
%config InlineBackend.figure_format = "retina"
from matplotlib import rcParams
rcParams['savefig.dpi'] = 550
rcParams['font.size'] = 20
plt.rc('font', family='serif')
from tqdm import tqdm
import lsdb
import dask
dask.config.set({"temporary-directory" :'/epyc/ssd/users/atzanida/tmp'})
dask.config.set({"dataframe.shuffle-compression": 'Snappy'})

In [None]:
from dask.distributed import Client
client = Client(n_workers=15, memory_limit="auto")

In [None]:
client

In [None]:
%%time
fgk_object = lsdb.read_hipscat("/nvme/users/atzanida/tmp/sample_final_starhorse_hips")

In [None]:
fgk_table = fgk_object.cone_search(ra=131.312,
    dec=14.315,
    radius_arcsec=1_000)
fgk_table = fgk_table.compute()

In [None]:
# Select only RA and DEC columns
new_fgk = fgk_table[['RA_ICRS_StarHorse', 'DE_ICRS_StarHorse', 'ps1_objid_ztf_dr14']]

In [None]:
new_fgk.dtypes

In [None]:
len(new_fgk)

In [None]:
new_fgk.head(1)

In [None]:
%%time 
hips_fgk_object = lsdb.from_dataframe(new_fgk,
                                ra_column="RA_ICRS_StarHorse", 
                                dec_column="DE_ICRS_StarHorse", 
                                threshold=5_000,
                                      lowest_order=5,
                                      highest_order=8)

In [None]:
# Load ZTF DR14 sources
ztf_sources = lsdb.read_hipscat("/epyc/data3/hipscat/catalogs/ztf_axs/ztf_zource")

In [None]:
%%time
_sources = hips_fgk_object.join(
    ztf_sources, left_on="ps1_objid_ztf_dr14", right_on="ps1_objid")

## Initialize TAPE

In [None]:
import dask.dataframe as dd
from tape import Ensemble, ColumnMapper

# Initialize an Ensemble
ens = Ensemble(client=client)

In [None]:
ens.client

In [None]:
# ColumnMapper Establishes which table columns map to timeseries quantities
colmap = ColumnMapper(
        id_col='_hipscat_index',
        time_col='mjd',
        flux_col='mag',
        err_col='magerr',
        band_col='band',
      )

ens.from_dask_dataframe(
    source_frame=_sources._ddf,
    object_frame=hips_fgk_object._ddf,
    column_mapper=colmap,
    sync_tables=False, # Avoid doing an initial sync
    sorted=True, # If the input data is already sorted by the chosen index
    sort=False,
)

## Custom Time-Series Function for TAPE

In [2]:
from astropy import stats as astro_stats
import astropy.units as u
from astropy.modeling.models import Gaussian1D
from astropy.modeling import fitting
from scipy import stats
from scipy.optimize import curve_fit
from scipy import interpolate

from astropy.io import ascii
from scipy.signal import find_peaks
import scipy.integrate as integrate
from scipy.interpolate import interp1d
from scipy.signal import savgol_filter


def prepare_lc(time, mag, mag_err, flag, band, band_of_study='r', flag_good=0, q=None, custom_q=False):
    """
    Prepare the light curve for analysis - specifically for the ZTF data.
    
    Parameters:
    -----------
    time (array-like): Input time values.
    mag (array-like): Input magnitude values.
    mag_err (array-like): Input magnitude error values.
    flag (array-like): Input flag values.
    band (array-like): Input band values.
    band_of_study (str): Band to study. Default is 'r' band
    flag_good (int): Flag value for good detections. Default is 0 (see ZTF documentation)

    Returns:
    --------
    time (array-like): Output time values.
    mag (array-like): Output magnitude values.
    mag_err (array-like): Output magnitude error values.
    """

    # Selection and preparation of the light curve (default selection on )
    rmv = (flag == flag_good) & (mag_err>0) & (band==band_of_study) & (~np.isnan(time)) & (~np.isnan(mag)) & (~np.isnan(mag_err)) # remove nans!

    time, mag, mag_err = time[rmv], mag[rmv], mag_err[rmv]

    # sort time
    srt = time.argsort()

    time, mag, mag_err = time[srt], mag[srt], mag_err[srt]
    ts = abs(time - np.roll(time, 1)) > 1e-5

    time, mag, mag_err = time[ts], mag[ts], mag_err[ts]

    # Remove observations that are <0.5 day apart
    cut_close_time = np.where(np.diff(time) < 0.5)[0] + 1
    time, mag, mag_err  = np.delete(time, cut_close_time), np.delete(mag, cut_close_time), np.delete(mag_err, cut_close_time)

    return time, mag, mag_err


def best_peak_detector(peak_dictionary, min_in_dip=1):
    """Chose the best peak from the peak detector with a minimum number of detections threshold. 
    
    Parameters:
    -----------
    peak_dictionary (dict): Dictionary of the peaks.
    min_in_dip (int): Minimum number of detections in the dip. Default is 3 detections.

    Returns:
    --------
    pd.DataFrame: Table of the best dip properties.
    """
    # unpack dictionary
    N_peaks, dict_summary = peak_dictionary
    
    summary_matrix = np.zeros(shape=(N_peaks, 9)) # TODO: add more columns to this matrix
    for i, info in enumerate(dict_summary.keys()):
        summary_matrix[i,:] = np.array(list(dict_summary[f'{info}'].values()))

    dip_table = pd.DataFrame(summary_matrix, columns=['peak_loc', 'window_start', 'window_end', 'N_1sig_in_dip', 'N_in_dip', 'loc_forward_dur', 'loc_backward_dur', 'dip_power', 'average_dt_dif'])

    return dip_table
    
def deviation(mag, mag_err, R, S):
    """Calculate the running deviation of a light curve for outburst or dip detection.
    
    d >> 0 will be dimming
    d << 0 (or negative) will be brightenning
    
    
    Parameters:
    -----------
    mag (array-like): Magnitude values of the light curve.
    mag_err (array-like): Magnitude errors of the light curve.
    R (float): Biweight location of the light curve (global).
    S (float): Biweight scale of the light curve (global).

    Returns:
    --------
    dev (array-like): Deviation values of the light curve.
    """
    # Calculate biweight estimators
    return (mag - R) / np.sqrt(mag_err**2 + S**2) 


def gaus(x, a, x0, sigma, ofs):
    """"Calculate a simple Gaussian function with a term offset"""
    return a*np.exp(-(x-x0)**2/(2*sigma**2)) + ofs

def auto_fit(x, y, loc, base, return_model=False):
    """Perform a Gaussian function auto fitting with some lose priors."""
    try:
        popt, pcov = curve_fit(gaus, 
                                x,
                                y,
                                p0=[1, loc, 1, base],
                            bounds=((0.1, loc-5, 0.1, base-2),
                                    (np.inf, loc+5, np.inf, base+2)))
    except: # if fails return zeros...
        popt, pcov = [0, 0, 0, 0], [0, 0, 0, 0]
    
    if return_model:
        return gaus(x, *popt)
    else:
        return popt

def fwhm_calc(pop):
    """Add parameters recovered"""
    return 2.355 * pop[2] # fwhm 

def calc_sum_score(xdat, ydat, peak_dict, base, rms):
    """Calculate score"""
    score_term = 0
    for i in range(peak_dict[0]):
        event = peak_dict[1][f'dip_{i}']
        loc = event['peak_loc']
        powr = event['dip_power']
        Ndet = event['N_1sig_in_dip']
        
        fit_temrs = auto_fit(xdat, ydat,
                             loc, base, return_model=False)
        fwhm = fwhm_calc(fit_temrs)
        
        score_term += fwhm * powr * Ndet
        
    return (1/(peak_dict[0])) * (1/rms) * score_term


def detect_bursts_edges(time, mag, center_time, baseline_mean, baseline_std, burst_threshold=3.0, expansion_indices=1):
    """
    Detect bursts in a time series using linear interpolation. powered by GPT. 

    Parameters:
    -----------
    time (array-like): Time values of the light curve.
    mag (array-like): Magnitude values of the light curve.
    center_time (float): Center time of the burst.
    baseline_mean (float): Mean of the baseline.
    baseline_std (float): Standard deviation of the baseline.
    burst_threshold (float): Threshold for burst detection. Default is 3.0.
    expansion_indices (int): Number of indices to expand the burst region. Default is 1.

    Returns:
    --------
    burst_start (float): Start time of the burst.
    burst_end (float): End time of the burst.
    """


    # Initialize burst_start and burst_end
    burst_start = burst_end = np.searchsorted(time, center_time)

    # Find burst start
    while burst_start > 0:
        burst_start -= 1
        if mag[burst_start] < baseline_mean + burst_threshold * baseline_std:
            break

    # Find burst end
    while burst_end < len(time) - 1:
        burst_end += 1
        if mag[burst_end] < baseline_mean + burst_threshold * baseline_std:
            break

    # Expand burst region towards the beginning
    burst_start = max(0, burst_start - expansion_indices)

    # Expand burst region towards the end
    burst_end = min(len(time) - 1, burst_end + expansion_indices)

    # Final start and end
    t_start, t_end = time[burst_start], time[burst_end]

    # How many detections above 2std above the mean?
    N_thresh_1 = len((mag[(time>t_start) & (time<t_end)]>baseline_mean + 2*baseline_std))

    return t_start, t_end, abs(t_start-center_time), abs(t_end-center_time), N_thresh_1, 0, 0

def peak_detector(times, dips, power_thresh=3, peak_close_rmv=15, pk_2_pk_cut=30):
    """
    Run and compute dip detection algorithm on a light curve.
    
    Parameters:
    -----------
    times (array-like): Time values of the light curve.
    dips (array-like): Deviation values of the light curve.
    power_thresh (float): Threshold for the peak detection. Default is 3.
    peak_close_rmv (float): Tolerance for removing peaks that are too close to each other. Default is 15.
    pk_2_pk_cut (float): Minimum peak to peak separation. Default is 30 days.

    Returns:
    --------
    N_peaks (int): Number of peaks detected.
    dip_summary (dict): Summary of the dip. Including the peak location, the window start and end, the number of 1 sigma detections in the dip, the number of detections in the dip, the forward and backward duration of the dip, and the dip power.
    """
    try:
        if len(dips)==0:
            return 0, 0

        #TODO: add smoothing savgol_filter again...
        yht = dips

        # Scipy peak finding algorithm
        pks, _ = find_peaks(yht, height=power_thresh, distance=pk_2_pk_cut) #TODO: is 100 days peak separation too aggresive?

        # Reverse sort the peak values
        pks = np.sort(pks)[::-1]
        
        # Time of peaks and dev of peaks
        t_pks, p_pks = times[pks], dips[pks]
        
        # Number of peaks
        N_peaks = len(t_pks)
        
        dip_summary = {}
        for i, (time_ppk, ppk) in enumerate(zip(t_pks, p_pks)):
            #TODO: old version
            #_edges = calc_dip_edges(times, dips, time_ppk, atol=0.2)
            _edges = detect_bursts_edges(times, dips, time_ppk, np.nanmean(dips), np.nanstd(dips), burst_threshold=3.0, expansion_indices=1)
            # t_start, t_end, abs(t_start-center_time), abs(t_end-center_time), N_thresh_1, 0, 0 : above. #TODO: remove this!
            
            dip_summary[f'dip_{i}'] = {
                "peak_loc": time_ppk,
                'window_start': _edges[0],
                'window_end': _edges[1],
                "N_1sig_in_dip": _edges[-3], # number of 1 sigma detections in the dip
                "N_in_dip": _edges[-3], # number of detections in the dip
                'loc_forward_dur': _edges[2],
                "loc_backward_dur": _edges[3],
                "dip_power":ppk,
                "average_dt_dif": _edges[-1]
            }
                    
        return N_peaks, dip_summary
    except:
        return 0, 0

def eval_prelim(time_cat, mag_cat, mag_err_cat, flag_cat, band_cat):
    """Given the light curve source, compute the number of peaks found."""
    
    # Digest my light curve. Select band, good detections & sort
    time, mag, mag_err = prepare_lc(time_cat, mag_cat, mag_err_cat, flag_cat, band_cat, 
                                    band_of_study='r', flag_good=0, q=None, custom_q=False)
    
    # Digest my light curve. Select band, good detections & sort
    time_g, mag_g, mag_err_g = prepare_lc(time_cat, mag_cat, mag_err_cat, flag_cat, band_cat, 
                                    band_of_study='g', flag_good=0, q=None, custom_q=False)
    
    if len(time)>10 and len(time_g)>10: 
        # Evaluate biweight location and scale & other obvious statistics
        R, S = astro_stats.biweight.biweight_location(mag), astro_stats.biweight.biweight_scale(mag)

        # Running deviation
        running_deviation = deviation(mag, mag_err, R, S)

        # Peak detection summary per light curve
        peak_detections = peak_detector(time, running_deviation, power_thresh=4, peak_close_rmv=20, pk_2_pk_cut=20)
        
        if peak_detections[0]>0:
            del time, mag, mag_err, time_g, mag_g, mag_err_g, R, S, running_deviation
            return peak_detections[0]          
    else:
        del time, mag, mag_err, time_g, mag_g, mag_err_g
        return 0   

In [None]:
%%time
# apply calc_biweight function
batch_calc = ens.batch(
    eval_prelim,
    'mjd_ztf_zource', 'mag_ztf_zource', 
    'magerr_ztf_zource', 'catflags_ztf_zource',
    'band_ztf_zource')

In [None]:
%%time
full_comp = batch_calc.compute()