Skip to content

Commit

Permalink
First work on Augmentor class
Browse files Browse the repository at this point in the history
  • Loading branch information
kboone committed May 10, 2019
1 parent e21410e commit 71ee434
Show file tree
Hide file tree
Showing 5 changed files with 336 additions and 28 deletions.
1 change: 1 addition & 0 deletions avocado/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .astronomical_object import *
from .dataset import *
from .instruments import *
from .augment import *

from . import plasticc

Expand Down
22 changes: 20 additions & 2 deletions avocado/astronomical_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class AstronomicalObject():
following keys exist in the metadata:
- object_id: A unique ID for the object.
- galactic: Whether or not the object is in the Milky Way galaxy or
not.
- host_photoz: The photometric redshift of the object's host galaxy.
- host_photoz_error: The error on the photometric redshift of the
object's host galaxy.
Expand All @@ -48,6 +50,8 @@ def __init__(self, metadata, observations):
self.metadata = dict(metadata)
self.observations = observations

self._default_gaussian_process = None

def __repr__(self):
return "AstronomicalObject(object_id=%s)" % self.metadata['object_id']

Expand Down Expand Up @@ -225,6 +229,19 @@ def grad_neg_ln_like(p):

return gaussian_process, gp_observations, fit_result.x

def get_default_gaussian_process(self):
"""Get the default Gaussian Process.
This method calls fit_gaussian_process with the default arguments and
caches its output so that multiple calls only require fitting the GP a
single time.
"""
if self._default_gaussian_process is None:
gaussian_process, _, _ = self.fit_gaussian_process()
self._default_gaussian_process = gaussian_process

return self._default_gaussian_process

def predict_gaussian_process(self, bands, times, uncertainties=True,
fitted_gp=None, **gp_kwargs):
"""Predict the Gaussian process in a given set of bands and at a given
Expand Down Expand Up @@ -362,8 +379,9 @@ def print_meta(self):
# Try to print out specific keys in a nice order. If these keys aren't
# available, then we skip them. The rest of the keys are printed out in
# a random order afterwards.
ordered_keys = ['object_id', 'category', 'fold', 'redshift',
'host_specz', 'host_photoz', 'host_photoz_error']
ordered_keys = ['object_id', 'category', 'galactic', 'fold',
'redshift', 'host_specz', 'host_photoz',
'host_photoz_error']
for key in ordered_keys:
if key in self.metadata:
print("%20s: %s" % (key, self.metadata[key]))
Expand Down
40 changes: 40 additions & 0 deletions avocado/augment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
class Augmentor():
"""Class used to augment a dataset.
This class takes :class:`AstronomicalObject`s as input and generates new
:class:`AstronomicalObject`s with the following transformations applied:
- Drop random observations.
- Drop large blocks of observations.
- For galactic observations, adjust the brightness (= distance).
- For extragalactic observations, adjust the redshift.
- Add noise.
The augmentor needs to have some reasonable idea of the properties of the
survey that it is being applied to. If there is a large dataset that the
classifier will be used on, then that dataset can be used directly to
estimate the properties of the survey.
This class needs to be subclassed to implement survey specific methods.
These methods are:
- TODO
"""
def __init__(self):
pass

def augment_metadata(self, reference_object):
"""Generate new metadata for the augmented object.
This method needs to be implemented in survey-specific subclasses of
this class.
Parameters
==========
reference_object : :class:`AstronomicalObject`
The object to use as a reference for the augmentation.
Returns
=======
augmented_metadata : dict
The augmented metadata
"""
return NotImplementedError
29 changes: 16 additions & 13 deletions avocado/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,22 @@ def __init__(self, name, metadata, observations=None):
if 'category' in self.metadata:
self.label_folds()

# Load each astronomical object in the dataset.
objects = []
meta_dicts = self.metadata.to_dict('records')
for object_id, object_observations in \
observations.groupby('object_id'):
meta_index = self.metadata.index.get_loc(object_id)
object_metadata = meta_dicts[meta_index]
object_metadata['object_id'] = object_id
new_object = AstronomicalObject(object_metadata,
object_observations)
objects.append(new_object)

self.objects = np.array(objects)
if observations is None:
self.objects = None
else:
# Load each astronomical object in the dataset.
objects = []
meta_dicts = self.metadata.to_dict('records')
for object_id, object_observations in \
observations.groupby('object_id'):
meta_index = self.metadata.index.get_loc(object_id)
object_metadata = meta_dicts[meta_index]
object_metadata['object_id'] = object_id
new_object = AstronomicalObject(object_metadata,
object_observations)
objects.append(new_object)

self.objects = np.array(objects)

def label_folds(self):
"""Separate the dataset into groups for k-folding
Expand Down

0 comments on commit 71ee434

Please sign in to comment.