Skip to content

Commit

Permalink
Use "class" as the key for object's classes rather than "category"
Browse files Browse the repository at this point in the history
  • Loading branch information
kboone committed May 20, 2019
1 parent afe22d2 commit 46a6ab2
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 36 deletions.
4 changes: 2 additions & 2 deletions avocado/astronomical_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class AstronomicalObject():
For training data objects, the following keys are assumed to exist in
the metadata:
- redshift: The true redshift of the object.
- category: The true category label of the object.
- class: The true class label of the object.
observations : pandas.DataFrame
Observations of the object's light curve. This should be a pandas
Expand Down Expand Up @@ -391,7 +391,7 @@ def print_metadata(self):
# Try to print out specific keys in a nice order. If these keys aren't
# available, then we skip them. The rest of the keys are printed out in
# a random order afterwards.
ordered_keys = ['object_id', 'category', 'galactic', 'fold',
ordered_keys = ['object_id', 'class', 'galactic', 'fold',
'redshift', 'host_specz', 'host_photoz',
'host_photoz_error']
for key in ordered_keys:
Expand Down
4 changes: 2 additions & 2 deletions avocado/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def evaluate_weights(self, dataset):
weights : `pandas.Series`
The weights that should be used for classification.
"""
object_classes = dataset.metadata['category']
object_classes = dataset.metadata['class']
class_counts = object_classes.value_counts()

norm_class_weights = {}
Expand Down Expand Up @@ -197,7 +197,7 @@ def train(self, dataset, num_folds=None, random_state=None, **kwargs):

weights = self.evaluate_weights(dataset)

object_classes = dataset.metadata['category']
object_classes = dataset.metadata['class']
classes = np.unique(object_classes)

importances = pd.DataFrame()
Expand Down
62 changes: 31 additions & 31 deletions avocado/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def label_folds(self, num_folds=None, random_state=None):
"""Separate the dataset into groups for k-folding
This is only applicable to training datasets that have assigned
categories.
classes.
If the dataset is an augmented dataset, we ensure that the
augmentations of the same object stay in the same fold.
Expand All @@ -215,9 +215,9 @@ def label_folds(self, num_folds=None, random_state=None):
if random_state is None:
random_state = settings['fold_random_state']

if 'category' not in self.metadata:
if 'class' not in self.metadata:
raise AvocadoException(
"Dataset %s does not have labeled categories! Can't separate "
"Dataset %s does not have labeled classes! Can't separate "
"into folds." % self.name
)

Expand All @@ -231,7 +231,7 @@ def label_folds(self, num_folds=None, random_state=None):
is_augmented = False
reference_metadata = self.metadata

reference_classes = reference_metadata['category']
reference_classes = reference_metadata['class']
folds = StratifiedKFold(n_splits=num_folds, shuffle=True,
random_state=random_state)
fold_map = {}
Expand All @@ -251,21 +251,21 @@ def label_folds(self, num_folds=None, random_state=None):

return fold_indices

def get_object(self, index=None, category=None, object_id=None):
def get_object(self, index=None, object_class=None, object_id=None):
"""Parse keywords to pull a specific object out of the dataset
Parameters
==========
index : int (optional)
The index of the object in the dataset in the range
[0, num_objects-1]. If a specific category is specified, then the
index only counts objects of that category.
category : int or str (optional)
Filter for objects of a specific category. If this is specified,
then index must also be specified.
[0, num_objects-1]. If a specific object_class is specified, then
the index only counts objects of that class.
object_class : int or str (optional)
Filter for objects of a specific class. If this is specified, then
index must also be specified.
object_id : str (optional)
Retrieve an object with this specific object_id. If index or
category is specified, then object_id cannot also be specified.
object_class is specified, then object_id cannot also be specified.
Returns
=======
Expand All @@ -275,24 +275,23 @@ def get_object(self, index=None, category=None, object_id=None):
# Check to make sure that we have a valid object specification.
base_error = "Error finding object! "
if object_id is not None:
if index is not None or category is not None:
if index is not None or object_class is not None:
raise AvocadoException(
base_error + "If object_id is specified, can't also "
"specify index or category!"
"specify index or object_class!"
)

if category is not None and index is None:
if object_class is not None and index is None:
raise AvocadoException(
base_error + "Must specify index if category is specified!"
base_error + "Must specify index if object_class is specified!"
)

# Figure out the index to use.
if category is not None:
if object_class is not None:
# Figure out the target object_id and use that to get the index.
category_index = index
category_meta = self.metadata[
self.metadata['category'] == category]
object_id = category_meta.index[category_index]
class_index = index
class_meta = self.metadata[self.metadata['class'] == object_class]
object_id = class_meta.index[class_index]

if object_id is not None:
try:
Expand All @@ -304,7 +303,8 @@ def get_object(self, index=None, category=None, object_id=None):

return self.objects[index]

def _get_object(self, index=None, category=None, object_id=None, **kwargs):
def _get_object(self, index=None, object_class=None, object_id=None,
**kwargs):
"""Wrapper around get_object that returns unused kwargs.
This function is used for the common situation of pulling an object out
Expand All @@ -319,7 +319,7 @@ def _get_object(self, index=None, category=None, object_id=None, **kwargs):
**kwargs
Additional arguments passed to the function that weren't used.
"""
return self.get_object(index, category, object_id), kwargs
return self.get_object(index, object_class, object_id), kwargs

def plot_light_curve(self, *args, **kwargs):
"""Plot the light curve for an object in the dataset.
Expand All @@ -339,26 +339,26 @@ def plot_interactive(self):
"""
from ipywidgets import interact, IntSlider, Dropdown

categories = {'' : None}
for category in np.unique(self.metadata['category']):
categories[category] = category
object_classes = {'' : None}
for object_class in np.unique(self.metadata['class']):
object_classes[object_class] = object_class

idx_widget = IntSlider(min=0, max=1)
category_widget = Dropdown(options=categories, index=0)
class_widget = Dropdown(options=object_classes, index=0)

def update_idx_range(*args):
if category_widget.value is None:
if class_widget.value is None:
idx_widget.max = len(self.metadata) - 1
else:
idx_widget.max = np.sum(self.metadata['category'] ==
category_widget.value) - 1
idx_widget.max = np.sum(self.metadata['class'] ==
class_widget.value) - 1

category_widget.observe(update_idx_range, 'value')
class_widget.observe(update_idx_range, 'value')

update_idx_range()

interact(self.plot_light_curve, index=idx_widget,
category=category_widget, show_gp=True, uncertainties=True,
object_class=class_widget, show_gp=True, uncertainties=True,
verbose=False, subtract_background=True)

def write(self, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion scripts/avocado_download_plasticc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def update_plasticc_metadata(metadata):
"""
# Rename columns in the metadata table to match the avocado conventions.
metadata_name_map = {
'true_target': 'category',
'true_target': 'class',
'hostgal_photoz_err': 'host_photoz_error',
'hostgal_photoz': 'host_photoz',
'hostgal_specz': 'host_specz',
Expand Down

0 comments on commit 46a6ab2

Please sign in to comment.