Skip to content

Commit

Permalink
Merge pull request #725 from mv1388/ddp_tests_naming
Browse files Browse the repository at this point in the history
Fix DDP testing method names
  • Loading branch information
mv1388 committed Aug 7, 2022
2 parents 98c56c0 + 28ecc21 commit f7af487
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions tests/test_torchtrain/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,41 +199,41 @@ def _ddp_model_wrap_forward_attribute_access(gpu):
for i in range(1, 101):
assert ddp_model(100) == i

def test_dp_model_wrap_get_loss_attribute_access(self):
def test_ddp_model_wrap_get_loss_attribute_access(self):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '8888'
mp.spawn(self._dp_model_wrap_get_loss_attribute_access, nprocs=2)
mp.spawn(self._ddp_model_wrap_get_loss_attribute_access, nprocs=2)

@staticmethod
def _dp_model_wrap_get_loss_attribute_access(gpu):
def _ddp_model_wrap_get_loss_attribute_access(gpu):
dist.init_process_group(backend='gloo', init_method='env://', world_size=2, rank=gpu)
model = DPModel()
ddp_model = TTDistributedDataParallel(model)

for i in range(1, 101):
assert ddp_model.get_loss(100, None, None) == (i, i, 'my_new_fn return value', 'test string')

def test_dp_model_wrap_get_predictions_attribute_access(self):
def test_ddp_model_wrap_get_predictions_attribute_access(self):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '8888'
mp.spawn(self._dp_model_wrap_get_predictions_attribute_access, nprocs=2)
mp.spawn(self._ddp_model_wrap_get_predictions_attribute_access, nprocs=2)

@staticmethod
def _dp_model_wrap_get_predictions_attribute_access(gpu):
def _ddp_model_wrap_get_predictions_attribute_access(gpu):
dist.init_process_group(backend='gloo', init_method='env://', world_size=2, rank=gpu)
model = DPModel()
ddp_model = TTDistributedDataParallel(model)

for i in range(1, 101):
assert ddp_model.get_predictions(100, None) == (i, i, 'my_new_fn return value', 'test string')

def test_dp_model_wrap_all_methods_mix_attribute_access(self):
def test_ddp_model_wrap_all_methods_mix_attribute_access(self):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '8888'
mp.spawn(self._dp_model_wrap_all_methods_mix_attribute_access, nprocs=2)
mp.spawn(self._ddp_model_wrap_all_methods_mix_attribute_access, nprocs=2)

@staticmethod
def _dp_model_wrap_all_methods_mix_attribute_access(gpu):
def _ddp_model_wrap_all_methods_mix_attribute_access(gpu):
dist.init_process_group(backend='gloo', init_method='env://', world_size=2, rank=gpu)
model = DPModel()
ddp_model = TTDistributedDataParallel(model)
Expand All @@ -244,13 +244,13 @@ def _dp_model_wrap_all_methods_mix_attribute_access(gpu):
for i in range(1, 101):
assert ddp_model.get_loss(100, None, None) == (i + 100, i, 'my_new_fn return value', 'test string')

def test_dp_model_wrap_unreachable_attribute_access(self):
def test_ddp_model_wrap_unreachable_attribute_access(self):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '8888'
mp.spawn(self._dp_model_wrap_unreachable_attribute_access, nprocs=2)
mp.spawn(self._ddp_model_wrap_unreachable_attribute_access, nprocs=2)

@staticmethod
def _dp_model_wrap_unreachable_attribute_access(gpu):
def _ddp_model_wrap_unreachable_attribute_access(gpu):
dist.init_process_group(backend='gloo', init_method='env://', world_size=2, rank=gpu)
model = DPModel()
ddp_model = TTDistributedDataParallel(model)
Expand Down

0 comments on commit f7af487

Please sign in to comment.