# SNCOSMO Feature Extraction for TDE Classification

**Purpose:** Fit SN templates (SALT2, SALT3, Nugent) to all objects and save features.

**Key insight:** TDEs should fit SN templates poorly (high χ²), while actual SNe fit well (low χ²).

**Outputs:**
- `train_sncosmo_features.parquet` - Fit features for training set
- `test_sncosmo_features.parquet` - Fit features for test set

**Runtime:** ~2-4 hours on Kaggle GPU (parallelized)

In [1]:
import os
import json
import warnings
warnings.filterwarnings('ignore')
!pip install iminuit

import numpy as np
import pandas as pd
from pathlib import Path
from astropy.table import Table
from joblib import Parallel, delayed
from tqdm.auto import tqdm
import time

# Install sncosmo if needed
try:
    import sncosmo
except ImportError:
    import subprocess
    subprocess.run(["pip", "-q", "install", "sncosmo"], check=True)
    import sncosmo

# Install extinction if needed
try:
    from extinction import fitzpatrick99
except ImportError:
    import subprocess
    subprocess.run(["pip", "-q", "install", "extinction==0.4.7"], check=True)
    from extinction import fitzpatrick99

print(f"sncosmo version: {sncosmo.__version__}")

Collecting iminuit
  Downloading iminuit-2.32.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (11 kB)
Downloading iminuit-2.32.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (448 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m448.2/448.2 kB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: iminuit
Successfully installed iminuit-2.32.0
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 241.4/241.4 kB 5.3 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 627.8/627.8 kB 20.9 MB/s eta 0:00:00
sncosmo version: 2.12.1


## Configuration

In [2]:
# Paths - adjust for Kaggle
DATASET_DIR = Path("/kaggle/input/mallorn-astronomical-classification-challenge")
OUTPUT_DIR = Path("/kaggle/working")

# Fitting parameters
N_JOBS = 8  # Parallel workers
TMIN_REL = -50.0  # Days before peak to include
TMAX_REL = 400.0  # Days after peak to include
MIN_POINTS = 10  # Minimum points for fitting

# LSST band mapping for sncosmo
LSST_BAND_MAP = {"u": "lsstu", "g": "lsstg", "r": "lsstr", 
                 "i": "lssti", "z": "lsstz", "y": "lssty"}

# Effective wavelengths for extinction correction
EFF_WL = {
    "u": 3641.0, "g": 4704.0, "r": 6155.0,
    "i": 7504.0, "z": 8695.0, "y": 10056.0,
}

## Test Cell: Verify sncosmo Works

Run this first to make sure sncosmo is properly installed and models load correctly.

In [3]:
# Quick test that sncosmo models load correctly
print("Testing sncosmo model loading...")

test_models = [
    ("salt2", "salt2"),
    ("salt3", "salt3"),
    ("nugent-sn1a", "template"),
]

for source, kind in test_models:
    try:
        model = sncosmo.Model(source=source)
        print(f"  ✓ {source}: params = {model.param_names}")
    except Exception as e:
        print(f"  ✗ {source}: FAILED - {e}")

# Test LSST bandpasses
print("\nTesting LSST bandpasses...")
for band in ["lsstu", "lsstg", "lsstr", "lssti", "lsstz", "lssty"]:
    try:
        bp = sncosmo.get_bandpass(band)
        print(f"  ✓ {band}: λ_eff = {bp.wave_eff:.0f} Å")
    except Exception as e:
        print(f"  ✗ {band}: FAILED - {e}")

print("\n✓ All tests passed! Ready to proceed.")

Testing sncosmo model loading...
Downloading https://sncosmo.github.io/data/models/salt2/salt2-k21-frag.tar.gz [Done]
  ✓ salt2: params = ['z', 't0', 'x0', 'x1', 'c']
Downloading https://sncosmo.github.io/data/models/salt3/salt3-f22.tar.gz [Done]
  ✓ salt3: params = ['z', 't0', 'x0', 'x1', 'c']
Downloading http://c3.lbl.gov/nugent/templates/sn1a_flux.v1.2.dat.gz [Done]
  ✓ nugent-sn1a: params = ['z', 't0', 'amplitude']

Testing LSST bandpasses...
Downloading https://sncosmo.github.io/data/bandpasses/lsst/total_u.dat [Done]
  ✓ lsstu: λ_eff = 3671 Å
Downloading https://sncosmo.github.io/data/bandpasses/lsst/total_g.dat [Done]
  ✓ lsstg: λ_eff = 4827 Å
Downloading https://sncosmo.github.io/data/bandpasses/lsst/total_r.dat [Done]
  ✓ lsstr: λ_eff = 6223 Å
Downloading https://sncosmo.github.io/data/bandpasses/lsst/total_i.dat [Done]
  ✓ lssti: λ_eff = 7546 Å
Downloading https://sncosmo.github.io/data/bandpasses/lsst/total_z.dat [Done]
  ✓ lsstz: λ_eff = 8691 Å
Downloading https://sncosmo.g

## Load Data

In [4]:
print("Loading data...")

train_log = pd.read_csv(DATASET_DIR / "train_log.csv")
test_log = pd.read_csv(DATASET_DIR / "test_log.csv")

print(f"Train objects: {len(train_log):,}")
print(f"Test objects: {len(test_log):,}")

# Load all lightcurve splits
split_dirs = sorted([p for p in DATASET_DIR.glob("split_*") if p.is_dir()])
print(f"Found {len(split_dirs)} data splits")

train_lc = pd.concat([pd.read_csv(d / "train_full_lightcurves.csv") for d in split_dirs], ignore_index=True)
test_lc = pd.concat([pd.read_csv(d / "test_full_lightcurves.csv") for d in split_dirs], ignore_index=True)

print(f"Train observations: {len(train_lc):,}")
print(f"Test observations: {len(test_lc):,}")

Loading data...
Train objects: 3,043
Test objects: 7,135
Found 20 data splits
Train observations: 479,384
Test observations: 1,145,125


## Preprocess Lightcurves

In [5]:
FILTERS = ["u", "g", "r", "i", "z", "y"]
f2i = {f: i for i, f in enumerate(FILTERS)}

def prep_lightcurves(df):
    df = df.copy()
    df["Filter"] = df["Filter"].astype(str).str.strip().str.lower()
    df["filter_id"] = df["Filter"].map(f2i).astype("int64")
    df["Time (MJD)"] = df["Time (MJD)"].astype("float64")
    df["Flux"] = df["Flux"].astype("float64")
    df["Flux_err"] = df["Flux_err"].astype("float64")
    return df

train_lc = prep_lightcurves(train_lc)
test_lc = prep_lightcurves(test_lc)

# Merge metadata
train_lc = train_lc.merge(train_log[["object_id", "EBV", "Z"]], on="object_id", how="left")
test_lc = test_lc.merge(test_log[["object_id", "EBV", "Z"]], on="object_id", how="left")

print("Preprocessed lightcurves.")

Preprocessed lightcurves.


In [6]:
def apply_deextinction(df):
    """Apply Milky Way extinction correction."""
    df = df.copy()
    df["Flux_corr"] = df["Flux"].copy()
    df["Fluxerr_corr"] = df["Flux_err"].copy()
    
    for filt in df["Filter"].unique():
        if filt not in EFF_WL:
            continue
        filt_mask = df["Filter"] == filt
        wl_arr = np.array([EFF_WL[filt]])
        
        for ebv_val in df.loc[filt_mask, "EBV"].unique():
            if pd.isna(ebv_val) or ebv_val == 0:
                continue
            mask = filt_mask & (df["EBV"] == ebv_val)
            A_lambda = fitzpatrick99(wl_arr, ebv_val * 3.1)[0]
            factor = 10 ** (A_lambda / 2.5)
            df.loc[mask, "Flux_corr"] = df.loc[mask, "Flux"] * factor
            df.loc[mask, "Fluxerr_corr"] = df.loc[mask, "Flux_err"] * factor
    
    return df

print("Applying extinction correction...")
train_lc = apply_deextinction(train_lc)
test_lc = apply_deextinction(test_lc)
print("Done.")

# Drop bad rows
def drop_bad_rows(lc, name=""):
    before = len(lc)
    keep = (
        np.isfinite(lc["Time (MJD)"].values) &
        np.isfinite(lc["Flux_corr"].values) &
        np.isfinite(lc["Fluxerr_corr"].values)
    )
    lc = lc[keep].copy()
    print(f"{name}: Dropped {before - len(lc):,} bad rows")
    return lc

train_lc = drop_bad_rows(train_lc, "Train")
test_lc = drop_bad_rows(test_lc, "Test")

Applying extinction correction...
Done.
Train: Dropped 891 bad rows
Test: Dropped 2,022 bad rows


## SNCOSMO Fitting Functions

In [7]:
def lc_to_sncosmo_table(
    lc: pd.DataFrame,
    time_col="Time (MJD)",
    flux_col="Flux_corr",
    err_col="Fluxerr_corr",
    band_col="Filter",
    zp=23.9,
    zpsys="ab",
):
    """Convert lightcurve DataFrame to sncosmo Table format."""
    df = lc.copy()
    df["time"] = df[time_col].astype(float)
    df["flux"] = df[flux_col].astype(float)
    df["fluxerr"] = df[err_col].astype(float)

    filt = df[band_col].astype(str).str.lower().str.strip()
    df["band"] = filt.map(LSST_BAND_MAP)

    df = df.dropna(subset=["time", "flux", "fluxerr", "band"])
    df = df.replace([np.inf, -np.inf], np.nan).dropna(subset=["time", "flux", "fluxerr"])
    df = df[df["fluxerr"] > 0]

    df["zp"] = float(zp)
    df["zpsys"] = str(zpsys)

    return Table.from_pandas(df[["time", "band", "flux", "fluxerr", "zp", "zpsys"]])

In [8]:
def robust_peak_time(
    lc: pd.DataFrame,
    time_col="Time (MJD)",
    flux_col="Flux_corr",
    err_col="Fluxerr_corr",
    snr_min=3.0,
    baseline_snr_max=2.0,
    peak_k_sigma=2.0,
    max_err_quantile=0.90,
):
    """
    Robust peak finding:
    - Ignore worst error points
    - Define baseline from |SNR| <= baseline_snr_max
    - Require peak >= baseline + k*sigma_base and SNR >= snr_min
    """
    df = lc[[time_col, flux_col, err_col]].dropna().copy()
    df = df[df[err_col] > 0]
    if len(df) < 5:
        raise ValueError("Too few valid points.")

    # Drop huge-error points for peak detection
    err_thr = df[err_col].quantile(max_err_quantile)
    dfp = df[df[err_col] <= err_thr].copy()
    if len(dfp) < 5:
        dfp = df.copy()

    dfp["snr"] = dfp[flux_col] / dfp[err_col]

    # Baseline from low-SNR points
    base = dfp[dfp["snr"].abs() <= baseline_snr_max]
    if len(base) >= 5:
        baseline = float(np.median(base[flux_col]))
        sigma_base = float(1.4826 * np.median(np.abs(base[flux_col] - baseline)))
        if not np.isfinite(sigma_base) or sigma_base <= 0:
            sigma_base = float(np.median(base[err_col]))
    else:
        baseline = float(np.median(dfp[flux_col]))
        sigma_base = float(np.median(dfp[err_col]))

    cand = dfp[(dfp["snr"] >= snr_min) & (dfp[flux_col] >= baseline + peak_k_sigma * sigma_base)]
    if len(cand) > 0:
        cand = cand.sort_values([flux_col, "snr"], ascending=False)
        t0 = float(cand.iloc[0][time_col])
        info = {"method": "snr+baseline", "baseline": baseline, "sigma_base": sigma_base, "n_cand": int(len(cand))}
        return t0, info

    # Fallback: max SNR
    i = dfp["snr"].idxmax()
    t0 = float(dfp.loc[i, time_col])
    info = {"method": "max_snr_fallback", "baseline": baseline, "sigma_base": sigma_base, "n_cand": 0}
    return t0, info


def slice_window(lc, t_peak, tmin_rel=-50.0, tmax_rel=200.0, time_col="Time (MJD)"):
    """Extract lightcurve window around peak."""
    m = (lc[time_col] >= t_peak + tmin_rel) & (lc[time_col] <= t_peak + tmax_rel)
    return lc.loc[m].copy()


def compute_aic_bic(chisq, n, k):
    """Compute AIC and BIC."""
    if n <= 0:
        return np.nan, np.nan
    return float(chisq + 2 * k), float(chisq + k * np.log(n))

In [9]:
def fit_one(tbl: Table, z: float, source: str, kind: str):
    """
    Fit a single sncosmo model to the lightcurve table.
    
    Returns: chisq, ndof, rchisq, aic, bic, param_dict
    """
    model = sncosmo.Model(source=source)
    model.set(z=float(z))

    # Init t0 near median time
    model.set(t0=float(np.median(tbl["time"])))

    if kind == "salt2":
        # SALT2/SALT3 parametric models
        model.set(x1=0.0, c=0.0)
        params = ["t0", "x0", "x1", "c"]
        bounds = {"x1": (-5, 5), "c": (-0.3, 0.6)}
    else:
        # Template time series: fit t0 + amplitude
        pnames = list(model.param_names)
        amp = None
        for cand in ["amplitude", "x0", "norm"]:
            if cand in pnames:
                amp = cand
                break
        if amp is None:
            raise ValueError(f"No amplitude-like parameter found; params={pnames}")
        params = ["t0", amp]
        bounds = None

    res, mfit = sncosmo.fit_lc(tbl, model, params, bounds=bounds)

    chisq = float(res.chisq)
    ndof = int(res.ndof)
    rchisq = float(chisq / ndof) if ndof > 0 else np.nan
    aic, bic = compute_aic_bic(chisq, n=len(tbl), k=len(params))
    param_dict = dict(zip(mfit.param_names, map(float, mfit.parameters)))

    return chisq, ndof, rchisq, aic, bic, param_dict

In [10]:
# Model bank: SN templates to fit
MODEL_BANK = [
    # Ia-like (parametric) - most important for discrimination
    {"name": "salt2", "source": "salt2", "kind": "salt2"},
    {"name": "salt3", "source": "salt3", "kind": "salt2"},

    # Ia subtype templates (time-series)
    {"name": "nugent-sn1a", "source": "nugent-sn1a", "kind": "template"},
    {"name": "nugent-sn91t", "source": "nugent-sn91t", "kind": "template"},
    {"name": "nugent-sn91bg", "source": "nugent-sn91bg", "kind": "template"},

    # Core collapse
    {"name": "nugent-sn1bc", "source": "nugent-sn1bc", "kind": "template"},  # Ib/c
    {"name": "nugent-sn2p", "source": "nugent-sn2p", "kind": "template"},    # II-P
    {"name": "nugent-sn2l", "source": "nugent-sn2l", "kind": "template"},    # II-L
    {"name": "nugent-sn2n", "source": "nugent-sn2n", "kind": "template"},    # IIn
]

print(f"Model bank: {len(MODEL_BANK)} models")
for m in MODEL_BANK:
    print(f"  - {m['name']} ({m['kind']})")

Model bank: 9 models
  - salt2 (salt2)
  - salt3 (salt2)
  - nugent-sn1a (template)
  - nugent-sn91t (template)
  - nugent-sn91bg (template)
  - nugent-sn1bc (template)
  - nugent-sn2p (template)
  - nugent-sn2l (template)
  - nugent-sn2n (template)


## Test Fitting on a Few Objects

Before running the full extraction, test on 3 objects to make sure everything works.

In [11]:
# Test on 3 objects (1 TDE, 2 non-TDE if possible)
print("="*60)
print("TEST FITTING ON 3 OBJECTS")
print("="*60)

# Get sample objects
tde_ids = train_log[train_log["target"] == 1]["object_id"].values[:1]
non_tde_ids = train_log[train_log["target"] == 0]["object_id"].values[:2]
test_ids = list(tde_ids) + list(non_tde_ids)

print(f"Testing on objects: {test_ids}")
print(f"  TDE: {list(tde_ids)}")
print(f"  Non-TDE: {list(non_tde_ids)}")

z_dict = train_log.set_index("object_id")["Z"].to_dict()

for oid in test_ids:
    is_tde = oid in tde_ids
    label = "TDE" if is_tde else "Non-TDE"
    z = z_dict.get(oid, 0.1)
    
    print(f"\n--- Object {oid} ({label}, z={z:.3f}) ---")
    
    lc = train_lc[train_lc["object_id"] == oid].copy()
    print(f"  Points: {len(lc)}")
    
    if len(lc) < 10:
        print(f"  ✗ Too few points")
        continue
    
    # Find peak
    try:
        t_peak, peak_info = robust_peak_time(lc)
        print(f"  Peak: t={t_peak:.1f}, method={peak_info['method']}")
    except Exception as e:
        print(f"  ✗ Peak finding failed: {e}")
        continue
    
    # Window
    lc_w = slice_window(lc, t_peak, tmin_rel=TMIN_REL, tmax_rel=TMAX_REL)
    print(f"  Window points: {len(lc_w)}")
    
    # Convert to sncosmo table
    tbl = lc_to_sncosmo_table(lc_w)
    print(f"  Fit points: {len(tbl)}")
    
    if len(tbl) < 10:
        print(f"  ✗ Too few points after cleaning")
        continue
    
    # Test SALT2 fit only (fastest)
    try:
        chisq, ndof, rchisq, aic, bic, params = fit_one(tbl, z=z, source="salt2", kind="salt2")
        print(f"  ✓ SALT2 fit: rchisq={rchisq:.2f}, x1={params.get('x1', 0):.2f}, c={params.get('c', 0):.2f}")
    except Exception as e:
        print(f"  ✗ SALT2 fit failed: {e}")
    
    # Test one template fit
    try:
        chisq, ndof, rchisq, aic, bic, params = fit_one(tbl, z=z, source="nugent-sn1a", kind="template")
        print(f"  ✓ Nugent-SN1a fit: rchisq={rchisq:.2f}")
    except Exception as e:
        print(f"  ✗ Nugent-SN1a fit failed: {e}")

print("\n" + "="*60)
print("TEST COMPLETE - If no errors above, proceed with full extraction")
print("="*60)

TEST FITTING ON 3 OBJECTS
Testing on objects: ['amon_imloth_luin', 'Dornhoth_fervain_onodrim', 'Dornhoth_galadh_ylf']
  TDE: ['amon_imloth_luin']
  Non-TDE: ['Dornhoth_fervain_onodrim', 'Dornhoth_galadh_ylf']

--- Object amon_imloth_luin (TDE, z=0.777) ---
  Points: 119
  Peak: t=63938.4, method=snr+baseline
  Window points: 30
  Fit points: 30
  ✓ SALT2 fit: rchisq=4.72, x1=5.00, c=-0.30
  ✓ Nugent-SN1a fit: rchisq=8.24

--- Object Dornhoth_fervain_onodrim (Non-TDE, z=3.049) ---
  Points: 65
  Peak: t=63772.2, method=snr+baseline
  Window points: 22
  Fit points: 22
  ✗ SALT2 fit failed: No data points with S/N > 5.0. Initial guessing failed.
  ✓ Nugent-SN1a fit: rchisq=195.00

--- Object Dornhoth_galadh_ylf (Non-TDE, z=0.432) ---
  Points: 167
  Peak: t=62749.9, method=snr+baseline
  Window points: 41
  Fit points: 41
  ✓ SALT2 fit: rchisq=5.52, x1=1.51, c=0.60
  ✓ Nugent-SN1a fit: rchisq=13.41

TEST COMPLETE - If no errors above, proceed with full extraction


In [12]:
def process_object(oid: int, lc_all: pd.DataFrame, z: float,
                   tmin_rel=-50.0, tmax_rel=200.0, min_points=10):
    """
    Process a single object: find peak, fit all models.
    
    Returns list of result dicts (one per model).
    """
    lc = lc_all[lc_all["object_id"] == oid].copy()
    
    if len(lc) < min_points:
        return [{"object_id": oid, "z": z, "model": m["name"], 
                 "status": "too_few_points", "n_points_total": len(lc)} 
                for m in MODEL_BANK]

    # Find peak
    try:
        t_peak, peak_info = robust_peak_time(lc)
    except Exception as e:
        return [{"object_id": oid, "z": z, "model": m["name"],
                 "status": "peak_fail", "error": str(e)} 
                for m in MODEL_BANK]

    # Window around peak
    lc_w = slice_window(lc, t_peak, tmin_rel=tmin_rel, tmax_rel=tmax_rel)
    if len(lc_w) < min_points:
        return [{"object_id": oid, "z": z, "t_peak": t_peak, "model": m["name"],
                 "status": "too_few_points_window", "n_points_window": len(lc_w),
                 "peak_method": peak_info.get("method", "unknown")} 
                for m in MODEL_BANK]

    # Convert to sncosmo table
    tbl = lc_to_sncosmo_table(lc_w)
    if len(tbl) < min_points:
        return [{"object_id": oid, "z": z, "t_peak": t_peak, "model": m["name"],
                 "status": "too_few_points_after_clean", 
                 "n_points_window": len(lc_w), "n_points_fit": len(tbl),
                 "peak_method": peak_info.get("method", "unknown")} 
                for m in MODEL_BANK]

    # Fit all models
    rows = []
    for m in MODEL_BANK:
        row = {
            "object_id": oid,
            "z": z,
            "t_peak": t_peak,
            "n_points_window": len(lc_w),
            "n_points_fit": len(tbl),
            "model": m["name"],
            "kind": m["kind"],
            "peak_method": peak_info.get("method", "unknown"),
        }
        try:
            chisq, ndof, rchisq, aic, bic, params = fit_one(
                tbl, z=z, source=m["source"], kind=m["kind"]
            )
            row.update({
                "status": "ok",
                "chisq": chisq,
                "ndof": ndof,
                "rchisq": rchisq,
                "aic": aic,
                "bic": bic,
            })
            # Add model-specific params
            if m["kind"] == "salt2":
                row["x0"] = params.get("x0", np.nan)
                row["x1"] = params.get("x1", np.nan)
                row["c"] = params.get("c", np.nan)
                row["t0_fit"] = params.get("t0", np.nan)
        except Exception as e:
            row.update({"status": "fit_fail", "error": str(e)})
        rows.append(row)

    return rows

## Run Fitting on All Objects

In [13]:
def run_all_fits(lc_df, meta_df, output_path, n_jobs=8, desc="Fitting"):
    """
    Run sncosmo fitting on all objects.
    
    Args:
        lc_df: Lightcurve DataFrame with Flux_corr, Fluxerr_corr
        meta_df: Metadata DataFrame with object_id and Z columns
        output_path: Where to save results
        n_jobs: Parallel workers
    
    Returns:
        DataFrame with fit results
    """
    object_ids = meta_df["object_id"].unique()
    z_dict = meta_df.set_index("object_id")["Z"].to_dict()
    
    print(f"Processing {len(object_ids):,} objects with {n_jobs} workers...")
    start_time = time.time()
    
    def work(oid):
        z = z_dict.get(oid, 0.1)  # Default z if missing
        if pd.isna(z) or z <= 0:
            z = 0.1
        return process_object(oid, lc_df, z, 
                              tmin_rel=TMIN_REL, tmax_rel=TMAX_REL, 
                              min_points=MIN_POINTS)
    
    # Parallel execution with progress
    nested = Parallel(n_jobs=n_jobs, backend="loky", verbose=5)(
        delayed(work)(oid) for oid in object_ids
    )
    
    # Flatten results
    rows = [r for sub in nested for r in sub]
    result_df = pd.DataFrame(rows)
    
    elapsed = time.time() - start_time
    print(f"Done in {elapsed/60:.1f} minutes")
    
    # Save
    result_df.to_parquet(output_path, index=False)
    print(f"Saved to {output_path}")
    
    return result_df

In [14]:
# Run on train set
print("="*60)
print("FITTING TRAIN SET")
print("="*60)

train_fits = run_all_fits(
    train_lc, train_log,
    output_path=OUTPUT_DIR / "train_sncosmo_raw.parquet",
    n_jobs=N_JOBS,
    desc="Train"
)

print(f"\nTrain fit status:")
print(train_fits["status"].value_counts())

FITTING TRAIN SET
Processing 3,043 objects with 8 workers...


[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.


Downloading http://c3.lbl.gov/nugent/templates/sn91t_flux.v1.1.dat.gz [Done]




Downloading http://c3.lbl.gov/nugent/templates/sn91bg_flux.v1.1.dat.gz [Done]
Downloading http://c3.lbl.gov/nugent/templates/sn91bg_flux.v1.1.dat.gz [Done]
Downloading http://c3.lbl.gov/nugent/templates/sn1bc_flux.v1.1.dat.gz [Done]
Downloading http://c3.lbl.gov/nugent/templates/sn2p_flux.v1.2.dat.gz [Done]
Downloading http://c3.lbl.gov/nugent/templates/sn2l_flux.v1.2.dat.gz [Done]
Downloading http://c3.lbl.gov/nugent/templates/sn2n_flux.v2.1.dat.gz [Done]


[Parallel(n_jobs=8)]: Done   2 tasks      | elapsed:   20.4s
[Parallel(n_jobs=8)]: Done  56 tasks      | elapsed:   57.7s
[Parallel(n_jobs=8)]: Done 146 tasks      | elapsed:  2.0min
[Parallel(n_jobs=8)]: Done 272 tasks      | elapsed:  3.4min
[Parallel(n_jobs=8)]: Done 434 tasks      | elapsed:  5.4min
[Parallel(n_jobs=8)]: Done 632 tasks      | elapsed:  7.9min
[Parallel(n_jobs=8)]: Done 866 tasks      | elapsed: 10.8min
[Parallel(n_jobs=8)]: Done 1136 tasks      | elapsed: 14.1min
[Parallel(n_jobs=8)]: Done 1442 tasks      | elapsed: 17.8min
[Parallel(n_jobs=8)]: Done 1784 tasks      | elapsed: 21.9min
[Parallel(n_jobs=8)]: Done 2162 tasks      | elapsed: 26.6min
[Parallel(n_jobs=8)]: Done 2576 tasks      | elapsed: 31.6min
[Parallel(n_jobs=8)]: Done 3026 tasks      | elapsed: 37.7min
[Parallel(n_jobs=8)]: Done 3043 out of 3043 | elapsed: 37.9min finished


Done in 37.9 minutes
Saved to /kaggle/working/train_sncosmo_raw.parquet

Train fit status:
status
ok                       25957
too_few_points_window     1161
fit_fail                   269
Name: count, dtype: int64


In [15]:
# Run on test set
print("="*60)
print("FITTING TEST SET") 
print("="*60)

test_fits = run_all_fits(
    test_lc, test_log,
    output_path=OUTPUT_DIR / "test_sncosmo_raw.parquet",
    n_jobs=N_JOBS,
    desc="Test"
)

print(f"\nTest fit status:")
print(test_fits["status"].value_counts())

FITTING TEST SET
Processing 7,135 objects with 8 workers...


[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   2 tasks      | elapsed:    3.2s
[Parallel(n_jobs=8)]: Done  56 tasks      | elapsed:   45.9s
[Parallel(n_jobs=8)]: Done 146 tasks      | elapsed:  2.0min
[Parallel(n_jobs=8)]: Done 272 tasks      | elapsed:  3.7min
[Parallel(n_jobs=8)]: Done 434 tasks      | elapsed:  6.4min
[Parallel(n_jobs=8)]: Done 632 tasks      | elapsed:  9.0min
[Parallel(n_jobs=8)]: Done 866 tasks      | elapsed: 12.2min
[Parallel(n_jobs=8)]: Done 1136 tasks      | elapsed: 16.4min
[Parallel(n_jobs=8)]: Done 1442 tasks      | elapsed: 20.4min
[Parallel(n_jobs=8)]: Done 1784 tasks      | elapsed: 25.3min
[Parallel(n_jobs=8)]: Done 2162 tasks      | elapsed: 30.7min
[Parallel(n_jobs=8)]: Done 2576 tasks      | elapsed: 36.4min
[Parallel(n_jobs=8)]: Done 3026 tasks      | elapsed: 42.6min
[Parallel(n_jobs=8)]: Done 3512 tasks      | elapsed: 49.8min
[Parallel(n_jobs=8)]: Done 4034 tasks      | elapsed: 56.9min
[P

Done in 101.6 minutes
Saved to /kaggle/working/test_sncosmo_raw.parquet

Test fit status:
status
ok                       60737
too_few_points_window     2745
fit_fail                   733
Name: count, dtype: int64


## Pivot to Feature Format

Convert from long format (one row per model) to wide format (one row per object) for easy integration into training pipeline.

In [16]:
def pivot_to_features(fit_df):
    """
    Pivot fit results to feature format (one row per object).
    
    Creates features like:
    - sn_salt2_rchisq, sn_salt3_rchisq, etc.
    - sn_salt2_x1, sn_salt2_c (SALT params)
    - sn_best_ia_rchisq (min of Ia fits)
    - sn_is_good_ia_fit (boolean)
    """
    # Get successful fits
    ok = fit_df[fit_df["status"] == "ok"].copy()
    
    # Get unique object IDs from full df
    all_objects = fit_df["object_id"].unique()
    
    # Pivot rchisq for each model
    features = pd.DataFrame({"object_id": all_objects})
    
    for m in MODEL_BANK:
        model_name = m["name"].replace("-", "_")
        model_fits = ok[ok["model"] == m["name"]][["object_id", "rchisq", "chisq", "ndof"]]
        model_fits = model_fits.rename(columns={
            "rchisq": f"sn_{model_name}_rchisq",
            "chisq": f"sn_{model_name}_chisq",
            "ndof": f"sn_{model_name}_ndof",
        })
        features = features.merge(model_fits, on="object_id", how="left")
        
        # SALT-specific params
        if m["kind"] == "salt2":
            salt_params = ok[ok["model"] == m["name"]][["object_id", "x0", "x1", "c"]]
            salt_params = salt_params.rename(columns={
                "x0": f"sn_{model_name}_x0",
                "x1": f"sn_{model_name}_x1",
                "c": f"sn_{model_name}_c",
            })
            features = features.merge(salt_params, on="object_id", how="left")
    
    # Aggregate features
    ia_models = ["salt2", "salt3", "nugent_sn1a", "nugent_sn91t", "nugent_sn91bg"]
    ia_cols = [f"sn_{m}_rchisq" for m in ia_models if f"sn_{m}_rchisq" in features.columns]
    
    if ia_cols:
        # Best (lowest) Ia fit
        features["sn_best_ia_rchisq"] = features[ia_cols].min(axis=1)
        
        # Is it a "good" Ia fit? (rchisq < 2 is reasonable)
        features["sn_is_good_ia_fit"] = (features["sn_best_ia_rchisq"] < 2.0).astype(int)
        
        # Worst Ia fit (for comparison)
        features["sn_worst_ia_rchisq"] = features[ia_cols].max(axis=1)
    
    # Core collapse models
    cc_models = ["nugent_sn1bc", "nugent_sn2p", "nugent_sn2l", "nugent_sn2n"]
    cc_cols = [f"sn_{m}_rchisq" for m in cc_models if f"sn_{m}_rchisq" in features.columns]
    
    if cc_cols:
        features["sn_best_cc_rchisq"] = features[cc_cols].min(axis=1)
        features["sn_is_good_cc_fit"] = (features["sn_best_cc_rchisq"] < 2.0).astype(int)
    
    # Overall best SN fit
    all_rchisq_cols = [c for c in features.columns if c.endswith("_rchisq") and c.startswith("sn_") 
                      and "best" not in c and "worst" not in c]
    if all_rchisq_cols:
        features["sn_best_any_rchisq"] = features[all_rchisq_cols].min(axis=1)
        features["sn_is_good_any_fit"] = (features["sn_best_any_rchisq"] < 2.0).astype(int)
    
    # Add peak info and point counts
    peak_info = fit_df.groupby("object_id").first()[["t_peak", "n_points_fit", "peak_method"]].reset_index()
    features = features.merge(peak_info, on="object_id", how="left")
    features = features.rename(columns={
        "t_peak": "sn_t_peak",
        "n_points_fit": "sn_n_points_fit",
        "peak_method": "sn_peak_method",
    })
    
    return features


print("Pivoting to feature format...")
train_features = pivot_to_features(train_fits)
test_features = pivot_to_features(test_fits)

print(f"Train features: {train_features.shape}")
print(f"Test features: {test_features.shape}")

Pivoting to feature format...
Train features: (3043, 44)
Test features: (7135, 44)


In [17]:
# Save final feature files
train_features.to_parquet(OUTPUT_DIR / "train_sncosmo_features.parquet", index=False)
test_features.to_parquet(OUTPUT_DIR / "test_sncosmo_features.parquet", index=False)

print(f"Saved train_sncosmo_features.parquet: {len(train_features)} objects, {len(train_features.columns)} features")
print(f"Saved test_sncosmo_features.parquet: {len(test_features)} objects, {len(test_features.columns)} features")

Saved train_sncosmo_features.parquet: 3043 objects, 44 features
Saved test_sncosmo_features.parquet: 7135 objects, 44 features


In [18]:
# Show feature overview
print("\n" + "="*60)
print("Feature Overview")
print("="*60)
print(f"\nColumns: {list(train_features.columns)}")
print(f"\nTrain features describe:")
display(train_features.describe())


Feature Overview

Columns: ['object_id', 'sn_salt2_rchisq', 'sn_salt2_chisq', 'sn_salt2_ndof', 'sn_salt2_x0', 'sn_salt2_x1', 'sn_salt2_c', 'sn_salt3_rchisq', 'sn_salt3_chisq', 'sn_salt3_ndof', 'sn_salt3_x0', 'sn_salt3_x1', 'sn_salt3_c', 'sn_nugent_sn1a_rchisq', 'sn_nugent_sn1a_chisq', 'sn_nugent_sn1a_ndof', 'sn_nugent_sn91t_rchisq', 'sn_nugent_sn91t_chisq', 'sn_nugent_sn91t_ndof', 'sn_nugent_sn91bg_rchisq', 'sn_nugent_sn91bg_chisq', 'sn_nugent_sn91bg_ndof', 'sn_nugent_sn1bc_rchisq', 'sn_nugent_sn1bc_chisq', 'sn_nugent_sn1bc_ndof', 'sn_nugent_sn2p_rchisq', 'sn_nugent_sn2p_chisq', 'sn_nugent_sn2p_ndof', 'sn_nugent_sn2l_rchisq', 'sn_nugent_sn2l_chisq', 'sn_nugent_sn2l_ndof', 'sn_nugent_sn2n_rchisq', 'sn_nugent_sn2n_chisq', 'sn_nugent_sn2n_ndof', 'sn_best_ia_rchisq', 'sn_is_good_ia_fit', 'sn_worst_ia_rchisq', 'sn_best_cc_rchisq', 'sn_is_good_cc_fit', 'sn_best_any_rchisq', 'sn_is_good_any_fit', 'sn_t_peak', 'sn_n_points_fit', 'sn_peak_method']

Train features describe:


Unnamed: 0,sn_salt2_rchisq,sn_salt2_chisq,sn_salt2_ndof,sn_salt2_x0,sn_salt2_x1,sn_salt2_c,sn_salt3_rchisq,sn_salt3_chisq,sn_salt3_ndof,sn_salt3_x0,...,sn_nugent_sn2n_ndof,sn_best_ia_rchisq,sn_is_good_ia_fit,sn_worst_ia_rchisq,sn_best_cc_rchisq,sn_is_good_cc_fit,sn_best_any_rchisq,sn_is_good_any_fit,sn_t_peak,sn_n_points_fit
count,2855.0,2857.0,2857.0,2857.0,2857.0,2857.0,2855.0,2857.0,2857.0,2857.0,...,2892.0,2892.0,3043.0,2892.0,2892.0,3043.0,2892.0,3043.0,3043.0,2914.0
mean,112.70137,3263.227,34.19566,0.000154,2.132209,-0.08037258,106.894622,3564.036,34.760238,6.3e-05,...,38.578492,68.516217,0.187315,174.128284,70.597347,0.037134,56.831939,0.202432,63053.808391,40.548387
std,1095.63124,39595.86,24.946747,0.005309,2.977025,0.2551897,973.268461,43896.39,25.161487,0.000804,...,25.550874,774.369304,0.390228,1254.072961,715.300678,0.189122,701.358947,0.401878,644.595981,25.480953
min,0.53632,6.425687e-06,-1.0,-0.018585,-5.0,-0.3,0.383443,1.533772,-1.0,-0.000229,...,6.0,0.383443,0.0,1.590887,0.666842,0.0,0.383443,0.0,61016.5899,10.0
25%,3.218742,105.7841,19.0,8e-06,-0.132436,-0.2999999,3.028547,99.51818,20.0,9e-06,...,24.0,2.658009,0.0,12.221357,4.838351,0.0,2.368605,0.0,62560.60515,26.0
50%,11.434668,299.8391,31.0,1.5e-05,2.466222,-0.1330139,10.444396,287.9483,32.0,1.8e-05,...,36.0,8.516164,0.0,25.384791,10.003116,0.0,6.649986,0.0,63095.1727,38.0
75%,31.982909,777.4799,44.0,3e-05,4.999999,4.255494e-18,29.268749,754.6232,45.0,3.7e-05,...,48.0,21.5593,0.0,68.011131,24.038528,0.0,17.086969,0.0,63588.12405,50.0
max,35620.588691,1531685.0,308.0,0.272771,5.0,0.6,33842.080946,1759788.0,308.0,0.042431,...,345.0,32361.317346,1.0,41804.392084,29103.633195,1.0,29103.633195,1.0,64623.0971,347.0


## Quick Validation: Check Discriminative Power

In [19]:
# Merge with labels to check discriminative power
train_check = train_features.merge(train_log[["object_id", "target"]], on="object_id")

print("\n" + "="*60)
print("DISCRIMINATIVE POWER CHECK")
print("="*60)

tde = train_check[train_check["target"] == 1]
non_tde = train_check[train_check["target"] == 0]

key_features = [
    "sn_salt2_rchisq", "sn_salt3_rchisq", "sn_best_ia_rchisq",
    "sn_salt2_x1", "sn_salt2_c", "sn_is_good_ia_fit"
]

print(f"\n{'Feature':<25} {'TDE mean':>12} {'Non-TDE mean':>12} {'Diff':>10} {'Cohen d':>10}")
print("-" * 70)

for feat in key_features:
    if feat not in train_check.columns:
        continue
    
    tde_vals = tde[feat].dropna()
    non_tde_vals = non_tde[feat].dropna()
    
    if len(tde_vals) < 10 or len(non_tde_vals) < 10:
        continue
    
    tde_mean = tde_vals.mean()
    non_tde_mean = non_tde_vals.mean()
    diff = tde_mean - non_tde_mean
    
    # Cohen's d
    pooled_std = np.sqrt((tde_vals.std()**2 + non_tde_vals.std()**2) / 2)
    cohen_d = diff / pooled_std if pooled_std > 0 else 0
    
    print(f"{feat:<25} {tde_mean:>12.3f} {non_tde_mean:>12.3f} {diff:>10.3f} {cohen_d:>10.2f}")

print("\n(Positive Cohen's d means TDEs have higher values)")
print("(Expect: TDEs should have HIGHER rchisq = worse SN fits)")


DISCRIMINATIVE POWER CHECK

Feature                       TDE mean Non-TDE mean       Diff    Cohen d
----------------------------------------------------------------------
sn_salt2_rchisq                 56.451      115.755    -59.304      -0.07
sn_salt3_rchisq                 67.830      109.015    -41.186      -0.06
sn_best_ia_rchisq               39.993       70.044    -30.051      -0.05
sn_salt2_x1                      3.277        2.070      1.207       0.43
sn_salt2_c                      -0.130       -0.078     -0.053      -0.22
sn_is_good_ia_fit                0.014        0.196     -0.183      -0.62

(Positive Cohen's d means TDEs have higher values)
(Expect: TDEs should have HIGHER rchisq = worse SN fits)


In [20]:
# Good Ia fit rate by class
print("\nGood Ia fit rate by class:")
print(f"  TDEs: {tde['sn_is_good_ia_fit'].mean()*100:.1f}%")
print(f"  Non-TDEs: {non_tde['sn_is_good_ia_fit'].mean()*100:.1f}%")
print("\n(Lower is better for TDEs - they shouldn't fit Ia templates well)")


Good Ia fit rate by class:
  TDEs: 1.4%
  Non-TDEs: 19.6%

(Lower is better for TDEs - they shouldn't fit Ia templates well)


## Done!

**Next steps:**
1. Copy `train_sncosmo_features.parquet` and `test_sncosmo_features.parquet` to your main notebook
2. Merge with your feature matrix: `X_train = X_train.merge(sncosmo_features, on='object_id', how='left')`
3. Key features to use:
   - `sn_salt2_rchisq` - Higher = worse SN Ia fit = more likely TDE
   - `sn_best_ia_rchisq` - Best Ia fit across all models
   - `sn_is_good_ia_fit` - Binary flag (1 = good fit, likely SN Ia)
   - `sn_salt2_x1`, `sn_salt2_c` - SALT stretch and color params

In [21]:
print("\n" + "="*60)
print("SNCOSMO FEATURE EXTRACTION COMPLETE")
print("="*60)
print(f"\nOutput files:")
print(f"  - train_sncosmo_features.parquet ({len(train_features)} objects)")
print(f"  - test_sncosmo_features.parquet ({len(test_features)} objects)")
print(f"\nFeatures created: {len(train_features.columns) - 1}")


SNCOSMO FEATURE EXTRACTION COMPLETE

Output files:
  - train_sncosmo_features.parquet (3043 objects)
  - test_sncosmo_features.parquet (7135 objects)

Features created: 43
