## Params, Imports, Read data,

In [None]:
import matplotlib.pyplot as plt
import os, glob
import numpy as np
from tqdm.auto import tqdm
from astropy.io import fits
from astropy.table import Table, vstack, MaskedColumn
from astropy.time import Time, TimeDelta
from astropy import units as u
import pandas as pd
import pickle
import json
import seaborn as sns
import george
from george import kernels
from scipy.optimize import minimize
from matplotlib.lines import Line2D
from scipy.stats import iqr
from sklearn.decomposition import PCA
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.impute import SimpleImputer
from collections import defaultdict
from matplotlib.colors import Normalize
from matplotlib.cm import viridis
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_curve, auc
from imblearn.over_sampling import SMOTE
from imblearn.under_sampling import RandomUnderSampler
from sklearn.neural_network import MLPClassifier
import re
import time
import shap
#from dask import delayed, compute
#from dask.distributed import Client, LocalCluster
import logging
import multiprocessing as mp
import concurrent.futures
import gc  # For garbage collection

#from sklearn.neural_network import MLPClassifier
#import torch
#import torch.nn as nn
#import torch.optim as optim
#from torch.utils.data import DataLoader, TensorDataset
import xgboost as xgb

# Use get_cmap from pyplot
get_cmap = plt.get_cmap

# Configure logging
logging.basicConfig(
    filename='data_conversion_debug.log',  # Log file name
    filemode='w',                           # Overwrite the log file each run
    level=logging.DEBUG,                    # Capture all levels of logs
    format='%(asctime)s - %(levelname)s - %(message)s'  # Log format
)

In [None]:
# Define LSST band colors and effective wavelengths (positions for data points along wavelength direction)

lsst_bands = {
    "u": 3670.69,
    "g": 4826.85,
    "r": 6223.24,
    "i": 7545.98,
    "z": 8590.90,
    "Y": 9710.28
}


bands = ["u", "g", "r", "i", "z", "Y"]

band_colors = {'u': 'blue', 'g': 'green', 'r': 'red', 'i': 'purple', 'z': 'brown', 'Y': 'yellow'}


In [None]:
def read_elasticc_file(filename):
    if '_PHOT' in filename:
        headname = filename.replace('_PHOT', '_HEAD')
    else:
        headname = filename
        filename = filename.replace('_HEAD', '_PHOT')

    # Debug prints to verify paths
    #print(f"reading phot file: {filename}")
    print(f"reading head file: {headname}")
    
    if not os.path.exists(filename):
        raise FileNotFoundError(f"File not found: {filename}")
    if not os.path.exists(headname):
        raise FileNotFoundError(f"File not found: {headname}")

    table = Table.read(filename)
    head = Table.read(headname)

    # Sanitize the data
    for _ in table:
        _['BAND'] = _['BAND'].strip()

    head['SNID'] = np.int64(head['SNID'])
    
    # Sanity check 
    if np.sum(table['MJD'] < 0) != len(head):
        print(filename, 'is broken:', np.sum(table['MJD'] < 0), '!=', len(head))
        
    # Measured mag and magerr - simulated one is in SIM_MAGOBS
    table['mag'] = np.nan
    table['magerr'] = np.nan
    idx = table['FLUXCAL'] > 0
    
    table['mag'][idx] = 27.5 - 2.5 * np.log10(table['FLUXCAL'][idx])
    table['magerr'][idx] = 2.5 / np.log(10) * table['FLUXCALERR'][idx] / table['FLUXCAL'][idx]
    
    # Augment table with SNID (light curve id) from head
    table['SNID'] = 0
    
    idx = np.where(table['MJD'] < 0)[0]
    idx = np.hstack((np.array([0]), idx))

    for i in range(1, len(idx)):
        i0, i1 = idx[i - 1], idx[i]
        table['SNID'][i0:i1] = head['SNID'][i - 1]
    
    table = table[table['MJD'] > 0]

    return table, head


## Functions 

In [None]:
def get_snid_head_sub(table, head):
    snids = np.unique(table['SNID'])
    shead_list = [head[head['SNID'] == snid] for snid in snids]
    sub_list = [table[table['SNID'] == snid] for snid in snids]
    return snids, shead_list, sub_list


### compute GP

In [None]:
def try_optimization(gp, snid, neg_ln_like, grad_neg_ln_like, initial_guess, retries=3):
    best_result = None
    for attempt in range(retries):
        result = minimize(neg_ln_like, initial_guess, jac=grad_neg_ln_like, method='L-BFGS-B')
        if best_result is None or (result.success and result.fun < best_result.fun):
            best_result = result
        if result.success:
            break
        else:
            # Slightly perturb the initial guess for the next attempt
            initial_guess += np.random.normal(0, 1e-2, size=initial_guess.shape)
    
    if best_result is None or not best_result.success:
        print(f"All optimization attempts failed for SNID {snid}")
        gp.set_parameter_vector(initial_guess)  # Use the best guess available
    else:
        gp.set_parameter_vector(best_result.x)

In [None]:
def compute_gp(sub, snid, verbose=False):
    try:
        # Ensure inputs are numpy arrays with appropriate dtype
        t = np.array(sub['MJD'], dtype=float)
        flux = np.array(sub['FLUXCAL'], dtype=float)
        fluxerr = np.array(sub['FLUXCALERR'], dtype=float)
        band = np.array([lsst_bands.get(b) for b in sub['BAND']], dtype=float)
        
        # 2D positions of data points (time and wavelength)
        x = np.vstack([t, band]).T
        
        # Clean the data: remove rows with NaNs, infs, and non-positive flux values
        mask = np.isfinite(flux) & np.isfinite(fluxerr) & np.all(np.isfinite(x), axis=1) & (flux > 0)
        x = x[mask]
        flux = flux[mask]
        fluxerr = fluxerr[mask]

        if len(flux) < 5:  # Ensure there are enough data points
            raise ValueError("Not enough data points to perform GP fitting.")
        
        signal_to_noises = np.abs(flux) / np.sqrt(fluxerr ** 2 + (1e-2 * np.max(flux)) ** 2)
        scale = np.abs(flux[np.argmax(signal_to_noises)])

        # Define the kernel
        kernel = (0.5 * scale) ** 2 * george.kernels.Matern32Kernel([100 ** 2, 6000 ** 2], ndim=2)
        
        # Define the GP model with HODLR solver and white noise
        gp = george.GP(kernel, solver=george.HODLRSolver)
        
        # Compute the GP
        gp.compute(x, fluxerr)
        
        # Define the negative log likelihood and its gradient
        def neg_ln_like(p):
            gp.set_parameter_vector(p)
            return -gp.log_likelihood(flux)
        
        def grad_neg_ln_like(p):
            gp.set_parameter_vector(p)
            return -gp.grad_log_likelihood(flux)
        
        # Attempt optimization with multiple initial guesses
        initial_guess = gp.get_parameter_vector()
        try_optimization(gp, snid, neg_ln_like, grad_neg_ln_like, initial_guess, retries=3)
        
        # Return the GP, flux, data points, and final parameters
        return gp, flux, x, gp.get_parameter_vector()

    except (ValueError, np.linalg.LinAlgError, Exception) as e:
        print(f"GP optimization failed for SNID {snid}: {e}")
        return None, None, None, None  # Return None values


### peak and rise/fade estimation 

In [None]:
def find_peak_and_calculate_times(gp, sub, flux, band1='g', band2='r', peak_threshold=2.512):
    """
    Find the peak using Gaussian Process predictions for two input bands (band1 and band2).
    Then calculate the rise and fade times relative to the peak for each band.
    """
    
    # Generate MJD values for predictions
    t_min, t_max = sub['MJD'].min(), sub['MJD'].max()

    # Variables to hold peak information for both bands
    peak_mjd_band1, peak_flux_band1, rise_time_band1, fade_time_band1 = None, None, None, None
    peak_mjd_band2, peak_flux_band2, rise_time_band2, fade_time_band2 = None, None, None, None

    # Loop over the two input bands (band1 and band2)
    for band in [band1, band2]:
        mjd_for_pred = np.linspace(t_min - 50, t_max + 75, 1000)
        wavelength = lsst_bands[band]
        x_pred = np.vstack([mjd_for_pred, wavelength * np.ones_like(mjd_for_pred)]).T

        # Predict flux at these times using the GP model
        mean_pred, _ = gp.predict(flux, x_pred, return_var=True)

        # Find the time of peak flux
        peak_flux_idx = np.argmax(mean_pred)
        peak_mjd = mjd_for_pred[peak_flux_idx]
        peak_flux = mean_pred[peak_flux_idx]

        # Calculate rise and fade times
        one_flux_fainter = peak_flux / peak_threshold  # Threshold flux for 1 mag fainter

        # Rise time: time it took to rise from the faint threshold to the peak
        rise_time = None
        try:
            rise_time_idx = np.where(mean_pred[:peak_flux_idx] <= one_flux_fainter)[0]
            if len(rise_time_idx) > 0:
                rise_time = peak_mjd - mjd_for_pred[rise_time_idx[-1]]
        except IndexError:
            rise_time = None  # If no valid rise time found

        # Fade time: time it takes to decay from peak to the faint threshold
        fade_time = None
        try:
            fade_time_idx = np.where(mean_pred[peak_flux_idx:] <= one_flux_fainter)[0]
            if len(fade_time_idx) > 0:
                fade_time = mjd_for_pred[peak_flux_idx + fade_time_idx[0]] - peak_mjd
        except IndexError:
            fade_time = None  # If no valid fade time found

        # Store results for the respective band
        if band == band1:
            peak_mjd_band1, peak_flux_band1, rise_time_band1, fade_time_band1 = peak_mjd, peak_flux, rise_time, fade_time
        elif band == band2:
            peak_mjd_band2, peak_flux_band2, rise_time_band2, fade_time_band2 = peak_mjd, peak_flux, rise_time, fade_time

    return (peak_mjd_band1, peak_flux_band1, rise_time_band1, fade_time_band1, 
            peak_mjd_band2, peak_flux_band2, rise_time_band2, fade_time_band2)


### color estimation

In [None]:
def calc_color(gp, flux, x, x1, band1, band2):
    x_band1 = np.vstack([x1, lsst_bands[band1] * np.ones_like(x1)]).T
    x_band2 = np.vstack([x1, lsst_bands[band2] * np.ones_like(x1)]).T

    # Predict fluxes and their variances for both bands
    flux_pred_band1, fluxvar_band1 = gp.predict(flux, x_band1, return_var=True)
    flux_pred_band2, fluxvar_band2 = gp.predict(flux, x_band2, return_var=True)

    # Ensure positive flux predictions by clipping at a small positive value
    flux_pred_band1_clipped = np.clip(flux_pred_band1, 1e-10, None)
    flux_pred_band2_clipped = np.clip(flux_pred_band2, 1e-10, None)

    # Ensure that variances are non-negative
    fluxvar_band1_clipped = np.clip(fluxvar_band1, 0, None)
    fluxvar_band2_clipped = np.clip(fluxvar_band2, 0, None)

    # Calculate magnitudes from fluxes
    mag_band1 = -2.5 * np.log10(flux_pred_band1_clipped)
    mag_band2 = -2.5 * np.log10(flux_pred_band2_clipped)
    
    # Calculate errors in magnitudes
    with np.errstate(invalid='ignore', divide='ignore'):
        magerr_band1 = 2.5 / np.log(10) * (np.sqrt(fluxvar_band1_clipped) / flux_pred_band1_clipped)
        magerr_band2 = 2.5 / np.log(10) * (np.sqrt(fluxvar_band2_clipped) / flux_pred_band2_clipped)

    # Calculate color as the difference in magnitudes
    color = mag_band1 - mag_band2
    
    # Calculate the error in color
    color_err = np.sqrt(magerr_band1**2 + magerr_band2**2)

    # Create a mask to select only non-NaN and finite elements
    valid_mask = np.isfinite(color) & np.isfinite(color_err)

    # Apply the mask to color, color_err arrays, and x1
    color = color[valid_mask]
    color_err = color_err[valid_mask]
    x1 = x1[valid_mask]

    return x1, color, color_err


### mean color and color evolution

In [None]:
def weighted_mean_std(value, error, min_error=1e-6):
    """
    Calculates the weighted mean and weighted standard deviation.

    Parameters:
        value (array-like): Data values.
        error (array-like): Error values corresponding to the data.
        min_error (float): Minimum allowable error to prevent division by zero.

    Returns:
        tuple: (weighted_mean, weighted_std)
    """
    # Replace zero or negative errors with min_error
    error_safe = np.where(error > 0, error, min_error)
    
    # Calculate weights
    weight = 1 / error_safe**2
    
    # Handle cases where weights might still be invalid (e.g., due to overflow)
    weight = np.where(np.isfinite(weight), weight, 0)
    
    # Calculate the sum of weights
    sum_weights = np.sum(weight)
    
    if sum_weights == 0:
        # If all weights are zero, return NaN
        weighted_mean = np.nan
        weighted_std = np.nan
    else:
        # Calculate weighted mean
        weighted_mean = np.sum(value * weight) / sum_weights
        
        # Calculate weighted standard deviation
        weighted_std = np.sqrt(1 / sum_weights)
    
    return weighted_mean, weighted_std


In [None]:
def calc_mean_colors_and_slope(sub, gp, flux, x, band1, band2, rise_time, fade_time, t_peak):
    """
    Calculate mean colors, slopes, and plot color evolution for a given SNID using rise and fade times.
    Fallback to original logic if rise_time or fade_time is not available.
    """
    
    # Fallback to default rise_time and fade_time if not valid
    calc_rise_time = 50 if rise_time is None or not np.isfinite(rise_time) else rise_time
    calc_fade_time = 75 if fade_time is None or not np.isfinite(fade_time) else fade_time

    ### Pre-peak calculations using the rise time or default window
    try:
        indices_pre_peak = ((sub['BAND'] == band1) | (sub['BAND'] == band2)) & \
                           (sub['MJD'] >= t_peak - calc_rise_time) & (sub['MJD'] <= t_peak)
        mjd_pre_peak = sub['MJD'][indices_pre_peak]

        # Calculate color and filter valid data
        x_pre_peak, color_pre_peak, color_err_pre_peak = calc_color(gp, flux, x, mjd_pre_peak, band1, band2)

        # Ensure we only apply np.isfinite() to numeric data
        mask_pre = np.isfinite(x_pre_peak) & np.isfinite(color_pre_peak) & (color_err_pre_peak < 1)

        # Apply mask to filter out invalid data
        x_pre_peak, color_pre_peak, color_err_pre_peak = x_pre_peak[mask_pre], color_pre_peak[mask_pre], color_err_pre_peak[mask_pre]

    except Exception as e:
        # If pre-peak data fails, set all pre-peak values to None
        x_pre_peak, color_pre_peak, color_err_pre_peak = None, None, None

    ### Post-peak calculations using the fade time or default window
    try:
        indices_post_peak = ((sub['BAND'] == band1) | (sub['BAND'] == band2)) & \
                            (sub['MJD'] > t_peak) & (sub['MJD'] <= t_peak + calc_fade_time)
        mjd_post_peak = sub['MJD'][indices_post_peak]

        # Calculate color and filter valid data
        x_post_peak, color_post_peak, color_err_post_peak = calc_color(gp, flux, x, mjd_post_peak, band1, band2)

        # Ensure we only apply np.isfinite() to numeric data
        mask_post = np.isfinite(x_post_peak) & np.isfinite(color_post_peak) & (color_err_post_peak < 1)

        # Apply mask to filter out invalid data
        x_post_peak, color_post_peak, color_err_post_peak = x_post_peak[mask_post], color_post_peak[mask_post], color_err_post_peak[mask_post]

    except Exception as e:
        # If post-peak data fails, set all post-peak values to None
        x_post_peak, color_post_peak, color_err_post_peak = None, None, None

    ### Calculate weighted mean for pre-peak and post-peak colors
    mean_color_pre_peak, std_err_mean_color_pre_peak = None, None
    if color_pre_peak is not None and len(color_pre_peak) > 0:
        mean_color_pre_peak, std_err_mean_color_pre_peak = weighted_mean_std(color_pre_peak, color_err_pre_peak)

    mean_color_post_peak, std_err_mean_color_post_peak = None, None
    if color_post_peak is not None and len(color_post_peak) > 0:
        mean_color_post_peak, std_err_mean_color_post_peak = weighted_mean_std(color_post_peak, color_err_post_peak)

    ### Calculate slopes for pre-peak and post-peak
    slope_pre_peak, slope_err_pre_peak = None, None
    if x_pre_peak is not None and len(x_pre_peak) >= 2:
        try:
            # Replace zero or very small color_err_pre_peak with a minimum threshold
            color_err_pre_peak_safe = np.where(color_err_pre_peak > 0, color_err_pre_peak, 1e-6)
            
            # Calculate weights safely
            weights_pre = 1 / color_err_pre_peak_safe**2
            weights_pre = np.where(np.isfinite(weights_pre), weights_pre, 0)
            
            # Check if sum of weights is zero
            if np.sum(weights_pre) == 0:
                slope_pre_peak, slope_err_pre_peak = np.nan, np.nan
            else:
                p_pre_peak, cov_pre_peak = np.polyfit(
                    x_pre_peak, color_pre_peak, 1, w=weights_pre, cov=True
                )
                slope_pre_peak, _ = p_pre_peak
                slope_err_pre_peak = np.sqrt(cov_pre_peak[0, 0])  # Standard error of the slope
        except Exception as e:
            slope_pre_peak, slope_err_pre_peak = np.nan, np.nan

    slope_post_peak, slope_err_post_peak = None, None
    if x_post_peak is not None and len(x_post_peak) >= 2:
        try:
            # Replace zero or very small color_err_post_peak with a minimum threshold
            color_err_post_peak_safe = np.where(color_err_post_peak > 0, color_err_post_peak, 1e-6)
            
            # Calculate weights safely
            weights_post = 1 / color_err_post_peak_safe**2
            weights_post = np.where(np.isfinite(weights_post), weights_post, 0)
            
            # Check if sum of weights is zero
            if np.sum(weights_post) == 0:
                slope_post_peak, slope_err_post_peak = np.nan, np.nan
            else:
                p_post_peak, cov_post_peak = np.polyfit(
                    x_post_peak, color_post_peak, 1, w=weights_post, cov=True
                )
                slope_post_peak, _ = p_post_peak
                slope_err_post_peak = np.sqrt(cov_post_peak[0, 0])  # Standard error of the slope
        except Exception as e:
            slope_post_peak, slope_err_post_peak = np.nan, np.nan

    return (
        mean_color_pre_peak, std_err_mean_color_pre_peak, slope_pre_peak, slope_err_pre_peak,
        mean_color_post_peak, std_err_mean_color_post_peak, slope_post_peak, slope_err_post_peak
    )


In [None]:
from astropy.io import fits

# Example file path (replace with the actual file path if necessary)
example_file = "../../../karpov/ELASTICC2/ELASTICC2_FINAL_TDE/ELASTICC2_FINAL_NONIaMODEL0-0001_HEAD.FITS.gz"

# Read the file
with fits.open(example_file) as hdul:
    data = hdul[1].data  # Assuming the data is in the second HDU
    column_names = data.columns.names

# Print the column names vertically
for col_name in column_names:
    print(col_name)


## Main processing loop

In [None]:
print(os.getcwd())


### multiprocessing

### Fraction of observations per band

In [None]:
# Start the timer
start_time = time.time()

# Base path and template for file names
base_path = "../../../karpov/ELASTICC2/"
filename_template = "ELASTICC2_FINAL_{object_type}/ELASTICC2_FINAL_NONIaMODEL0-{index}_HEAD.FITS.gz"

# Object types and model names
object_info = [
    'TDE',
    'AGN',
    'SLSN-I+host',
    'SLSN-I_no_host',
    'SNIa-SALT3',
    'SNIa-91bg',
    'SNIax',
    'SNIcBL+HostXT_V19',
    'SNIb+HostXT_V19',
    'SNIIn-MOSFIT',
    'SNII-NMF',
    'SNII+HostXT_V19',
    'SNIIb+HostXT_V19',
    'KN_B19',
    'KN_K17',
]

# Generate file paths for each object type
all_filenames = []
object_types = []
for object_type in object_info:
    filenames = [os.path.join(base_path, filename_template.format(object_type=object_type, index=str(i).zfill(4))) for i in range(1, 2)]
    all_filenames.extend(filenames)
    object_types.extend([object_type] * len(filenames))

# Dictionary for counting data, organized by object type and band
count_data = defaultdict(lambda: defaultdict(int))

# Main loop to process each file
for i, (filename, object_type) in tqdm(enumerate(zip(all_filenames, object_types)), total=len(all_filenames)):
    try:
        table, head = read_elasticc_file(filename)
        snids, shead_list, sub_list = get_snid_head_sub(table, head)

        # Process each SNID
        for snid, shead, sub in zip(snids, shead_list, sub_list):
            try:
                if not isinstance(sub, Table):
                    sub = Table(sub)

                # Compute SNR of each data point
                sub['SNR'] = sub['FLUXCAL'] / sub['FLUXCALERR']

                # DETECTION CHECK: pass only if there are points with SNR > 5 and flux > 50 in any bands
                snr_flux_mask = (sub['SNR'] > 5) & (sub['FLUXCAL'] > 100)

                # Filter the data
                filtered_data = sub[snr_flux_mask]

                # Count data for each band
                for band in np.unique(filtered_data['BAND']):
                    count_data[object_type][band] += len(filtered_data['BAND'][filtered_data['BAND'] == band])

            except Exception as e:
                print(f"An error occurred while processing SNID {snid}: {e}")
                continue

    except FileNotFoundError as e:
        print(e)
        continue  # Skip this file and continue with the next one


# End the timer and print the elapsed time
end_time = time.time()
print(f"Time taken: {end_time - start_time} seconds")


In [None]:
# Define the color coding for the bands
band_colors = {'u': 'blue', 'g': 'green', 'r': 'red', 'i': 'purple', 'z': 'brown', 'Y': 'orange'}

# Define the combined classes
combined_classes = {
    'SNI': ['SNIa-SALT3', 'SNIa-91bg', 'SNIax', 'SNIcBL+HostXT_V19', 'SNIb+HostXT_V19'],
    'SNII': ['SNIIn-MOSFIT', 'SNII-NMF', 'SNII+HostXT_V19', 'SNIIb+HostXT_V19'],
    'AGN': ['AGN'],
    'KN': ['KN_K17', 'KN_B19'],
    'SLSN': ['SLSN-I+host', 'SLSN-I_no_host'],
    'TDE': ['TDE']
}

# Combine counts based on combined classes
combined_count_data = defaultdict(lambda: defaultdict(int))

for obj_class, subtypes in combined_classes.items():
    for subtype in subtypes:
        if subtype in count_data:
            for band in bands:
                combined_count_data[obj_class][band] += count_data[subtype][band]

# Calculate total counts for normalization
total_counts = {cls: sum(band_counts.values()) for cls, band_counts in combined_count_data.items()}

# Create subplots for each combined class and one for cumulative efficiency
num_combined_classes = len(combined_classes)
fig, axs = plt.subplots(num_combined_classes + 1, 1, figsize=(8, 3 * (num_combined_classes + 1)), sharex=True)

# Plot normalized counts (efficiency) for each combined class
for idx, (combined_class, band_counts) in enumerate(combined_count_data.items()):
    if total_counts[combined_class] > 0:
        efficiencies = [band_counts[band] / total_counts[combined_class] for band in bands]
        colors = [band_colors[band] for band in bands]
        axs[idx].bar(bands, efficiencies, color=colors, alpha=0.75)
        axs[idx].set_ylabel('Fraction')
        axs[idx].set_title(f'{combined_class}')
        axs[idx].grid(True, linestyle='--', alpha=0.7)

# Plot cumulative efficiency across all combined classes
cumulative_counts = defaultdict(int)
for band_counts in combined_count_data.values():
    for band in bands:
        cumulative_counts[band] += band_counts[band]

total_cumulative_count = sum(cumulative_counts.values())
if total_cumulative_count > 0:
    cumulative_efficiencies = [cumulative_counts[band] / total_cumulative_count for band in bands]
    axs[-1].bar(bands, cumulative_efficiencies, color=[band_colors[band] for band in bands], alpha=0.75)
    axs[-1].set_ylabel('Fraction')
    axs[-1].set_title('All classes combined')
    axs[-1].grid(True, linestyle='--', alpha=0.7)

# Set the xlabel only for the last subplot
axs[-1].set_xlabel('Band')

plt.tight_layout()
plt.show()


In [None]:
# Define the color coding for the bands
band_colors = {'u': 'blue', 'g': 'green', 'r': 'red', 'i': 'purple', 'z': 'brown', 'Y': 'orange'}

# Define the combined classes
combined_classes = {
    'SNI': ['SNIa-SALT3', 'SNIa-91bg', 'SNIax', 'SNIcBL+HostXT_V19', 'SNIb+HostXT_V19'],
    'SNII': ['SNIIn-MOSFIT', 'SNII-NMF', 'SNII+HostXT_V19', 'SNIIb+HostXT_V19'],
    'CLAGN': ['CLAGN'],
    'KN': ['KN_K17', 'KN_B19'],
    'SLSN': ['SLSN-I+host', 'SLSN-I_no_host'],
    'TDE': ['TDE']
}

# Combine counts based on combined classes
combined_count_data = defaultdict(lambda: defaultdict(int))

for obj_class, subtypes in combined_classes.items():
    for subtype in subtypes:
        if subtype in count_data:
            for band in bands:
                combined_count_data[obj_class][band] += count_data[subtype][band]

# Calculate total counts for normalization
total_counts = {cls: sum(band_counts.values()) for cls, band_counts in combined_count_data.items()}

# Create subplots for each combined class and one for cumulative efficiency
num_combined_classes = len(combined_classes)
fig, axs = plt.subplots(num_combined_classes + 1, 1, figsize=(8, 3 * (num_combined_classes + 1)), sharex=True)

# Plot normalized counts (efficiency) for each combined class
for idx, (combined_class, band_counts) in enumerate(combined_count_data.items()):
    if total_counts[combined_class] > 0:
        efficiencies = [band_counts[band] / total_counts[combined_class] for band in bands]
        colors = [band_colors[band] for band in bands]
        axs[idx].bar(bands, efficiencies, color=colors, alpha=0.75)
        axs[idx].set_ylabel('Fraction')
        axs[idx].set_title(f'{combined_class}')
        axs[idx].grid(True, linestyle='--', alpha=0.7)

# Plot cumulative efficiency across all combined classes
cumulative_counts = defaultdict(int)
for band_counts in combined_count_data.values():
    for band in bands:
        cumulative_counts[band] += band_counts[band]

total_cumulative_count = sum(cumulative_counts.values())
if total_cumulative_count > 0:
    cumulative_efficiencies = [cumulative_counts[band] / total_cumulative_count for band in bands]
    axs[-1].bar(bands, cumulative_efficiencies, color=[band_colors[band] for band in bands], alpha=0.75)
    axs[-1].set_ylabel('Fraction')
    axs[-1].set_title('All classes combined')
    axs[-1].grid(True, linestyle='--', alpha=0.7)

# Set the xlabel only for the last subplot
axs[-1].set_xlabel('Band')

plt.tight_layout()
plt.show()


## Quality cut plots

### counts/efficiencies

In [None]:
# Import necessary libraries
import time
import os
import glob
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from astropy.io import fits
import logging
import gc

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(levelname)s:%(message)s')

# Start the timer
start_time = time.time()

# Define the base directory for the ELASTICC2
base_path = "../../../karpov/ELASTICC2/"

# Combined classes and their respective object types
combined_classes = {
    'TDE': ['TDE'],
    'SNI': ['SNIa-SALT3', 'SNIa-91bg', 'SNIax', 'SNIcBL+HostXT_V19', 'SNIb+HostXT_V19'],
    'SNII': ['SNIIn-MOSFIT', 'SNII-NMF', 'SNII+HostXT_V19', 'SNIIb+HostXT_V19'],
    'AGN': ['AGN'],
    'KN': ['KN_K17', 'KN_B19'],
    'SLSN': ['SLSN-I+host', 'SLSN-I_no_host']
}

# Flatten the combined_classes dictionary to map individual types to combined classes
object_type_to_class = {}
for combined_class, obj_types in combined_classes.items():
    for obj_type in obj_types:
        object_type_to_class[obj_type] = combined_class

# Object types and model names
object_info = list(object_type_to_class.keys())

# Generate file paths for each object type
all_filenames = []
object_types = []
for object_type in object_info:
    filename_template = "ELASTICC2_FINAL_{object_type}/ELASTICC2_FINAL_NONIaMODEL0-{index}_HEAD.FITS.gz"
    for i in range(2, 3):  # Adjust range based on the number of files per object type
        file_path = os.path.join(base_path, filename_template.format(object_type=object_type, index=str(i).zfill(4)))
        
        # Check if the file exists
        if os.path.exists(file_path):
            all_filenames.append(file_path)
            object_types.append(object_type)
        else:
            logging.warning(f"File not found: {file_path}")

# **Step 1: Count Total Objects by Combined Class**

def count_total_objects_by_combined_class(all_filenames, object_types, object_type_to_class):
    """
    Counts the total number of objects per combined class across all HDF5 files.
    
    Parameters:
    - all_filenames (list): List of HEAD FITS file paths.
    - object_types (list): Corresponding list of object types for each file.
    - object_type_to_class (dict): Mapping from individual object types to combined classes.
    
    Returns:
    - pd.Series: Total counts per combined class.
    """
    # Initialize a dictionary to store counts
    total_counts = {combined_class: 0 for combined_class in combined_classes.keys()}
    
    for file_path, obj_type in zip(all_filenames, object_types):
        combined_class = object_type_to_class.get(obj_type, 'Unknown')
        if combined_class == 'Unknown':
            logging.warning(f"Object type '{obj_type}' not found in combined_classes mapping.")
            continue
        try:
            # Open the FITS file
            with fits.open(file_path) as hdulist:
                # Assuming the data is in the first extension
                data = hdulist[1].data
                num_objects = len(data)
                total_counts[combined_class] += num_objects
                logging.info(f"Counted {num_objects} objects for class '{combined_class}' in file '{file_path}'")
        except Exception as e:
            logging.error(f"Error reading {file_path}: {e}")
    
    # Convert to pandas Series
    total_counts_series = pd.Series(total_counts)
    
    return total_counts_series

# **Step 2: Count Passed-Cut Objects by Combined Class from FITS Files**

def count_passing_objects_from_fits(all_filenames, object_types, object_type_to_class, selection_criteria):
    """
    Counts the number of objects that pass the selection criteria per combined class directly from FITS files.
    
    Selection Criteria:
    - At least 3 data points in any band with SNR > 5 and FLUXCAL > 100.
    
    Parameters:
    - all_filenames (list): List of HEAD FITS file paths.
    - object_types (list): Corresponding list of object types for each file.
    - object_type_to_class (dict): Mapping from individual object types to combined classes.
    - selection_criteria (dict): Dictionary containing selection thresholds.
    
    Returns:
    - pd.Series: Passed-cut counts per combined class.
    """
    # Initialize a dictionary to store counts
    passed_counts = {combined_class: 0 for combined_class in combined_classes.keys()}
    
    for file_path, obj_type in zip(all_filenames, object_types):
        combined_class = object_type_to_class.get(obj_type, 'Unknown')
        if combined_class == 'Unknown':
            logging.warning(f"Object type '{obj_type}' not found in combined_classes mapping.")
            continue
        try:
            # Open the FITS file
            with fits.open(file_path) as hdulist:
                # Assuming the data is in the first extension
                data = hdulist[1].data
                # Check if necessary columns exist
                required_columns = ['SNID', 'BAND', 'FLUXCAL', 'FLUXCALERR']
                for col in required_columns:
                    if col not in data.columns.names:
                        logging.warning(f"Column '{col}' not found in {file_path}. Skipping this file.")
                        raise ValueError(f"Missing column '{col}'")
                
                # Convert to pandas DataFrame for easier manipulation
                df = pd.DataFrame(data)
                
                # Compute SNR
                df['SNR'] = df['FLUXCAL'] / df['FLUXCALERR']
                
                # Group by SNID
                grouped = df.groupby('SNID')
                
                # Iterate through each SNID
                for snid, group in grouped:
                    # Check if any band has at least 3 points with SNR > 5 and FLUXCAL > 100
                    bands = group['BAND'].unique()
                    pass_snid = False
                    for band in bands:
                        band_group = group[group['BAND'] == band]
                        condition = (band_group['SNR'] > selection_criteria['SNR_threshold']) & (band_group['FLUXCAL'] > selection_criteria['FLUXCAL_threshold'])
                        if condition.sum() >= selection_criteria['min_points']:
                            pass_snid = True
                            break
                    if pass_snid:
                        passed_counts[combined_class] += 1
                        
                logging.info(f"File '{file_path}' processed. Passed-count for class '{combined_class}': {passed_counts[combined_class]}")
        except Exception as e:
            logging.error(f"Error processing {file_path}: {e}")
    
    # Convert to pandas Series
    passed_counts_series = pd.Series(passed_counts)
    
    return passed_counts_series

# **Step 3: Count Passed-Cut Objects by Combined Class from CSV File**

def count_passed_cut_objects_combined(csv_file_path, object_type_to_class, combined_classes):
    """
    Counts the number of objects per combined class that passed the selection cuts from the processed CSV file.
    
    Parameters:
    - csv_file_path (str): Path to the processed CSV file.
    - object_type_to_class (dict): Mapping from individual object types to combined classes.
    - combined_classes (dict): Combined classes with their respective object types.
    
    Returns:
    - pd.Series: Passed-cut counts per combined class.
    """
    try:
        # Read the processed CSV file
        processed_df = pd.read_csv(csv_file_path)
        
        if 'Object_Type' not in processed_df.columns:
            logging.error(f"'Object_Type' column not found in {csv_file_path}.")
            return pd.Series({cls:0 for cls in combined_classes.keys()})
        
        # Map individual object types to combined classes
        processed_df['Combined_Class'] = processed_df['Object_Type'].map(object_type_to_class)
        
        # Handle unmapped object types
        unmapped = processed_df['Combined_Class'].isna().sum()
        if unmapped > 0:
            logging.warning(f"{unmapped} objects have unmapped 'Object_Type'. They will be excluded from counts.")
            processed_df = processed_df.dropna(subset=['Combined_Class'])
        
        # Count occurrences of each combined class
        passed_counts = processed_df['Combined_Class'].value_counts().sort_index()
        
        # Ensure all combined classes are represented
        passed_counts = passed_counts.reindex(combined_classes.keys(), fill_value=0)
        
        logging.info("Passed-Cut object counts by combined class from CSV:")
        logging.info(passed_counts)
        
        return passed_counts
    
    except Exception as e:
        logging.error(f"Error reading {csv_file_path}: {e}")
        return pd.Series({cls:0 for cls in combined_classes.keys()})

# **Step 4: Save Counts to CSV Files**

def save_counts_to_csv(total_counts, passed_counts_fits, passed_counts_csv, output_dir='counts'):
    """
    Saves the total and passed-cut counts from FITS and CSV to CSV files.
    
    Parameters:
    - total_counts (pd.Series): Total counts per combined class.
    - passed_counts_fits (pd.Series): Passed-cut counts from FITS per combined class.
    - passed_counts_csv (pd.Series): Passed-cut counts from CSV per combined class.
    - output_dir (str): Directory to save the count CSV files.
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # Save total counts
    total_counts.to_csv(os.path.join(output_dir, 'total_counts_combined_classes.csv'), header=True)
    logging.info(f"Total counts saved to '{os.path.join(output_dir, 'total_counts_combined_classes.csv')}'")
    
    # Save passed counts from FITS
    passed_counts_fits.to_csv(os.path.join(output_dir, 'passed_counts_from_fits_combined_classes.csv'), header=True)
    logging.info(f"Passed-cut counts from FITS saved to '{os.path.join(output_dir, 'passed_counts_from_fits_combined_classes.csv')}'")
    
    # Save passed counts from CSV
    passed_counts_csv.to_csv(os.path.join(output_dir, 'passed_counts_from_csv_combined_classes.csv'), header=True)
    logging.info(f"Passed-cut counts from CSV saved to '{os.path.join(output_dir, 'passed_counts_from_csv_combined_classes.csv')}'")

# **Step 5: Create Overlaid Histogram Comparing FITS and CSV Counts**

def plot_combined_class_histogram_comparison(total_counts, passed_counts_fits, passed_counts_csv, output_path='object_type_histogram_comparison.png'):
    """
    Plots a histogram comparing total objects, passed-cut objects from FITS, and passed-cut objects from CSV by combined class.
    
    Parameters:
    - total_counts (pd.Series): Counts of total objects by combined class.
    - passed_counts_fits (pd.Series): Counts of passed-cut objects from FITS by combined class.
    - passed_counts_csv (pd.Series): Counts of passed-cut objects from CSV by combined class.
    - output_path (str): Path to save the histogram image.
    """
    # Combine the counts into a single DataFrame
    combined_df = pd.DataFrame({
        'Total': total_counts,
        'Passed_Cut_FITS': passed_counts_fits,
        'Passed_Cut_CSV': passed_counts_csv
    }).fillna(0)  # Fill NaN with 0 for combined classes that didn't pass the cut
    
    # Calculate percentages for FITS and CSV
    combined_df['Passed_Cut_FITS_Percent'] = (combined_df['Passed_Cut_FITS'] / combined_df['Total']) * 100
    combined_df['Passed_Cut_CSV_Percent'] = (combined_df['Passed_Cut_CSV'] / combined_df['Total']) * 100
    
    # Handle cases where Total is 0 to avoid division by zero
    combined_df['Passed_Cut_FITS_Percent'] = combined_df.apply(
        lambda row: row['Passed_Cut_FITS_Percent'] if row['Total'] > 0 else 0, axis=1
    )
    combined_df['Passed_Cut_CSV_Percent'] = combined_df.apply(
        lambda row: row['Passed_Cut_CSV_Percent'] if row['Total'] > 0 else 0, axis=1
    )
    
    # Sort the DataFrame by Total counts for better visualization
    combined_df = combined_df.sort_values('Total', ascending=False)
    
    # Plotting
    x = np.arange(len(combined_df))  # the label locations
    width = 0.25  # the width of the bars
    
    fig, ax = plt.subplots(figsize=(16, 10))
    
    # Plot Total counts
    bars1 = ax.bar(x - width, combined_df['Total'], width, label='Total Objects', color='skyblue')
    
    # Plot Passed-Cut counts from FITS
    bars2 = ax.bar(x, combined_df['Passed_Cut_FITS'], width, label='Passed Cut (FITS)', color='salmon')
    
    # Plot Passed-Cut counts from CSV
    bars3 = ax.bar(x + width, combined_df['Passed_Cut_CSV'], width, label='Passed Cut (CSV)', color='lightgreen')
    
    # Add labels, title, and custom x-axis tick labels
    ax.set_xlabel('Combined Object Class', fontsize=16)
    ax.set_ylabel('Number of Objects', fontsize=16)
    ax.set_title('Comparison of Total Objects and Passed-Cut Objects by Combined Class', fontsize=18)
    ax.set_xticks(x)
    ax.set_xticklabels(combined_df.index, rotation=45, ha='right', fontsize=12)
    ax.legend(fontsize=12)
    
    # Function to attach labels to bars
    def autolabel(bars, counts, percentages=None):
        """Attach a text label above each bar displaying its height and percentage."""
        for bar, count, percent in zip(bars, counts, percentages if percentages is not None else [None]*len(bars)):
            height = bar.get_height()
            if percent is not None and not np.isnan(percent):
                label = f'{int(count)} ({percent:.1f}%)'
            else:
                label = f'{int(count)}'
            ax.annotate(
                label,
                xy=(bar.get_x() + bar.get_width() / 2, height),
                xytext=(0, 3),  # 3 points vertical offset
                textcoords="offset points",
                ha='center', va='bottom',
                fontsize=10
            )
    
    # Attach labels to Total bars
    autolabel(bars1, combined_df['Total'])
    
    # Attach labels to Passed-Cut FITS bars with percentages
    autolabel(bars2, combined_df['Passed_Cut_FITS'], combined_df['Passed_Cut_FITS_Percent'])
    
    # Attach labels to Passed-Cut CSV bars with percentages
    autolabel(bars3, combined_df['Passed_Cut_CSV'], combined_df['Passed_Cut_CSV_Percent'])
    
    fig.tight_layout()
    
    # Save the figure
    plt.show()

# **Step 6: Main Execution Flow**

def main():
    # **Counting Total Objects by Combined Class**
    logging.info("Counting total objects by combined class from FITS files...")
    total_object_counts = count_total_objects_by_combined_class(all_filenames, object_types, object_type_to_class)
    
    # **Counting Passed-Cut Objects by Combined Class from FITS Files**
    logging.info("Counting passed-cut objects by combined class directly from FITS files...")
    selection_criteria = {
        'min_points': 3,
        'SNR_threshold': 5,
        'FLUXCAL_threshold': 100
    }
    passed_cut_counts_fits = count_passing_objects_from_fits(all_filenames, object_types, object_type_to_class, selection_criteria)
    
    # **Counting Passed-Cut Objects by Combined Class from CSV File**
    logging.info("Counting passed-cut objects by combined class from CSV file...")
    processed_csv_path = '/home/bhardwaj/notebooksLSST/ELAsTiCC2_processed/hdf5_files/ELAsTiCC2_Test.csv'  # Update if necessary
    passed_cut_counts_csv = count_passed_cut_objects_combined(processed_csv_path, object_type_to_class, combined_classes)
    
    # **Save Counts to CSV Files**
    logging.info("Saving counts to CSV files...")
    save_counts_to_csv(total_object_counts, passed_cut_counts_fits, passed_cut_counts_csv, output_dir='counts')
    
    # **Plot the Overlaid Histogram Comparing FITS and CSV Counts**
    logging.info("Plotting the comparison histogram...")
    plot_combined_class_histogram_comparison(total_object_counts, passed_cut_counts_fits, passed_cut_counts_csv, output_path='object_type_histogram_comparison.png')
    
    # End the timer and print elapsed time
    end_time = time.time()
    elapsed_time = end_time - start_time
    logging.info(f"Script completed in {elapsed_time:.2f} seconds")

if __name__ == "__main__":
    main()


### population stats

In [None]:
# Combined classes and their respective object types
combined_classes = {
    'TDE': ['TDE'],
    'SNI': ['SNIa-SALT3', 'SNIa-91bg', 'SNIax', 'SNIcBL+HostXT_V19', 'SNIb+HostXT_V19'],
    'SNII': ['SNIIn-MOSFIT', 'SNII-NMF', 'SNII+HostXT_V19', 'SNIIb+HostXT_V19'],
    'CLAGN': ['CLAGN'],
    'KN': ['KN_K17', 'KN_B19'],
    'SLSN': ['SLSN-I+host', 'SLSN-I_no_host']
  
}

# Create a combined class column in df0
df0['Combined_Class'] = df0['Object_Type']
for combined_class, original_classes in combined_classes.items():
    df0.loc[df0['Object_Type'].isin(original_classes), 'Combined_Class'] = combined_class

# Calculate counts for each combined class
combined_class_counts = df0['Combined_Class'].value_counts()

# Define custom colors
custom_colors = ['#ff9999','#66b3ff','#99ff99','#ffcc99','#c2c2f0','#ffb3e6']

# Explode values to highlight TDE without hiding KN
explode_values = [0.1 if combined_class == 'TDE' else 0 for combined_class in combined_class_counts.index]

# Custom autopct function to show percentages and counts
def autopct_with_count(pct, allvalues):
    absolute = int(np.round(pct/100.*np.sum(allvalues)))
    return "{:.1f}%\n({:d})".format(pct, absolute)

# Plot pie chart for combined classes with custom colors and shadow
plt.figure(figsize=(10, 6))
combined_class_counts.plot.pie(
    autopct=lambda pct: autopct_with_count(pct, combined_class_counts.values),
    colors=custom_colors,
    startangle=140,
    shadow=True,
    explode=explode_values
)
plt.title('Distribution of Combined Classes', fontsize=16, fontweight='bold')
plt.ylabel('')
plt.show()

# Map combined classes to TDE vs Non-TDEs
df0['TDE_vs_NonTDE'] = df0['Combined_Class'].apply(lambda x: 'TDE' if x == 'TDE' else 'Non-TDE')

# Calculate counts for TDE vs Non-TDEs
tde_vs_nontde_counts = df0['TDE_vs_NonTDE'].value_counts()

# Plot pie chart for TDE vs Non-TDEs with custom colors and shadow
plt.figure(figsize=(8, 6))
tde_vs_nontde_counts.plot.pie(
    autopct=lambda pct: autopct_with_count(pct, tde_vs_nontde_counts.values),
    colors=['#66b3ff', '#ff9999'],
    startangle=140,
    shadow=True,
    explode=[0.1 if combined_class == 'TDE' else 0 for combined_class in tde_vs_nontde_counts.index]
)
plt.title('Distribution of TDE vs Non-TDEs', fontsize=16, fontweight='bold')
plt.ylabel('')
plt.show()


In [None]:
# true peak vs gp peak
# Load the data
file_path = '/home/bhardwaj/notebooksLSST/ELAsTiCC2_processed/test19-Copy1.1_snr5_intermediate.csv'
data = pd.read_csv(file_path)

# Calculate the difference between 'TruePeakMJD' and 'peak_time_MJD'
data['MJD_Difference'] = data['TruePeakMJD'] - data['peak_time_MJD']

# Filter out rows with NaN values in the columns
data = data.dropna(subset=['TruePeakMJD', 'peak_time_MJD'])

# Plot 1: Histogram of the differences
plt.figure(figsize=(10, 6))
plt.hist(data['MJD_Difference'], bins=50, color='skyblue', edgecolor='black')
plt.title('Histogram of MJD Differences (TruePeakMJD - peak_time_MJD)')
plt.xlabel('MJD Difference')
plt.ylabel('Frequency')
plt.grid(True)
plt.show()

# Plot 2: KDE plot (Density estimation)
plt.figure(figsize=(10, 6))
sns.kdeplot(data['MJD_Difference'], fill=True, color='lightcoral')
plt.title('KDE of MJD Differences (TruePeakMJD - peak_time_MJD)')
plt.xlabel('MJD Difference')
plt.grid(True)
plt.show()

# Plot 3: Scatter plot of TruePeakMJD vs. peak_time_MJD
plt.figure(figsize=(10, 6))
plt.scatter(data['TruePeakMJD'], data['peak_time_MJD'], alpha=0.5, color='purple')
plt.title('Scatter Plot: TruePeakMJD vs peak_time_MJD')
plt.xlabel('TruePeakMJD')
plt.ylabel('peak_time_MJD')
plt.grid(True)
plt.show()

# Plot 4: Jointplot (scatter plot + KDE for comparison)
sns.jointplot(x='TruePeakMJD', y='peak_time_MJD', data=data, kind="scatter", color="blue", marginal_kws=dict(bins=50, fill=True))
plt.show()

# Plot 5: Box plot of MJD differences
plt.figure(figsize=(10, 6))
sns.boxplot(data['MJD_Difference'], color='green')
plt.title('Box Plot of MJD Differences')
plt.xlabel('MJD Difference')
plt.grid(True)
plt.show()

# Display some basic statistics
print(data['MJD_Difference'].describe())


In [None]:
# Define the file path
file_path = '/home/bhardwaj/notebooksLSST/ELAsTiCC2_processed/test20.3_snr5_intermediate.csv'


# Read the CSV file into an Astropy Table
table1 = Table.read(file_path, format='csv')

#table.colnames
table1.info()


In [None]:

# Define the file path
file_path = '/home/bhardwaj/nonIa-TRUTH_OBJECTS.csv'

# Read the CSV file into an Astropy Table
truth_table = Table.read(file_path, format='csv')

#table.colnames
truth_table.info()

In [None]:
from astropy.table import Table, join


# Ensure that the 'SNID' column exists in both tables
if 'SNID' not in truth_table.colnames or 'SNID' not in table1.colnames:
    raise ValueError("'SNID' column not found in one or both tables.")

# Perform a join operation on the 'SNID' column to find matches
matched_table = join(truth_table, table1, keys='SNID', join_type='inner')

# Display the results of the matching
print("Matched Rows:")
print(matched_table)

# Ensure that the 'SNID' column exists in both tables
if 'SNID' not in truth_table.colnames or 'SNID' not in table1.colnames:
    raise ValueError("'SNID' column not found in one or both tables.")

# Perform a join operation on the 'SNID' column to find matches
matched_table = join(truth_table, table1, keys='SNID', join_type='inner')



In [None]:

# Ensure that the required columns exist in the matched table
if 'PeakMag' not in matched_table.colnames or 'PEAKMAG_g' not in matched_table.colnames:
    raise ValueError("'PeakMag' or 'PEAKMAG_g' column not found in the matched table.")

# Calculate the difference in PeakMag
matched_table['MagDifference'] = matched_table['PeakMag'] - matched_table['PEAKMAG_g']

# Plot the difference
plt.figure(figsize=(10, 6))
plt.hist(matched_table['MagDifference'], bins=30, edgecolor='k', alpha=0.7)
plt.xlabel('Difference in PeakMag (PeakMag - PEAKMAG_g)')
plt.ylabel('Number of SNID')
plt.title('Difference in PeakMag for Matched SNID')
plt.grid(True)
plt.show()

In [None]:
# List all variable (column) names in the DataFrame
variable_names = df0.columns.tolist()

# Display the list of variable names vertically
for name in variable_names:
    print(name)


In [None]:
import matplotlib.pyplot as plt

# Define which Object_Types are considered TDE
tde_types = ['TDE']  # Replace 'TDE' with the actual values that represent TDE in your data

# Create a new column in the DataFrame to categorize as 'TDE' or 'non-TDE'
df0['Category'] = df0['Object_Type'].apply(lambda x: 'TDE' if x in tde_types else 'non-TDE')

# Dictionary mapping old column names to new column names
rename_mapping = {
    'Rise_Time': 'Rise time',
    'Fade_Time': 'Fade time',
    'Mean_Color_Pre_Peak_gr': 'Mean Color Pre Peak (g-r)',
    'Mean_Color_Pre_Peak_ri': 'Mean Color Pre Peak (r-i)',
    'Mean_Color_Post_Peak_ri': 'Mean Color Post Peak (r-i)',
    'Mean_Color_Post_Peak_gr': 'Mean Color Post Peak (g-r)',
    'Slope_Pre_Peak_gr': 'Slope Pre Peak (g-r)',
    'Slope_Pre_Peak_ri': 'Slope Pre Peak (r-i)',
    'Slope_Post_Peak_ri': 'Slope Post Peak (r-i)',
    'Slope_Post_Peak_gr': 'Slope Post Peak (g-r)'
}

# Rename the columns in the DataFrame
df0.rename(columns=rename_mapping, inplace=True)

# List of variables you want to include in the plot
selected_variables = [
    'LengthScale_Wavelength', 
    'LengthScale_Time', 
    'Amplitude',
    'Rise time', 
    'Fade time',
    'Mean Color Pre Peak (g-r)',
    'Mean Color Pre Peak (r-i)',
    'Mean Color Post Peak (r-i)',
    'Mean Color Post Peak (g-r)',
    'Slope Pre Peak (g-r)',
    'Slope Pre Peak (r-i)',
    'Slope Post Peak (r-i)',
    'Slope Post Peak (g-r)'
]

# Filter out variables that do not exist in the DataFrame (in case some were renamed incorrectly)
selected_variables = [var for var in selected_variables if var in df0.columns]

non_missing_percentage_by_category = {}

for category in ['TDE', 'non-TDE']:
    df_subset = df0[df0['Category'] == category]
    
    # Calculate the non-missing percentage for selected variables
    non_missing_percentage = df_subset[selected_variables].notnull().mean() * 100
    
    # Calculate the percentage of rows with all selected variables present
    percentage_all_present = df_subset[selected_variables].dropna().shape[0] / df_subset.shape[0] * 100
    
    # Add this percentage as a new entry to the non_missing_percentage series
    non_missing_percentage['All features extracted'] = percentage_all_present
    
    non_missing_percentage_by_category[category] = non_missing_percentage

# Plotting
# Create subplots with adjusted figure size and layout
fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(10, 6), sharex=True, constrained_layout=True)

for ax, (category, non_missing_percentage) in zip(axes, non_missing_percentage_by_category.items()):
    # Sort the percentages for better visualization
    non_missing_percentage_sorted = non_missing_percentage.sort_values()
    
    # Create a horizontal bar plot
    non_missing_percentage_sorted.plot(kind='barh', color='skyblue', ax=ax)
    
    # Set the title and labels
    ax.set_title(f'Proportion of Extracted Features: {category}', fontsize=14, fontweight='bold')
    ax.set_xlabel('Percentage', fontsize=12)
    
    # Customize tick parameters for better readability
    ax.tick_params(axis='both', which='major', labelsize=10)
    
    # Add grid lines for the x-axis
    ax.grid(axis='x', alpha=0.3)
    
    # Optionally, annotate the bars with percentage values
    for i, (value, name) in enumerate(zip(non_missing_percentage_sorted, non_missing_percentage_sorted.index)):
        ax.text(value + 1, i, f'{value:.1f}%', va='center', fontsize=9)

# Save the figure once after all subplots are created
save_path = '/home/bhardwaj/notebooksLSST/tdes-fzu/home/bhardwaj/notebooksLSST/tdes-fzu/notebooksLSST/ELAsTiCC2_processed/results-images/successrate' 
plt.savefig(save_path, dpi=100)
print(f"Figure saved to {save_path}")

# Display the plot
plt.show()


## Plots 

In [None]:
import pandas as pd
import plotly.graph_objects as go

# Combined classes and their respective colors
combined_classes = {
    'SNI': ['SNIa-SALT3', 'SNIa-91bg', 'SNIax', 'SNIcBL+HostXT_V19', 'SNIb+HostXT_V19'],
    'SNII': ['SNIIn-MOSFIT', 'SNII-NMF', 'SNII+HostXT_V19', 'SNIIb+HostXT_V19'],
    'AGN': ['AGN'],
    'SLSN': ['SLSN-I+host', 'SLSN-I_no_host'],
    'TDE': ['TDE']
}

colors = {
    'SNI': 'red',
    'SNII': 'green',
    'AGN': 'yellow',
    'SLSN': 'purple',
    'TDE': 'blue'
}

# Initialize lists to store rise times, decay times, and amplitudes for different combined classes
rise_times = {cls: [] for cls in combined_classes}
decay_times = {cls: [] for cls in combined_classes}
amplitudes = {cls: [] for cls in combined_classes}

# Collect rise times, decay times, and amplitudes from the processed data
for _, entry in df0.iterrows():
    for combined_class, original_classes in combined_classes.items():
        if entry['Object_Type'] in original_classes:
            rise_times[combined_class].append(entry['LengthScale_Wavelength'])
            decay_times[combined_class].append(entry['LengthScale_Time'])
            amplitudes[combined_class].append(entry['Amplitude'])
            break

# Create 3D scatter plot using Plotly
fig = go.Figure()

# Add each combined class to the 3D plot
for combined_class, color in colors.items():
    fig.add_trace(go.Scatter3d(
        x=rise_times[combined_class],
        y=decay_times[combined_class],
        z=amplitudes[combined_class],
        mode='markers',
        marker=dict(
            size=1,
            color=color,  # Set color for each class
            opacity=0.5
        ),
        name=combined_class
    ))

# Set the layout for the 3D plot
fig.update_layout(
    scene=dict(
        xaxis_title="LengthScale_Wavelength",
        yaxis_title="LengthScale_Time",
        zaxis_title="Amplitude"
    ),
    title="3D Scatter Plot of LengthScale_Wavelength, LengthScale_Time, and Amplitude",
    margin=dict(l=0, r=0, b=0, t=50)
)

# Display the interactive 3D plot
fig.show()


In [None]:
# Combined classes and their respective colors
combined_classes = {
    'SNI': ['SNIa-SALT3','SNIa-91bg','SNIax', 'SNIcBL+HostXT_V19','SNIb+HostXT_V19'],
    'SNII': ['SNIIn-MOSFIT','SNII-NMF','SNII+HostXT_V19','SNIIb+HostXT_V19'],
    'AGN': ['AGN'],
  #  'KN': ['KN_K17','KN_B19'],
  #  'SLSN': ['SLSN-I+host','SLSN-I_no_host'],
    'TDE': ['TDE']
}

colors = {
    'SNI': 'red',
    'SNII': 'green',
    'AGN': 'orange',
   # 'KN': 'purple',
   # 'SLSN': 'purple',
    'TDE': 'blue'
}
# Initialize lists to store rise and decay times for different combined classes
rise_times = {cls: [] for cls in combined_classes}
decay_times = {cls: [] for cls in combined_classes}

# Collect rise and decay times from the processed data
for _, entry in df0.iterrows():
    for combined_class, original_classes in combined_classes.items():
        if entry['Object_Type'] in original_classes:
            rise_times[combined_class].append(entry['LengthScale_Wavelength'])
            decay_times[combined_class].append(entry['LengthScale_Time'])
            break

plt.figure(figsize=(10, 5))

# Plot each combined class
for combined_class, color in colors.items():
    plt.scatter(rise_times[combined_class], decay_times[combined_class], color=color, label=combined_class, marker='.', alpha = 0.4)

# Add labels, title, legend
plt.xlabel("LengthScale_Wavelength")
plt.ylabel("LengthScale_Time")
plt.xlim(0, 50) 
plt.ylim(0, 20)    
#plt.title("Measured by peakmag+1 either side of peak, using GP on g band")
plt.legend()
plt.savefig('/home/bhardwaj/notebooksLSST/ELAsTiCC2_processed/LWvsLT.png', dpi=300)

# Display the plot
plt.show()

In [None]:
df0

### Rate of color change v/s mean color

In [None]:
# Combined classes and their respective colors
combined_classes = {
    'SNI': ['SNIa-SALT3', 'SNIa-91bg', 'SNIax', 'SNIcBL+HostXT_V19', 'SNIb+HostXT_V19'],
    'SNII': ['SNIIn-MOSFIT', 'SNII-NMF', 'SNII+HostXT_V19', 'SNIIb+HostXT_V19'],
    'CLAGN': ['CLAGN'],
    'KN': ['KN_K17', 'KN_B19'],
    'SLSN': ['SLSN-I+host', 'SLSN-I_no_host'],
    'TDE': ['TDE']
}

colors = {
    'SNI': 'red',
    'SNII': 'green',
    'CLAGN': 'yellow',
    'KN': 'purple',
    'SLSN': 'black',
    'TDE': 'blue'
}

# Create a combined class column in df0
df0['Combined_Class'] = df0['Object_Type']
for combined_class, original_classes in combined_classes.items():
    df0.loc[df0['Object_Type'].isin(original_classes), 'Combined_Class'] = combined_class

# Plotting
fig, ((ax8, ax9), (ax10, ax11)) = plt.subplots(2, 2, figsize=(17, 6))

for combined_class, color in colors.items():
    mask = df0['Combined_Class'] == combined_class
    
    # Apply a mask to filter out NaN values for each plot
    for ax, mean_col, slope_col, err_col in [
        (ax8, 'Mean_Color_Pre_Peak', 'Slope_Pre_Peak', 'Slope_Err_Pre_Peak'),
        (ax9, 'Mean_Color_Post_Peak', 'Slope_Post_Peak', 'Slope_Err_Post_Peak'),
        (ax10, 'Mean_Color_Pre_Peak_GR', 'Slope_Pre_Peak_GR', 'Slope_Err_Pre_Peak_GR'),
        (ax11, 'Mean_Color_Post_Peak_GR', 'Slope_Post_Peak_GR', 'Slope_Err_Post_Peak_GR')]:

        valid_mask = mask & df0[mean_col].notna() & df0[slope_col].notna() & df0[err_col].notna()
        
        # Plot if there are any valid data points
        if valid_mask.sum() > 0:
            ax.errorbar(df0.loc[valid_mask, mean_col], df0.loc[valid_mask, slope_col], 
                        yerr=df0.loc[valid_mask, err_col].fillna(0), fmt='.', label=combined_class, alpha=0.4, color=color)
        else:
            print(f"No valid data points for {combined_class} in {ax.get_title()}")

ax8.set_xlabel("Mean Pre-peak g-r Color")
ax8.set_ylabel("Rate of Color Change (1/day)")
ax9.set_xlabel("Mean Post-peak g-r Color")
ax9.set_ylabel("Rate of Color Change (1/day)")
ax10.set_xlabel("Mean Pre-peak r-i Color")
ax10.set_ylabel("Rate of Color Change (1/day)")
ax11.set_xlabel("Mean Post-peak r-i Color")
ax11.set_ylabel("Rate of Color Change (1/day)")

# Setting axis limits and adding legends
for ax in [ax8, ax9, ax10, ax11]:
    ax.legend()
    ax.grid(True)
    ax.set_xlim(-3, 2.5)
    ax.set_ylim(-0.07, 0.07)

plt.tight_layout()
plt.show()


### Color-color diagram

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# Combined classes and their respective colors
combined_classes = {
    'SNI': ['SNIa-SALT3','SNIa-91bg','SNIax', 'SNIcBL+HostXT_V19','SNIb+HostXT_V19'],
    'SNII': ['SNIIn-MOSFIT','SNII-NMF','SNII+HostXT_V19','SNIIb+HostXT_V19'],
    'AGN': ['CLAGN'],
  #  'KN': ['KN_K17','KN_B19'],
    'SLSN': ['SLSN-I+host','SLSN-I_no_host'],
    'TDE': ['TDE']
}

colors = {
    'SNI': 'red',
    'SNII': 'green',
    'AGN': 'yellow',
 #   'KN': 'purple',
    'SLSN': 'purple',
    'TDE': 'blue'
}

# Create a combined class column in df0
df0['Combined_Class'] = df0['Object_Type']
for combined_class, original_classes in combined_classes.items():
    df0.loc[df0['Object_Type'].isin(original_classes), 'Combined_Class'] = combined_class

# Function to plot with marginal distributions
def plot_with_marginals(joint, x_data, y_data, x_error, y_error, label, color):
    joint.ax_joint.errorbar(
        x_data, y_data, 
        xerr=x_error, yerr=y_error, 
        fmt='.', alpha=0.3, color=color, label=label
    )
    sns.histplot(
        x=x_data, bins=100, kde=True, 
        color=color, ax=joint.ax_marg_x, alpha=0.5
    )
    sns.histplot(
        y=y_data, bins=100, kde=True, 
        color=color, ax=joint.ax_marg_y, alpha=0.5, orientation='horizontal'
    )

# Set a random seed for reproducibility
RANDOM_SEED = 42

# Pre-peak Color-Color Diagram
pre_joint = sns.JointGrid(height=6, ratio=4)

for combined_class, color in colors.items():
    mask = df0['Combined_Class'] == combined_class
    df_class = df0.loc[mask]
    
    # Check if there are enough data points to sample
    if len(df_class) == 0:
        continue  # Skip if no data points for this class
    
    # Sample 10% of the data for the current class
    df_sampled = df_class.sample(frac=0.3, random_state=RANDOM_SEED)
    
    x_data = df_sampled['Mean Color Pre Peak (g-r)']
    y_data = df_sampled['Mean Color Pre Peak (r-i)']
    x_error = df_sampled['Pre_Peak_Color_err_gr']
    y_error = df_sampled['Pre_Peak_Color_err_ri']
    
    plot_with_marginals(pre_joint, x_data, y_data, x_error, y_error, combined_class, color)

pre_joint.set_axis_labels("Mean Pre-peak g-r Color", "Mean Pre-peak r-i Color")
pre_joint.ax_joint.legend(title='Combined Class')
pre_joint.ax_joint.set_xlim(-1.5, 3)
pre_joint.ax_joint.set_ylim(-1.5, 1.5)
pre_joint.fig.suptitle('Pre-peak Color-Color Diagram', y=1.02)

# Post-peak Color-Color Diagram
post_joint = sns.JointGrid(height=6, ratio=4)

for combined_class, color in colors.items():
    mask = df0['Combined_Class'] == combined_class
    df_class = df0.loc[mask]
    
    # Check if there are enough data points to sample
    if len(df_class) == 0:
        continue  # Skip if no data points for this class
    
    # Sample 20% of the data for the current class
    df_sampled = df_class.sample(frac=0.1, random_state=RANDOM_SEED)
    
    x_data = df_sampled['Mean Color Post Peak (g-r)']
    y_data = df_sampled['Mean Color Post Peak (r-i)']
    x_error = df_sampled['Post_Peak_Color_err_gr']
    y_error = df_sampled['Post_Peak_Color_err_ri']
    
    plot_with_marginals(post_joint, x_data, y_data, x_error, y_error, combined_class, color)

post_joint.set_axis_labels("Mean post-peak g-r", "Mean post-peak r-i")
post_joint.ax_joint.legend(loc='lower right')  # Move legend to bottom right
post_joint.ax_joint.set_xlim(-0.8, 2)
post_joint.ax_joint.set_ylim(-0.9, 0.9)
#post_joint.fig.suptitle('Post-peak Color-Color Diagram', y=1.02)

# Adjust layout to prevent clipping
plt.tight_layout(rect=[0, 0.03, 1, 0.95])  # Adjust rect to make space for suptitle

# Save the Post-peak plot separately with proper settings
save_path_post = '/home/bhardwaj/notebooksLSST/ELAsTiCC2_processed/results-images/colordiag_post_peak.png'
post_joint.fig.savefig(save_path_post, dpi=150, bbox_inches='tight')
print(f"Post-peak Color-Color Diagram saved to {save_path_post}")



plt.show()


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

# Combined classes and their respective colors
combined_classes = {
    'SNI': ['SNIa-SALT3', 'SNIa-91bg', 'SNIax', 'SNIcBL+HostXT_V19', 'SNIb+HostXT_V19'],
    'SNII': ['SNIIn-MOSFIT', 'SNII-NMF', 'SNII+HostXT_V19', 'SNIIb+HostXT_V19'],
    'AGN': ['CLAGN'],
    'SLSN': ['SLSN-I+host', 'SLSN-I_no_host'],
    'TDE': ['TDE']
}

colors = {
    'SNI': 'red',
    'SNII': 'green',
    'AGN': 'orange',
    'SLSN': 'purple',
    'TDE': 'blue'
}

# Create a combined class column in df0
df0['Combined_Class'] = df0['Object_Type']
for combined_class, original_classes in combined_classes.items():
    df0.loc[df0['Object_Type'].isin(original_classes), 'Combined_Class'] = combined_class

# Set a random seed for reproducibility
RANDOM_SEED = 42

# Function to plot contours
def plot_contours(joint, x_data, y_data, label, color):
    # Plot KDE contours with seaborn
    sns.kdeplot(
        x=x_data,
        y=y_data,
        ax=joint.ax_joint,
        levels=5,  # Number of contour levels
        color=color,
        linewidths=1,
        alpha=0.6,  # Set transparency for contour lines
        label=label
    )

# Pre-peak Color-Color Diagram
pre_joint = sns.JointGrid(height=6, ratio=4)

for combined_class, color in colors.items():
    mask = df0['Combined_Class'] == combined_class
    df_class = df0.loc[mask]

    # Check if there are enough data points to sample
    if len(df_class) == 0:
        continue  # Skip if no data points for this class

    # Sample 30% of the data for the current class
    df_sampled = df_class.sample(frac=0.3, random_state=RANDOM_SEED)

    x_data = df_sampled['Mean Color Pre Peak (g-r)']
    y_data = df_sampled['Mean Color Pre Peak (r-i)']

    plot_contours(pre_joint, x_data, y_data, combined_class, color)

pre_joint.set_axis_labels("Mean Pre-peak g-r Color", "Mean Pre-peak r-i Color")
pre_joint.ax_joint.legend(title='Combined Class')
pre_joint.ax_joint.set_xlim(-1.5, 3)
pre_joint.ax_joint.set_ylim(-1.5, 1.5)
pre_joint.fig.suptitle('Pre-peak Color-Color Diagram', y=1.02)

# Post-peak Color-Color Diagram
post_joint = sns.JointGrid(height=6, ratio=4)

for combined_class, color in colors.items():
    mask = df0['Combined_Class'] == combined_class
    df_class = df0.loc[mask]

    # Check if there are enough data points to sample
    if len(df_class) == 0:
        continue  # Skip if no data points for this class

    # Sample 10% of the data for the current class
    df_sampled = df_class.sample(frac=0.5, random_state=RANDOM_SEED)

    x_data = df_sampled['Mean Color Post Peak (g-r)']
    y_data = df_sampled['Mean Color Post Peak (r-i)']

    plot_contours(post_joint, x_data, y_data, combined_class, color)

post_joint.set_axis_labels("Mean post-peak g-r", "Mean post-peak r-i")
post_joint.ax_joint.legend(loc='lower right')  # Move legend to bottom right
post_joint.ax_joint.set_xlim(-0.8, 2)
post_joint.ax_joint.set_ylim(-0.9, 0.9)

# Adjust layout to prevent clipping
plt.tight_layout(rect=[0, 0.03, 1, 0.95])  # Adjust rect to make space for suptitle

# Save the Post-peak plot separately with proper settings
save_path_post = '/home/bhardwaj/notebooksLSST/ELAsTiCC2_processed/results-images/colordiag_post_peak.png'
post_joint.fig.savefig(save_path_post, dpi=150, bbox_inches='tight')
print(f"Post-peak Color-Color Diagram saved to {save_path_post}")

plt.show()


### rise vs decay

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from scipy.stats import gaussian_kde

# Assuming df0 is already loaded
# Combined classes and their respective colors
combined_classes = {
    'SNI': ['SNIa-SALT3', 'SNIa-91bg', 'SNIax', 'SNIcBL+HostXT_V19', 'SNIb+HostXT_V19'],
    'SNII': ['SNIIn-MOSFIT', 'SNII-NMF', 'SNII+HostXT_V19', 'SNIIb+HostXT_V19'],
    'AGN': ['AGN'],
    'SLSN': ['SLSN-I+host', 'SLSN-I_no_host'],
    'TDE': ['TDE']
}

colors = {
    'SNI': 'red',
    'SNII': 'green',
    'AGN': 'yellow',
    'SLSN': 'purple',
    'TDE': 'blue'
}

# Initialize lists to store rise and decay times for different combined classes
rise_times = {cls: [] for cls in combined_classes}
decay_times = {cls: [] for cls in combined_classes}

# Collect rise and decay times from the processed data
for _, entry in df0.iterrows():
    for combined_class, original_classes in combined_classes.items():
        if entry['Object_Type'] in original_classes:
            rise_times[combined_class].append(entry['Rise time'])
            decay_times[combined_class].append(entry['Fade time'])
            break

plt.figure(figsize=(10, 6))

# Plot each combined class with contours using Seaborn
for combined_class, color in colors.items():
    x = np.array(rise_times[combined_class])
    y = np.array(decay_times[combined_class])

    # KDE plot to create contours
    sns.kdeplot(
        x=x,
        y=y,
        levels=5,  # Adjust the number of levels to match the contour style
        color=color,
        fill=False,  # Only draw contour lines
        linewidths=1,
        alpha=0.6  # Transparency for contours
    )

    # Calculate density using gaussian_kde
    xy = np.vstack([x, y])
    kde = gaussian_kde(xy)
    z = kde(xy)

    # Define a threshold for the density - only show points outside of dense regions
    threshold = np.percentile(z, 90)  # Only show the lowest 10% of the points
    scatter_x = x[z < threshold]
    scatter_y = y[z < threshold]

    # Scatter plot for points outside the contour regions
    plt.scatter(
        scatter_x, 
        scatter_y, 
        color=color, 
        label=combined_class, 
        marker='.', 
        alpha=0.4, 
        s=1  # Adjust size of points for better visualization
    )

# Add labels, title, legend
plt.xlabel("Rise (days)")
plt.ylabel("Fade (days)")
plt.ylim(-5, 200)
plt.xlim(-5, 200)
plt.legend()

plt.tight_layout()
plt.show()


### rise/fade vs color

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from matplotlib.lines import Line2D

# Assuming df0 is already loaded with necessary columns including errors
# Sample only 50% of each object type
df_sampled = df0.groupby('Combined_Class').apply(lambda x: x.sample(frac=0.5)).reset_index(drop=True)

# Define combined classes and respective colors
combined_classes = {
    'SNI': ['SNIa-SALT3', 'SNIa-91bg', 'SNIax', 'SNIcBL+HostXT_V19', 'SNIb+HostXT_V19'],
    'SNII': ['SNIIn-MOSFIT', 'SNII-NMF', 'SNII+HostXT_V19', 'SNIIb+HostXT_V19'],
    'AGN': ['AGN'],
    'SLSN': ['SLSN-I+host', 'SLSN-I_no_host'],
    'TDE': ['TDE']
}

colors = {
    'SNI': 'red',
    'SNII': 'green',
    'AGN': 'yellow',
    'SLSN': 'purple',
    'TDE': 'blue'
}

# Initialize color map for the different combined classes
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 4))

for combined_class, color in colors.items():
    mask = df_sampled['Combined_Class'] == combined_class
    
    # Plot Rise Time vs. Pre-Peak Mean Color with error bars
    ax1.errorbar(df_sampled.loc[mask, 'Mean Color Post Peak (r-i)'], 
                 df_sampled.loc[mask, 'Slope Post Peak (r-i)'], 
                 yerr=df_sampled.loc[mask, 'Slope_Err_Post_Peak_gr'],  # Error for y
                 xerr=df_sampled.loc[mask, 'Post_Peak_Color_err_gr'],  # Error for x
                 fmt='.', label=combined_class, alpha=0.4, color=color)
    
    ax1.set_ylabel("Rate of color change (r-i) per day")
    ax1.set_xlabel("Mean Color Post Peak r-i (mag)")
    ax1.set_xlim(-0.7, 1)
    ax1.set_ylim(-0.02, 0.05)

    # Plot Fade Time vs. Post-Peak Mean Color with error bars
    ax2.errorbar(df_sampled.loc[mask, 'Mean Color Post Peak (g-r)'], 
                 df_sampled.loc[mask, 'Slope Post Peak (g-r)'], 
                 yerr=df_sampled.loc[mask, 'Slope_Err_Post_Peak_gr'],  # Error for y
                 xerr=df_sampled.loc[mask, 'Post_Peak_Color_err_gr'],  # Error for x
                 fmt='.', label=combined_class, alpha=0.3, color=color)
    
    ax2.set_ylabel("Rate of color change (g-r) per day")
    ax2.set_xlabel("Mean Color Post Peak g-r (mag)")
    ax2.set_xlim(-0.7, 1)
    ax2.set_ylim(-0.02, 0.05)

# Create custom legend handles with higher alpha
legend_handles = [Line2D([0], [0], marker='.', color='w', label=class_type,
                          markerfacecolor=color, markersize=10, alpha=1) for class_type, color in colors.items()]

# Add legend to both axes
ax1.legend(handles=legend_handles, title="Combined Class", loc='upper left', frameon=True, facecolor='white', edgecolor='black')
ax2.legend(handles=legend_handles, title="Combined Class", loc='upper left', frameon=True, facecolor='white', edgecolor='black')

plt.tight_layout()
plt.savefig('/home/bhardwaj/notebooksLSST/ELAsTiCC2_processed/colorvslope.png', dpi=300)
plt.show()


In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from matplotlib.lines import Line2D

# Assuming df0 is already loaded with necessary columns including errors
# Sample only 50% of each object type

# Define combined classes and respective colors
combined_classes = {
    'SNI': ['SNIa-SALT3', 'SNIa-91bg', 'SNIax', 'SNIcBL+HostXT_V19', 'SNIb+HostXT_V19'],
    'SNII': ['SNIIn-MOSFIT', 'SNII-NMF', 'SNII+HostXT_V19', 'SNIIb+HostXT_V19'],
 #   'AGN': ['AGN'],
    'SLSN': ['SLSN-I+host', 'SLSN-I_no_host'],
    'TDE': ['TDE']
}

colors = {
    'SNI': 'red',
    'SNII': 'green',
  #  'AGN': 'yellow',
    'SLSN': 'purple',
    'TDE': 'blue'
}

# Create a combined class column in df0
df0['Combined_Class'] = df0['Object_Type']
for combined_class, original_classes in combined_classes.items():
    df0.loc[df0['Object_Type'].isin(original_classes), 'Combined_Class'] = combined_class

df_sampled = df0.groupby('Combined_Class').apply(lambda x: x.sample(frac=0.5)).reset_index(drop=True)


# Initialize color map for the different combined classes
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 4))

for combined_class, color in colors.items():
    mask = df_sampled['Combined_Class'] == combined_class
    
    # Plot Post vs. Pre-Peak Mean Color with error bars
    ax1.errorbar(df_sampled.loc[mask, 'Mean Color Post Peak (r-i)'], 
                 df_sampled.loc[mask, 'Mean Color Pre Peak (r-i)'], 
                 xerr=df_sampled.loc[mask, 'Post_Peak_Color_err_ri'],  # Error for x
                 yerr=df_sampled.loc[mask, 'Pre_Peak_Color_err_ri'],  # Error for y
                 fmt='.', label=combined_class, alpha=0.3, color=color)
    
    ax1.set_ylabel("Mean Color Pre Peak r-i (mag)")
    ax1.set_xlabel("Mean Color Post Peak r-i (mag)")
    ax1.set_xlim(-1, 1)
    ax1.set_ylim(-1, 1)

    # Plot Fade Time vs. Post-Peak Mean Color with error bars
    ax2.errorbar(df_sampled.loc[mask, 'Mean Color Post Peak (g-r)'], 
                 df_sampled.loc[mask, 'Mean Color Pre Peak (r-i)'], 
                 xerr=df_sampled.loc[mask, 'Post_Peak_Color_err_gr'],  # Error for x
                 yerr=df_sampled.loc[mask, 'Pre_Peak_Color_err_ri'],  # Error for y
                 fmt='.', label=combined_class, alpha=0.3, color=color)
    
    ax2.set_ylabel("Mean Color Pre Peak r-i (mag)")
    ax2.set_xlabel("Mean Color Post Peak g-r (mag)")
    ax2.set_xlim(-1, 1)
    ax2.set_ylim(-1, 1)

# Create custom legend handles with higher alpha
legend_handles = [Line2D([0], [0], marker='.', color='w', label=class_type,
                          markerfacecolor=color, markersize=10, alpha=1) for class_type, color in colors.items()]

# Add legend to both axes
ax1.legend(handles=legend_handles, title="Combined Class", loc='upper left', frameon=True, facecolor='white', edgecolor='black')
ax2.legend(handles=legend_handles, title="Combined Class", loc='upper left', frameon=True, facecolor='white', edgecolor='black')

plt.tight_layout()
plt.savefig('/home/bhardwaj/notebooksLSST/ELAsTiCC2_processed/colorvscolor.png', dpi=300)
plt.show()


In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from matplotlib.lines import Line2D

# Assuming df0 is already loaded with necessary columns including errors

# Define combined classes and respective colors
combined_classes = {
    'SNI': ['SNIa-SALT3', 'SNIa-91bg', 'SNIax', 'SNIcBL+HostXT_V19', 'SNIb+HostXT_V19'],
    'SNII': ['SNIIn-MOSFIT', 'SNII-NMF', 'SNII+HostXT_V19', 'SNIIb+HostXT_V19'],
    'AGN': ['AGN'],
    'SLSN': ['SLSN-I+host', 'SLSN-I_no_host'],
    'TDE': ['TDE']
}

colors = {
    'SNI': 'red',
    'SNII': 'green',
    'AGN': 'yellow',
    'SLSN': 'purple',
    'TDE': 'blue'
}

# Create a combined class column in df0
df0['Combined_Class'] = df0['Object_Type']
for combined_class, original_classes in combined_classes.items():
    df0.loc[df0['Object_Type'].isin(original_classes), 'Combined_Class'] = combined_class

# ----------------------------
# Step 1: Remove Extreme Outliers
# ----------------------------

def remove_outliers_iqr(df, column):
    Q1 = df[column].quantile(0.1)
    Q3 = df[column].quantile(0.9)
    IQR = Q3 - Q1
    lower_bound = Q1 - 1.5 * IQR
    upper_bound = Q3 + 1.5 * IQR
    before_count = df.shape[0]
    df_filtered = df[(df[column] >= lower_bound) & (df[column] <= upper_bound)]
    after_count = df_filtered.shape[0]
    print(f"Removed {before_count - after_count} outliers from '{column}'.")
    return df_filtered

# Columns to check for outliers
outlier_columns = ['Mean Color Pre Peak (r-i)', 'Mean Color Post Peak (r-i)', 'LengthScale_Wavelength']

for col in outlier_columns:
    if col in df0.columns:
        df0 = remove_outliers_iqr(df0, col)
    else:
        print(f"Column '{col}' not found in DataFrame.")

# ----------------------------
# Step 2: Remove Points with Mean Color Pre/Post Peak == 0.0
# ----------------------------

# Define conditions for removal
conditions = (df0['Mean Color Pre Peak (r-i)'] != 0.0) & (df0['Mean Color Post Peak (r-i)'] != 0.0)

# Number of points before removal
before_zero_removal = df0.shape[0]

# Apply conditions
df0 = df0[conditions]

# Number of points after removal
after_zero_removal = df0.shape[0]

print(f"Removed {before_zero_removal - after_zero_removal} points with 'Mean Color Pre/Post Peak (r-i)' == 0.0.")

# ----------------------------
# Step 3: Remove Sampling Step (Use Entire Dataset)
# ----------------------------

# Commented out the sampling step
# df_sampled = df0.groupby('Combined_Class').apply(lambda x: x.sample(frac=0.5)).reset_index(drop=True)

# Use the entire cleaned dataset
df_cleaned = df0.copy()
print(f"Total number of rows after cleaning: {df_cleaned.shape[0]}")

# ----------------------------
# Step 4: Plotting
# ----------------------------

# Initialize color map for the different combined classes
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

for combined_class, color in colors.items():
    mask = df_cleaned['Combined_Class'] == combined_class
    
    # Ensure there are data points for the class
    if not mask.any():
        continue
    
    # Plot Pre-Peak Mean Color vs LengthScale_Wavelength with error bars
    ax1.errorbar(
        df_cleaned.loc[mask, 'LengthScale_Wavelength'], 
        df_cleaned.loc[mask, 'Mean Color Pre Peak (r-i)'], 
        yerr=df_cleaned.loc[mask, 'Pre_Peak_Color_err_ri'],  # Error for y
        fmt='.', label=combined_class, alpha=0.6, color=color, markersize=4, linestyle='none'
    )
    
    # Plot Post-Peak Mean Color vs LengthScale_Wavelength with error bars
    ax2.errorbar(
        df_cleaned.loc[mask, 'LengthScale_Wavelength'], 
        df_cleaned.loc[mask, 'Mean Color Post Peak (r-i)'], 
        yerr=df_cleaned.loc[mask, 'Post_Peak_Color_err_ri'],  # Error for y
        fmt='.', label=combined_class, alpha=0.6, color=color, markersize=4, linestyle='none'
    )

# Set labels and limits for ax1
ax1.set_xlabel("LengthScale_Wavelength")
ax1.set_ylabel("Mean Color Pre Peak r-i (mag)")
ax1.set_xlim(10, 30)
ax1.set_ylim(-2, 2)
ax1.set_title("Pre-Peak Mean Color vs LengthScale_Wavelength")

# Set labels and limits for ax2
ax2.set_xlabel("LengthScale_Wavelength")
ax2.set_ylabel("Mean Color Post Peak r-i (mag)")
ax2.set_xlim(10, 30)
ax2.set_ylim(-2, 2)
ax2.set_title("Post-Peak Mean Color vs LengthScale_Wavelength")

# Create custom legend handles with higher alpha
legend_handles = [
    Line2D([0], [0], marker='o', color='w', label=class_type,
           markerfacecolor=color, markersize=8, alpha=0.6)
    for class_type, color in colors.items()
]

# Add legend to both axes
ax1.legend(handles=legend_handles, title="Combined Class", loc='upper right', frameon=True, facecolor='white', edgecolor='black')
ax2.legend(handles=legend_handles, title="Combined Class", loc='upper right', frameon=True, facecolor='white', edgecolor='black')

plt.tight_layout()
plt.savefig('/home/bhardwaj/notebooksLSST/ELAsTiCC2_processed/LWvscolor_cleaned.png', dpi=300)
plt.show()


In [None]:
# Combined classes and their respective colors
combined_classes = {
    'SNI': ['SNIa-SALT3', 'SNIa-91bg', 'SNIax', 'SNIcBL+HostXT_V19', 'SNIb+HostXT_V19'],
    'SNII': ['SNIIn-MOSFIT', 'SNII-NMF', 'SNII+HostXT_V19', 'SNIIb+HostXT_V19'],
    'CLAGN': ['CLAGN'],
    'KN': ['KN_K17', 'KN_B19'],
    'SLSN': ['SLSN-I+host', 'SLSN-I_no_host'],
    'TDE': ['TDE']
}

colors = {
    'SNI': 'red',
    'SNII': 'green',
    'CLAGN': 'yellow',
    'KN': 'purple',
    'SLSN': 'black',
    'TDE': 'blue'
}

# Create a combined class column in df0
df0['Combined_Class'] = df0['Object_Type']
for combined_class, original_classes in combined_classes.items():
    mask = df0['Combined_Class'] == combined_class
    df0.loc[df0['Object_Type'].isin(original_classes), 'Combined_Class'] = combined_class
    
    # Plot Rise Time vs. Post-Peak Mean Color
    ax1.scatter(df0.loc[mask, 'REDSHIFT_FINAL'], df0.loc[mask, 'Mean_Color_Pre_Peak_GR'],  label=combined_class, alpha=0.5, color=color, marker = '.')
    ax1.set_xlabel("Redshift")
    ax1.set_ylabel("Pre-Peak Mean Color")
    #ax1.set_title("Pre-Peak Mean Color g-r")
    ax1.set_ylim(-1,1)
    #ax1.set_yscale('log')
    ax1.legend()
    
    # Plot Fade Time vs. Post-Peak Mean Color
    ax2.scatter(df0.loc[mask, 'REDSHIFT_FINAL'], df0.loc[mask, 'Mean_Color_Post_Peak_GR'], label=combined_class, alpha=0.5, color=color, marker = '.')
    ax2.set_xlabel("Redshift")
    ax2.set_ylabel("Post-Peak Mean Color")
    #ax2.set_title("Fade Time vs. Post-Peak Mean Color g-r")
    ax2.set_ylim(-1,1)
    #ax2.set_yscale('log')
    ax2.legend()

plt.tight_layout()
plt.show()

### Redshift vs peak-mag/flux

In [None]:
# Initialize color map for the different object types
colors = {'TDE': 'blue', 'SNIa-SALT3': 'red', 'CLAGN': 'green', 'SLSN-I+host': 'yellow', 'SNIIn-MOSFIT': 'purple', 'SNIb+HostXT_V19': 'grey', 'SNIcBL+HostXT_V19': 'pink', 'SNIax': 'black' }

# Create a figure with subplots
#fig, (ax1, ax2) = plt.subplots(1, 1, figsize=(14, 6))

for object_type, color in colors.items():
    mask = df0['Object_Type'] == object_type
    
    # Plot Rise Time vs. Post-Peak Mean Color
    plt.scatter(df0.loc[mask, 'REDSHIFT_FINAL'], df0.loc[mask, 'PeakMag'], label=object_type, alpha=0.5, color=color, marker = '.')
    plt.xlabel("Redshift")
    plt.ylabel("PeakFlux_GP")
    #ax1.set_title("Pre-Peak Mean Color g-r")
    #ax1.set_ylim(-1,1)
    #ax1.set_yscale('log')
    plt.legend()


plt.tight_layout()
plt.show()


In [None]:
# Initialize color map for the different object types
colors = {'TDE': 'blue', 'SNIa': 'red', 'CLAGN': 'green', 'SLSN-I': 'yellow', 'SNII': 'purple'}

# Create a figure with subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

for object_type, color in colors.items():
    mask = df0['Object_Type'] == object_type
    
    # Plot Rise Time vs. Post-Peak Mean Color
    ax1.scatter(df0.loc[mask, 'REDSHIFT_FINAL'], df0.loc[mask, 'PeakMag'], label=object_type, alpha=0.5, color=color, marker = '.')
    ax1.set_xlabel("Redshift")
    ax1.set_ylabel("PeakMag")
    #ax1.set_title("Pre-Peak Mean Color g-r")
    #ax1.set_ylim(-1,1)
    #ax1.set_yscale('log')
    ax1.legend()
    
    # Plot Fade Time vs. Post-Peak Mean Color
    ax2.scatter(df0.loc[mask, 'REDSHIFT_FINAL'], df0.loc[mask, 'Mean_Color_Post_Peak_GR'], label=object_type, alpha=0.5, color=color, marker = '.')
    ax2.set_xlabel("Redshift")
    ax2.set_ylabel("Post-Peak Mean Color")
    #ax2.set_title("Fade Time vs. Post-Peak Mean Color g-r")
    ax2.set_ylim(-1,1)
    #ax2.set_yscale('log')
    ax2.legend()

plt.tight_layout()
plt.show()


### variable distributions 

In [None]:
# Assuming df0 is your DataFrame with all object types
colors = {
    'SNI': 'red',
    'SNII': 'green',
    'AGN': 'orange',
 #   'KN': 'purple',
    'SLSN': 'purple',
    'TDE': 'blue'
}

combined_classes = {
    'SNI': ['SNIa-SALT3', 'SNIa-91bg', 'SNIax', 'SNIcBL+HostXT_V19', 'SNIb+HostXT_V19'],
    'SNII': ['SNIIn-MOSFIT', 'SNII-NMF', 'SNII+HostXT_V19', 'SNIIb+HostXT_V19'],
    'AGN': ['AGN'],
#    'KN': ['KN_K17', 'KN_B19'],
    'SLSN': ['SLSN-I+host', 'SLSN-I_no_host'],
    'TDE': ['TDE']
}

rename_map = {
    'Mean_Color_Pre_Peak_gr': 'Mean Color Pre Peak (g-r)',
    'Mean_Color_Post_Peak_gr': 'Mean Color Post Peak (g-r)',
    'Mean_Color_Pre_Peak_ri': 'Mean Color Pre Peak (r-i)',
    'Mean_Color_Post_Peak_ri': 'Mean Color Post Peak (r-i)',
    'Slope_Pre_Peak_gr': 'Slope Pre Peak (g-r)',
    'Slope_Post_Peak_gr': 'Slope Post Peak (g-r)',
    'Slope_Pre_Peak_ri': 'Slope Pre Peak (r-i)',
    'Slope_Post_Peak_ri': 'Slope Post Peak (r-i)',
    'Rise_Time': 'Rise time',
    'Fade_Time': 'Fade time'
    
}

df0.rename(columns=rename_map, inplace=True)

# Filtered columns based on your criteria
filtered_columns = [
 #   'Amplitude',
 #   'LengthScale_Time',
 #   'LengthScale_Wavelength',
 #   'Mean Color Pre Peak (g-r)',
 #   'Mean Color Post Peak (g-r)',
 #   'Mean Color Pre Peak (r-i)',
 #   'Mean Color Post Peak (r-i)',
#     'Slope Pre Peak (g-r)',
     'Slope Post Peak (g-r)',
   #  'Slope Pre Peak (r-i)',
     'Slope Post Peak (r-i)',
     'Rise time',
     'Fade time'
    
]

# Custom axis limits
custom_limits = {
    'Amplitude': (5, 15),
    'LengthScale_Time': (4, 14),
    'LengthScale_Wavelength': (12, 24),
    'Mean Color Pre Peak (g-r)': (-0.8, 0.8),
    'Mean Color Pre Peak (r-i)': (-0.8, 0.8),
    'Mean Color Post Peak (g-r)': (-0.8, 0.8),
    'Mean Color Post Peak (r-i)': (-0.8, 0.8),
    'Slope Pre Peak (g-r)': (-0.04, 0.04),
    'Slope Pre Peak (r-i)': (-0.04, 0.04),
    'Slope Post Peak (g-r)': (-0.02, 0.04),
    'Slope Post Peak (r-i)': (-0.02, 0.04),
    'Rise time': (0, 100),
    'Fade time': (0, 200)
}


# Calculate the number of rows needed for the subplots
num_rows = int(np.ceil(len(filtered_columns) / 2))

# Set up the matplotlib figure with the right number of subplots
fig, axes = plt.subplots(num_rows, 2, figsize=(20, 4 * num_rows))
axes = axes.flatten()  # Flatten the axes array for easy iteration

# Loop through the filtered columns and create the plots for each object type
for i, col in enumerate(filtered_columns):
    ax = axes[i]

    # Determine quantile range for binning (e.g., 5th to 95th percentile)
    lower_quantile = df0[col].quantile(0.0005)
    upper_quantile = df0[col].quantile(0.9995)
    
    # Adjust the binning strategy to focus on the specified quantile range
    bins = np.linspace(lower_quantile, upper_quantile, 1000)  # Reduced number of bins to smooth the distribution

    # Plot histograms for each object type
    for class_type, obj_types in combined_classes.items():
        data = df0[df0['Object_Type'].isin(obj_types)][col].clip(lower_quantile, upper_quantile)
        alpha = 0.6 if class_type == 'TDE' else 0.4  # Emphasize TDE by setting alpha to 0.6
        linewidth = 1.2 if class_type == 'TDE' else 1  # Emphasize TDE by setting linewidth to 1.5
        
        # Adjust KDE using bw_adjust to better fit the data
        sns.histplot(
            data, bins=bins, color=colors[class_type], label=class_type, stat="probability", linewidth=linewidth,
            ax=ax, alpha=alpha, kde=False, kde_kws={'bw_adjust': 0.1}  # Lower bandwidth for tighter KDE fit
        )

    # Set custom axis limits if specified
    if col in custom_limits:
        ax.set_xlim(custom_limits[col])

    ax.set_xlabel(col, fontsize=12, fontweight='bold')
    ax.set_ylabel('Probability', fontsize=12, fontweight='bold')
    ax.legend(loc='upper right')  # Add legend to each subplot

# Check if there are any unused axes and remove them
for j in range(i + 1, len(axes)):
    fig.delaxes(axes[j])

plt.tight_layout()
plt.savefig('/home/bhardwaj/notebooksLSST/ELAsTiCC2_processed/variable_dist1.png', dpi=300)
plt.show()


In [None]:
colors = {
    'TDE': 'blue',
    'Non-TDE': 'white'
}

combined_classes = {
    'TDE': ['TDE'],
    'Non-TDE': [
        'SNIa-SALT3', 'SNIa-91bg', 'SNIax', 'SNIcBL+HostXT_V19', 'SNIb+HostXT_V19', 
        'SNIIn-MOSFIT', 'SNII-NMF', 'SNII+HostXT_V19', 'SNIIb+HostXT_V19', 
        'CLAGN', 'KN_K17', 'KN_B19', 'SLSN-I+host', 'SLSN-I_no_host'
    ]
}

# Filtered columns based on your criteria
filtered_columns = [
    'Amplitude',
    'LengthScale_Time',
    'LengthScale_Wavelength',
    'Mean_Color_Pre_Peak',
    'Mean_Color_Post_Peak',
    'Slope_Pre_Peak',
    'Slope_Post_Peak',
    'Rise_Time',
    'Decay_Time'
]

# Calculate the number of rows needed for the subplots
num_rows = int(np.ceil(len(filtered_columns) / 2))

# Set up the matplotlib figure with the right number of subplots
fig, axes = plt.subplots(num_rows, 2, figsize=(20, 4 * num_rows))
axes = axes.flatten()  # Flatten the axes array for easy iteration

# Loop through the filtered columns and create the plots for each object type 'Mean_Color_Pre_Peak_GR',

for i, col in enumerate(filtered_columns):
    ax = axes[i]

    # Determine quantile range for binning (e.g., 5th to 95th percentile)
    lower_quantile = df0[col].quantile(0.0005)
    upper_quantile = df0[col].quantile(0.9995)
    
    # Adjust the binning strategy to focus on the specified quantile range
    bins = np.linspace(lower_quantile, upper_quantile, 100)  

    # Plot histograms for TDE and Others
    for class_type, obj_types in combined_classes.items():
        data = df0[df0['Object_Type'].isin(obj_types)][col].clip(lower_quantile, upper_quantile)
        alpha = 1.0 if class_type == 'TDE' else 0.5  # Emphasize TDE by setting alpha to 1.0
        linewidth = 2 if class_type == 'TDE' else 1  # Emphasize TDE by setting linewidth to 2
        sns.histplot(data, bins=bins, color=colors[class_type], label=class_type if i == 0 else "", stat="probability", linewidth=linewidth, ax=ax, alpha=alpha, kde=True, )

    ax.set_xlabel(col, fontsize=12, fontweight='bold')
    ax.set_ylabel('Probability', fontsize=12, fontweight='bold')
    if i == 0:  # Add legend only to the first subplot for clarity
        ax.legend()

# Check if there are any unused axes and remove them
for j in range(i + 1, len(axes)):
    fig.delaxes(axes[j])

plt.tight_layout()
plt.show()


## Correlation, PCA, ML classifiers

In [None]:
excluded_columns = [
    'SNID',
 #'Amplitude', 
 #'LengthScale_Time',
 #'LengthScale_Wavelength',
 'TruePeakMJD', 
'peak_time_MJD',
# 'Mean Color Pre Peak (g-r)',
 'Pre_Peak_Color_err_gr', 
# 'Mean Color Post Peak (g-r)', 
 'Post_Peak_Color_err_gr', 
# 'Slope Pre Peak (g-r)',
 'Slope_Err_Pre_Peak_gr',
# 'Slope Post Peak (g-r)', 
 'Slope_Err_Post_Peak_gr', 
# 'Mean Color Pre Peak (r-i)', 
 'Pre_Peak_Color_err_ri',
# 'Mean Color Post Peak (r-i)',
 'Post_Peak_Color_err_ri', 
# 'Slope Pre Peak (r-i)', 
 'Slope_Err_Pre_Peak_ri',
# 'Slope Post Peak (r-i)', 
 'Slope_Err_Post_Peak_ri',
# 'Rise time', 
# 'Fade time',
     'REDSHIFT_FINAL',
    'REDSHIFT_FINAL_ERR',
    'Error'
]



# Function to filter numeric columns based on exclusions
def filter_columns(df, excluded_cols):
    numeric_columns = df.select_dtypes(include=[np.number]).columns
    filtered_columns = [col for col in numeric_columns if col not in excluded_cols]
    return filtered_columns

# Filter the DataFrame to include only TDE objects
df_tde = df0[df0['Object_Type'] == 'TDE'].copy()

# Apply the filtering based on the excluded_columns list
filtered_columns = filter_columns(df_tde, excluded_columns)

# Ensure "Amplitude" is included and check the filtered columns
print("Filtered Columns: ", filtered_columns)

# Plotting Correlation Heatmap for TDE with the lower triangle and diagonal
def plot_correlation_heatmap(data):
    plt.figure(figsize=(12, 10)) 

    # Compute the correlation matrix
    corr_matrix = data.corr(method='kendall')

    # Generate a mask for the upper triangle
    mask = np.triu(np.ones_like(corr_matrix, dtype=bool), k=1)

    # Draw the heatmap with the mask and correct aspect ratio
    sns.heatmap(corr_matrix, mask=mask, cmap='coolwarm', annot=True, fmt=".2f", square=True, cbar_kws={"shrink": .75}, linewidths=.5)

    # Enhanced title to reflect the analysis context
    plt.title('Correlation Matrix of Selected TDE Features', fontsize=18)
    plt.xticks(rotation=45, ha='right', fontsize=11)
    plt.yticks(fontsize=11)
    plt.tight_layout()
    plt.savefig('/home/bhardwaj/notebooksLSST/ELAsTiCC2_processed/corrTDE_ELAsTiCC2.png', dpi=150)

    plt.show()

# Call the function to plot the heatmap
data_numeric = df_tde[filtered_columns]
plot_correlation_heatmap(data_numeric)

# Function to perform PCA and plot the results
def perform_pca(data, axes):
    imputer = SimpleImputer(strategy='median')  # Handle NaN values by imputation
    data_imputed = imputer.fit_transform(data)
    scaler = StandardScaler()  # Standardize the data
    data_scaled = scaler.fit_transform(data_imputed)

    pca = PCA()
    pca_data = pca.fit_transform(data_scaled)

    # Scree plot of the explained variance
    ax_scree = axes[0]  # First subplot for scree plot
    ax_loadings = axes[1]  # Second subplot for loadings plot

    ax_scree.bar(range(1, len(pca.explained_variance_ratio_) + 1), pca.explained_variance_ratio_, color='blue', alpha=0.7)
    ax_scree.set_ylabel('Explained Variance Ratio', fontsize=14)
    ax_scree.set_xlabel('Principal Component', fontsize=14)
    ax_scree.set_title('PCA Scree Plot for TDE Features', fontsize=16)
    
    # PCA Component Loadings Plot with annotations
    components_df = pd.DataFrame(pca.components_, columns=data.columns, index=[f'PC{i+1}' for i in range(len(pca.components_))])
    sns.heatmap(components_df, annot=True, cmap='coolwarm', ax=ax_loadings, fmt=".2f", cbar_kws={"shrink": .75})
    ax_loadings.set_title('PCA Component Loadings for TDE Features', fontsize=16)
    ax_loadings.set_xticklabels(ax_loadings.get_xticklabels(), rotation=45, ha='right', fontsize=12)
    ax_loadings.set_yticklabels(ax_loadings.get_yticklabels(), fontsize=12)

# Preparing for PCA
fig, axes = plt.subplots(1, 2, figsize=(18, 6))
data_numeric = df_tde[filtered_columns]
if data_numeric.shape[1] > 0:  # Proceed with PCA if there are columns left
    perform_pca(data_numeric, axes)
else:
    print("No columns left for PCA after preprocessing for TDE. Check your data and preprocessing steps.")

plt.tight_layout()

plt.show()


In [None]:
from scipy.stats import pearsonr, spearmanr

x = df_tde['Slope Post Peak (r-i)']
y = df_tde['Mean Color Post Peak (r-i)']



# Combine x and y into a DataFrame
data = pd.DataFrame({'x': x, 'y': y})

# Drop rows where either x or y is NaN
data_clean = data.dropna(subset=['x', 'y'])

# Extract cleaned x and y
x_clean = data_clean['x']
y_clean = data_clean['y']
pearson_corr, _ = pearsonr(x_clean, y_clean)
spearman_corr, _ = spearmanr(x_clean, y_clean)
print(f"Pearson Correlation: {pearson_corr}")
print(f"Spearman Correlation: {spearman_corr}")


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

plt.figure(figsize=(8, 6))
sns.scatterplot(x=x_clean, y=y_clean)
plt.title('Slope Post Peak (r-i) vs Mean Color Post Peak (r-i)')
plt.xlabel('Slope Post Peak (r-i)')
plt.ylabel('Mean Color Post Peak (r-i)')
plt.show()
sns.lmplot(x='x', y='y', data=data_clean)
plt.title('Linear Regression Fit')
plt.xlabel('Slope Post Peak (r-i)')
plt.ylabel('Mean Color Post Peak (r-i)')
plt.show()


In [None]:
def remove_outliers_iqr(df, columns):
    df_clean = df.copy()
    for col in columns:
        # Calculate Q1 and Q3
        Q1 = df_clean[col].quantile(0.1)
        Q3 = df_clean[col].quantile(0.9)
        IQR = Q3 - Q1
        
        # Define bounds
        lower_bound = Q1 - 3 * IQR
        upper_bound = Q3 + 3 * IQR
        
        # Filter data within bounds
        df_clean = df_clean[(df_clean[col] >= lower_bound) & (df_clean[col] <= upper_bound)]
    return df_clean


import pandas as pd
import numpy as np

# Define columns to exclude
excluded_columns = [
    'SNID',
    'TruePeakMJD', 
    'peak_time_MJD',
    'Pre_Peak_Color_err_gr', 
    'Post_Peak_Color_err_gr', 
    'Slope_Err_Pre_Peak_gr',
    'Slope_Err_Post_Peak_gr', 
    'Pre_Peak_Color_err_ri',
    'Post_Peak_Color_err_ri', 
    'Slope_Err_Pre_Peak_ri',
    'Slope_Err_Post_Peak_ri',
    'REDSHIFT_FINAL',
    'REDSHIFT_FINAL_ERR',
    'Error',
    'Object_Type'  # Exclude if it's non-numeric
]

# Function to filter numeric columns based on exclusions
def filter_columns(df, excluded_cols):
    numeric_columns = df.select_dtypes(include=[np.number]).columns
    filtered_columns = [col for col in numeric_columns if col not in excluded_cols]
    return filtered_columns

# Assuming df0 is your original DataFrame
# Filter the DataFrame to include only TDE objects
df_tde = df0[df0['Object_Type'] == 'TDE'].copy()

# Get the filtered columns
filtered_columns = filter_columns(df_tde, excluded_columns)

# Display the columns being used
print("Filtered Columns: ", filtered_columns)


In [None]:
# Remove outliers

columns_with_outliers = [
    'Slope Post Peak (r-i)', 
    'Mean Color Post Peak (r-i)', 
    'Slope Post Peak (r-i)', 
    'Mean Color Post Peak (r-i)',
    'Fade time',
    'Rise time'
]

df_tde_no_outliers = remove_outliers_iqr(df_tde, filtered_columns)

# Check the number of observations before and after outlier removal
print(f"Number of observations before outlier removal: {len(df_tde)}")
print(f"Number of observations after outlier removal: {len(df_tde_no_outliers)}")


In [None]:
# Drop rows with missing values in the filtered columns
df_tde_no_outliers = df_tde_no_outliers.dropna(subset=filtered_columns)


In [None]:
# Convert to numeric, coercing errors to NaN (which have already been dropped)
df_tde_no_outliers[filtered_columns] = df_tde_no_outliers[filtered_columns].apply(pd.to_numeric, errors='coerce')


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

def plot_correlation_heatmap(data, columns):
    plt.figure(figsize=(12, 10)) 

    # Compute the correlation matrix
    corr_matrix = data[columns].corr()

    # Generate a mask for the upper triangle
    mask = np.triu(np.ones_like(corr_matrix, dtype=bool), k=1)

    # Draw the heatmap
    sns.heatmap(
        corr_matrix,
        mask=mask,
        cmap='coolwarm',
        annot=True,
        fmt=".2f",
        square=True,
        cbar_kws={"shrink": .75},
        linewidths=.5
    )

    # Title and layout adjustments
    plt.title('Correlation Matrix of Selected TDE Features (Outliers Removed)', fontsize=18)
    plt.xticks(rotation=45, ha='right', fontsize=11)
    plt.yticks(fontsize=11)
    plt.tight_layout()
    plt.savefig('/home/bhardwaj/notebooksLSST/ELAsTiCC2_processed/corrTDE_no_outliers.png', dpi=150)
    plt.show()

# Plot the correlation heatmap
plot_correlation_heatmap(df_tde_no_outliers, filtered_columns)


In [None]:
for col in filtered_columns:
    plt.figure(figsize=(12, 5))

    # Before outlier removal
    plt.subplot(1, 2, 1)
    sns.boxplot(x=df_tde[col])
    plt.title(f'Before Outlier Removal - {col}')

    # After outlier removal
    plt.subplot(1, 2, 2)
    sns.boxplot(x=df_tde_no_outliers[col])
    plt.title(f'After Outlier Removal - {col}')

    plt.tight_layout()
    plt.show()

plt.savefig('/home/bhardwaj/notebooksLSST/ELAsTiCC2_processed/after_outliers.png', dpi=150)




### ML Classifiers

In [None]:
# Define the file path
file_path = '/home/bhardwaj/notebooksLSST/ELAsTiCC2_processed/hdf5_files/ELAsTiCC2_Test.csv'
# Read the CSV file into a DataFrame
df0 = pd.read_csv(file_path)

# Display the first few rows of the DataFrame to ensure it was loaded correctly
#print(df0.head())
df0


### Multi-classification

In [None]:
# Select relevant columns for the model and exclude 'SNID', any truth info 

excluded_columns = ['SNID', 'PeakMag', 'PeakFlux_GP', 'REDSHIFT_FINAL'] 
excluded_columns += [col for col in df0.columns if 'err' in col.lower()] #remove error estimates
feature_columns = [col for col in df0.select_dtypes(include=[np.number]).columns if col not in excluded_columns]

X = df0[feature_columns].fillna(-999)  # Handling missing values by filling with -999
y = df0['Object_Type']
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X) #Standardize features by removing the mean and scaling to unit variance.
                                   #The standard score of a sample x is calculated as:  z = (x - u) / s 
                                   #where u is the mean of the training samples or zero if with_mean=False, 
                                   #and s is the standard deviation of the training samples or one if with_std=False.

X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.7, random_state=42)

def plot_normalized_conf_matrix(y_true, y_pred, ax, title):
    cm = confusion_matrix(y_true, y_pred, labels=np.unique(y_true))
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    sns.heatmap(cm_normalized, annot=True, fmt=".2f", cmap='Blues', ax=ax, 
                xticklabels=np.unique(y_true), yticklabels=np.unique(y_true))
    ax.set_xlabel('Predicted Class')
    ax.set_ylabel('True Class')
    ax.set_title(title)

strategies = {
    'Original': None,  # No sampling for the original distribution
    'Undersampling': RandomUnderSampler(random_state=42),
    'Oversampling (SMOTE)': SMOTE(random_state=42)
}

# Adjusting the figure size and orientation of the labels for clarity
fig, axs = plt.subplots(len(strategies), 2, figsize=(24, len(strategies) * 8), gridspec_kw={'width_ratios': [2, 3]})

for i, (strategy_name, sampler) in enumerate(strategies.items()):
    if strategy_name == 'Original':
        X_train_mod, y_train_mod = X_train, y_train
    else:
        X_train_mod, y_train_mod = sampler.fit_resample(X_train, y_train)
    
    # Visualize class distribution with improved label orientation
    sns.countplot(y=y_train_mod, ax=axs[i, 0], palette='Set2', order=y_train_mod.value_counts().index)
    axs[i, 0].set_title(f'Class Distribution : {strategy_name}')
    axs[i, 0].set_xlabel('Counts')
    axs[i, 0].set_ylabel('Class')
    axs[i, 0].tick_params(axis='y', rotation=45)  # Rotate labels to avoid overlap
    
    # Train Random Forest classifier
    rf_classifier = RandomForestClassifier(n_estimators=100, random_state=42, class_weight='balanced')
    rf_classifier.fit(X_train_mod, y_train_mod)

    # Train the Gradient Boosting classifier
    bdt_classifier = GradientBoostingClassifier(n_estimators=100, learning_rate=1.0, max_depth=1, random_state=42)
    bdt_classifier.fit(X_train_mod, y_train_mod)

    # Predictiction 
    y_pred_rf = rf_classifier.predict(X_test)
    y_pred_bdt = bdt_classifier.predict(X_test)

    #plot normalized confusion matrix
    plot_normalized_conf_matrix(y_test, y_pred_rf, axs[i, 1], f'RF: {strategy_name}')

plt.tight_layout()
plt.show()

def plot_precision_recall_curve_multiclass(y_true, y_probas, class_labels):
    plt.figure(figsize=(10, 7))
    for i, class_label in enumerate(class_labels):
        # Binarize the output for the current class
        y_true_bin = (y_true == class_label).astype(int)
        precision, recall, _ = precision_recall_curve(y_true_bin, y_probas[:, i])
        auc_score = auc(recall, precision) 
        plt.plot(recall, precision, lw=2, label=f'Class {class_label} (AUC = {auc_score:0.2f})')
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.legend(loc="best")
    plt.title("Precision-Recall Curve for Multiple Classes")
    plt.grid(True)
    plt.show()

# Get the probabilities for each class
y_probs_rf = rf_classifier.predict_proba(X_test)
y_probs_bdt = bdt_classifier.predict_proba(X_test)

# Class labels for the multi-class setupKey Adjustments:
class_labels = rf_classifier.classes_

# Plot the precision-recall curve for Random Forest
plot_precision_recall_curve_multiclass(y_test, y_probs_rf, class_labels)


print("Random Forest Classification Report:")
print(classification_report(y_test, y_pred_rf))
print("\nGradient Boosting Classification Report:")
print(classification_report(y_test, y_pred_bdt))

In [None]:
# Define the combined classes
combined_classes = {
    'SNI': ['SNIa-SALT3', 'SNIa-91bg', 'SNIax', 'SNIcBL+HostXT_V19', 'SNIb+HostXT_V19'],
    'SNII': ['SNIIn-MOSFIT', 'SNII-NMF', 'SNII+HostXT_V19', 'SNIIb+HostXT_V19'],
    'AGN': ['AGN'],
    'KN': ['KN_K17', 'KN_B19'],
    'SLSN': ['SLSN-I+host', 'SLSN-I_no_host'],
    'TDE': ['TDE']
}

# Map the original classes to combined classes
class_mapping = {original: combined for combined, originals in combined_classes.items() for original in originals}
df0['Combined_Type'] = df0['Object_Type'].map(class_mapping)

# Select relevant columns for the model and exclude 'SNID', any truth info
excluded_columns = ['SNID', 'PeakMag', 'PeakFlux_GP', 'REDSHIFT_FINAL', 'Num_Peaks_GP']
excluded_columns += [col for col in df0.columns if 'err' in col.lower()]  # remove error estimates
feature_columns = [col for col in df0.select_dtypes(include=[np.number]).columns if col not in excluded_columns]

X = df0[feature_columns].fillna(-999)  # Handling missing values by filling with -999
y = df0['Combined_Type']
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)  # Standardize features

X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.7, random_state=42)

def plot_normalized_conf_matrix(y_true, y_pred, ax, title):
    cm = confusion_matrix(y_true, y_pred, labels=np.unique(y_true))
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    sns.heatmap(cm_normalized, annot=True, fmt=".2f", cmap='Blues', ax=ax, 
                xticklabels=np.unique(y_true), yticklabels=np.unique(y_true))
    ax.set_xlabel('Predicted Class')
    ax.set_ylabel('True Class')
    ax.set_title(title)

strategies = {
    'Original': None,  # No sampling for the original distribution
    'Undersampling': RandomUnderSampler(random_state=42),
    'Oversampling (SMOTE)': SMOTE(random_state=42)
}

# Adjusting the figure size and orientation of the labels for clarity
fig, axs = plt.subplots(len(strategies), 2, figsize=(14, len(strategies) * 4), gridspec_kw={'width_ratios': [2, 3]})

for i, (strategy_name, sampler) in enumerate(strategies.items()):
    if strategy_name == 'Original':
        X_train_mod, y_train_mod = X_train, y_train
    else:
        X_train_mod, y_train_mod = sampler.fit_resample(X_train, y_train)
    
    # Visualize class distribution with improved label orientation
    sns.countplot(y=y_train_mod, ax=axs[i, 0], palette='Set2', order=y_train_mod.value_counts().index)
    axs[i, 0].set_title(f'Class Distribution: {strategy_name}')
    axs[i, 0].set_xlabel('Counts')
    axs[i, 0].set_ylabel('Class')
    axs[i, 0].tick_params(axis='y', rotation=45)  # Rotate labels to avoid overlap
    
    # Train Random Forest classifier
    rf_classifier = RandomForestClassifier(n_estimators=100, random_state=42, class_weight='balanced')
    rf_classifier.fit(X_train_mod, y_train_mod)

    # Train the Gradient Boosting classifier
    bdt_classifier = GradientBoostingClassifier(n_estimators=100, learning_rate=1.0, max_depth=1, random_state=42)
    bdt_classifier.fit(X_train_mod, y_train_mod)

    # Prediction 
    y_pred_rf = rf_classifier.predict(X_test)
    y_pred_bdt = bdt_classifier.predict(X_test)

    # Plot normalized confusion matrix
    plot_normalized_conf_matrix(y_test, y_pred_rf, axs[i, 1], f'RF: {strategy_name}')

plt.tight_layout()
plt.show()

def plot_precision_recall_curve_multiclass(y_true, y_probas, class_labels):
    plt.figure(figsize=(10, 7))
    for i, class_label in enumerate(class_labels):
        # Binarize the output for the current class
        y_true_bin = (y_true == class_label).astype(int)
        precision, recall, _ = precision_recall_curve(y_true_bin, y_probas[:, i])
        auc_score = auc(recall, precision) 
        plt.plot(recall, precision, lw=2, label=f'Class {class_label} (AUC = {auc_score:0.2f})')
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.legend(loc="best")
    plt.title("Precision-Recall Curve for Multiple Classes")
    plt.grid(True)
    plt.show()

# Get the probabilities for each class
y_probs_rf = rf_classifier.predict_proba(X_test)
y_probs_bdt = bdt_classifier.predict_proba(X_test)

# Class labels for the multi-class setup
class_labels = rf_classifier.classes_

# Plot the precision-recall curve for Random Forest
plot_precision_recall_curve_multiclass(y_test, y_probs_rf, class_labels)

print("Random Forest Classification Report:")
print(classification_report(y_test, y_pred_rf))
print("\nGradient Boosting Classification Report:")
print(classification_report(y_test, y_pred_bdt))


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split, RandomizedSearchCV, StratifiedKFold
from sklearn.metrics import (
    confusion_matrix, precision_recall_curve, classification_report, roc_auc_score,
    make_scorer, precision_score, recall_score
)
import xgboost as xgb
from imblearn.over_sampling import SMOTE
from imblearn.under_sampling import RandomUnderSampler
from scipy.stats import uniform, randint
import os
from sklearn.inspection import permutation_importance
from sklearn.metrics import auc

# ---------------------------------------
# Helper Functions
# ---------------------------------------

def find_threshold_multiclass(y_true, y_probs, target_class, target_metric='precision', target_value=0.95):
    """
    Find the threshold for desired precision or recall for a specific class.
    
    Parameters:
    - y_true: Ground truth labels.
    - y_probs: Predicted probabilities for each class.
    - target_class: The class index for which to find the threshold.
    - target_metric: 'precision' or 'recall'.
    - target_value: Desired value for the target metric.
    
    Returns:
    - Threshold value.
    """
    y_true_bin = (y_true == target_class).astype(int)
    precision, recall, thresholds = precision_recall_curve(y_true_bin, y_probs[:, target_class])
    
    if target_metric == 'precision':
        # Find indices where precision >= target_value
        idx = np.where(precision >= target_value)[0]
        if len(idx) == 0:
            print(f"No threshold found to achieve {target_value*100}% precision for class '{class_labels[target_class]}'. Returning threshold=1.0")
            return 1.0  # Maximum threshold
        # Choose the threshold corresponding to the first occurrence
        return thresholds[idx[0]]
    elif target_metric == 'recall':
        # Find indices where recall >= target_value
        idx = np.where(recall >= target_value)[0]
        if len(idx) == 0:
            print(f"No threshold found to achieve {target_value*100}% recall for class '{class_labels[target_class]}'. Returning threshold=0.0")
            return 0.0  # Minimum threshold
        # Choose the threshold corresponding to the last occurrence
        return thresholds[idx[-1]]
    else:
        raise ValueError("target_metric must be 'precision' or 'recall'.")

def plot_confusion_matrix_multiclass(cm, classes, ax, title, normalize=True):
    """
    Plot a confusion matrix for multi-class classification.
    
    Parameters:
    - cm: Confusion matrix.
    - classes: List of class names.
    - ax: Matplotlib Axes object.
    - title: Title for the plot.
    - normalize: Whether to normalize the confusion matrix.
    """
    if normalize:
        cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    else:
        cm_normalized = cm

    sns.heatmap(cm_normalized, annot=False, fmt='.2f', cmap='Blues', ax=ax,
                xticklabels=classes, yticklabels=classes, cbar=False)

    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            if normalize:
                percentage = cm_normalized[i, j] * 100
                count = cm[i, j]
                text = f'{percentage:.1f}%\n({count})'
            else:
                text = f'{cm[i, j]}'
            # Adjust text color based on background for better visibility
            color = 'white' if cm_normalized[i, j] > 0.5 else 'black'
            ax.text(j + 0.5, i + 0.5, text,
                    ha='center', va='center',
                    color=color, fontsize=8,
                    bbox=dict(facecolor='none', edgecolor='none'))

    ax.set_xlabel('Predicted Class', fontsize=10)
    ax.set_ylabel('True Class', fontsize=10)
    ax.set_title(title, fontsize=12)

def plot_precision_recall_curve_multiclass(y_true, y_probas, class_labels, save_path, figsize=(6, 4)):
    """
    Plot Precision-Recall curves for each class in a multi-class setting.
    
    Parameters:
    - y_true: Ground truth labels.
    - y_probas: Predicted probabilities for each class.
    - class_labels: List of class names.
    - save_path: Path to save the plot.
    - figsize: Figure size.
    """
    plt.figure(figsize=figsize)
    for i, class_label in enumerate(class_labels):
        # Binarize the output for the current class
        y_true_bin = (y_true == i).astype(int)
        precision, recall, _ = precision_recall_curve(y_true_bin, y_probas[:, i])
        pr_auc = auc(recall, precision)
        plt.plot(recall, precision, lw=1.5, label=f'{class_label} (AUC = {pr_auc:.2f})')
    
    plt.xlabel('Recall', fontsize=10)
    plt.ylabel('Precision', fontsize=10)
    plt.title('Precision-Recall Curve for Multi-class', fontsize=12)
    plt.legend(fontsize=8, loc='lower left')
    plt.grid(True, linestyle='--', alpha=0.5)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300)
    plt.show()
    plt.close()

def plot_feature_importance_multiclass(model, feature_names, save_path, figsize=(8, 6)):
    """
    Plot and save the feature importance for a multi-class XGBoost model.
    
    Parameters:
    - model: Trained XGBoost model.
    - feature_names: List of feature names.
    - save_path: Path to save the plot.
    - figsize: Figure size.
    """
    plt.figure(figsize=figsize)
    ax = plt.gca()
    xgb.plot_importance(
        model,
        max_num_features=20,
        importance_type='gain',
        show_values=False,
        ax=ax
    )
    ax.set_title('Feature Importance', fontsize=12)
    ax.set_xlabel('Importance (Gain)', fontsize=10)
    ax.set_ylabel('Features', fontsize=10)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300)
    plt.show()
    plt.close()

def plot_class_distribution(y, classes, ax, title):
    """
    Plot the class distribution.
    
    Parameters:
    - y: Array-like of labels.
    - classes: List of class names.
    - ax: Matplotlib Axes object.
    - title: Title for the plot.
    """
    sns.countplot(y=y, ax=ax, palette='Set2', order=np.sort(np.unique(y)))
    ax.set_title(title, fontsize=12)
    ax.set_xlabel('Counts', fontsize=10)
    ax.set_ylabel('Class', fontsize=10)
    
    # Dynamically adjust x-ticks to prevent overlapping
    max_count = y.value_counts().max()
    step = max_count // 5 if max_count > 5 else 1
    ticks = list(range(0, int(max_count) + step, step))
    ax.set_xticks(ticks)
    ax.set_xticklabels(ticks, rotation=0, fontsize=8)
    
    ax.set_yticklabels(classes[np.sort(np.unique(y))], rotation=45, fontsize=9)
    ax.tick_params(axis='y', labelsize=8)

# ---------------------------------------
# Data Preparation
# ---------------------------------------

# Define the combined classes, including the 'KN' class
combined_classes = {
    'SNI': ['SNIa-SALT3', 'SNIa-91bg', 'SNIax', 'SNIcBL+HostXT_V19', 'SNIb+HostXT_V19'],
    'SNII': ['SNIIn-MOSFIT', 'SNII-NMF', 'SNII+HostXT_V19', 'SNIIb+HostXT_V19'],
    'AGN': ['AGN'],
   # 'KN': ['KN_K17', 'KN_B19'],  # Commented out as per original code
    'SLSN': ['SLSN-I+host', 'SLSN-I_no_host'],
    'TDE': ['TDE']
}

# Map the original classes to combined classes
class_mapping = {original: combined for combined, originals in combined_classes.items() for original in originals}
df0['Combined_Type'] = df0['Object_Type'].map(class_mapping)

# Display initial counts
initial_counts = df0['Combined_Type'].value_counts()
print("Initial Class Distribution")
print(initial_counts)

# Select relevant columns for the model and exclude 'SNID', any truth info
excluded_columns = ['SNID', 'PeakMag', 'PeakFlux_GP', 'REDSHIFT_FINAL', 'Num_Peaks_GP']
excluded_columns += [col for col in df0.columns if 'err' in col.lower()]  # Remove error estimates
feature_columns = [col for col in df0.select_dtypes(include=[np.number]).columns if col not in excluded_columns]

# Renaming columns for clarity
rename_map = {
    'Mean_Color_Pre_Peak_gr': 'Mean Color Pre Peak (g-r)',
    'Mean_Color_Post_Peak_gr': 'Mean Color Post Peak (g-r)',
    'Mean_Color_Pre_Peak_ri': 'Mean Color Pre Peak (r-i)',
    'Mean_Color_Post_Peak_ri': 'Mean Color Post Peak (r-i)',
    'Slope_Pre_Peak_gr': 'Slope Pre Peak (g-r)',
    'Slope_Post_Peak_gr': 'Slope Post Peak (g-r)',
    'Slope_Pre_Peak_ri': 'Slope Pre Peak (r-i)',
    'Slope_Post_Peak_ri': 'Slope Post Peak (r-i)',
    'Rise_Time': 'Rise time',
    'Fade_Time': 'Fade time'
}
df0.rename(columns=rename_map, inplace=True)
feature_columns = [rename_map.get(col, col) for col in feature_columns]

# Define features and target
X = df0[feature_columns].fillna(-999)  # Handling missing values by filling with -999
y = df0['Combined_Type']

# Encode the target variable
label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(y)
class_labels = label_encoder.classes_

# Standardize the features
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)  # Standardize features
X_scaled = pd.DataFrame(X_scaled, columns=feature_columns)  # Retain feature names

# Split the data with a 1:1 train-test ratio and stratification
X_train, X_test, y_train, y_test = train_test_split(
    X_scaled, y_encoded, test_size=0.5, random_state=42, stratify=y_encoded
)

# Display counts after train-test split
print("\nTraining Set Class Distribution")
print(pd.Series(y_train).value_counts().sort_index().rename(index=lambda x: class_labels[x]))

print("\nTest Set Class Distribution")
print(pd.Series(y_test).value_counts().sort_index().rename(index=lambda x: class_labels[x]))

# ---------------------------------------
# Hyperparameter Tuning
# ---------------------------------------

# Define the parameter grid for RandomizedSearchCV
param_dist = {
    'n_estimators': randint(100, 1000),
    'learning_rate': uniform(0.01, 0.29),  # 0.01 to 0.3
    'max_depth': randint(3, 10),
    'min_child_weight': randint(1, 10),
    'subsample': uniform(0.5, 0.5),  # 0.5 to 1.0
    'colsample_bytree': uniform(0.5, 0.5),  # 0.5 to 1.0
    'gamma': uniform(0, 5),
    'scale_pos_weight': uniform(1, 10)  # Useful for imbalanced classes
}

# Initialize the XGBoost classifier for multi-class
xgb_clf = xgb.XGBClassifier(
    objective='multi:softprob',
    num_class=len(class_labels),
    use_label_encoder=False,
    eval_metric='mlogloss',
    random_state=42
)

# Define a custom scorer that emphasizes macro recall at a certain precision threshold
def macro_recall_at_precision_threshold(y_true, y_probas, precision_threshold=0.95):
    """
    Calculate the macro recall across all classes at a specified precision threshold.
    
    Parameters:
    - y_true: Ground truth labels.
    - y_probas: Predicted probabilities for each class.
    - precision_threshold: Desired precision threshold.
    
    Returns:
    - Macro recall score.
    """
    recalls = []
    for class_idx in range(len(class_labels)):
        threshold = find_threshold_multiclass(y_true, y_probas, target_class=class_idx, target_metric='precision', target_value=precision_threshold)
        y_pred = (y_probas[:, class_idx] >= threshold).astype(int)
        recall = recall_score((y_true == class_idx).astype(int), y_pred, zero_division=0)
        recalls.append(recall)
    return np.mean(recalls)

# Create a scorer for RandomizedSearchCV
def custom_scorer_multiclass(y_true, y_pred_probas):
    return macro_recall_at_precision_threshold(y_true, y_pred_probas, precision_threshold=0.95)

scorer = make_scorer(custom_scorer_multiclass, needs_proba=True)

# Initialize RandomizedSearchCV with StratifiedKFold
random_search = RandomizedSearchCV(
    estimator=xgb_clf,
    param_distributions=param_dist,
    n_iter=100,
    scoring=scorer,
    cv=StratifiedKFold(n_splits=3, shuffle=True, random_state=42),
    verbose=1,
    random_state=42,
    n_jobs=-1
)

# Perform hyperparameter tuning
random_search.fit(X_train, y_train)

# Retrieve the best estimator
best_xgb = random_search.best_estimator_

print(f"\nBest parameters found: {random_search.best_params_}")
print(f"Best macro recall at 95% precision: {random_search.best_score_:.4f}")

# ---------------------------------------
# Threshold Determination for Each Class
# ---------------------------------------

# Predict probabilities with the best model
y_probs = best_xgb.predict_proba(X_test)

# Define thresholds based on desired precision for each class
thresholds_dict = {}
for class_idx, class_label in enumerate(class_labels):
    # Threshold for 95% Precision
    threshold_95p_precision = find_threshold_multiclass(
        y_test, y_probs, target_class=class_idx, target_metric='precision', target_value=0.95
    )
    thresholds_dict[class_label] = {
        '95% Precision': threshold_95p_precision
    }

# Function to get predictions based on thresholds (One-vs-Rest)
def get_multiclass_predictions(y_probas, thresholds_dict):
    """
    Generate multi-class predictions based on per-class thresholds.
    
    Parameters:
    - y_probas: Predicted probabilities for each class.
    - thresholds_dict: Dictionary containing thresholds for each class.
    
    Returns:
    - Array of predicted class labels.
    """
    y_pred = np.full(y_probas.shape[0], -1)  # Initialize with -1 (invalid class)
    for class_idx, class_label in enumerate(class_labels):
        threshold = thresholds_dict[class_label]['95% Precision']
        y_pred[y_probas[:, class_idx] >= threshold] = class_idx
    
    # Assign the class with the highest probability if no threshold is met
    y_pred[y_pred == -1] = np.argmax(y_probas[y_pred == -1], axis=1)
    return y_pred

# Generate predictions using the 95% Precision thresholds
y_pred_threshold = get_multiclass_predictions(y_probs, thresholds_dict)

# ---------------------------------------
# Plotting Confusion Matrices
# ---------------------------------------

# Define the save directory
save_directory = '/home/bhardwaj/notebooksLSST/ELAsTiCC2_processed/XGB_multiclass'
os.makedirs(save_directory, exist_ok=True)

# Define the combined classes, including the 'KN' class
combined_classes = {
    'SNI': ['SNIa-SALT3', 'SNIa-91bg', 'SNIax', 'SNIcBL+HostXT_V19', 'SNIb+HostXT_V19'],
    'SNII': ['SNIIn-MOSFIT', 'SNII-NMF', 'SNII+HostXT_V19', 'SNIIb+HostXT_V19'],
    'AGN': ['AGN'],
   # 'KN': ['KN_K17', 'KN_B19'],  # Commented out as per original code
    'SLSN': ['SLSN-I+host', 'SLSN-I_no_host'],
    'TDE': ['TDE']
}

# Since 'KN' is commented out, ensure it's handled accordingly
# No further action needed as per original code

# Define the strategies
strategies = {
    'Original': None,  # No sampling for the original distribution
    'Undersampling': RandomUnderSampler(random_state=42),
    'Oversampling (SMOTE)': SMOTE(random_state=42)
}

# Adjusting the figure size and orientation of the labels for clarity
fig, axs = plt.subplots(len(strategies), 2, figsize=(12, len(strategies) * 4), gridspec_kw={'width_ratios': [5, 7]})

for i, (strategy_name, sampler) in enumerate(strategies.items()):
    if strategy_name == 'Original':
        X_train_mod, y_train_mod = X_train, y_train
    else:
        X_train_mod, y_train_mod = sampler.fit_resample(X_train, y_train)
    
    # Display counts after undersampling or oversampling
    sampled_counts = pd.Series(y_train_mod).value_counts()
    # Uncomment the following lines if you want to print sampled counts
    # print(f"\n{strategy_name} Training Set Class Distribution")
    # print(sampled_counts)
    
    # Visualize class distribution with improved label orientation
    plot_class_distribution(
        y_train_mod, class_labels,
        ax=axs[i, 0],
        title=f'Training Set Class Distribution: {strategy_name}'
    )
    
    # Train XGBoost classifier
    xgb_classifier = xgb.XGBClassifier(
        n_estimators=best_xgb.get_params()['n_estimators'],
        learning_rate=best_xgb.get_params()['learning_rate'],
        max_depth=best_xgb.get_params()['max_depth'],
        min_child_weight=best_xgb.get_params()['min_child_weight'],
        subsample=best_xgb.get_params()['subsample'],
        colsample_bytree=best_xgb.get_params()['colsample_bytree'],
        gamma=best_xgb.get_params()['gamma'],
        scale_pos_weight=best_xgb.get_params()['scale_pos_weight'],
        objective='multi:softprob',
        num_class=len(class_labels),
        use_label_encoder=False,
        eval_metric='mlogloss',
        random_state=42
    )
    xgb_classifier.fit(X_train_mod, y_train_mod)
    
    # Prediction probabilities
    y_probs_xgb = xgb_classifier.predict_proba(X_test)
    
    # Generate predictions using the 95% Precision thresholds
    y_pred_xgb = get_multiclass_predictions(y_probs_xgb, thresholds_dict)
    
    # Plot normalized confusion matrix (Normalized by truth)
    cm = confusion_matrix(y_test, y_pred_xgb, labels=np.arange(len(class_labels)))
    plot_confusion_matrix_multiclass(cm, classes=class_labels, ax=axs[i, 1], title=f'XGB Confusion Matrix ({strategy_name})', normalize=True)

plt.tight_layout()
confusion_matrix_path = os.path.join(save_directory, 'XGB_multiclass_confusion_matrices.png')
plt.savefig(confusion_matrix_path, dpi=300)
plt.show()
plt.close()  # Close the figure to free memory

# ---------------------------------------
# Plotting Precision-Recall Curves
# ---------------------------------------

# Since we have already trained the best model, we can use it for plotting
# However, for each strategy, the model might differ
# To keep things consistent, we'll plot Precision-Recall curves for the 'Original' strategy

# Re-train the best model on the 'Original' strategy for consistent plotting
best_xgb_original = best_xgb
best_xgb_original.fit(X_train, y_train)

# Predict probabilities with the best model
y_probs_best = best_xgb_original.predict_proba(X_test)

# Generate predictions using the 95% Precision thresholds
y_pred_best = get_multiclass_predictions(y_probs_best, thresholds_dict)

# Plot Precision-Recall curves
pr_save_path = os.path.join(save_directory, 'XGB_multiclass_precision_recall.png')
plot_precision_recall_curve_multiclass(y_test, y_probs_best, class_labels, pr_save_path, figsize=(8, 6))

# ---------------------------------------
# Plotting Feature Importance
# ---------------------------------------

# Plot and save Feature Importance
fi_save_path = os.path.join(save_directory, 'XGB_multiclass_feature_importance.png')
plot_feature_importance_multiclass(best_xgb_original, feature_columns, fi_save_path, figsize=(10, 8))

# ---------------------------------------
# Plotting Class Distribution
# ---------------------------------------

# Create subplots for class distributions
fig, axs = plt.subplots(1, 2, figsize=(12, 5))

# Training set class distribution for 'Original' strategy
plot_class_distribution(
    y_train, class_labels,
    ax=axs[0],
    title='Training Set Class Distribution'
)

# Test set class distribution
plot_class_distribution(
    y_test, class_labels,
    ax=axs[1],
    title='Test Set Class Distribution'
)

plt.tight_layout()
class_distribution_path = os.path.join(save_directory, 'XGB_multiclass_class_distribution.png')
plt.savefig(class_distribution_path, dpi=300)
plt.show()
plt.close()

# ---------------------------------------
# Evaluation Metrics
# ---------------------------------------

print("XGBoost Multi-class Classification Report:")
print(classification_report(y_test, y_pred_best, target_names=class_labels))

# Compute ROC AUC for multi-class
roc_auc = roc_auc_score(y_test, y_probs_best, multi_class='ovr')
print(f"ROC AUC (One-vs-Rest): {roc_auc:.4f}")

# ---------------------------------------
# Permutation Feature Importance
# ---------------------------------------

# Calculate permutation feature importance
result = permutation_importance(
    best_xgb_original, X_test, y_test, n_repeats=10, random_state=42, scoring='roc_auc_ovr'
)

# Create a Series for easy plotting
perm_importance = pd.Series(result.importances_mean, index=feature_columns).sort_values(ascending=False)

# Plot top 20 permutation feature importances
plt.figure(figsize=(10, 8))
sns.barplot(x=perm_importance.values[:20], y=perm_importance.index[:20], palette='viridis')
plt.title('Permutation Feature Importance (Top 20)', fontsize=12)
plt.xlabel('Mean Decrease in ROC AUC', fontsize=10)
plt.ylabel('Features', fontsize=10)
plt.tight_layout()
permutation_importance_path = os.path.join(save_directory, 'XGB_multiclass_permutation_importance.png')
plt.savefig(permutation_importance_path, dpi=300)
plt.show()
plt.close()


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split, RandomizedSearchCV, StratifiedKFold
from sklearn.metrics import (
    confusion_matrix, precision_recall_curve, classification_report, roc_auc_score,
    make_scorer, precision_score, recall_score
)
import xgboost as xgb
from imblearn.over_sampling import SMOTE
from imblearn.under_sampling import RandomUnderSampler
from scipy.stats import uniform, randint
import os
from sklearn.inspection import permutation_importance
from sklearn.metrics import auc

# ---------------------------------------
# Helper Functions
# ---------------------------------------

def find_threshold_multiclass(y_true, y_probs, target_class, target_metric='precision', target_value=0.95):
    """
    Find the threshold for desired precision or recall for a specific class.
    
    Parameters:
    - y_true: Ground truth labels.
    - y_probs: Predicted probabilities for each class.
    - target_class: The class index for which to find the threshold.
    - target_metric: 'precision' or 'recall'.
    - target_value: Desired value for the target metric.
    
    Returns:
    - Threshold value.
    """
    y_true_bin = (y_true == target_class).astype(int)
    precision, recall, thresholds = precision_recall_curve(y_true_bin, y_probs[:, target_class])
    
    if target_metric == 'precision':
        idx = np.where(precision >= target_value)[0]
        if len(idx) == 0:
            print(f"No threshold found to achieve {target_value*100}% precision for class {target_class}. Returning threshold=1.0")
            return 1.0
        return thresholds[idx[0]]
    elif target_metric == 'recall':
        idx = np.where(recall >= target_value)[0]
        if len(idx) == 0:
            print(f"No threshold found to achieve {target_value*100}% recall for class {target_class}. Returning threshold=0.0")
            return 0.0
        return thresholds[idx[-1]]
    else:
        raise ValueError("target_metric must be 'precision' or 'recall'.")

def plot_confusion_matrix_multiclass(cm, classes, ax, title, normalize=True):
    """
    Plot a confusion matrix for multi-class classification.
    
    Parameters:
    - cm: Confusion matrix.
    - classes: List of class names.
    - ax: Matplotlib Axes object.
    - title: Title for the plot.
    - normalize: Whether to normalize the confusion matrix.
    """
    if normalize:
        cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    else:
        cm_normalized = cm

    sns.heatmap(cm_normalized, annot=False, fmt='.2f', cmap='Blues', ax=ax,
                xticklabels=classes, yticklabels=classes, cbar=False)

    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            if normalize:
                percentage = cm_normalized[i, j] * 100
                count = cm[i, j]
                text = f'{percentage:.1f}%\n({count})'
            else:
                text = f'{cm[i, j]}'
            ax.text(j + 0.5, i + 0.5, text,
                    ha='center', va='center',
                    color='black', fontsize=8,
                    bbox=dict(facecolor='white', edgecolor='white'))

    ax.set_xlabel('Predicted Class', fontsize=10)
    ax.set_ylabel('True Class', fontsize=10)
    ax.set_title(title, fontsize=12)

def plot_precision_recall_curve_multiclass(y_true, y_probs, class_labels, save_path, figsize=(6, 4)):
    """
    Plot Precision-Recall curves for each class in a multi-class setting.
    
    Parameters:
    - y_true: Ground truth labels.
    - y_probs: Predicted probabilities for each class.
    - class_labels: List of class names.
    - save_path: Path to save the plot.
    - figsize: Figure size.
    """
    plt.figure(figsize=figsize)
    for i, class_label in enumerate(class_labels):
        y_true_bin = (y_true == i).astype(int)
        precision, recall, _ = precision_recall_curve(y_true_bin, y_probs[:, i])
        pr_auc = auc(recall, precision)
        plt.plot(recall, precision, lw=1.5, label=f'{class_label} (AUC = {pr_auc:.2f})')
    
    plt.xlabel('Recall', fontsize=10)
    plt.ylabel('Precision', fontsize=10)
    plt.title('Precision-Recall Curve for Multi-class', fontsize=12)
    plt.legend(fontsize=8, loc='lower left')
    plt.grid(True, linestyle='--', alpha=0.5)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300)
    plt.show()
    plt.close()

def plot_feature_importance_multiclass(model, feature_names, save_path, figsize=(8, 6)):
    """
    Plot and save the feature importance for a multi-class XGBoost model.
    
    Parameters:
    - model: Trained XGBoost model.
    - feature_names: List of feature names.
    - save_path: Path to save the plot.
    - figsize: Figure size.
    """
    plt.figure(figsize=figsize)
    ax = plt.gca()
    xgb.plot_importance(
        model,
        max_num_features=20,
        importance_type='weight',
        show_values=False,
        ax=ax
    )
    ax.set_title('Feature Importance', fontsize=12)
    ax.set_xlabel('Importance (Weight)', fontsize=10)
    ax.set_ylabel('Features', fontsize=10)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300)
    plt.show()
    plt.close()

def plot_class_distribution(y, classes, ax, title):
    """
    Plot the class distribution.
    
    Parameters:
    - y: Array-like of labels.
    - classes: List of class names.
    - ax: Matplotlib Axes object.
    - title: Title for the plot.
    """
    sns.countplot(y=y, ax=ax, palette='Set2', order=np.sort(np.unique(y)))
    ax.set_title(title, fontsize=12)
    ax.set_xlabel('Counts', fontsize=10)
    ax.set_ylabel('Class', fontsize=10)
    ax.set_yticklabels(classes[np.sort(np.unique(y))], rotation=45, fontsize=9)
    ax.tick_params(axis='y', labelsize=8)

# ---------------------------------------
# Data Preparation
# ---------------------------------------

# Ensure df0 is loaded
if 'df0' not in globals():
    raise ValueError("DataFrame 'df0' is not loaded. Please load your data into 'df0' before running the script.")

# Define combined classes
combined_classes = {
    'SNI': ['SNIa-SALT3', 'SNIa-91bg', 'SNIax', 'SNIcBL+HostXT_V19', 'SNIb+HostXT_V19'],
    'SNII': ['SNIIn-MOSFIT', 'SNII-NMF', 'SNII+HostXT_V19', 'SNIIb+HostXT_V19'],
    'AGN': ['AGN'],
    # 'KN': ['KN_K17', 'KN_B19'],  # Commented out as per original code
    'SLSN': ['SLSN-I+host', 'SLSN-I_no_host'],
    'TDE': ['TDE']
}

# Map the original classes to combined classes
class_mapping = {original: combined for combined, originals in combined_classes.items() for original in originals}
df0['Combined_Type'] = df0['Object_Type'].map(class_mapping)

# Display initial counts
initial_counts = df0['Combined_Type'].value_counts()
print("Initial Class Distribution:")
print(initial_counts)

# Select relevant columns for the model and exclude 'SNID', any truth info
excluded_columns = ['SNID', 'PeakMag', 'PeakFlux_GP', 'REDSHIFT_FINAL', 'Num_Peaks_GP']
excluded_columns += [col for col in df0.columns if 'err' in col.lower()]  # Remove error estimates

feature_columns = [col for col in df0.select_dtypes(include=[np.number]).columns if col not in excluded_columns]

# Renaming columns for clarity
rename_map = {
    'Mean_Color_Pre_Peak_gr': 'Mean Color Pre Peak (g-r)',
    'Mean_Color_Post_Peak_gr': 'Mean Color Post Peak (g-r)',
    'Mean_Color_Pre_Peak_ri': 'Mean Color Pre Peak (r-i)',
    'Mean_Color_Post_Peak_ri': 'Mean Color Post Peak (r-i)',
    'Slope_Pre_Peak_gr': 'Slope Pre Peak (g-r)',
    'Slope_Post_Peak_gr': 'Slope Post Peak (g-r)',
    'Slope_Pre_Peak_ri': 'Slope Pre Peak (r-i)',
    'Slope_Post_Peak_ri': 'Slope Post Peak (r-i)',
    'Rise_Time': 'Rise time',
    'Fade_Time': 'Fade time'
}
df0.rename(columns=rename_map, inplace=True)
feature_columns = [rename_map.get(col, col) for col in feature_columns]

# Define features and target
X = df0[feature_columns].fillna(-999)  # Handling missing values by filling with -999
y = df0['Combined_Type']

# Encode the target variable
label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(y)
class_labels = label_encoder.classes_

# Standardize the features
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)  # Standardize features
X_scaled = pd.DataFrame(X_scaled, columns=feature_columns)  # Retain feature names

# Split the data with a 1:1 train-test ratio and stratification
X_train, X_test, y_train, y_test = train_test_split(
    X_scaled, y_encoded, test_size=0.5, random_state=42, stratify=y_encoded
)

# Display counts after train-test split
train_counts = pd.Series(y_train).value_counts().sort_index()
test_counts = pd.Series(y_test).value_counts().sort_index()
print("\nTraining Set Class Distribution:")
print(pd.Series(y_train).value_counts().sort_index().rename(index=lambda x: class_labels[x]))

print("\nTest Set Class Distribution:")
print(pd.Series(y_test).value_counts().sort_index().rename(index=lambda x: class_labels[x]))

# ---------------------------------------
# Hyperparameter Tuning
# ---------------------------------------

# Define the parameter grid for RandomizedSearchCV
param_dist = {
    'n_estimators': randint(100, 1000),
    'learning_rate': uniform(0.01, 0.29),  # 0.01 to 0.3
    'max_depth': randint(3, 10),
    'min_child_weight': randint(1, 10),
    'subsample': uniform(0.5, 0.5),  # 0.5 to 1.0
    'colsample_bytree': uniform(0.5, 0.5),  # 0.5 to 1.0
    'gamma': uniform(0, 5),
    'scale_pos_weight': uniform(1, 10)  # Useful for imbalanced classes
}

# Initialize the XGBoost classifier for multi-class
xgb_clf = xgb.XGBClassifier(
    objective='multi:softprob',
    num_class=len(class_labels),
    use_label_encoder=False,
    eval_metric='mlogloss',
    random_state=42
)

# Define a custom scorer that emphasizes macro recall at a certain precision threshold
def macro_recall_at_precision_threshold(y_true, y_probas, precision_threshold=0.80):
    """
    Calculate the macro recall at a specified precision threshold across all classes.
    
    Parameters:
    - y_true: Ground truth labels.
    - y_probas: Predicted probabilities for each class.
    - precision_threshold: Desired precision threshold.
    
    Returns:
    - Macro recall score.
    """
    recalls = []
    for class_idx in range(len(class_labels)):
        threshold = find_threshold_multiclass(y_true, y_probas, target_class=class_idx, target_metric='precision', target_value=precision_threshold)
        y_pred = (y_probas[:, class_idx] >= threshold).astype(int)
        recall = recall_score((y_true == class_idx).astype(int), y_pred, zero_division=0)
        recalls.append(recall)
    return np.mean(recalls)

# Create a scorer for RandomizedSearchCV
def custom_scorer_multiclass(y_true, y_pred_probas):
    return macro_recall_at_precision_threshold(y_true, y_pred_probas, precision_threshold=0.80)

scorer = make_scorer(custom_scorer_multiclass, needs_proba=True)

# Initialize RandomizedSearchCV with StratifiedKFold
random_search = RandomizedSearchCV(
    estimator=xgb_clf,
    param_distributions=param_dist,
    n_iter=100,
    scoring=scorer,
    cv=StratifiedKFold(n_splits=3, shuffle=True, random_state=42),
    verbose=1,
    random_state=42,
    n_jobs=-1
)

# Perform hyperparameter tuning
random_search.fit(X_train, y_train)

# Retrieve the best estimator
best_xgb = random_search.best_estimator_

print(f"\nBest parameters found: {random_search.best_params_}")
print(f"Best macro recall at 80% precision: {random_search.best_score_:.4f}")

# ---------------------------------------
# Threshold Determination for Each Class
# ---------------------------------------

# Predict probabilities with the best model
y_probs = best_xgb.predict_proba(X_test)

# Define thresholds based on desired precision and recall for each class
thresholds_dict = {}
for class_idx, class_label in enumerate(class_labels):
    # Threshold for 80% Precision
    threshold_80p_precision = find_threshold_multiclass(
        y_test, y_probs, target_class=class_idx, target_metric='precision', target_value=0.80
    )
    # Threshold for 95% Precision
    threshold_95p_precision = find_threshold_multiclass(
        y_test, y_probs, target_class=class_idx, target_metric='precision', target_value=0.95
    )
    # Threshold for 95% Recall
    threshold_95p_recall = find_threshold_multiclass(
        y_test, y_probs, target_class=class_idx, target_metric='recall', target_value=0.95
    )
    
    thresholds_dict[class_label] = {
        '80% Precision': threshold_80p_precision,
        '95% Precision': threshold_95p_precision,
        '95% Recall': threshold_95p_recall
    }

# Function to get predictions based on thresholds (One-vs-Rest)
def get_multiclass_predictions(y_probs, thresholds_dict):
    """
    Generate multi-class predictions based on per-class thresholds.
    
    Parameters:
    - y_probs: Predicted probabilities for each class.
    - thresholds_dict: Dictionary containing thresholds for each class.
    
    Returns:
    - Array of predicted class labels.
    """
    y_pred = np.full(y_probs.shape[0], -1)  # Initialize with -1 (invalid class)
    for idx, class_label in enumerate(class_labels):
        thresholds = thresholds_dict[class_label]
        # For simplicity, use one threshold (e.g., 80% Precision) or define rules to combine thresholds
        # Here, we'll use 80% Precision thresholds for prediction
        threshold = thresholds['80% Precision']
        y_pred[y_probs[:, idx] >= threshold] = idx
    
    # Assign the class with the highest probability if no threshold is met
    y_pred[y_pred == -1] = np.argmax(y_probs[y_pred == -1], axis=1)
    return y_pred

# Generate predictions using the 80% Precision thresholds
y_pred_threshold = get_multiclass_predictions(y_probs, thresholds_dict)

# ---------------------------------------
# Plotting Confusion Matrices
# ---------------------------------------

# Define the save directory
save_directory = '/home/bhardwaj/notebooksLSST/ELAsTiCC2_processed/XGB_multiclass'
os.makedirs(save_directory, exist_ok=True)

# Compute confusion matrix
cm = confusion_matrix(y_test, y_pred_threshold, labels=np.arange(len(class_labels)))

# Create a single plot for confusion matrix
plt.figure(figsize=(10, 8))
plot_confusion_matrix_multiclass(cm, classes=class_labels, ax=plt.gca(), title='XGBoost Confusion Matrix', normalize=True)
plt.tight_layout()
confusion_matrix_path = os.path.join(save_directory, 'XGB_multiclass_confusion_matrix.png')
plt.savefig(confusion_matrix_path, dpi=300)
plt.show()

# ---------------------------------------
# Plotting Precision-Recall Curves
# ---------------------------------------

# Define save path for Precision-Recall curve
save_path_pr = os.path.join(save_directory, 'XGB_multiclass_precision_recall.png')

# Plot and save Precision-Recall curves
plot_precision_recall_curve_multiclass(y_test, y_probs, class_labels, save_path_pr, figsize=(8, 6))

# ---------------------------------------
# Plotting Feature Importance
# ---------------------------------------

# Define save path for Feature Importance plot
save_path_fi = os.path.join(save_directory, 'XGB_multiclass_feature_importance.png')

# Plot and save Feature Importance
plot_feature_importance_multiclass(best_xgb, feature_columns, save_path_fi, figsize=(10, 8))

# ---------------------------------------
# Plotting Class Distribution
# ---------------------------------------

fig, axs = plt.subplots(1, 2, figsize=(12, 5))

# Training set class distribution
plot_class_distribution(
    y_train, class_labels,
    ax=axs[0],
    title='Training Set Class Distribution'
)

# Test set class distribution
plot_class_distribution(
    y_test, class_labels,
    ax=axs[1],
    title='Test Set Class Distribution'
)

plt.tight_layout()
class_distribution_path = os.path.join(save_directory, 'XGB_multiclass_class_distribution.png')
plt.savefig(class_distribution_path, dpi=300)
plt.show()
plt.close()

# ---------------------------------------
# Evaluation Metrics
# ---------------------------------------

print("XGBoost Multi-class Classification Report:")
print(classification_report(y_test, y_pred_threshold, target_names=class_labels))

# Compute ROC AUC for multi-class
roc_auc = roc_auc_score(y_test, y_probs, multi_class='ovr')
print(f"ROC AUC (One-vs-Rest): {roc_auc:.4f}")

# ---------------------------------------
# Permutation Feature Importance
# ---------------------------------------

result = permutation_importance(
    best_xgb, X_test, y_test, n_repeats=10, random_state=42, scoring='roc_auc_ovr'
)

perm_importance = pd.Series(result.importances_mean, index=feature_columns).sort_values(ascending=False)

plt.figure(figsize=(10, 8))
sns.barplot(x=perm_importance.values[:20], y=perm_importance.index[:20], palette='viridis')
plt.title('Permutation Feature Importance')
plt.xlabel('Mean Decrease in ROC AUC')
plt.ylabel('Features')
plt.tight_layout()
permutation_importance_path = os.path.join(save_directory, 'XGB_multiclass_permutation_importance.png')
plt.savefig(permutation_importance_path, dpi=300)
plt.show()
plt.close()


In [None]:
# Gain-based importance
plt.figure(figsize=(10, 8))
xgb.plot_importance(best_xgb, importance_type='gain', max_num_features=20, show_values=False)
plt.title('Feature Importance (Gain)')
plt.tight_layout()
plt.show()

# Cover-based importance
plt.figure(figsize=(10, 8))
xgb.plot_importance(best_xgb, importance_type='cover', max_num_features=20, show_values=False)
plt.title('Feature Importance (Cover)')
plt.tight_layout()
plt.show()


In [None]:
# Define the classes to exclude
excluded_classes = ['KN_K17', 'KN_B19', 'SLSN-I_no_host', 'SNIIn-MOSFIT']

# Filter out the excluded classes
df0_filtered = df0[~df0['Object_Type'].isin(excluded_classes)]

# Select relevant columns for the model and exclude 'SNID', any truth info
excluded_columns = ['SNID', 'PeakMag', 'PeakFlux_GP', 'REDSHIFT_FINAL']
excluded_columns += [col for col in df0_filtered.columns if 'err' in col.lower()]  # remove error estimates
feature_columns = [col for col in df0_filtered.select_dtypes(include=[np.number]).columns if col not in excluded_columns]

X = df0_filtered[feature_columns].fillna(-999)  # Handling missing values by filling with -999
y = df0_filtered['Object_Type']
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)  # Standardize features

X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.7, random_state=42)

def plot_normalized_conf_matrix(y_true, y_pred, ax, title):
    cm = confusion_matrix(y_true, y_pred, labels=np.unique(y_true))
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    sns.heatmap(cm_normalized, annot=True, fmt=".2f", cmap='Blues', ax=ax, 
                xticklabels=np.unique(y_true), yticklabels=np.unique(y_true))
    ax.set_xlabel('Predicted Class')
    ax.set_ylabel('True Class')
    ax.set_title(title)

strategies = {
    'Original': None,  # No sampling for the original distribution
    'Undersampling': RandomUnderSampler(random_state=42),
    'Oversampling (SMOTE)': SMOTE(random_state=42)
}

# Adjusting the figure size and orientation of the labels for clarity
fig, axs = plt.subplots(len(strategies), 2, figsize=(24, len(strategies) * 8), gridspec_kw={'width_ratios': [2, 3]})

for i, (strategy_name, sampler) in enumerate(strategies.items()):
    if strategy_name == 'Original':
        X_train_mod, y_train_mod = X_train, y_train
    else:
        X_train_mod, y_train_mod = sampler.fit_resample(X_train, y_train)
    
    # Visualize class distribution with improved label orientation
    sns.countplot(y=y_train_mod, ax=axs[i, 0], palette='Set2', order=y_train_mod.value_counts().index)
    axs[i, 0].set_title(f'Class Distribution: {strategy_name}')
    axs[i, 0].set_xlabel('Counts')
    axs[i, 0].set_ylabel('Class')
    axs[i, 0].tick_params(axis='y', rotation=45)  # Rotate labels to avoid overlap
    
    # Train Random Forest classifier
    rf_classifier = RandomForestClassifier(n_estimators=100, random_state=42, class_weight='balanced')
    rf_classifier.fit(X_train_mod, y_train_mod)

    # Train the Gradient Boosting classifier
    bdt_classifier = GradientBoostingClassifier(n_estimators=100, learning_rate=1.0, max_depth=1, random_state=42)
    bdt_classifier.fit(X_train_mod, y_train_mod)

    # Prediction 
    y_pred_rf = rf_classifier.predict(X_test)
    y_pred_bdt = bdt_classifier.predict(X_test)

    # Plot normalized confusion matrix
    plot_normalized_conf_matrix(y_test, y_pred_rf, axs[i, 1], f'RF: {strategy_name}')

plt.tight_layout()
plt.show()

def plot_precision_recall_curve_multiclass(y_true, y_probas, class_labels):
    plt.figure(figsize=(10, 7))
    for i, class_label in enumerate(class_labels):
        # Binarize the output for the current class
        y_true_bin = (y_true == class_label).astype(int)
        precision, recall, _ = precision_recall_curve(y_true_bin, y_probas[:, i])
        auc_score = auc(recall, precision) 
        plt.plot(recall, precision, lw=2, label=f'Class {class_label} (AUC = {auc_score:0.2f})')
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.legend(loc="best")
    plt.title("Precision-Recall Curve for Multiple Classes")
    plt.grid(True)
    plt.show()

# Get the probabilities for each class
y_probs_rf = rf_classifier.predict_proba(X_test)
y_probs_bdt = bdt_classifier.predict_proba(X_test)

# Class labels for the multi-class setup
class_labels = rf_classifier.classes_

# Plot the precision-recall curve for Random Forest
plot_precision_recall_curve_multiclass(y_test, y_probs_rf, class_labels)

print("Random Forest Classification Report:")
print(classification_report(y_test, y_pred_rf))
print("\nGradient Boosting Classification Report:")
print(classification_report(y_test, y_pred_bdt))


### precision recall curves

In [None]:
#FROM  tdescore (Stein), 

cls = 'TDE'
probs_cls = y_probs_rf[:, class_labels.tolist().index(cls)]
pr, recall, thresholds = precision_recall_curve(y_test == cls, probs_cls)

index = np.arange(len(thresholds))

mask = recall[1:] >= 0.96
loose_index = max(index[mask])
threshold_loose = thresholds[loose_index]
print(f"Loose threshold {threshold_loose:.2f}, Purity={100.*pr[loose_index]:.1f}%, Efficiency={100.*recall[loose_index]:.1f}%")

mask = pr[:-1] >= .95
strict_index = min(index[mask])
threshold_strict = thresholds[strict_index]
print(f"Strict Threshold {threshold_strict:.2f}, Purity={100.*pr[strict_index]:.1f}%, Efficiency={100.*recall[strict_index]:.1f}%")

mask = pr[:-1] >= 0.90
balanced_index = min(index[mask])
threshold_balanced = thresholds[balanced_index]
print(f"Balanced Threshold {threshold_balanced:.2f}, Purity={100.*pr[balanced_index]:.1f}%, Efficiency={100.*recall[balanced_index]:.1f}%")

In [None]:
cls = 'TDE'
probs_cls = y_probs_rf[:, class_labels.tolist().index(cls)]

# Calculate precision, recall, and thresholds for TDE
precision, recall, thresholds_pr = precision_recall_curve(y_test == cls, probs_cls)

# Plot precision and recall as functions of the threshold
plt.figure(figsize=(10, 7))

plt.plot(thresholds_pr, precision[:-1], label='Precision', color='blue')
plt.plot(thresholds_pr, recall[:-1], label='Recall', color='orange')

# Adding vertical lines for specific thresholds
threshold_95_precision = thresholds_pr[np.where(precision[:-1] >= 0.95)[0][0]]
#threshold_90_recall = thresholds_pr[np.where(recall[:-1] >= 0.90)[0][0]]
#threshold_balanced = thresholds_pr[np.argmax(precision[:-1] >= 0.80)]

plt.axvline(x=threshold_95_precision, color='black', linestyle='--', label='95% Precision Threshold')
#plt.axvline(x=threshold_90_recall, color='black', linestyle='-', label='90% Recall Threshold')
#plt.axvline(x=threshold_balanced, color='black', linestyle='-', label='Balanced Threshold (80% Precision)')

plt.xlabel('Threshold')
plt.ylabel('Score')
plt.title('Precision and Recall as a function of threshold')
plt.legend(loc='best')
plt.grid(True)
plt.show()


In [None]:
# Combine all SN into one class and leave SLSN, TDE, AGN, and KN as separate classes
combined_classes = {
    'SNI': ['SNIa-SALT3', 'SNIa-91bg', 'SNIax', 'SNIcBL+HostXT_V19', 'SNIb+HostXT_V19'],
    'SNII': ['SNIIn-MOSFIT', 'SNII-NMF', 'SNII+HostXT_V19', 'SNIIb+HostXT_V19'],
    'CLAGN': ['CLAGN'],
    'KN': ['KN_K17', 'KN_B19'],
    'SLSN': ['SLSN-I+host', 'SLSN-I_no_host'],
    'TDE': ['TDE']
}


# Create a combined class column in df0
df0['Combined_Class'] = df0['Object_Type']
for combined_class, original_classes in combined_classes.items():
    df0.loc[df0['Object_Type'].isin(original_classes), 'Combined_Class'] = combined_class

# Select relevant columns for the model and exclude 'SNID', any truth info
excluded_columns = ['SNID', 'PeakMag', 'PeakFlux_GP', 'REDSHIFT_FINAL']
excluded_columns += [col for col in df0.columns if 'err' in col.lower()]
feature_columns = [col for col in df0.select_dtypes(include=[np.number]).columns if col not in excluded_columns]

X = df0[feature_columns].fillna(-999)
y = df0['Combined_Class']
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.7, random_state=42)

# Balance the training set using SMOTE
smote = SMOTE(random_state=42)
X_train_balanced, y_train_balanced = smote.fit_resample(X_train, y_train)

# Train Random Forest classifier on the balanced dataset
rf_classifier = RandomForestClassifier(n_estimators=100, random_state=42, class_weight='balanced')
rf_classifier.fit(X_train_balanced, y_train_balanced)

# Get the probabilities for each class (classifier scores)
y_probs_rf = rf_classifier.predict_proba(X_test)
class_labels = rf_classifier.classes_

# Define a color palette
palette = sns.color_palette("tab10", len(class_labels))

# Plot precision and recall as functions of the threshold for each class
plt.figure(figsize=(12, 8))

for idx, cls in enumerate(class_labels):
    probs_cls = y_probs_rf[:, class_labels.tolist().index(cls)]
    precision, recall, thresholds_pr = precision_recall_curve(y_test == cls, probs_cls)
    
    plt.plot(thresholds_pr, precision[:-1], label=f'Precision {cls}', color=palette[idx])
    plt.plot(thresholds_pr, recall[:-1], label=f'Recall {cls}', color=palette[idx], linestyle='--')

plt.xlabel('Threshold')
plt.ylabel('Score')
plt.title('Precision and Recall as a function of threshold for each class')
plt.legend(loc='best', fontsize='small')
plt.grid(True)
plt.show()


In [None]:
# Define the combined classes
combined_classes = {
    'Non-TDE': ['SNIa-SALT3', 'SNIa-91bg', 'SNIax', 'SNIcBL+HostXT_V19', 'SNIb+HostXT_V19', 
                'SNIIn-MOSFIT', 'SNII-NMF', 'SNII+HostXT_V19', 'SNIIb+HostXT_V19', 
                'CLAGN', 'KN_K17', 'KN_B19', 
                'SLSN-I+host', 'SLSN-I_no_host'],
    'TDE': ['TDE']
}

# Map the original classes to combined classes
class_mapping = {original: combined for combined, originals in combined_classes.items() for original in originals}
df0['Combined_Class'] = df0['Object_Type'].map(class_mapping)

# Select relevant columns for the model and exclude 'SNID', any truth info
excluded_columns = ['SNID', 'PeakMag', 'PeakFlux_GP', 'REDSHIFT_FINAL', 'Num_Peaks_GP']
excluded_columns += [col for col in df0.columns if 'err' in col.lower()]  # remove error estimates
feature_columns = [col for col in df0.select_dtypes(include=[np.number]).columns if col not in excluded_columns]

X = df0[feature_columns].fillna(-999)  # Handling missing values by filling with -999
y = df0['Combined_Class']

# Encode the target variable
label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(y)

scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)  # Standardize features

X_train, X_test, y_train, y_test = train_test_split(X_scaled, y_encoded, test_size=0.7, random_state=42)

# Balance the training set using SMOTE
smote = SMOTE(random_state=42)
X_train_balanced, y_train_balanced = smote.fit_resample(X_train, y_train)

# Train XGBoost classifier on the balanced dataset
xgb_classifier = xgb.XGBClassifier(n_estimators=100, learning_rate=0.1, max_depth=3, random_state=42)
xgb_classifier.fit(X_train_balanced, y_train_balanced)

# Get the probabilities for each class (classifier scores)
y_probs_xgb = xgb_classifier.predict_proba(X_test)
class_labels = xgb_classifier.classes_

# Plot precision and recall as functions of the threshold for TDE vs Non-TDE
plt.figure(figsize=(12, 8))

for idx, cls in enumerate(class_labels):
    probs_cls = y_probs_xgb[:, class_labels.tolist().index(cls)]
    precision, recall, thresholds_pr = precision_recall_curve(y_test == cls, probs_cls)
    
    plt.plot(thresholds_pr, precision[:-1], label=f'Precision {label_encoder.inverse_transform([cls])[0]}', color=sns.color_palette("tab10")[idx])
    plt.plot(thresholds_pr, recall[:-1], label=f'Recall {label_encoder.inverse_transform([cls])[0]}', color=sns.color_palette("tab10")[idx], linestyle='--')

plt.xlabel('Threshold')
plt.ylabel('Score')
plt.title('Precision and Recall as a function of threshold for TDE vs Non-TDE')
plt.legend(loc='best', fontsize='small')
plt.grid(True)
plt.show()

def plot_normalized_conf_matrix(y_true, y_pred, ax, title):
    cm = confusion_matrix(y_true, y_pred)
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    sns.heatmap(cm_normalized, annot=True, fmt=".2f", cmap='Blues', ax=ax, 
                xticklabels=label_encoder.classes_, yticklabels=label_encoder.classes_)
    ax.set_xlabel('Predicted Class')
    ax.set_ylabel('True Class')
    ax.set_title(title)

# Plot normalized confusion matrix
fig, ax = plt.subplots(figsize=(10, 8))
y_pred_xgb = xgb_classifier.predict(X_test)
plot_normalized_conf_matrix(y_test, y_pred_xgb, ax, 'XGB: TDE vs Non-TDE')
plt.show()

print("XGBoost Classification Report:")
print(classification_report(y_test, y_pred_xgb, target_names=label_encoder.classes_))


### Binary Classification: TDE vs Others

#### RF

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, RandomizedSearchCV, StratifiedKFold
from sklearn.metrics import (
    confusion_matrix, precision_recall_curve, classification_report, roc_auc_score,
    make_scorer
)
from sklearn.ensemble import RandomForestClassifier
from sklearn.inspection import permutation_importance
import os
from scipy.stats import uniform, randint

# ---------------------------------------
# Helper Functions
# ---------------------------------------

def find_threshold(y_true, y_probs, target='precision', target_value=0.95):
    """
    Find the threshold for desired precision or recall.

    Parameters:
    - y_true: Ground truth binary labels.
    - y_probs: Predicted probabilities for the positive class.
    - target: 'precision' or 'recall'.
    - target_value: Desired value for the target metric.

    Returns:
    - Threshold value.
    """
    precision, recall, thresholds = precision_recall_curve(y_true, y_probs)
    
    if target == 'precision':
        # Find indices where precision >= target_value
        idx = np.where(precision >= target_value)[0]
        if len(idx) == 0:
            print(f"No threshold found to achieve {target_value*100}% precision. Returning threshold=1.0")
            return 1.0  # Maximum threshold
        # Choose the threshold corresponding to the first occurrence
        return thresholds[idx[0]]
    
    elif target == 'recall':
        # Find indices where recall >= target_value
        idx = np.where(recall >= target_value)[0]
        if len(idx) == 0:
            print(f"No threshold found to achieve {target_value*100}% recall. Returning threshold=0.0")
            return 0.0  # Minimum threshold
        # Choose the threshold corresponding to the last occurrence
        return thresholds[idx[-1]]
    
    else:
        raise ValueError("Target must be 'precision' or 'recall'.")

def plot_confusion_matrix(cm, ax, title, normalize=True):
    """
    Plot a confusion matrix with percentages and counts.

    Parameters:
    - cm: Confusion matrix.
    - ax: Matplotlib Axes object.
    - title: Title for the subplot.
    - normalize: Whether to normalize the confusion matrix per true label.
    """
    if normalize:
        cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    else:
        cm_normalized = cm

    sns.heatmap(
        cm_normalized, annot=False, fmt='.2f', cmap='Blues', ax=ax,
        xticklabels=['Non-TDE', 'TDE'], yticklabels=['Non-TDE', 'TDE'], cbar=False
    )

    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            if normalize:
                percentage = cm_normalized[i, j] * 100
                count = cm[i, j]
                text = f'{percentage:.1f}%\n({count})'
            else:
                text = f'{cm[i, j]}'
            ax.text(
                j + 0.5, i + 0.5, text, ha='center', va='center',
                color='black', fontsize=8, bbox=dict(facecolor='white', edgecolor='white')
            )
    
    ax.set_xlabel('Predicted Label', fontsize=10)
    ax.set_ylabel('True Label', fontsize=10)
    ax.set_title(title, fontsize=12)

def plot_precision_recall_curve(y_true, y_probs, thresholds_to_mark, save_path, figsize=(6, 4)):
    """
    Plot Precision and Recall as functions of the threshold.

    Parameters:
    - y_true: Ground truth binary labels.
    - y_probs: Predicted probabilities for the positive class.
    - thresholds_to_mark: Dictionary with labels and threshold values to mark.
    - save_path: Path to save the plot.
    - figsize: Tuple specifying the figure size.
    """
    precision, recall, thresholds = precision_recall_curve(y_true, y_probs)
    
    plt.figure(figsize=figsize)
    plt.plot(thresholds, precision[:-1], label='Precision', color='blue')
    plt.plot(thresholds, recall[:-1], label='Recall', color='green')
    
    for label, thresh in thresholds_to_mark.items():
        if '95% Precision' in label:
            color = 'red'  # Change color for '95% Precision' threshold
        else:
            color = 'blue'  # Default color for other thresholds
        plt.axvline(x=thresh, linestyle='--', color=color, label=f'{label} ({thresh:.2f})')
    
    plt.xlabel('Threshold', fontsize=10)
    plt.ylabel('Score', fontsize=10)
    plt.title('Precision and Recall vs. Threshold', fontsize=12)
    plt.legend(fontsize=8)
    plt.grid(True, linestyle='--', alpha=0.5)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300)
    plt.show()
    plt.close()

def plot_feature_importance(model, feature_names, save_path, figsize=(8, 6)):
    """
    Plot and save the feature importance from the Random Forest model.

    Parameters:
    - model: Trained Random Forest model.
    - feature_names: List of feature names.
    - save_path: Path to save the plot.
    - figsize: Tuple specifying the figure size.
    """
    plt.figure(figsize=figsize)
    importances = model.feature_importances_
    indices = np.argsort(importances)[::-1]
    top_n = 20
    indices = indices[:top_n]
    sorted_importances = importances[indices]
    sorted_features = [feature_names[i] for i in indices]
    
    sns.barplot(x=sorted_importances, y=sorted_features, palette='viridis')
    plt.title('Feature Importance', fontsize=12)
    plt.xlabel('Importance', fontsize=10)
    plt.ylabel('Features', fontsize=10)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300)
    plt.show()
    plt.close()

def plot_permutation_importance_rf(model, X_test, y_test, feature_names, save_path, figsize=(10, 8)):
    """
    Plot and save permutation feature importance from the Random Forest model.

    Parameters:
    - model: Trained Random Forest model.
    - X_test: Test feature set.
    - y_test: Test labels.
    - feature_names: List of feature names.
    - save_path: Path to save the plot.
    - figsize: Tuple specifying the figure size.
    """
    result = permutation_importance(
        model, X_test, y_test, n_repeats=10, random_state=42, scoring='roc_auc'
    )
    perm_importance = pd.Series(result.importances_mean, index=feature_names).sort_values(ascending=False)
    
    plt.figure(figsize=figsize)
    sns.barplot(x=perm_importance.values[:20], y=perm_importance.index[:20], palette='viridis')
    plt.title('Permutation Feature Importance')
    plt.xlabel('Mean Decrease in ROC AUC')
    plt.ylabel('Features')
    plt.tight_layout()
    plt.savefig(save_path, dpi=300)
    plt.show()
    plt.close()

# ---------------------------------------
# Data Preparation
# ---------------------------------------

# Ensure df0 is loaded
if 'df0' not in locals():
    raise ValueError("DataFrame 'df0' is not loaded. Please load your data into 'df0' before running the script.")

# Create a simplified object type column
df0['Simple_Object_Type'] = df0['Object_Type'].apply(lambda x: 'TDE' if x == 'TDE' else 'Other')

# Select numeric columns for the model and exclude specified columns
excluded_columns = [
    'SNID', 
    'PeakMag',
    'PeakMagErr',
    'REDSHIFT_FINAL',
    'REDSHIFT_FINAL_ERR',
    # Add other excluded columns if needed
]

# Exclude columns containing 'err', 'flux', or 'mjd'
excluded_columns += [col for col in df0.columns if 'err' in col.lower() or 'flux' in col.lower() or 'mjd' in col.lower()]

# Select feature columns
feature_columns = [col for col in df0.select_dtypes(include=[np.number]).columns if col not in excluded_columns]

# Renaming columns for clarity
rename_map = {
    'Mean_Color_Pre_Peak_gr': 'Mean Color Pre Peak (g-r)',
    'Mean_Color_Post_Peak_gr': 'Mean Color Post Peak (g-r)',
    'Mean_Color_Pre_Peak_ri': 'Mean Color Pre Peak (r-i)',
    'Mean_Color_Post_Peak_ri': 'Mean Color Post Peak (r-i)',
    'Slope_Pre_Peak_gr': 'Slope Pre Peak (g-r)',
    'Slope_Post_Peak_gr': 'Slope Post Peak (g-r)',
    'Slope_Pre_Peak_ri': 'Slope Pre Peak (r-i)',
    'Slope_Post_Peak_ri': 'Slope Post Peak (r-i)',
    'Rise_Time': 'Rise time',
    'Fade_Time': 'Fade time'
}
df0.rename(columns=rename_map, inplace=True)
feature_columns = [rename_map.get(col, col) for col in feature_columns]

# Define features and target
X = df0[feature_columns].fillna(-999)  # Handle missing values
y = df0['Simple_Object_Type'].apply(lambda x: 1 if x == 'TDE' else 0)  # Binary encoding

# Standardize the features
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# Convert scaled features back to DataFrame with feature names
X_scaled = pd.DataFrame(X_scaled, columns=feature_columns)

# Split the dataset with stratification to maintain class distribution
X_train, X_test, y_train, y_test = train_test_split(
    X_scaled, y, test_size=0.5, random_state=42, stratify=y
)

# Display class distribution
print("Class distribution:")
print("Training set:")
print(y_train.value_counts())
print("\nTest set:")
print(y_test.value_counts())

# ---------------------------------------
# Hyperparameter Tuning for Random Forest
# ---------------------------------------

from sklearn.ensemble import RandomForestClassifier

# Define the parameter grid for RandomizedSearchCV
param_dist_rf = {
    'n_estimators': randint(100, 1000),
    'max_depth': randint(5, 50),
    'min_samples_split': randint(2, 20),
    'min_samples_leaf': randint(1, 20),
    'max_features': ['auto', 'sqrt', 'log2', None],
    'bootstrap': [True, False],
    'class_weight': ['balanced', 'balanced_subsample', None]
}

# Initialize the Random Forest classifier
rf_clf = RandomForestClassifier(
    random_state=42,
    n_jobs=-1
)

# Define a custom scorer that emphasizes recall at a precision threshold
def recall_at_precision_threshold(y_true, y_probs, precision_threshold=0.80):
    precision, recall, thresholds = precision_recall_curve(y_true, y_probs)
    # Find the threshold where precision >= precision_threshold
    idx = np.where(precision >= precision_threshold)[0]
    if len(idx) == 0:
        return 0
    # Return the maximum recall at or above the precision threshold
    return recall[idx].max()

# Create a scorer for RandomizedSearchCV
def custom_scorer_rf(y_true, y_pred_proba):
    return recall_at_precision_threshold(y_true, y_pred_proba, precision_threshold=0.80)

from sklearn.metrics import make_scorer

scorer_rf = make_scorer(custom_scorer_rf, needs_proba=True)

# Initialize RandomizedSearchCV
random_search_rf = RandomizedSearchCV(
    estimator=rf_clf,
    param_distributions=param_dist_rf,
    n_iter=100,
    scoring=scorer_rf,
    cv=3,
    verbose=1,
    random_state=42,
    n_jobs=-1
)

# Perform hyperparameter tuning
random_search_rf.fit(X_train, y_train)

# Retrieve the best estimator
best_rf = random_search_rf.best_estimator_

print(f"Best parameters found: {random_search_rf.best_params_}")
print(f"Best recall at 80% precision: {random_search_rf.best_score_:.4f}")

# ---------------------------------------
# Threshold Determination
# ---------------------------------------

# Predict probabilities with the best model
y_probs_rf = best_rf.predict_proba(X_test)[:, 1]

# Define thresholds based on desired precision and recall
threshold_80p_precision_rf = find_threshold(y_test, y_probs_rf, target='precision', target_value=0.80)  # 80% Precision
threshold_95p_precision_rf = find_threshold(y_test, y_probs_rf, target='precision', target_value=0.95)  # 95% Precision
threshold_95p_recall_rf = find_threshold(y_test, y_probs_rf, target='recall', target_value=0.95)        # 95% Recall

thresholds_dict_rf = {
    '80% Precision': threshold_80p_precision_rf,
    '95% Precision': threshold_95p_precision_rf,
    '95% Recall': threshold_95p_recall_rf
}

# Function to get binary predictions based on a threshold
def get_predictions_rf(y_probs, threshold):
    return (y_probs >= threshold).astype(int)

# Generate predictions for the desired thresholds
predictions_rf = {
    label: get_predictions_rf(y_probs_rf, thresh)
    for label, thresh in thresholds_dict_rf.items()
}

# ---------------------------------------
# Plotting Confusion Matrices
# ---------------------------------------

# Titles for classifiers and thresholds
classifier_titles_rf = [
    'Random Forest - 80% Precision',
    'Random Forest - 95% Precision',
    'Random Forest - 95% Recall'
]

# Combine predictions into a list maintaining the order
all_predictions_rf = [
    predictions_rf['80% Precision'],
    predictions_rf['95% Precision'],
    predictions_rf['95% Recall']
]

# Define the save directory
save_directory_rf = '/home/bhardwaj/notebooksLSST/ELAsTiCC2_processed/RandomForest'
os.makedirs(save_directory_rf, exist_ok=True)

# Prepare subplots with wider horizontal layout
fig_rf, axs_rf = plt.subplots(1, 3, figsize=(18, 6))  # 1 row, 3 columns with a total size of 18x6 inches

# Plot confusion matrices
for i in range(3):
    cm_rf = confusion_matrix(y_test, all_predictions_rf[i], labels=[0, 1])
    title_rf = classifier_titles_rf[i]
    plot_confusion_matrix(cm_rf, axs_rf[i], title_rf, normalize=True)

# Adjust layout, save, and display the confusion matrices plot
plt.tight_layout()
confusion_matrix_path_rf = os.path.join(save_directory_rf, 'RandomForest_confusion_matrices.png')
plt.savefig(confusion_matrix_path_rf, dpi=300)
plt.show()
plt.close()  # Close the figure to free memory

# ---------------------------------------
# Plotting Precision and Recall Curves
# ---------------------------------------

# Define thresholds to mark
thresholds_to_mark_rf = {
    '80% Precision': threshold_80p_precision_rf,
    '95% Precision': threshold_95p_precision_rf
}

# Define save path for Precision-Recall curve
save_path_pr_rf = os.path.join(save_directory_rf, 'RandomForest_precision_recall.png')

# Define desired figure size for Precision-Recall curve
pr_figsize_rf = (6, 4)  # Adjust as needed

# Plot and save Precision-Recall curve
plot_precision_recall_curve(y_test, y_probs_rf, thresholds_to_mark_rf, save_path_pr_rf, figsize=pr_figsize_rf)

# ---------------------------------------
# Plotting Feature Importance
# ---------------------------------------

# Define save path for Feature Importance plot
save_path_fi_rf = os.path.join(save_directory_rf, 'RandomForest_feature_importance.png')

# Define desired figure size for Feature Importance plot
fi_figsize_rf = (10, 8)  # Adjust as needed

# Plot and save Feature Importance
plot_feature_importance(best_rf, feature_columns, save_path_fi_rf, figsize=fi_figsize_rf)

# Optionally, plot Permutation Feature Importance
save_path_perm_fi_rf = os.path.join(save_directory_rf, 'RandomForest_permutation_importance.png')
plot_permutation_importance_rf(best_rf, X_test, y_test, feature_columns, save_path_perm_fi_rf, figsize=(10, 8))

# ---------------------------------------
# Evaluation Metrics
# ---------------------------------------

def evaluate_model_rf(y_true, y_probs, predictions, classifier_name):
    print(f"\n{classifier_name} Classification Report:")
    print(classification_report(y_true, predictions, target_names=['Non-TDE', 'TDE']))
    print(f"ROC AUC: {roc_auc_score(y_true, y_probs):.4f}")

# Evaluate the Classifier at each threshold
for label in thresholds_dict_rf.keys():
    classifier_name_rf = f'Random Forest - {label}'
    evaluate_model_rf(y_test, y_probs_rf, predictions_rf[label], classifier_name_rf)


In [None]:
# Prepare subplots with wider horizontal layout
fig_rf, axs_rf = plt.subplots(1, 3, figsize=(8, 3))  # 1 row, 3 columns with a total size of 18x6 inches

# Plot confusion matrices
for i in range(3):
    cm_rf = confusion_matrix(y_test, all_predictions_rf[i], labels=[0, 1])
    title_rf = classifier_titles_rf[i]
    plot_confusion_matrix(cm_rf, axs_rf[i], title_rf, normalize=True)

# Adjust layout, save, and display the confusion matrices plot
plt.tight_layout()
confusion_matrix_path_rf = os.path.join(save_directory_rf, 'RandomForest_confusion_matrices.png')
plt.savefig(confusion_matrix_path_rf, dpi=300)
plt.show()
plt.close()  # Close the figure to free memory

# ---------------------------------------
# Plotting Precision and Recall Curves
# ---------------------------------------

# Define thresholds to mark
thresholds_to_mark_rf = {
    '80% Precision': threshold_80p_precision_rf,
    '95% Precision': threshold_95p_precision_rf
}

# Define save path for Precision-Recall curve
save_path_pr_rf = os.path.join(save_directory_rf, 'RandomForest_precision_recall.png')

# Define desired figure size for Precision-Recall curve
pr_figsize_rf = (6, 4)  # Adjust as needed

# Plot and save Precision-Recall curve
plot_precision_recall_curve(y_test, y_probs_rf, thresholds_to_mark_rf, save_path_pr_rf, figsize=pr_figsize_rf)

# ---------------------------------------
# Plotting Feature Importance
# ---------------------------------------

# Define save path for Feature Importance plot
save_path_fi_rf = os.path.join(save_directory_rf, 'RandomForest_feature_importance.png')

# Define desired figure size for Feature Importance plot
fi_figsize_rf = (6, 3)  # Adjust as needed

# Plot and save Feature Importance
plot_feature_importance(best_rf, feature_columns, save_path_fi_rf, figsize=fi_figsize_rf)

# Optionally, plot Permutation Feature Importance
save_path_perm_fi_rf = os.path.join(save_directory_rf, 'RandomForest_permutation_importance.png')
plot_permutation_importance_rf(best_rf, X_test, y_test, feature_columns, save_path_perm_fi_rf, figsize=(10, 8))

#### XGB

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, RandomizedSearchCV, StratifiedKFold
from sklearn.metrics import (
    confusion_matrix, precision_recall_curve, classification_report, roc_auc_score,
    make_scorer
)
import xgboost as xgb
import os
from scipy.stats import uniform, randint

# ---------------------------------------
# Helper Functions
# ---------------------------------------

def find_threshold(y_true, y_probs, target='precision', target_value=0.95):
    """
    Find the threshold for desired precision or recall.

    Parameters:
    - y_true: Ground truth binary labels.
    - y_probs: Predicted probabilities for the positive class.
    - target: 'precision' or 'recall'.
    - target_value: Desired value for the target metric.

    Returns:
    - Threshold value.
    """
    precision, recall, thresholds = precision_recall_curve(y_true, y_probs)

    if target == 'precision':
        # Find indices where precision >= target_value
        idx = np.where(precision >= target_value)[0]
        if len(idx) == 0:
            print(f"No threshold found to achieve {target_value*100}% precision. Returning threshold=1.0")
            return 1.0  # Maximum threshold
        # Choose the threshold corresponding to the first occurrence
        return thresholds[idx[0]]

    elif target == 'recall':
        # Find indices where recall >= target_value
        idx = np.where(recall >= target_value)[0]
        if len(idx) == 0:
            print(f"No threshold found to achieve {target_value*100}% recall. Returning threshold=0.0")
            return 0.0  # Minimum threshold
        # Choose the threshold corresponding to the last occurrence
        return thresholds[idx[-1]]

    else:
        raise ValueError("Target must be 'precision' or 'recall'.")

def plot_confusion_matrix(cm, ax, title, normalize=True):
    """
    Plot a confusion matrix with percentages and counts.

    Parameters:
    - cm: Confusion matrix.
    - ax: Matplotlib Axes object.
    - title: Title for the subplot.
    - normalize: Whether to normalize the confusion matrix per true label.
    """
    if normalize:
        cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    else:
        cm_normalized = cm

    sns.heatmap(
        cm_normalized, annot=False, fmt='.2f', cmap='Blues', ax=ax,
        xticklabels=['Non-TDE', 'TDE'], yticklabels=['Non-TDE', 'TDE'], cbar=False
    )

    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            if normalize:
                percentage = cm_normalized[i, j] * 100
                count = cm[i, j]
                text = f'{percentage:.1f}%\n({count})'
            else:
                text = f'{cm[i, j]}'
            ax.text(
                j + 0.5, i + 0.5, text, ha='center', va='center',
                color='black', fontsize=8, bbox=dict(facecolor='white', edgecolor='white')
            )

    ax.set_xlabel('Predicted Label', fontsize=10)
    ax.set_ylabel('True Label', fontsize=10)
    ax.set_title(title, fontsize=12)

def plot_precision_recall_curve(y_true, y_probs, thresholds_to_mark, save_path, figsize=(6, 4)):
    """
    Plot Precision and Recall as functions of the threshold.

    Parameters:
    - y_true: Ground truth binary labels.
    - y_probs: Predicted probabilities for the positive class.
    - thresholds_to_mark: Dictionary with labels and threshold values to mark.
    - save_path: Path to save the plot.
    - figsize: Tuple specifying the figure size.
    """
    precision, recall, thresholds = precision_recall_curve(y_true, y_probs)

    plt.figure(figsize=figsize)
    plt.plot(thresholds, precision[:-1], label='Precision', color='blue')
    plt.plot(thresholds, recall[:-1], label='Recall', color='green')

    for label, thresh in thresholds_to_mark.items():
        if '95% Precision' in label:
            color = 'red'  # Change color for '95% Precision' threshold
        else:
            color = 'blue'  # Default color for other thresholds
        plt.axvline(x=thresh, linestyle='--', color=color, label=f'{label} ({thresh:.2f})')

    plt.xlabel('Threshold', fontsize=10)
    plt.ylabel('Score', fontsize=10)
    plt.title('Precision and Recall vs. Threshold', fontsize=12)
    plt.legend(fontsize=8)
    plt.grid(True, linestyle='--', alpha=0.5)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300)
    plt.show()
    plt.close()

def plot_feature_importance(model, feature_names, save_path, figsize=(8, 6)):
    """
    Plot and save the feature importance from the XGBoost model.

    Parameters:
    - model: Trained XGBoost model.
    - feature_names: List of feature names.
    - save_path: Path to save the plot.
    - figsize: Tuple specifying the figure size.
    """
    plt.figure(figsize=figsize)
    ax = plt.gca()
    xgb.plot_importance(
        model,
        max_num_features=20,
        importance_type='gain',
        show_values=False,
        ax=ax
    )
  #  plt.title('Feature Importance', fontsize=12)
    plt.xlabel('Importance (gain)', fontsize=10)
    plt.ylabel('Features', fontsize=10)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300)
    plt.show()
    plt.close()

# ---------------------------------------
# Data Preparation
# ---------------------------------------

# Ensure df0 is loaded
if 'df0' not in locals():
    raise ValueError("DataFrame 'df0' is not loaded. Please load your data into 'df0' before running the script.")

# Create a simplified object type column
df0['Simple_Object_Type'] = df0['Object_Type'].apply(lambda x: 'TDE' if x == 'TDE' else 'Other')

# Select numeric columns for the model and exclude specified columns
excluded_columns = [
    'SNID', 
    'PeakMag',
    'PeakMagErr',
    'REDSHIFT_FINAL',
    'REDSHIFT_FINAL_ERR',
    # Add other excluded columns if needed
]

# Exclude columns containing 'err', 'flux', or 'mjd'
excluded_columns += [col for col in df0.columns if 'err' in col.lower() or 'flux' in col.lower() or 'mjd' in col.lower()]

# Select feature columns
feature_columns = [col for col in df0.select_dtypes(include=[np.number]).columns if col not in excluded_columns]

# Renaming columns for clarity
rename_map = {
    'Mean_Color_Pre_Peak_gr': 'Mean Color Pre Peak (g-r)',
    'Mean_Color_Post_Peak_gr': 'Mean Color Post Peak (g-r)',
    'Mean_Color_Pre_Peak_ri': 'Mean Color Pre Peak (r-i)',
    'Mean_Color_Post_Peak_ri': 'Mean Color Post Peak (r-i)',
    'Slope_Pre_Peak_gr': 'Slope Pre Peak (g-r)',
    'Slope_Post_Peak_gr': 'Slope Post Peak (g-r)',
    'Slope_Pre_Peak_ri': 'Slope Pre Peak (r-i)',
    'Slope_Post_Peak_ri': 'Slope Post Peak (r-i)',
    'Rise_Time': 'Rise time',
    'Fade_Time': 'Fade time'
}
df0.rename(columns=rename_map, inplace=True)
feature_columns = [rename_map.get(col, col) for col in feature_columns]

# ---------------------------------------
# Feature Selection
# ---------------------------------------

# Define the list of selected features
# Modify this list based on your requirements
# Example: Selecting GP hyperparameters, Rise time, Fade time, and Mean Color Post Peak
selected_features = [
    'Amplitude',  
    'LengthScale_Time',
    'LengthScale_Wavelength',
    'Mean Color Post Peak (g-r)',
    'Mean Color Post Peak (r-i)',
    'Mean Color Pre Peak (g-r)',
    'Mean Color Pre Peak (r-i)',
    'Slope Pre Peak (g-r)',
    'Slope Post Peak (g-r)',
    'Slope Pre Peak (r-i)',
    'Slope Post Peak (r-i)',
    'Rise time',
    'Fade time']

# Validate selected features
missing_features = set(selected_features) - set(feature_columns)
if missing_features:
    raise ValueError(f"The following selected features are not present in the feature columns: {missing_features}")

# Update feature_columns to include only selected features
feature_columns_selected = selected_features

# Define features and target using selected features
X = df0[feature_columns_selected].fillna(-999)  # Handle missing values
y = df0['Simple_Object_Type'].apply(lambda x: 1 if x == 'TDE' else 0)  # Binary encoding

# Standardize the features
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# Convert scaled features back to DataFrame with selected feature names
X_scaled = pd.DataFrame(X_scaled, columns=feature_columns_selected)

# Split the dataset with stratification to maintain class distribution
X_train, X_test, y_train, y_test = train_test_split(
    X_scaled, y, test_size=0.5, random_state=42, stratify=y
)

# Display class distribution
print("Class distribution:")
print("Training set:")
print(y_train.value_counts())
print("\nTest set:")
print(y_test.value_counts())

# ---------------------------------------
# Hyperparameter Tuning
# ---------------------------------------

# Define the parameter grid for RandomizedSearchCV
param_dist = {
    'n_estimators': randint(100, 500),
    'learning_rate': uniform(0.01, 0.3),
    'max_depth': randint(3, 8),
    'min_child_weight': randint(1, 10),
    'subsample': uniform(0.5, 0.5),
    'colsample_bytree': uniform(0.5, 0.5),
    'gamma': uniform(0, 5),
    'scale_pos_weight': uniform(1, 10)  # Useful for imbalanced classes
}

# Initialize the XGBoost classifier
xgb_clf = xgb.XGBClassifier(
    objective='binary:logistic',
    use_label_encoder=False,
    eval_metric='logloss',
    random_state=42
)

# Define a custom scorer that emphasizes recall at a specified precision threshold
def recall_at_precision_threshold(y_true, y_pred_proba, precision_threshold=0.80):
    precision, recall, thresholds = precision_recall_curve(y_true, y_pred_proba)
    # Find indices where precision >= precision_threshold
    idx = np.where(precision >= precision_threshold)[0]
    if len(idx) == 0:
        return 0
    # Return the maximum recall at or above the precision threshold
    return recall[idx].max()

# Create a scorer for RandomizedSearchCV
def custom_scorer(y_true, y_pred_proba):
    return recall_at_precision_threshold(y_true, y_pred_proba, precision_threshold=0.80)

# Register the custom scorer
scorer = make_scorer(custom_scorer, needs_proba=True, greater_is_better=True)

# Initialize StratifiedKFold for cross-validation
skf = StratifiedKFold(n_splits=3, shuffle=True, random_state=42)

# Initialize RandomizedSearchCV
random_search = RandomizedSearchCV(
    estimator=xgb_clf,
    param_distributions=param_dist,
    n_iter=100,
    scoring=scorer,
    cv=skf,
    verbose=1,
    random_state=42,
    n_jobs=-1
)

# Perform hyperparameter tuning
random_search.fit(X_train, y_train)

# Retrieve the best estimator
best_xgb = random_search.best_estimator_

print(f"Best parameters found: {random_search.best_params_}")
print(f"Best recall at 80% precision: {random_search.best_score_:.4f}")

# ---------------------------------------
# Threshold Determination
# ---------------------------------------

# Predict probabilities with the best model
y_probs = best_xgb.predict_proba(X_test)[:, 1]

# Define thresholds based on desired precision and recall
threshold_80p_precision = find_threshold(y_test, y_probs, target='precision', target_value=0.80)  # 80% Precision
threshold_95p_precision = find_threshold(y_test, y_probs, target='precision', target_value=0.95)  # 95% Precision
threshold_95p_recall = find_threshold(y_test, y_probs, target='recall', target_value=0.95)        # 95% Recall

thresholds_dict = {
    '80% Precision': threshold_80p_precision,
    '95% Precision': threshold_95p_precision,
    '95% Recall': threshold_95p_recall
}

# Function to get binary predictions based on a threshold
def get_predictions(y_probs, threshold):
    return (y_probs >= threshold).astype(int)

# Generate predictions for the desired thresholds
predictions = {
    label: get_predictions(y_probs, thresh)
    for label, thresh in thresholds_dict.items()
}

# ---------------------------------------
# Plotting Confusion Matrices
# ---------------------------------------

# Titles for classifiers and thresholds
classifier_titles = [
    'XGB - 80% Precision',
    'XGB - 95% Precision',
    'XGB - 95% Recall'
]

# Combine predictions into a list maintaining the order
all_predictions = [
    predictions['80% Precision'],
    predictions['95% Precision'],
    predictions['95% Recall']
]

# Define the save directory
save_directory = '/home/bhardwaj/notebooksLSST/ELAsTiCC2_processed/XGB'
os.makedirs(save_directory, exist_ok=True)

# Prepare subplots with wider horizontal layout
fig, axs = plt.subplots(1, 3, figsize=(8, 3))  # 1 row, 3 columns with a total size of 18x6 inches

# Plot confusion matrices
for i in range(3):
    cm = confusion_matrix(y_test, all_predictions[i], labels=[0, 1])
    title = classifier_titles[i]
    plot_confusion_matrix(cm, axs[i], title, normalize=True)

# Adjust layout, save, and display the confusion matrices plot
plt.tight_layout()
confusion_matrix_path = os.path.join(save_directory, 'XGB_confusion_matrices.png')
plt.savefig(confusion_matrix_path, dpi=300)
plt.show()
plt.close()  # Close the figure to free memory

# ---------------------------------------
# Plotting Precision and Recall Curves
# ---------------------------------------

# Define thresholds to mark
thresholds_to_mark = {
    '80% Precision': threshold_80p_precision,
    '95% Precision': threshold_95p_precision
}

# Define save path for Precision-Recall curve
save_path_pr = os.path.join(save_directory, 'XGB_precision_recall.png')

# Define desired figure size for Precision-Recall curve
pr_figsize = (6, 4)  # Adjust as needed

# Plot and save Precision-Recall curve
plot_precision_recall_curve(y_test, y_probs, thresholds_to_mark, save_path_pr, figsize=pr_figsize)

# ---------------------------------------
# Plotting Feature Importance
# ---------------------------------------

# Define save path for Feature Importance plot
save_path_fi = os.path.join(save_directory, 'XGB_feature_importance.png')

# Define desired figure size for Feature Importance plot
fi_figsize = (6, 3)  # Adjust as needed

# Plot and save Feature Importance
plot_feature_importance(best_xgb, feature_columns_selected, save_path_fi, figsize=fi_figsize)

# ---------------------------------------
# Evaluation Metrics
# ---------------------------------------

def evaluate_model(y_true, y_probs, predictions, classifier_name):
    print(f"\n{classifier_name} Classification Report:")
    print(classification_report(y_true, predictions, target_names=['Non-TDE', 'TDE']))
    print(f"ROC AUC: {roc_auc_score(y_true, y_probs):.4f}")

# Evaluate the Classifier at each threshold
for label in thresholds_dict.keys():
    classifier_name = f'XGB - {label}'
    evaluate_model(y_test, y_probs, predictions[label], classifier_name)


### Using actual Train and Test data

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score
import xgboost as xgb
from imblearn.over_sampling import SMOTE

# Suppress warnings for cleaner output
import warnings
warnings.filterwarnings('ignore')

# File paths
TRAIN_DATA_PATH = '/home/bhardwaj/notebooksLSST/ELAsTiCC2_processed/ELAsTiCC2_snr5.csv'
TEST_DATA_PATH = '/home/bhardwaj/notebooksLSST/ELAsTiCC2_processed/hdf5_files/ELAsTiCC2_Test_cleaned.csv'

# Load the training data
print("Loading training data...")
df_train = pd.read_csv(TRAIN_DATA_PATH)
print("Training data loaded successfully.\n")

# Load the test data
print("Loading test data...")
df_test = pd.read_csv(TEST_DATA_PATH)
print("Test data loaded successfully.\n")

# Function to display column names
def display_columns(df, df_name):
    print(f"Columns in {df_name}:")
    for col in df.columns:
        print(f"- {col}")
    print("\n")

# Display columns before renaming
display_columns(df_train, "Training Data (Before Renaming)")
display_columns(df_test, "Test Data (Before Renaming)")

# Define rename map to handle both training and test data
rename_map = {
    # Training data mappings
    'Mean_Color_Pre_Peak': 'Mean Color Pre Peak (g-r)',
    'Mean_Color_Post_Peak': 'Mean Color Post Peak (g-r)',
    'Slope_Pre_Peak': 'Slope Pre Peak (g-r)',
    'Slope_Post_Peak': 'Slope Post Peak (g-r)',
    'Mean_Color_Pre_Peak_RI': 'Mean Color Pre Peak (r-i)',
    'Mean_Color_Post_Peak_RI': 'Mean Color Post Peak (r-i)',
    'Slope_Pre_Peak_RI': 'Slope Pre Peak (r-i)',
    'Slope_Post_Peak_RI': 'Slope Post Peak (r-i)',
    'Rise_Time': 'Rise time',
    'Fade_Time': 'Fade time',
    'Amplitude': 'Amplitude',
    'LengthScale_Time': 'LengthScale Time',
    'LengthScale_Wavelength': 'LengthScale Wavelength',
    'TruePeakMJD': 'True Peak MJD',
    'peak_time_MJD': 'Peak Time MJD',
    
    # Test data mappings
    'Mean_Color_Pre_Peak_gr': 'Mean Color Pre Peak (g-r)',
    'Mean_Color_Post_Peak_gr': 'Mean Color Post Peak (g-r)',
    'Slope_Pre_Peak_gr': 'Slope Pre Peak (g-r)',
    'Slope_Post_Peak_gr': 'Slope Post Peak (g-r)',
    'Mean_Color_Pre_Peak_ri': 'Mean Color Pre Peak (r-i)',
    'Mean_Color_Post_Peak_ri': 'Mean Color Post Peak (r-i)',
    'Slope_Pre_Peak_ri': 'Slope Pre Peak (r-i)',
    'Slope_Post_Peak_ri': 'Slope Post Peak (r-i)',
}

# Function to preprocess data
def preprocess_data(df, rename_map, feature_columns, excluded_columns, df_name=""):
    # Create 'Simple_Object_Type' label
    if 'Object_Type' in df.columns:
        df['Simple_Object_Type'] = df['Object_Type'].apply(lambda x: 'TDE' if x == 'TDE' else 'Other')
    else:
        print(f"'Object_Type' column not found in {df_name}. Ensure the test data contains this column.")
        # If Object_Type is not present, handle accordingly
        df['Simple_Object_Type'] = np.nan  # Placeholder

    # Rename columns for clarity
    df.rename(columns=rename_map, inplace=True)

    # Update feature columns based on renaming
    renamed_feature_columns = [rename_map.get(col, col) for col in feature_columns]

    # Identify which renamed_feature_columns are present in df
    available_features = [col for col in renamed_feature_columns if col in df.columns]
    missing_features = [col for col in renamed_feature_columns if col not in df.columns]

    if missing_features:
        print(f"Warning: The following expected columns are missing in {df_name}:")
        for col in missing_features:
            print(f"- {col}")
        print("These columns will be excluded from the feature set.\n")

    # Select features that are available
    X = df[available_features].fillna(-999)  # Handle missing values
    y = df['Simple_Object_Type'].apply(lambda x: 1 if x == 'TDE' else 0)  # Binary encoding

    return X, y, available_features, missing_features

# Define columns to exclude (ignore 'err', 'flux', 'mjd', 'redshift')
excluded_columns = [
    'SNID', 
    'REDSHIFT_FINAL',
    'REDSHIFT_FINAL_ERR',
    'Error'  # From Test Data
]

# Exclude columns containing 'err', 'flux', or 'mjd' from training and test data
excluded_columns += [col for col in df_train.columns if 'err' in col.lower() or 'flux' in col.lower() or 'mjd' in col.lower()]
excluded_columns += [col for col in df_test.columns if 'err' in col.lower() or 'flux' in col.lower() or 'mjd' in col.lower()]

# Select numeric feature columns from training data (excluding specified columns)
feature_columns = [col for col in df_train.select_dtypes(include=[np.number]).columns if col not in excluded_columns]

# Preprocess training data
print("Preprocessing training data...")
X_train, y_train, available_train_features, missing_train_features = preprocess_data(
    df_train, rename_map, feature_columns, excluded_columns, df_name="Training Data"
)
print("Training data preprocessed.\n")

# Preprocess test data
print("Preprocessing test data...")
X_test, y_test, available_test_features, missing_test_features = preprocess_data(
    df_test, rename_map, feature_columns, excluded_columns, df_name="Test Data"
)
print("Test data preprocessed.\n")

# Display available features
print("Available Features in Training Data:", available_train_features)
print("Available Features in Test Data:", available_test_features)
print("\n")

# Ensure that training and test have the same feature set
common_features = list(set(available_train_features).intersection(set(available_test_features)))
print(f"Number of common features between Training and Test data: {len(common_features)}")
print("Common Features:")
for feature in common_features:
    print(f"- {feature}")
print("\n")

# If there are no common features, you cannot proceed further
if not common_features:
    raise ValueError("No common features found between Training and Test data after renaming. Please check the rename_map and feature selection.")

# Update X_train and X_test to include only common_features
X_train = X_train[common_features]
X_test = X_test[common_features]

# Standardize the features
print("Standardizing features...")
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
print("Feature scaling completed.\n")

# Display class distribution before SMOTE
print("Class distribution before SMOTE:")
print("Training set:")
print(y_train.value_counts())
print("\nTest set:")
print(y_test.value_counts())

# Apply SMOTE to the training data
print("\nApplying SMOTE to the training data...")
smote = SMOTE(random_state=42)
X_train_smote, y_train_smote = smote.fit_resample(X_train_scaled, y_train)
print("SMOTE applied successfully.\n")

# Display class distribution after SMOTE
print("Class distribution after SMOTE:")
print(pd.Series(y_train_smote).value_counts())

# Function to plot confusion matrices with specified normalization
def plot_confusion_matrix(y_true, y_pred, ax, title, normalize='pred'):
    """
    Plots a confusion matrix with specified normalization.
    
    Parameters:
    - y_true: True labels
    - y_pred: Predicted labels
    - ax: Matplotlib Axes object
    - title: Title for the plot
    - normalize: 'pred' for prediction normalized, 'true' for truth normalized
    """
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    
    if normalize == 'pred':
        cm_normalized = cm.astype('float') / cm.sum(axis=0)
    elif normalize == 'true':
        cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    else:
        cm_normalized = cm
    
    sns.heatmap(cm_normalized, annot=False, fmt=".2f", cmap='Blues', ax=ax, 
                xticklabels=['Non-TDE', 'TDE'], yticklabels=['Non-TDE', 'TDE'])
    
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            if normalize == 'pred':
                percentage = cm_normalized[i, j] * 100 if cm.sum(axis=0)[j] != 0 else 0
            elif normalize == 'true':
                percentage = cm_normalized[i, j] * 100 if cm.sum(axis=1)[i] != 0 else 0
            else:
                percentage = cm_normalized[i, j]
            text = f'{percentage:.1f}%\n({cm[i, j]})'
            ax.text(j + 0.5, i + 0.5, text, ha='center', va='center', color='black', fontsize=10, 
                    bbox=dict(facecolor='white', edgecolor='white'))
    
    ax.set_xlabel('Predicted Label')
    ax.set_ylabel('True Label')
    ax.set_title(title)

# Train XGBoost classifier on original training data
print("Training XGBoost classifier on original training data...")
xgb_original = xgb.XGBClassifier(
    n_estimators=100,
    learning_rate=0.3,
    max_depth=3,
    random_state=42,
    use_label_encoder=False,
    eval_metric='logloss'
)
xgb_original.fit(X_train_scaled, y_train)
y_pred_original = xgb_original.predict(X_test_scaled)
y_probs_original = xgb_original.predict_proba(X_test_scaled)[:, 1]
print("Original XGBoost classifier trained.\n")

# Train XGBoost classifier on SMOTE-augmented training data
print("Training XGBoost classifier on SMOTE-augmented training data...")
xgb_smote = xgb.XGBClassifier(
    n_estimators=100,
    learning_rate=0.3,
    max_depth=3,
    random_state=42,
    use_label_encoder=False,
    eval_metric='logloss'
)
xgb_smote.fit(X_train_smote, y_train_smote)
y_pred_smote = xgb_smote.predict(X_test_scaled)
y_probs_smote = xgb_smote.predict_proba(X_test_scaled)[:, 1]
print("SMOTE XGBoost classifier trained.\n")

# Plot confusion matrices with smaller figures
fig, axs = plt.subplots(2, 2, figsize=(12, 10))  # Reduced figure size

# Original Classifier - Prediction Normalized
plot_confusion_matrix(
    y_test, y_pred_original, axs[0, 0], 'XGB Original - Prediction Normalized', normalize='pred'
)

# Original Classifier - Truth Normalized
plot_confusion_matrix(
    y_test, y_pred_original, axs[0, 1], 'XGB Original - Truth Normalized', normalize='true'
)

# SMOTE Classifier - Prediction Normalized
plot_confusion_matrix(
    y_test, y_pred_smote, axs[1, 0], 'XGB SMOTE - Prediction Normalized', normalize='pred'
)

# SMOTE Classifier - Truth Normalized
plot_confusion_matrix(
    y_test, y_pred_smote, axs[1, 1], 'XGB SMOTE - Truth Normalized', normalize='true'
)

plt.tight_layout()
plt.show()

# Function to plot feature importance
def plot_feature_importance(model, ax, title, feature_names):
    xgb.plot_importance(model, max_num_features=20, importance_type='weight', ax=ax)
    # Ensure that the yticklabels correspond to feature names
    ax.set_yticklabels([
        rename_map.get(feature, feature) for feature in feature_names
    ], fontsize=10)
    ax.set_xlabel('Feature Importance', fontsize=10)
    ax.set_title(title, fontsize=12)

# Plot feature importances for both classifiers with smaller figures
fig, axs = plt.subplots(1, 2, figsize=(12, 6))  # Reduced figure size

# Original Classifier Feature Importance
plot_feature_importance(
    xgb_original, axs[0], 'XGBoost Feature Importance - Original', common_features
)

# SMOTE Classifier Feature Importance
plot_feature_importance(
    xgb_smote, axs[1], 'XGBoost Feature Importance - SMOTE', common_features
)

plt.tight_layout()
plt.show()

# Additional Evaluation Metrics
print("Evaluation Metrics:\n")

print("Original Classifier:")
print(classification_report(y_test, y_pred_original))
print(f"ROC AUC: {roc_auc_score(y_test, y_probs_original):.4f}\n")

print("SMOTE Classifier:")
print(classification_report(y_test, y_pred_smote))
print(f"ROC AUC: {roc_auc_score(y_test, y_probs_smote):.4f}\n")


In [None]:
# Target class label
cls = 1  # 'TDE' class is encoded as 1

# Get the probabilities for the target class
probs_cls = y_probs_xgb_original[:, cls]
pr, recall, thresholds = precision_recall_curve(y_test, probs_cls)

index = np.arange(len(thresholds))

# Loose threshold: recall >= 0.96
mask = recall[1:] >= 0.90
if np.any(mask):
    loose_index = max(index[mask])
    threshold_loose = thresholds[loose_index]
    print(f"Loose threshold {threshold_loose:.2f}, Purity={100.*pr[loose_index]:.1f}%, Efficiency={100.*recall[loose_index]:.1f}%")

# Strict threshold: precision >= 0.95
mask = pr[:-1] >= 0.95
if np.any(mask):
    strict_index = min(index[mask])
    threshold_strict = thresholds[strict_index]
    print(f"Strict Threshold {threshold_strict:.2f}, Purity={100.*pr[strict_index]:.1f}%, Efficiency={100.*recall[strict_index]:.1f}%")

# Balanced threshold: precision >= 0.90
mask = pr[:-1] >= 0.90
if np.any(mask):
    balanced_index = min(index[mask])
    threshold_balanced = thresholds[balanced_index]
    print(f"Balanced Threshold {threshold_balanced:.2f}, Purity={100.*pr[balanced_index]:.1f}%, Efficiency={100.*recall[balanced_index]:.1f}%")

In [None]:
# Function to find the threshold where precision reaches a specific value
def find_threshold_for_precision(precision, thresholds, target_precision=0.95):
    for p, t in zip(precision, thresholds):
        if p >= target_precision:
            return t
    return None

# Function to plot precision and recall as a function of the threshold
def plot_precision_recall_thresholds(y_true, y_probs, model_name):
    # Get the probabilities for the 'TDE' class (assuming 'TDE' is encoded as 1 and 'Other' as 0)
    probs_tde = y_probs[:, 1]
    precision_tde, recall_tde, thresholds_tde = precision_recall_curve(y_true == 1, probs_tde)

    # Get the probabilities for the 'Non-TDE' class (assuming 'Other' is encoded as 0)
    probs_non_tde = y_probs[:, 0]
    precision_non_tde, recall_non_tde, thresholds_non_tde = precision_recall_curve(y_true == 0, probs_non_tde)

    # Find the threshold for 95% purity (precision) for TDE
    threshold_95_tde = find_threshold_for_precision(precision_tde, thresholds_tde)

    plt.figure(figsize=(7, 3.5))

    # Plot for TDE class
    plt.plot(thresholds_tde, precision_tde[:-1], label='Purity TDE', color='blue')
    plt.plot(thresholds_tde, recall_tde[:-1], label='Efficiency TDE', color='blue', linestyle='--')

    # Plot for Non-TDE class
    plt.plot(thresholds_non_tde, precision_non_tde[:-1], label='Purity Non-TDE', color='red')
    plt.plot(thresholds_non_tde, recall_non_tde[:-1], label='Efficiency Non-TDE', color='red', linestyle='--')

    # Add vertical dashed lines for 95% purity for TDE
    if threshold_95_tde is not None:
        plt.axvline(x=threshold_95_tde, color='gray', linestyle='--', label='95% Purity Threshold for TDE')

    plt.xlabel('Threshold')
    plt.ylabel('Score')
    plt.text(0.5, 0.5, model_name, fontsize=14, ha='center', transform=plt.gca().transAxes, bbox=dict(facecolor='white', edgecolor='white'))
    plt.legend(loc='lower right', fontsize='small')
    plt.grid(True)
    plt.show()

# For Random Forest
#plot_precision_recall_thresholds(y_test, y_probs_rf_original, 'Random Forest')

# For XGBoost
plot_precision_recall_thresholds(y_test, y_probs_xgb_original, 'XGBoost')


## Check fits 

In [None]:
def read_elasticc_file(filename):
    if '_PHOT' in filename:
        headname = filename.replace('_PHOT', '_HEAD')
    else:
        headname = filename
        filename = filename.replace('_HEAD', '_PHOT')

    # Debug prints to verify paths
    #print(f"reading phot file: {filename}")
    print(f"reading head file: {headname}")
    
    if not os.path.exists(filename):
        raise FileNotFoundError(f"File not found: {filename}")
    if not os.path.exists(headname):
        raise FileNotFoundError(f"File not found: {headname}")

    table = Table.read(filename)
    head = Table.read(headname)

    # Sanitize the data
    for _ in table:
        _['BAND'] = _['BAND'].strip()

    head['SNID'] = np.int64(head['SNID'])
    
    # Sanity check 
    if np.sum(table['MJD'] < 0) != len(head):
        print(filename, 'is broken:', np.sum(table['MJD'] < 0), '!=', len(head))
        
    # Measured mag and magerr - simulated one is in SIM_MAGOBS
    table['mag'] = np.nan
    table['magerr'] = np.nan
    idx = table['FLUXCAL'] > 0
    
    table['mag'][idx] = 27.5 - 2.5 * np.log10(table['FLUXCAL'][idx])
    table['magerr'][idx] = 2.5 / np.log(10) * table['FLUXCALERR'][idx] / table['FLUXCAL'][idx]
    
    # Augment table with SNID (light curve id) from head
    table['SNID'] = 0
    
    idx = np.where(table['MJD'] < 0)[0]
    idx = np.hstack((np.array([0]), idx))

    for i in range(1, len(idx)):
        i0, i1 = idx[i - 1], idx[i]
        table['SNID'][i0:i1] = head['SNID'][i - 1]
    
    table = table[table['MJD'] > 0]

    return table, head


In [None]:
def try_optimization(gp, neg_ln_like, grad_neg_ln_like, initial_guess, retries=3):
    best_result = None
    for attempt in range(retries):
        result = minimize(neg_ln_like, initial_guess, jac=grad_neg_ln_like, method='L-BFGS-B')
        if best_result is None or (result.success and result.fun < best_result.fun):
            best_result = result
        if result.success:
            break
        else:
            # Slightly perturb the initial guess for the next attempt
            initial_guess += np.random.normal(0, 1e-2, size=initial_guess.shape)
    
    if best_result is None or not best_result.success:
        print(f"All optimization attempts failed for SNID {shead['SNID']}.")
        gp.set_parameter_vector(initial_guess)  # Use the best guess available
    else:
        gp.set_parameter_vector(best_result.x)

In [None]:
def compute_gp(sub, shead, verbose=False):
    try:
        # Convert inputs to numpy arrays with appropriate dtype
        t = np.array(sub['MJD'], dtype=float)
        flux = np.array(sub['FLUXCAL'], dtype=float)
        fluxerr = np.array(sub['FLUXCALERR'], dtype=float)
        band = np.array([lsst_bands.get(b) for b in sub['BAND']], dtype=float)
        
        # 2D positions of data points (time and wavelength)
        x = np.vstack([t, band]).T
        
        # Clean the data: remove rows with NaNs, infs, and non-positive flux values
        mask = np.isfinite(flux) & np.isfinite(fluxerr) & np.all(np.isfinite(x), axis=1) & (flux > 0)
        x = x[mask]
        flux = flux[mask]
        fluxerr = fluxerr[mask]

        if len(flux) < 5:  # Ensure there are enough data points
            raise ValueError("Not enough data points to perform GP fitting.")
        
        # Calculate signal-to-noise ratios
        signal_to_noises = np.abs(flux) / np.sqrt(fluxerr ** 2 + (1e-2 * np.max(flux)) ** 2)
        scale = np.abs(flux[np.argmax(signal_to_noises)])

        # Define the kernel for GP
        kernel = (0.5 * scale) ** 2 * george.kernels.Matern32Kernel([100 ** 2, 6000 ** 2], ndim=2)
        
        # Define the GP model with HODLR solver and white noise
        gp = george.GP(kernel, white_noise=np.log(np.var(fluxerr)), fit_white_noise=True, solver=george.HODLRSolver)
        
        # Compute the GP
        gp.compute(x, fluxerr)
        
        # Define the negative log likelihood and its gradient for optimization
        def neg_ln_like(p):
            gp.set_parameter_vector(p)
            return -gp.log_likelihood(flux)
        
        def grad_neg_ln_like(p):
            gp.set_parameter_vector(p)
            return -gp.grad_log_likelihood(flux)
        
        # Attempt optimization with multiple initial guesses
        initial_guess = gp.get_parameter_vector()
        try_optimization(gp, neg_ln_like, grad_neg_ln_like, initial_guess)
        
        # Return the GP, flux, data points, and final parameters
        return gp, flux, x, gp.get_parameter_vector()

    except (ValueError, np.linalg.LinAlgError, Exception) as e:
        print(f"GP optimization failed for SNID {shead['SNID']}: {e}")
        return None, None, None, None  # Return None to indicate failure


In [None]:
def peak_and_risefade(gp, x, flux, peak_threshold=2.512):
    """
    Find the peak using Gaussian Process predictions for both g and r bands.
    Then calculate the rise and fade times relative to the peak.
    """

    # Define bands
    bands = ['g', 'r']
    t_min, t_max = x[:, 0].min(), x[:, 0].max()

    peak_mjd_g, peak_flux_g = None, None
    peak_mjd_r, peak_flux_r = None, None

    # Use the provided x values directly and extend for GP predictions
    mjd_for_pred = np.linspace(t_min - 10, t_max + 10, 1000)

    for band in bands:
        wavelength = lsst_bands[band]
        x_pred = np.vstack([mjd_for_pred, wavelength * np.ones_like(mjd_for_pred)]).T

        # Predict flux at these times using the GP model
        mean_pred, _ = gp.predict(flux, x_pred, return_var=True)

        # Find the time of peak flux
        peak_flux_idx = np.argmax(mean_pred)
        peak_mjd = mjd_for_pred[peak_flux_idx]
        peak_flux = mean_pred[peak_flux_idx]

        if band == 'g':
            peak_mjd_g, peak_flux_g = peak_mjd, peak_flux
        elif band == 'r':
            peak_mjd_r, peak_flux_r = peak_mjd, peak_flux

    # Choose peak time and flux
    t_peak = peak_mjd_g if peak_mjd_g is not None else peak_mjd_r
    fpeak = peak_flux_g if peak_flux_g is not None else peak_flux_r

    # Find the threshold flux for defining rise and fade times
    threshold_flux = fpeak / peak_threshold  # Threshold for rise and fade time calculation

    # Calculate rise time relative to the peak
    try:
        first_detection_index = np.where(mean_pred[:peak_flux_idx] <= threshold_flux)[0]  # <= for fainter
        if len(first_detection_index) > 0:
            rise_time = t_peak - mjd_for_pred[first_detection_index[-1]]
        else:
            rise_time = None
    except IndexError:
        rise_time = None

    # Calculate fade time relative to the peak
    try:
        last_detection_index = np.where(mean_pred[peak_flux_idx:] <= threshold_flux)[0]  # <= for fainter
        if len(last_detection_index) > 0:
            fade_time = mjd_for_pred[peak_flux_idx + last_detection_index[0]] - t_peak
        else:
            fade_time = None
    except IndexError:
        fade_time = None

    # Fallback to default values only if necessary
    if rise_time is None:
        rise_time = 50  # Default rise time
    if fade_time is None:
        fade_time = 75  # Default fade time

    return rise_time, fade_time, t_peak, fpeak


In [None]:
def calc_mean_colors_and_slope(sub, shead, gp, band1, band2, object_type, snid, rise_time, fade_time, classification, t_peak):
    """
    Calculate mean colors, slopes, and plot color evolution for a given SNID using rise and fade times.
    Fallback to original logic if rise_time or fade_time is not available.
    """
    if isinstance(snid, bytes):
        snid = snid.decode("utf-8")
    elif isinstance(snid, (np.integer, int)):
        snid = str(snid)

    snid_numeric = re.sub(r'\D', '', snid)  # Extract only numeric part of SNID
        
    if isinstance(rise_time, np.ndarray):
        rise_time = rise_time[0]
    if isinstance(fade_time, np.ndarray):
        fade_time = fade_time[0]

    t_min, t_max = x[:, 0].min(), x[:, 0].max()
    x1 = np.linspace(t_min - 10, t_max + 75, 1000)

    
    indices_pre_peak = ((sub['BAND'] == band1) | (sub['BAND'] == band2)) & \
                       (sub['MJD'] >= t_peak - rise_time) & (sub['MJD'] <= t_peak)
    mjd_pre_peak = sub['MJD'][indices_pre_peak]
    time_pre_peak = mjd_pre_peak - t_peak  # Time since peak
    x_pre_peak, color_pre_peak, color_err_pre_peak = calc_color(gp, flux, sub, mjd_pre_peak, band1, band2)


    mask_pre = ~np.isnan(x_pre_peak) & ~np.isnan(color_pre_peak) & (color_err_pre_peak < 1)
    x_pre_peak, color_pre_peak, color_err_pre_peak = time_pre_peak[mask_pre], color_pre_peak[mask_pre], color_err_pre_peak[mask_pre]

    indices_post_peak = ((sub['BAND'] == band1) | (sub['BAND'] == band2)) & \
                        (sub['MJD'] > t_peak) & (sub['MJD'] <= t_peak + fade_time)
    mjd_post_peak = sub['MJD'][indices_post_peak]
    time_post_peak = mjd_post_peak - t_peak  # Time since peak
    x_post_peak, color_post_peak, color_err_post_peak = calc_color(gp, flux, sub, mjd_post_peak, band1, band2)


    mask_post = ~np.isnan(x_post_peak) & ~np.isnan(color_post_peak) & (color_err_post_peak < 1)
    x_post_peak, color_post_peak, color_err_post_peak = time_post_peak[mask_post], color_post_peak[mask_post], color_err_post_peak[mask_post]  
    
    slope_pre_peak, intercept_pre_peak, slope_err_pre_peak = None, None, None
    slope_post_peak, intercept_post_peak, slope_err_post_peak = None, None, None

    try:
        if len(x_pre_peak) >= 2:
            weights_pre = 1 / color_err_pre_peak**2
            p_pre_peak, cov_pre_peak = np.polyfit(x_pre_peak, color_pre_peak, 1, w=weights_pre, cov=True)
            slope_pre_peak, intercept_pre_peak = p_pre_peak
            slope_err_pre_peak = np.sqrt(cov_pre_peak[0, 0])

        if len(x_post_peak) >= 2:
            weights_post = 1 / color_err_post_peak**2
            p_post_peak, cov_post_peak = np.polyfit(x_post_peak, color_post_peak, 1, w=weights_post, cov=True)
            slope_post_peak, intercept_post_peak = p_post_peak
            slope_err_post_peak = np.sqrt(cov_post_peak[0, 0])
    except (ValueError, TypeError) as e:
        print(f"Linear fit failed for SNID {snid_numeric}: {e}")

    fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(8, 6), sharex=True)
    
    for band, color in [('u', 'blue'), ('g', 'green'), ('r', 'red'), ('i', 'orange'), ('z', 'purple'), ('Y', 'yellow')]:
        wavelength = lsst_bands[band]
        x_band = np.vstack([x1, wavelength * np.ones_like(x1)]).T
        flux_band, fluxerr_band = gp.predict(flux, x_band, return_var=True)
        flux_band = np.maximum(flux_band, 0)  # Ensure positive predictions
        
        ax[0].plot(x_band[:, 0] - t_peak, flux_band, color=color, lw=1.5, alpha=0.5) #, label=f'{band} GP')
        ax[0].fill_between(x_band[:, 0] - t_peak, flux_band - np.sqrt(fluxerr_band), flux_band + np.sqrt(fluxerr_band), color=color, alpha=0.2)

    band_colors = {'u': 'blue', 'g': 'green', 'r': 'red', 'i': 'orange', 'z': 'purple', 'Y': 'yellow'}    
    for band in ['u', 'g', 'r', 'i', 'z', 'Y']:
        idx = (sub['BAND'] == band)
        ax[0].errorbar(sub['MJD'][idx] - t_peak, sub['FLUXCAL'][idx], sub['FLUXCALERR'][idx], fmt='.', color=band_colors[band], label=f'{band}')

    ax[0].axvline(0, color='blue', linestyle='dashed', label='GP Peak')
 #   ax[0].axvline(shead['PEAKMJD'] - t_peak, color='orange', linestyle='dashed', label='True Peak')

    # Plotting rise and fade times on the flux plot relative to peak
    if rise_time is not None:
        ax[0].axvline(-rise_time, color='green', linestyle='dashed', label='Rise Time')  # Negative because it's before the peak
        print(f"Plotting rise time relative to peak at: {-rise_time}")
    
    if fade_time is not None:
        ax[0].axvline(fade_time, color='red', linestyle='dashed', label='Fade Time')  # Positive because it's after the peak
        print(f"Plotting fade time relative to peak at: {fade_time}")

         
    ax[0].grid(alpha=0.3)
    ax[0].set_ylabel('Flux')
    ax[0].set_title(f'{object_type} {snid_numeric} {classification}')
    ax[0].legend()#loc='upper left', bbox_to_anchor=(1, 1), borderaxespad=0., fontsize='small')
    
    if len(x_pre_peak) == len(color_pre_peak) and len(color_pre_peak) == len(color_err_pre_peak):
        ax[1].errorbar(x_pre_peak, color_pre_peak, yerr=color_err_pre_peak, fmt='.', label='Pre-Peak Color (g-r)')
    else:
        print("Mismatch in lengths for color pre-peak.")

    if len(x_post_peak) == len(color_post_peak) and len(color_post_peak) == len(color_err_post_peak):
        ax[1].errorbar(x_post_peak, color_post_peak, yerr=color_err_post_peak, fmt='.', label='Post-Peak Color (g-r)')
    else:
        print("Mismatch in lengths for color post-peak.")

    if slope_pre_peak is not None and slope_err_pre_peak is not None:
        line_pre_peak = slope_pre_peak * np.array(x_pre_peak) + intercept_pre_peak
        ax[1].plot(x_pre_peak, line_pre_peak, 'r-', label='Pre-Peak Fit')

    if slope_post_peak is not None and slope_err_post_peak is not None:
        line_post_peak = slope_post_peak * np.array(x_post_peak) + intercept_post_peak
        ax[1].plot(x_post_peak, line_post_peak, 'g-', label='Post-Peak Fit')

    ax[1].set_xlabel('Time Since Peak (days)')
    ax[1].set_ylabel('Color (g-r)')
    ax[1].legend()#loc='upper left', bbox_to_anchor=(1, 1), borderaxespad=0., fontsize='small')
    ax[1].set_xlim(-100, 200)
    ax[1].set_ylim(-1, 1)
    ax[1].grid(alpha=0.3)
    ax[1].axvline(0, color='blue', linestyle='dashed', label='GP-g Peak')

    plt.tight_layout()
    plt.show()

    return slope_pre_peak, slope_post_peak, slope_err_pre_peak, slope_err_post_peak


## FP and FN

In [None]:
true_positive_snids

In [None]:
import re
import numpy as np
import matplotlib.pyplot as plt

def calc_mean_colors_and_slope(sub, shead, gp, band1, band2, object_type, snid, rise_time, fade_time, classification, t_peak):
    """
    Calculate mean colors, slopes, and plot color evolution for a given SNID using rise and fade times.
    Fallback to original logic if rise_time or fade_time is not available.
    """
    if isinstance(snid, bytes):
        snid = snid.decode("utf-8")
    elif isinstance(snid, (np.integer, int)):
        snid = str(snid)

    snid_numeric = re.sub(r'\D', '', snid)  # Extract only numeric part of SNID
        
    if isinstance(rise_time, np.ndarray):
        rise_time = rise_time[0]
    if isinstance(fade_time, np.ndarray):
        fade_time = fade_time[0]

    # Assuming 'x' and 'flux' are defined elsewhere in your code
    t_min, t_max = x[:, 0].min(), x[:, 0].max()
    x1 = np.linspace(t_min - 10, t_max + 75, 1000)

    indices_pre_peak = ((sub['BAND'] == band1) | (sub['BAND'] == band2)) & \
                       (sub['MJD'] >= t_peak - rise_time) & (sub['MJD'] <= t_peak)
    mjd_pre_peak = sub['MJD'][indices_pre_peak]
    time_pre_peak = mjd_pre_peak - t_peak  # Time since peak
    x_pre_peak, color_pre_peak, color_err_pre_peak = calc_color(gp, flux, sub, mjd_pre_peak, band1, band2)

    mask_pre = ~np.isnan(x_pre_peak) & ~np.isnan(color_pre_peak) & (color_err_pre_peak < 1)
    x_pre_peak, color_pre_peak, color_err_pre_peak = time_pre_peak[mask_pre], color_pre_peak[mask_pre], color_err_pre_peak[mask_pre]

    indices_post_peak = ((sub['BAND'] == band1) | (sub['BAND'] == band2)) & \
                        (sub['MJD'] > t_peak) & (sub['MJD'] <= t_peak + fade_time)
    mjd_post_peak = sub['MJD'][indices_post_peak]
    time_post_peak = mjd_post_peak - t_peak  # Time since peak
    x_post_peak, color_post_peak, color_err_post_peak = calc_color(gp, flux, sub, mjd_post_peak, band1, band2)

    mask_post = ~np.isnan(x_post_peak) & ~np.isnan(color_post_peak) & (color_err_post_peak < 1)
    x_post_peak, color_post_peak, color_err_post_peak = time_post_peak[mask_post], color_post_peak[mask_post], color_err_post_peak[mask_post]  
    
    slope_pre_peak, intercept_pre_peak, slope_err_pre_peak = None, None, None
    slope_post_peak, intercept_post_peak, slope_err_post_peak = None, None, None

    try:
        if len(x_pre_peak) >= 2:
            weights_pre = 1 / color_err_pre_peak**2
            p_pre_peak, cov_pre_peak = np.polyfit(x_pre_peak, color_pre_peak, 1, w=weights_pre, cov=True)
            slope_pre_peak, intercept_pre_peak = p_pre_peak
            slope_err_pre_peak = np.sqrt(cov_pre_peak[0, 0])

        if len(x_post_peak) >= 2:
            weights_post = 1 / color_err_post_peak**2
            p_post_peak, cov_post_peak = np.polyfit(x_post_peak, color_post_peak, 1, w=weights_post, cov=True)
            slope_post_peak, intercept_post_peak = p_post_peak
            slope_err_post_peak = np.sqrt(cov_post_peak[0, 0])
    except (ValueError, TypeError) as e:
        print(f"Linear fit failed for SNID {snid_numeric}: {e}")

    # Create subplots with adjusted height ratios
    fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(8, 5), sharex=True, 
                           gridspec_kw={'height_ratios': [3, 1]})
    
    for band, color in [('u', 'blue'), ('g', 'green'), ('r', 'red'), 
                        ('i', 'orange'), ('z', 'purple'), ('Y', 'yellow')]:
        wavelength = lsst_bands[band]
        x_band = np.vstack([x1, wavelength * np.ones_like(x1)]).T
        flux_band, fluxerr_band = gp.predict(flux, x_band, return_var=True)
        flux_band = np.maximum(flux_band, 0)  # Ensure positive predictions
        
        ax[0].plot(x_band[:, 0] - t_peak, flux_band, color=color, lw=1.5, alpha=0.5)
        ax[0].fill_between(x_band[:, 0] - t_peak, 
                           flux_band - np.sqrt(fluxerr_band), 
                           flux_band + np.sqrt(fluxerr_band), 
                           color=color, alpha=0.2)

    band_colors = {'u': 'blue', 'g': 'green', 'r': 'red', 
                  'i': 'orange', 'z': 'purple', 'Y': 'yellow'}    
    for band in ['u', 'g', 'r', 'i', 'z', 'Y']:
        idx = (sub['BAND'] == band)
        ax[0].errorbar(sub['MJD'][idx] - t_peak, 
                      sub['FLUXCAL'][idx], 
                      sub['FLUXCALERR'][idx], 
                      fmt='.', 
                      color=band_colors[band], 
                      label=f'{band}')

    ax[0].axvline(0, color='blue', linestyle='dashed', label='GP Peak')
    # ax[0].axvline(shead['PEAKMJD'] - t_peak, color='orange', linestyle='dashed', label='True Peak')

    # Plotting rise and fade times on the flux plot relative to peak
    if rise_time is not None:
        ax[0].axvline(-rise_time, color='green', linestyle='dashed', label='Rise Time')  # Negative because it's before the peak
        print(f"Plotting rise time relative to peak at: {-rise_time}")
    
    if fade_time is not None:
        ax[0].axvline(fade_time, color='red', linestyle='dashed', label='Fade Time')  # Positive because it's after the peak
        print(f"Plotting fade time relative to peak at: {fade_time}")

    # Set Y-axis limit for the top plot
    ax[0].set_ylim(bottom=-50)  # Prevent y-axis from going below -50

    # Add text labels "rise" and "fade" near the respective vertical lines
    y_max_flux = ax[0].get_ylim()[1]
    text_y_position = y_max_flux * 0.95  # 95% of the y-axis maximum

    if rise_time is not None:
        ax[0].text(-rise_time, text_y_position, 'rise', 
                   color='green', fontsize=9, ha='right', va='top', 
                   bbox=dict(facecolor='white', edgecolor='green', boxstyle='round,pad=0.2'))
    
    if fade_time is not None:
        ax[0].text(fade_time, text_y_position, 'fade', 
                   color='red', fontsize=9, ha='left', va='top', 
                   bbox=dict(facecolor='white', edgecolor='red', boxstyle='round,pad=0.2'))

    ax[0].grid(alpha=0.3)
    ax[0].set_ylabel('Flux')
    ax[0].set_title(f'ELAsTiCC2 {object_type} {snid_numeric} {classification}')
    ax[0].legend()

    # Plot Pre-Peak Color
    if len(x_pre_peak) == len(color_pre_peak) and len(color_pre_peak) == len(color_err_pre_peak):
        ax[1].errorbar(x_pre_peak, color_pre_peak, yerr=color_err_pre_peak, fmt='.', 
                      label='Pre-Peak Color (g-r)')
    else:
        print("Mismatch in lengths for color pre-peak.")

    # Plot Post-Peak Color
    if len(x_post_peak) == len(color_post_peak) and len(color_post_peak) == len(color_err_post_peak):
        ax[1].errorbar(x_post_peak, color_post_peak, yerr=color_err_post_peak, fmt='.', 
                      label='Post-Peak Color (g-r)')
    else:
        print("Mismatch in lengths for color post-peak.")

    # Plot Pre-Peak Fit
    if slope_pre_peak is not None and slope_err_pre_peak is not None:
        line_pre_peak = slope_pre_peak * np.array(x_pre_peak) + intercept_pre_peak
        ax[1].plot(x_pre_peak, line_pre_peak, 'r-', label='Pre-Peak Fit')

    # Plot Post-Peak Fit
    if slope_post_peak is not None and slope_err_post_peak is not None:
        line_post_peak = slope_post_peak * np.array(x_post_peak) + intercept_post_peak
        ax[1].plot(x_post_peak, line_post_peak, 'g-', label='Post-Peak Fit')

    ax[1].set_xlabel('Time Since Peak (days)')
    ax[1].set_ylabel('g-r (magnitude)')
    ax[1].legend()
    ax[1].set_xlim(-100, 200)
    ax[1].set_ylim(-0.5, 0.7)
    ax[1].grid(alpha=0.3)
    ax[1].axvline(0, color='blue', linestyle='dashed', label='GP-g Peak')

    plt.tight_layout()
    plt.show()

    return slope_pre_peak, slope_post_peak, slope_err_pre_peak, slope_err_post_peak


In [None]:
#True positives

# Base path and template for file names
base_path = "../../../karpov/ELASTICC2/"
filename_template = "ELASTICC2_FINAL_{object_type}/ELASTICC2_FINAL_NONIaMODEL0-{index}_HEAD.FITS.gz"

# Object types and model names
object_info = [
    'TDE', #'AGN', 'SLSN-I+host', 'SLSN-I_no_host', 'SNIa-SALT3', 'SNIa-91bg',
    #'SNIax', 'SNIcBL+HostXT_V19', 'SNIb+HostXT_V19', 'SNIIn-MOSFIT', 'SNII-NMF',
    #'SNII+HostXT_V19', 'SNIIb+HostXT_V19', 'KN_B19', 'KN_K17'
]

# Generate all filenames and corresponding object types
all_filenames = []
object_types = []
for object_type in object_info:
    filenames = [os.path.join(base_path, filename_template.format(object_type=object_type, index=str(i).zfill(4))) for i in range(5, 20)]
    all_filenames.extend(filenames)
    object_types.extend([object_type] * len(filenames))

# Process files
for filename, object_type in zip(all_filenames, object_types):
    table, head = read_elasticc_file(filename)
    snids, shead_list, sub_list = get_snid_head_sub(table, head)

    for snid, shead, sub in zip(snids, shead_list, sub_list):
        snid = int(snid)  # Convert to standard integer

        # Check if the SNID is in the TP TDE list
        if snid not in true_positive_snids:
            continue  # Skip if not in TP TDE list

        # Compute GP for the data set of the SNID
        gp, flux, x, params = compute_gp(sub, shead)
        
        if gp is None:
            print(f"Skipping SNID {shead['SNID']} due to GP optimization failure.")
            continue  # Skip processing this SNID if GP failed

            
        # Calculate peak, rise, and fade times using the combined function
        rise_time, fade_time, t_peak, fpeak = peak_and_risefade(gp, x, flux)

        # Plot using the calc_mean_colors_and_slope function
        slope_pre_peak, slope_post_peak, slope_err_pre_peak, slope_err_post_peak = calc_mean_colors_and_slope(
            sub, shead, gp, 'g', 'r', object_type, snid, rise_time, fade_time, "TP", t_peak)

        plt.show()  # Display the plot immediately

# End the timer and print the elapsed time
end_time = time.time()
print(f"Time taken: {end_time - start_time} seconds")


In [None]:
import re
import numpy as np
import matplotlib.pyplot as plt

def calc_mean_colors_and_slope(sub, shead, gp, band1, band2, object_type, snid, rise_time, fade_time, classification, t_peak, save_path=None, show_plot=True):
    """
    Calculate mean colors, slopes, and plot color evolution for a given SNID using rise and fade times.
    Optionally save the figure to a specified path and control plot display.
    
    Parameters:
    - save_path (str or None): Path to save the figure. If None, the figure is not saved.
    - show_plot (bool): Whether to display the plot. Defaults to True.
    
    Returns:
    - fig: The matplotlib figure object.
    - slope_pre_peak, slope_post_peak, slope_err_pre_peak, slope_err_post_peak
    """
    if isinstance(snid, bytes):
        snid = snid.decode("utf-8")
    elif isinstance(snid, (np.integer, int)):
        snid = str(snid)

    snid_numeric = re.sub(r'\D', '', snid)  # Extract only numeric part of SNID
        
    if isinstance(rise_time, np.ndarray):
        rise_time = rise_time[0]
    if isinstance(fade_time, np.ndarray):
        fade_time = fade_time[0]

    # Assuming 'x' and 'flux' are defined elsewhere in your code
    t_min, t_max = x[:, 0].min(), x[:, 0].max()
    x1 = np.linspace(t_min - 10, t_max + 75, 1000)
    
    # Pre-Peak Indices and Data
    indices_pre_peak = ((sub['BAND'] == band1) | (sub['BAND'] == band2)) & \
                       (sub['MJD'] >= t_peak - rise_time) & (sub['MJD'] <= t_peak)
    mjd_pre_peak = sub['MJD'][indices_pre_peak]
    time_pre_peak = mjd_pre_peak - t_peak  # Time since peak
    x_pre_peak, color_pre_peak, color_err_pre_peak = calc_color(gp, flux, sub, mjd_pre_peak, band1, band2)

    mask_pre = ~np.isnan(x_pre_peak) & ~np.isnan(color_pre_peak) & (color_err_pre_peak < 1)
    x_pre_peak, color_pre_peak, color_err_pre_peak = time_pre_peak[mask_pre], color_pre_peak[mask_pre], color_err_pre_peak[mask_pre]

    # Post-Peak Indices and Data
    indices_post_peak = ((sub['BAND'] == band1) | (sub['BAND'] == band2)) & \
                        (sub['MJD'] > t_peak) & (sub['MJD'] <= t_peak + fade_time)
    mjd_post_peak = sub['MJD'][indices_post_peak]
    time_post_peak = mjd_post_peak - t_peak  # Time since peak
    x_post_peak, color_post_peak, color_err_post_peak = calc_color(gp, flux, sub, mjd_post_peak, band1, band2)

    mask_post = ~np.isnan(x_post_peak) & ~np.isnan(color_post_peak) & (color_err_post_peak < 1)
    x_post_peak, color_post_peak, color_err_post_peak = time_post_peak[mask_post], color_post_peak[mask_post], color_err_post_peak[mask_post]  
    
    slope_pre_peak, intercept_pre_peak, slope_err_pre_peak = None, None, None
    slope_post_peak, intercept_post_peak, slope_err_post_peak = None, None, None

    try:
        if len(x_pre_peak) >= 2:
            weights_pre = 1 / color_err_pre_peak**2
            p_pre_peak, cov_pre_peak = np.polyfit(x_pre_peak, color_pre_peak, 1, w=weights_pre, cov=True)
            slope_pre_peak, intercept_pre_peak = p_pre_peak
            slope_err_pre_peak = np.sqrt(cov_pre_peak[0, 0])

        if len(x_post_peak) >= 2:
            weights_post = 1 / color_err_post_peak**2
            p_post_peak, cov_post_peak = np.polyfit(x_post_peak, color_post_peak, 1, w=weights_post, cov=True)
            slope_post_peak, intercept_post_peak = p_post_peak
            slope_err_post_peak = np.sqrt(cov_post_peak[0, 0])
    except (ValueError, TypeError) as e:
        print(f"Linear fit failed for SNID {snid_numeric}: {e}")

    # Create subplots with adjusted height ratios
    fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(10, 6), sharex=True, 
                           gridspec_kw={'height_ratios': [3, 1]})
    
    for band, color in [('u', 'blue'), ('g', 'green'), ('r', 'red'), 
                        ('i', 'orange'), ('z', 'purple'), ('Y', 'pink')]:
        wavelength = lsst_bands[band]
        x_band = np.vstack([x1, wavelength * np.ones_like(x1)]).T
        flux_band, fluxerr_band = gp.predict(flux, x_band, return_var=True)
        flux_band = np.maximum(flux_band, 0)  # Ensure positive predictions
        
        ax[0].plot(x_band[:, 0] - t_peak, flux_band, color=color, lw=1.5, alpha=0.5)
        ax[0].fill_between(x_band[:, 0] - t_peak, 
                           flux_band - np.sqrt(fluxerr_band), 
                           flux_band + np.sqrt(fluxerr_band), 
                           color=color, alpha=0.2)

    band_colors = {'u': 'blue', 'g': 'green', 'r': 'red', 
                  'i': 'orange', 'z': 'purple', 'Y': 'pink'}    
    for band in ['u', 'g', 'r', 'i', 'z', 'Y']:
        idx = (sub['BAND'] == band)
        ax[0].errorbar(sub['MJD'][idx] - t_peak, 
                      sub['FLUXCAL'][idx], 
                      sub['FLUXCALERR'][idx], 
                      fmt='.', 
                      color=band_colors[band], 
                      label=f'{band}')

    ax[0].axvline(0, color='green', linestyle='dashed', label='GP-g Peak')
    # ax[0].axvline(shead['PEAKMJD'] - t_peak, color='orange', linestyle='dashed', label='True Peak')

    # Plotting rise and fade times on the flux plot relative to peak
    if rise_time is not None:
        ax[0].axvline(-rise_time, color='blue', linestyle='dashed', label='Rise Time')  # Negative because it's before the peak
        print(f"Plotting rise time relative to peak at: {-rise_time}")
    
    if fade_time is not None:
        ax[0].axvline(fade_time, color='red', linestyle='dashed', label='Fade Time')  # Positive because it's after the peak
        print(f"Plotting fade time relative to peak at: {fade_time}")

    # Set Y-axis limit for the top plot
    ax[0].set_ylim(bottom=-50)  # Prevent y-axis from going below -50

    # Add text labels "rise" and "fade" inside between the lines and peak
    # Define an offset for positioning the text labels closer to the peak
    delta_x = 5  # days; adjust as needed

    # Calculate the y-position for the text labels based on y-limits
    y_max_flux = ax[0].get_ylim()[1]
    y_min_flux = ax[0].get_ylim()[0]
    text_y_position = y_max_flux * 0.05  # 90% of the y-axis maximum

    if rise_time is not None:
        # Position the "rise" label slightly to the right of the rise_time line
        x_rise_label = -rise_time + delta_x
        ax[0].text(x_rise_label, text_y_position, 'rise', 
                   color='blue', fontsize=9, ha='left', va='bottom', 
                   bbox=dict(facecolor='white', edgecolor='blue', boxstyle='round,pad=0.2'))
    
    if fade_time is not None:
        # Position the "fade" label slightly to the left of the fade_time line
        x_fade_label = fade_time - delta_x
        ax[0].text(x_fade_label, text_y_position, 'fade', 
                   color='red', fontsize=9, ha='right', va='bottom', 
                   bbox=dict(facecolor='white', edgecolor='red', boxstyle='round,pad=0.2'))


    ax[0].grid(alpha=0.3)
    ax[0].set_ylabel('Flux (ADU)')
    ax[0].set_title(f'ELAsTiCC2 {object_type} {snid_numeric} {classification}')
    ax[0].legend()

    # Plot Pre-Peak Color
    if len(x_pre_peak) == len(color_pre_peak) and len(color_pre_peak) == len(color_err_pre_peak):
        ax[1].errorbar(x_pre_peak, color_pre_peak, yerr=color_err_pre_peak, fmt='.', 
                      label='Pre-Peak Color (g-r)')
    else:
        print("Mismatch in lengths for color pre-peak.")

    # Plot Post-Peak Color
    if len(x_post_peak) == len(color_post_peak) and len(color_post_peak) == len(color_err_post_peak):
        ax[1].errorbar(x_post_peak, color_post_peak, yerr=color_err_post_peak, fmt='.', 
                      label='Post-Peak Color (g-r)')
    else:
        print("Mismatch in lengths for color post-peak.")

    # Plot Pre-Peak Fit
    if slope_pre_peak is not None and slope_err_pre_peak is not None:
        line_pre_peak = slope_pre_peak * np.array(x_pre_peak) + intercept_pre_peak
        ax[1].plot(x_pre_peak, line_pre_peak, 'r-', label='Pre-Peak Fit')

    # Plot Post-Peak Fit
    if slope_post_peak is not None and slope_err_post_peak is not None:
        line_post_peak = slope_post_peak * np.array(x_post_peak) + intercept_post_peak
        ax[1].plot(x_post_peak, line_post_peak, 'g-', label='Post-Peak Fit')

    ax[1].set_xlabel('Time Since Peak (days)')
    ax[1].set_ylabel('g-r (mag)')
    ax[1].legend()
    ax[1].set_xlim(-70, 200)
    ax[1].set_ylim(-1, 1)
    ax[1].grid(alpha=0.3)
    ax[1].axvline(0, color='green', linestyle='dashed', label='GP-g Peak')

    plt.tight_layout()

    # Save the figure if save_path is provided
    if save_path is not None:
        fig.savefig(save_path)
        print(f"Figure saved to {save_path}")

    # Show the plot if requested
    if show_plot:
        plt.show()

    return fig, slope_pre_peak, slope_post_peak, slope_err_pre_peak, slope_err_post_peak


In [None]:
import re
import numpy as np
import matplotlib.pyplot as plt

def calc_mean_colors_and_slope(sub, shead, gp, band1, band2, object_type, snid, rise_time, fade_time, classification, t_peak, save_path=None, show_plot=True):
    """
    Calculate mean colors, slopes, and plot GP fit for a given SNID using rise and fade times.
    Optionally save the figure to a specified path and control plot display.
    
    Parameters:
    - save_path (str or None): Path to save the figure. If None, the figure is not saved.
    - show_plot (bool): Whether to display the plot. Defaults to True.
    
    Returns:
    - fig: The matplotlib figure object.
    - slope_pre_peak, slope_post_peak, slope_err_pre_peak, slope_err_post_peak
    """
    if isinstance(snid, bytes):
        snid = snid.decode("utf-8")
    elif isinstance(snid, (np.integer, int)):
        snid = str(snid)

    snid_numeric = re.sub(r'\D', '', snid)  # Extract only numeric part of SNID
        
    if isinstance(rise_time, np.ndarray):
        rise_time = rise_time[0]
    if isinstance(fade_time, np.ndarray):
        fade_time = fade_time[0]

    # Assuming 'x' and 'flux' are defined elsewhere in your code
    t_min, t_max = x[:, 0].min(), x[:, 0].max()
    x1 = np.linspace(t_min - 10, t_max + 75, 1000)
    
    # Pre-Peak Indices and Data
    indices_pre_peak = ((sub['BAND'] == band1) | (sub['BAND'] == band2)) & \
                       (sub['MJD'] >= t_peak - rise_time) & (sub['MJD'] <= t_peak)
    mjd_pre_peak = sub['MJD'][indices_pre_peak]
    time_pre_peak = mjd_pre_peak - t_peak  # Time since peak
    x_pre_peak, color_pre_peak, color_err_pre_peak = calc_color(gp, flux, sub, mjd_pre_peak, band1, band2)

    mask_pre = ~np.isnan(x_pre_peak) & ~np.isnan(color_pre_peak) & (color_err_pre_peak < 1)
    x_pre_peak, color_pre_peak, color_err_pre_peak = time_pre_peak[mask_pre], color_pre_peak[mask_pre], color_err_pre_peak[mask_pre]

    # Post-Peak Indices and Data
    indices_post_peak = ((sub['BAND'] == band1) | (sub['BAND'] == band2)) & \
                        (sub['MJD'] > t_peak) & (sub['MJD'] <= t_peak + fade_time)
    mjd_post_peak = sub['MJD'][indices_post_peak]
    time_post_peak = mjd_post_peak - t_peak  # Time since peak
    x_post_peak, color_post_peak, color_err_post_peak = calc_color(gp, flux, sub, mjd_post_peak, band1, band2)

    mask_post = ~np.isnan(x_post_peak) & ~np.isnan(color_post_peak) & (color_err_post_peak < 1)
    x_post_peak, color_post_peak, color_err_post_peak = time_post_peak[mask_post], color_post_peak[mask_post], color_err_post_peak[mask_post]  
    
    slope_pre_peak, intercept_pre_peak, slope_err_pre_peak = None, None, None
    slope_post_peak, intercept_post_peak, slope_err_post_peak = None, None, None

    try:
        if len(x_pre_peak) >= 2:
            weights_pre = 1 / color_err_pre_peak**2
            p_pre_peak, cov_pre_peak = np.polyfit(x_pre_peak, color_pre_peak, 1, w=weights_pre, cov=True)
            slope_pre_peak, intercept_pre_peak = p_pre_peak
            slope_err_pre_peak = np.sqrt(cov_pre_peak[0, 0])

        if len(x_post_peak) >= 2:
            weights_post = 1 / color_err_post_peak**2
            p_post_peak, cov_post_peak = np.polyfit(x_post_peak, color_post_peak, 1, w=weights_post, cov=True)
            slope_post_peak, intercept_post_peak = p_post_peak
            slope_err_post_peak = np.sqrt(cov_post_peak[0, 0])
    except (ValueError, TypeError) as e:
        print(f"Linear fit failed for SNID {snid_numeric}: {e}")

    # Create a single subplot for GP fit
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 4))
    
    # Plot GP fits for each band
    for band, color in [('u', 'blue'), ('g', 'green'), ('r', 'red'), 
                        ('i', 'orange'), ('z', 'purple'), ('Y', 'pink')]:
        wavelength = lsst_bands[band]
        x_band = np.vstack([x1, wavelength * np.ones_like(x1)]).T
        flux_band, fluxerr_band = gp.predict(flux, x_band, return_var=True)
        flux_band = np.maximum(flux_band, 0)  # Ensure positive predictions
        
        ax.plot(x_band[:, 0] - t_peak, flux_band, color=color, lw=1.5, alpha=0.7)
        ax.fill_between(x_band[:, 0] - t_peak, 
                        flux_band - np.sqrt(fluxerr_band), 
                        flux_band + np.sqrt(fluxerr_band), 
                        color=color, alpha=0.2)
    
    # Optionally, plot the observational data points (commented out if not needed)
    band_colors = {'u': 'blue', 'g': 'green', 'r': 'red', 
                   'i': 'orange', 'z': 'purple', 'Y': 'pink'}    
    for band in ['u', 'g', 'r', 'i', 'z', 'Y']:
        idx = (sub['BAND'] == band)
        ax.errorbar(sub['MJD'][idx] - t_peak, 
                  sub['FLUXCAL'][idx], 
                  sub['FLUXCALERR'][idx], 
                  fmt='.', 
                  color=band_colors[band], 
                  label=f'{band}')

    # Set Y-axis limit for the plot
    ax.set_ylim(bottom=-50)  # Prevent y-axis from going below -50

    ax.grid(alpha=0.3)
    ax.set_ylabel('Flux (ADU)')
    ax.set_xlabel('Time since peak (days)')
    ax.set_title(f'ELAsTiCC2 {object_type} {snid_numeric} {classification}')
    ax.legend()

    # Remove any vertical lines related to peak, rise, and fade
    # Removed ax.axvline calls and associated text labels

    plt.tight_layout()

    # Save the figure if save_path is provided
    if save_path is not None:
        fig.savefig(save_path)
        print(f"Figure saved to {save_path}")

    # Show the plot if requested
    if show_plot:
        plt.show()

    return fig, slope_pre_peak, slope_post_peak, slope_err_pre_peak, slope_err_post_peak


In [None]:
# True positives
# (Assuming 'true_positive_snids' is defined elsewhere in your code)

# Base path and template for file names
base_path = "../../../karpov/ELASTICC2/"
filename_template = "ELASTICC2_FINAL_{object_type}/ELASTICC2_FINAL_NONIaMODEL0-{index}_HEAD.FITS.gz"

# Object types and model names
object_info = [
    'TDE', #'AGN', 'SLSN-I+host', 'SLSN-I_no_host', 'SNIa-SALT3', 'SNIa-91bg',
    #'SNIax', 'SNIcBL+HostXT_V19', 'SNIb+HostXT_V19', 'SNIIn-MOSFIT', 'SNII-NMF',
    #'SNII+HostXT_V19', 'SNIIb+HostXT_V19', 'KN_B19', 'KN_K17'
]

# Generate all filenames and corresponding object types
all_filenames = []
object_types = []
for object_type in object_info:
    filenames = [os.path.join(base_path, filename_template.format(object_type=object_type, index=str(i).zfill(4))) for i in range(5, 7)]
    all_filenames.extend(filenames)
    object_types.extend([object_type] * len(filenames))

# Define the SNID(s) you want to save the figure for
save_snids = [137495236, 110044328]  # Replace with your specific SNID(s)

# Directory to save figures
save_directory = "home/bhardwaj/notebooksLSST/tdes-fzu/notebooksLSST/ELAsTiCC2_processed/results-images"
os.makedirs(save_directory, exist_ok=True)  # Create the directory if it doesn't exist

# Process files
for filename, object_type in zip(all_filenames, object_types):
    table, head = read_elasticc_file(filename)
    snids, shead_list, sub_list = get_snid_head_sub(table, head)

    for snid, shead, sub in zip(snids, shead_list, sub_list):
        snid = int(snid)  # Convert to standard integer

        # Check if the SNID is in the TP TDE list
        if snid not in true_positive_snids:
            continue  # Skip if not in TP TDE list

        # Compute GP for the data set of the SNID
        gp, flux, x, params = compute_gp(sub, shead)
        
        if gp is None:
            print(f"Skipping SNID {shead['SNID']} due to GP optimization failure.")
            continue  # Skip processing this SNID if GP failed

        # Calculate peak, rise, and fade times using the combined function
        rise_time, fade_time, t_peak, fpeak = peak_and_risefade(gp, x, flux)

        # Determine if the current SNID needs to have its figure saved
        if snid in save_snids:
            # Define the save path for this SNID
            save_path = os.path.join(save_directory, f"SNID_{snid}.png")
            show_plot = False  # Don't display the plot immediately
            print(f"Saving figure for SNID {snid} to {save_path}")
        else:
            save_path = None
            show_plot = True  # Display the plot normally

        # Plot using the modified calc_mean_colors_and_slope function
        fig, slope_pre_peak, slope_post_peak, slope_err_pre_peak, slope_err_post_peak = calc_mean_colors_and_slope(
            sub, shead, gp, 'g', 'r', object_type, snid, rise_time, fade_time, "GP fit with 1σ C.I.", t_peak,
            save_path=save_path,
            show_plot=show_plot
        )

        # If you want to further manipulate or close the figure after saving/showing
        if save_path is not None:
            plt.close(fig)  # Close the figure to free memory

# End the timer and print the elapsed time
end_time = time.time()
print(f"Time taken: {end_time - start_time} seconds")


In [None]:
#True negatives

# Base path and template for file names
base_path = "../../../karpov/ELASTICC2/"
filename_template = "ELASTICC2_FINAL_{object_type}/ELASTICC2_FINAL_NONIaMODEL0-{index}_HEAD.FITS.gz"

# Object types and model names
object_info = [
    'TDE', 'AGN', 'SLSN-I+host', 'SLSN-I_no_host', #'SNIa-SALT3', 
    'SNIa-91bg', 'SNIax', 'SNIcBL+HostXT_V19', 'SNIb+HostXT_V19', 'SNIIn-MOSFIT', 'SNII-NMF',
    'SNII+HostXT_V19', 'SNIIb+HostXT_V19', 'KN_B19', 'KN_K17'
]

# Generate all filenames and corresponding object types
all_filenames = []
object_types = []
for object_type in object_info:
    filenames = [os.path.join(base_path, filename_template.format(object_type=object_type, index=str(i).zfill(4))) for i in range(2, 3)]
    all_filenames.extend(filenames)
    object_types.extend([object_type] * len(filenames))

# Process files
for filename, object_type in zip(all_filenames, object_types):
    table, head = read_elasticc_file(filename)
    snids, shead_list, sub_list = get_snid_head_sub(table, head)

    for snid, shead, sub in zip(snids, shead_list, sub_list):
        snid = int(snid)  # Convert to standard integer

        # Check if the SNID is in the TN TDE list
        if snid not in true_negative_snids:
            continue  # Skip if not in TN TDE list

        # Compute GP for the data set of the SNID
        gp, flux, x, params = compute_gp(sub, shead)
        
        if gp is None:
            print(f"Skipping SNID {shead['SNID']} due to GP optimization failure.")
            continue  # Skip processing this SNID if GP failed

            
        # Calculate peak, rise, and fade times using the combined function
        rise_time, fade_time, t_peak, fpeak = peak_and_risefade(gp, x, flux)

        # Plot using the calc_mean_colors_and_slope function
        slope_pre_peak, slope_post_peak, slope_err_pre_peak, slope_err_post_peak = calc_mean_colors_and_slope(
            sub, shead, gp, 'g', 'r', object_type, snid, rise_time, fade_time, "TN", t_peak)

        plt.show()  # Display the plot immediately
        

# End the timer and print the elapsed time
end_time = time.time()
print(f"Time taken: {end_time - start_time} seconds")


In [None]:
#False positives

# Base path and template for file names
base_path = "../../../karpov/ELASTICC2/"
filename_template = "ELASTICC2_FINAL_{object_type}/ELASTICC2_FINAL_NONIaMODEL0-{index}_HEAD.FITS.gz"

# Object types and model names
object_info = [
    #'TDE', 'AGN',
    'SLSN-I+host', 'SLSN-I_no_host', #'SNIa-SALT3',
    'SNIa-91bg', 'SNIax', 'SNIcBL+HostXT_V19', 'SNIb+HostXT_V19', 'SNIIn-MOSFIT', 'SNII-NMF',
    'SNII+HostXT_V19', 'SNIIb+HostXT_V19'
]

# Generate all filenames and corresponding object types
all_filenames = []
object_types = []
for object_type in object_info:
    filenames = [os.path.join(base_path, filename_template.format(object_type=object_type, index=str(i).zfill(4))) for i in range(2, 4)]
    all_filenames.extend(filenames)
    object_types.extend([object_type] * len(filenames))

# Process files
for filename, object_type in zip(all_filenames, object_types):
    table, head = read_elasticc_file(filename)
    snids, shead_list, sub_list = get_snid_head_sub(table, head)

    for snid, shead, sub in zip(snids, shead_list, sub_list):
        snid = int(snid)  # Convert to standard integer

        # Check if the SNID is in the FP TDE list
        if snid not in false_positive_snids:
            continue  # Skip if not in FP TDE list

        # Compute GP for the data set of the SNID
        gp, flux, x, params = compute_gp(sub, shead)
        
        if gp is None:
            print(f"Skipping SNID {shead['SNID']} due to GP optimization failure.")
            continue  # Skip processing this SNID if GP failed

            
        # Calculate peak, rise, and fade times using the combined function
        rise_time, fade_time, t_peak, fpeak = peak_and_risefade(gp, x, flux)

        # Plot using the calc_mean_colors_and_slope function
        slope_pre_peak, slope_post_peak, slope_err_pre_peak, slope_err_post_peak = calc_mean_colors_and_slope(
            sub, shead, gp, 'g', 'r', object_type, snid, rise_time, fade_time, "FP", t_peak)

        plt.show()  # Display the plot immediately

# End the timer and print the elapsed time
end_time = time.time()
print(f"Time taken: {end_time - start_time} seconds")


In [None]:
#False negatives

# Base path and template for file names
base_path = "../../../karpov/ELASTICC2/"
filename_template = "ELASTICC2_FINAL_{object_type}/ELASTICC2_FINAL_NONIaMODEL0-{index}_HEAD.FITS.gz"

# Object types and model names
object_info = [
    'TDE', 'AGN', 'SLSN-I+host', 'SLSN-I_no_host', 'SNIa-SALT3', 'SNIa-91bg',
    'SNIax', 'SNIcBL+HostXT_V19', 'SNIb+HostXT_V19', 'SNIIn-MOSFIT', 'SNII-NMF',
    'SNII+HostXT_V19', 'SNIIb+HostXT_V19', 'KN_B19', 'KN_K17'
]

# Generate all filenames and corresponding object types
all_filenames = []
object_types = []
for object_type in object_info:
    filenames = [os.path.join(base_path, filename_template.format(object_type=object_type, index=str(i).zfill(4))) for i in range(1, 41)]
    all_filenames.extend(filenames)
    object_types.extend([object_type] * len(filenames))

# Process files
for filename, object_type in zip(all_filenames, object_types):
    table, head = read_elasticc_file(filename)
    snids, shead_list, sub_list = get_snid_head_sub(table, head)

    for snid, shead, sub in zip(snids, shead_list, sub_list):
        snid = int(snid)  # Convert to standard integer

        # Check if the SNID is in the FN TDE list
        if snid not in false_negative_snids:
            continue  # Skip if not in FN TDE list

        # Compute GP for the data set of the SNID
        gp, flux, x, params = compute_gp(sub, shead)
        
        if gp is None:
            print(f"Skipping SNID {shead['SNID']} due to GP optimization failure.")
            continue  # Skip processing this SNID if GP failed

            
        # Calculate peak, rise, and fade times using the combined function
        rise_time, fade_time, t_peak, fpeak = peak_and_risefade(gp, x, flux)

        # Plot using the calc_mean_colors_and_slope function
        slope_pre_peak, slope_post_peak, slope_err_pre_peak, slope_err_post_peak = calc_mean_colors_and_slope(
            sub, shead, gp, 'g', 'r', object_type, snid, rise_time, fade_time, "FN", t_peak)

        plt.show()  # Display the plot immediately

# End the timer and print the elapsed time
end_time = time.time()
print(f"Time taken: {end_time - start_time} seconds")


### search and plot gp by snid and object_type

In [None]:
plot_shap_for_snid(82038200)

In [None]:
# Example: Interactive waterfall plot
shap.initjs()
plot_shap_for_snid(68314204)


In [None]:
# Check for duplicate SNIDs
duplicate_snids = results['SNID'][results['SNID'].duplicated()]
if not duplicate_snids.empty:
    print(f"Duplicate SNIDs found: {duplicate_snids.tolist()}")
    # Handle duplicates as necessary


In [None]:
# If X_test_scaled is a NumPy array, ensure feature order matches feature_columns
if isinstance(X_test_scaled, np.ndarray):
    assert X_test_scaled.shape[1] == len(feature_columns), "Mismatch in number of features."


In [None]:
# create a dependence scatter plot to show the effect of a single feature across the whole dataset
shap.plots.scatter(shap_values[:, "Mean Color Post Peak (r-i)"], color=shap_values)

In [None]:
# summarize the effects of all the features
shap.plots.beeswarm(shap_values, max_display=20, color=plt.get_cmap("cool"))

In [None]:
shap.plots.beeswarm(shap_values, order=shap_values.abs.max(0))

In [None]:
shap.plots.beeswarm(shap_values.abs, color="shap_red")

In [None]:
shap.plots.bar(shap_values)

In [None]:
# Training

# Prepare the data
df0['Simple_Object_Type'] = df0['Object_Type'].apply(lambda x: 'TDE' if x == 'TDE' else 'Other')

# Select numeric columns for the model and exclude 'SNID', any 'peak', and 'redshift' related columns
excluded_columns = ['SNID', 'PeakMag', 'PeakMagErr', 'REDSHIFT_FINAL', 'REDSHIFT_FINAL_ERR']

excluded_columns += [col for col in df0.columns if 'err' in col.lower() or 'flux' in col.lower() or 'mjd' in col.lower()] #or 'slope' in col.lower()]
feature_columns = [col for col in df0.select_dtypes(include=[np.number]).columns if col not in excluded_columns]

X = df0[feature_columns].fillna(-999)  # Handle missing values by filling with -999
y = df0['Simple_Object_Type'].apply(lambda x: 1 if x == 'TDE' else 0)  # Target variable

# Standardize the features
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# Split the datasetbb
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.5, random_state=42)

# Train the XGBoost classifier on the original dataset
xgb_classifier_original = xgb.XGBClassifier(n_estimators=1000, learning_rate=0.5, max_depth=3, random_state=42)
xgb_classifier_original.fit(X_train, y_train)
y_pred_xgb_original = xgb_classifier_original.predict(X_test)
y_probs_xgb_original = xgb_classifier_original.predict_proba(X_test)

# Get the corresponding SNIDs, Object_Types, and predictions
test_snids = df0.iloc[y_test.index]['SNID'].apply(int).tolist()
test_object_types = df0.iloc[y_test.index]['Object_Type'].tolist()

# Combine SNID, Object_Type, and predictions
results = pd.DataFrame({
    'SNID': test_snids,
    'Object_Type': test_object_types,
    'True_Label': y_test,
    'Predicted_Label': y_pred_xgb_original
})

# Identify false positives and false negatives
false_positives = results[(results['True_Label'] == 0) & (results['Predicted_Label'] == 1)]
false_negatives = results[(results['True_Label'] == 1) & (results['Predicted_Label'] == 0)]
true_negatives = results[(results['True_Label'] == 0) & (results['Predicted_Label'] == 0)]
true_positives = results[(results['True_Label'] == 1) & (results['Predicted_Label'] == 1)]


# Output the SNIDs and Object_Types for FNs
false_negative_snids = false_negatives['SNID'].tolist()
false_negative_types = false_negatives['Object_Type'].tolist()
false_positive_snids = false_positives['SNID'].tolist()
false_positive_types = false_positives['Object_Type'].tolist()
true_negative_snids = true_negatives['SNID'].tolist()
true_negative_types = true_negatives['Object_Type'].tolist()
true_positive_snids = true_positives['SNID'].tolist()
true_positive_types = true_positives['Object_Type'].tolist()

print(f"Total False Negative TDEs: {len(false_negative_snids)}")
print(f"Total False Positive TDEs: {len(false_positive_snids)}")
print(f"Total True Negative TDEs: {len(true_negative_snids)}")
print(f"Total True Positive TDEs: {len(true_positive_snids)}")

print(f"Sample FN TDE SNIDs: {false_negative_snids[:10]}")
print(f"Sample FN TDE Object Types: {false_negative_types[:10]}")
print(f"Sample FP TDE SNIDs: {false_positive_snids[:10]}")
print(f"Sample FP TDE Object Types: {false_positive_types[:10]}")

In [None]:
feature_columns

In [None]:
# Data Preparation
# ---------------------------------------

# Prepare the data
df0['Simple_Object_Type'] = df0['Object_Type'].apply(lambda x: 'TDE' if x == 'TDE' else 'Other')

# Select numeric columns for the model and exclude specified columns
excluded_columns = ['SNID', 'PeakMag', 'PeakMagErr', 'REDSHIFT_FINAL', 'REDSHIFT_FINAL_ERR']

# Exclude columns containing 'err', 'flux', or 'mjd'
excluded_columns += [col for col in df0.columns if 'err' in col.lower() or 'flux' in col.lower() or 'mjd' in col.lower()]

# Select feature columns
feature_columns = [col for col in df0.select_dtypes(include=[np.number]).columns if col not in excluded_columns]

# Define a renaming map for clarity (if applicable)
rename_map = {
    'Mean_Color_Pre_Peak_gr': 'Mean Color Pre Peak (g-r)',
    'Mean_Color_Post_Peak_gr': 'Mean Color Post Peak (g-r)',
    'Mean_Color_Pre_Peak_ri': 'Mean Color Pre Peak (r-i)',
    'Mean_Color_Post_Peak_ri': 'Mean Color Post Peak (r-i)',
    'Slope_Pre_Peak_gr': 'Slope Pre Peak (g-r)',
    'Slope_Post_Peak_gr': 'Slope Post Peak (g-r)',
    'Slope_Pre_Peak_ri': 'Slope Pre Peak (r-i)',
    'Slope_Post_Peak_ri': 'Slope Post Peak (r-i)',
    'Rise_Time': 'Rise time',
    'Fade_Time': 'Fade time'
}

# Rename columns for clarity (if necessary)
df0.rename(columns=rename_map, inplace=True)
feature_columns = [rename_map.get(col, col) for col in feature_columns]

# Define features and target
X = df0[feature_columns]
y = df0['Simple_Object_Type'].apply(lambda x: 1 if x == 'TDE' else 0)  # Binary encoding

# ---------------------------------------
# Handling Missing Values
# ---------------------------------------

# Split the dataset first to prevent data leakage
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.5, random_state=42, stratify=y
)

# Compute the mean of each feature from the training set
feature_means = X_train.mean()

# Fill missing values in both training and test sets with the training set means
X_train_filled = X_train.fillna(feature_means)
X_test_filled = X_test.fillna(feature_means)

# ---------------------------------------
# Feature Scaling
# ---------------------------------------

# Initialize the StandardScaler
scaler = StandardScaler()

# Fit the scaler on the training data and transform both training and test data
X_train_scaled = scaler.fit_transform(X_train_filled)
X_test_scaled = scaler.transform(X_test_filled)

# ---------------------------------------
# Model Training
# ---------------------------------------

# Initialize the XGBoost classifier with specified hyperparameters
xgb_classifier_original = xgb.XGBClassifier(
    n_estimators=1000,
    learning_rate=0.5,
    max_depth=3,
    random_state=42,
    use_label_encoder=False,  # Suppress warning in newer XGBoost versions
    eval_metric='logloss'     # Specify evaluation metric to avoid warnings
)

# Train the classifier on the original training data
xgb_classifier_original.fit(X_train_scaled, y_train)

# Make predictions on the test set
y_pred_xgb_original = xgb_classifier_original.predict(X_test_scaled)
y_probs_xgb_original = xgb_classifier_original.predict_proba(X_test_scaled)[:, 1]

# ---------------------------------------
# Extracting Test Set Information
# ---------------------------------------

# Get the corresponding SNIDs and Object_Types for the test set
test_snids = df0.iloc[y_test.index]['SNID'].apply(int).tolist()
test_object_types = df0.iloc[y_test.index]['Object_Type'].tolist()

# Combine SNID, Object_Type, and predictions into a DataFrame
results = pd.DataFrame({
    'SNID': test_snids,
    'Object_Type': test_object_types,
    'True_Label': y_test,
    'Predicted_Label': y_pred_xgb_original
})

# ---------------------------------------
# Identifying Prediction Outcomes
# ---------------------------------------

# Identify false positives, false negatives, true positives, and true negatives
false_positives = results[(results['True_Label'] == 0) & (results['Predicted_Label'] == 1)]
false_negatives = results[(results['True_Label'] == 1) & (results['Predicted_Label'] == 0)]
true_negatives = results[(results['True_Label'] == 0) & (results['Predicted_Label'] == 0)]
true_positives = results[(results['True_Label'] == 1) & (results['Predicted_Label'] == 1)]

# Extract SNIDs and Object_Types for each category
false_negative_snids = false_negatives['SNID'].tolist()
false_negative_types = false_negatives['Object_Type'].tolist()
false_positive_snids = false_positives['SNID'].tolist()
false_positive_types = false_positives['Object_Type'].tolist()
true_negative_snids = true_negatives['SNID'].tolist()
true_negative_types = true_negatives['Object_Type'].tolist()
true_positive_snids = true_positives['SNID'].tolist()
true_positive_types = true_positives['Object_Type'].tolist()

# ---------------------------------------
# Displaying Results
# ---------------------------------------

print(f"Total False Negative TDEs: {len(false_negative_snids)}")
print(f"Total False Positive TDEs: {len(false_positive_snids)}")
print(f"Total True Negative TDEs: {len(true_negative_snids)}")
print(f"Total True Positive TDEs: {len(true_positive_snids)}\n")

print(f"Sample FN TDE SNIDs: {false_negative_snids[:10]}")
print(f"Sample FN TDE Object Types: {false_negative_types[:10]}")
print(f"Sample FP TDE SNIDs: {false_positive_snids[:10]}")
print(f"Sample FP TDE Object Types: {false_positive_types[:10]}")


In [None]:
X

In [None]:
# Prepare SHAP explainer
explainer_data = X_test  # Using the test dataset
explainer = shap.Explainer(xgb_classifier_original, explainer_data)
shap_values = explainer(explainer_data, check_additivity=False)  # Disable additivity check
snid_to_shap_index = {snid: i for i, snid in enumerate(results['SNID'])} # Map SNIDs to SHAP values indices

In [None]:
def plot_shap_for_snid(snid):
    try:
        # Determine the classification based on SNID
        if snid in false_negative_snids:
            classification = "FN"
        elif snid in false_positive_snids:
            classification = "FP"
        elif snid in true_negative_snids:
            classification = "TN"
        elif snid in true_positive_snids:
            classification = "TP"    
        else:
            print(f"SNID {snid} not found FP/FN/TP/TN lists.")
            return

        # Find the index of the SNID
        index = snid_to_shap_index[snid]

        # Access the SHAP values for this SNID
        shap_value = shap_values[index]
        print(f"Successfully accessed SHAP value for SNID {snid} at index {index}.")

        # Map the SHAP values to their corresponding feature names
        feature_mapping = {i: name for i, name in enumerate(feature_columns)}
        shap_value.feature_names = [feature_mapping.get(i, f"Feature {i}") for i in range(len(feature_columns))]

        # Plot the SHAP waterfall plot
        shap.plots.waterfall(shap_value, max_display=5, show=False)
        plt.title(f"SNID: {snid}, Object Type: {results.iloc[index]['Object_Type']}, Classification: {classification}")
        plt.show()

    except IndexError as e:
        print(f"IndexError: {e}")


In [None]:
from lime import lime_tabular
import matplotlib.pyplot as plt
import numpy as np

# 1. Verify Data Integrity
def verify_data(X_train, X_test):
    if np.any(np.isnan(X_train)) or np.any(np.isinf(X_train)):
        raise ValueError("Training data contains NaN or infinite values.")
    if np.any(np.isnan(X_test)) or np.any(np.isinf(X_test)):
        raise ValueError("Test data contains NaN or infinite values.")
    
    zero_variance_features = [col for col in X_train.columns if X_train[col].nunique() <= 1]
    if zero_variance_features:
        print(f"Features with zero or one unique value: {zero_variance_features}")
        # Handle accordingly

# 2. Initialize LIME Explainer
def initialize_lime(X_train, feature_columns, categorical_features=[], categorical_names={}):
    lime_explainer = lime_tabular.LimeTabularExplainer(
        training_data=X_train.values,
        feature_names=feature_columns,
        class_names=['Non-TDE', 'TDE'],
        categorical_features=categorical_features,
        categorical_names=categorical_names,
        mode='classification',
        discretize_continuous=True
    )
    return lime_explainer

# 3. Define SNID to Index Mapping
def create_snid_mapping(results):
    snid_to_index = {snid: i for i, snid in enumerate(results['SNID'])}
    return snid_to_index

# 4. Define the LIME Plotting Function
def plot_lime_for_snid(snid, lime_explainer, snid_to_index, pipeline, aligned_results,
                      false_negative_snids, false_positive_snids, true_negative_snids, true_positive_snids):
    try:
        # Determine the classification based on SNID
        if snid in false_negative_snids:
            classification = "FN"
        elif snid in false_positive_snids:
            classification = "FP"
        elif snid in true_negative_snids:
            classification = "TN"
        elif snid in true_positive_snids:
            classification = "TP"    
        else:
            print(f"SNID {snid} not found in FP/FN/TP/TN lists.")
            return

        # Find the index of the SNID
        index = snid_to_index[snid]

        # Extract the instance for explanation
        instance = X_test_imputed.iloc[index].values

        # Debug: Inspect the instance
        print(f"Inspecting SNID {snid} at index {index}:")
        print(instance)

        # Generate LIME explanation
        explanation = lime_explainer.explain_instance(
            data_row=instance,
            predict_fn=pipeline.predict_proba,  # Use the pipeline's predict_proba
            num_features=5,
            top_labels=1
        )

        # Get the predicted class
        predicted_class = pipeline.predict(instance.reshape(1, -1))[0]
        class_names = ['Non-TDE', 'TDE']
        predicted_class_name = class_names[predicted_class]

        # Plot the explanation
        fig = explanation.as_pyplot_figure(label=predicted_class)
        plt.title(f"SNID: {snid}, Object Type: {aligned_results.iloc[index]['Object_Type']}, Classification: {classification}")
        plt.tight_layout()
        plt.show()

    except KeyError as e:
        print(f"KeyError: {e} - SNID {snid} not found.")
    except ValueError as e:
        print(f"ValueError: {e}")
    except Exception as e:
        print(f"An error occurred: {e}")


# 5. Execute the Workflow with Debugging
def main():
    # Step 1: Handle Data Cleaning
    # Replace infinite values with NaN
    X_train.replace([np.inf, -np.inf], np.nan, inplace=True)
    X_test.replace([np.inf, -np.inf], np.nan, inplace=True)

    # Impute NaN values
    imputer = SimpleImputer(strategy='mean')  # Ensure this matches your pipeline's strategy
    X_train_imputed = pd.DataFrame(imputer.fit_transform(X_train), columns=X_train.columns, index=X_train.index)
    X_test_imputed = pd.DataFrame(imputer.transform(X_test), columns=X_test.columns, index=X_test.index)

    # Step 2: Verify Data
    verify_data(X_train_imputed, X_test_imputed)

    # Step 3: Initialize LIME Explainer with imputed data
    lime_explainer = initialize_lime(
        X_train=X_train_imputed,
        feature_columns=feature_columns,
        categorical_features=[],  # Update if necessary
        categorical_names={}      # Update if necessary
    )

    # Step 4: Create SNID to Index Mapping
    snid_to_index = create_snid_mapping(results)

    # Step 5: Define SNIDs to Explain
    example_snids = ['14291270']  # Replace with actual SNIDs

    # Step 6: Plot Explanations
    for snid in example_snids:
        plot_lime_for_snid(
            snid=snid,
            lime_explainer=lime_explainer,
            snid_to_index=snid_to_index,
            pipeline=pipeline_xgb_smote,  # Replace with the appropriate pipeline
            aligned_results=results,  # Ensure this is correctly defined
            false_negative_snids=false_negative_snids,
            false_positive_snids=false_positive_snids,
            true_negative_snids=true_negative_snids,
            true_positive_snids=true_positive_snids
        )

# Run the main function
main()



In [None]:
from sklearn.impute import SimpleImputer

# Step 1: Replace infinite values with NaN
X_train.replace([np.inf, -np.inf], np.nan, inplace=True)
X_test.replace([np.inf, -np.inf], np.nan, inplace=True)

# Step 2: Impute NaN values
imputer = SimpleImputer(strategy='mean')  # Ensure this matches your pipeline's strategy

X_train_imputed = pd.DataFrame(imputer.fit_transform(X_train), columns=X_train.columns, index=X_train.index)
X_test_imputed = pd.DataFrame(imputer.transform(X_test), columns=X_test.columns, index=X_test.index)

# Step 3: Verify data
verify_data(X_train_imputed, X_test_imputed)


In [None]:
# Transform the data using the combined pipeline's preprocessor
X_train_processed = combined_pipeline.named_steps['preprocessor'].transform(X_train)
X_test_processed = combined_pipeline.named_steps['preprocessor'].transform(X_test)

# Convert back to DataFrame
X_train_processed = pd.DataFrame(X_train_processed, columns=feature_columns, index=X_train.index)
X_test_processed = pd.DataFrame(X_test_processed, columns=feature_columns, index=X_test.index)

# Verify the processed data
verify_data(X_train_processed, X_test_processed)


In [None]:
from lime import lime_tabular

# Define categorical features if any
categorical_features = []  # Update based on your data
categorical_names = {}     # Update based on your data

# Initialize LIME Tabular Explainer
lime_explainer = lime_tabular.LimeTabularExplainer(
    training_data=X_train_processed.values,
    feature_names=feature_columns,
    class_names=['Non-TDE', 'TDE'],
    categorical_features=categorical_features,
    categorical_names=categorical_names,
    mode='classification',
    discretize_continuous=True
)


In [None]:
# Assuming 'results' is your DataFrame containing 'SNID'
snid_to_index = {snid: i for i, snid in enumerate(results['SNID'])}


In [None]:
def plot_lime_for_snid(snid, lime_explainer, snid_to_index, pipeline, aligned_results,
                      false_negative_snids, false_positive_snids, true_negative_snids, true_positive_snids):
    try:
        # Determine the classification based on SNID
        if snid in false_negative_snids:
            classification = "FN"
        elif snid in false_positive_snids:
            classification = "FP"
        elif snid in true_negative_snids:
            classification = "TN"
        elif snid in true_positive_snids:
            classification = "TP"    
        else:
            print(f"SNID {snid} not found in FP/FN/TP/TN lists.")
            return

        # Find the index of the SNID
        index = snid_to_index[snid]

        # Extract the instance for explanation (raw data)
        instance = X_test.iloc[index].values

        # Generate LIME explanation
        explanation = lime_explainer.explain_instance(
            data_row=instance,
            predict_fn=combined_pipeline.predict_proba,  # Use the combined pipeline's predict_proba
            num_features=5,
            top_labels=1
        )

        # Get the predicted class
        predicted_class = combined_pipeline.predict(instance.reshape(1, -1))[0]
        class_names = ['Non-TDE', 'TDE']
        predicted_class_name = class_names[predicted_class]

        # Plot the explanation
        fig = explanation.as_pyplot_figure(label=predicted_class)
        plt.title(f"SNID: {snid}, Object Type: {aligned_results.iloc[index]['Object_Type']}, Classification: {classification}")
        plt.tight_layout()
        plt.show()

    except KeyError as e:
        print(f"KeyError: {e} - SNID {snid} not found.")
    except ValueError as e:
        print(f"ValueError: {e}")
    except Exception as e:
        print(f"An error occurred: {e}")


In [None]:
def main():
    # Step 1: Define and fit the combined pipeline (already done earlier)
    
    # Step 2: Transform the data using the combined pipeline's preprocessor
    X_train_processed = combined_pipeline.named_steps['preprocessor'].transform(X_train)
    X_test_processed = combined_pipeline.named_steps['preprocessor'].transform(X_test)
    
    # Convert back to DataFrame
    X_train_processed = pd.DataFrame(X_train_processed, columns=feature_columns, index=X_train.index)
    X_test_processed = pd.DataFrame(X_test_processed, columns=feature_columns, index=X_test.index)
    
    # Step 3: Verify the processed data
    verify_data(X_train_processed, X_test_processed)
    
    # Step 4: Initialize LIME Explainer with processed training data
    lime_explainer = lime_tabular.LimeTabularExplainer(
        training_data=X_train_processed.values,
        feature_names=feature_columns,
        class_names=['Non-TDE', 'TDE'],
        categorical_features=categorical_features,
        categorical_names=categorical_names,
        mode='classification',
        discretize_continuous=True
    )
    
    # Step 5: Create SNID to Index Mapping
    snid_to_index = {snid: i for i, snid in enumerate(results['SNID'])}
    
    # Step 6: Define SNIDs to Explain
    example_snids = ['14291270']  # Replace with actual SNIDs
    
    # Step 7: Plot Explanations
    for snid in example_snids:
        plot_lime_for_snid(
            snid=snid,
            lime_explainer=lime_explainer,
            snid_to_index=snid_to_index,
            pipeline=combined_pipeline,
            aligned_results=results,  # Ensure this is correctly defined
            false_negative_snids=false_negative_snids,
            false_positive_snids=false_positive_snids,
            true_negative_snids=true_negative_snids,
            true_positive_snids=true_positive_snids
        )


In [None]:
# Execute the main function
main()


## Understanding Classifiers

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.preprocessing import StandardScaler
from sklearn.impute import SimpleImputer, KNNImputer
from sklearn.model_selection import train_test_split, StratifiedKFold, GridSearchCV
from sklearn.metrics import (
    confusion_matrix, classification_report, roc_auc_score, precision_recall_curve
)
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer

import xgboost as xgb
from imblearn.over_sampling import SMOTE, ADASYN
from imblearn.pipeline import Pipeline as ImbPipeline  # To include SMOTE/ADASYN in the pipeline

import warnings
warnings.filterwarnings('ignore')  # Suppress warnings for cleaner output

# ---------------------------------------
# Helper Functions
# ---------------------------------------

def find_threshold_for_precision(y_true, y_probs, desired_precision):
    """
    Find the threshold that achieves at least the desired precision.

    Parameters:
    - y_true: Ground truth binary labels.
    - y_probs: Predicted probabilities for the positive class.
    - desired_precision: Desired precision level (e.g., 0.80 for 80%).

    Returns:
    - threshold: Threshold value achieving at least the desired precision.
    """
    precision, recall, thresholds = precision_recall_curve(y_true, y_probs)
    # Exclude the last precision value which has no corresponding threshold
    precision = precision[:-1]
    thresholds = thresholds

    # Find indices where precision is >= desired_precision
    indices = np.where(precision >= desired_precision)[0]

    if len(indices) > 0:
        # Choose the first threshold where precision >= desired_precision
        # This corresponds to the smallest threshold achieving desired precision
        return thresholds[indices[0]]
    else:
        # If desired precision is not achievable, return None
        print(f"No threshold found to achieve {desired_precision*100}% precision.")
        return None





def plot_confusion_matrix(cm, strategy_name, normalize=True):
    """
    Plot a confusion matrix with percentages and counts.

    Parameters:
    - cm: Confusion matrix.
    - strategy_name: Name of the strategy for the title.
    - normalize: Whether to normalize the confusion matrix per true label.
    """
    if normalize:
        cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    else:
        cm_normalized = cm

    plt.figure(figsize=(6, 4))
    sns.heatmap(
        cm_normalized, annot=False, fmt='.2f', cmap='Blues',
        xticklabels=['Non-TDE', 'TDE'],
        yticklabels=['Non-TDE', 'TDE'],
        cbar=False
    )

    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            if normalize:
                percentage = cm_normalized[i, j] * 100
                count = cm[i, j]
                text = f'{percentage:.1f}%\n({count})'
            else:
                text = f'{cm[i, j]}'
            plt.text(
                j + 0.5, i + 0.5, text,
                ha='center', va='center',
                color='black', fontsize=12,
                bbox=dict(facecolor='white', edgecolor='white')
            )

    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.title(f'Confusion Matrix - {strategy_name}')
    plt.show()

def plot_precision_recall_curve_with_thresholds(y_true, y_probs, strategy_name, desired_precisions=[0.80, 0.95]):
    """
    Plot Precision and Recall as functions of the threshold, marking specified thresholds.

    Parameters:
    - y_true: Ground truth binary labels.
    - y_probs: Predicted probabilities for the positive class.
    - strategy_name: Name of the strategy for identification.
    - desired_precisions: List of desired precision levels to mark.
    """
    precision, recall, thresholds = precision_recall_curve(y_true, y_probs)
    # Exclude the last precision and recall values which have no corresponding threshold
    precision = precision[:-1]
    recall = recall[:-1]
    thresholds = thresholds

    plt.figure(figsize=(10, 6))
    plt.plot(thresholds, precision, label='Precision', color='blue')
    plt.plot(thresholds, recall, label='Recall', color='green')

    # Find and plot thresholds for desired precisions
    for dp in desired_precisions:
        thresh = find_threshold_for_precision(y_true, y_probs, dp)
        if thresh is not None:
            label = f'{int(dp*100)}% Precision Threshold ({thresh:.2f})'
            color = 'red' if dp == 0.80 else 'purple'
            plt.axvline(x=thresh, linestyle='--', color=color, label=label)
        else:
            print(f"{strategy_name}: Desired precision of {int(dp*100)}% not achievable.")

    plt.xlabel('Threshold')
    plt.ylabel('Score')
    plt.title(f'Precision and Recall vs. Threshold - {strategy_name}')
    plt.legend()
    plt.grid(True)
    plt.show()





def train_and_evaluate(pipeline, param_grid, X_train, y_train, X_test, y_test, strategy_name):
    """
    Perform hyperparameter tuning, train the model, make predictions, and evaluate performance.

    Parameters:
    - pipeline: scikit-learn or imblearn pipeline.
    - param_grid: Dictionary of hyperparameters for GridSearchCV.
    - X_train, y_train: Training data.
    - X_test, y_test: Testing data.
    - strategy_name: Name of the strategy for identification.

    Returns:
    - best_estimator: The best pipeline after GridSearchCV.
    - metrics: Dictionary containing evaluation metrics and predictions.
    """
    # Initialize Stratified K-Fold
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

    # Initialize GridSearchCV
    grid_search = GridSearchCV(
        estimator=pipeline,
        param_grid=param_grid,
        scoring='roc_auc',
        cv=skf,
        n_jobs=-1,
        verbose=1
    )

    # Fit GridSearchCV
    grid_search.fit(X_train, y_train)

    print(f"Best parameters for {strategy_name}: {grid_search.best_params_}")
    print(f"Best ROC AUC for {strategy_name}: {grid_search.best_score_:.4f}\n")

    # Make predictions with the best estimator
    best_estimator = grid_search.best_estimator_
    y_pred = best_estimator.predict(X_test)
    y_probs = best_estimator.predict_proba(X_test)[:, 1]

    # Calculate evaluation metrics
    cm = confusion_matrix(y_test, y_pred)
    report = classification_report(y_test, y_pred, target_names=['Non-TDE', 'TDE'])
    roc_auc = roc_auc_score(y_test, y_probs)

    # Plot confusion matrix
    plot_confusion_matrix(cm, strategy_name, normalize=True)

    # Print classification report and ROC AUC
    print(f"Classification Report for {strategy_name}:\n{report}")
    print(f"ROC AUC for {strategy_name}: {roc_auc:.4f}\n")

    return best_estimator, {
        'Confusion Matrix': cm,
        'Classification Report': report,
        'ROC AUC': roc_auc,
        'Predicted Probabilities': y_probs
    }


In [None]:
# Prepare the data
df0['Simple_Object_Type'] = df0['Object_Type'].apply(lambda x: 1 if x == 'TDE' else 0)

# Select numeric columns for the model and exclude specified columns
excluded_columns = ['SNID', 'PeakMag', 'PeakMagErr', 'REDSHIFT_FINAL', 'REDSHIFT_FINAL_ERR']
excluded_columns += [col for col in df0.columns if 'err' in col.lower() or 'flux' in col.lower() or 'mjd' in col.lower()]

feature_columns = [col for col in df0.select_dtypes(include=[np.number]).columns if col not in excluded_columns]

# Define a renaming map for clarity (if applicable)
rename_map = {
    'Mean_Color_Pre_Peak_gr': 'Mean Color Pre Peak (g-r)',
    'Mean_Color_Post_Peak_gr': 'Mean Color Post Peak (g-r)',
    'Mean_Color_Pre_Peak_ri': 'Mean Color Pre Peak (r-i)',
    'Mean_Color_Post_Peak_ri': 'Mean Color Post Peak (r-i)',
    'Slope_Pre_Peak_gr': 'Slope Pre Peak (g-r)',
    'Slope_Post_Peak_gr': 'Slope Post Peak (g-r)',
    'Slope_Pre_Peak_ri': 'Slope Pre Peak (r-i)',
    'Slope_Post_Peak_ri': 'Slope Post Peak (r-i)',
    'Rise_Time': 'Rise time',
    'Fade_Time': 'Fade time'
}

# Rename columns for clarity (if necessary)
df0.rename(columns=rename_map, inplace=True)
feature_columns = [rename_map.get(col, col) for col in feature_columns]

# Define features and target
X = df0[feature_columns]
y = df0['Simple_Object_Type']


In [None]:
# Split the dataset with stratification to maintain class distribution
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.5, random_state=42, stratify=y
)

print("Training set class distribution:")
print(y_train.value_counts())
print("\nTest set class distribution:")
print(y_test.value_counts())


In [None]:
# Strategy 1: XGBoost with Mean Imputation and SMOTE
pipeline_xgb_smote = ImbPipeline(steps=[
    ('imputer', SimpleImputer(strategy='mean')),
    ('scaler', StandardScaler()),
    ('smote', SMOTE(random_state=42)),
    ('classifier', xgb.XGBClassifier(
        use_label_encoder=False,
        eval_metric='logloss',
        random_state=42
    ))
])

# Define hyperparameter grid for XGBoost
param_grid_xgb = {
    'classifier__n_estimators': [100, 500, 1000],
    'classifier__learning_rate': [0.01, 0.1, 0.3],
    'classifier__max_depth': [3, 5, 7],
    'classifier__subsample': [0.6, 0.8, 1.0],
    'classifier__colsample_bytree': [0.6, 0.8, 1.0]
}


In [None]:
# Strategy 2: Random Forest with Median Imputation and Class Weights
pipeline_rf_class_weights = Pipeline(steps=[
    ('imputer', SimpleImputer(strategy='median')),
    ('scaler', StandardScaler()),
    ('classifier', RandomForestClassifier(
        class_weight='balanced',
        random_state=42
    ))
])

# Define hyperparameter grid for Random Forest
param_grid_rf = {
    'classifier__n_estimators': [100, 300, 500],
    'classifier__max_depth': [None, 10, 20],
    'classifier__min_samples_split': [2, 5, 10],
    'classifier__min_samples_leaf': [1, 2, 4],
    'classifier__bootstrap': [True, False]
}


In [None]:
# Strategy 3: Logistic Regression with KNN Imputation and ADASYN
pipeline_lr_adasyn = ImbPipeline(steps=[
    ('imputer', KNNImputer(n_neighbors=5)),
    ('scaler', StandardScaler()),
    ('adasyn', ADASYN(random_state=42)),
    ('classifier', LogisticRegression(
        penalty='l2',
        solver='lbfgs',
        max_iter=1000,
        random_state=42
    ))
])

# Define hyperparameter grid for Logistic Regression
param_grid_lr = {
    'classifier__C': [0.01, 0.1, 1, 10],
    'classifier__solver': ['lbfgs', 'saga'],
    'classifier__penalty': ['l2']
}


In [None]:
# Strategy 4: XGBoost without SMOTE
pipeline_xgb_no_smote = Pipeline(steps=[
    ('imputer', SimpleImputer(strategy='mean')),  # Mean imputation
    ('scaler', StandardScaler()),                 # Feature scaling
    ('classifier', xgb.XGBClassifier(
        use_label_encoder=False,
        eval_metric='logloss',
        random_state=42
    ))
])

# Define hyperparameter grid for XGBoost without SMOTE
param_grid_xgb_no_smote = {
    'classifier__n_estimators': [100, 500, 1000],
    'classifier__learning_rate': [0.01, 0.1, 0.3],
    'classifier__max_depth': [3, 5, 7],
    'classifier__subsample': [0.6, 0.8, 1.0],
    'classifier__colsample_bytree': [0.6, 0.8, 1.0]
}


In [None]:
# Strategy 5: Neural Network Classifier with ADASYN
pipeline_nn_adasyn = ImbPipeline(steps=[
    ('imputer', KNNImputer(n_neighbors=5)),      # KNN imputation
    ('scaler', StandardScaler()),                 # Feature scaling
    ('adasyn', ADASYN(random_state=42)),         # ADASYN for class imbalance
    ('classifier', MLPClassifier(
        hidden_layer_sizes=(100,),                # Default architecture
        activation='relu',
        solver='adam',
        max_iter=500,
        random_state=42
    ))
])

# Define hyperparameter grid for Neural Network Classifier
param_grid_nn = {
    'classifier__hidden_layer_sizes': [(50,), (100,), (100, 50)],
    'classifier__activation': ['relu', 'tanh'],
    'classifier__solver': ['adam', 'sgd'],
    'classifier__alpha': [0.0001, 0.001, 0.01],
    'classifier__learning_rate': ['constant', 'adaptive']
}


In [None]:
# Strategy 1: XGBoost with Mean Imputation and SMOTE
print("=== Strategy 1: XGBoost with Mean Imputation and SMOTE ===\n")
best_xgb_smote, metrics_xgb_smote = train_and_evaluate(
    pipeline=pipeline_xgb_smote,
    param_grid=param_grid_xgb,
    X_train=X_train,
    y_train=y_train,
    X_test=X_test,
    y_test=y_test,
    strategy_name='XGBoost with SMOTE'
)


In [None]:
# Strategy 2: Random Forest with Median Imputation and Class Weights
print("=== Strategy 2: Random Forest with Median Imputation and Class Weights ===\n")
best_rf_class_weights, metrics_rf_class_weights = train_and_evaluate(
    pipeline=pipeline_rf_class_weights,
    param_grid=param_grid_rf,
    X_train=X_train,
    y_train=y_train,
    X_test=X_test,
    y_test=y_test,
    strategy_name='Random Forest with Class Weights'
)


In [None]:
# Strategy 3: Logistic Regression with KNN Imputation and ADASYN
print("=== Strategy 3: Logistic Regression with KNN Imputation and ADASYN ===\n")
best_lr_adasyn, metrics_lr_adasyn = train_and_evaluate(
    pipeline=pipeline_lr_adasyn,
    param_grid=param_grid_lr,
    X_train=X_train,
    y_train=y_train,
    X_test=X_test,
    y_test=y_test,
    strategy_name='Logistic Regression with ADASYN'
)


In [None]:
# Strategy 4: XGBoost without SMOTE
print("=== Strategy 4: XGBoost without SMOTE ===\n")
best_xgb_no_smote, metrics_xgb_no_smote = train_and_evaluate(
    pipeline=pipeline_xgb_no_smote,
    param_grid=param_grid_xgb_no_smote,
    X_train=X_train,
    y_train=y_train,
    X_test=X_test,
    y_test=y_test,
    strategy_name='XGBoost without SMOTE'
)



In [None]:
# Strategy 5: Neural Network Classifier with ADASYN
print("=== Strategy 5: Neural Network Classifier with ADASYN ===\n")
best_nn_adasyn, metrics_nn_adasyn = train_and_evaluate(
    pipeline=pipeline_nn_adasyn,
    param_grid=param_grid_nn,
    X_train=X_train,
    y_train=y_train,
    X_test=X_test,
    y_test=y_test,
    strategy_name='Neural Network with ADASYN'
)

In [None]:
# Plot Precision-Recall Curves for each strategy with threshold markings
plot_precision_recall_curve_with_thresholds(
    y_test,
    metrics_xgb_smote['Predicted Probabilities'],
    'XGBoost with SMOTE'
)

plot_precision_recall_curve_with_thresholds(
    y_test,
    metrics_rf_class_weights['Predicted Probabilities'],
    'Random Forest with Class Weights'
)

#plot_precision_recall_curve_with_thresholds(
#    y_test,
#    metrics_lr_adasyn['Predicted Probabilities'],
#    'Logistic Regression with ADASYN'
#)

plot_precision_recall_curve_with_thresholds(
    y_test,
    metrics_xgb_no_smote['Predicted Probabilities'],
    'XGBoost without SMOTE'
)

#plot_precision_recall_curve_with_thresholds(
#    y_test,
#    metrics_nn_adasyn['Predicted Probabilities'],
#    'Neural Network with ADASYN'
#)


In [None]:
# Extend the summary DataFrame
summary_extended = pd.DataFrame({
    'Strategy': [
        'XGBoost with SMOTE',
        'Random Forest with Class Weights',
   #     'Logistic Regression with ADASYN',
        'XGBoost without SMOTE',
   #     'Neural Network with ADASYN'
    ],
    'ROC AUC': [
        metrics_xgb_smote['ROC AUC'],
        metrics_rf_class_weights['ROC AUC'],
   #     metrics_lr_adasyn['ROC AUC'],
        metrics_xgb_no_smote['ROC AUC'],
   #     metrics_nn_adasyn['ROC AUC']
    ],
    'Precision (TDE)': [
        float(metrics_xgb_smote['Classification Report'].split('\n')[2].split()[1]),
        float(metrics_rf_class_weights['Classification Report'].split('\n')[2].split()[1]),
    #    float(metrics_lr_adasyn['Classification Report'].split('\n')[2].split()[1]),
        float(metrics_xgb_no_smote['Classification Report'].split('\n')[2].split()[1]),
    #    float(metrics_nn_adasyn['Classification Report'].split('\n')[2].split()[1])
    ],
    'Recall (TDE)': [
        float(metrics_xgb_smote['Classification Report'].split('\n')[2].split()[2]),
        float(metrics_rf_class_weights['Classification Report'].split('\n')[2].split()[2]),
   #     float(metrics_lr_adasyn['Classification Report'].split('\n')[2].split()[2]),
        float(metrics_xgb_no_smote['Classification Report'].split('\n')[2].split()[2]),
   #     float(metrics_nn_adasyn['Classification Report'].split('\n')[2].split()[2])
    ],
    'F1-Score (TDE)': [
        float(metrics_xgb_smote['Classification Report'].split('\n')[2].split()[3]),
        float(metrics_rf_class_weights['Classification Report'].split('\n')[2].split()[3]),
    #    float(metrics_lr_adasyn['Classification Report'].split('\n')[2].split()[3]),
        float(metrics_xgb_no_smote['Classification Report'].split('\n')[2].split()[3]),
   #     float(metrics_nn_adasyn['Classification Report'].split('\n')[2].split()[3])
    ]
})

print("=== Extended Performance Summary ===")
print(summary_extended)


In [None]:
# Example: Verify Thresholds for 80% and 95% Precision in Random Forest with Class Weights
desired_precisions = [0.80, 0.95]
strategy_name = 'Random Forest with Class Weights'

for dp in desired_precisions:
    thresh = find_threshold_for_precision(y_test, metrics_rf_class_weights['Predicted Probabilities'], dp)
    if thresh is not None:
        print(f"{strategy_name}: Threshold for {int(dp*100)}% precision = {thresh:.4f}")
    else:
        print(f"{strategy_name}: {int(dp*100)}% precision not achievable.")


In [None]:
# Plot Feature Importance for XGBoost without SMOTE
fig, ax = plt.subplots(figsize=(12, 8))
xgb.plot_importance(
    best_xgb_no_smote.named_steps['classifier'],
    max_num_features=20,
    importance_type='weight',
    ax=ax,
    show_values=False
)
# Set feature names
feature_names = feature_columns  # Ensure feature names are correctly mapped
ax.set_yticklabels(feature_names, fontsize=10)
ax.set_xlabel('Feature Importance (Weight)', fontsize=12)
ax.set_title('XGBoost Feature Importance - Without SMOTE Classifier', fontsize=14)
plt.show()
