Skip to content

Commit

Permalink
Merge pull request #640 from mv1388/remove-ApexDistributedDataParallel
Browse files Browse the repository at this point in the history
Remove TTApexDistributedDataParallel since APEX was deprecated
  • Loading branch information
mv1388 committed Nov 1, 2020
2 parents 7bba82d + 33d9803 commit 8b6d702
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 21 deletions.
21 changes: 0 additions & 21 deletions aitoolbox/torchtrain/parallel.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
import functools
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
try:
from apex.parallel import DistributedDataParallel as ApexDistributedDataParallel
APEX_AVAILABLE = True
except ImportError:
APEX_AVAILABLE = False
try:
from deepspeed import DeepSpeedLight
DEEPSPEED_AVAILABLE = True
Expand Down Expand Up @@ -87,22 +82,6 @@ def __init__(self, module,
TTParallelBase.__init__(self, module, default_model_methods)


if APEX_AVAILABLE:
class TTApexDistributedDataParallel(ApexDistributedDataParallel, TTParallelBase):
def __init__(self, module,
default_model_methods=('get_loss', 'get_loss_eval', 'get_predictions'), **kwargs):
"""torchtrain enabled Nvidia Apex DistributedDataParallel
Args:
module (aitoolbox.torchtrain.model.TTModel): neural network model
default_model_methods (list or tuple): list of core methods which are present also in TTModel
abstract class
**kwargs: additional parameters for underlying apex.parallel.DistributedDataParallel
"""
ApexDistributedDataParallel.__init__(self, module, **kwargs)
TTParallelBase.__init__(self, module, default_model_methods)


if DEEPSPEED_AVAILABLE:
class TTDeepSpeedLight(DeepSpeedLight, TTParallelBase):
def __init__(self, args, model,
Expand Down
Binary file modified dist/aitoolbox-1.2.0-py3-none-any.whl
Binary file not shown.
Binary file modified dist/aitoolbox-1.2.0.tar.gz
Binary file not shown.

0 comments on commit 8b6d702

Please sign in to comment.