Skip to content

Commit

Permalink
Spark/Lightning: add missing tranform_spec for Petastorm datamodule
Browse files Browse the repository at this point in the history
Fix issue#3540

Signed-off-by: Chongxiao Cao <chongxiaoc@uber.com>
  • Loading branch information
chongxiaoc committed May 10, 2022
1 parent 464c82e commit be525b2
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 7 deletions.
16 changes: 11 additions & 5 deletions horovod/spark/lightning/datamodule.py
Expand Up @@ -13,7 +13,7 @@ class PetastormDataModule(pl.LightningDataModule):
def __init__(self, train_dir: str, val_dir: str, num_train_epochs: int=1, has_val: bool=True,
train_batch_size: int=32, val_batch_size: int=32, shuffle_size: int=1000,
num_reader_epochs=None, reader_pool_type: str="process",
reader_worker_count: int=2, transform_spec=None, inmemory_cache_all=False,
reader_worker_count: int=2, transformation=None, inmemory_cache_all=False,
cur_shard: int=0, shard_count: int=1, schema_fields=None, storage_options=None,
steps_per_epoch_train: int=1, steps_per_epoch_val: int=1, verbose=True,
debug_data_loader: bool=False, train_async_data_loader_queue_size: int=None,
Expand All @@ -29,7 +29,7 @@ def __init__(self, train_dir: str, val_dir: str, num_train_epochs: int=1, has_va
self.num_reader_epochs = num_reader_epochs
self.reader_pool_type = reader_pool_type
self.reader_worker_count = reader_worker_count
self.transform_spec = transform_spec
self.transformation = transformation
self.inmemory_cache_all = inmemory_cache_all
self.cur_shard = cur_shard
self.shard_count = shard_count
Expand All @@ -49,13 +49,15 @@ def __init__(self, train_dir: str, val_dir: str, num_train_epochs: int=1, has_va
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == 'fit' or stage is None:
transform_spec = TransformSpec(self.transform_spec) if self.transform_spec else None
transform_spec = TransformSpec(self.transformation) if self.transformation else None
# In general, make_batch_reader is faster than make_reader for reading the dataset.
# However, we found out that make_reader performs data transformations much faster than
# make_batch_reader with parallel worker processes. Therefore, the default reader
# we choose is make_batch_reader unless there are data transformations.
reader_factory_kwargs = dict()
if transform_spec:
reader_factory = make_reader
reader_factory_kwargs['pyarrow_serialize'] = True
else:
reader_factory = make_batch_reader

Expand All @@ -64,15 +66,19 @@ def setup(self, stage=None):
hdfs_driver=PETASTORM_HDFS_DRIVER,
schema_fields=self.schema_fields,
storage_options=self.storage_options,
transform_spec=transform_spec,
# Don't shuffle row groups without shuffling.
shuffle_row_groups=True if self.shuffle_size > 0 else False)
shuffle_row_groups=True if self.shuffle_size > 0 else False,
**reader_factory_kwargs)
if self.has_val:
self.val_reader = reader_factory(self.val_dir, num_epochs=self.num_reader_epochs,
cur_shard=self.cur_shard, shard_count=self.shard_count,
hdfs_driver=PETASTORM_HDFS_DRIVER,
schema_fields=self.schema_fields,
storage_options=self.storage_options,
shuffle_row_groups=False)
transform_spec=transform_spec,
shuffle_row_groups=False,
**reader_factory_kwargs)

def teardown(self, stage=None):
if stage == "fit" or stage is None:
Expand Down
2 changes: 1 addition & 1 deletion horovod/spark/lightning/remote.py
Expand Up @@ -266,7 +266,7 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
'num_reader_epochs': loader_num_epochs,
'reader_pool_type': reader_pool_type,
'reader_worker_count': train_reader_worker_count,
'transform_spec': transformation,
'transformation': transformation,
'inmemory_cache_all': inmemory_cache_all,
'cur_shard': hvd.rank(),
'shard_count': hvd.size(),
Expand Down
2 changes: 1 addition & 1 deletion horovod/spark/torch/remote.py
Expand Up @@ -264,7 +264,7 @@ def save_checkpoint():
transform_spec=transform_spec,
storage_options=storage_options,
# Don't shuffle row groups without shuffling.
shuffle_row_groups=True if shuffle_buffer_size > 0 else False
shuffle_row_groups=True if shuffle_buffer_size > 0 else False,
**reader_factory_kwargs) as train_reader:
with reader_factory(remote_store.val_data_path,
num_epochs=None,
Expand Down

0 comments on commit be525b2

Please sign in to comment.