Skip to content

Commit

Permalink
Merge pull request #617 from mv1388/dp-attribute-transfer-names
Browse files Browse the repository at this point in the history
DP/DDP wrapper attribute transfer names
  • Loading branch information
mv1388 committed Oct 10, 2020
2 parents 120d013 + cea7bd6 commit e0ddadd
Show file tree
Hide file tree
Showing 8 changed files with 312 additions and 35 deletions.
9 changes: 9 additions & 0 deletions aitoolbox/torchtrain/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,16 @@ class TTModel(nn.Module, ABC):
In addition to the common ``forward()`` method required by the base nn.Module, the user also needs to implement
the additional AIToolbox specific ``get_loss()`` and ``get_predictions()`` methods.
``transfer_model_attributes`` (list or tuple): additional TTModel attributes which need to be transferred to
the TTDataParallel level to enable their use in the transferred/exposed class methods. When coding
the model's __init__() method user should also fill in the string names of attributes that should be
transferred in case the model is wrapped for DP/DDP.
"""
def __init__(self):
super().__init__()
self.transfer_model_attributes = []

@abstractmethod
def get_loss(self, batch_data, criterion, device):
"""Get loss during training stage
Expand Down
32 changes: 11 additions & 21 deletions aitoolbox/torchtrain/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,12 @@


class TTParallelBase:
def __init__(self, module, add_model_attributes=None,
def __init__(self, module,
default_model_methods=('get_loss', 'get_loss_eval', 'get_predictions')):
"""torchtrain parallel base class used for transferring TTModel functions to the PyTorch Parallel wrappers level
Args:
module (aitoolbox.torchtrain.model.TTModel): neural network model
add_model_attributes (list or tuple or None): additional TTModel attributes which need to be transferred to
the TTDataParallel level to enable their use in the transferred/exposed class methods
default_model_methods (list or tuple): list of core methods which are present also in TTModel abstract class
"""
# Core TTModel methods which every model has
Expand All @@ -43,8 +41,8 @@ def __init__(self, module, add_model_attributes=None,
functools.partial(copy_function(getattr(module, method_name)), self))

# Optionally transfer additional TTModel attributes to the TTDataParallel level
if add_model_attributes is not None and isinstance(add_model_attributes, (list, tuple)):
for attr_name in add_model_attributes:
if module.transfer_model_attributes is not None and isinstance(module.transfer_model_attributes, (list, tuple)):
for attr_name in module.transfer_model_attributes:
setattr(self, attr_name, getattr(module, attr_name))

def get_loss(self, batch_data, criterion, device):
Expand All @@ -58,7 +56,7 @@ def get_predictions(self, batch_data, device):


class TTDataParallel(nn.DataParallel, TTParallelBase):
def __init__(self, module, add_model_attributes=None,
def __init__(self, module,
default_model_methods=('get_loss', 'get_loss_eval', 'get_predictions'), **kwargs):
"""torchtrain enabled DataParallel
Expand All @@ -68,67 +66,59 @@ def __init__(self, module, add_model_attributes=None,
Args:
module (aitoolbox.torchtrain.model.TTModel): neural network model
add_model_attributes (list or tuple or None): additional TTModel attributes which need to be transferred to
the TTDataParallel level to enable their use in the transferred/exposed class methods
default_model_methods (list or tuple): list of core methods which are present also in TTModel abstract class
**kwargs: additional parameters for underlying nn.DataParallel
"""
nn.DataParallel.__init__(self, module, **kwargs)
TTParallelBase.__init__(self, module, add_model_attributes, default_model_methods)
TTParallelBase.__init__(self, module, default_model_methods)


class TTDistributedDataParallel(DistributedDataParallel, TTParallelBase):
def __init__(self, module, add_model_attributes=None,
def __init__(self, module,
default_model_methods=('get_loss', 'get_loss_eval', 'get_predictions'), **kwargs):
"""torchtrain enabled DistributedDataParallel
Args:
module (aitoolbox.torchtrain.model.TTModel): neural network model
add_model_attributes (list or tuple or None): additional TTModel attributes which need to be transferred to
the TTDistributedDataParallel level to enable their use in the transferred/exposed class methods
default_model_methods (list or tuple): list of core methods which are present also in TTModel abstract class
**kwargs: additional parameters for underlying nn.parallel.DistributedDataParallel
"""
DistributedDataParallel.__init__(self, module, **kwargs)
TTParallelBase.__init__(self, module, add_model_attributes, default_model_methods)
TTParallelBase.__init__(self, module, default_model_methods)


if APEX_AVAILABLE:
class TTApexDistributedDataParallel(ApexDistributedDataParallel, TTParallelBase):
def __init__(self, module, add_model_attributes=None,
def __init__(self, module,
default_model_methods=('get_loss', 'get_loss_eval', 'get_predictions'), **kwargs):
"""torchtrain enabled Nvidia Apex DistributedDataParallel
Args:
module (aitoolbox.torchtrain.model.TTModel): neural network model
add_model_attributes (list or tuple or None): additional TTModel attributes which need to be
transferred to the TTDataParallel level to enable their use in the transferred/exposed class methods
default_model_methods (list or tuple): list of core methods which are present also in TTModel
abstract class
**kwargs: additional parameters for underlying apex.parallel.DistributedDataParallel
"""
ApexDistributedDataParallel.__init__(self, module, **kwargs)
TTParallelBase.__init__(self, module, add_model_attributes, default_model_methods)
TTParallelBase.__init__(self, module, default_model_methods)


if DEEPSPEED_AVAILABLE:
class TTDeepSpeedLight(DeepSpeedLight, TTParallelBase):
def __init__(self, args, model,
add_model_attributes=None, default_model_methods=('get_loss', 'get_loss_eval', 'get_predictions'),
default_model_methods=('get_loss', 'get_loss_eval', 'get_predictions'),
**kwargs):
"""torchtrain enabled Microsoft DeepSpeed's DeepSpeedLight engine
Args:
args (argparse.Namespace): argparser results structured as per DeepSpeed requirements. A dictionary
containing local_rank and deepspeed_config file location.
model (aitoolbox.torchtrain.model.TTModel): neural network model
add_model_attributes (list or tuple or None): additional TTModel attributes which need to be transferred
to the TTDeepSpeedLight level to enable their use in the transferred/exposed class methods
default_model_methods (list or tuple): list of core methods which are present also in TTModel
abstract class
**kwargs: additional parameters for the underlying ``deepspeed.DeepSpeedLight`` class
Possible arguments: https://deepspeed.readthedocs.io/en/latest/initialize.html
"""
DeepSpeedLight.__init__(self, args, model, **kwargs)
TTParallelBase.__init__(self, model, add_model_attributes, default_model_methods)
TTParallelBase.__init__(self, model, default_model_methods)
9 changes: 2 additions & 7 deletions aitoolbox/torchtrain/train_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,9 +673,6 @@ def _train_dp(self, num_epochs, callbacks=None, grad_accumulation=1, dp_model_ar
grad_accumulation (int): number of batches the gradients are accumulated before updating weights
dp_model_args (dict or None): parameters for :class:`aitoolbox.torchtrain.parallel.TTDataParallel` /
``nn.DataParallel`` DP model wrap.
Probably the most common optional parameter to set is ``TTDataParallel``'s ``add_model_attributes``
list. In this list the user can list any additional TTModel attributes which need to be transferred to
the TTDataParallel level to enable their use in the transferred/exposed class methods.
Returns:
TTDataParallel or nn.DataParallel: trained model
Expand Down Expand Up @@ -793,7 +790,7 @@ def _spawn_fit(self, gpu, ddp_args, num_epochs, callbacks, grad_accumulation, in
self._train(num_epochs, callbacks, grad_accumulation)

def _train_deepspeed(self, deepspeed_args, num_epochs, callbacks=None,
add_model_attributes=None, **ds_model_args):
**ds_model_args):
"""Train the model using Microsoft DeepSpeed package
Before starting the training the DeepSpeed library needs to be installed on the machine. Find the installation
Expand All @@ -807,8 +804,6 @@ def _train_deepspeed(self, deepspeed_args, num_epochs, callbacks=None,
A dictionary containing local_rank and deepspeed_config file location.
num_epochs (int): how many epochs the network will be trained
callbacks (list): callbacks that are executed during the training run
add_model_attributes (list or tuple or None): additional TTModel attributes which need to be transferred to
the TTDataParallel level to enable their use in the transferred/exposed class methods
**ds_model_args: additional parameters for the underlying ``deepspeed.DeepSpeedLight`` class
Possible arguments: https://deepspeed.readthedocs.io/en/latest/initialize.html
Expand All @@ -826,7 +821,7 @@ def _train_deepspeed(self, deepspeed_args, num_epochs, callbacks=None,

self.model = TTDeepSpeedLight(
args=deepspeed_args,
model=self.model, model_parameters=self.model.parameters(), add_model_attributes=add_model_attributes,
model=self.model, model_parameters=self.model.parameters(),
training_data=self.train_loader.dataset,
**ds_model_args
)
Expand Down
Binary file modified dist/aitoolbox-1.2.0-py3-none-any.whl
Binary file not shown.
Binary file modified dist/aitoolbox-1.2.0.tar.gz
Binary file not shown.
2 changes: 2 additions & 0 deletions tests/test_torchtrain/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@ def __init__(self):
super().__init__()
self.model_level_str = 'test string'

self.transfer_model_attributes = ['model_level_str']

def get_loss(self, batch_data, criterion, device):
return 'loss_return'

Expand Down
101 changes: 94 additions & 7 deletions tests/test_torchtrain/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from tests.test_torchtrain.test_model import MyModel

from aitoolbox import TTDataParallel
from aitoolbox import TTModel, TTDataParallel
from aitoolbox.torchtrain.parallel import TTParallelBase


Expand All @@ -19,11 +19,9 @@ def test_init_default(self):
self.assertTrue(hasattr(model_parallel, 'my_new_fn'))
self.assertTrue(hasattr(model_parallel, 'get_model_level_str'))

self.assertFalse(hasattr(model_parallel, 'model_level_str'))

def test_init_attr_transfer(self):
model = MyModel()
model_parallel = TTParallelBase(model, add_model_attributes=['model_level_str'])
model_parallel = TTParallelBase(model)

self.assertTrue(hasattr(model_parallel, 'model_level_str'))

Expand Down Expand Up @@ -51,8 +49,6 @@ def test_init_default(self):
self.assertTrue(hasattr(model_parallel, 'my_new_fn'))
self.assertTrue(hasattr(model_parallel, 'get_model_level_str'))

self.assertFalse(hasattr(model_parallel, 'model_level_str'))

def test_inheritance(self):
model = MyModel()
model_parallel = TTDataParallel(model)
Expand All @@ -62,7 +58,7 @@ def test_inheritance(self):

def test_init_attr_transfer(self):
model = MyModel()
model_parallel = TTDataParallel(model, add_model_attributes=['model_level_str'])
model_parallel = TTDataParallel(model)

self.assertTrue(hasattr(model_parallel, 'model_level_str'))

Expand All @@ -76,3 +72,94 @@ def test_core_model_transferred_fns(self):
self.assertEqual(model_parallel.get_loss_eval(None, None, None), 'loss_eval_return')
self.assertEqual(model_parallel.get_predictions(None, None), 'predictions_return')
self.assertEqual(model_parallel.my_new_fn(), 'my_new_fn return value')

def test_dp_model_wrap_forward_attribute_access(self):
model = DPModel()
dp_model = TTDataParallel(model)

for i in range(1, 101):
self.assertEqual(dp_model(100), i)

def test_dp_model_wrap_get_loss_attribute_access(self):
model = DPModel()
dp_model = TTDataParallel(model)

for i in range(1, 101):
self.assertEqual(dp_model.get_loss(100, None, None),
(i, i, 'my_new_fn return value', 'test string'))

def test_dp_model_wrap_get_predictions_attribute_access(self):
model = DPModel()
dp_model = TTDataParallel(model)

for i in range(1, 101):
self.assertEqual(dp_model.get_predictions(100, None),
(i, i, 'my_new_fn return value', 'test string'))

def test_dp_model_wrap_all_methods_mix_attribute_access(self):
model = DPModel()
dp_model = TTDataParallel(model)

for i in range(1, 101):
self.assertEqual(dp_model(100), i)

for i in range(1, 101):
self.assertEqual(dp_model.get_loss(100, None, None),
(i + 100, i, 'my_new_fn return value', 'test string'))

def test_dp_model_wrap_unreachable_attribute_access(self):
model = DPModel()
dp_model = TTDataParallel(model)

self.assertEqual(dp_model.get_loss(100, None, None), (1, 1, 'my_new_fn return value', 'test string'))

with self.assertRaises(nn.modules.module.ModuleAttributeError):
dp_model.get_loss(100, None, 'unreachable')

with self.assertRaises(AttributeError):
dp_model.get_loss(100, None, 'unreachable')


class DPModel(TTModel):
def __init__(self):
super().__init__()
self.model_level_str = 'test string'

self.forward_ctr = 0
self.get_loss_ctr = 0
self.get_predictions_ctr = 0

self.unreachable_attr = "Can't get me"

self.transfer_model_attributes = ['model_level_str', 'get_loss_ctr', 'get_predictions_ctr']

def forward(self, batch):
self.forward_ctr += 1
return self.forward_ctr

def get_loss(self, batch_data, criterion, device):
forward_ctr = self(batch_data)
self.get_loss_ctr += 1

my_fn_return = self.my_new_fn()
model_level_str = self.get_model_level_str()

if device == 'unreachable':
my_fn_return = self.unreachable_attr

return forward_ctr, self.get_loss_ctr, my_fn_return, model_level_str

def get_predictions(self, batch_data, device):
forward_ctr = self(batch_data)
self.get_predictions_ctr += 1

my_fn_return = self.my_new_fn()
model_level_str = self.get_model_level_str()

return forward_ctr, self.get_predictions_ctr, my_fn_return, model_level_str

def my_new_fn(self):
return 'my_new_fn return value'

def get_model_level_str(self):
return self.model_level_str

0 comments on commit e0ddadd

Please sign in to comment.