diff --git a/test/test_mps.py b/test/test_mps.py index 2085d0cebe721..b7907e7ed1990 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -19,7 +19,7 @@ import yaml import platform from collections import defaultdict -from torch._six import inf +from torch import inf from torch.nn import Parameter from torch.testing._internal import opinfo from torch.testing._internal.common_utils import \