Skip to content

Commit

Permalink
Wrap dataset generation function to disable autograph to fix issues w…
Browse files Browse the repository at this point in the history
…ith invalid tensor shapes (#2069)

Signed-off-by: Travis Addair <taddair@uber.com>
  • Loading branch information
tgaddair committed Jun 30, 2020
1 parent 62c2314 commit a860d5e
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion horovod/spark/keras/util.py
Expand Up @@ -15,6 +15,8 @@

import io

from distutils.version import LooseVersion

import h5py
import numpy as np
import tensorflow as tf
Expand All @@ -28,6 +30,8 @@
BARE_KERAS = 'keras'
TF_KERAS = 'tf_keras'

_HAS_AUTOGRAPH = LooseVersion(tf.__version__) >= LooseVersion('1.15')


class TFKerasUtil(object):
type = TF_KERAS
Expand Down Expand Up @@ -74,7 +78,7 @@ def fn(reader, shuffle_buffer_size, is_batch_reader, shuffle=False):

dataset = dataset.batch(batch_size).map(prep_data_tf_keras)
return dataset
return fn
return tf.autograph.experimental.do_not_convert(fn) if _HAS_AUTOGRAPH else fn

@staticmethod
def get_horovod():
Expand Down

0 comments on commit a860d5e

Please sign in to comment.