Skip to content

Commit

Permalink
Add redshift-weighting for classifiers
Browse files Browse the repository at this point in the history
  • Loading branch information
kboone committed May 27, 2019
1 parent c648e08 commit 6c45a9d
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 7 deletions.
162 changes: 158 additions & 4 deletions avocado/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,160 @@ def evaluate_weights_flat(dataset, class_weights=None):
return weights


def evaluate_weights_redshift(
dataset, class_weights=None, group_key=None, min_redshift=None,
max_redshift=None, num_bins=None, min_count_fraction=None):
"""Evaluate redshift-weighted weights to use to generate a
rates-independent classifier.
The redshift range is divided into logarithmically-spaced bins. Each class
is given the same weights in each bin so that the rates information in the
training set doesn't affect the classification. A classifier trained using
these weights will produce a "rates-independent" classification.
The redshift bins to use are set using a logarithmic range between
min_redshift and max_redshift with a total of num_bins. Any objects that
spill out of these bins are included in the first and last bins. A separate
bin is included for galactic objects at redshift exactly 0.
Parameters
----------
dataset : :class:`Dataset`
The dataset to evaluate weights on.
class_weights : dict (optional)
Weights to use for each class. If not set, equal weights are assumed
for each class.
group_key : str (optional)
If set, the group of each object will be loaded using group_key as the
key in the dataset's metadata. The weights will be calculated
independently for each group. This can be useful if there are multiple
very different survey strategies in the same dataset, all of which have
their own selection efficiencies. By default,
settings['redshift_weighting_group_key'] will be used.
min_redshift : float (optional)
The minimum redshift bin to use. By default,
settings['redshift_weighting_min_redshift'] will be used.
max_redshift : float (optional)
The maximum redshift bin to use. By default,
settings['redshift_weighting_max_redshift'] will be used.
num_bins : int (optional)
The number of redshift bins to use. By default,
settings['redshift_weighting_num_bins'] will be used.
min_count_fraction : float(optional)
The minimum number of counts to use for weighting in a bin, as a
fraction of the total number of objects of all classes in the same
redshift bin and group. By default,
settings['redshift_weighting_min_count_fraction'] will be used.
Returns
-------
weights : `pandas.Series`
The weights that should be used for classification.
"""
if group_key is None:
group_key = settings['redshift_weighting_group_key']
if min_redshift is None:
min_redshift = settings['redshift_weighting_min_redshift']
if max_redshift is None:
max_redshift = settings['redshift_weighting_max_redshift']
if num_bins is None:
num_bins = settings['redshift_weighting_num_bins']
if min_count_fraction is None:
min_count_fraction = settings['redshift_weighting_min_count_fraction']

# Create the initial bin range
redshift_bins = np.logspace(np.log10(min_redshift), np.log10(max_redshift),
num_bins+1)

# Replace the first and last bins with very small and large numbers to
# effectively extend them to infinity.
redshift_bins[0] = 1e-99
redshift_bins[-1] = 1e99

# Add in a bin for galactic objects at redshifts of exactly 0
redshift_bins = np.hstack([-1e99, redshift_bins])

# Figure out which redshift bin each object falls in.
redshift_indices = np.searchsorted(
redshift_bins, dataset.metadata['host_specz']) - 1

# Figure out how many different classes there are, and create a mapping for
# them.
object_classes = dataset.metadata['class']
class_names = np.unique(object_classes)
class_map = {class_name : i for i, class_name in enumerate(class_names)}
class_indices = [class_map[i] for i in object_classes]

# Figure out how many different groups there are, and create a mapping for
# them.
if group_key is not None:
groups = dataset.metadata[group_key]
group_names = np.unique(groups)
group_map = {group_name : i for i, group_name in
enumerate(group_names)}
group_indices = [group_map[i] for i in groups]
else:
group_names = ['default']
group_indices = np.zeros(len(dataset.metadata), dtype=int)

# Count how many objects are in each bin.
counts = np.zeros((len(group_names), len(redshift_bins) - 1,
len(class_names)))
for group_index, redshift_index, class_index in \
zip(group_indices, redshift_indices, class_indices):
counts[group_index, redshift_index, class_index] += 1

total_counts = np.sum(counts)

# Count how many extragalactic bins are actually populated. This is
# used to set the scales so that they roughly match what we have
# for the non-redshift-weighted metric. This is done so that we can
# use the same hyperparameters. For galactic objects, we don't need
# to do anything because all of the observations end up in the same
# bin. For extragalactic objects, we need to take into account the
# fact that the objects are now split up between many different
# bins. Get an estimate of how many bins are populated, and apply
# that to the data.
num_extgal_bins = np.sum(counts[:, 1:, :] > 1e-4 * total_counts)
class_extgal_counts = np.sum(np.sum(counts[:, 1:, :], axis=0), axis=0)
class_gal_counts = np.sum(counts[:, 0, :], axis=0)
extgal_mask = class_extgal_counts > class_gal_counts
num_extgal_classes = np.sum(extgal_mask)
extgal_scale = num_extgal_bins / num_extgal_classes

# Figure out the weights for each bin.
weights = np.zeros(counts.shape)

for group_index in range(len(group_names)):
for redshift_index in range(len(redshift_bins) - 1):
# Calculate the weights for each redshift and group bin separately.
bin_counts = counts[group_index, redshift_index]

if np.sum(bin_counts) == 0:
weights[group_index, redshift_index] = 0.

# Add a floor to the counts in each bin to prevent absurdly
# high weights for poorly represented classes.
min_counts = min_count_fraction * np.sum(bin_counts)
bin_counts[bin_counts < min_counts] = min_counts

weights[group_index, redshift_index] = total_counts / bin_counts

# Rescale the weights for extragalactic classes.
weights[:, :, extgal_mask] /= extgal_scale

# If class_weights is set, rescale the weights for each class.
if class_weights is not None:
for class_idx, class_name in enumerate(class_names):
weights[:, :, class_idx] *= class_weights[class_name]

# Calculate the weights for each object
object_weights = weights[group_indices, redshift_indices, class_indices]
object_weights = pd.Series(object_weights, index=dataset.metadata.index)

return object_weights


class Classifier():
"""Classifier used to classify the different objects in a dataset.
Expand Down Expand Up @@ -171,7 +325,7 @@ class LightGBMClassifier(Classifier):
class_weights : dict (optional)
Weights to use for each class. If not set, equal weights are assumed
for each class.
weights_function : function (optional)
weighting_function : function (optional)
Function to use to evaluate weights. By default,
`evaluate_weights_flat` is used which normalizes the weights for each
class so that their overall weight matches the one set by
Expand All @@ -180,12 +334,12 @@ class so that their overall weight matches the one set by
it has the same signature as `evaluate_weights_flat`.
"""
def __init__(self, name, featurizer, class_weights=None,
weights_function=evaluate_weights_flat):
weighting_function=evaluate_weights_flat):
super().__init__(name)

self.featurizer = featurizer
self.class_weights = class_weights
self.weights_function = evaluate_weights_flat
self.weighting_function = evaluate_weights_flat

def train(self, dataset, num_folds=None, random_state=None, **kwargs):
"""Train the classifier on a dataset
Expand All @@ -208,7 +362,7 @@ def train(self, dataset, num_folds=None, random_state=None, **kwargs):
folds = dataset.label_folds(num_folds, random_state)
num_folds = np.max(folds) + 1

weights = self.weights_function(dataset)
weights = self.weighting_function(dataset)

object_classes = dataset.metadata['class']
classes = np.unique(object_classes)
Expand Down
8 changes: 7 additions & 1 deletion avocado_settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,11 @@

"classifier_directory": "./classifiers",

"predictions_directory": "./predictions"
"predictions_directory": "./predictions",

"redshift_weighting_group_key": "ddf",
"redshift_weighting_min_redshift": 0.1,
"redshift_weighting_max_redshift": 3.0,
"redshift_weighting_num_bins": 10,
"redshift_weighting_min_count_fraction": 0.01
}
20 changes: 18 additions & 2 deletions scripts/avocado_train_classifier
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python
"""Train a classifier using avocado.
TODO: Add options for weights, classifier kind, featurizer, etc.
For now, this only supports a LightGBM classifier with the PLAsTiCC featurizer.
"""

import argparse
Expand All @@ -19,6 +19,12 @@ if __name__ == "__main__":
'classifier',
help='Name of the classifier to produce.'
)
parser.add_argument(
'--weighting',
help='Kind of weighting to use. (default: %(default)s)',
default='flat',
choices=('flat', 'redshift'),
)

args = parser.parse_args()

Expand All @@ -30,12 +36,22 @@ if __name__ == "__main__":
print("Loading raw features...")
dataset.load_raw_features()

# Figure out which weighting to use.
if args.weighting == 'flat':
weighting_function = avocado.evaluate_weights_flat
elif args.weighting == 'redshift':
weighting_function = avocado.evaluate_weights_redshift
else:
raise avocado.AvocadoException("Invalid weighting '%s'!" %
args.weighting)

# Train the classifier
print("Training classifier '%s'..." % args.classifier)
classifier = avocado.LightGBMClassifier(
args.classifier,
avocado.plasticc.PlasticcFeaturizer(),
avocado.plasticc.plasticc_kaggle_weights
avocado.plasticc.plasticc_kaggle_weights,
weighting_function=weighting_function,
)
classifier.train(dataset)

Expand Down

0 comments on commit 6c45a9d

Please sign in to comment.