From e9c144ce621c52182b8c3490e19ecfcbf1a546fe Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Tue, 30 Jun 2020 12:35:58 -0700 Subject: [PATCH] Wrap dataset generation function to disable autograph to fix issues with invalid tensor shapes Signed-off-by: Travis Addair --- horovod/spark/keras/util.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/horovod/spark/keras/util.py b/horovod/spark/keras/util.py index 193d9a89ad..6fcdd97fe9 100644 --- a/horovod/spark/keras/util.py +++ b/horovod/spark/keras/util.py @@ -15,6 +15,8 @@ import io +from distutils.version import LooseVersion + import h5py import numpy as np import tensorflow as tf @@ -28,6 +30,8 @@ BARE_KERAS = 'keras' TF_KERAS = 'tf_keras' +_HAS_AUTOGRAPH = LooseVersion(tf.__version__) >= LooseVersion('1.15') + class TFKerasUtil(object): type = TF_KERAS @@ -74,7 +78,7 @@ def fn(reader, shuffle_buffer_size, is_batch_reader, shuffle=False): dataset = dataset.batch(batch_size).map(prep_data_tf_keras) return dataset - return fn + return tf.autograph.experimental.do_not_convert(fn) if _HAS_AUTOGRAPH else fn @staticmethod def get_horovod():