Skip to content

Commit

Permalink
Estimator: add petastorm reader_pool_type into constructor
Browse files Browse the repository at this point in the history
Allow users to specify process_pool (default) or thread_pool for
petastorm reader.

Signed-off-by: Chongxiao Cao <chongxiaoc@uber.com>
  • Loading branch information
chongxiaoc committed May 7, 2021
1 parent b2a4065 commit 9d16140
Show file tree
Hide file tree
Showing 10 changed files with 84 additions and 58 deletions.
8 changes: 8 additions & 0 deletions horovod/spark/common/params.py
Expand Up @@ -28,6 +28,7 @@ class EstimatorParams(Params):
'number of parallel worker processes to read train data')
val_reader_num_workers = Param(Params._dummy(), 'val_reader_num_workers',
'number of parallel worker processes to read validation data')
reader_pool_type = Param(Params._dummy(), 'reader_pool_type', 'type of worker pool to read data')
optimizer = Param(Params._dummy(), 'optimizer', 'optimizer')
model = Param(Params._dummy(), 'model', 'model')
backend = Param(Params._dummy(), 'backend', 'backend')
Expand Down Expand Up @@ -122,6 +123,7 @@ def __init__(self):
transformation_fn=None,
train_reader_num_workers=2,
val_reader_num_workers=2,
reader_pool_type='process',
label_shapes=None)

def _check_params(self, metadata):
Expand Down Expand Up @@ -309,6 +311,12 @@ def setValReaderNumWorker(self, value):
def getValReaderNumWorker(self):
return self.getOrDefault(self.val_reader_num_workers)

def setReaderPoolType(self, value):
return self._set(reader_pool_type=value)

def getReaderPoolType(self):
return self.getOrDefault(self.reader_pool_type)

def setLabelShapes(self, value):
return self._set(label_shapes=value)

Expand Down
3 changes: 3 additions & 0 deletions horovod/spark/keras/estimator.py
Expand Up @@ -158,6 +158,8 @@ class KerasEstimator(HorovodEstimator, KerasEstimatorParamsReadable,
high enough, or users need to apply transformation such as
decompression or data augmentation on raw data.
val_reader_num_workers: Similar to the train_reader_num_workers.
reader_pool_type: Type of worker pool used to parallelize reading data from the dataset.
Should be one of ['thread', 'process']. Defaults to 'process'.
"""

custom_objects = Param(Params._dummy(), 'custom_objects', 'custom objects')
Expand Down Expand Up @@ -194,6 +196,7 @@ def __init__(self,
transformation_fn=None,
train_reader_num_workers=None,
val_reader_num_workers=None,
reader_pool_type=None,
label_shapes=None,
checkpoint_callback=None):

Expand Down
5 changes: 3 additions & 2 deletions horovod/spark/keras/remote.py
Expand Up @@ -54,6 +54,7 @@ def RemoteTrainer(estimator, metadata, keras_utils, run_id, dataset_idx):
# Data reader parameters
train_reader_worker_count = estimator.getTrainReaderNumWorker()
val_reader_worker_count = estimator.getValReaderNumWorker()
reader_pool_type = estimator.getReaderPoolType()

# Model parameters
input_shapes, output_shapes = estimator.get_model_shapes()
Expand Down Expand Up @@ -214,7 +215,7 @@ def train(serialized_model, train_rows, val_rows, avg_row_size):
with reader_factory(remote_store.train_data_path,
num_epochs=None,
cur_shard=hvd.rank(),
reader_pool_type='process',
reader_pool_type=reader_pool_type,
workers_count=train_reader_worker_count,
shard_count=hvd.size(),
hdfs_driver=PETASTORM_HDFS_DRIVER,
Expand All @@ -224,7 +225,7 @@ def train(serialized_model, train_rows, val_rows, avg_row_size):
with reader_factory(remote_store.val_data_path,
num_epochs=None,
cur_shard=hvd.rank(),
reader_pool_type='process',
reader_pool_type=reader_pool_type,
workers_count=val_reader_worker_count,
shard_count=hvd.size(),
hdfs_driver=PETASTORM_HDFS_DRIVER,
Expand Down
3 changes: 3 additions & 0 deletions horovod/spark/lightning/estimator.py
Expand Up @@ -150,6 +150,8 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable,
high enough, or users need to apply transformation such as
decompression or data augmentation on raw data.
val_reader_num_workers: Similar to the train_reader_num_workers.
reader_pool_type: Type of worker pool used to parallelize reading data from the dataset.
Should be one of ['thread', 'process']. Defaults to 'process'.
"""

input_shapes = Param(Params._dummy(), 'input_shapes', 'input layer shapes')
Expand Down Expand Up @@ -193,6 +195,7 @@ def __init__(self,
transformation_fn=None,
train_reader_num_workers=None,
val_reader_num_workers=None,
reader_pool_type=None,
label_shapes=None,
inmemory_cache_all=False):

Expand Down
9 changes: 5 additions & 4 deletions horovod/spark/lightning/remote.py
Expand Up @@ -56,6 +56,7 @@ def RemoteTrainer(estimator, metadata, ckpt_bytes, run_id, dataset_idx, train_ro
# Data reader parameters
train_reader_worker_count = estimator.getTrainReaderNumWorker()
val_reader_worker_count = estimator.getValReaderNumWorker()
reader_pool_type = estimator.getReaderPoolType()

# Utility functions
deserialize = deserialize_fn()
Expand Down Expand Up @@ -126,9 +127,9 @@ def train(serialized_model):
# print(row_group)

with make_petastorm_reader(model, remote_store.train_data_path, 'train_dataloader',
train_reader_worker_count), \
train_reader_worker_count, reader_pool_type), \
make_petastorm_reader(model, remote_store.val_data_path, 'val_dataloader',
val_reader_worker_count, should_validate):
val_reader_worker_count, reader_pool_type, should_validate):

trainer.fit(model)

Expand Down Expand Up @@ -168,7 +169,7 @@ def on_sanity_check_end(self, trainer, model):
def _make_petastorm_reader_fn(transformation, schema_fields, batch_size, calculate_shuffle_buffer_size, dataloader_cls):

@contextlib.contextmanager
def make_petastorm_reader(model, data_path, dataloader_attr, reader_worker_count, should_read=True):
def make_petastorm_reader(model, data_path, dataloader_attr, reader_worker_count, reader_pool_type, should_read=True):
from petastorm import TransformSpec, make_reader, make_batch_reader
import horovod.torch as hvd

Expand Down Expand Up @@ -201,7 +202,7 @@ def make_petastorm_reader(model, data_path, dataloader_attr, reader_worker_count
with reader_factory(data_path,
num_epochs=1,
cur_shard=hvd.rank(),
reader_pool_type='process',
reader_pool_type=reader_pool_type,
workers_count=reader_worker_count,
shard_count=hvd.size(),
hdfs_driver=PETASTORM_HDFS_DRIVER,
Expand Down
3 changes: 3 additions & 0 deletions horovod/spark/torch/estimator.py
Expand Up @@ -142,6 +142,8 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable,
high enough, or users need to apply transformation such as
decompression or data augmentation on raw data.
val_reader_num_workers: Similar to the train_reader_num_workers.
reader_pool_type: Type of worker pool used to parallelize reading data from the dataset.
Should be one of ['thread', 'process']. Defaults to 'process'.
"""

input_shapes = Param(Params._dummy(), 'input_shapes', 'input layer shapes')
Expand Down Expand Up @@ -185,6 +187,7 @@ def __init__(self,
transformation_fn=None,
train_reader_num_workers=None,
val_reader_num_workers=None,
reader_pool_type=None,
label_shapes=None,
inmemory_cache_all=False):

Expand Down
5 changes: 3 additions & 2 deletions horovod/spark/torch/remote.py
Expand Up @@ -73,6 +73,7 @@ def RemoteTrainer(estimator, metadata, last_checkpoint_state, run_id, dataset_id
# Data reader parameters
train_reader_worker_count = estimator.getTrainReaderNumWorker()
val_reader_worker_count = estimator.getValReaderNumWorker()
reader_pool_type = estimator.getReaderPoolType()

# Utility functions
deserialize = deserialize_fn()
Expand Down Expand Up @@ -223,7 +224,7 @@ def save_checkpoint():
with reader_factory(remote_store.train_data_path,
num_epochs=1 if inmemory_cache_all else None,
cur_shard=hvd.rank(),
reader_pool_type='process',
reader_pool_type=reader_pool_type,
workers_count=train_reader_worker_count,
shard_count=hvd.size(),
hdfs_driver=PETASTORM_HDFS_DRIVER,
Expand All @@ -233,7 +234,7 @@ def save_checkpoint():
with reader_factory(remote_store.val_data_path,
num_epochs=1 if inmemory_cache_all else None,
cur_shard=hvd.rank(),
reader_pool_type='process',
reader_pool_type=reader_pool_type,
workers_count=val_reader_worker_count,
shard_count=hvd.size(),
hdfs_driver=PETASTORM_HDFS_DRIVER,
Expand Down
32 changes: 17 additions & 15 deletions test/integration/test_spark_keras.py
Expand Up @@ -213,21 +213,23 @@ def test_keras_direct_parquet_train(self, mock_fit_fn, mock_pin_gpu_fn):
optimizer = tf.keras.optimizers.SGD(lr=0.1)
loss = 'binary_crossentropy'

est = hvd.KerasEstimator(
backend=backend,
store=store,
model=model,
optimizer=optimizer,
loss=loss,
feature_cols=['features'],
label_cols=['y'],
batch_size=1,
epochs=3,
verbose=2)

transformer = est.fit_on_parquet()
predictions = transformer.transform(df)
assert predictions.count() == df.count()
for reader_pool_type in ['process', 'thread']:
est = hvd.KerasEstimator(
backend=backend,
store=store,
model=model,
optimizer=optimizer,
loss=loss,
feature_cols=['features'],
label_cols=['y'],
batch_size=1,
epochs=3,
reader_pool_type=reader_pool_type,
verbose=2)

transformer = est.fit_on_parquet()
predictions = transformer.transform(df)
assert predictions.count() == df.count()

@mock.patch('horovod.spark.keras.remote._pin_gpu_fn')
@mock.patch('horovod.spark.keras.util.TFKerasUtil.fit_fn')
Expand Down
34 changes: 18 additions & 16 deletions test/integration/test_spark_lightning.py
Expand Up @@ -370,22 +370,24 @@ def test_direct_parquet_train(self):
model = create_xor_model()

for inmemory_cache_all in [False, True]:
est = hvd_spark.TorchEstimator(
backend=backend,
store=store,
model=model,
input_shapes=[[-1, 2]],
feature_cols=['features'],
label_cols=['y'],
validation=0.2,
batch_size=1,
epochs=3,
verbose=2,
inmemory_cache_all=inmemory_cache_all)

transformer = est.fit_on_parquet()
predictions = transformer.transform(df)
assert predictions.count() == df.count()
for reader_pool_type in ['process', 'thread']:
est = hvd_spark.TorchEstimator(
backend=backend,
store=store,
model=model,
input_shapes=[[-1, 2]],
feature_cols=['features'],
label_cols=['y'],
validation=0.2,
batch_size=1,
epochs=3,
verbose=2,
inmemory_cache_all=inmemory_cache_all,
reader_pool_type=reader_pool_type)

transformer = est.fit_on_parquet()
predictions = transformer.transform(df)
assert predictions.count() == df.count()

def test_legacy_calculate_loss_with_sample_weight(self):
labels = torch.tensor([[1.0, 2.0, 3.0]])
Expand Down
40 changes: 21 additions & 19 deletions test/integration/test_spark_torch.py
Expand Up @@ -353,25 +353,27 @@ def test_torch_direct_parquet_train(self):
loss = nn.BCELoss()

for inmemory_cache_all in [False, True]:
est = hvd_spark.TorchEstimator(
backend=backend,
store=store,
model=model,
optimizer=optimizer,
input_shapes=[[2]],
feature_cols=['features'],
label_cols=['y'],
batch_size=1,
epochs=3,
verbose=2,
inmemory_cache_all=inmemory_cache_all)

# To make sure that setLoss works with non-list loss.
est.setLoss(loss)

transformer = est.fit_on_parquet()
predictions = transformer.transform(df)
assert predictions.count() == df.count()
for reader_pool_type in ['process', 'thread']:
est = hvd_spark.TorchEstimator(
backend=backend,
store=store,
model=model,
optimizer=optimizer,
input_shapes=[[2]],
feature_cols=['features'],
label_cols=['y'],
batch_size=1,
epochs=3,
verbose=2,
reader_pool_type=reader_pool_type,
inmemory_cache_all=inmemory_cache_all)

# To make sure that setLoss works with non-list loss.
est.setLoss(loss)

transformer = est.fit_on_parquet()
predictions = transformer.transform(df)
assert predictions.count() == df.count()

def test_calculate_loss_with_sample_weight(self):
calculate_loss = remote._calculate_loss_fn()
Expand Down

0 comments on commit 9d16140

Please sign in to comment.