Skip to content

Commit

Permalink
Add plot_light_curve
Browse files Browse the repository at this point in the history
  • Loading branch information
kboone committed May 8, 2019
1 parent f8c7256 commit ebd6ac1
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 6 deletions.
143 changes: 137 additions & 6 deletions avocado/astronomical_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from functools import partial
import george
from george import kernels
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
from scipy.optimize import minimize

from .instruments import band_central_wavelengths
from .instruments import band_central_wavelengths, band_colors

class AstronomicalObject():
"""Class representing an astronomical object.
Expand Down Expand Up @@ -85,7 +86,7 @@ def subtract_background(self):

return subtracted_observations

def preprocess_observations(self, subtract_background=True):
def preprocess_observations(self, subtract_background=True, **kwargs):
"""Apply preprocessing to the observations.
This function is intended to be used to transform the raw observations
Expand All @@ -98,6 +99,11 @@ def preprocess_observations(self, subtract_background=True):
If True (the default), a background subtraction routine is applied
to the lightcurve before fitting the GP. Otherwise, the flux values
are used as-is.
kwargs : dict
Additional keyword arguments. These are ignored. We allow
additional keyword arguments so that the various functions that
call this one can be called with the same arguments, even if they
don't actually use them.
Returns
-------
Expand Down Expand Up @@ -145,9 +151,8 @@ def fit_gaussian_process(self, subtract_background=True, fix_scale=False,
gp_observations : pandas.DataFrame
The processed observations that the GP was fit to. This could have
effects such as background subtraction applied to it.
gp_fit_parameters : dict
A dictionary containing all of the information needed to build the
Gaussian process.
gp_fit_parameters : list
A list of the resulting GP fit parameters.
"""
gp_observations = self.preprocess_observations(**preprocessing_kwargs)

Expand All @@ -164,7 +169,7 @@ def fit_gaussian_process(self, subtract_background=True, fix_scale=False,
np.abs(fluxes) /
np.sqrt(flux_errors**2 + (1e-2 * np.max(fluxes))**2)
)
scale = fluxes[np.argmax(signal_to_noises.idxmax())]
scale = fluxes[signal_to_noises.idxmax()]

kernel = (
(0.2 * scale)**2 *
Expand Down Expand Up @@ -222,3 +227,129 @@ def grad_neg_ln_like(p):
gaussian_process = partial(gp.predict, fluxes)

return gaussian_process, gp_observations, fit_result.x

def predict_gaussian_process(self, bands, times, uncertainties=True,
fitted_gp=None, **gp_kwargs):
"""Predict the Gaussian process in a given set of bands and at a given
set of times.
Parameters
----------
bands : list(str)
bands to predict the Gaussian process in.
times : list or numpy.array of floats
times to evaluate the Gaussian process at.
uncertainties : bool (optional)
If True (default), the GP uncertainties are computed and returned
along with the mean prediction. If False, only the mean prediction
is returned.
fitted_gp : function (optional)
By default, this function will perform the GP fit before doing
predictions. If the GP fit has already been done, then the fitted
GP function (returned by fit_gaussian_process) can be passed here
instead to skip redoing the fit.
gp_kwargs : kwargs (optional)
Additional arguments that are passed to `fit_gaussian_process`.
Returns:
--------
predictions : numpy.array
A 2-dimensional array with shape (len(bands), len(times))
containing the Gaussian process mean flux predictions.
prediction_uncertainties : numpy.array
Only returned if uncertainties is True. This is an array with the
same shape as predictions containing the Gaussian process
uncertainty for the predictions.
"""
if fitted_gp is not None:
gp = fitted_gp
else:
gp, _, _ = self.fit_gaussian_process(**gp_kwargs)

# Predict the Gaussian process band-by-band.
predictions = []
prediction_uncertainties = []

for band in bands:
wavelengths = (
np.ones(len(times)) * band_central_wavelengths[band]
)
pred_x_data = np.vstack([times, wavelengths]).T
if uncertainties:
band_pred, band_pred_var = gp(pred_x_data, return_var=True)
prediction_uncertainties.append(np.sqrt(band_pred_var))
else:
band_pred = gp(pred_x_data, return_cov=False)
predictions.append(band_pred)

predictions = np.array(predictions)
if uncertainties:
prediction_uncertainties = np.array(prediction_uncertainties)
return predictions, prediction_uncertainties
else:
return predictions

def plot_light_curve(self, show_gp=True, **kwargs):
"""Plot the object's light curve
Parameters
----------
show_gp : bool (optional)
If True (default), the Gaussian process prediction is plotted along
with the raw data.
kwargs : kwargs (optional)
Additional arguments. If show_gp is True, these are passed to
`fit_gaussian_process`. Otherwise, these are passed to
`preprocess_observations`.
"""
if show_gp:
gp, observations, gp_fit_parameters = \
self.fit_gaussian_process(**kwargs)
else:
observations = self.preprocess_observations(**kwargs)

# Figure out the times to plot. We go 10% past the edges of the
# observations.
min_time_obs = np.min(observations['time'])
max_time_obs = np.max(observations['time'])
border = 0.1 * (max_time_obs - min_time_obs)
min_time = min_time_obs - border
max_time = max_time_obs + border

if show_gp:
pred_times = np.arange(min_time, max_time + 1)

predictions, prediction_uncertainties = \
self.predict_gaussian_process(self.bands, pred_times,
fitted_gp=gp)

plt.figure()

for band_idx, band in enumerate(self.bands):
mask = observations['band'] == band
band_data = observations[mask]
color = band_colors[band]

plt.errorbar(band_data['time'], band_data['flux'],
band_data['flux_error'], fmt='o', c=color,
markersize=5, label=band)

if not show_gp:
continue

pred = predictions[band_idx]
plt.plot(pred_times, pred, c=color)
err = prediction_uncertainties[band_idx]

if kwargs.get('uncertainties', True):
# If they were calculated, show uncertainties with a shaded
# band.
plt.fill_between(pred_times, pred-err, pred+err, alpha=0.2,
color=color)

plt.legend()

plt.xlabel('Time (days)')
plt.ylabel('Flux')
plt.tight_layout()
plt.xlim(min_time, max_time)
10 changes: 10 additions & 0 deletions avocado/instruments.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,13 @@
'lsstz': 8691.,
'lssty': 9710.,
}

# Colors for plotting
band_colors = {
'lsstu': 'C6',
'lsstg': 'C4',
'lsstr': 'C0',
'lssti': 'C2',
'lsstz': 'C3',
'lssty': 'goldenrod',
}

0 comments on commit ebd6ac1

Please sign in to comment.