Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Estimator: add petastorm reader_pool_type into constructor #2903

Merged
merged 1 commit into from May 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

### Added

- Estimator: add petastorm reader_pool_type into constructor ([#2903](https://github.com/horovod/horovod/pull/2903))
- Added NVTX tracing hooks for profiling with Nsight Systems. ([#2723](https://github.com/horovod/horovod/pull/2723))
- Added a generic `num_workers` API for ``RayExecutor`` ([#2870](https://github.com/horovod/horovod/pull/2870))

Expand Down
8 changes: 8 additions & 0 deletions horovod/spark/common/params.py
Expand Up @@ -28,6 +28,7 @@ class EstimatorParams(Params):
'number of parallel worker processes to read train data')
val_reader_num_workers = Param(Params._dummy(), 'val_reader_num_workers',
'number of parallel worker processes to read validation data')
reader_pool_type = Param(Params._dummy(), 'reader_pool_type', 'type of worker pool to read data')
optimizer = Param(Params._dummy(), 'optimizer', 'optimizer')
model = Param(Params._dummy(), 'model', 'model')
backend = Param(Params._dummy(), 'backend', 'backend')
Expand Down Expand Up @@ -122,6 +123,7 @@ def __init__(self):
transformation_fn=None,
train_reader_num_workers=2,
val_reader_num_workers=2,
reader_pool_type='process',
label_shapes=None)

def _check_params(self, metadata):
Expand Down Expand Up @@ -309,6 +311,12 @@ def setValReaderNumWorker(self, value):
def getValReaderNumWorker(self):
return self.getOrDefault(self.val_reader_num_workers)

def setReaderPoolType(self, value):
return self._set(reader_pool_type=value)

def getReaderPoolType(self):
return self.getOrDefault(self.reader_pool_type)

def setLabelShapes(self, value):
return self._set(label_shapes=value)

Expand Down
3 changes: 3 additions & 0 deletions horovod/spark/keras/estimator.py
Expand Up @@ -158,6 +158,8 @@ class KerasEstimator(HorovodEstimator, KerasEstimatorParamsReadable,
high enough, or users need to apply transformation such as
decompression or data augmentation on raw data.
val_reader_num_workers: Similar to the train_reader_num_workers.
reader_pool_type: Type of worker pool used to parallelize reading data from the dataset.
Should be one of ['thread', 'process']. Defaults to 'process'.
"""

custom_objects = Param(Params._dummy(), 'custom_objects', 'custom objects')
Expand Down Expand Up @@ -194,6 +196,7 @@ def __init__(self,
transformation_fn=None,
train_reader_num_workers=None,
val_reader_num_workers=None,
reader_pool_type=None,
label_shapes=None,
checkpoint_callback=None):

Expand Down
5 changes: 3 additions & 2 deletions horovod/spark/keras/remote.py
Expand Up @@ -54,6 +54,7 @@ def RemoteTrainer(estimator, metadata, keras_utils, run_id, dataset_idx):
# Data reader parameters
train_reader_worker_count = estimator.getTrainReaderNumWorker()
val_reader_worker_count = estimator.getValReaderNumWorker()
reader_pool_type = estimator.getReaderPoolType()

# Model parameters
input_shapes, output_shapes = estimator.get_model_shapes()
Expand Down Expand Up @@ -214,7 +215,7 @@ def train(serialized_model, train_rows, val_rows, avg_row_size):
with reader_factory(remote_store.train_data_path,
num_epochs=None,
cur_shard=hvd.rank(),
reader_pool_type='process',
reader_pool_type=reader_pool_type,
workers_count=train_reader_worker_count,
shard_count=hvd.size(),
hdfs_driver=PETASTORM_HDFS_DRIVER,
Expand All @@ -224,7 +225,7 @@ def train(serialized_model, train_rows, val_rows, avg_row_size):
with reader_factory(remote_store.val_data_path,
num_epochs=None,
cur_shard=hvd.rank(),
reader_pool_type='process',
reader_pool_type=reader_pool_type,
workers_count=val_reader_worker_count,
shard_count=hvd.size(),
hdfs_driver=PETASTORM_HDFS_DRIVER,
Expand Down
3 changes: 3 additions & 0 deletions horovod/spark/lightning/estimator.py
Expand Up @@ -150,6 +150,8 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable,
high enough, or users need to apply transformation such as
decompression or data augmentation on raw data.
val_reader_num_workers: Similar to the train_reader_num_workers.
reader_pool_type: Type of worker pool used to parallelize reading data from the dataset.
Should be one of ['thread', 'process']. Defaults to 'process'.
"""

input_shapes = Param(Params._dummy(), 'input_shapes', 'input layer shapes')
Expand Down Expand Up @@ -193,6 +195,7 @@ def __init__(self,
transformation_fn=None,
train_reader_num_workers=None,
val_reader_num_workers=None,
reader_pool_type=None,
label_shapes=None,
inmemory_cache_all=False):

Expand Down
9 changes: 5 additions & 4 deletions horovod/spark/lightning/remote.py
Expand Up @@ -56,6 +56,7 @@ def RemoteTrainer(estimator, metadata, ckpt_bytes, run_id, dataset_idx, train_ro
# Data reader parameters
train_reader_worker_count = estimator.getTrainReaderNumWorker()
val_reader_worker_count = estimator.getValReaderNumWorker()
reader_pool_type = estimator.getReaderPoolType()

# Utility functions
deserialize = deserialize_fn()
Expand Down Expand Up @@ -126,9 +127,9 @@ def train(serialized_model):
# print(row_group)

with make_petastorm_reader(model, remote_store.train_data_path, 'train_dataloader',
train_reader_worker_count), \
train_reader_worker_count, reader_pool_type), \
make_petastorm_reader(model, remote_store.val_data_path, 'val_dataloader',
val_reader_worker_count, should_validate):
val_reader_worker_count, reader_pool_type, should_validate):

trainer.fit(model)

Expand Down Expand Up @@ -168,7 +169,7 @@ def on_sanity_check_end(self, trainer, model):
def _make_petastorm_reader_fn(transformation, schema_fields, batch_size, calculate_shuffle_buffer_size, dataloader_cls):

@contextlib.contextmanager
def make_petastorm_reader(model, data_path, dataloader_attr, reader_worker_count, should_read=True):
def make_petastorm_reader(model, data_path, dataloader_attr, reader_worker_count, reader_pool_type, should_read=True):
from petastorm import TransformSpec, make_reader, make_batch_reader
import horovod.torch as hvd

Expand Down Expand Up @@ -201,7 +202,7 @@ def make_petastorm_reader(model, data_path, dataloader_attr, reader_worker_count
with reader_factory(data_path,
num_epochs=1,
cur_shard=hvd.rank(),
reader_pool_type='process',
reader_pool_type=reader_pool_type,
workers_count=reader_worker_count,
shard_count=hvd.size(),
hdfs_driver=PETASTORM_HDFS_DRIVER,
Expand Down
3 changes: 3 additions & 0 deletions horovod/spark/torch/estimator.py
Expand Up @@ -142,6 +142,8 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable,
high enough, or users need to apply transformation such as
decompression or data augmentation on raw data.
val_reader_num_workers: Similar to the train_reader_num_workers.
reader_pool_type: Type of worker pool used to parallelize reading data from the dataset.
Should be one of ['thread', 'process']. Defaults to 'process'.
"""

input_shapes = Param(Params._dummy(), 'input_shapes', 'input layer shapes')
Expand Down Expand Up @@ -185,6 +187,7 @@ def __init__(self,
transformation_fn=None,
train_reader_num_workers=None,
val_reader_num_workers=None,
reader_pool_type=None,
label_shapes=None,
inmemory_cache_all=False):

Expand Down
5 changes: 3 additions & 2 deletions horovod/spark/torch/remote.py
Expand Up @@ -73,6 +73,7 @@ def RemoteTrainer(estimator, metadata, last_checkpoint_state, run_id, dataset_id
# Data reader parameters
train_reader_worker_count = estimator.getTrainReaderNumWorker()
val_reader_worker_count = estimator.getValReaderNumWorker()
reader_pool_type = estimator.getReaderPoolType()

# Utility functions
deserialize = deserialize_fn()
Expand Down Expand Up @@ -223,7 +224,7 @@ def save_checkpoint():
with reader_factory(remote_store.train_data_path,
num_epochs=1 if inmemory_cache_all else None,
cur_shard=hvd.rank(),
reader_pool_type='process',
reader_pool_type=reader_pool_type,
workers_count=train_reader_worker_count,
shard_count=hvd.size(),
hdfs_driver=PETASTORM_HDFS_DRIVER,
Expand All @@ -233,7 +234,7 @@ def save_checkpoint():
with reader_factory(remote_store.val_data_path,
num_epochs=1 if inmemory_cache_all else None,
cur_shard=hvd.rank(),
reader_pool_type='process',
reader_pool_type=reader_pool_type,
workers_count=val_reader_worker_count,
shard_count=hvd.size(),
hdfs_driver=PETASTORM_HDFS_DRIVER,
Expand Down
32 changes: 17 additions & 15 deletions test/integration/test_spark_keras.py
Expand Up @@ -213,21 +213,23 @@ def test_keras_direct_parquet_train(self, mock_fit_fn, mock_pin_gpu_fn):
optimizer = tf.keras.optimizers.SGD(lr=0.1)
loss = 'binary_crossentropy'

est = hvd.KerasEstimator(
backend=backend,
store=store,
model=model,
optimizer=optimizer,
loss=loss,
feature_cols=['features'],
label_cols=['y'],
batch_size=1,
epochs=3,
verbose=2)

transformer = est.fit_on_parquet()
predictions = transformer.transform(df)
assert predictions.count() == df.count()
for reader_pool_type in ['process', 'thread']:
est = hvd.KerasEstimator(
backend=backend,
store=store,
model=model,
optimizer=optimizer,
loss=loss,
feature_cols=['features'],
label_cols=['y'],
batch_size=1,
epochs=3,
reader_pool_type=reader_pool_type,
verbose=2)

transformer = est.fit_on_parquet()
predictions = transformer.transform(df)
assert predictions.count() == df.count()

@mock.patch('horovod.spark.keras.remote._pin_gpu_fn')
@mock.patch('horovod.spark.keras.util.TFKerasUtil.fit_fn')
Expand Down
34 changes: 18 additions & 16 deletions test/integration/test_spark_lightning.py
Expand Up @@ -370,22 +370,24 @@ def test_direct_parquet_train(self):
model = create_xor_model()

for inmemory_cache_all in [False, True]:
est = hvd_spark.TorchEstimator(
backend=backend,
store=store,
model=model,
input_shapes=[[-1, 2]],
feature_cols=['features'],
label_cols=['y'],
validation=0.2,
batch_size=1,
epochs=3,
verbose=2,
inmemory_cache_all=inmemory_cache_all)

transformer = est.fit_on_parquet()
predictions = transformer.transform(df)
assert predictions.count() == df.count()
for reader_pool_type in ['process', 'thread']:
est = hvd_spark.TorchEstimator(
backend=backend,
store=store,
model=model,
input_shapes=[[-1, 2]],
feature_cols=['features'],
label_cols=['y'],
validation=0.2,
batch_size=1,
epochs=3,
verbose=2,
inmemory_cache_all=inmemory_cache_all,
reader_pool_type=reader_pool_type)

transformer = est.fit_on_parquet()
predictions = transformer.transform(df)
assert predictions.count() == df.count()

def test_legacy_calculate_loss_with_sample_weight(self):
labels = torch.tensor([[1.0, 2.0, 3.0]])
Expand Down
40 changes: 21 additions & 19 deletions test/integration/test_spark_torch.py
Expand Up @@ -353,25 +353,27 @@ def test_torch_direct_parquet_train(self):
loss = nn.BCELoss()

for inmemory_cache_all in [False, True]:
est = hvd_spark.TorchEstimator(
backend=backend,
store=store,
model=model,
optimizer=optimizer,
input_shapes=[[2]],
feature_cols=['features'],
label_cols=['y'],
batch_size=1,
epochs=3,
verbose=2,
inmemory_cache_all=inmemory_cache_all)

# To make sure that setLoss works with non-list loss.
est.setLoss(loss)

transformer = est.fit_on_parquet()
predictions = transformer.transform(df)
assert predictions.count() == df.count()
for reader_pool_type in ['process', 'thread']:
est = hvd_spark.TorchEstimator(
backend=backend,
store=store,
model=model,
optimizer=optimizer,
input_shapes=[[2]],
feature_cols=['features'],
label_cols=['y'],
batch_size=1,
epochs=3,
verbose=2,
reader_pool_type=reader_pool_type,
inmemory_cache_all=inmemory_cache_all)

# To make sure that setLoss works with non-list loss.
est.setLoss(loss)

transformer = est.fit_on_parquet()
predictions = transformer.transform(df)
assert predictions.count() == df.count()

def test_calculate_loss_with_sample_weight(self):
calculate_loss = remote._calculate_loss_fn()
Expand Down