In [None]:
# import os
# os.environ["JAX_PLATFORM_NAME"] = "cpu"

In [None]:
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from pzflow import Flow
import pzflow
from tdastro.astro_utils.noise_model import apply_noise
from tdastro.astro_utils.passbands import PassbandGroup
from tdastro.astro_utils.snia_utils import (
    DistModFromRedshift,
    HostmassX1Func,
    X0FromDistMod,
    num_snia_per_redshift_bin,
)
from tdastro.astro_utils.pzflow_node import PZFlowNode
from tdastro.astro_utils.unit_utils import flam_to_fnu, fnu_to_flam
from tdastro.math_nodes.np_random import NumpyRandomFunc
from tdastro.math_nodes.scipy_random import SamplePDF
from tdastro.sources.sncomso_models import SncosmoWrapperModel
from tdastro.sources.snia_host import SNIaHost
from tdastro.astro_utils.opsim import OpSim

In [None]:
opsim_db = OpSim.from_db(
    "../../../../tdastro/opsim_db/baseline_v3.4_10yrs.db",
    sql_query="SELECT * FROM observations WHERE filter IN ('g','r','i','z')",
)
t_min = opsim_db["observationStartMJD"].min()
t_max = opsim_db["observationStartMJD"].max()

In [None]:
opsim_db["target"].unique()

In [None]:
passbands_dir = Path("../../tests/tdastro/data/passbands/LSST")
passband_list = []
for band in "griz":
    file_path = passbands_dir / "LSST" / f"{band}.dat"
    passband_list.append({"filter_name": band, "table_path": file_path})
    print(f"Loading band {band} from {file_path}")

if len(passband_list) == 0:
    raise ValueError("No passbands being loaded.")

# Do the actual loading and processing.
passbands = PassbandGroup(
    passband_parameters=passband_list,
    survey="LSST",
    units="nm",
    trim_quantile=0.001,
    delta_wave=1,
)

In [None]:
# Create a host galaxy with properties drawn from a pzflow model.

flow = Flow(file="snia_hosts_test_pzflow.pkl")
pz_node = PZFlowNode(flow, node_label="pznode")

host = SNIaHost(
    ra=pz_node.RA_GAL,
    dec=pz_node.DEC_GAL,
    hostmass=pz_node.LOGMASS,
    redshift=pz_node.ZTRUE,
    node_label="host",
)

In [None]:
distmod_func = DistModFromRedshift(host.redshift, H0=73.0, Omega_m=0.3)
x1_func = HostmassX1Func(host.hostmass)
c_func = NumpyRandomFunc("normal", loc=0, scale=0.02)
m_abs_func = NumpyRandomFunc("normal", loc=-19.3, scale=0.1)
x0_func = X0FromDistMod(
    distmod=distmod_func,
    x1=x1_func,
    c=c_func,
    alpha=0.14,
    beta=3.1,
    m_abs=m_abs_func,
    node_label="x0_func",
)

In [None]:
sncosmo_modelname = "salt2-h17"
source = SncosmoWrapperModel(
    sncosmo_modelname,
    t0=NumpyRandomFunc("uniform", low=t_min, high=t_max),
    x0=x0_func,
    x1=x1_func,
    c=c_func,
    ra=NumpyRandomFunc("normal", loc=host.ra, scale=0.01),
    dec=NumpyRandomFunc("normal", loc=host.dec, scale=0.01),
    redshift=host.redshift,
    node_label="source",
)

In [None]:
%%timeit

nsamples = 1_000
states = source.sample_parameters(num_samples=nsamples)
lc_list = []

for i in range(0, nsamples):
    state = states.extract_single_sample(i)

    ra = state["source"]["ra"]
    dec = state["source"]["dec"]
    t0 = state["source"]["t0"]
    z = state["source"]["redshift"]

    # print(ra,dec,t0,z)

    opsim = opsim_db

    obs_index = np.array(opsim.range_search(ra, dec, radius=1.75))

    # Update obs_index to only include observations within SN lifespan
    # We need this until we have a detection model
    phase_obs = opsim["time"][obs_index] - t0
    obs_index = obs_index[(phase_obs > -20 * (1.0 + z)) & (phase_obs < 50 * (1.0 + z))]

    # Extract the timing and filter information for those observations, changing the
    # match band names in passbands object.
    times = opsim["time"][obs_index].to_numpy()
    if len(times) == 0:
        print(f"No overlap time in opsim for (ra,dec)=({ra:.2f},{dec:.2f}), index={i}")

    filters = opsim["filter"][obs_index].to_numpy(str)
    filters = np.char.add("LSST_", filters)

    # Compute the band_flixes over just the given filters.
    try:
        bandfluxes_perfect = source.get_band_fluxes(passbands, times, filters, state)
    except Exception as e:
        print(f"{e}, index={i}, redshift={z}")
        continue

    bandfluxes_error = opsim.bandflux_error_point_source(bandfluxes_perfect, obs_index)
    bandfluxes = apply_noise(bandfluxes_perfect, bandfluxes_error, rng=None)
    lc = pd.DataFrame(
        {"id": i, "mjd": times, "filter": filters, "flux": bandfluxes, "fluxerr": bandfluxes_error}
    )
    lc_list.append(lc)

In [None]:
%%timeit

# try above in batch

nsamples = 1_000
states = source.sample_parameters(num_samples=nsamples)
ra = states.extract_parameters("host.ra")["host.ra"]
dec = states.extract_parameters("host.dec")["host.dec"]
z = states.extract_parameters("host.redshift")["host.redshift"]
t0 = states.extract_parameters("t0")["t0"]

opsim = opsim_db
obs_indexes = np.array(opsim.range_search(ra, dec, radius=1.75))

lc_list = []

for i, obs_index in enumerate(obs_indexes):
    # Update obs_index to only include observations within SN lifespan
    # We need this until we have a detection model
    phase_obs = np.array(opsim["time"][obs_index]) - t0[i]
    obs_index = np.array(obs_index)[(phase_obs > -20 * (1.0 + z[i])) & (phase_obs < 50 * (1.0 + z[i]))]

    # Extract the timing and filter information for those observations, changing the
    # match band names in passbands object.
    times = opsim["time"][obs_index].to_numpy()
    if len(times) == 0:
        print(f"No overlap time in opsim for (ra,dec)=({ra[i]:.2f},{dec[i]:.2f}), index={i}")

    filters = opsim["filter"][obs_index].to_numpy(str)
    filters = np.char.add("LSST_", filters)

    # Compute the band_flixes over just the given filters.
    try:
        bandfluxes_perfect = source.get_band_fluxes(
            passbands, times, filters, states.extract_single_sample(i)
        )
    except Exception as e:
        print(f"{e}, index={i}, redshift={z[i]}")
        continue

    bandfluxes_error = opsim.bandflux_error_point_source(bandfluxes_perfect, obs_index)
    bandfluxes = apply_noise(bandfluxes_perfect, bandfluxes_error, rng=None)
    lc = pd.DataFrame(
        {"id": i, "mjd": times, "filter": filters, "flux": bandfluxes, "fluxerr": bandfluxes_error}
    )
    lc_list.append(lc)

In [None]:
lightcurves = pd.concat(lc_list)
lightcurves.head()

In [None]:
random_ids = np.random.choice(lightcurves.id.unique(), 5)

for random_id in random_ids:
    lc = lightcurves.loc[lightcurves.id == random_id]

    for f in lc["filter"].unique():
        lc_f = lc.loc[lc["filter"] == f]
        plt.errorbar(lc_f["mjd"], lc_f["flux"], yerr=lc_f["fluxerr"], fmt="o", label=f)
    plt.xlabel(f"MJD")
    plt.ylabel(f"nJy")
    plt.legend()
    plt.show()