Skip to content

Commit

Permalink
Spark: adapt Petastorm 0.12.0 changes to speedup data shuffling (#3665)
Browse files Browse the repository at this point in the history
Added:
- Petastorm reader level is used to parallelize shuffling samples.
- Random seed is added for Lightning datamodule to generate
  reproducible data loading outputs.

Changed:
- Default reader pool is changed from `process` to `thread` for lower
  memory usage.

Deprecated:
- Deprecated field `shuffle_buffer_size` from `EstimatorParams`. Use `shuffle` to enable shuffle or not.
  • Loading branch information
chongxiaoc committed Aug 29, 2022
1 parent b182e83 commit 001260a
Show file tree
Hide file tree
Showing 14 changed files with 96 additions and 301 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Expand Up @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

### Added

- Added support for Petastorm reader level parallel shuffling. ([#3665](https://github.com/horovod/horovod/pull/3665))
- Added random seed support for Lightning datamodule to generate reproducible data loading outputs. ([#3665](https://github.com/horovod/horovod/pull/3665))
- Added support for `int8, uint8` allreduce and grouped_allreduce in tensorflow. ([#3649](https://github.com/horovod/horovod/pull/3649))
- Added support for batched memory copies in GPUAllgather. ([#3590](https://github.com/horovod/horovod/pull/3590))
- Added support for batched memory copies in GPUReducescatter. ([#3621](https://github.com/horovod/horovod/pull/3621))
Expand All @@ -20,8 +22,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

### Changed

- Default Petastorm reader pool is changed from `process` to `thread` for lower memory usage. ([#3665](https://github.com/horovod/horovod/pull/3665))

### Deprecated

- Deprecated field `shuffle_buffer_size` from `EstimatorParams`. Use `shuffle` to enable shuffle or not. ([#3665](https://github.com/horovod/horovod/pull/3665))

### Removed

### Fixed
Expand Down
16 changes: 14 additions & 2 deletions horovod/spark/common/params.py
Expand Up @@ -74,9 +74,14 @@ class EstimatorParams(Params):

shuffle_buffer_size = Param(Params._dummy(),
'shuffle_buffer_size',
'shuffling buffer size of data before training in number of samples',
'(Deprecated) shuffling buffer size used for training samples',
typeConverter=TypeConverters.toInt)

shuffle = Param(Params._dummy(),
'shuffle',
'Whether to shuffle training samples or not. Defaults to True',
typeConverter=TypeConverters.toBoolean)

verbose = Param(Params._dummy(), 'verbose', 'verbose flag (0=silent, 1=enabled, other values used by frameworks)',
typeConverter=TypeConverters.toInt)

Expand Down Expand Up @@ -149,6 +154,7 @@ def __init__(self):
callbacks=[],
random_seed=None,
shuffle_buffer_size=None,
shuffle=True,
partitions_per_process=10,
run_id=None,
train_steps_per_epoch=None,
Expand All @@ -158,7 +164,7 @@ def __init__(self):
transformation_removed_fields=None,
train_reader_num_workers=2,
val_reader_num_workers=2,
reader_pool_type='process',
reader_pool_type='thread',
label_shapes=None,
inmemory_cache_all=False,
use_gpu=True,
Expand Down Expand Up @@ -316,9 +322,15 @@ def getRandomSeed(self):
def setShufflingBufferSize(self, value):
return self._set(shuffle_buffer_size=value)

def setShuffle(self, value):
return self._set(shuffle=value)

def getShufflingBufferSize(self):
return self.getOrDefault(self.shuffle_buffer_size)

def getShuffle(self):
return self.getOrDefault(self.shuffle)

def setOptimizer(self, value):
return self._set(optimizer=value)

Expand Down
10 changes: 6 additions & 4 deletions horovod/spark/keras/estimator.py
Expand Up @@ -116,10 +116,11 @@ class KerasEstimator(HorovodEstimator, KerasEstimatorParamsReadable,
epochs: Number of epochs to train.
verbose: Verbosity level [0, 2] (default: 1).
random_seed: Optional random seed to use for Tensorflow. Default: None.
shuffle_buffer_size: Optional size of in-memory shuffle buffer in rows (on training data).
shuffle_buffer_size: (Deprecated) Optional size of in-memory shuffle buffer in rows (on training data).
Allocating a larger buffer size increases randomness of shuffling at
the cost of more host memory. Defaults to estimating with an assumption
of 4GB of memory per host. Set shuffle_buffer_size=0 would turn off shuffle.
shuffle: (Optional) Whether to shuffle training samples or not. Defaults to True.
partitions_per_process: Number of Parquet partitions to assign per worker process from `num_proc` (default: 10).
run_id: Optional unique ID for this run for organization in the Store. Will be automatically assigned if not
provided.
Expand All @@ -142,8 +143,8 @@ class KerasEstimator(HorovodEstimator, KerasEstimatorParamsReadable,
high enough, or users need to apply transformation such as
decompression or data augmentation on raw data.
val_reader_num_workers: Similar to the train_reader_num_workers.
reader_pool_type: Type of worker pool used to parallelize reading data from the dataset.
Should be one of ['thread', 'process']. Defaults to 'process'.
reader_pool_type: Type of Petastorm worker pool used to parallelize reading data from the dataset.
Should be one of ['thread', 'process', 'dummy']. Defaults to 'thread'.
inmemory_cache_all: boolean value. Cache the data in memory for training and validation. Default: False.
backend_env: dict to add to the environment of the backend. Defaults to setting the java heap size to
2G min and max for libhdfs through petastorm
Expand Down Expand Up @@ -180,14 +181,15 @@ def __init__(self,
verbose=None,
random_seed=None,
shuffle_buffer_size=None,
shuffle=True,
partitions_per_process=None,
run_id=None,
train_steps_per_epoch=None,
validation_steps_per_epoch=None,
transformation_fn=None,
train_reader_num_workers=None,
val_reader_num_workers=None,
reader_pool_type=None,
reader_pool_type='thread',
label_shapes=None,
checkpoint_callback=None,
inmemory_cache_all=False,
Expand Down
70 changes: 12 additions & 58 deletions horovod/spark/keras/remote.py
Expand Up @@ -17,6 +17,7 @@
import io
import math
import os
import warnings

import h5py
import tensorflow as tf
Expand Down Expand Up @@ -49,6 +50,9 @@ def RemoteTrainer(estimator, metadata, keras_utils, run_id, dataset_idx):
should_validate = estimator.getValidation()
random_seed = estimator.getRandomSeed()
user_shuffle_buffer_size = estimator.getShufflingBufferSize()
if user_shuffle_buffer_size is not None:
warnings.warn('shuffle_buffer_size is deprecated, use shuffle instead', DeprecationWarning)
shuffle = estimator.getShuffle()
user_verbose = estimator.getVerbose()
checkpoint_callback = estimator.getCheckpointCallback()
inmemory_cache_all = estimator.getInMemoryCacheAll()
Expand Down Expand Up @@ -84,7 +88,6 @@ def RemoteTrainer(estimator, metadata, keras_utils, run_id, dataset_idx):

# Utility functions
deserialize_keras_model = _deserialize_keras_model_fn()
calculate_shuffle_buffer_size = _calculate_shuffle_buffer_size_fn()
pin_gpu = _pin_gpu_fn()

# Storage
Expand Down Expand Up @@ -134,15 +137,6 @@ def train(serialized_model, train_rows, val_rows, avg_row_size):
else:
tf.random.set_seed(random_seed)

# If user specifies any user_shuffle_buffer_size (even 0), we should honor it.
if user_shuffle_buffer_size is None:
shuffle_buffer_size = calculate_shuffle_buffer_size(
hvd, avg_row_size, train_rows / hvd.size())
else:
if user_shuffle_buffer_size < 0:
raise ValueError("user_shuffle_buffer_size cannot be negative!")
shuffle_buffer_size = user_shuffle_buffer_size

# needs to be deserialized in the with scope
with k.utils.custom_object_scope(custom_objects):
model = deserialize_keras_model(
Expand Down Expand Up @@ -238,7 +232,7 @@ def train(serialized_model, train_rows, val_rows, avg_row_size):

if verbose:
print(f"Training parameters: Epochs: {epochs}, Scaled lr: {scaled_lr}, "
f"Shuffle size: {shuffle_buffer_size}, random_seed: {random_seed}\n"
f"Shuffle: {shuffle}, random_seed: {random_seed}\n"
f"Train rows: {train_rows}, Train batch size: {batch_size}, Train_steps_per_epoch: {steps_per_epoch}\n"
f"Val rows: {val_rows}, Val batch size: {val_batch_size}, Val_steps_per_epoch: {validation_steps}\n"
f"Checkpoint file: {remote_store.checkpoint_path}, Logs dir: {remote_store.logs_path}\n")
Expand All @@ -265,8 +259,9 @@ def train(serialized_model, train_rows, val_rows, avg_row_size):
schema_fields=schema_fields,
transform_spec=transform_spec,
storage_options=storage_options,
# Don't shuffle row groups if shuffle_buffer_size is 0 (non-shuffle case).
shuffle_row_groups=True if shuffle_buffer_size > 0 else False,
shuffle_rows=shuffle,
shuffle_row_groups=shuffle,
seed=random_seed,
**reader_factory_kwargs) as train_reader:
with reader_factory(remote_store.val_data_path,
num_epochs=1,
Expand All @@ -278,14 +273,14 @@ def train(serialized_model, train_rows, val_rows, avg_row_size):
schema_fields=schema_fields,
transform_spec=transform_spec,
storage_options=storage_options,
shuffle_rows=False,
shuffle_row_groups=False,
**reader_factory_kwargs) \
if should_validate else empty_batch_reader() as val_reader:

train_data = make_dataset(train_reader, batch_size, shuffle_buffer_size,
is_batch_reader, shuffle=True if shuffle_buffer_size > 0 else False,
cache=inmemory_cache_all, seed=random_seed)
val_data = make_dataset(val_reader, val_batch_size, shuffle_buffer_size,
train_data = make_dataset(train_reader, batch_size,
is_batch_reader, shuffle=shuffle, cache=inmemory_cache_all)
val_data = make_dataset(val_reader, val_batch_size,
is_batch_reader, shuffle=False, cache=inmemory_cache_all) \
if val_reader else None

Expand Down Expand Up @@ -329,47 +324,6 @@ def deserialize_keras_model(model_bytes, load_model_fn):
return deserialize_keras_model


def _calculate_shuffle_buffer_size_fn():
def calculate_shuffle_buffer_size(hvd, avg_row_size, train_row_count_per_worker):
"""
Determines the shuffling buffer size such that each worker gets at most 1GB for shuffling
buffer such that on a single machine, among all the workers on that machine, at most
memory_cap_gb GB are allocated for shuffling buffer. Also, it ensures that the buffer size
is identical among all the workers.
example 1:
memory_cap_gb = 4
machine1: 8 workers
machine2: 3 workers
shuffle_buffer_size = 0.5 GB
example 2:
memory_cap_gb = 4
machine1: 2 workers
machine2: 3 workers
shuffle_buffer_size = 1 GB
example 3:
memory_cap_gb = 4
machine1: 2 workers
machine2: 8 workers
machine3: 5 workers
shuffle_buffer_size = 0.5 GB
"""
local_size = hvd.local_size()
local_sizes = hvd.allgather([local_size])
max_local_size = int(max(local_sizes))

if max_local_size > TOTAL_BUFFER_MEMORY_CAP_GIB:
shuffle_buffer_size = TOTAL_BUFFER_MEMORY_CAP_GIB * BYTES_PER_GIB / avg_row_size / max_local_size
else:
shuffle_buffer_size = BYTES_PER_GIB / avg_row_size

return int(min(shuffle_buffer_size, train_row_count_per_worker))

return calculate_shuffle_buffer_size


def _pin_gpu_fn():
# Horovod: pin GPU to be used to process local rank (one GPU per process)
return _pin_gpu_tensorflow2_fn() if LooseVersion(tf.__version__) >= LooseVersion('2.0.0') \
Expand Down
11 changes: 5 additions & 6 deletions horovod/spark/keras/util.py
Expand Up @@ -60,20 +60,19 @@ def make_dataset_fn(feature_columns, label_columns, sample_weight_col, metadata,
has_sparse_col, sample_weight_col, feature_columns,
label_columns, input_shapes, label_shapes, output_names)

def fn(reader, batch_size, shuffle_buffer_size, is_batch_reader, shuffle=False, cache=False, seed=None):
def fn(reader, batch_size, is_batch_reader, shuffle=True, cache=False):
from petastorm.tf_utils import make_petastorm_dataset

# Samples come from Petastorm reader are already shuffled if needed.
# We don't need to shuffle again in Tensorflow dataset.
dataset = make_petastorm_dataset(reader)
if is_batch_reader:
dataset = dataset.apply(tf.data.experimental.unbatch())

# Apply cache() before shuffle, so we can reshuffle in each iteration.
if cache:
# cache() can only be applied without shuffle to generate same samples per epoch.
if cache and not shuffle:
dataset = dataset.cache()

if shuffle:
dataset = dataset.shuffle(shuffle_buffer_size, seed=seed)

# Use tf.data.Dataset.repeat() to set up an infinite iterator
# and to enable ranks to perform training and validation with
# unequal number of samples.
Expand Down
21 changes: 13 additions & 8 deletions horovod/spark/lightning/datamodule.py
Expand Up @@ -20,9 +20,9 @@ def __init__(
has_val: bool = True,
train_batch_size: int = 32,
val_batch_size: int = 32,
shuffle_size: int = 1000,
shuffle: bool = True,
num_reader_epochs=None,
reader_pool_type: str = "process",
reader_pool_type: str = "thread",
reader_worker_count: int = 2,
transformation=None,
transformation_edit_fields=None,
Expand All @@ -38,6 +38,7 @@ def __init__(
debug_data_loader: bool = False,
train_async_data_loader_queue_size: int = None,
val_async_data_loader_queue_size: int = None,
seed: int = None,
**kwargs):
super().__init__()
self.train_dir = train_dir
Expand All @@ -46,7 +47,7 @@ def __init__(
self.has_val = has_val
self.train_batch_size = train_batch_size
self.val_batch_size = val_batch_size
self.shuffle_size = shuffle_size
self.shuffle = shuffle
self.num_reader_epochs = num_reader_epochs
self.reader_pool_type = reader_pool_type
self.reader_worker_count = reader_worker_count
Expand All @@ -64,6 +65,7 @@ def __init__(
self.debug_data_loader = debug_data_loader
self.train_async_data_loader_queue_size = train_async_data_loader_queue_size
self.val_async_data_loader_queue_size = val_async_data_loader_queue_size
self.seed = seed

if debug_data_loader:
print("Creating data_module")
Expand Down Expand Up @@ -101,9 +103,9 @@ def setup(self, stage=None):
schema_fields=self.schema_fields,
storage_options=self.storage_options,
transform_spec=transform_spec,
# Don't shuffle row groups
# without shuffling.
shuffle_row_groups=True if self.shuffle_size > 0 else False,
shuffle_rows=self.shuffle,
shuffle_row_groups=self.shuffle,
seed=self.seed,
**reader_factory_kwargs)
if self.has_val:
self.val_reader = reader_factory(
Expand All @@ -117,6 +119,7 @@ def setup(self, stage=None):
schema_fields=self.schema_fields,
storage_options=self.storage_options,
transform_spec=transform_spec,
shuffle_rows=False,
shuffle_row_groups=False,
**reader_factory_kwargs)

Expand Down Expand Up @@ -151,11 +154,13 @@ def train_dataloader(self):
if self.inmemory_cache_all:
# Use inmem dataloader
dataloader_class = PytorchInmemAsyncDataLoader
kwargs['shuffle'] = self.shuffle_size > 0
kwargs['shuffle'] = self.shuffle
kwargs['num_epochs'] = self.num_train_epochs
else:
dataloader_class = PytorchInfiniteAsyncDataLoader
kwargs['shuffling_queue_capacity'] = self.shuffle_size
# Don't need to shuffle again in dataloder level.
# Reader shuffles rows in every row group since Petastorm 0.12.0.
kwargs['shuffling_queue_capacity'] = 0

if self.debug_data_loader:
kwargs['debug_data_loader'] = self.debug_data_loader
Expand Down
10 changes: 6 additions & 4 deletions horovod/spark/lightning/estimator.py
Expand Up @@ -135,17 +135,18 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable,
for training. Not needed for lightning model.
partitions_per_process: (Optional) Number of Parquet partitions to assign per worker
process from `num_proc` (default: 10).
reader_pool_type: (Optional) Type of worker pool used to parallelize reading data from
the dataset. Should be one of ['thread', 'process']. Defaults to
'process'.
reader_pool_type: (Optional) Type of Petastorm worker pool used to parallelize reading data from
the dataset. Should be one of ['thread', 'process', 'dummy']. Defaults to
'thread'.
run_id: (Optional) unique ID for this run for organization in the Store. Will be
automatically assigned if not provided.
sample_weight_col: (Optional) column indicating the weight of each sample.
random_seed: Optional random seed to use for PyTorch Lightning. Default: None.
shuffle_buffer_size: Optional size of in-memory shuffle buffer in rows (on training data).
shuffle_buffer_size: (Deprecated) Optional size of in-memory shuffle buffer in rows (on training data).
Allocating a larger buffer size increases randomness of shuffling at
the cost of more host memory. Defaults to estimating with an assumption
of 4GB of memory per host. Set shuffle_buffer_size=0 would turn off shuffle.
shuffle: (Optional) Whether to shuffle training samples or not. Defaults to True.
store: Store object that abstracts reading and writing of intermediate data and
run results.
terminate_on_nan : (Optinoal) terminate the training process on seeing NaN output.
Expand Down Expand Up @@ -249,6 +250,7 @@ def __init__(self,
verbose=1,
random_seed=None,
shuffle_buffer_size=None,
shuffle=True,
partitions_per_process=None,
run_id=None,
train_minibatch_fn=None,
Expand Down

0 comments on commit 001260a

Please sign in to comment.