Skip to content

Commit

Permalink
Make augment_object return AstronomicalObject instances
Browse files Browse the repository at this point in the history
  • Loading branch information
kboone committed May 13, 2019
1 parent 858a80f commit 212eba6
Showing 1 changed file with 37 additions and 38 deletions.
75 changes: 37 additions & 38 deletions avocado/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import pandas as pd
import string

from .astronomical_object import AstronomicalObject
from .instruments import band_central_wavelengths
from .utils import settings
from .utils import settings, logger

class Augmentor():
"""Class used to augment a dataset.
Expand Down Expand Up @@ -276,18 +277,26 @@ def _simulate_detection(self, observations, augmented_metadata):
"""
return NotImplementedError

def augment_object(self, reference_object):
def augment_object(self, reference_object, force_success=True):
"""Generate an augmented version of an object.
Parameters
==========
reference_object : :class:`AstronomicalObject`
The object to use as a reference for the augmentation.
force_success : bool
If True, then if we fail to generate an augmented light curve for a
specific set of augmented parameters, we choose a different set of
augmented parameters until we eventually get an augmented light
curve. This is useful for debugging/interactive work, but when
actually augmenting a dataset there is a massive speed up to
ignoring bad light curves without a major change in classification
performance.
Returns
=======
augmented_object : :class:`AstronomicalObject`
The augmented object.
The augmented object. If force_success is False, this can be None.
"""
# Create a new object id for the augmented object. We choose a random
# string to add on to the end of the original object id that is very
Expand All @@ -296,43 +305,29 @@ def augment_object(self, reference_object):
random_str = ''.join(np.random.choice(list(string.ascii_letters), 10))
new_object_id = '%s_aug_%s' % (ref_object_id, random_str)

# Augment the metadata. The details of how this should work is survey
# specific, so this must be implemented in subclasses.
augmented_metadata = self._augment_metadata(reference_object)
augmented_metadata['object_id'] = new_object_id
augmented_metadata['reference_object_id'] = ref_object_id
while True:
# Augment the metadata. The details of how this should work is
# survey specific, so this must be implemented in subclasses.
augmented_metadata = self._augment_metadata(reference_object)
augmented_metadata['object_id'] = new_object_id
augmented_metadata['reference_object_id'] = ref_object_id

return self._resample_light_curve(reference_object, augmented_metadata)

# Add observation-related uncertainties to the light_curve
object_data = _simulate_light_curve_uncertainties(object_model, new_ddf)

# Model the photoz
if photoz_reference is not None:
# Use the reference to simulate the photoz
object_meta = _simulate_photoz_reference(object_meta,
photoz_reference)
else:
# Use a model of the photoz
object_meta = _simulate_photoz_model(object_meta)

# Smear the mwebv value a bit so that it doesn't uniquely identify
# points. I leave the position on the sky unchanged (ra, dec, etc.).
# Don't put any of those variables directly into the classifier!
object_meta['mwebv'] *= np.random.normal(1, 0.1)

# Update the object id by adding a random fractional offset to the id.
# This lets us match it to the original but uniquely identify it.
new_object_id = object_meta['object_id'] + np.random.uniform(0, 1)
object_data['object_id'] = new_object_id
object_meta['object_id'] = new_object_id

object_meta['ddf'] = new_ddf
# Generate an augmented light curve for this augmented metadata.
observations = self._resample_light_curve(reference_object,
augmented_metadata)

if full_return:
return object_meta, object_data, object_model
else:
return object_meta, object_data
if observations is not None:
# Successfully generated a light curve.
augmented_object = AstronomicalObject(augmented_metadata,
observations)
return augmented_object
elif not force_success:
# Failed to generate a light curve, and we aren't retrying
# until we are successful.
return None
else:
logger.warn("Failed to generate a light curve for redshift "
"%.2f. Retrying." % augmented_metadata['redshift'])

def _resample_light_curve(self, reference_object, augmented_metadata):
"""Resample a light curve as part of the augmenting procedure
Expand Down Expand Up @@ -399,6 +394,10 @@ def _resample_light_curve(self, reference_object, augmented_metadata):
observations['flux'] *= adjust_scale
observations['flux_error'] *= adjust_scale

# Save the model flux and flux error
observations['model_flux'] = observations['flux']
observations['model_flux_error'] = observations['flux_error']

# Add in light curve noise. This is survey specific and must be
# implemented in subclasses.
observations = self._simulate_light_curve_uncertainties(
Expand Down

0 comments on commit 212eba6

Please sign in to comment.