# 4. SALT2 fits (sncosmo)

Fits a **SALT2** model to each light curve using **sncosmo**. For each object:

- **Inputs:** Light-curve CSV (MJD, filter g/r, forced_ujy, forced_ujy_error; zp=23.9, zpsys=ab), plus from `ztf_cleansed.csv`: redshift (or host_redshift), and A_V â†’ mwebv = A_V/3.1 (Milky Way dust).
- **Fitted parameters:** t0, x0, x1, c (z and mwebv fixed).
  - **t0:** time of B-band maximum (MJD).
  - **x0:** flux normalization (overall scale).
  - **x1:** stretch (positive = broader light curve).
  - **c:** colour (positive = redder).

**Output:** `runs/<run>/sncosmo_fits.csv` - columns include ztf_id, redshift, ncall, ndof, chisq, t0, x0, x1, c.

In [31]:
from pathlib import Path
import pandas as pd
import numpy as np
import datetime
import sncosmo
from astropy.table import Table
import matplotlib.pyplot as plt
import warnings

project_root = Path.cwd().parent
print(f"Project root: {project_root}")

Project root: /Users/david/Code/msc


In [32]:
# User input: run folder name
folder_name = input("Enter the run folder name: ").strip()
run_folder = project_root / "runs" / folder_name

ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
lightcurve_dir = run_folder / "laisair"
if not lightcurve_dir.exists():
    raise FileNotFoundError(f"Lightcurve folder not found: {lightcurve_dir}")

# get all light curve for the run
lightcurve_files = sorted(lightcurve_dir.glob("*_lightcurve.csv"))
print(f"Found {len(lightcurve_files)} lightcurve CSV in {lightcurve_dir}")
for f in lightcurve_files:
    print(f"{f.name}")

Found 4998 lightcurve CSV in /Users/david/Code/msc/runs/run7/laisair
ZTF17aabtvsy_lightcurve.csv
ZTF17aabvong_lightcurve.csv
ZTF17aacldgo_lightcurve.csv
ZTF17aadlxmv_lightcurve.csv
ZTF18aaaonon_lightcurve.csv
ZTF18aaaooqj_lightcurve.csv
ZTF18aaaqexr_lightcurve.csv
ZTF18aadlaxo_lightcurve.csv
ZTF18aadxnul_lightcurve.csv
ZTF18aadzfso_lightcurve.csv
ZTF18aaeqjmc_lightcurve.csv
ZTF18aaermez_lightcurve.csv
ZTF18aafdigb_lightcurve.csv
ZTF18aagkwgz_lightcurve.csv
ZTF18aagtwyh_lightcurve.csv
ZTF18aahfbqp_lightcurve.csv
ZTF18aahtjsc_lightcurve.csv
ZTF18aahvndq_lightcurve.csv
ZTF18aailmnv_lightcurve.csv
ZTF18aaisqmw_lightcurve.csv
ZTF18aaiwzie_lightcurve.csv
ZTF18aaiykoz_lightcurve.csv
ZTF18aaizerg_lightcurve.csv
ZTF18aajivpr_lightcurve.csv
ZTF18aajkcdn_lightcurve.csv
ZTF18aajkrxi_lightcurve.csv
ZTF18aajpjdi_lightcurve.csv
ZTF18aajvqye_lightcurve.csv
ZTF18aakitiq_lightcurve.csv
ZTF18aakiwbs_lightcurve.csv
ZTF18aaklpdo_lightcurve.csv
ZTF18aamvfeb_lightcurve.csv
ZTF18aamxads_lightcurve.csv
ZTF18aa

In [None]:
ztf_cleansed_path = project_root / "ztf_cleansed.csv"
ztf_df = pd.read_csv(ztf_cleansed_path)

# Initialize list to collect results
results_list = []

cosmo_folder = run_folder / "cosmo"
cosmo_folder.mkdir(parents=True, exist_ok=True)

for idx, lc_path in enumerate(lightcurve_files, 1):
    print(f"Processing file {idx}/{len(lightcurve_files)}: {lc_path.name}")
    lc_df = pd.read_csv(lc_path)
    
    lc_df['MJD'] = pd.to_numeric(lc_df['MJD'], errors='coerce')
    # we will only use 'forced_ujy', 'forced_ujy_error'
    lc_df = lc_df.dropna(subset=['MJD', 'filter', 'forced_ujy', 'forced_ujy_error'])
    lc_df['filter'] = lc_df['filter'].astype(str).str.strip().str.lower()
    lc_df = lc_df[lc_df['forced_ujy'] > 0]
    lc_df = lc_df[lc_df['forced_ujy_error'] > 0]
    lc_df = lc_df.sort_values('MJD')

    if lc_df.empty:
        print(f"  No valid data points after cleaning, skipping.")
        continue

    bands = lc_df['filter'].map({'g': 'ztfg', 'r': 'ztfr'}).values

    data = Table({
        'time': lc_df['MJD'].values,
        'band': bands,
        'flux': lc_df['forced_ujy'].values,
        'fluxerr': lc_df['forced_ujy_error'].values,
        'zp': np.full(len(lc_df), 23.9),
        'zpsys': np.array(['ab'] * len(lc_df)),
    })
    # print(set(bands))
    # print(data)
    print(f"Loaded {len(data)} data points after cleaning.")

    # add dust (commented out for now)
    dust = sncosmo.CCM89Dust()
    Rv = 3.1 # Milky Way extinction law

    model = sncosmo.Model(
        source='salt2',
        effects=[sncosmo.CCM89Dust()],
        effect_names=['mw'],
        effect_frames=['obs']
    )

    obj_id = lc_df['ztf_id'].iloc[0]

    z_row = ztf_df.loc[ztf_df['ZTFID'] == obj_id]
    if z_row.empty:
        print(f"  No row found in ztf_cleansed for ZTFID {obj_id}, skipping.")
        continue
    # MW dust (commented out for now) â€” A_V and redshift can be '-' or other non-numeric in CSV
    A_V = pd.to_numeric(z_row['A_V'].iloc[0], errors='coerce')
    # Prefer TNS host redshift (higher precision) to reduce catalog-rounding strips in Hubble diagram
    z_tns = pd.to_numeric(z_row['tns_redshift'].iloc[0], errors='coerce') if 'tns_redshift' in z_row.columns else np.nan
    z_ztf = pd.to_numeric(z_row['redshift'].iloc[0], errors='coerce')
    if np.isfinite(z_tns) and z_tns > 0:
        z_ztf = z_tns
    if not np.isfinite(z_ztf) or z_ztf <= 0:
        print(f"  Invalid or non-positive redshift for ZTFID {obj_id} (redshift={z_row['redshift'].iloc[0]!r}), skipping.")
        continue
    z_ztf = float(z_ztf)  # use for fit and output (TNS when available, else ztf)
    z_tns = float(z_tns) if np.isfinite(z_tns) else np.nan
    # A_V: use 0 if missing or negative (negative extinction is unphysical; catalog artifacts)
    A_V = float(A_V) if np.isfinite(A_V) else 0.0
    A_V = max(0.0, A_V)
    mwebv = A_V / Rv

    # Initial t0 from ztf_cleansed peak time (peakt) which we already converted to MJD, fallback to first MJD in light curve
    t0_guess = pd.to_numeric(z_row['peakt'].iloc[0], errors='coerce')

    # Model.set with t0 initial guess and redshift and MW extinction (z_ztf = best available)
    model.set(z=z_ztf, mwebv=mwebv, t0=t0_guess)

    try:
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always", RuntimeWarning)
            result, fitted_model = sncosmo.fit_lc(
                data, model,
                ['t0', 'x0', 'x1', 'c'])
            has_numerical_warning = any(
                isinstance(x.message, RuntimeWarning) for x in w
            )
        # print(result)
        # Collect results
        params = dict(zip(result.param_names, result.parameters))
        result_dict = {
            'ztf_id': obj_id,
            'redshift': z_ztf,
            'ncall': result.ncall,
            'ndof': result.ndof,
            'chisq': result.chisq,
            't0': params['t0'],
            'x0': params['x0'],
            'x1': params['x1'],
            'c': params['c'],
            'fit_numerical_warning': has_numerical_warning,
            'fit': result.chisq / result.ndof if result.ndof > 0 else np.nan
        }
        results_list.append(result_dict)

        print(f"  Fit completed for {obj_id}" + (" (numerical warning)" if has_numerical_warning else ""))
    except RuntimeError as e:
        print(f"  Fit failed for {obj_id}: {e}")
        continue

    # Create and save plot for each lightcurve fit
    # fig, ax = plt.subplots(figsize=(8, 5))
    # sncosmo.plot_lc(data, model=fitted_model, ax=ax, errors=result.errors if hasattr(result, 'errors') else None)
    # ax.set_title(f"Light Curve Fit: {obj_id}")
    # ax.set_xlabel('MJD')
    # ax.set_ylabel('Flux (uJy)')
    # plot_path = cosmo_folder / f"{obj_id}_plot.png"
    # plt.tight_layout()
    # plt.savefig(plot_path)
    # plt.close(fig)
    # print(f"  Plot saved to {plot_path}")

# After loop, create dataframe and save to CSV
results_df = pd.DataFrame(results_list)

cosmo_df = results_df.copy()

# Apply cuts
cosmo_df = cosmo_df[
    (cosmo_df['fit'] > 0.5) & (cosmo_df['fit'] < 3) & # drop low and high stretch
    (cosmo_df['x1'] > -3) & (cosmo_df['x1'] < 3) & # drop low and high stretch
    (cosmo_df['c'] > -0.3) & (cosmo_df['c'] < 0.3) & # drop low and high colour
    (cosmo_df['redshift'] > 0) & # drop negative redshift
    (cosmo_df['ndof'] > 5) & # drop low dof
    (~cosmo_df['fit_numerical_warning']) # drop numerical warning
]

output_csv_path = run_folder / "sncosmo_fits.csv"
cosmo_df.to_csv(output_csv_path, index=False)
print(f"Results saved to {output_csv_path}")

Processing file 1/4998: ZTF17aabtvsy_lightcurve.csv
Loaded 23 data points after cleaning.
  Fit completed for ZTF17aabtvsy
Processing file 2/4998: ZTF17aabvong_lightcurve.csv
Loaded 31 data points after cleaning.
  Fit completed for ZTF17aabvong
Processing file 3/4998: ZTF17aacldgo_lightcurve.csv
Loaded 16 data points after cleaning.
  Fit completed for ZTF17aacldgo
Processing file 4/4998: ZTF17aadlxmv_lightcurve.csv
Loaded 27 data points after cleaning.
  Fit completed for ZTF17aadlxmv
Processing file 5/4998: ZTF18aaaonon_lightcurve.csv
Loaded 10 data points after cleaning.
  Fit completed for ZTF18aaaonon
Processing file 6/4998: ZTF18aaaooqj_lightcurve.csv
Loaded 40 data points after cleaning.
  Fit completed for ZTF18aaaooqj
Processing file 7/4998: ZTF18aaaqexr_lightcurve.csv
Loaded 11 data points after cleaning.
  Fit completed for ZTF18aaaqexr
Processing file 8/4998: ZTF18aadlaxo_lightcurve.csv
Loaded 20 data points after cleaning.
  Fit completed for ZTF18aadlaxo
Processing file 