From e7bdeb075a581ec6a1246f3578cde8c1c41bbb57 Mon Sep 17 00:00:00 2001 From: Lee Yang Date: Tue, 30 Aug 2022 17:37:25 -0700 Subject: [PATCH] move data_module param to KerasEstimator; fix PetaStormDataModule Signed-off-by: Lee Yang --- horovod/spark/common/params.py | 2 -- horovod/spark/keras/datamodule.py | 4 ++-- horovod/spark/keras/estimator.py | 1 + 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/horovod/spark/common/params.py b/horovod/spark/common/params.py index 8357dfd0ef..e553d04bc3 100644 --- a/horovod/spark/common/params.py +++ b/horovod/spark/common/params.py @@ -24,7 +24,6 @@ class EstimatorParams(Params): num_proc = Param(Params._dummy(), 'num_proc', 'number of processes') - data_module = Param(Params._dummy(), 'data_module', 'data module class to use when reading data') train_reader_num_workers = Param(Params._dummy(), 'train_reader_num_workers', 'number of parallel worker processes to read train data') @@ -140,7 +139,6 @@ def __init__(self): self._setDefault( num_proc=None, - data_module=None, store=None, backend=None, model=None, diff --git a/horovod/spark/keras/datamodule.py b/horovod/spark/keras/datamodule.py index abca200507..fbd33f66cd 100644 --- a/horovod/spark/keras/datamodule.py +++ b/horovod/spark/keras/datamodule.py @@ -21,7 +21,7 @@ class PetastormDataModule(DataModule): """Default Petastorm-based DataModule for KerasEstimator.""" - def __init__(self, reader_pool_type: str="process", + def __init__(self, reader_pool_type: str='thread', train_reader_worker_count: int=2, val_reader_worker_count: int=2, make_dataset=None, @@ -55,7 +55,7 @@ def __init__(self, reader_pool_type: str="process", def __enter__(self): super().__enter__() self.train_reader = self.reader_factory(self.train_dir, - num_epochs=self.num_train_epochs, + num_epochs=1, cur_shard=hvd.rank(), reader_pool_type=self.reader_pool_type, workers_count=self.train_reader_worker_count, diff --git a/horovod/spark/keras/estimator.py b/horovod/spark/keras/estimator.py index 7a1c4e587f..d20804a138 100644 --- a/horovod/spark/keras/estimator.py +++ b/horovod/spark/keras/estimator.py @@ -158,6 +158,7 @@ class KerasEstimator(HorovodEstimator, KerasEstimatorParamsReadable, custom_objects = Param(Params._dummy(), 'custom_objects', 'custom objects') checkpoint_callback = Param(Params._dummy(), 'checkpoint_callback', 'model checkpointing callback') + data_module = Param(Params._dummy(), 'data_module', 'data module class to use when reading data') backend_env = Param(Params._dummy(), "backend_env", "dict to add to the environment of the command run on the environment")