Skip to content

Commit

Permalink
move data_module param to KerasEstimator; fix PetaStormDataModule
Browse files Browse the repository at this point in the history
Signed-off-by: Lee Yang <leewyang@gmail.com>
  • Loading branch information
leewyang committed Aug 31, 2022
1 parent 21fba94 commit e7bdeb0
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 4 deletions.
2 changes: 0 additions & 2 deletions horovod/spark/common/params.py
Expand Up @@ -24,7 +24,6 @@

class EstimatorParams(Params):
num_proc = Param(Params._dummy(), 'num_proc', 'number of processes')
data_module = Param(Params._dummy(), 'data_module', 'data module class to use when reading data')
train_reader_num_workers = Param(Params._dummy(),
'train_reader_num_workers',
'number of parallel worker processes to read train data')
Expand Down Expand Up @@ -140,7 +139,6 @@ def __init__(self):

self._setDefault(
num_proc=None,
data_module=None,
store=None,
backend=None,
model=None,
Expand Down
4 changes: 2 additions & 2 deletions horovod/spark/keras/datamodule.py
Expand Up @@ -21,7 +21,7 @@

class PetastormDataModule(DataModule):
"""Default Petastorm-based DataModule for KerasEstimator."""
def __init__(self, reader_pool_type: str="process",
def __init__(self, reader_pool_type: str='thread',
train_reader_worker_count: int=2,
val_reader_worker_count: int=2,
make_dataset=None,
Expand Down Expand Up @@ -55,7 +55,7 @@ def __init__(self, reader_pool_type: str="process",
def __enter__(self):
super().__enter__()
self.train_reader = self.reader_factory(self.train_dir,
num_epochs=self.num_train_epochs,
num_epochs=1,
cur_shard=hvd.rank(),
reader_pool_type=self.reader_pool_type,
workers_count=self.train_reader_worker_count,
Expand Down
1 change: 1 addition & 0 deletions horovod/spark/keras/estimator.py
Expand Up @@ -158,6 +158,7 @@ class KerasEstimator(HorovodEstimator, KerasEstimatorParamsReadable,
custom_objects = Param(Params._dummy(), 'custom_objects', 'custom objects')
checkpoint_callback = Param(Params._dummy(), 'checkpoint_callback',
'model checkpointing callback')
data_module = Param(Params._dummy(), 'data_module', 'data module class to use when reading data')
backend_env = Param(Params._dummy(), "backend_env",
"dict to add to the environment of the command run on the environment")

Expand Down

0 comments on commit e7bdeb0

Please sign in to comment.