In [8]:
from pathlib import Path
import pandas as pd
import numpy as np
import sncosmo
from astropy.table import Table
import matplotlib.pyplot as plt
import warnings

project_root = Path.cwd().parent
print(f"Project root: {project_root}")

Project root: /Users/david/Code/msc


In [9]:
# User input: run folder name
folder_name = input("Enter the run folder name: ").strip()
run_folder = project_root / "runs" / folder_name

lightcurve_files = sorted(run_folder.glob("*_lightcurve.csv"))
print(f"Found {len(lightcurve_files)} lightcurve CSV(s) in {run_folder}")
for f in lightcurve_files:
    print(f"{f.name}")

Found 1 lightcurve CSV(s) in /Users/david/Code/msc/runs/test1
ZTF18aaizerg_lightcurve.csv


In [10]:
filter_map = {"g": "ztfg", "r": "ztfr"}
# Redshift per object from ztf_cleansed.csv; fit t0, x0, x1, c only
fit_params = ["t0", "x0", "x1", "c"]
bounds = {"x1": (-3, 3), "c": (-0.3, 0.3)}

all_results = []

In [None]:
# Load ztf_cleansed.csv
ztf_cleansed_path = project_root / "ztf_cleansed.csv"
ztf_meta = pd.read_csv(ztf_cleansed_path)
ztf_meta["ZTFID"] = ztf_meta["ZTFID"].astype(str).str.strip()

total_files = len(lightcurve_files)
completed = 0

for idx, lc_path in enumerate(lightcurve_files, 1):
    obj_id = lc_path.stem.replace("_lightcurve", "")
    print(f"[{idx}/{total_files}] Processing {obj_id}....")

    # Load cleaned lightcurve from downloadLasair (MJD, filter, forced_ujy, forced_ujy_error)
    df = pd.read_csv(lc_path)
    df["MJD"] = pd.to_numeric(df["MJD"], errors="coerce")
    df = df.dropna(subset=["MJD", "filter", "forced_ujy", "forced_ujy_error"])
    df["filter"] = df["filter"].astype(str).str.strip().str.lower()
    df = df[df["forced_ujy"].gt(0) & df["forced_ujy_error"].gt(0)].copy()

    if len(df) < 5:
        print(f"  Skip {obj_id}: too few points ({len(df)})")
        continue

    bands = [filter_map.get(f, f"ztf{f}") for f in df["filter"].values]
    data = Table({
        "time": df["MJD"].values,
        "band": bands,
        "flux": df["forced_ujy"].values,
        "fluxerr": df["forced_ujy_error"].values,
        "zp": np.full(len(df), 23.9),
        "zpsys": np.array(["ab"] * len(df)),
    })

    # Look up obj_id in the ZTFID column
    redshift = float(ztf_meta.loc[ztf_meta["ZTFID"] == obj_id, "redshift"].values[0])

    model = sncosmo.Model(source="salt2")
    model.set(z=redshift)

    try:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", RuntimeWarning)
            result, fitted_model = sncosmo.fit_lc(data, model, fit_params, bounds=bounds)
    except Exception as e:
        print(f"  Fit failed: {e}")
        continue

    row = {"object_id": obj_id, "chisq": result.chisq, "ndof": result.ndof, "z": redshift}
    for name, val in zip(result.param_names, result.parameters):
        row[name] = val
    if result.errors is not None:
        for name, err in zip(result.param_names, result.errors):
            row[f"{name}_err"] = err
    all_results.append(row)

    fig = sncosmo.plot_lc(data, model=fitted_model, errors=result.errors)
    plt.gcf().suptitle(obj_id)
    out_plot = run_folder / f"{lc_path.stem}_salt2.png"
    plt.savefig(out_plot, dpi=150, bbox_inches="tight")
    plt.close()
    completed += 1
    print(f"  Saved {out_plot.name}  ({completed}/{total_files} completed)")

Processing ZTF18aaizerg....
  Saved ZTF18aaizerg_lightcurve_salt2.png


In [12]:
# Output CSV of sncosmo parameters
if all_results:
    params_df = pd.DataFrame(all_results)
    out_csv = run_folder / "sncosmo_parameters.csv"
    params_df.to_csv(out_csv, index=False)
    print(f"Saved sncosmo parameters to {out_csv}")
    display(params_df)
else:
    print("No successful fits to write.")

Saved sncosmo parameters to /Users/david/Code/msc/runs/test1/sncosmo_parameters.csv


Unnamed: 0,object_id,chisq,ndof,z,t0,x0,x1,c,z_err,t0_err,x0_err,x1_err
0,ZTF18aaizerg,13.955202,24,0.0685,58852.304436,0.000608,-1.288589,0.112886,t0,x0,x1,c
