Skip to content

Commit

Permalink
Merge pull request #202 from nikhil-sarin/newheatingratemodel
Browse files Browse the repository at this point in the history
Misc and get set for v1.02
  • Loading branch information
nikhil-sarin committed Mar 19, 2024
2 parents ef70edc + b7ad7e1 commit 898317f
Show file tree
Hide file tree
Showing 18 changed files with 882 additions and 53 deletions.
43 changes: 43 additions & 0 deletions examples/calling_a_model_with_inbuilt_constraints.py
@@ -0,0 +1,43 @@
import bilby.core.prior
from bilby.core.prior import PriorDict, Uniform, Constraint
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import redback
from redback.constraints import csm_constraints

model = 'csm_interaction'
model_priors = redback.priors.get_priors(model=model)
# model = 'csm_interaction_bolometric'
function = redback.model_library.all_models_dict[model]

priors = bilby.core.prior.PriorDict(conversion_function=csm_constraints)
priors.update(model_priors)
# priors['photosphere_constraint_1'] = Constraint(0, 1)
# priors['photosphere_constraint_2'] = Constraint(0, 1)
priors['csm_mass'] = 58.0
priors['mej'] = 46
priors['vej'] = 5500
priors['r0'] = 617
priors['nn'] = 8.8
priors['delta'] = 0.
priors['rho'] = 19
priors['eta'] = 2
priors['redshift'] = 0.16
samples = pd.DataFrame(priors.sample(10))
time = np.geomspace(0.01, 500, 500)
redshift = 0.01


for x in range(10):
kwargs = samples.iloc[x]
kwargs['output_format'] = 'magnitude'
kwargs['bands'] = ['lsstg']
# kwargs['interaction_process'] = None
mag = function(time, **kwargs)
print(mag)
# plt.loglog(time, mag)
plt.plot(time, mag)
plt.gca().invert_yaxis()
plt.show()
4 changes: 4 additions & 0 deletions redback/__init__.py
Expand Up @@ -2,3 +2,7 @@
transient_models, utils, photosphere, sed, interaction_processes, constraints, plotting, model_library, simulate_transients
from redback.transient import afterglow, kilonova, prompt, supernova, tde
from redback.sampler import fit_model
from redback.utils import setup_logger

__version__ = "1.0.1"
setup_logger(log_level='info')
228 changes: 207 additions & 21 deletions redback/likelihoods.py

Large diffs are not rendered by default.

20 changes: 10 additions & 10 deletions redback/plotting.py
Expand Up @@ -346,7 +346,7 @@ def _plot_lightcurves(self, axes: matplotlib.axes.Axes, times: np.ndarray) -> No
axes.plot(times, ys, color=self.random_sample_color, alpha=self.random_sample_alpha, lw=self.linewidth,
zorder=self.zorder)
elif self.uncertainty_mode == "credible_intervals":
lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=random_ys_list)
lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=random_ys_list, interval=self.credible_interval_level)
axes.fill_between(
times, lower_bound, upper_bound, alpha=self.uncertainty_band_alpha, color=self.max_likelihood_color)

Expand Down Expand Up @@ -391,11 +391,11 @@ class LuminosityPlotter(IntegratedFluxPlotter):

class MagnitudePlotter(Plotter):

xlim_low_phase_model_multiplier = 0.9
xlim_high_phase_model_multiplier = 1.1
xlim_high_multiplier = 1.2
ylim_low_magnitude_multiplier = 0.8
ylim_high_magnitude_multiplier = 1.2
xlim_low_phase_model_multiplier = KwargsAccessorWithDefault("xlim_low_multiplier", 0.9)
xlim_high_phase_model_multiplier = KwargsAccessorWithDefault("xlim_high_multiplier", 1.1)
xlim_high_multiplier = KwargsAccessorWithDefault("xlim_high_multiplier", 1.2)
ylim_low_magnitude_multiplier = KwargsAccessorWithDefault("ylim_low_multiplier", 0.8)
ylim_high_magnitude_multiplier = KwargsAccessorWithDefault("ylim_high_multiplier", 1.2)
ncols = KwargsAccessorWithDefault("ncols", 2)

@property
Expand Down Expand Up @@ -635,11 +635,11 @@ def plot_lightcurve(
elif self.uncertainty_mode == "credible_intervals":
if band in self.band_scaling:
if self.band_scaling.get("type") == 'x':
lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=np.array(random_ys_list) * self.band_scaling.get(band))
lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=np.array(random_ys_list) * self.band_scaling.get(band), interval=self.credible_interval_level)
elif self.band_scaling.get("type") == '+':
lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=np.array(random_ys_list) + self.band_scaling.get(band))
lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=np.array(random_ys_list) + self.band_scaling.get(band), interval=self.credible_interval_level)
else:
lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=np.array(random_ys_list))
lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=np.array(random_ys_list), interval=self.credible_interval_level)
axes.fill_between(
times - self._reference_mjd_date, lower_bound, upper_bound,
alpha=self.uncertainty_band_alpha, color=color_sample)
Expand Down Expand Up @@ -789,7 +789,7 @@ def plot_multiband_lightcurve(
axes[ii].plot(times - self._reference_mjd_date, random_ys, color=color_sample,
alpha=self.random_sample_alpha, lw=self.linewidth, zorder=self.zorder)
elif self.uncertainty_mode == "credible_intervals":
lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=random_ys_list)
lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=random_ys_list, interval=self.credible_interval_level)
axes[ii].fill_between(
times - self._reference_mjd_date, lower_bound, upper_bound,
alpha=self.uncertainty_band_alpha, color=color_sample)
Expand Down
11 changes: 11 additions & 0 deletions redback/priors.py
Expand Up @@ -9,6 +9,17 @@


def get_priors(model, times=None, y=None, yerr=None, dt=None, **kwargs):
"""
Get the prior for the given model. If the model is a prompt model, the times, y, and yerr must be provided.
:param model: String referring to a name of a model implemented in Redback.
:param times: Time array
:param y: Y values, arbitrary units
:param yerr: Error on y values, arbitrary units
:param dt: time interval
:param kwargs: Extra arguments to be passed to the prior function
:return: priors: PriorDict object
"""
prompt_prior_functions = dict(gaussian=get_gaussian_priors, skew_gaussian=get_skew_gaussian_priors,
skew_exponential=get_skew_exponential_priors, fred=get_fred_priors,
fred_extended=get_fred_extended_priors)
Expand Down
5 changes: 5 additions & 0 deletions redback/priors/one_comp_kne_rosswog_heatingrate.prior
@@ -0,0 +1,5 @@
redshift = Uniform(1e-6, 0.1, 'redshift', latex_label = r'$z$')
mej = Uniform(1e-2, 0.05, 'mej', latex_label = r'$M_{\mathrm{ej} }~(M_\odot)$')
vej = Uniform(0.05, 0.3, 'vej', latex_label = r'$v_{\mathrm{ej}}~(c)$')
ye = Uniform(0.05, 0.4, 'ye', latex_label = r'$Y_{e}$')
temperature_floor = LogUniform(100, 6000, 'temperature_floor', latex_label = r'$T_{\mathrm{floor}}~(\mathrm{K})$')
9 changes: 9 additions & 0 deletions redback/priors/two_comp_kne_rosswog_heatingrate.prior
@@ -0,0 +1,9 @@
redshift = Uniform(1e-6, 0.1, 'redshift', latex_label = r'$z$')
mej_1 = Uniform(1e-2, 0.05, 'mej', latex_label = r'$M_{\mathrm{ej}~1}~(M_\odot)$')
vej_1 = Uniform(0.05, 0.3, 'vej', latex_label = r'$v_{\mathrm{ej}~1}~(c)$')
ye_1 = Uniform(0.05, 0.4, 'ye', latex_label = r'$Y_{e}~1$')
temperature_floor_1 = LogUniform(100, 6000, 'temperature_floor', latex_label = r'$T_{\mathrm{floor}~1}~(\mathrm{K})$')
mej_2 = Uniform(1e-2, 0.05, 'mej', latex_label = r'$M_{\mathrm{ej}~2}~(M_\odot)$')
vej_2 = Uniform(0.05, 0.3, 'vej', latex_label = r'$v_{\mathrm{ej}~2}~(c)$')
ye_2 = Uniform(0.05, 0.4, 'ye', latex_label = r'$Y_{e}~2$')
temperature_floor_2 = LogUniform(100, 6000, 'temperature_floor', latex_label = r'$T_{\mathrm{floor}~2}~(\mathrm{K})$')
5 changes: 5 additions & 0 deletions redback/sed.py
Expand Up @@ -304,6 +304,11 @@ def get_correct_output_format_from_spectra(time, time_eval, spectra, lambda_arra
:param output_format: 'flux', 'magnitude', 'sncosmo_source', 'flux_density'
:return: flux, magnitude or SNcosmo TimeSeries Source depending on output format kwarg
"""
# clean up spectrum to remove nonsensical values before creating sncosmo source
spectra = np.nan_to_num(spectra)
spectra[spectra.value == np.nan_to_num(np.inf)] = 1e-30 * np.mean(spectra[10])
spectra[spectra.value == 0.] = 1e-30 * np.mean(spectra[10])

source = TimeSeriesSource(phase=time_eval, wave=lambda_array, flux=spectra)
if kwargs['output_format'] == 'flux':
bands = kwargs['bands']
Expand Down
33 changes: 28 additions & 5 deletions redback/simulate_transients.py
Expand Up @@ -14,7 +14,7 @@

class SimulateGenericTransient(object):
def __init__(self, model, parameters, times, model_kwargs, data_points,
seed=1234, multiwavelength_transient=False, noise_term=0.2):
seed=1234, multiwavelength_transient=False, noise_term=0.2, noise_type='gaussianmodel', extra_scatter=0.0):
"""
A generic interface to simulating transients
Expand All @@ -31,7 +31,12 @@ def __init__(self, model, parameters, times, model_kwargs, data_points,
and the data points are sampled in bands/frequency as well,
rather than just corresponding to one wavelength/filter.
This also allows the same time value to be sampled multiple times.
:param noise_term: Float. Factor which is multiplied by the model flux/magnitude to give the sigma.
:param noise_type: String. Type of noise to add to the model.
Default is 'gaussianmodel' where sigma is noise_term * model.
Another option is 'gaussian' i.e., a simple Gaussian noise with sigma = noise_term.
:param noise_term: Float. Factor which is multiplied by the model flux/magnitude to give the sigma
or is sigma itself for 'gaussian' noise.
:param extra_scatter: Float. Sigma of normal added to output for additional scatter.
"""
self.model = redback.model_library.all_models_dict[model]
self.parameters = parameters
Expand Down Expand Up @@ -81,8 +86,26 @@ def __init__(self, model, parameters, times, model_kwargs, data_points,
if 'frequency' in model_kwargs.keys():
data['frequency'] = self.subset_frequency
data['true_output'] = true_output
data['output'] = np.random.normal(true_output, self.noise_term * true_output)
data['output_error'] = self.noise_term * true_output

if noise_type == 'gaussianmodel':
noise = np.random.normal(0, self.noise_term * true_output, len(true_output))
output = true_output + noise
output_error = self.noise_term * true_output
elif noise_type == 'gaussian':
noise = np.random.normal(0, self.noise_term, len(true_output))
output = true_output + noise
output_error = self.noise_term
else:
logger.warning(f"noise_type {noise_type} not implemented.")
raise ValueError('noise_type must be either gaussianmodel or gaussian')

if extra_scatter > 0:
extra_noise = np.random.normal(0, extra_scatter, len(true_output))
output = output + extra_noise
output_error = np.sqrt(output_error**2 + extra_noise**2)

data['output'] = output
data['output_error'] = output_error
self.data = data

def save_transient(self, name):
Expand All @@ -98,7 +121,7 @@ def save_transient(self, name):
path = 'simulated/' + name + '.csv'
injection_path = 'simulated/' + name + '_injection_parameters.csv'
self.data.to_csv(path, index=False)
self.parameters=pd.DataFrame.from_dict(self.parameters)
self.parameters=pd.DataFrame.from_dict([self.parameters])
self.parameters.to_csv(injection_path, index=False)

class SimulateOpticalTransient(object):
Expand Down
5 changes: 3 additions & 2 deletions redback/transient/afterglow.py
Expand Up @@ -57,7 +57,7 @@ def __init__(
:type flux: np.ndarray, optional
:type flux_err: np.ndarray, optional
:param flux_err: Flux error values.
:param flux_density:Flux density values.
:param flux_density: Flux density values.
:type flux_density: np.ndarray, optional
:param flux_density_err: Flux density error values.
:type flux_density_err: np.ndarray, optional
Expand Down Expand Up @@ -244,7 +244,8 @@ def _set_data(self) -> None:
'BAT Photon Index (15-150 keV) (PL = simple power-law, CPL = cutoff power-law)'].fillna(0)
self.meta_data = meta_data
except FileNotFoundError:
logger.warning("Meta data does not exist for this event.")
logger.info("Metadata does not exist for this event.")
logger.info("Setting metadata to None. This is not an error, but a warning that no metadata could be found online.")
self.meta_data = None

def _set_photon_index(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion redback/transient/kilonova.py
Expand Up @@ -50,7 +50,7 @@ def __init__(
:type flux: np.ndarray, optional
:type flux_err: np.ndarray, optional
:param flux_err: Flux error values.
:param flux_density:Flux density values.
:param flux_density: Flux density values.
:type flux_density: np.ndarray, optional
:param flux_density_err: Flux density error values.
:type flux_density_err: np.ndarray, optional
Expand Down
2 changes: 1 addition & 1 deletion redback/transient/supernova.py
Expand Up @@ -45,7 +45,7 @@ def __init__(
:type flux: np.ndarray, optional
:type flux_err: np.ndarray, optional
:param flux_err: Flux error values.
:param flux_density:Flux density values.
:param flux_density: Flux density values.
:type flux_density: np.ndarray, optional
:param flux_density_err: Flux density error values.
:type flux_density_err: np.ndarray, optional
Expand Down
2 changes: 1 addition & 1 deletion redback/transient/tde.py
Expand Up @@ -42,7 +42,7 @@ def __init__(
:type flux: np.ndarray, optional
:type flux_err: np.ndarray, optional
:param flux_err: Flux error values.
:param flux_density:Flux density values.
:param flux_density: Flux density values.
:type flux_density: np.ndarray, optional
:param flux_density_err: Flux density error values.
:type flux_density_err: np.ndarray, optional
Expand Down
4 changes: 2 additions & 2 deletions redback/transient/transient.py
Expand Up @@ -137,7 +137,7 @@ def __init__(
self.system = system
self.data_mode = data_mode
self.active_bands = active_bands
self.sncosmo_bands = redback.utils.sncosmo_bandname_from_band(self.bands, warning_style='soft')
self.sncosmo_bands = redback.utils.sncosmo_bandname_from_band(self.bands)
self.redshift = redshift
self.name = name
self.use_phase_model = use_phase_model
Expand Down Expand Up @@ -906,7 +906,7 @@ def _set_data(self) -> None:
meta_data = pd.read_csv(self.event_table, on_bad_lines='skip', delimiter=',', dtype='str')
except FileNotFoundError as e:
redback.utils.logger.warning(e)
redback.utils.logger.warning("Setting metadata to None")
redback.utils.logger.warning("Setting metadata to None. This is not an error, but a warning that no metadata could be found online.")
meta_data = None
self.meta_data = meta_data

Expand Down

0 comments on commit 898317f

Please sign in to comment.