Skip to content

Commit

Permalink
example: pytorch_lightning_mnist.py (#3290)
Browse files Browse the repository at this point in the history
Set GPU device with horovod local rank.

Signed-off-by: Chongxiao Cao <chongxiaoc@uber.com>
  • Loading branch information
chongxiaoc committed Jan 21, 2022
1 parent a89efa9 commit 15a6aa3
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions examples/pytorch/pytorch_lightning_mnist.py
Expand Up @@ -110,8 +110,12 @@ def test():

if __name__ == '__main__':
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
hvd.init()
args.cuda = not args.no_cuda and torch.cuda.is_available()
if args.cuda:
torch.cuda.set_device(hvd.local_rank())
torch.cuda.manual_seed(args.seed)

kwargs = {'num_workers': 2}
# When supported, use 'forkserver' to spawn dataloader workers instead of 'fork' to prevent
Expand Down Expand Up @@ -205,4 +209,3 @@ def on_train_end(self, trainer, model):
if args.cuda:
model = model.cuda()
test()

0 comments on commit 15a6aa3

Please sign in to comment.