Skip to content
Permalink
Browse files

Fixed Keras tests for changes to weights in RMSProp optimizer for TF …

…2.0 compat that adds iteration scalar to the head of the weights list (#1392)
  • Loading branch information...
tgaddair committed Sep 10, 2019
1 parent c90b8ba commit 93d1a4fc4febbe510be10aaf160ad4a5fc083e21
Showing with 12 additions and 13 deletions.
  1. +12 −13 test/test_keras.py
@@ -92,10 +92,7 @@ def test_load_model(self):
self.assertEqual(type(new_opt).__module__, 'horovod._keras')
self.assertEqual(type(new_opt).__name__, 'RMSprop')
self.assertEqual(K.get_value(opt.lr), K.get_value(new_opt.lr))
self.assertEqual(len(opt.get_weights()), len(new_opt.get_weights()))
for weights, new_weights in zip(opt.get_weights(),
new_opt.get_weights()):
self.assertListEqual(weights.tolist(), new_weights.tolist())
self._check_optimizer_weights(opt, new_opt)

def test_load_model_custom_optimizers(self):
class TestOptimizer(keras.optimizers.RMSprop):
@@ -131,11 +128,7 @@ def __init__(self, **kwargs):

self.assertEqual(type(new_opt).__module__, 'horovod._keras')
self.assertEqual(type(new_opt).__name__, 'TestOptimizer')
self.assertEqual(K.get_value(opt.lr), K.get_value(new_opt.lr))
self.assertEqual(len(opt.get_weights()), len(new_opt.get_weights()))
for weights, new_weights in zip(opt.get_weights(),
new_opt.get_weights()):
self.assertListEqual(weights.tolist(), new_weights.tolist())
self._check_optimizer_weights(opt, new_opt)

def test_load_model_custom_objects(self):
class TestOptimizer(keras.optimizers.RMSprop):
@@ -175,10 +168,7 @@ def __init__(self, **kwargs):
self.assertEqual(type(new_opt).__module__, 'horovod._keras')
self.assertEqual(type(new_opt).__name__, 'TestOptimizer')
self.assertEqual(K.get_value(opt.lr), K.get_value(new_opt.lr))
self.assertEqual(len(opt.get_weights()), len(new_opt.get_weights()))
for weights, new_weights in zip(opt.get_weights(),
new_opt.get_weights()):
self.assertListEqual(weights.tolist(), new_weights.tolist())
self._check_optimizer_weights(opt, new_opt)

def test_load_model_broadcast(self):
def create_model():
@@ -239,3 +229,12 @@ def generator():
initial_epoch=1)

self.assertEqual(len(model.optimizer.weights), 5)

def _check_optimizer_weights(self, opt, new_opt):
self.assertEqual(len(opt.get_weights()), len(new_opt.get_weights()))
for weights, new_weights in zip(opt.get_weights(),
new_opt.get_weights()):
if np.isscalar(weights):
self.assertEqual(weights, new_weights)
else:
self.assertListEqual(weights.tolist(), new_weights.tolist())

0 comments on commit 93d1a4f

Please sign in to comment.
You can’t perform that action at this time.