Skip to content

Commit

Permalink
Add detection for augmented light curves
Browse files Browse the repository at this point in the history
  • Loading branch information
kboone committed May 13, 2019
1 parent 3372777 commit 858a80f
Show file tree
Hide file tree
Showing 3 changed files with 224 additions and 45 deletions.
167 changes: 123 additions & 44 deletions avocado/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import string

from .instruments import band_central_wavelengths
from .utils import settings

class Augmentor():
"""Class used to augment a dataset.
Expand All @@ -25,6 +26,8 @@ class Augmentor():
These methods are:
- `_augment_metadata`
- Either `_choose_sampling_times` or `_choose_target_observation_count`
- `_simulate_light_curve_uncertainties`
- `_simulate_detection`
Parameters
----------
Expand Down Expand Up @@ -221,6 +224,58 @@ def _choose_sampling_times(self, reference_object, augmented_metadata,

return sampling_times

def _simulate_light_curve_uncertainties(self, observations,
augmented_metadata):
"""Simulate the observation-related noise for a light curve.
This method needs to be implemented in survey-specific subclasses of
this class. It should simulate the observation uncertainties for the
light curve.
Parameters
==========
observations : pandas.DataFrame
The augmented observations that have been sampled from a Gaussian
Process. These observations have model flux uncertainties listed
that should be included in the final uncertainties.
augmented_metadata : dict
The augmented metadata
Returns
=======
observations : pandas.DataFrame
The observations with uncertainties added.
"""
return NotImplementedError

def _simulate_detection(self, observations, augmented_metadata):
"""Simulate the detection process for a light curve.
This method needs to be implemented in survey-specific subclasses of
this class. It should simulate whether each observation is detected as
a point-source by the survey and set the "detected" flag in the
observations DataFrame. It should also return whether or not the light
curve passes a base set of criterion to be included in the sample that
this classifier will be applied to.
Parameters
==========
observations : pandas.DataFrame
The augmented observations that have been sampled from a Gaussian
Process.
augmented_metadata : dict
The augmented metadata
Returns
=======
observations : pandas.DataFrame
The observations with the detected flag set.
pass_detection : bool
Whether or not the full light curve passes the detection thresholds
used for the full sample.
"""
return NotImplementedError

def augment_object(self, reference_object):
"""Generate an augmented version of an object.
Expand All @@ -241,16 +296,16 @@ def augment_object(self, reference_object):
random_str = ''.join(np.random.choice(list(string.ascii_letters), 10))
new_object_id = '%s_aug_%s' % (ref_object_id, random_str)

# Augment the metadata. This is survey specific, so this must be
# implemented in subclasses.
# Augment the metadata. The details of how this should work is survey
# specific, so this must be implemented in subclasses.
augmented_metadata = self._augment_metadata(reference_object)
augmented_metadata['object_id'] = new_object_id
augmented_metadata['reference_object_id'] = ref_object_id

return self._resample_light_curve(reference_object, augmented_metadata)

# Add noise to the light_curve
object_data = _simulate_light_curve_noise(object_model, new_ddf)
# Add observation-related uncertainties to the light_curve
object_data = _simulate_light_curve_uncertainties(object_model, new_ddf)

# Model the photoz
if photoz_reference is not None:
Expand Down Expand Up @@ -285,53 +340,77 @@ def _resample_light_curve(self, reference_object, augmented_metadata):
This uses the Gaussian process fit to a light curve to generate new
simulated observations of that light curve.
In some cases, the light curve that is generated will be accidentally
shifted out of the frame, or otherwise missed. If that is the case, the
light curve will automatically be regenerated with the same metadata
until it is either detected or until the number of tries has exceeded
settings['augment_retries'].
Parameters
----------
reference_object : :class:`AstronomicalObject`
The object to use as a reference for the augmentation.
augmented_metadata : dict
The augmented metadata
Returns
-------
augmented_observations : pandas.DataFrame
The simulated observations for the augmented object. If the chosen
metadata leads to an object that is too faint or otherwise unable
to be detected, None will be returned instead.
"""
# Get the GP. This uses a cache if possible.
gp = reference_object.get_default_gaussian_process()

# Figure out where to sample the augmented light curve at.
observations = self._choose_sampling_times(reference_object,
augmented_metadata)

# Compute the fluxes from the GP at the augmented observation times.
new_redshift = augmented_metadata['redshift']
reference_redshift = reference_object.metadata['redshift']
redshift_scale = (1 + new_redshift) / (1 + reference_redshift)

new_wavelengths = np.array([band_central_wavelengths[i] for i in
observations['band']])
eval_wavelengths = new_wavelengths / redshift_scale
pred_x_data = np.vstack([observations['time'], eval_wavelengths]).T
new_fluxes, new_fluxvars = gp(pred_x_data, return_var=True)

observations['flux'] = new_fluxes
observations['flux_error'] = np.sqrt(new_fluxvars)

# Update the brightness of the new observations.
if reference_redshift == 0:
# Adjust brightness for galactic objects.
adjust_mag = np.random.normal(0, 0.5)
# adjust_mag = np.random.lognormal(-1, 0.5)
adjust_scale = 10**(-0.4*adjust_mag)
else:
# Adjust brightness for extragalactic objects. We simply follow the
# Hubble diagram.
delta_distmod = (self.cosmology.distmod(reference_redshift) -
self.cosmology.distmod(new_redshift)).value
adjust_scale = 10**(0.4*delta_distmod)

observations['flux'] *= adjust_scale
observations['flux_error'] *= adjust_scale

# We have the resampled models! Note that there is no error added in yet,
# so we set the detected flags to default values and clean up.
observations['detected'] = 1
observations.reset_index(inplace=True, drop=True)

return observations
for attempt in range(settings['augment_retries']):
# Figure out where to sample the augmented light curve at.
observations = self._choose_sampling_times(reference_object,
augmented_metadata)

# Compute the fluxes from the GP at the augmented observation
# times.
new_redshift = augmented_metadata['redshift']
reference_redshift = reference_object.metadata['redshift']
redshift_scale = (1 + new_redshift) / (1 + reference_redshift)

new_wavelengths = np.array([band_central_wavelengths[i] for i in
observations['band']])
eval_wavelengths = new_wavelengths / redshift_scale
pred_x_data = np.vstack([observations['time'], eval_wavelengths]).T
new_fluxes, new_fluxvars = gp(pred_x_data, return_var=True)

observations['flux'] = new_fluxes
observations['flux_error'] = np.sqrt(new_fluxvars)

# Update the brightness of the new observations.
if reference_redshift == 0:
# Adjust brightness for galactic objects.
adjust_mag = np.random.normal(0, 0.5)
# adjust_mag = np.random.lognormal(-1, 0.5)
adjust_scale = 10**(-0.4*adjust_mag)
else:
# Adjust brightness for extragalactic objects. We simply follow
# the Hubble diagram.
delta_distmod = (self.cosmology.distmod(reference_redshift) -
self.cosmology.distmod(new_redshift)).value
adjust_scale = 10**(0.4*delta_distmod)

observations['flux'] *= adjust_scale
observations['flux_error'] *= adjust_scale

# Add in light curve noise. This is survey specific and must be
# implemented in subclasses.
observations = self._simulate_light_curve_uncertainties(
observations, augmented_metadata)

# Simulate detection
observations, pass_detection = self._simulate_detection(
observations, augmented_metadata)

# If our light curve passes detection thresholds, we're done!
if pass_detection:
return observations

# Failed to generate valid observations.
return None
100 changes: 99 additions & 1 deletion avocado/plasticc.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Utility functions to interact with the PLAsTiCC dataset"""

import numpy as np
import pandas as pd
import os
import pandas as pd
from scipy.special import erf

from .dataset import Dataset
from .utils import settings, AvocadoException, logger
Expand Down Expand Up @@ -371,3 +372,100 @@ def _choose_target_observation_count(self, augmented_metadata):
np.clip(np.random.normal(mu, sigma), 50, None))

return target_observation_count

def _simulate_light_curve_uncertainties(self, observations,
augmented_metadata):
"""Simulate the observation-related noise and detections for a light
curve.
For the PLAsTiCC dataset, we estimate the measurement uncertainties for
each band with a lognormal distribution for both the WFD and DDF
surveys. Those measurement uncertainties are added to the simulated
observations.
Parameters
==========
observations : pandas.DataFrame
The augmented observations that have been sampled from a Gaussian
Process. These observations have model flux uncertainties listed
that should be included in the final uncertainties.
augmented_metadata : dict
The augmented metadata
Returns
=======
observations : pandas.DataFrame
The observations with uncertainties added.
"""
# Make a copy so that we don't modify the original array.
observations = observations.copy()

if len(observations) == 0:
# No data, skip
return observations

if augmented_metadata['ddf']:
band_noises = {
'lsstu': (0.68, 0.26),
'lsstg': (0.25, 0.50),
'lsstr': (0.16, 0.36),
'lssti': (0.53, 0.27),
'lsstz': (0.88, 0.22),
'lssty': (1.76, 0.23),
}
else:
band_noises = {
'lsstu': (2.34, 0.43),
'lsstg': (0.94, 0.41),
'lsstr': (1.30, 0.41),
'lssti': (1.82, 0.42),
'lsstz': (2.56, 0.36),
'lssty': (3.33, 0.37),
}

# Calculate the new noise levels using a lognormal distribution for
# each band.
lognormal_parameters = np.array([band_noises[i] for i in
observations['band']])
add_stds = np.random.lognormal(lognormal_parameters[:, 0],
lognormal_parameters[:, 1])

noise_add = np.random.normal(loc=0.0, scale=add_stds)
observations['flux'] += noise_add
observations['flux_error'] = np.sqrt(
observations['flux_error']**2 + add_stds**2
)

return observations

def _simulate_detection(self, observations, augmented_metadata):
"""Simulate the detection process for a light curve.
We model the PLAsTiCC detection probabilities with an error function.
I'm not entirely sure why this isn't deterministic. The full light
curve is considered to be detected if there are at least 2 individual
detected observations.
Parameters
==========
observations : pandas.DataFrame
The augmented observations that have been sampled from a Gaussian
Process.
augmented_metadata : dict
The augmented metadata
Returns
=======
observations : pandas.DataFrame
The observations with the detected flag set.
pass_detection : bool
Whether or not the full light curve passes the detection thresholds
used for the full sample.
"""
s2n = np.abs(observations['flux']) / observations['flux_error']
prob_detected = (erf((s2n - 5.5) / 2) + 1) / 2.
observations['detected'] = np.random.rand(len(s2n)) < prob_detected

pass_detection = np.sum(observations['detected']) >= 2

return observations, pass_detection
2 changes: 2 additions & 0 deletions avocado_settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

"num_folds": 5,

"augment_retries": 10,

"RAW_DATA_DIR": "./data",
"RAW_TRAINING_PATH": "./data/training_set.csv",
"RAW_TRAINING_METADATA_PATH": "./data/training_set_metadata.csv",
Expand Down

0 comments on commit 858a80f

Please sign in to comment.