In [20]:

# ELAsTiCC Dataset Light Curves Processing and Visualization - Object Oriented Version

import os
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm
from astropy.io import fits
import matplotlib.pyplot as plt
from astropy import units as u
from astropy.coordinates import SkyCoord
import copy
import random
import warnings

# Configuration for folders
DATA_FOLDER = "../DataSets/"  # Folder where all datasets are saved
VISUALIZATION_FOLDER = "../LightCurves_Visualization/"  # Folder for visualizations


class LightCurve:
    def __init__(self, idx, mjd=None, flux=None, flux_err=None, mag=None, mag_err=None, source="txt", select="snid", file_path=None, zp=27.5):
        self.band = band
        if source == "txt" and file_path:
            self.load_from_txt(file_path)
        elif source == "snana" and file_path:
            self.load_from_snana(file_path, band)
        else:
            self.mjd = mjd
            self.flux = flux
            self.flux_err = flux_err
            self.mag, self.mag_err = mag, mag_err or self.flux_to_mag(flux, flux_err)

    @staticmethod
    def flux_to_mag(flux, flux_err=None, zp=27.5):
        if flux_err is not None:
            top = flux + flux_err
            bot = flux - flux_err
            top_mag = zp - 2.5 * np.log10(np.abs(top))
            bot_mag = zp - 2.5 * np.log10(np.abs(bot))
            return zp - 2.5 * np.log10(np.abs(flux)), np.abs(top_mag - bot_mag)
        else:
            return zp - 2.5 * np.log10(np.abs(flux)), None

    def load_from_snana(self, file_path, band, select="snid", idx=None, fit_file=None):
        with fits.open(file_path) as hdul:
            head = hdul[1].data
            phot = hdul[2].data if len(hdul) > 2 else None

            if select == "snid":
                start = head[head["SNID"] == str(idx)]["PTROBS_MIN"][0]
                end = head[head["SNID"] == str(idx)]["PTROBS_MAX"][0]
            elif select == "idx":
                start, end = head["PTROBS_MIN"][idx], head["PTROBS_MAX"][idx]
            elif select == "random":
                idx = random.randint(0, len(head))
                start, end = head["PTROBS_MIN"][idx], head["PTROBS_MAX"][idx]

            lc_band = phot[start:end][phot[start:end]["BAND"].lower() == band.lower()]
            self.mjd = lc_band["MJD"]
            self.flux = lc_band["FLUXCAL"]
            self.flux_err = lc_band["FLUXCALERR"]
            self.mag, self.mag_err = self.flux_to_mag(self.flux, self.flux_err)

    def load_from_txt(self, file_path, zp = 27.5):
        mjd, mag, mag_err = np.loadtxt(file_path).T
        self.mjd, self.mag, self.mag_err = mjd, mag, mag_err
        self.flux = 10 ** ((zp - mag) / 2.5)
        self.flux_err = np.zeros_like(self.flux)

    def plot(self, ax=None, show=True):
        if ax is None:
            fig, ax = plt.subplots()
        else:
            fig = ax.figure
        ax.errorbar(self.mjd, self.mag, self.mag_err, fmt='o', label=f"Band {self.band}", alpha=0.8)
        ax.set_xlabel("MJD")
        ax.set_ylabel("Magnitude")
        ax.set_title(f"Light Curve in Band {self.band}")
        ax.invert_yaxis()
        ax.legend()
        if show:
            plt.show()
        return fig, ax
        
    @staticmethod
    def flux_to_mag(flux, flux_err=None, zp = 27.5):
        """Convert flux to magnitude and optionally calculate the error in magnitude."""
        if flux_err is not None:
            top = flux + flux_err
            bot = flux - flux_err
            top_mag = zp - 2.5 * np.log10(np.abs(top))
            bot_mag = zp - 2.5 * np.log10(np.abs(bot))
            return zp - 2.5 * np.log10(np.abs(flux)), np.abs(top_mag - bot_mag)
        else:
            return zp - 2.5 * np.log10(np.abs(flux)), None

    def to_dict(self):
        """Convert the light curve to a dictionary format."""
        return {"mjd": self.mjd, "mag": self.mag, "mag_err": self.mag_err}

    def to_file(self, path):
        """Save light curve data to a text file."""
        with open(path, 'w') as f:
            for mj, m, merr in zip(self.mjd, self.mag, self.mag_err):
                f.write(f"   {mj:.3f}  {m:.4f}   {merr:.3f}\n")
            f.write('\n')

    @classmethod
    def from_file(cls, file_path, band, zp=27.5):
        """Load a light curve from a text file."""
        mjd, mag, mag_err = np.loadtxt(file_path).T
        flux =   10 ** ((zp - mag) / 2.5)# Convert mag back to flux
        flux_err = np.zeros_like(flux)  # Placeholder for error
        return cls(band, mjd, flux, flux_err)

    def plot(self, ax=None, show=True):
        """Plot the light curve data."""
        if ax is None:
            fig, ax = plt.subplots()
        else:
            fig = ax.figure
        ax.errorbar(self.mjd, self.mag, self.mag_err, label=f"Band {self.band}", fmt='o', alpha=0.8)
        ax.set_xlabel("MJD")
        ax.set_ylabel("Magnitude")
        ax.set_title(f"Light Curve in Band {self.band}")
        ax.invert_yaxis()
        ax.legend()
        if show:
            plt.show()
        return fig, ax


# ---------------------------------
# 2. DataSet Class
# ---------------------------------

class DataSet:
    def __init__(self, name, data_folder=DATA_FOLDER):
        self.name = name
        self.data_folder = Path(data_folder, name)
        self.light_curves = {}

    # def load_from_fits(self, file_path, band):
    #     """Load light curve data from a FITS file."""
    #     with fits.open(file_path) as hdul:
    #         mjd = hdul[1].data['MJD']
    #         flux = hdul[1].data['FLUX']
    #         flux_err = hdul[1].data['FLUX_ERR']
    #         lc = LightCurve(band, mjd, flux, flux_err)
    #         self.light_curves[band] = lc

    # def load_from_dat(self, file_path, band):
    #     """Load light curve data from a .dat file."""
    #     mjd, mag, mag_err = np.loadtxt(file_path).T
    #     flux = 10 ** ((27.5 - mag) / 2.5)
    #     flux_err = np.zeros_like(flux)  # Placeholder for flux error
    #     lc = LightCurve(band, mjd, flux, flux_err)
    #     self.light_curves[band] = lc

    def save_all_to_files(self, output_folder):
        """Save all light curves in the dataset to text files."""
        output_folder = Path(output_folder)
        output_folder.mkdir(parents=True, exist_ok=True)
        for band, lc in self.light_curves.items():
            file_path = output_folder / f"{self.name}_{band}.txt"
            lc.save_to_file(file_path)

    def plot_all(self, bands="ugrizy", show=True):
        """Plot all light curves in the dataset."""
        fig, ax = plt.subplots()
        for band in bands:
            if band in self.light_curves:
                self.light_curves[band].plot(ax=ax, show=False)
        ax.set_title(f"Light Curves for Dataset {self.name}")
        if show:
            plt.show()
        return fig, ax

    def filter_light_curves(self, min_mjd=None, max_mjd=None):
        """Apply MJD filters to each light curve in the dataset."""
        for lc in self.light_curves.values():
            if min_mjd is not None:
                mask = lc.mjd >= min_mjd
                lc.mjd, lc.mag, lc.mag_err = lc.mjd[mask], lc.mag[mask], lc.mag_err[mask]
            if max_mjd is not None:
                mask = lc.mjd <= max_mjd
                lc.mjd, lc.mag, lc.mag_err = lc.mjd[mask], lc.mag[mask], lc.mag_err[mask]


In [None]:


# ---------------------------------
# 3. Example Usage
# ---------------------------------

if __name__ == "__main__":
    # Example: Load a dataset and plot all light curves
    dataset_name = "ELASTICC2"
    dataset = DataSet(dataset_name)

    # Load light curves from files
    for band in "ugrizy":
        file_path = DATA_FOLDER + f"{dataset_name}_{band}.fits"
        if os.path.exists(file_path):
            dataset.load_from_fits(file_path, band)

    # Plot all light curves in the dataset
    dataset.plot_all(show=True)

    # Save light curves to text files
    dataset.save_all_to_files(VISUALIZATION_FOLDER + dataset_name)

    # Apply filters and re-plot
    dataset.filter_light_curves(min_mjd=59000, max_mjd=60000)
    dataset.plot_all(show=True)


In [None]:
!conda activate TV_Classifier