Skip to content

Commit

Permalink
improved probability to logits conversion in keras model (#69)
Browse files Browse the repository at this point in the history
* improved probability to logits conversion in keras model

* attempt to fix theano / tensorflow keras backend differences
  • Loading branch information
jonasrauber committed Aug 30, 2017
1 parent 2237f83 commit d11a3c7
Showing 1 changed file with 13 additions and 17 deletions.
30 changes: 13 additions & 17 deletions foolbox/models/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,21 @@ def __init__(

predictions = model.output

if predicts == 'probabilities':
predictions_are_logits = False
elif predicts == 'logits':
predictions_are_logits = True

shape = K.int_shape(predictions)
_, num_classes = shape
assert num_classes is not None

self._num_classes = num_classes

loss = K.sparse_categorical_crossentropy(
label_input, predictions, from_logits=predictions_are_logits)
if predicts == 'probabilities':
loss = K.sparse_categorical_crossentropy(
label_input, predictions, from_logits=False)
# transform the probability predictions into logits, so that
# the rest of this code can assume predictions to be logits
predictions = self._to_logits(predictions)
elif predicts == 'logits':
loss = K.sparse_categorical_crossentropy(
label_input, predictions, from_logits=True)

# sparse_categorical_crossentropy returns 1-dim tensor,
# gradients wants 0-dim tensor (for some backends)
Expand Down Expand Up @@ -95,15 +97,11 @@ def __init__(
[images_input, label_input],
[predictions, grad])

self._predictions_are_logits = predictions_are_logits

def _as_logits(self, predictions):
assert predictions.ndim in [1, 2]
if self._predictions_are_logits:
return predictions
def _to_logits(self, predictions):
from keras import backend as K
eps = 10e-8
predictions = np.clip(predictions, eps, 1 - eps)
predictions = np.log(predictions)
predictions = K.clip(predictions, eps, 1 - eps)
predictions = K.log(predictions)
return predictions

def num_classes(self):
Expand All @@ -114,15 +112,13 @@ def batch_predictions(self, images):
assert len(predictions) == 1
predictions = predictions[0]
assert predictions.shape == (images.shape[0], self.num_classes())
predictions = self._as_logits(predictions)
return predictions

def predictions_and_gradient(self, image, label):
predictions, gradient = self._pred_grad_fn([
self._process_input(image[np.newaxis]),
np.array([label])])
predictions = np.squeeze(predictions, axis=0)
predictions = self._as_logits(predictions)
gradient = np.squeeze(gradient, axis=0)
gradient = self._process_gradient(gradient)
assert predictions.shape == (self.num_classes(),)
Expand Down

0 comments on commit d11a3c7

Please sign in to comment.