Skip to content

Commit

Permalink
[tests] switch to torchrun (#22712)
Browse files Browse the repository at this point in the history
  • Loading branch information
stas00 committed Apr 12, 2023
1 parent d87ef00 commit 1306b7d
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion tests/extended/test_trainer_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def run_trainer(
n_gpus_to_use = get_gpu_count()
master_port = get_torch_dist_unique_port()
distributed_args = f"""
-m torch.distributed.launch
-m torch.distributed.run
--nproc_per_node={n_gpus_to_use}
--master_port={master_port}
{self.examples_dir_str}/pytorch/translation/run_translation.py
Expand Down
6 changes: 3 additions & 3 deletions tests/trainer/test_trainer_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class TestTrainerDistributedNeuronCore(TestCasePlus):
@require_torch_neuroncore
def test_trainer(self):
distributed_args = f"""
-m torch.distributed.launch
-m torch.distributed.run
--nproc_per_node=2
--master_port={get_torch_dist_unique_port()}
{self.test_file_dir}/test_trainer_distributed.py
Expand All @@ -83,7 +83,7 @@ class TestTrainerDistributed(TestCasePlus):
@require_torch_multi_gpu
def test_trainer(self):
distributed_args = f"""
-m torch.distributed.launch
-m torch.distributed.run
--nproc_per_node={torch.cuda.device_count()}
--master_port={get_torch_dist_unique_port()}
{self.test_file_dir}/test_trainer_distributed.py
Expand All @@ -98,7 +98,7 @@ def test_trainer(self):
if __name__ == "__main__":
# The script below is meant to be run under torch.distributed, on a machine with multiple GPUs:
#
# PYTHONPATH="src" python -m torch.distributed.launch --nproc_per_node 2 --output_dir output_dir ./tests/test_trainer_distributed.py
# PYTHONPATH="src" python -m torch.distributed.run --nproc_per_node 2 --output_dir output_dir ./tests/test_trainer_distributed.py

parser = HfArgumentParser((TrainingArguments,))
training_args = parser.parse_args_into_dataclasses()[0]
Expand Down

0 comments on commit 1306b7d

Please sign in to comment.