diff --git a/examples/pytorch/pytorch_lightning_mnist.py b/examples/pytorch/pytorch_lightning_mnist.py index e6ffcccf1f..4c72de6df6 100644 --- a/examples/pytorch/pytorch_lightning_mnist.py +++ b/examples/pytorch/pytorch_lightning_mnist.py @@ -110,8 +110,12 @@ def test(): if __name__ == '__main__': args = parser.parse_args() - args.cuda = not args.no_cuda and torch.cuda.is_available() + torch.manual_seed(args.seed) hvd.init() + args.cuda = not args.no_cuda and torch.cuda.is_available() + if args.cuda: + torch.cuda.set_device(hvd.local_rank()) + torch.cuda.manual_seed(args.seed) kwargs = {'num_workers': 2} # When supported, use 'forkserver' to spawn dataloader workers instead of 'fork' to prevent @@ -205,4 +209,3 @@ def on_train_end(self, trainer, model): if args.cuda: model = model.cuda() test() -