Skip to content

Commit

Permalink
Fix #18
Browse files Browse the repository at this point in the history
  • Loading branch information
angeloskath committed Feb 10, 2019
1 parent 991230a commit e5dac90
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 2 deletions.
39 changes: 38 additions & 1 deletion importance_sampling/model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,41 @@ def _augment_model(self, model, score, reweighting):
self._train_on_batch = train_on_batch
self._evaluate_on_batch = evaluate_on_batch

def _normalize_metric_size(self, metric_values, n_samples):
"""Normalize the metric size so that they can be aggregated per sample.
Fixes #18.
Arguments
---------
metric_values: array, the output return by Keras
n_samples: int, the number of samples in our batch
"""
shape = metric_values.shape

# If everything is correct return early
if len(shape) == 2 and shape[0] == n_samples:
return metric_values

# Case 1: The metrics are of shape (n_samples,)
if len(shape) == 1 and shape[0] == n_samples:
return np.expand_dims(metric_values, -1)

# Case 2: The metrics have size 1
if metric_values.size == 1:
return np.tile(
metric_values.reshape(1, 1),
(n_samples, 1)
)

# Report the error in a normal way
raise ValueError(("A metric function returns a non scalar value. "
"In order to use the automatic aggregation method "
"for evaluation on the validation set the metrics "
"need to be scalar per sample or per batch but the "
"shape is {}.".format(shape)))

def evaluate_batch(self, x, y):
n_samples = len(y)
if len(y.shape) == 1:
y = np.expand_dims(y, axis=1)
dummy_weights = np.ones((y.shape[0], self.reweighting.weight_size))
Expand All @@ -295,7 +329,10 @@ def evaluate_batch(self, x, y):

signal("is.evaluate_batch").send(outputs)

return np.hstack([outputs[self.LOSS]] + outputs[self.METRIC0:])
return np.hstack([
self._normalize_metric_size(m, n_samples)
for m in [outputs[self.LOSS]] + outputs[self.METRIC0:]
])

def score_batch(self, x, y):
if len(y.shape) == 1:
Expand Down
8 changes: 7 additions & 1 deletion importance_sampling/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,16 @@ def fit_dataset(self, dataset, steps_per_epoch=None, batch_size=32,
return self.history

def _get_metric_names(self):
def name(x):
try:
return x.__name__
except AttributeError:
return str(x)

metrics = self.original_model.metrics or []
return (
["loss"] +
list(map(str, metrics)) +
list(map(name, metrics)) +
["score"]
)

Expand Down
20 changes: 20 additions & 0 deletions tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import unittest

from keras import backend as K
from keras.layers import Dense, Input, dot
from keras.models import Model, Sequential
import numpy as np
Expand Down Expand Up @@ -141,6 +142,25 @@ def on_sample(sampler, idxs, w, scores):
self.assertEqual(16, calls[0])
calls[0] = 0

def test_metrics(self):
def const_metric(a, b):
return K.tf.constant(1, dtype=K.tf.float32)

model = Sequential([
Dense(10, activation="relu", input_shape=(2,)),
Dense(10, activation="relu"),
Dense(2)
])
model.compile("sgd", "mse", metrics=[const_metric])

for Training in self.TRAININGS:
Training(model).fit(
np.random.rand(64, 2), np.random.rand(64, 2),
batch_size=16,
epochs=4,
validation_data=[np.random.rand(64, 2), np.random.rand(64, 2)]
)


if __name__ == "__main__":
unittest.main()

0 comments on commit e5dac90

Please sign in to comment.