In [6]:
import os

import numpy as np
import pandas as pd
#import polars as pl
import matplotlib.pyplot as plt

import astropy

import scipy
 
import torch
from torch import nn
from torch.utils.data import TensorDataset, Dataset, DataLoader
import lightning as L

from tqdm.auto import tqdm

import lcdata

In [16]:
lightcurves_alercextns = pd.read_pickle('/home/jurados/Supernovae_DeepLearning/data/lightcurves/lcs_transients_20240517.pkl')

In [17]:
#lightcurves_alercextns = pl.from_pandas(lightcurves_alercextns)
lightcurves_alercextns.columns

Index(['oid', 'candid', 'rfid', 'mjd', 'fid', 'magpsf', 'sigmapsf'], dtype='object')

In [None]:
def process_light_curve_parsnip(ligth_curve):

    new_light_curve = ligth_curve.copy()

    SIDEREAL_SCALE = 86400. / 86164.0905

    time = ligth_curve['mjd'].to_numpy()
    sidereal_time = time * SIDEREAL_SCALE

    # Initial guess of the phase. Round everything to 0.1 days, and find the decimal
    # that has the largest count.
    mode, count = scipy.stats.mode(np.round(sidereal_time % 1 + 0.05, 1), keepdims=True)
    guess_offset = mode[0] - 0.05

    # Shift everything by the guessed offset
    guess_shift_time = sidereal_time - guess_offset

    # Do a proper estimate of the offset.
    sidereal_offset = guess_offset + np.median((guess_shift_time + 0.5) % 1) - 0.5

    # Shift everything by the final offset estimate.
    shift_time = sidereal_time - sidereal_offset

    # Selecting the 
    s2n = ligth_curve['magpsf'] / ligth_curve['sigmapsf']
    s2n_mask = np.argsort(s2n)[-5:]

    cut_times = shift_time[s2n_mask]

    max_time = np.round(np.median(cut_times))

    # Convert back to a reference time in the original units. This reference time
    # corresponds to the reference of the grid in sidereal time.
    reference_time = ((max_time + sidereal_offset) / SIDEREAL_SCALE)
    grid_times = (time - reference_time) * SIDEREAL_SCALE
    time_indices = np.round(grid_times).astype(int) + 300 // 2 # 300 days
    time_mask = (
        (time_indices >= -100)
        & (time_indices < 300 + 100)
    )
    new_light_curve['grid_time'] = grid_times
    new_light_curve['time_index'] = time_indices
    new_light_curve = new_light_curve[time_mask]

    return new_light_curve

In [None]:
def plot_light_curve(light_curve):

    time = light_curve['mjd'].to_numpy()
    mag  = light_curve['magpsf'].to_numpy()

    fig, ax = plt.subplots()
    ax.plot(time,mag,'o')
    ax.set_xlabel('MJD')
    ax.set_ylabel('Apparent magnitude')
    ax.set_ylim(ax.get_ylim()[::-1])

In [18]:
one_light_curve = lightcurves_alercextns[lightcurves_alercextns['oid'] == lightcurves_alercextns['oid'].unique()[0]]
one_light_curve

Unnamed: 0,oid,candid,rfid,mjd,fid,magpsf,sigmapsf
0,ZTF19abgpgyp,1515523874715015006,681120247.0,59269.523877,2,20.385720,0.211641
1,ZTF19abgpgyp,1499511354715015014,,59253.511354,2,20.678400,0.258971
2,ZTF19abgpgyp,1502430264715015014,,59256.430266,2,20.652500,0.269586
3,ZTF19abgpgyp,1510536184715015009,,59264.536181,2,20.663600,0.299923
521734,ZTF19abgpgyp,1376156074715015007,,59130.156076,1,20.678900,0.251519
...,...,...,...,...,...,...,...
768148,ZTF19abgpgyp,930177034715015039,681120247.0,58684.177037,2,18.858166,0.088371
771630,ZTF19abgpgyp,1255363184715015007,,59009.363183,2,19.853200,0.246000
776366,ZTF19abgpgyp,980143654715015015,681120247.0,58734.143657,2,18.257755,0.069674
776367,ZTF19abgpgyp,1031115464715015011,681120147.0,58785.115463,1,19.355577,0.159911
