Skip to content

Commit

Permalink
Fixes tensorflow#57: Adding class weight support on estimator level. …
Browse files Browse the repository at this point in the history
…Now just pass an array of n_classes with weights and it will adjust logits and loss value for this
  • Loading branch information
ilblackdragon committed Feb 14, 2016
1 parent 5ff25c9 commit 79eed03
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 10 deletions.
13 changes: 12 additions & 1 deletion skflow/estimators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def exp_decay(global_step):
return tf.train.exponential_decay(
learning_rate=0.1, global_step,
decay_steps=2, decay_rate=0.001)
class_weight: None or list of n_classes floats. Weight associated with
classes for loss computation. If not given, all classes are suppose to have
weight one.
tf_random_seed: Random seed for TensorFlow initializers.
Setting this value, allows consistency between reruns.
continue_training: when continue_training is True, once initialized
Expand All @@ -79,7 +82,8 @@ def exp_decay(global_step):

def __init__(self, model_fn, n_classes, tf_master="", batch_size=32,
steps=200, optimizer="SGD",
learning_rate=0.1, tf_random_seed=42, continue_training=False,
learning_rate=0.1, class_weight=None,
tf_random_seed=42, continue_training=False,
num_cores=4, verbose=1, early_stopping_rounds=None,
max_to_keep=5, keep_checkpoint_every_n_hours=10000):
self.n_classes = n_classes
Expand All @@ -97,6 +101,7 @@ def __init__(self, model_fn, n_classes, tf_master="", batch_size=32,
self._early_stopping_rounds = early_stopping_rounds
self.max_to_keep = max_to_keep
self.keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
self.class_weight = class_weight

def _setup_training(self):
"""Sets up graph, model and trainer."""
Expand All @@ -117,6 +122,12 @@ def _setup_training(self):
tf.as_dtype(self._data_feeder.output_dtype), output_shape,
name="output")

# If class weights are provided, add them to the graph.
# Different loss functions can use this tensor by name.
if self.class_weight:
self._class_weight_node = tf.constant(
self.class_weight, name='class_weight')

# Add histograms for X and y if they are floats.
if self._data_feeder.input_dtype in (np.float32, np.float64):
tf.histogram_summary("X", self._inp)
Expand Down
7 changes: 6 additions & 1 deletion skflow/estimators/dnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def exp_decay(global_step):
return tf.train.exponential_decay(
learning_rate=0.1, global_step,
decay_steps=2, decay_rate=0.001)
class_weight: None or list of n_classes floats. Weight associated with
classes for loss computation. If not given, all classes are suppose to have
weight one.
tf_random_seed: Random seed for TensorFlow initializers.
Setting this value, allows consistency between reruns.
continue_training: when continue_training is True, once initialized
Expand All @@ -57,6 +60,7 @@ def exp_decay(global_step):

def __init__(self, hidden_units, n_classes, tf_master="", batch_size=32,
steps=200, optimizer="SGD", learning_rate=0.1,
class_weight=None,
tf_random_seed=42, continue_training=False,
verbose=1, early_stopping_rounds=None,
max_to_keep=5, keep_checkpoint_every_n_hours=10000):
Expand All @@ -65,7 +69,8 @@ def __init__(self, hidden_units, n_classes, tf_master="", batch_size=32,
model_fn=self._model_fn,
n_classes=n_classes, tf_master=tf_master,
batch_size=batch_size, steps=steps, optimizer=optimizer,
learning_rate=learning_rate, tf_random_seed=tf_random_seed,
learning_rate=learning_rate, class_weight=class_weight,
tf_random_seed=tf_random_seed,
continue_training=continue_training, verbose=verbose,
early_stopping_rounds=early_stopping_rounds,
max_to_keep=max_to_keep,
Expand Down
6 changes: 4 additions & 2 deletions skflow/estimators/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,16 @@ class TensorFlowLinearClassifier(TensorFlowEstimator, ClassifierMixin):
"""TensorFlow Linear Classifier model."""

def __init__(self, n_classes, tf_master="", batch_size=32, steps=200, optimizer="SGD",
learning_rate=0.1, tf_random_seed=42, continue_training=False,
learning_rate=0.1, class_weight=None,
tf_random_seed=42, continue_training=False,
verbose=1, early_stopping_rounds=None,
max_to_keep=5, keep_checkpoint_every_n_hours=10000):
super(TensorFlowLinearClassifier, self).__init__(
model_fn=models.logistic_regression, n_classes=n_classes,
tf_master=tf_master,
batch_size=batch_size, steps=steps, optimizer=optimizer,
learning_rate=learning_rate, tf_random_seed=tf_random_seed,
learning_rate=learning_rate, class_weight=class_weight,
tf_random_seed=tf_random_seed,
continue_training=continue_training,
verbose=verbose, early_stopping_rounds=early_stopping_rounds,
max_to_keep=max_to_keep,
Expand Down
7 changes: 6 additions & 1 deletion skflow/estimators/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ def exp_decay(global_step):
return tf.train.exponential_decay(
learning_rate=0.1, global_step,
decay_steps=2, decay_rate=0.001)
class_weight: None or list of n_classes floats. Weight associated with
classes for loss computation. If not given, all classes are suppose to have
weight one.
tf_random_seed: Random seed for TensorFlow initializers.
Setting this value, allows consistency between reruns.
continue_training: when continue_training is True, once initialized
Expand All @@ -74,6 +77,7 @@ def __init__(self, rnn_size, n_classes, cell_type='gru', num_layers=1,
initial_state=None, bidirectional=False,
sequence_length=None, tf_master="", batch_size=32,
steps=50, optimizer="SGD", learning_rate=0.1,
class_weight=None,
tf_random_seed=42, continue_training=False,
verbose=1, early_stopping_rounds=None,
max_to_keep=5, keep_checkpoint_every_n_hours=10000):
Expand All @@ -88,7 +92,8 @@ def __init__(self, rnn_size, n_classes, cell_type='gru', num_layers=1,
model_fn=self._model_fn,
n_classes=n_classes, tf_master=tf_master,
batch_size=batch_size, steps=steps, optimizer=optimizer,
learning_rate=learning_rate, tf_random_seed=tf_random_seed,
learning_rate=learning_rate, class_weight=class_weight,
tf_random_seed=tf_random_seed,
continue_training=continue_training, verbose=verbose,
early_stopping_rounds=early_stopping_rounds,
max_to_keep=max_to_keep,
Expand Down
10 changes: 9 additions & 1 deletion skflow/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def logistic_regression(X, y, class_weight=None):
shape should be [batch_size, n_classes].
class_weight: tensor, [n_classes], where for each class
it has weight of the class. If not provided
all ones are used.
will check if graph contains tensor `class_weight:0`.
If that is not provided either all ones are used.
Returns:
Predictions and loss tensors.
Expand All @@ -68,6 +69,13 @@ def logistic_regression(X, y, class_weight=None):
bias = tf.get_variable('bias', [y.get_shape()[-1]])
tf.histogram_summary('logistic_regression.weights', weights)
tf.histogram_summary('logistic_regression.bias', bias)
# If no class weight provided, try to retrieve one from pre-defined
# tensor name in the graph.
if not class_weight:
try:
class_weight = tf.get_default_graph().get_tensor_by_name('class_weight:0')
except KeyError:
pass
return softmax_classifier(X, y, weights, bias,
class_weight=class_weight)

Expand Down
13 changes: 9 additions & 4 deletions skflow/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,21 @@ def testOneDim(self):
self.assertLess(score, 0.3, "Failed with score = {0}".format(score))

def testIris(self):
random.seed(42)
iris = datasets.load_iris()
classifier = skflow.TensorFlowLinearClassifier(n_classes=3)
classifier.fit(iris.data, iris.target)
score = accuracy_score(iris.target, classifier.predict(iris.data))
self.assertGreater(score, 0.5, "Failed with score = {0}".format(score))
self.assertGreater(score, 0.7, "Failed with score = {0}".format(score))

def testIrisClassWeight(self):
iris = datasets.load_iris()
classifier = skflow.TensorFlowLinearClassifier(
n_classes=3, class_weight=[0.1, 0.8, 0.1])
classifier.fit(iris.data, iris.target)
score = accuracy_score(iris.target, classifier.predict(iris.data))
self.assertLess(score, 0.7, "Failed with score = {0}".format(score))

def testIrisSummaries(self):
random.seed(42)
iris = datasets.load_iris()
classifier = skflow.TensorFlowLinearClassifier(n_classes=3)
classifier.fit(iris.data, iris.target, logdir='/tmp/skflow_tests/')
Expand All @@ -53,7 +59,6 @@ def testIrisSummaries(self):


def testIrisContinueTraining(self):
random.seed(42)
iris = datasets.load_iris()
classifier = skflow.TensorFlowLinearClassifier(n_classes=3,
learning_rate=0.01, continue_training=True, steps=250)
Expand Down

0 comments on commit 79eed03

Please sign in to comment.