Skip to content

Commit

Permalink
Load the PLAsTiCC dataset with the new API
Browse files Browse the repository at this point in the history
  • Loading branch information
kboone committed Apr 23, 2019
1 parent 43e5e09 commit 851dc11
Show file tree
Hide file tree
Showing 8 changed files with 2,762 additions and 2,576 deletions.
5 changes: 5 additions & 0 deletions avocado/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,7 @@
from .settings import settings
from .logging import logger

from .astronomical_object import *
from .dataset import *

from . import plasticc
126 changes: 63 additions & 63 deletions avocado/astronomical_object.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pandas as pd
from astropy.stats import biweight_location

class AstronomicalObject():
"""Class representing an astronomical object.
Expand All @@ -8,7 +9,7 @@ class AstronomicalObject():
Parameters
----------
metadata : dict-like
metadata : dict or pandas Series
Metadata for this object. This is represented using a dict
internally, and must be able to be cast to a dict. Any keys and
information are allowed. Various functions assume that the
Expand All @@ -24,12 +25,12 @@ class AstronomicalObject():
- class: The true class label of the object (only available for the
training data).
observations : DataFrame
observations : pandas.DataFrame
Observations of the object's light curve. This should be a pandas
DataFrame with at least the following columns:
- mjd: The MJD date of each observation.
- passband: The passband used for the observation.
- band: The band used for the observation.
- flux: The measured flux value of the observation.
- flux_err: The flux measurement uncertainty of the observation.
"""
Expand All @@ -38,67 +39,66 @@ def __init__(self, metadata, observations):
self.metadata = metadata
self.observations = observations

def _get_gp_data(self, object_meta, object_data, fix_background=True):
times = []
fluxes = []
bands = []
flux_errs = []
wavelengths = []

# The zeropoints were arbitrarily set from the first image. Pick the
# 20th percentile of all observations in each channel as a new
# zeropoint. This has good performance when there are supernova-like
# bursts in the image, even if they are quite wide.
# UPDATE: when picking the 20th percentile, observations with just
# noise get really messed up. Revert back to the median for now and see
# if that helps. It doesn't really matter if supernovae go slightly
# negative...
# UPDATE 2: most of the objects of interest are short-lived in time.
# The only issue with the background occurs when there was flux from
# the transient in the reference image. To deal with this, look at the
# last observations and see if they are negative (indicating that the
# reference has additional flux in it). If so, then update the
# background level. Otherwise, leave the background at the reference
# level.
for passband in range(num_passbands):
band_data = object_data[object_data['passband'] == passband]
if len(band_data) == 0:
# No observations in this band
continue
@property
def bands(self):
"""Return a list of bands that this object has observations in."""
return np.unique(self.observations['band'])

def subtract_background(self):
"""Subtract the background levels from each band.
The background levels are estimated using a biweight location
estimator. This estimator will calculate a robust estimate of the
background level for objects that have short-lived light curves, and it
will return something like the median flux level for periodic or
continuous light curves.
Returns
-------
subtracted_observations : pandas.DataFrame
A modified version of the observations DataFrame with the
background level removed.
"""
subtracted_observations = self.observations.copy()

for band in self.bands:
mask = self.observations['band'] == band
band_data = self.observations['mask']

# Use a biweight location to estimate the background
ref_flux = biweight_location(band_data['flux'])

for idx, row in band_data.iterrows():
times.append(row['mjd'] - start_mjd)
flux = row['flux']
if fix_background:
flux = flux - ref_flux
bands.append(passband)
wavelengths.append(band_wavelengths[passband])
fluxes.append(flux)
flux_errs.append(row['flux_err'])

times = np.array(times)
bands = np.array(bands)
wavelengths = np.array(wavelengths)
fluxes = np.array(fluxes)
flux_errs = np.array(flux_errs)

# Guess the scale based off of the highest signal-to-noise point.
# Sometimes the edge bands are pure noise and can have large
# insignificant points. Add epsilon to this calculation to avoid divide
# by zero errors for model fluxes that have 0 error.
scale = fluxes[np.argmax(fluxes / (flux_errs + 1e-5))]

gp_data = {
'meta': object_meta,
'times': times,
'bands': bands,
'scale': scale,
'wavelengths': wavelengths,
'fluxes': fluxes,
'flux_errs': flux_errs,
}

return gp_data
subtracted_observations['flux', mask] -= ref_flux

return subtracted_observations

def plot_light_curve(self, data_only=False):
"""Plot the object's light curve"""
result = self.predict_gp(*args, **kwargs)

plt.figure()

for band in range(num_passbands):
cut = result['bands'] == band
color = band_colors[band]
plt.errorbar(result['times'][cut], result['fluxes'][cut],
result['flux_errs'][cut], fmt='o', c=color,
markersize=5, label=band_names[band])

if data_only:
continue

plt.plot(result['pred_times'], result['pred'][band], c=color)

if kwargs.get('uncertainties', False):
# Show uncertainties with a shaded band
pred = result['pred'][band]
err = np.sqrt(result['pred_var'][band])
plt.fill_between(result['pred_times'], pred-err, pred+err,
alpha=0.2, color=color)

plt.legend()

plt.xlabel('Time (days)')
plt.ylabel('Flux')
plt.tight_layout()
84 changes: 83 additions & 1 deletion avocado/dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,85 @@
import numpy as np

from sklearn.model_selection import StratifiedKFold

from .astronomical_object import AstronomicalObject
from .logging import logger
from .settings import settings

class Dataset():
"""Class representing a dataset of many astronomical objects."""
"""Class representing a dataset of many astronomical objects.
Parameters
----------
metadata : pandas.DataFrame
DataFrame where each row is the metadata for an object in the dataset.
See :class:`AstronomicalObject` for details.
observations : pandas.DataFrame
Observations of all of the objects' light curves. See
:class:`AstronomicalObject` for details.
name : str
Name of the dataset. This will be used to determine the filenames of
various outputs such as computed features and predictions.
"""
def __init__(self, name, metadata, observations=None):
"""Create a new Dataset from a set of metadata and observations"""

# Use object_id as the index
metadata = metadata.set_index('object_id')
metadata.sort_index(inplace=True)

self.name = name
self.metadata = metadata

# Label folds for training datasets
if 'class' in self.metadata:
self.label_folds()

# Load each astronomical object in the dataset.
objects = []
for object_id, object_observations in \
observations.groupby('object_id').groups.items():
object_metadata = metadata.loc[object_id]
new_object = AstronomicalObject(object_metadata,
object_observations)
objects.append(new_object)

self.objects = np.array(objects)

def label_folds(self):
"""Separate the dataset into groups for k-folding
This is only applicable to training datasets that have assigned
classes.
The number of folds is set by the `num_folds` settings parameter.
This needs to happen before augmentation to avoid leakage, so augmented
datasets and similar datasets should already have the folds set.
"""
if 'class' not in self.metadata:
logger.warn("Dataset %s does not have labeled classes! Can't "
"separate into folds." % self.name)
return

num_folds = settings['num_folds']

if 'fold' in self.metadata:
# Warn if the fold count doesn't match.
data_num_folds = np.max(self.metadata['fold']) + 1
if data_num_folds != num_folds:
logger.warn("Using %d preset folds in dataset instead of "
"%d requested." % (data_num_folds, num_folds))
return

# Label folds
classes = self.metadata['class']
folds = StratifiedKFold(n_splits=num_folds, shuffle=True,
random_state=1)
kfold_indices = -1 * np.ones(len(classes), dtype=int)
for idx, (fold_train, fold_val) in \
enumerate(folds.split(classes, classes)):
kfold_indices[fold_val] = idx
self.metadata['fold'] = kfold_indices
5 changes: 5 additions & 0 deletions avocado/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import logging

from . import settings

logger = logging.getLogger('avocado')

0 comments on commit 851dc11

Please sign in to comment.