Skip to content

Commit

Permalink
Merge pull request #732 from mv1388/ddp_callback_registration_check
Browse files Browse the repository at this point in the history
DDP _spawn_fit callback registration and execution added
  • Loading branch information
mv1388 committed Aug 13, 2022
2 parents 4241bdd + 3d9ef6a commit 3867a2b
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 23 deletions.
55 changes: 40 additions & 15 deletions aitoolbox/torchtrain/train_loop/components/callback_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, train_loop_obj):
self.registered_cbs = [
self.cbs_on_epoch_begin, self.cbs_on_epoch_end,
self.cbs_on_train_begin, self.cbs_on_train_end,
self.cbs_on_batch_begin, self.cbs_on_batch_end,
self.cbs_on_batch_begin, self.cbs_on_batch_end,
self.cbs_on_after_gradient_update, self.cbs_on_after_optimizer_step,
self.cbs_on_multiprocess_start,
self.cbs_on_after_batch_prediction
Expand All @@ -49,11 +49,13 @@ def register_callbacks(self, callbacks, cache_callbacks=False):
Normally, this is called from inside the train loop by the TrainLoop itself. Basically train loop "registers"
itself with each of the provided callbacks.
Add via append new provided callbacks to the existing ones.
Args:
callbacks (list or None): list of callbacks
cache_callbacks (bool): should provided callbacks be cached and not yet registered. First subsequent time
this method is called without ``cache_callbacks`` enabled all the previously cached callbacks are added
and also registered with the current list of callbacks.
callbacks (list or None): list of new callbacks to be added (appended)
cache_callbacks (bool): should the provided callbacks be cached and not yet registered. First subsequent
time this method is called without ``cache_callbacks`` enabled all the previously cached callbacks
are added and also registered with the current list of callbacks.
Returns:
None
Expand All @@ -70,18 +72,46 @@ def register_callbacks(self, callbacks, cache_callbacks=False):

if callbacks is not None and len(callbacks) > 0:
self.enforce_callbacks_quality(callbacks)

self.train_loop_obj.callbacks += [
cb.register_train_loop_object(self.train_loop_obj) for cb in callbacks
if self.train_loop_obj.device.index is None or
cb.device_idx_execution is None or
(cb.device_idx_execution is not None and cb.device_idx_execution == self.train_loop_obj.device.index)
if self.should_enable_callback(cb)
]

if not all(0 == cb.execution_order for cb in self.train_loop_obj.callbacks):
self.train_loop_obj.callbacks = sorted(self.train_loop_obj.callbacks, key=lambda cb: cb.execution_order)

# Note: using `callbacks` here instead of `self.train_loop_obj.callbacks` is correct.
# Provide original input `callbacks` to this method instead of `self.train_loop_obj.callbacks`
# which we added new callbacks to above. In case some callbacks were already registered at some earlier
# time this prevents their duplication int the execution-position-split self.registered_cbs.
self.split_on_execution_position(callbacks, register_train_loop=False)

def should_enable_callback(self, callback):
"""Determine if callback should be enabled and executed to be in accordance with the GPU device setting
Always true in case of training on single device (CPU or one GPU).
In case of multi (GPU) device training such as DDP, this function checks if a callback should be executed on
the particular GPU device. If the callback doesn't have any ``device_idx_execution`` set than it is executed
on all the GPUs. In case the parameter is set in the callback than this function will only be True when the set
``device_idx_execution`` in the callback and the train loop's GPU device index match. In other words
the callback will be executed only in the DDP process which sits on the matching GPU.
Args:
callback (AbstractCallback): callback which will be checked if it should be enabled during the particular
train loop run
Returns:
bool: if the provided callback should be enabled or disabled based on (GPU) device index matching.
"""
return self.train_loop_obj.device.index is None or \
callback.device_idx_execution is None or \
(
callback.device_idx_execution is not None and
callback.device_idx_execution == self.train_loop_obj.device.index
)

def execute_epoch_begin(self):
for callback in self.cbs_on_epoch_begin:
callback.on_epoch_begin()
Expand Down Expand Up @@ -125,10 +155,7 @@ def execute_after_batch_prediction(self, y_pred_batch, y_test_batch, metadata_ba
def split_on_execution_position(self, callbacks, register_train_loop=False):
if callbacks is not None and len(callbacks) > 0:
for callback in callbacks:
if self.train_loop_obj.device.index is None or \
callback.device_idx_execution is None or \
(callback.device_idx_execution is not None and
callback.device_idx_execution == self.train_loop_obj.device.index):
if self.should_enable_callback(callback):

if register_train_loop:
callback = callback.register_train_loop_object(self.train_loop_obj)
Expand Down Expand Up @@ -191,9 +218,7 @@ def mp_filter_callbacks(self):
]

def _mp_filter_cb_list(self, callbacks_list):
return [cb for cb in callbacks_list
if cb.device_idx_execution is None or
(cb.device_idx_execution is not None and cb.device_idx_execution == self.train_loop_obj.device.index)]
return [cb for cb in callbacks_list if self.should_enable_callback(cb)]

def enforce_callbacks_quality(self, callbacks):
for cb in callbacks:
Expand Down
5 changes: 5 additions & 0 deletions aitoolbox/torchtrain/train_loop/train_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,12 @@ def _spawn_fit(self, gpu, ddp_args, num_epochs, num_iterations, callbacks, grad_
torch.manual_seed(0)
torch.cuda.set_device(gpu)
self.device = torch.device(f"cuda:{gpu}")

# DDP MP device filter any existing callbacks and add new ones
self.callbacks_handler.mp_filter_callbacks()
self.callbacks_handler.register_callbacks(callbacks)
# Set callbacks to None, so they aren't double added/registered later in `_train()` method
callbacks = None

# Optionally load data in-process
self.callbacks_handler.register_callbacks(in_process_data_load)
Expand Down
Binary file modified dist/aitoolbox-1.6.1-py3-none-any.whl
Binary file not shown.
Binary file modified dist/aitoolbox-1.6.1.tar.gz
Binary file not shown.
121 changes: 113 additions & 8 deletions tests/test_torchtrain/test_tl_components/test_callback_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,35 @@ def test_register_callbacks(self):
self.assertEqual(cb_handler.cbs_on_epoch_begin, [])
self.assertEqual(cb_handler.cbs_on_epoch_end, [])

def test_should_enable_callback(self):
train_loop = TrainLoop(NetUnifiedBatchFeed(), None, None, None, None, None)
cb_handler = CallbacksHandler(train_loop)
train_loop.callbacks_handler = cb_handler

# train_loop device.index is None
train_loop.device = torch.device(f"cpu")
self.assertTrue(cb_handler.should_enable_callback(AbstractCallback('dummy cb')))
self.assertTrue(cb_handler.should_enable_callback(AbstractCallback('dummy cb', device_idx_execution=None)))
self.assertTrue(cb_handler.should_enable_callback(AbstractCallback('dummy cb', device_idx_execution=0)))
self.assertTrue(cb_handler.should_enable_callback(AbstractCallback('dummy cb', device_idx_execution=2)))
self.assertTrue(cb_handler.should_enable_callback(AbstractCallback('dummy cb', device_idx_execution=3)))

# train_loop device.index is not None
train_loop.device = torch.device(f"cuda:0")
self.assertTrue(cb_handler.should_enable_callback(AbstractCallback('dummy cb')))
self.assertTrue(cb_handler.should_enable_callback(AbstractCallback('dummy cb', device_idx_execution=None)))
self.assertTrue(cb_handler.should_enable_callback(AbstractCallback('dummy cb', device_idx_execution=0)))
self.assertFalse(cb_handler.should_enable_callback(AbstractCallback('dummy cb', device_idx_execution=2)))
self.assertFalse(cb_handler.should_enable_callback(AbstractCallback('dummy cb', device_idx_execution=3)))

# train_loop device.index is not None
train_loop.device = torch.device(f"cuda:3")
self.assertTrue(cb_handler.should_enable_callback(AbstractCallback('dummy cb')))
self.assertTrue(cb_handler.should_enable_callback(AbstractCallback('dummy cb', device_idx_execution=None)))
self.assertFalse(cb_handler.should_enable_callback(AbstractCallback('dummy cb', device_idx_execution=0)))
self.assertFalse(cb_handler.should_enable_callback(AbstractCallback('dummy cb', device_idx_execution=2)))
self.assertTrue(cb_handler.should_enable_callback(AbstractCallback('dummy cb', device_idx_execution=3)))

def test_enforce_callbacks_quality(self):
train_loop = TrainLoop(NetUnifiedBatchFeed(), None, None, None, None, None)
cb_handler = train_loop.callbacks_handler
Expand Down Expand Up @@ -294,6 +323,82 @@ def test_handler_cache_callbacks_further_add(self):
)
self.assertEqual(cb_handler.callbacks_cache, [])

def test_mp_filter_callbacks_device_idx_none(self):
train_loop = TrainLoop(NetUnifiedBatchFeed(), None, None, None, None, None)
cb_handler = CallbacksHandler(train_loop)
train_loop.callbacks_handler = cb_handler

device_idx_execution = None
callbacks = [
BatchBeginCB(device_idx_execution=device_idx_execution),
BatchBeginTrainBeginCB(device_idx_execution=device_idx_execution),
BatchBeginTrainBeginAfterOptiCB(device_idx_execution=device_idx_execution),
AfterBatchPredictionCB(True, device_idx_execution=device_idx_execution)
]
cb_handler.register_callbacks(callbacks)
train_loop.device = torch.device(f"cpu")
pre_filtered_reg_cbs = cb_handler.registered_cbs
cb_handler.mp_filter_callbacks()
self.assertEqual(train_loop.callbacks, callbacks)
self.assertEqual(pre_filtered_reg_cbs, cb_handler.registered_cbs)

def test_mp_filter_callbacks_device_match(self):
train_loop = TrainLoop(NetUnifiedBatchFeed(), None, None, None, None, None)
cb_handler = CallbacksHandler(train_loop)
train_loop.callbacks_handler = cb_handler

device_idx_execution = 1
callbacks = [
BatchBeginCB(device_idx_execution=device_idx_execution),
BatchBeginTrainBeginCB(device_idx_execution=device_idx_execution),
BatchBeginTrainBeginAfterOptiCB(device_idx_execution=device_idx_execution),
AfterBatchPredictionCB(True, device_idx_execution=device_idx_execution)
]
cb_handler.register_callbacks(callbacks)
train_loop.device = torch.device(f"cuda:1")
pre_filtered_reg_cbs = cb_handler.registered_cbs
cb_handler.mp_filter_callbacks()
self.assertEqual(train_loop.callbacks, callbacks)
self.assertEqual(pre_filtered_reg_cbs, cb_handler.registered_cbs)

def test_mp_filter_callbacks_device_mismatch(self):
train_loop = TrainLoop(NetUnifiedBatchFeed(), None, None, None, None, None)
cb_handler = CallbacksHandler(train_loop)
train_loop.callbacks_handler = cb_handler

device_idx_execution = 1
callbacks = [
BatchBeginCB(device_idx_execution=device_idx_execution),
BatchBeginTrainBeginCB(device_idx_execution=device_idx_execution),
BatchBeginTrainBeginAfterOptiCB(device_idx_execution=device_idx_execution),
AfterBatchPredictionCB(True, device_idx_execution=device_idx_execution)
]
cb_handler.register_callbacks(callbacks)
train_loop.device = torch.device(f"cuda:0")
cb_handler.mp_filter_callbacks()
self.assertEqual(train_loop.callbacks, [])
self.assertEqual(cb_handler.registered_cbs, [[], [], [], [], [], [], [], [], [], []])

def test_mp_filter_callbacks_device_semi_mismatch(self):
train_loop = TrainLoop(NetUnifiedBatchFeed(), None, None, None, None, None)
cb_handler = CallbacksHandler(train_loop)
train_loop.callbacks_handler = cb_handler

cb_1 = BatchBeginCB(device_idx_execution=0)
cb_2 = BatchBeginTrainBeginCB(device_idx_execution=1)
cb_3 = BatchBeginTrainBeginAfterOptiCB(device_idx_execution=0)
cb_4 = AfterBatchPredictionCB(True, device_idx_execution=1)
callbacks = [cb_1, cb_2, cb_3, cb_4]

cb_handler.register_callbacks(callbacks)
train_loop.device = torch.device(f"cuda:0")
cb_handler.mp_filter_callbacks()
self.assertEqual(
train_loop.callbacks,
[cb for cb in callbacks if cb.device_idx_execution == train_loop.device.index]
)
self.assertEqual(cb_handler.registered_cbs, [[], [], [cb_3], [], [cb_1, cb_3], [], [], [cb_3], [], []])

def test_handler_execution_after_batch_prediction(self):
self.execute_training_with_on_batch_prediction_cb(
num_epochs=5, train_loader=list(range(4)), val_loader=list(range(3)), test_loader=None
Expand Down Expand Up @@ -352,8 +457,8 @@ def test_handler_disable_execution_after_batch_prediction(self):


class BatchBeginCB(AbstractCallback):
def __init__(self, execution_order=0):
super().__init__('', execution_order)
def __init__(self, execution_order=0, device_idx_execution=None):
super().__init__('', execution_order, device_idx_execution)
self.registered_tl = False

def on_train_loop_registration(self):
Expand All @@ -364,8 +469,8 @@ def on_batch_begin(self):


class BatchBeginTrainBeginCB(AbstractCallback):
def __init__(self, execution_order=0):
super().__init__('', execution_order)
def __init__(self, execution_order=0, device_idx_execution=None):
super().__init__('', execution_order, device_idx_execution)

def on_batch_begin(self):
print("executed")
Expand All @@ -375,8 +480,8 @@ def on_train_begin(self):


class BatchBeginTrainBeginAfterOptiCB(AbstractCallback):
def __init__(self, execution_order=0):
super().__init__('', execution_order)
def __init__(self, execution_order=0, device_idx_execution=None):
super().__init__('', execution_order, device_idx_execution)
self.exe_on_batch_begin = False
self.exe_on_train_begin = False
self.exe_on_after_optimizer_step = False
Expand All @@ -392,8 +497,8 @@ def on_after_optimizer_step(self):


class AfterBatchPredictionCB(AbstractCallback):
def __init__(self, execute_callbacks, execution_order=0):
super().__init__('', execution_order)
def __init__(self, execute_callbacks, execution_order=0, device_idx_execution=None):
super().__init__('', execution_order, device_idx_execution)
self.execute_callbacks = execute_callbacks
self.cb_execution_ctr = 0
self.cb_execution_ctr_dict = {'train': 0, 'validation': 0, 'test': 0}
Expand Down

0 comments on commit 3867a2b

Please sign in to comment.