diff --git a/horovod/spark/common/datamodule.py b/horovod/spark/common/datamodule.py index 50a07cfb16..4a44eb1a30 100644 --- a/horovod/spark/common/datamodule.py +++ b/horovod/spark/common/datamodule.py @@ -20,12 +20,10 @@ class DataModule(ABC): short_name = None # implementations should provide a short name for easy reference, e.g. 'petastorm', 'nvtabular', etc. def __init__(self, train_dir: str, val_dir: str, num_train_epochs: int=1, has_val: bool=True, - train_batch_size: int=32, val_batch_size: int=32, shuffle_size: int=1000, + train_batch_size: int=32, val_batch_size: int=32, shuffle: bool=True, transform_fn=None, inmemory_cache_all=False, cur_shard: int=0, shard_count: int=1, schema_fields=None, storage_options=None, - steps_per_epoch_train: int=1, steps_per_epoch_val: int=1, verbose=True, - debug_data_loader: bool=False, train_async_data_loader_queue_size: int=None, - val_async_data_loader_queue_size: int=None, **kwargs): + steps_per_epoch_train: int=1, steps_per_epoch_val: int=1, verbose=True, **kwargs): super().__init__() self.train_dir = train_dir self.val_dir = val_dir @@ -33,7 +31,7 @@ def __init__(self, train_dir: str, val_dir: str, num_train_epochs: int=1, has_va self.has_val = has_val self.train_batch_size = train_batch_size self.val_batch_size = val_batch_size - self.shuffle_size = shuffle_size + self.shuffle = shuffle self.transform_fn = transform_fn self.inmemory_cache_all = inmemory_cache_all self.cur_shard = cur_shard @@ -43,9 +41,6 @@ def __init__(self, train_dir: str, val_dir: str, num_train_epochs: int=1, has_va self.steps_per_epoch_train = steps_per_epoch_train self.steps_per_epoch_val = steps_per_epoch_val self.verbose = verbose - self.debug_data_loader = debug_data_loader - self.train_async_data_loader_queue_size = train_async_data_loader_queue_size - self.val_async_data_loader_queue_size = val_async_data_loader_queue_size def __enter__(self): return self diff --git a/horovod/spark/common/params.py b/horovod/spark/common/params.py index e553d04bc3..6efb039f81 100644 --- a/horovod/spark/common/params.py +++ b/horovod/spark/common/params.py @@ -206,12 +206,6 @@ def setNumProc(self, value): def getNumProc(self): return self.getOrDefault(self.num_proc) - def setDataModule(self, value): - return self._set(data_module=value) - - def getDataModule(self): - return self.getOrDefault(self.data_module) - def setModel(self, value): return self._set(model=value) diff --git a/horovod/spark/keras/datamodule.py b/horovod/spark/keras/datamodule.py index fbd33f66cd..ef788c2035 100644 --- a/horovod/spark/keras/datamodule.py +++ b/horovod/spark/keras/datamodule.py @@ -25,7 +25,6 @@ def __init__(self, reader_pool_type: str='thread', train_reader_worker_count: int=2, val_reader_worker_count: int=2, make_dataset=None, - shuffle=False, random_seed=0, **kwargs): from petastorm import TransformSpec, make_reader, make_batch_reader @@ -35,7 +34,6 @@ def __init__(self, reader_pool_type: str='thread', self.train_reader_worker_count = train_reader_worker_count self.val_reader_worker_count = val_reader_worker_count self.make_dataset = make_dataset - self.shuffle = shuffle self.random_seed = random_seed # In general, make_batch_reader is faster than make_reader for reading the dataset. @@ -150,7 +148,7 @@ def train_data(self): cat_names=self.categorical_cols, cont_names=self.continuous_cols, engine="parquet", - shuffle=True, + shuffle=self.shuffle, buffer_size=0.1, # how many batches to load at once parts_per_chunk=1, global_size=hvd.size(), diff --git a/horovod/spark/keras/estimator.py b/horovod/spark/keras/estimator.py index d20804a138..414183d6e5 100644 --- a/horovod/spark/keras/estimator.py +++ b/horovod/spark/keras/estimator.py @@ -251,6 +251,12 @@ def setBackendEnv(self, value): def getBackendEnv(self): return self.getOrDefault(self.backend_env) + def setDataModule(self, value): + return self._set(data_module=value) + + def getDataModule(self): + return self.getOrDefault(self.data_module) + def _check_metadata_compatibility(self, metadata): input_shapes, output_shapes = self.get_model_shapes() util.check_shape_compatibility(metadata, diff --git a/horovod/spark/keras/remote.py b/horovod/spark/keras/remote.py index 5cb61d95d4..5d0b848598 100644 --- a/horovod/spark/keras/remote.py +++ b/horovod/spark/keras/remote.py @@ -227,11 +227,7 @@ def train(serialized_model, train_rows, val_rows, avg_row_size): f"Checkpoint file: {remote_store.checkpoint_path}, Logs dir: {remote_store.logs_path}\n") data_module_kwargs = { - 'label_cols': label_columns, # nvtabular - 'continuous_cols': continuous_columns, # nvtabular - 'categorical_cols': categorical_columns, # nvtabular - 'make_dataset': make_dataset, # petastorm - 'random_seed': random_seed, # petastorm + # common 'train_dir': remote_store.train_data_path, 'val_dir': remote_store.val_data_path, 'num_train_epochs': epochs, @@ -239,11 +235,6 @@ def train(serialized_model, train_rows, val_rows, avg_row_size): 'train_batch_size': batch_size, 'val_batch_size': val_batch_size, 'shuffle': shuffle, - 'random_seed': random_seed, - 'num_reader_epochs': epochs, - 'reader_pool_type': reader_pool_type, - 'train_reader_worker_count': train_reader_worker_count, - 'val_reader_worker_count': val_reader_worker_count, 'transform_fn': transform_fn, 'inmemory_cache_all': inmemory_cache_all, 'cur_shard': hvd.rank(), @@ -253,6 +244,16 @@ def train(serialized_model, train_rows, val_rows, avg_row_size): 'steps_per_epoch_train': steps_per_epoch, 'steps_per_epoch_val': validation_steps, 'verbose': verbose, + # petastorm + 'make_dataset': make_dataset, + 'random_seed': random_seed, + 'reader_pool_type': reader_pool_type, + 'train_reader_worker_count': train_reader_worker_count, + 'val_reader_worker_count': val_reader_worker_count, + # nvtabular + 'categorical_cols': categorical_columns, + 'continuous_cols': continuous_columns, + 'label_cols': label_columns, } if verbose: print("data_module: {}".format(data_module))