Skip to content

Commit

Permalink
Merge pull request #704 from mv1388/ttmodel_ddp_inheritance
Browse files Browse the repository at this point in the history
Switch around TTDistributedDataParallel inheritance
  • Loading branch information
mv1388 committed Jul 16, 2022
2 parents cb27e1c + 315c67c commit 018ed4d
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion aitoolbox/torchtrain/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(self, module,
TTParallelBase.__init__(self, module, default_model_methods)


class TTDistributedDataParallel(DistributedDataParallel, TTParallelBase):
class TTDistributedDataParallel(TTParallelBase, DistributedDataParallel):
def __init__(self, module,
default_model_methods=('get_loss', 'get_loss_eval', 'get_predictions'), **kwargs):
"""torchtrain enabled DistributedDataParallel
Expand Down

0 comments on commit 018ed4d

Please sign in to comment.