Skip to content

Commit

Permalink
Merge pull request #105 from nikhil-sarin/new_dev
Browse files Browse the repository at this point in the history
New dev
  • Loading branch information
MoritzThomasHuebner committed Mar 17, 2022
2 parents 0fae104 + ffa7c4d commit 5bb19d3
Show file tree
Hide file tree
Showing 11 changed files with 441 additions and 509 deletions.
4 changes: 2 additions & 2 deletions examples/kilonova_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
# creates a GRBDir with GRB
kilonova = redback.kilonova.Kilonova.from_open_access_catalogue(
name=kne, data_mode="flux_density", active_bands=np.array(["g", "i"]))
kilonova.plot_data(plot_show=False)
kilonova.plot_data(show=False)
fig, axes = plt.subplots(3, 2, sharex=True, sharey=True, figsize=(12, 8))
kilonova.plot_multiband(figure=fig, axes=axes, filters=["g", "r", "i", "z", "y", "J"], plot_show=False)
kilonova.plot_multiband(figure=fig, axes=axes, filters=["g", "r", "i", "z", "y", "J"], show=False)

# use default priors
priors = redback.priors.get_priors(model=model)
Expand Down
547 changes: 291 additions & 256 deletions redback/plotting.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions redback/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def plot_lightcurve(self, model: Union[callable, str] = None, **kwargs: dict) ->
"""
if model is None:
model = model_library.all_models_dict[self.model]
self.transient.plot_lightcurve(model=model, posterior=self.posterior, outdir=self.outdir,
self.transient.plot_lightcurve(model=model, posterior=self.posterior,
model_kwargs=self.model_kwargs, **kwargs)

def plot_multiband_lightcurve(self, model: Union[callable, str] = None, **kwargs: dict) -> None:
Expand All @@ -161,7 +161,7 @@ def plot_multiband_lightcurve(self, model: Union[callable, str] = None, **kwargs
if model is None:
model = model_library.all_models_dict[self.model]
self.transient.plot_multiband_lightcurve(
model=model, posterior=self.posterior, outdir=self.outdir, model_kwargs=self.model_kwargs, **kwargs)
model=model, posterior=self.posterior, model_kwargs=self.model_kwargs, **kwargs)

def plot_data(self, **kwargs: dict) -> None:
"""
Expand Down
117 changes: 2 additions & 115 deletions redback/transient/afterglow.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@


class Afterglow(Transient):

DATA_MODES = ['luminosity', 'flux', 'flux_density', 'magnitude']

def __init__(
Expand Down Expand Up @@ -268,7 +267,7 @@ def _set_data(self) -> None:
try:
meta_data = pd.read_csv(self.event_table, header=0, error_bad_lines=False, delimiter='\t', dtype='str')
meta_data['BAT Photon Index (15-150 keV) (PL = simple power-law, CPL = cutoff power-law)'] = meta_data[
'BAT Photon Index (15-150 keV) (PL = simple power-law, CPL = cutoff power-law)'].fillna(0)
'BAT Photon Index (15-150 keV) (PL = simple power-law, CPL = cutoff power-law)'].fillna(0)
self.meta_data = meta_data
except FileNotFoundError:
logger.warning("Meta data does not exist for this event.")
Expand Down Expand Up @@ -404,116 +403,6 @@ def _convert_flux_to_luminosity(
self.x, self.x_err, self.y, self.y_err = converter.convert_flux_to_luminosity()
self._save_luminosity_data()

def plot_lightcurve(
self, model: callable, filename: str = None, axes: matplotlib.axes.Axes = None, plot_save: bool = True,
plot_show: bool = True, random_models: int = 100, posterior: pd.DataFrame = None, outdir: str = '.',
model_kwargs: dict = None, **kwargs: object) -> None:
if self.flux_data:
plotter = IntegratedFluxPlotter(transient=self)
elif self.luminosity_data:
plotter = LuminosityPlotter(transient=self)
elif self.flux_density_data:
plotter = FluxDensityPlotter(transient=self)
elif self.magnitude_data:
plotter = MagnitudePlotter(transient=self)
else:
return axes
return plotter.plot_lightcurve(
model=model, filename=filename, axes=axes, plot_save=plot_save,
plot_show=plot_show, random_models=random_models, posterior=posterior,
outdir=outdir, model_kwargs=model_kwargs, **kwargs)

def plot_data(self, axes: matplotlib.axes.Axes = None, colour: str = 'k', **kwargs: dict) -> matplotlib.axes.Axes:
"""
Plots the Afterglow lightcurve and returns Axes.
Parameters
----------
axes : Union[matplotlib.axes.Axes, None], optional
Matplotlib axes to plot the lightcurve into. Useful for user specific modifications to the plot.
colour: str, optional
Colour of the data.
kwargs: dict
Additional keyword arguments to pass in the Plotter methods.
Returns
----------
matplotlib.axes.Axes: The axes with the plot.
"""

if self.flux_data:
plotter = IntegratedFluxPlotter(transient=self)
elif self.luminosity_data:
plotter = LuminosityPlotter(transient=self)
elif self.flux_density_data:
plotter = FluxDensityPlotter(transient=self)
elif self.magnitude_data:
plotter = MagnitudePlotter(transient=self)
else:
return axes
return plotter.plot_data(axes=axes, colour=colour, **kwargs)


def plot_multiband(
self, figure: matplotlib.figure.Figure = None, axes: matplotlib.axes.Axes = None, ncols: int = 2,
nrows: int = None, figsize: tuple = None, filters: list = None, **plot_kwargs: dict) \
-> matplotlib.axes.Axes:
"""
Parameters
----------
figure: matplotlib.figure.Figure, optional
Figure can be given if defaults are not satisfying
axes: matplotlib.axes.Axes, optional
Axes can be given if defaults are not satisfying
ncols: int, optional
Number of columns to use on the plot. Default is 2.
nrows: int, optional
Number of rows to use on the plot. If None are given this will
be inferred from ncols and the number of filters.
figsize: tuple, optional
Size of the figure. A default based on ncols and nrows will be used if None is given.
filters: list, optional
Which bands to plot. Will use default filters if None is given.
plot_kwargs:
Additional optional plotting kwargs:
wspace: Extra argument for matplotlib.pyplot.subplots_adjust
hspace: Extra argument for matplotlib.pyplot.subplots_adjust
fontsize: Label fontsize
errorbar_fmt: Errorbar format ('fmt' argument in matplotlib.pyplot.errorbar)
colors: colors to be used for the bands
xlabel: Plot xlabel
ylabel: Plot ylabel
plot_label: Addional filename label appended to the default name
Returns
-------
"""
if self.data_mode not in ['flux_density', 'magnitude']:
raise ValueError(
f'You cannot plot multiband data with {self.data_mode} data mode . Why are you doing this?')
if self.magnitude_data:
plotter = MagnitudePlotter(transient=self)
elif self.flux_density_data:
plotter = FluxDensityPlotter(transient=self)
else:
return
return plotter.plot_multiband(
figure=figure, axes=axes, ncols=ncols, nrows=nrows, figsize=figsize, filters=filters, **plot_kwargs)

def plot_multiband_lightcurve(
self, model: callable, filename: str = None, axes: matplotlib.axes.Axes = None, plot_save: bool = True,
plot_show: bool = True, random_models: int = 100, posterior: pd.DataFrame = None, outdir: str = '.',
model_kwargs: dict = None, **kwargs: object) -> None:

if self.data_mode not in ['flux_density', 'magnitude']:
raise ValueError(
f'You cannot plot multiband data with {self.data_mode} data mode . Why are you doing this?')
return super(Afterglow, self).plot_multiband_lightcurve(
model=model, filename=filename, axes=axes, plot_save=plot_save, plot_show=plot_show,
random_models=random_models, posterior=posterior, outdir=outdir, model_kwargs=model_kwargs, **kwargs)


class SGRB(Afterglow):
pass
Expand All @@ -524,7 +413,6 @@ class LGRB(Afterglow):


class Truncator(object):

TRUNCATE_METHODS = ['prompt_time_error', 'left_of_max', 'default']

def __init__(
Expand Down Expand Up @@ -634,7 +522,6 @@ def _truncate_by_index(self, index: Union[int, np.ndarray]) -> tuple:


class FluxToLuminosityConverter(object):

CONVERSION_METHODS = ["analytical", "numerical"]

def __init__(
Expand Down Expand Up @@ -685,7 +572,7 @@ def counts_to_flux_fraction(self) -> float:
-------
float: The counts to flux fraction.
"""
return self.counts_to_flux_unabsorbed/self.counts_to_flux_absorbed
return self.counts_to_flux_unabsorbed / self.counts_to_flux_absorbed

@property
def luminosity_distance(self) -> float:
Expand Down
6 changes: 3 additions & 3 deletions redback/transient/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def plot_data(self, **kwargs: dict) -> None:
plt.clf()

def plot_lightcurve(
self, model: callable, axes: matplotlib.axes.Axes = None, plot_save: bool = True, plot_show: bool = True,
self, model: callable, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True,
random_models: int = 1000, posterior: pd.DataFrame = None, outdir: str = None, **kwargs: dict) -> None:
"""
Expand All @@ -186,9 +186,9 @@ def plot_lightcurve(
The model we are using
axes: matplotlib.axes.Axes, optional
Axes to plot into. Currently a placeholder.
plot_save: bool, option
save: bool, option
Whether to save the plot. Default is `True`. Currently, a placeholder.
plot_show: bool, optional
show: bool, optional
Whether to show the plot. Default is `True`. Currently, a placeholder.
random_models: int, optional
Number of random posterior samples to use for plots. Default is 1000.
Expand Down
73 changes: 34 additions & 39 deletions redback/transient/transient.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def set_bands_and_frequency(
self, bands: Union[None, list, np.ndarray], frequency: Union[None, list, np.ndarray]):
if (bands is None and frequency is None) or (bands is not None and frequency is not None):
self._bands = bands
self._frequency = bands
self._frequency = frequency
elif bands is None and frequency is not None:
self._frequency = frequency
self._bands = self.frequency
Expand Down Expand Up @@ -454,20 +454,21 @@ def get_colors(filters: Union[np.ndarray, list]) -> matplotlib.colors.Colormap:
Returns
-------
matplotlib.colors.Colormap: Colormap with one colour for each filter
matplotlib.colors.Colormap: Colormap with one color for each filter
"""
return matplotlib.cm.rainbow(np.linspace(0, 1, len(filters)))

def plot_data(self, axes: matplotlib.axes.Axes = None, colour: str = 'k', **kwargs: dict) -> matplotlib.axes.Axes:
def plot_data(self, axes: matplotlib.axes.Axes = None, filename: str = None, outdir: str = None, save: bool = True,
show: bool = True, plot_others: bool = True, color: str = 'k', **kwargs: dict) -> matplotlib.axes.Axes:
"""
Plots the Afterglow lightcurve and returns Axes.
Parameters
----------
axes : Union[matplotlib.axes.Axes, None], optional
Matplotlib axes to plot the lightcurve into. Useful for user specific modifications to the plot.
colour: str, optional
Colour of the data.
color: str, optional
color of the data.
kwargs: dict
Additional keyword arguments to pass in the Plotter methods.
Expand All @@ -477,20 +478,21 @@ def plot_data(self, axes: matplotlib.axes.Axes = None, colour: str = 'k', **kwar
"""

if self.flux_data:
plotter = IntegratedFluxPlotter(transient=self)
plotter = IntegratedFluxPlotter(transient=self, color=color, filename=filename, outdir=outdir, **kwargs)
elif self.luminosity_data:
plotter = LuminosityPlotter(transient=self)
plotter = LuminosityPlotter(transient=self, color=color, filename=filename, outdir=outdir, **kwargs)
elif self.flux_density_data:
plotter = FluxDensityPlotter(transient=self)
plotter = FluxDensityPlotter(transient=self, color=color, filename=filename, outdir=outdir, plot_others=plot_others, **kwargs)
elif self.magnitude_data:
plotter = MagnitudePlotter(transient=self)
plotter = MagnitudePlotter(transient=self, color=color, filename=filename, outdir=outdir, plot_others=plot_others, **kwargs)
else:
return axes
return plotter.plot_data(axes=axes, colour=colour, **kwargs)
return plotter.plot_data(axes=axes, save=save, show=show)

def plot_multiband(
self, figure: matplotlib.figure.Figure = None, axes: matplotlib.axes.Axes = None, ncols: int = 2,
nrows: int = None, figsize: tuple = None, filters: list = None, **plot_kwargs: dict) \
self, figure: matplotlib.figure.Figure = None, axes: matplotlib.axes.Axes = None, filename: str = None,
outdir: str = None, ncols: int = 2, save: bool = True, show: bool = True,
nrows: int = None, figsize: tuple = None, filters: list = None, **kwargs: dict) \
-> matplotlib.axes.Axes:
"""
Expand All @@ -509,7 +511,7 @@ def plot_multiband(
Size of the figure. A default based on ncols and nrows will be used if None is given.
filters: list, optional
Which bands to plot. Will use default filters if None is given.
plot_kwargs:
kwargs:
Additional optional plotting kwargs:
wspace: Extra argument for matplotlib.pyplot.subplots_adjust
hspace: Extra argument for matplotlib.pyplot.subplots_adjust
Expand All @@ -528,17 +530,16 @@ def plot_multiband(
raise ValueError(
f'You cannot plot multiband data with {self.data_mode} data mode . Why are you doing this?')
if self.magnitude_data:
plotter = MagnitudePlotter(transient=self)
plotter = MagnitudePlotter(transient=self, filters=filters, filename=filename, outdir=outdir, nrows=nrows, ncols=ncols, figsize=figsize, **kwargs)
elif self.flux_density_data:
plotter = FluxDensityPlotter(transient=self)
plotter = FluxDensityPlotter(transient=self, filters=filters, filename=filename, outdir=outdir, nrows=nrows, ncols=ncols, figsize=figsize, **kwargs)
else:
return
return plotter.plot_multiband(
figure=figure, axes=axes, ncols=ncols, nrows=nrows, figsize=figsize, filters=filters, **plot_kwargs)
return plotter.plot_multiband(figure=figure, axes=axes, save=save, show=show)

def plot_lightcurve(
self, model: callable, filename: str = None, axes: matplotlib.axes.Axes = None, plot_save: bool = True,
plot_show: bool = True, random_models: int = 100, posterior: pd.DataFrame = None, outdir: str = '.',
self, model: callable, filename: str = None, outdir: str = None, axes: matplotlib.axes.Axes = None,
save: bool = True, show: bool = True, random_models: int = 100, posterior: pd.DataFrame = None,
model_kwargs: dict = None, **kwargs: object) -> None:
"""
Expand All @@ -551,9 +552,9 @@ def plot_lightcurve(
attribute and ends with *lightcurve.png.
axes: matplotlib.axes.Axes, optional
Axes to plot in if given.
plot_save: bool, optional
save: bool, optional
Whether to save the plot.
plot_show: bool, optional
show: bool, optional
Whether to show the plot.
random_models: int, optional
Number of random posterior samples plotted faintly. Default is 100.
Expand All @@ -567,24 +568,20 @@ def plot_lightcurve(
No current function.
"""
if self.flux_data:
plotter = IntegratedFluxPlotter(transient=self)
plotter = IntegratedFluxPlotter(transient=self, model=model, filename=filename, outdir=outdir, posterior=posterior, model_kwargs=model_kwargs, **kwargs)
elif self.luminosity_data:
plotter = LuminosityPlotter(transient=self)
plotter = LuminosityPlotter(transient=self, model=model, filename=filename, outdir=outdir, posterior=posterior, model_kwargs=model_kwargs, **kwargs)
elif self.flux_density_data:
plotter = FluxDensityPlotter(transient=self)
plotter = FluxDensityPlotter(transient=self, model=model, filename=filename, outdir=outdir, posterior=posterior, model_kwargs=model_kwargs, **kwargs)
elif self.magnitude_data:
plotter = MagnitudePlotter(transient=self)
plotter = MagnitudePlotter(transient=self, model=model, filename=filename, outdir=outdir, posterior=posterior, model_kwargs=model_kwargs, **kwargs)
else:
return axes
return plotter.plot_lightcurve(
model=model, filename=filename, axes=axes, plot_save=plot_save,
plot_show=plot_show, random_models=random_models, posterior=posterior,
outdir=outdir, model_kwargs=model_kwargs, **kwargs)

return plotter.plot_lightcurve(axes=axes, save=save, show=show)

def plot_multiband_lightcurve(
self, model: callable, filename: str = None, axes: matplotlib.axes.Axes = None, plot_save: bool = True,
plot_show: bool = True, random_models: int = 100, posterior: pd.DataFrame = None, outdir: str = '.',
self, model: callable, filename: str = None, outdir: str = None, axes: matplotlib.axes.Axes = None,
save: bool = True, show: bool = True, random_models: int = 100, posterior: pd.DataFrame = None,
model_kwargs: dict = None, **kwargs: object) -> None:
"""
Expand All @@ -597,9 +594,9 @@ def plot_multiband_lightcurve(
attribute and ends with *lightcurve.png.
axes: matplotlib.axes.Axes, optional
Axes to plot in if given.
plot_save: bool, optional
save: bool, optional
Whether to save the plot.
plot_show: bool, optional
show: bool, optional
Whether to show the plot.
random_models: int, optional
Number of random posterior samples plotted faintly. Default is 100.
Expand All @@ -618,14 +615,12 @@ def plot_multiband_lightcurve(
raise ValueError(
f'You cannot plot multiband data with {self.data_mode} data mode . Why are you doing this?')
if self.magnitude_data:
plotter = MagnitudePlotter(transient=self)
plotter = MagnitudePlotter(transient=self, model=model, filename=filename, outdir=outdir, posterior=posterior, model_kwargs=model_kwargs, **kwargs)
elif self.flux_density_data:
plotter = FluxDensityPlotter(transient=self)
plotter = FluxDensityPlotter(transient=self, model=model, filename=filename, outdir=outdir, posterior=posterior, model_kwargs=model_kwargs, **kwargs)
else:
return
return plotter.plot_multiband_lightcurve(model=model, filename=filename, axes=axes, plot_save=plot_save,
plot_show=plot_show, random_models=random_models, posterior=posterior, outdir=outdir,
model_kwargs=model_kwargs, **kwargs)
return plotter.plot_multiband_lightcurve(axes=axes, save=save, show=show)


class OpticalTransient(Transient):
Expand Down

0 comments on commit 5bb19d3

Please sign in to comment.