# Figures for SNIa NIR Space Astro 2020 Paper
Michael Wood-Vasey

Figures to Make:
1. SED with UV, optical, NIR filters at z= 0, 0.2, 0.5, 1.0, 1.5
2. Cosmological constraints from reduced intrinsic scatter.
3. Constraints on dust extinction from UV+opt+NIR data.

In [None]:
from copy import copy

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
import numpy as np

In [None]:
import sncosmo
from astropy.cosmology import Planck15
import astropy.units as u
import astropy.constants as const

In [None]:
model = sncosmo.Model('hsiao')
wavelength = np.arange(model.minwave(), model.maxwave(), 10)  # Angstroms
phases = np.array([-19, -10, 0, 10])

In [None]:
model.param_names

In [None]:
model.update

In [None]:
z_0 = Planck15.H0 * (10 * u.pc) / const.c
z_0 = z_0.decompose()

In [None]:
mu = Planck15.distmod(z_0)

In [None]:
mu / u.mag

In [None]:
def plot_model(untouch_model, ax, phases=[-10, 0, 10, 20, 50], redshifts=[0], scaling=None,
              wavelength_plot_buffer=[.8, .3]):
    model = copy(untouch_model)

    # Only label redshift or phase if it's varying.
    label_entries = []
    if len(phases) > 1:
        label_entries.append('{day:} days')
    if len(redshifts) > 1:
        label_entries.append('z={redshift:}')
    label_format_string = ','.join(label_entries)
                
    for d in phases:
        z_eps = (Planck15.H0 * (10 * u.pc) / const.c).decompose()  # Redshift of 10 pc
        mu = Planck15.distmod(np.asarray(redshifts)+z_eps)  # Enforce a minimum (z_eps ~= 3e-9) so that the naive z=0 doesn't crash
        if scaling is None:
            scaling = 10**(-0.4*(mu/u.mag)).value

        for z, s in zip(redshifts, scaling):
            model.set(z=z)
            this_wavelength = np.arange(model.minwave(), model.maxwave(), 1)
            flux = model.flux(time=d, wave=this_wavelength)
            this_wavelength_um = this_wavelength * 1e-4
            # Rescale by mu
            flux *= s
            ax.plot(this_wavelength_um, flux, label=label_format_string.format(**{'redshift': z, 'day': d}))
    
    ax.set_xlim(this_wavelength_um[0] - wavelength_plot_buffer[0],
                this_wavelength_um[-1] + wavelength_plot_buffer[1])
    ax.set_xlabel(r'Wavelength [$\mu$m]')
    ax.set_ylabel(r'flux density [d flux/d$\lambda$]') 
    ax.legend(loc='upper left')

In [None]:
def plot_cumulative_flux(untouch_model, ax, phases=[-10, 0, 10, 20, 50], redshifts=[0], scaling=None,
              wavelength_plot_buffer=[.8, .3]):
    model = copy(untouch_model)

    # Only label redshift or phase if it's varying.
    label_entries = []
    if len(phases) > 1:
        label_entries.append('{day:} days')
    if len(redshifts) > 1:
        label_entries.append('z={redshift:}')
    label_format_string = ','.join(label_entries)
                
    for d in phases:
        for z in redshifts:
            model.set(z=z)
            this_wavelength = np.arange(model.minwave(), model.maxwave(), 1)
            flux = model.flux(time=d, wave=this_wavelength)
            this_wavelength_um = this_wavelength * 1e-4
            cdfflux = np.cumsum(flux)
            cdfflux /= cdfflux[-1]
            ax.plot(this_wavelength_um, cdfflux, label=label_format_string.format(**{'redshift': z, 'day': d}))
    
    ax.set_xlim(this_wavelength_um[0] - wavelength_plot_buffer[0],
                this_wavelength_um[-1] + wavelength_plot_buffer[1])
#    ax.set_xlabel(r'Wavelength [$\AA$]')
    ax.set_xlabel(r'Wavelength [$\mu$m]')
    ax.set_ylabel(r'Cumulative flux') 
    ax.legend(loc='upper left')

In [None]:
def overlay_bandpasses(ax,
                       bands=('lsstu', 'lsstg', 'lsstr', 'lssti', 'lsstz', 'lssty',
                              'cspjs', 'csphs', 'cspk',
                              'f277w', 'f356w', 'f444w'),
                       wavelength_resolution=1,
                       plot_bandpass_range=True, plot_bandpass_transmission=False,
                       legend=True,
                       loc='upper right'
                      ):
    """Overlay bandpass curves on twinx of given axis.
    
    ax:  matplotlib Axis
    wavelength_resolution:  float  [Angstrom
    """
    short_labels = {'lsstu': 'u', 'lsstg': 'g', 'lsstr': 'r', 'lssti': 'i', 'lsstz': 'z',
                    'lssty': 'y', 'cspjs': 'J', 'csphs': 'H', 'cspk': 'K',
                    'f277w': 'f277w', 'f356w': 'f356w', 'f444w': 'f444w'}

    ax2 = ax.twinx()
    for b in bands:
        band = sncosmo.get_bandpass(b)
        wavelength = np.arange(band.minwave(), band.maxwave(), wavelength_resolution)
        wavelength_um = wavelength * 1e-4
        transmission = band(wavelength)

        if plot_bandpass_transmission:
            ax2.plot(wavelength_um, transmission, label=short_labels[b], linestyle=':')

        if plot_bandpass_range:
            # Find points of 50% of peak transmission
            # Assume we're in wavelength order
            w, = np.where(transmission/np.max(transmission) > 0.50)
            minwave, maxwave = wavelength[w[0]], wavelength[w[-1]]
            wavelength = np.array([minwave, minwave, maxwave, maxwave])
            wavelength_um = wavelength * 1e-4
            points = [0, 1, 1, 0]
            ax2.plot(wavelength_um, points, label=short_labels[b], linestyle='--')

    if legend:
        ax2.legend(loc=loc)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

plot_model(model, axes[0])

plot_model(model, axes[1])
axes[1].set_yscale('log')
#axes[1].set_ylim(1e-12, 1.2e-8)

overlay_bandpasses(axes[0], plot_bandpass_transmission=True, plot_bandpass_range=False)
overlay_bandpasses(axes[0], plot_bandpass_transmission=False, plot_bandpass_range=True)

overlay_bandpasses(axes[1])

In [None]:
fig, axes = plt.subplots(1, 1, figsize=(6, 4))

plot_model(model, axes)
axes.set_yscale('log')
#axes[1].set_ylim(1e-12, 1.2e-8)

overlay_bandpasses(axes)
plt.savefig('SNIa_phases_restframe.pdf')

In [None]:
fig = plt.figure(figsize=(6, 4))
ax = plt.gca()
kwargs = {'phases': [-10], 'redshifts': [0, 0.2, 0.5, 1.0, 1.5], 'scaling': [1e0, 1e-1, 1e-2, 1e-3, 1e-4]}

plot_model(model, ax, **kwargs)

ax.set_yscale('log')
ax.set_ylim(1e-15, 1.5e-8)

overlay_bandpasses(ax)

In [None]:
fig = plt.figure(figsize=(6, 4))
ax = plt.gca()
kwargs = {'phases': [0], 'redshifts': [0, 0.2, 0.5, 1.0, 1.5], 'scaling': [1e0, 1e-1, 1e-2, 1e-3, 1e-4]}

plot_model(model, ax, **kwargs)

ax.set_yscale('log')
ax.set_ylim(1e-15, 1.5e-8)

overlay_bandpasses(ax)
plt.savefig('SNIa_at_max_over_redshift.pdf')

In [None]:
fig = plt.figure(figsize=(6, 4))
ax = plt.gca()
kwargs = {'phases': [20], 'redshifts': [0, 0.2, 0.5, 1.0, 1.5], 'scaling': [1e0, 1e-1, 1e-2, 1e-3, 1e-4]}

plot_model(model, ax, **kwargs)

ax.set_yscale('log')
ax.set_ylim(1e-15, 1.5e-8)

overlay_bandpasses(ax)

In [None]:
fig = plt.figure(figsize=(6, 4))
ax = plt.gca()
kwargs = {'phases': [20], 'redshifts': [0, 0.2, 0.5, 1.0, 1.5], 'scaling': [1e0, 1e-1, 1e-2, 1e-3, 1e-4]}

plot_cumulative_flux(model, ax, **kwargs)

# ax.set_yscale('log')

overlay_bandpasses(ax, loc='lower right')