In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys

import matplotlib.pyplot as plt
import numpy as np
import sncosmo
import tdastro

# Append the path to the test directory so we can import run_snia_end2end
test_path = tdastro._TDASTRO_TEST_DIR
sys.path.append(str(test_path.resolve()))

from sources.test_snia import run_snia_end2end

### Create the data we will use for this test

Load a sample opsim file (opsim_shorten.db) from the test's data directory and use the `oversample_opsim()` function to sample every 0.01 days from MJD=61406.0 to MJD=61771.0.

In [None]:
from tdastro.astro_utils.opsim import OpSim, oversample_opsim

opsim_name = os.path.join(tdastro._TDASTRO_TEST_DATA_DIR, "opsim_shorten.db")
base_opsim = OpSim.from_db(opsim_name)
oversampled_observations = oversample_opsim(
    base_opsim,
    pointing=(0.0, 0.0),
    search_radius=180.0,
    delta_t=0.01,
    time_range=(61406.0, 61771.0),
    bands=None,
    strategy="darkest_sky",
)

### define the source and observation

### Run the test 
Run the end to end test using the `run_snia_end2end()` to generate 20 samples.

TODO: use timeit or profiler here (repeat with a different nsample)

In [None]:
passbands_dir = os.path.join(tdastro._TDASTRO_TEST_DATA_DIR, "passbands")
res, passbands = run_snia_end2end(
    oversampled_observations,
    passbands_dir=passbands_dir,
    nsample=20,
)

print(f"Produced {len(res)} samples.")

### examine model parameters

In [None]:
hostmass = [x["parameter_values"]["hostmass"] for x in res]
x1 = [x["parameter_values"]["x1"] for x in res]
x0 = [x["parameter_values"]["x0"] for x in res]
c = [x["parameter_values"]["c"] for x in res]
z = [x["parameter_values"]["redshift"] for x in res]
t0 = [x["parameter_values"]["t0"] for x in res]
distmod = [x["parameter_values"]["distmod"] for x in res]

In [None]:
plt.hist(hostmass)

In [None]:
plt.hist(x1)
print(np.std(x1))

In [None]:
plt.hist(c)

In [None]:
plt.scatter(hostmass, x1)

In [None]:
plt.hist(x0)

In [None]:
plt.hist(z)

### examine physical effects

In [None]:
# cosmo = FlatLambdaCDM(H0=73, Om0=0.3)
# distmod = cosmo.distmod(z).value

#    x0 = np.power(10., -0.4 * (distmod - alpha * x1 + beta * c + m_abs))
#  -2.5*log10(x0) = distmod -alpha*x1 + beta*c + m_abs

mb = -2.5 * np.log10(x0)
print(np.std(mb - distmod))

plt.scatter(z, mb - distmod)
plt.show()

alpha = 0.14
beta = 3.1
mu = np.array(mb) + alpha * np.array(x1) - beta * np.array(c) + 19.3
print(np.std(mu - distmod))

plt.scatter(z, mu - distmod)

plt.show()

In [None]:
plt.scatter(z, mb)
plt.show()
plt.scatter(z, mu)
plt.show()

In [None]:
plt.scatter(hostmass, mu - distmod)
plt.show()

### examine the source model

In [None]:
for i in range(0, 3):
    try:
        plt.plot(res[i]["wavelengths_rest"] * (1 + z[i]), res[i]["flux_flam"][0], color="r")
    except Exception:
        continue
    saltpars = {"x0": x0[i], "x1": x1[i], "c": c[i], "z": z[i], "t0": t0[i]}
    model = sncosmo.Model("salt3")
    model.update(saltpars)
    print(saltpars)
    print(model.parameters)
    print(res[i]["times"] - t0[i])
    wave = res[i]["wavelengths_rest"] * (1 + z[i])
    plt.plot(wave, model.flux(res[i]["times"][0], wave), color="g")
    plt.show()

### examine the light curves

In [None]:
for i in range(0, 3):
    times = res[i]["times"]
    colors = ["red", "brown"]
    for f, color in zip("ri", colors):
        band_name = f"LSST_{f}"
        plt.plot(times, res[i]["bandfluxes"][band_name], "-", label=f, color=color, alpha=0.6, lw=2)
        saltpars = {"x0": x0[i], "x1": x1[i], "c": c[i], "z": z[i], "t0": t0[i]}
        model = sncosmo.Model("salt3")
        model.update(saltpars)
        print(saltpars)

        sncosmo_band = sncosmo.Bandpass(
            *passbands.passbands[band_name].processed_transmission_table.T, name=band_name
        )
        flux = model.bandflux(sncosmo_band, times, zpsys="ab", zp=8.9 + 2.5 * 9)  # -48.6)
        plt.plot(times, flux, "--", label=f, color=color)
        plt.xlabel("MJD")
        plt.ylabel("Flux, nJy")
        plt.legend()
    plt.show()

TODO: 
now we have lsst-like LCs and their true parameters
we should make idealized LCs with the same true parameters and no observational effects, just insanely high cadence opsim, etc.
Let's fit each set with scipy.optimize, then plot differences relative to truth (should be unbiased for perfect data and noisy for realistic)

TODO: fit population parameters from noisy data ensemble