Skip to content

Commit

Permalink
Add augmentation for full datasets.
Browse files Browse the repository at this point in the history
  • Loading branch information
kboone committed May 14, 2019
1 parent c319b92 commit 21fe659
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 65 deletions.
145 changes: 93 additions & 52 deletions avocado/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import numpy as np
import pandas as pd
import string
from tqdm import tqdm

from .astronomical_object import AstronomicalObject
from .dataset import Dataset
from .instruments import band_central_wavelengths
from .utils import settings, logger

Expand Down Expand Up @@ -283,58 +285,6 @@ def _simulate_detection(self, observations, augmented_metadata):
"""
return NotImplementedError

def augment_object(self, reference_object, force_success=True):
"""Generate an augmented version of an object.
Parameters
==========
reference_object : :class:`AstronomicalObject`
The object to use as a reference for the augmentation.
force_success : bool
If True, then if we fail to generate an augmented light curve for a
specific set of augmented parameters, we choose a different set of
augmented parameters until we eventually get an augmented light
curve. This is useful for debugging/interactive work, but when
actually augmenting a dataset there is a massive speed up to
ignoring bad light curves without a major change in classification
performance.
Returns
=======
augmented_object : :class:`AstronomicalObject`
The augmented object. If force_success is False, this can be None.
"""
# Create a new object id for the augmented object. We choose a random
# string to add on to the end of the original object id that is very
# unlikely to have collisions.
ref_object_id = reference_object.metadata['object_id']
random_str = ''.join(np.random.choice(list(string.ascii_letters), 10))
new_object_id = '%s_aug_%s' % (ref_object_id, random_str)

while True:
# Augment the metadata. The details of how this should work is
# survey specific, so this must be implemented in subclasses.
augmented_metadata = self._augment_metadata(reference_object)
augmented_metadata['object_id'] = new_object_id
augmented_metadata['reference_object_id'] = ref_object_id

# Generate an augmented light curve for this augmented metadata.
observations = self._resample_light_curve(reference_object,
augmented_metadata)

if observations is not None:
# Successfully generated a light curve.
augmented_object = AstronomicalObject(augmented_metadata,
observations)
return augmented_object
elif not force_success:
# Failed to generate a light curve, and we aren't retrying
# until we are successful.
return None
else:
logger.warn("Failed to generate a light curve for redshift "
"%.2f. Retrying." % augmented_metadata['redshift'])

def _resample_light_curve(self, reference_object, augmented_metadata):
"""Resample a light curve as part of the augmenting procedure
Expand Down Expand Up @@ -419,3 +369,94 @@ def _resample_light_curve(self, reference_object, augmented_metadata):

# Failed to generate valid observations.
return None

def augment_object(self, reference_object, force_success=True):
"""Generate an augmented version of an object.
Parameters
==========
reference_object : :class:`AstronomicalObject`
The object to use as a reference for the augmentation.
force_success : bool
If True, then if we fail to generate an augmented light curve for a
specific set of augmented parameters, we choose a different set of
augmented parameters until we eventually get an augmented light
curve. This is useful for debugging/interactive work, but when
actually augmenting a dataset there is a massive speed up to
ignoring bad light curves without a major change in classification
performance.
Returns
=======
augmented_object : :class:`AstronomicalObject`
The augmented object. If force_success is False, this can be None.
"""
# Create a new object id for the augmented object. We choose a random
# string to add on to the end of the original object id that is very
# unlikely to have collisions.
ref_object_id = reference_object.metadata['object_id']
random_str = ''.join(np.random.choice(list(string.ascii_letters), 10))
new_object_id = '%s_aug_%s' % (ref_object_id, random_str)

while True:
# Augment the metadata. The details of how this should work is
# survey specific, so this must be implemented in subclasses.
augmented_metadata = self._augment_metadata(reference_object)
augmented_metadata['object_id'] = new_object_id
augmented_metadata['reference_object_id'] = ref_object_id

# Generate an augmented light curve for this augmented metadata.
observations = self._resample_light_curve(reference_object,
augmented_metadata)

if observations is not None:
# Successfully generated a light curve.
augmented_object = AstronomicalObject(augmented_metadata,
observations)
return augmented_object
elif not force_success:
# Failed to generate a light curve, and we aren't retrying
# until we are successful.
return None
else:
logger.warn("Failed to generate a light curve for redshift "
"%.2f. Retrying." % augmented_metadata['redshift'])

def augment_dataset(self, dataset, num_augments, tag="augment",
include_reference=True):
"""Generate augmented versions of all objects in a dataset.
Parameters
==========
dataset : :class:`Dataset`
The dataset to use as a reference for the augmentation.
num_augments : int
The number of times to use each object in the dataset as a
reference for augmentation. Note that augmentation sometimes fails,
so this is the number of tries, not the number of sucesses.
include_reference : bool (optional)
If True (default), the reference objects are included in the new
augmented dataset. Otherwise they are dropped.
Returns
=======
augmented_dataset : :class:`Dataset`
The augmented dataset.
"""
augmented_objects = []

for reference_object in tqdm(dataset.objects):
if include_reference:
augmented_objects.append(reference_object)

for i in range(num_augments):
augmented_object = self.augment_object(reference_object,
force_success=False)
if augmented_object is not None:
augmented_objects.append(augmented_object)

new_name = "%s_%s" % (tag, dataset.name)

augmented_dataset = Dataset.from_objects(new_name, augmented_objects)

return augmented_dataset
50 changes: 37 additions & 13 deletions avocado/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,20 @@ class Dataset():
Parameters
----------
name : str
Name of the dataset. This will be used to determine the filenames of
various outputs such as computed features and predictions.
metadata : pandas.DataFrame
DataFrame where each row is the metadata for an object in the dataset.
See :class:`AstronomicalObject` for details.
observations : pandas.DataFrame
Observations of all of the objects' light curves. See
:class:`AstronomicalObject` for details.
name : str
Name of the dataset. This will be used to determine the filenames of
various outputs such as computed features and predictions.
objects : list
A list of :class:`AstronomicalObject` instances. Either this or
observations can be specified but not both.
"""
def __init__(self, name, metadata, observations=None):
def __init__(self, name, metadata, observations=None, objects=None):
"""Create a new Dataset from a set of metadata and observations"""
# Make copies of everything so that we don't mess anything up.
metadata = metadata.copy()
Expand All @@ -36,7 +37,7 @@ def __init__(self, name, metadata, observations=None):
self.metadata = metadata

if observations is None:
self.objects = None
self.objects = objects
else:
# Load each astronomical object in the dataset.
self.objects = np.zeros(len(self.metadata), dtype=object)
Expand All @@ -53,8 +54,7 @@ def __init__(self, name, metadata, observations=None):
self.objects[meta_index] = new_object

@classmethod
def load(cls, dataset_name, metadata_only=False, chunk=None,
num_chunks=None):
def load(cls, name, metadata_only=False, chunk=None, num_chunks=None):
"""Load a dataset that has been saved in HDF5 format in the data
directory.
Expand All @@ -66,7 +66,7 @@ def load(cls, dataset_name, metadata_only=False, chunk=None,
Parameters
----------
dataset_name : str
name : str
The name of the dataset to load
metadata_only : bool (optional)
If False (default), the observations are loaded. Otherwise, only
Expand All @@ -84,10 +84,10 @@ def load(cls, dataset_name, metadata_only=False, chunk=None,
"""
data_directory = settings['data_directory']

data_path = os.path.join(data_directory, dataset_name + '.h5')
data_path = os.path.join(data_directory, name + '.h5')

if not os.path.exists(data_path):
raise AvocadoException("Couldn't find dataset %s!" % dataset_name)
raise AvocadoException("Couldn't find dataset %s!" % name)

if chunk is None:
# Load the full dataset
Expand Down Expand Up @@ -144,14 +144,38 @@ def load(cls, dataset_name, metadata_only=False, chunk=None,
observations = pd.read_hdf(data_path, 'observations')

# Create a Dataset object
dataset = Dataset(dataset_name, metadata, observations)
dataset = cls(name, metadata, observations)

# Label folds if we have a full dataset with fold information
if chunk is None and 'category' in dataset.metadata:
dataset.label_folds()

return dataset

@classmethod
def from_objects(cls, name, objects):
"""Load a dataset from a list of AstronomicalObject instances.
Parameters
----------
objects : list
A list of AstronomicalObject instances.
name : str
The name of the dataset.
Returns
-------
dataset : :class:`Dataset`
The loaded dataset.
"""
# Pull the metadata out of the objects
metadata = pd.DataFrame([i.metadata for i in objects])
metadata.set_index('object_id', inplace=True)

# Load the new dataset.
dataset = cls(name, metadata, objects=objects)

return dataset

def label_folds(self):
"""Separate the dataset into groups for k-folding
Expand Down

0 comments on commit 21fe659

Please sign in to comment.