Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an option to pin to gpu for all estimators #3526

Merged
merged 6 commits into from Apr 30, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
23 changes: 22 additions & 1 deletion horovod/spark/common/params.py
Expand Up @@ -96,6 +96,14 @@ class EstimatorParams(Params):

label_shapes = Param(Params._dummy(), 'label_shapes', 'specifies the shape (or shapes) of the label column (or columns)')

inmemory_cache_all = Param(Params._dummy(), 'inmemory_cache_all',
'Cache the data in memory for training and validation.',
typeConverter=TypeConverters.toBoolean)

pin_gpu = Param(Params._dummy(), 'pin_gpu',
'Whether to pin the traininig process to the GPU. Defaults to True.',
Tixxx marked this conversation as resolved.
Show resolved Hide resolved
typeConverter=TypeConverters.toBoolean)

def __init__(self):
super(EstimatorParams, self).__init__()

Expand Down Expand Up @@ -129,7 +137,9 @@ def __init__(self):
train_reader_num_workers=2,
val_reader_num_workers=2,
reader_pool_type='process',
label_shapes=None)
label_shapes=None,
inmemory_cache_all=False,
pin_gpu=True)

def _check_params(self, metadata):
model = self.getModel()
Expand Down Expand Up @@ -334,6 +344,17 @@ def setLabelShapes(self, value):
def getLabelShapes(self):
return self.getOrDefault(self.label_shapes)

def setInMemoryCacheAll(self, value):
return self._set(inmemory_cache_all=value)

def getInMemoryCacheAll(self):
return self.getOrDefault(self.inmemory_cache_all)

def setPinGpu(self, value):
self._set(pin_gpu=value)

def getPinGpu(self):
return self.getOrDefault(self.pin_gpu)

class ModelParams(HasOutputCols):
history = Param(Params._dummy(), 'history', 'history')
Expand Down
14 changes: 3 additions & 11 deletions horovod/spark/keras/estimator.py
Expand Up @@ -147,14 +147,12 @@ class KerasEstimator(HorovodEstimator, KerasEstimatorParamsReadable,
inmemory_cache_all: boolean value. Cache the data in memory for training and validation. Default: False.
backend_env: dict to add to the environment of the backend. Defaults to setting the java heap size to
2G min and max for libhdfs through petastorm
pin_gpu: Whether to pin the traininig process to the GPU. Defaults to True.
"""

custom_objects = Param(Params._dummy(), 'custom_objects', 'custom objects')
checkpoint_callback = Param(Params._dummy(), 'checkpoint_callback',
'model checkpointing callback')
inmemory_cache_all = Param(Params._dummy(), 'inmemory_cache_all',
'Cache the data in memory for training and validation.',
typeConverter=TypeConverters.toBoolean)
backend_env = Param(Params._dummy(), "backend_env",
"dict to add to the environment of the command run on the environment")

Expand Down Expand Up @@ -192,14 +190,14 @@ def __init__(self,
label_shapes=None,
checkpoint_callback=None,
inmemory_cache_all=False,
backend_env=None):
backend_env=None,
pin_gpu=True):

super(KerasEstimator, self).__init__()

self._setDefault(optimizer=None,
custom_objects={},
checkpoint_callback=None,
inmemory_cache_all=False,
backend_env={'LIBHDFS_OPTS': '-Xms2048m -Xmx2048m'})

kwargs = self._input_kwargs
Expand Down Expand Up @@ -235,12 +233,6 @@ def setCheckpointCallback(self, value):
def getCheckpointCallback(self):
return self.getOrDefault(self.checkpoint_callback)

def setInMemoryCacheAll(self, value):
return self._set(inmemory_cache_all=value)

def getInMemoryCacheAll(self):
return self.getOrDefault(self.inmemory_cache_all)

def setBackendEnv(self, value):
self._set(backend_env=value)

Expand Down
14 changes: 11 additions & 3 deletions horovod/spark/keras/remote.py
Expand Up @@ -52,6 +52,7 @@ def RemoteTrainer(estimator, metadata, keras_utils, run_id, dataset_idx):
user_verbose = estimator.getVerbose()
checkpoint_callback = estimator.getCheckpointCallback()
inmemory_cache_all = estimator.getInMemoryCacheAll()
should_pin_gpu = estimator.getPinGpu()

# Data reader parameters
train_reader_worker_count = estimator.getTrainReaderNumWorker()
Expand Down Expand Up @@ -111,7 +112,16 @@ def train(serialized_model, train_rows, val_rows, avg_row_size):
hvd = get_horovod()
hvd.init()

pin_gpu(hvd, tf, k)
# Verbose mode 1 will print a progress bar
verbose = user_verbose if hvd.rank() == 0 else 0

if should_pin_gpu:
if verbose:
print(f"Pinning current process to the GPU.")
Tixxx marked this conversation as resolved.
Show resolved Hide resolved
pin_gpu(hvd, tf, k)
else:
if verbose:
print(f"Skip pinning current process to the GPU.")
Tixxx marked this conversation as resolved.
Show resolved Hide resolved

if random_seed is not None:
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
Expand All @@ -137,8 +147,6 @@ def train(serialized_model, train_rows, val_rows, avg_row_size):
scaled_lr = k.backend.get_value(model.optimizer.lr) * hvd.size()
k.backend.set_value(model.optimizer.lr, scaled_lr)

# Verbose mode 1 will print a progress bar
verbose = user_verbose if hvd.rank() == 0 else 0

if verbose:
print(f"Shared lib path is pointing to: {_horovod.common.process_sets._basics.MPI_LIB_CTYPES}")
Expand Down
15 changes: 3 additions & 12 deletions horovod/spark/lightning/estimator.py
Expand Up @@ -181,6 +181,7 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable,
debug_data_loader: (Optional)Debugging flag for data loader.
train_async_data_loader_queue_size: (Optional) Size of train async data loader queue.
val_async_data_loader_queue_size: (Optional) Size of val async data loader queue.
pin_gpu: Whether to pin the traininig process to the GPU. Defaults to True.
"""

input_shapes = Param(Params._dummy(), 'input_shapes', 'input layer shapes')
Expand All @@ -189,10 +190,6 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable,
train_minibatch_fn = Param(Params._dummy(), 'train_minibatch_fn',
'functions that construct the minibatch train function for torch')

inmemory_cache_all = Param(Params._dummy(), 'inmemory_cache_all',
'Cache the data in memory for training and validation.',
typeConverter=TypeConverters.toBoolean)

num_gpus = Param(Params._dummy(), 'num_gpus',
'Number of gpus per process, default to 1 when CUDA is available in the backend, otherwise 0.')

Expand Down Expand Up @@ -266,14 +263,14 @@ def __init__(self,
profiler=None,
debug_data_loader=False,
train_async_data_loader_queue_size=None,
val_async_data_loader_queue_size=None):
val_async_data_loader_queue_size=None,
pin_gpu=True):

super(TorchEstimator, self).__init__()
self._setDefault(loss_constructors=None,
input_shapes=None,
train_minibatch_fn=None,
transformation_fn=None,
inmemory_cache_all=False,
num_gpus=None,
logger=None,
log_every_n_steps=50,
Expand Down Expand Up @@ -315,12 +312,6 @@ def setLossConstructors(self, value):
def getLossConstructors(self):
return self.getOrDefault(self.loss_constructors)

def setInMemoryCacheAll(self, value):
return self._set(inmemory_cache_all=value)

def getInMemoryCacheAll(self):
return self.getOrDefault(self.inmemory_cache_all)

def setNumGPUs(self, value):
return self._set(num_gpus=value)

Expand Down
6 changes: 5 additions & 1 deletion horovod/spark/lightning/remote.py
Expand Up @@ -64,6 +64,7 @@ def RemoteTrainer(estimator, metadata, ckpt_bytes, run_id, dataset_idx, train_ro
debug_data_loader = estimator.getDebugDataLoader()
train_async_data_loader_queue_size = estimator.getTrainAsyncDataLoaderQueueSize()
val_async_data_loader_queue_size = estimator.getValAsyncDataLoaderQueueSize()
should_pin_gpu = estimator.getPinGpu()

# get logger
logger = estimator.getLogger()
Expand Down Expand Up @@ -194,7 +195,10 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
f"Val rows: {val_rows}, Val batch size: {val_batch_size}, Val_steps_per_epoch: {_val_steps_per_epoch}\n"
f"Checkpoint file: {remote_store.checkpoint_path}, Logs dir: {remote_store.logs_path}\n")

cuda_available = torch.cuda.is_available()
if not should_pin_gpu and verbose:
print("Skip pinning current process to the GPU.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not using the logger? Why is there verbose when there is a logger?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logger doesn't write to stdout and stderr properly in this function since it's run in ray executor. The train_logger is for passing some specialized loggers(not writing to stdout and stderr directly) to pytorch lightning. I have tried using some generic logger modules here, they either failed to serialize or no output.


cuda_available = torch.cuda.is_available() and should_pin_gpu
Tixxx marked this conversation as resolved.
Show resolved Hide resolved
# We need to check all ranks have same device type for traning.
# Horovod doesn't support heterogeneous allreduce for gradients.
cuda_avail_list = hvd.allgather_object(cuda_available, name='device type')
Expand Down
18 changes: 5 additions & 13 deletions horovod/spark/torch/estimator.py
Expand Up @@ -147,6 +147,8 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable,
val_reader_num_workers: Similar to the train_reader_num_workers.
reader_pool_type: Type of worker pool used to parallelize reading data from the dataset.
Should be one of ['thread', 'process']. Defaults to 'process'.
inmemory_cache_all: (Optional) Cache the data in memory for training and validation.
pin_gpu: Whether to pin the traininig process to the GPU. Defaults to True.
"""

input_shapes = Param(Params._dummy(), 'input_shapes', 'input layer shapes')
Expand All @@ -155,10 +157,6 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable,
train_minibatch_fn = Param(Params._dummy(), 'train_minibatch_fn',
'functions that construct the minibatch train function for torch')

inmemory_cache_all = Param(Params._dummy(), 'inmemory_cache_all',
'Cache the data in memory for training and validation.',
typeConverter=TypeConverters.toBoolean)

@keyword_only
def __init__(self,
num_proc=None,
Expand Down Expand Up @@ -193,14 +191,14 @@ def __init__(self,
val_reader_num_workers=None,
reader_pool_type=None,
label_shapes=None,
inmemory_cache_all=False):
inmemory_cache_all=False,
pin_gpu=True):

super(TorchEstimator, self).__init__()
self._setDefault(loss_constructors=None,
input_shapes=None,
train_minibatch_fn=None,
transformation_fn=None,
inmemory_cache_all=False)
transformation_fn=None)

kwargs = self._input_kwargs

Expand All @@ -227,12 +225,6 @@ def setLossConstructors(self, value):
def getLossConstructors(self):
return self.getOrDefault(self.loss_constructors)

def setInMemoryCacheAll(self, value):
return self._set(inmemory_cache_all=value)

def getInMemoryCacheAll(self):
return self.getOrDefault(self.inmemory_cache_all)

def _get_optimizer(self):
return self.getOrDefault(self.optimizer)

Expand Down
6 changes: 5 additions & 1 deletion horovod/spark/torch/remote.py
Expand Up @@ -60,6 +60,7 @@ def RemoteTrainer(estimator, metadata, last_checkpoint_state, run_id, dataset_id
transformation_fn = estimator.getTransformationFn()
transformation = transformation_fn if transformation_fn else None
inmemory_cache_all = estimator.getInMemoryCacheAll()
should_pin_gpu = estimator.getPinGpu()

# If loss weight is not provided, use equal loss for all the labels
loss_weights = estimator.getLossWeights()
Expand Down Expand Up @@ -134,7 +135,10 @@ def train(serialized_model, optimizer_cls, model_opt_state_serialized,
raise ValueError("user_shuffle_buffer_size cannot be negative!")
shuffle_buffer_size = user_shuffle_buffer_size

cuda_available = torch.cuda.is_available()
if not should_pin_gpu and user_verbose:
print("Skip pinning current process to the GPU.")

cuda_available = torch.cuda.is_available() and should_pin_gpu
Tixxx marked this conversation as resolved.
Show resolved Hide resolved
# We need to check all ranks have same device type for traning.
# Horovod doesn't support heterogeneous allreduce for gradients.
cuda_avail_list = hvd.allgather_object(cuda_available, name='device type')
Expand Down
5 changes: 4 additions & 1 deletion test/integration/test_spark_keras.py
Expand Up @@ -98,7 +98,10 @@ def test_fit_model(self):
batch_size=1,
random_seed=1,
epochs=3,
verbose=2)
verbose=2,
pin_gpu=False)

assert not keras_estimator.getPinGpu()

keras_model = keras_estimator.fit(df)

Expand Down