Skip to content

Commit

Permalink
Update and clean up of ml params
Browse files Browse the repository at this point in the history
  • Loading branch information
maxpumperla committed Feb 23, 2016
1 parent 89ee4f6 commit d86b350
Showing 1 changed file with 29 additions and 30 deletions.
59 changes: 29 additions & 30 deletions elephas/ml/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,39 +3,36 @@

class HasKerasModelConfig(Params):
'''
Mandatory:
Mandatory field:
Parameter mixin for Keras model yaml
'''
def __init__(self):
super(HasKerasModelConfig, self).__init__()
self.keras_model = Param(self, "keras_model", "Serialized Keras model as yaml string")
self.keras_model_config = Param(self, "keras_model_config", "Serialized Keras model as yaml string")

def set_keras_model(self, model):
self._paramMap[self.keras_model] = model
def set_keras_model_config(self, keras_model_config):
self._paramMap[self.keras_model_config] = keras_model_config
return self

def get_keras_model(self):
return self.getOrDefault(self.keras_model)
def get_keras_model_config(self):
return self.getOrDefault(self.keras_model_config)


class HasOptimizerConfig(Params):
'''
Mandatory:
Parameter mixin for Elephas optimizer config
'''
def __init__(self):
super(HasOptimizerConfig, self).__init__()
self.optimizer_config = Param(self, "optimizer_config", "Serialized Elephas optimizer properties")
# TODO: Define default value

def set_optimizer_config(self, optimizer):
self._paramMap[self.optimizer] = optimizer
def set_optimizer_config(self, optimizer_config):
self._paramMap[self.optimizer_config] = optimizer_config
return self

def get_optimizer_config(self):
return self.getOrDefault(self.config)
return self.getOrDefault(self.optimizer_config)


class HasMode(Params):
Expand Down Expand Up @@ -72,7 +69,7 @@ def get_frequency(self):
return self.getOrDefault(self.frequency)


class HasNumberOfClasses(Param):
class HasNumberOfClasses(Params):
'''
Mandatory:
Expand All @@ -81,6 +78,7 @@ class HasNumberOfClasses(Param):
def __init__(self):
super(HasNumberOfClasses, self).__init__()
self.nb_classes = Param(self, "nb_classes", "number of classes")
self._setDefault(nb_classes=10)

def set_nb_classes(self, nb_classes):
self._paramMap[self.nb_classes] = nb_classes
Expand All @@ -90,42 +88,43 @@ def get_nb_classes(self):
return self.getOrDefault(self.nb_classes)


class HasCategoricalFeatures(Param):
class HasCategoricalLabels(Params):
'''
Mandatory:
Parameter mixin for setting categorical features
'''
def __init__(self):
super(HasCategoricalFeatures, self).__init__()
super(HasCategoricalLabels, self).__init__()
self.categorical = Param(self, "categorical", "Boolean to indicate if labels are categorical")
self._setDefault(categorical=True)

def set_categorical_features(self, categorical):
def set_categorical_labels(self, categorical):
self._paramMap[self.categorical] = categorical
return self

def get_categorical_features(self):
def get_categorical_labels(self):
return self.getOrDefault(self.categorical)


class HasEpochs(Param):
class HasEpochs(Params):
'''
Parameter mixin for number of epochs
'''
def __init__(self):
super(HasEpochs, self).__init__()
self.epochs = Param(self, "epochs", "Number of epochs to train")
self._setDefault(epochs=10)
self.nb_epoch = Param(self, "nb_epoch", "Number of epochs to train")
self._setDefault(nb_epoch=10)

def set_num_epochs(self, epochs):
self._paramMap[self.epochs] = epochs
def set_nb_epoch(self, nb_epoch):
self._paramMap[self.nb_epoch] = nb_epoch
return self

def get_num_epochs(self):
return self.getOrDefault(self.epochs)
def get_nb_epoch(self):
return self.getOrDefault(self.nb_epoch)


class HasBatchSize(Param):
class HasBatchSize(Params):
'''
Parameter mixin for batch size
'''
Expand All @@ -142,13 +141,13 @@ def get_batch_size(self):
return self.getOrDefault(self.batch_size)


class HasVerbosity(Param):
class HasVerbosity(Params):
'''
Parameter mixin for output verbosity
'''
def __init__(self):
super(HasVerbosity, self).__init__()
self.verbose = Param(self, "verbosity", "Stdout verbosity")
self.verbose = Param(self, "verbose", "Stdout verbosity")
self._setDefault(verbose=0)

def set_verbosity(self, verbose):
Expand All @@ -159,7 +158,7 @@ def get_verbosity(self):
return self.getOrDefault(self.verbose)


class HasValidationSplit(Param):
class HasValidationSplit(Params):
'''
Parameter mixin for validation split percentage
'''
Expand All @@ -172,11 +171,11 @@ def set_validation_split(self, validation_split):
self._paramMap[self.validation_split] = validation_split
return self

def get_verbosity(self):
def get_validation_split(self):
return self.getOrDefault(self.validation_split)


class HasNumberOfWorkers(Param):
class HasNumberOfWorkers(Params):
'''
Parameter mixin for number of workers
'''
Expand Down

0 comments on commit d86b350

Please sign in to comment.