# Mini-challenge starting point

_Alex Malz (LINCC-Frameworks@CMU) & {add your name here}_

TODOs:
- [x] replace sources with some sncosmo sources using setup from plasticc_snia notebook
- [ ] timeit/cprofile

In [None]:
import numpy as np
from astropy.cosmology import FlatLambdaCDM
import matplotlib.pyplot as plt

from tdastro.sources.basic_sources import StaticSource
from tdastro.math_nodes.np_random import NumpyRandomFunc
from tdastro.astro_utils.snia_utils import DistModFromRedshift
from tdastro.base_models import FunctionNode
from tdastro.math_nodes.basic_math_node import BasicMathNode

from tdastro.astro_utils.opsim import OpSim
from tdastro.astro_utils.passbands import PassbandGroup
from tdastro.astro_utils.pzflow_node import PZFlowNode
from tdastro.astro_utils.snia_utils import (
    DistModFromRedshift,
    HostmassX1Func,
    X0FromDistMod,
)
from tdastro.math_nodes.np_random import NumpyRandomFunc
from tdastro.simulate import simulate_lightcurves
from tdastro.sources.sncomso_models import SncosmoWrapperModel
from tdastro.sources.snia_host import SNIaHost
from tdastro.utils.plotting import plot_lightcurves

from tdastro import _TDASTRO_BASE_DATA_DIR

import cProfile, pstats, io
from pstats import SortKey

I want to show how you can define multiple sources and generate a population of light curves under two opsims.
Then I'll add resource/time profiling to find bottlenecks.

## Define classes of source and their linkages and parameters

modifications of tdastro/docs/notebooks/pre_executed/plasticc_snia.ipynb

In [None]:
# Load the Flow model into a PZFlow node. This gives access to all of the outputs of the
# flow model as attributes of the PZFlowNode.
pz_node = PZFlowNode.from_file('../../tests/tdastro/data/snia_hosts_test_pzflow.pkl',
    # _TDASTRO_BASE_DATA_DIR / "model_f?iles" / "snia_hosts_test_pzflow.pkl",  # filename
    node_label="pznode",
)

# Create a model for the host of the SNIa. The attributes will be sampled via
# the PZFlowNode's model. So each host instantiation will have its own properties.
# Note: This requires the user to know the output names from the underlying flow model.
host = SNIaHost(
    ra=pz_node.RA_GAL,
    dec=pz_node.DEC_GAL,
    hostmass=pz_node.LOGMASS,
    redshift=NumpyRandomFunc("uniform", low=0.1, high=0.6),
    node_label="host",
)

need to set up for opsim up front

In [None]:
#DON'T RUN ME AGAIN! TAKES A MINUTE

pr = cProfile.Profile()
pr.enable()

# Load the OpSim data.
opsim_db = OpSim.from_db('../../tests/tdastro/data/baseline_v3.4_10yrs.db')#_TDASTRO_BASE_DATA_DIR / "opsim_db" / "baseline_v3.4_10yrs.db")
t_min, t_max = opsim_db.time_bounds()
print(f"Loaded OpSim with {len(opsim_db)} rows and times [{t_min}, {t_max}]")

# Load the passband data for the griz filters only.
passband_group = PassbandGroup(
    preset="LSST", filters_to_load=["g", "r", "i", "z"], units="nm", trim_quantile=0.001, delta_wave=1
)
print(f"Loaded Passbands: {passband_group}")

pr.disable()
s = io.StringIO()
sortby = SortKey.CUMULATIVE
ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
ps.print_stats()
print(s.getvalue())

In [None]:
# Create a mask of matching filters.
filter_mask = passband_group.mask_by_filter(opsim_db["filter"])

# Filter the OpSim
opsim_db = opsim_db.filter_rows(filter_mask)
t_min, t_max = opsim_db.time_bounds()
print(f"Filtered OpSim to {len(opsim_db)} rows and times [{t_min}, {t_max}]")

this will download models

TODO:
- [ ] all the sncosmo models!

In [None]:
distmod_func = DistModFromRedshift(host.redshift, H0=73.0, Omega_m=0.3)
x1_func = HostmassX1Func(host.hostmass)
c_func = NumpyRandomFunc("normal", loc=0, scale=0.02)
m_abs_func = NumpyRandomFunc("normal", loc=-19.3, scale=0.1)
x0_func = X0FromDistMod(
    distmod=distmod_func,
    x1=x1_func,
    c=c_func,
    alpha=0.14,
    beta=3.1,
    m_abs=m_abs_func,
    node_label="x0_func",
)

sncosmo_modelname = "salt2-h17"
source = SncosmoWrapperModel(
    sncosmo_modelname,
    t0=NumpyRandomFunc("uniform", low=t_min, high=t_max),
    x0=x0_func,
    x1=x1_func,
    c=c_func,
    ra=NumpyRandomFunc("normal", loc=host.ra, scale=0.01),
    dec=NumpyRandomFunc("normal", loc=host.dec, scale=0.01),
    redshift=host.redshift,
    node_label="source",
)

In [None]:
source1 = StaticSource(brightness=10.0, node_label="my_static_source", ra=55.5, dec=-43.5)

brightness_func = NumpyRandomFunc("uniform", low=11.0, high=15.5)
source2 = StaticSource(brightness=brightness_func, ra=65.5, dec=-53.5, node_label="my_static_source_2")

source3 = StaticSource(
    brightness=NumpyRandomFunc("normal", loc=20.0, scale=2.0),
    redshift=NumpyRandomFunc("uniform", low=0.1, high=0.5),
    ra=1.0, dec=2.0,
    node_label="test",
    t0=0.,
)

# host = StaticSource(brightness=15.0, ra=1.0, dec=2.0, node_label="host")
# source = StaticSource(brightness=10.0, ra=host.ra, dec=host.dec, node_label="source")
# state = source.sample_parameters(num_samples=5)

# for i in range(5):
#     print(
#         f"{i}: Host=({state['host']['ra'][i]}, {state['host']['dec'][i]})"
#         f"Source=({state['source']['ra'][i]}, {state['source']['dec'][i]})"
#     )

### Take samples of the underlying parameters (or do it implicitly in simulating light curves)

In [None]:
# state = source1.sample_parameters(num_samples=10)
# state["my_static_source"]["brightness"]

# state = source1.sample_parameters(num_samples=10)
# state["my_static_source"]["brightness"]

# state = source2.sample_parameters(num_samples=10)
# state["my_static_source_2"]["brightness"]

# num_samples = 10
# state = source3.sample_parameters(num_samples=num_samples)
# for i in range(num_samples):
#     print(f"{i}: brightness={state['test']['brightness'][i]} redshift={state['test']['redshift'][i]}")

In [None]:
# single_sample = state.extract_single_sample(0)
# print(str(single_sample))

In [None]:
# cosmo_obj = FlatLambdaCDM(H0=73.0, Om0=0.3)
# redshifts = np.array([0.1, 0.2, 0.3])
# distmods = cosmo_obj.distmod(redshifts).value
# print(distmods)

In [None]:
# distmod_obj = DistModFromRedshift(
#     H0=73.0, Omega_m=0.3, redshift=NumpyRandomFunc("uniform", low=0.1, high=0.5)
# )

## Apply observational effects and opsim/passbands to make light curves

need to download an opsim (see convenient one [here](https://drive.google.com/drive/folders/1XIgEfi9BOEHW0W-uM7XvzlnOcRfaCSM0)) and then change argument to relative path with respect to this notebook

In [None]:
lightcurves = simulate_lightcurves(source, 1_000, opsim_db, passband_group)
print(lightcurves)

params or parameters not both, also can I just print these with repr instead of digging? (get vs extract)

In [None]:
# test_samps = source2.sample_parameters()

# test_samps.extract_parameters('dec')

In [None]:
pr = cProfile.Profile()
pr.enable()
lightcurves1 = simulate_lightcurves(source1, 10, opsim_db, passband_group)
print(lightcurves1)
pr.disable()
s = io.StringIO()
sortby = SortKey.CUMULATIVE
ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
ps.print_stats()
print(s.getvalue())

In [None]:
pr = cProfile.Profile()
pr.enable()
lightcurves2 = simulate_lightcurves(source2, 100, opsim_db, passband_group)
print(lightcurves2)
pr.disable()
s = io.StringIO()
sortby = SortKey.CUMULATIVE
ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
ps.print_stats()
print(s.getvalue())

In [None]:
pr = cProfile.Profile()
pr.enable()
lightcurves3 = simulate_lightcurves(source3, 1000, opsim_db, passband_group)
print(lightcurves3)
pr.disable()
s = io.StringIO()
sortby = SortKey.CUMULATIVE
ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
ps.print_stats()
print(s.getvalue())

## Look at the light curves

In [None]:
def show_me_lcs(lcs):
    random_ids = np.random.choice(len(lightcurves), 10)

    for random_id in random_ids:
    # Extract the row for this object.
        lc = lcs.loc[random_id]

        if lc["nobs"] > 0:
        # Unpack the nested columns (filters, mjd, flux, and flux error).
            lc_filters = np.asarray(lc["lightcurve"]["filter"], dtype=str)
            lc_mjd = np.asarray(lc["lightcurve"]["mjd"], dtype=float)
            lc_flux = np.asarray(lc["lightcurve"]["flux"], dtype=float)
            lc_fluxerr = np.asarray(lc["lightcurve"]["fluxerr"], dtype=float)

            plot_lightcurves(
                fluxes=lc_flux,
                times=lc_mjd,
                fluxerrs=lc_fluxerr,
                filters=lc_filters,
                )

In [None]:
show_me_lcs(lightcurves3)