Skip to content

Commit

Permalink
[MPS] Add Python Module Bindings for the MPS backend (pytorch#94417)
Browse files Browse the repository at this point in the history
- This PR is a prerequisite for the upcoming Memory Leak Detection PR.
- Enable global manual seeding via `torch.manual_seed()` + test case
- Add `torch.mps.synchronize()` to wait for MPS stream to finish + test case
- Enable the following python interfaces for MPS:
  `torch.mps.[get_rng_state(), set_rng_state(), synchronize(), manual_seed(), seed()]`
- Added some test cases in test_mps.py
- Added `mps.rst` to document the `torch.mps` module.
- Fixed the failure with `test_public_bindings.py`

Description of new files added:
- `torch/csrc/mps/Module.cpp`: implements `torch._C` module functions for `torch.mps` and `torch.backends.mps`.
- `torch/mps/__init__.py`: implements Python bindings for `torch.mps` module.
Pull Request resolved: pytorch#94417
Approved by: https://github.com/albanD
  • Loading branch information
razarmehr authored and pull[bot] committed Feb 21, 2023
1 parent d5846cf commit dbf3581
Show file tree
Hide file tree
Showing 15 changed files with 262 additions and 16 deletions.
8 changes: 8 additions & 0 deletions aten/src/ATen/detail/MPSHooksInterface.h
Expand Up @@ -28,13 +28,21 @@ struct TORCH_API MPSHooksInterface {
return false;
}

virtual bool isOnMacOS13orNewer() const {
AT_ERROR("MPS backend is not available.");
}

virtual const Generator& getDefaultMPSGenerator() const {
AT_ERROR("Cannot get default MPS generator without MPS backend.");
}

virtual Allocator* getMPSDeviceAllocator() const {
AT_ERROR("MPSDeviceAllocator requires MPS.");
}

virtual void deviceSynchronize() const {
AT_ERROR("Cannot synchronize MPS device without MPS backend.");
}
};

struct TORCH_API MPSHooksArgs {};
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/mps/MPSDevice.h
Expand Up @@ -79,7 +79,7 @@ class TORCH_API MPSDevice {

TORCH_API bool is_available();
TORCH_API bool is_macos_13_or_newer(MacOSVersion version = MacOSVersion::MACOS_VER_13_0_PLUS);

TORCH_API void device_synchronize();
TORCH_API at::Allocator* GetMPSAllocator(bool useSharedAllocator = false);

} // namespace mps
Expand Down
5 changes: 5 additions & 0 deletions aten/src/ATen/mps/MPSDevice.mm
Expand Up @@ -3,6 +3,7 @@
#include <c10/util/CallOnce.h>

#include <ATen/mps/MPSDevice.h>
#include <ATen/mps/MPSStream.h>
#include <ATen/mps/MPSAllocatorInterface.h>
#include <ATen/mps/IndexKernels.h>

Expand Down Expand Up @@ -122,5 +123,9 @@ bool is_macos_13_or_newer(MacOSVersion version) {
return MPSDevice::getInstance()->isMacOS13Plus(version);
}

void device_synchronize() {
getDefaultMPSStream()->synchronize(SyncType::COMMIT_AND_WAIT);
}

} // namespace mps
} // namespace at
8 changes: 8 additions & 0 deletions aten/src/ATen/mps/MPSHooks.cpp
Expand Up @@ -16,6 +16,10 @@ bool MPSHooks::hasMPS() const {
return at::mps::is_available();
}

bool MPSHooks::isOnMacOS13orNewer() const {
return at::mps::is_macos_13_or_newer();
}

Allocator* MPSHooks::getMPSDeviceAllocator() const {
return at::mps::GetMPSAllocator();
}
Expand All @@ -24,6 +28,10 @@ const Generator& MPSHooks::getDefaultMPSGenerator() const {
return at::mps::detail::getDefaultMPSGenerator();
}

void MPSHooks::deviceSynchronize() const {
at::mps::device_synchronize();
}

using at::MPSHooksRegistry;
using at::RegistererMPSHooksRegistry;

Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/mps/MPSHooks.h
Expand Up @@ -13,8 +13,10 @@ struct MPSHooks : public at::MPSHooksInterface {
MPSHooks(at::MPSHooksArgs) {}
void initMPS() const override;
bool hasMPS() const override;
bool isOnMacOS13orNewer() const override;
Allocator* getMPSDeviceAllocator() const override;
const Generator& getDefaultMPSGenerator() const override;
void deviceSynchronize() const override;
};

}} // at::mps
1 change: 1 addition & 0 deletions build_variables.bzl
Expand Up @@ -822,6 +822,7 @@ libtorch_python_core_sources = [
"torch/csrc/dynamo/guards.cpp",
"torch/csrc/dynamo/init.cpp",
"torch/csrc/functorch/init.cpp",
"torch/csrc/mps/Module.cpp",
"torch/csrc/jit/backends/backend_init.cpp",
"torch/csrc/jit/python/init.cpp",
"torch/csrc/jit/passes/onnx.cpp",
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Expand Up @@ -81,6 +81,7 @@ Features described in this documentation are classified by release status:
torch.autograd <autograd>
torch.library <library>
cuda
mps
torch.backends <backends>
torch.distributed <distributed>
torch.distributed.algorithms.join <distributed.algorithms.join>
Expand Down
14 changes: 14 additions & 0 deletions docs/source/mps.rst
@@ -0,0 +1,14 @@
torch.mps
===================================
.. automodule:: torch.mps
.. currentmodule:: torch.mps

.. autosummary::
:toctree: generated
:nosignatures:

synchronize
get_rng_state
set_rng_state
manual_seed
seed
39 changes: 39 additions & 0 deletions test/test_mps.py
Expand Up @@ -5972,6 +5972,45 @@ def test_mps_generator(self):
mps_x = torch.randn(5, device='mps', generator=g_mps)
self.assertEqual(mps_x, mps_y)

def test_default_mps_generator(self):
# manual seeding on the "default" MPS generator using
# the global torch.manual_seed()
torch.manual_seed(230)
mps_x = torch.randn(5, device='mps')
# manual seeding using torch.mps.manual_seed()
# which should set the "default" MPS generator
# like the global torch.manual_seed()
torch.mps.manual_seed(230)
mps_y = torch.randn(5, device='mps')
# seed values were the same, so the random tensor contents should match
self.assertEqual(mps_x, mps_y)

# save the default generator's state to restore it later
g_state = torch.mps.get_rng_state()

# generate random numbers without seeding
mps_x = torch.randn(5, device='mps')
# in this case, the random results must differ from the last generated random results
self.assertNotEqual(mps_x, mps_y)

# restore the previously saved state, and the results should match again
torch.mps.set_rng_state(g_state)
mps_x = torch.randn(5, device='mps')
self.assertEqual(mps_x, mps_y)

def test_device_synchronize(self):
# just running some ops each followed by a synchronize to wait for
# MPS stream to finish running each of them
net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\
.to(device='mps', dtype=torch.float)

x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
torch.mps.synchronize()
x = net1(x)
torch.mps.synchronize()
x.backward(torch.randn_like(x))
torch.mps.synchronize()

# Test random_.to and random_.from
def test_random(self):
def helper(shape, low, high, dtype=torch.int32):
Expand Down
8 changes: 6 additions & 2 deletions torch/_C/__init__.pyi.in
Expand Up @@ -903,8 +903,6 @@ def _disabled_torch_function_impl(func: Callable, types: Iterable[Type], args: T
def _disabled_torch_dispatch_impl(func: Callable, types: Iterable[Type], args: Tuple, kwargs: Dict) -> Any: ... # THPModule_disable_dispatch_function
def _get_linalg_preferred_backend() -> torch._C._LinalgBackend: ...
def _set_linalg_preferred_backend(arg: torch._C._LinalgBackend): ...
def _is_mps_available() -> _bool: ...
def _is_mps_on_macos_13_or_newer() -> _bool: ...
class _LinalgBackend:
Default: _LinalgBackend
Cusolver: _LinalgBackend
Expand Down Expand Up @@ -1200,6 +1198,12 @@ class _TensorBase(metaclass=_TensorMeta):
# Defined in torch/csrc/multiprocessing/init.cpp
def _multiprocessing_init() -> None: ...

# Defined in torch/csrc/mps/Module.cpp
def _mps_synchronize() -> None: ...
def _mps_get_default_generator() -> Generator: ...
def _is_mps_available() -> _bool: ...
def _is_mps_on_macos_13_or_newer() -> _bool: ...

# Defined in torch/csrc/cuda/Module.cpp
def _cuda_getCurrentStream(device: _int) -> Tuple: ...
def _cuda_getCurrentRawStream(device: _int) -> _int: ...
Expand Down
15 changes: 2 additions & 13 deletions torch/csrc/Module.cpp
Expand Up @@ -60,6 +60,7 @@
#include <torch/csrc/jit/serialization/pickler.h>
#include <torch/csrc/lazy/python/init.h>
#include <torch/csrc/monitor/python_init.h>
#include <torch/csrc/mps/Module.h>
#include <torch/csrc/multiprocessing/init.h>
#include <torch/csrc/onnx/init.h>
#include <torch/csrc/profiler/python/init.h>
Expand Down Expand Up @@ -87,10 +88,6 @@
#endif
#endif

#if defined(USE_MPS)
#include <ATen/mps/MPSDevice.h>
#endif

#if defined(USE_VALGRIND)
#include <callgrind.h>
#endif
Expand Down Expand Up @@ -1271,6 +1268,7 @@ PyObject* initModule() {
THPUtils_addPyMethodDefs(methods, DataLoaderMethods);
THPUtils_addPyMethodDefs(methods, torch::autograd::python_functions());
THPUtils_addPyMethodDefs(methods, torch::multiprocessing::python_functions());
THPUtils_addPyMethodDefs(methods, torch::mps::python_functions());
#ifdef USE_CUDA
THPUtils_addPyMethodDefs(methods, THCPModule_methods());
#endif
Expand Down Expand Up @@ -1593,15 +1591,6 @@ Call this whenever a new thread is created in order to propagate values from

ASSERT_TRUE(set_module_attr("has_cuda", has_cuda));
ASSERT_TRUE(set_module_attr("has_mps", has_mps));
py_module.def("_is_mps_available", []() { return at::hasMPS(); });
py_module.def("_is_mps_on_macos_13_or_newer", []() {
#ifdef USE_MPS
return at::mps::is_macos_13_or_newer();
#else
return false;
#endif
});

ASSERT_TRUE(
set_module_attr("has_mkldnn", at::hasMKLDNN() ? Py_True : Py_False));

Expand Down
102 changes: 102 additions & 0 deletions torch/csrc/mps/Module.cpp
@@ -0,0 +1,102 @@
#include <ATen/ATen.h>
#include <c10/util/CallOnce.h>
#include <torch/csrc/Generator.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/python_numbers.h>

// pthread.h is included for tracking bad forks
#ifndef WIN32
#include <pthread.h>
#endif

namespace torch {
namespace mps {

namespace {
// True for children forked after mps init
static bool in_bad_fork = false;

// Called in the forked child if mps has already been initialized
static void forked_mps_child() {
in_bad_fork = true;
}

// Should be called before the first mps call.
static void track_bad_mps_fork() {
#ifndef WIN32
static c10::once_flag flag;
c10::call_once(
flag, [] { pthread_atfork(nullptr, nullptr, forked_mps_child); });
#endif
}
} // namespace

static PyObject* MPSModule_isInBadFork(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS
return PyBool_FromLong(in_bad_fork);
END_HANDLE_TH_ERRORS
}

static PyObject* MPSModule_getDefaultMPSGenerator(
PyObject* _unused,
PyObject* noargs) {
HANDLE_TH_ERRORS
track_bad_mps_fork();
return THPGenerator_initDefaultGenerator(
at::detail::getMPSHooks().getDefaultMPSGenerator());
END_HANDLE_TH_ERRORS
}

static PyObject* MPSModule_isAvailable(PyObject* _unused, PyObject* noargs) {
HANDLE_TH_ERRORS
track_bad_mps_fork();
if (at::detail::getMPSHooks().hasMPS()) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
END_HANDLE_TH_ERRORS
}

static PyObject* MPSModule_isMacOS13orNewer(
PyObject* _unused,
PyObject* noargs) {
HANDLE_TH_ERRORS
if (at::detail::getMPSHooks().isOnMacOS13orNewer()) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
END_HANDLE_TH_ERRORS
}

static PyObject* MPSModule_synchronize(PyObject* _unused, PyObject* noargs) {
HANDLE_TH_ERRORS
at::detail::getMPSHooks().deviceSynchronize();
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}

// NOLINTNEXTLINE(modernize-avoid-c-arrays,
// cppcoreguidelines-avoid-non-const-global-variables,
// cppcoreguidelines-avoid-c-arrays)
static struct PyMethodDef _MPSModule_methods[] = {
{"_mps_synchronize", MPSModule_synchronize, METH_NOARGS, nullptr},
{"_mps_is_in_bad_fork", MPSModule_isInBadFork, METH_NOARGS, nullptr},
{"_is_mps_available", MPSModule_isAvailable, METH_NOARGS, nullptr},
{"_is_mps_on_macos_13_or_newer",
MPSModule_isMacOS13orNewer,
METH_NOARGS,
nullptr},
{"_mps_get_default_generator",
MPSModule_getDefaultMPSGenerator,
METH_NOARGS,
nullptr},
{nullptr}};

PyMethodDef* python_functions() {
return _MPSModule_methods;
}

} // namespace mps
} // namespace torch
11 changes: 11 additions & 0 deletions torch/csrc/mps/Module.h
@@ -0,0 +1,11 @@
#pragma once

#include <torch/csrc/python_headers.h>

namespace torch {
namespace mps {

PyMethodDef* python_functions();

} // namespace mps
} // namespace torch
54 changes: 54 additions & 0 deletions torch/mps/__init__.py
@@ -0,0 +1,54 @@
r"""
This package enables an interface for accessing MPS backend in python
"""
import torch
from .. import Tensor

_is_in_bad_fork = getattr(torch._C, "_mps_is_in_bad_fork", lambda: False)
_default_mps_generator: torch._C.Generator = None # type: ignore[assignment]

# local helper function (not public or exported)
def _get_default_mps_generator() -> torch._C.Generator:
global _default_mps_generator
if _default_mps_generator is None:
_default_mps_generator = torch._C._mps_get_default_generator()
return _default_mps_generator

def synchronize() -> None:
r"""Waits for all kernels in all streams on a MPS device to complete."""
return torch._C._mps_synchronize()

def get_rng_state() -> Tensor:
r"""Returns the random number generator state as a ByteTensor."""
return _get_default_mps_generator().get_state()

def set_rng_state(new_state: Tensor) -> None:
r"""Sets the random number generator state.
Args:
new_state (torch.ByteTensor): The desired state
"""
new_state_copy = new_state.clone(memory_format=torch.contiguous_format)
_get_default_mps_generator().set_state(new_state_copy)

def manual_seed(seed: int) -> None:
r"""Sets the seed for generating random numbers.
Args:
seed (int): The desired seed.
"""
# the torch.mps.manual_seed() can be called from the global
# torch.manual_seed() in torch/random.py. So we need to make
# sure mps is available (otherwise we just return without
# erroring out)
if not torch.has_mps:
return
seed = int(seed)
_get_default_mps_generator().manual_seed(seed)

def seed() -> None:
r"""Sets the seed for generating random numbers to a random number."""
_get_default_mps_generator().seed()

__all__ = [
'get_rng_state', 'manual_seed', 'seed', 'set_rng_state', 'synchronize']

0 comments on commit dbf3581

Please sign in to comment.