Skip to content

Commit

Permalink
remove unused args; move data_module getter/setter; fix shuffle arg
Browse files Browse the repository at this point in the history
Signed-off-by: Lee Yang <leewyang@gmail.com>
  • Loading branch information
leewyang committed Aug 31, 2022
1 parent e7bdeb0 commit 6e53391
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 27 deletions.
11 changes: 3 additions & 8 deletions horovod/spark/common/datamodule.py
Expand Up @@ -20,20 +20,18 @@ class DataModule(ABC):
short_name = None # implementations should provide a short name for easy reference, e.g. 'petastorm', 'nvtabular', etc.

def __init__(self, train_dir: str, val_dir: str, num_train_epochs: int=1, has_val: bool=True,
train_batch_size: int=32, val_batch_size: int=32, shuffle_size: int=1000,
train_batch_size: int=32, val_batch_size: int=32, shuffle: bool=True,
transform_fn=None, inmemory_cache_all=False,
cur_shard: int=0, shard_count: int=1, schema_fields=None, storage_options=None,
steps_per_epoch_train: int=1, steps_per_epoch_val: int=1, verbose=True,
debug_data_loader: bool=False, train_async_data_loader_queue_size: int=None,
val_async_data_loader_queue_size: int=None, **kwargs):
steps_per_epoch_train: int=1, steps_per_epoch_val: int=1, verbose=True, **kwargs):
super().__init__()
self.train_dir = train_dir
self.val_dir = val_dir
self.num_train_epochs = num_train_epochs
self.has_val = has_val
self.train_batch_size = train_batch_size
self.val_batch_size = val_batch_size
self.shuffle_size = shuffle_size
self.shuffle = shuffle
self.transform_fn = transform_fn
self.inmemory_cache_all = inmemory_cache_all
self.cur_shard = cur_shard
Expand All @@ -43,9 +41,6 @@ def __init__(self, train_dir: str, val_dir: str, num_train_epochs: int=1, has_va
self.steps_per_epoch_train = steps_per_epoch_train
self.steps_per_epoch_val = steps_per_epoch_val
self.verbose = verbose
self.debug_data_loader = debug_data_loader
self.train_async_data_loader_queue_size = train_async_data_loader_queue_size
self.val_async_data_loader_queue_size = val_async_data_loader_queue_size

def __enter__(self):
return self
Expand Down
6 changes: 0 additions & 6 deletions horovod/spark/common/params.py
Expand Up @@ -206,12 +206,6 @@ def setNumProc(self, value):
def getNumProc(self):
return self.getOrDefault(self.num_proc)

def setDataModule(self, value):
return self._set(data_module=value)

def getDataModule(self):
return self.getOrDefault(self.data_module)

def setModel(self, value):
return self._set(model=value)

Expand Down
4 changes: 1 addition & 3 deletions horovod/spark/keras/datamodule.py
Expand Up @@ -25,7 +25,6 @@ def __init__(self, reader_pool_type: str='thread',
train_reader_worker_count: int=2,
val_reader_worker_count: int=2,
make_dataset=None,
shuffle=False,
random_seed=0,
**kwargs):
from petastorm import TransformSpec, make_reader, make_batch_reader
Expand All @@ -35,7 +34,6 @@ def __init__(self, reader_pool_type: str='thread',
self.train_reader_worker_count = train_reader_worker_count
self.val_reader_worker_count = val_reader_worker_count
self.make_dataset = make_dataset
self.shuffle = shuffle
self.random_seed = random_seed

# In general, make_batch_reader is faster than make_reader for reading the dataset.
Expand Down Expand Up @@ -150,7 +148,7 @@ def train_data(self):
cat_names=self.categorical_cols,
cont_names=self.continuous_cols,
engine="parquet",
shuffle=True,
shuffle=self.shuffle,
buffer_size=0.1, # how many batches to load at once
parts_per_chunk=1,
global_size=hvd.size(),
Expand Down
6 changes: 6 additions & 0 deletions horovod/spark/keras/estimator.py
Expand Up @@ -251,6 +251,12 @@ def setBackendEnv(self, value):
def getBackendEnv(self):
return self.getOrDefault(self.backend_env)

def setDataModule(self, value):
return self._set(data_module=value)

def getDataModule(self):
return self.getOrDefault(self.data_module)

def _check_metadata_compatibility(self, metadata):
input_shapes, output_shapes = self.get_model_shapes()
util.check_shape_compatibility(metadata,
Expand Down
21 changes: 11 additions & 10 deletions horovod/spark/keras/remote.py
Expand Up @@ -227,23 +227,14 @@ def train(serialized_model, train_rows, val_rows, avg_row_size):
f"Checkpoint file: {remote_store.checkpoint_path}, Logs dir: {remote_store.logs_path}\n")

data_module_kwargs = {
'label_cols': label_columns, # nvtabular
'continuous_cols': continuous_columns, # nvtabular
'categorical_cols': categorical_columns, # nvtabular
'make_dataset': make_dataset, # petastorm
'random_seed': random_seed, # petastorm
# common
'train_dir': remote_store.train_data_path,
'val_dir': remote_store.val_data_path,
'num_train_epochs': epochs,
'has_val': should_validate is not None,
'train_batch_size': batch_size,
'val_batch_size': val_batch_size,
'shuffle': shuffle,
'random_seed': random_seed,
'num_reader_epochs': epochs,
'reader_pool_type': reader_pool_type,
'train_reader_worker_count': train_reader_worker_count,
'val_reader_worker_count': val_reader_worker_count,
'transform_fn': transform_fn,
'inmemory_cache_all': inmemory_cache_all,
'cur_shard': hvd.rank(),
Expand All @@ -253,6 +244,16 @@ def train(serialized_model, train_rows, val_rows, avg_row_size):
'steps_per_epoch_train': steps_per_epoch,
'steps_per_epoch_val': validation_steps,
'verbose': verbose,
# petastorm
'make_dataset': make_dataset,
'random_seed': random_seed,
'reader_pool_type': reader_pool_type,
'train_reader_worker_count': train_reader_worker_count,
'val_reader_worker_count': val_reader_worker_count,
# nvtabular
'categorical_cols': categorical_columns,
'continuous_cols': continuous_columns,
'label_cols': label_columns,
}
if verbose:
print("data_module: {}".format(data_module))
Expand Down

0 comments on commit 6e53391

Please sign in to comment.