Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support gradient accumulation in spark torch estimator #3681

Merged
merged 7 commits into from
Sep 13, 2022
3 changes: 3 additions & 0 deletions examples/spark/pytorch/pytorch_spark_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
help='temporary working directory to write intermediate files (prefix with hdfs:// to use HDFS)')
parser.add_argument('--data-dir', default='/tmp',
help='location of the training dataset in the local filesystem (will be downloaded if needed)')
parser.add_argument('--backward-passes-per-step', type=int, default=1,
help='number of backward passes to perform before calling hvd.allreduce')

if __name__ == '__main__':
args = parser.parse_args()
Expand Down Expand Up @@ -114,6 +116,7 @@ def forward(self, features):
batch_size=args.batch_size,
epochs=args.epochs,
validation=0.1,
backward_passes_per_step=args.backward_passes_per_step,
verbose=1)

torch_model = torch_estimator.fit(train_df).setOutputCols(['label_prob'])
Expand Down
15 changes: 14 additions & 1 deletion horovod/spark/common/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ class EstimatorParams(Params):
'https://docs.python.org/3/library/multiprocessing.html#multiprocessing.get_start_method.'
'This param defaults to None.',
typeConverter=TypeConverters.toString)

backward_passes_per_step = Param(Params._dummy(), 'backward_passes_per_step',
'Number of backward passes to perform before calling hvd.allreduce. '
'This allows accumulating updates over multiple mini-batches before reducing and applying them. '
'This param defaults to 1.',
typeConverter=TypeConverters.toInt)

def __init__(self):
super(EstimatorParams, self).__init__()
Expand Down Expand Up @@ -175,7 +181,8 @@ def __init__(self):
label_shapes=None,
inmemory_cache_all=False,
use_gpu=True,
mp_start_method=None)
mp_start_method=None,
backward_passes_per_step=1)

def _check_params(self, metadata):
model = self.getModel()
Expand Down Expand Up @@ -427,6 +434,12 @@ def setMpStartMethod(self, value):

def getMpStartMethod(self):
return self.getOrDefault(self.mp_start_method)

def setBackwardPassesPerStep(self, value):
self._set(backward_passes_per_step=value)

def getBackwardPassesPerStep(self):
return self.getOrDefault(self.backward_passes_per_step)

class ModelParams(HasOutputCols):
history = Param(Params._dummy(), 'history', 'history')
Expand Down
6 changes: 5 additions & 1 deletion horovod/spark/torch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable,
inmemory_cache_all: (Optional) Cache the data in memory for training and validation.
use_gpu: Whether to use the GPU for training. Defaults to True.
mp_start_method: The method to use to start multiprocessing. Defaults to None.
backward_passes_per_step: Number of backward passes to perform before calling hvd.allreduce.
This allows accumulating updates over multiple mini-batches before
reducing and applying them. Defaults to 1.
"""

input_shapes = Param(Params._dummy(), 'input_shapes', 'input layer shapes')
Expand Down Expand Up @@ -196,7 +199,8 @@ def __init__(self,
label_shapes=None,
inmemory_cache_all=False,
use_gpu=True,
mp_start_method=None):
mp_start_method=None,
backward_passes_per_step=1):

super(TorchEstimator, self).__init__()
self._setDefault(loss_constructors=None,
Expand Down
15 changes: 11 additions & 4 deletions horovod/spark/torch/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def RemoteTrainer(estimator, metadata, last_checkpoint_state, run_id, dataset_id
inmemory_cache_all = estimator.getInMemoryCacheAll()
should_use_gpu = estimator.getUseGpu()
mp_start_method = estimator.getMpStartMethod()
backward_passes_per_step = estimator.getBackwardPassesPerStep()

# If loss weight is not provided, use equal loss for all the labels
loss_weights = estimator.getLossWeights()
Expand Down Expand Up @@ -194,6 +195,7 @@ def train(serialized_model, optimizer_cls, model_opt_state_serialized,
# Pass the compression arg only if it is specified by the user.
dist_optimizer_args['compression'] = gradient_compression
# Horovod: wrap optimizer with DistributedOptimizer.
dist_optimizer_args['backward_passes_per_step'] = backward_passes_per_step
thinkall marked this conversation as resolved.
Show resolved Hide resolved
optimizer = hvd.DistributedOptimizer(**dist_optimizer_args)

# This function takes the current optimizer and constructs a new optimizer with the
Expand Down Expand Up @@ -370,10 +372,12 @@ def _train(epoch):
row = next(train_loader_iter)
inputs, labels, sample_weights = prepare_batch(row)
outputs, loss = train_minibatch(model, optimizer, transform_outputs,
loss_fn, inputs, labels, sample_weights)
loss_fn, inputs, labels, sample_weights,
backward_passes_per_step, batch_idx)
update_metrics(metric_value_groups, outputs, labels)
train_loss.update(loss)
print_metrics(batch_idx, train_loss, metric_value_groups, 'train')
optimizer.step()

return aggregate_metrics('train', epoch, train_loss, metric_value_groups)

Expand Down Expand Up @@ -454,13 +458,16 @@ def _validate(epoch):


def _train_minibatch_fn():
def train_minibatch(model, optimizer, transform_outputs, loss_fn, inputs, labels, sample_weights):
optimizer.zero_grad()
def train_minibatch(model, optimizer, transform_outputs, loss_fn, inputs, labels, sample_weights, backward_passes_per_step, batch_idx):
if batch_idx % backward_passes_per_step == 0:
if batch_idx != 0:
optimizer.step()
optimizer.zero_grad()
outputs = model(*inputs)
outputs, labels = transform_outputs(outputs, labels)
loss = loss_fn(outputs, labels, sample_weights)
loss.div_(float(backward_passes_per_step))
thinkall marked this conversation as resolved.
Show resolved Hide resolved
loss.backward()
optimizer.step()
return outputs, loss
return train_minibatch

Expand Down
3 changes: 3 additions & 0 deletions test/integration/test_spark_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def test_fit_model(self):
epochs=3,
random_seed=1,
verbose=2,
backward_passes_per_step=3,
sample_weight_col='weight')

torch_model = torch_estimator.fit(df)
Expand Down Expand Up @@ -125,6 +126,7 @@ def test_restore_from_checkpoint(self):
batch_size=1,
epochs=1,
verbose=2,
backward_passes_per_step=3,
run_id=run_id)

torch_estimator._load_checkpoint = mock.Mock(side_effect=torch_estimator._load_checkpoint)
Expand Down Expand Up @@ -343,6 +345,7 @@ def test_torch_direct_parquet_train(self):
verbose=2,
reader_pool_type=reader_pool_type,
inmemory_cache_all=inmemory_cache_all,
backward_passes_per_step=3,
validation=validation)

# To make sure that setLoss works with non-list loss.
Expand Down