forked from pytorch/pytorch
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MPS] Add Python Module Bindings for the MPS backend (pytorch#94417)
- This PR is a prerequisite for the upcoming Memory Leak Detection PR. - Enable global manual seeding via `torch.manual_seed()` + test case - Add `torch.mps.synchronize()` to wait for MPS stream to finish + test case - Enable the following python interfaces for MPS: `torch.mps.[get_rng_state(), set_rng_state(), synchronize(), manual_seed(), seed()]` - Added some test cases in test_mps.py - Added `mps.rst` to document the `torch.mps` module. - Fixed the failure with `test_public_bindings.py` Description of new files added: - `torch/csrc/mps/Module.cpp`: implements `torch._C` module functions for `torch.mps` and `torch.backends.mps`. - `torch/mps/__init__.py`: implements Python bindings for `torch.mps` module. Pull Request resolved: pytorch#94417 Approved by: https://github.com/albanD
- Loading branch information
Showing
15 changed files
with
262 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
torch.mps | ||
=================================== | ||
.. automodule:: torch.mps | ||
.. currentmodule:: torch.mps | ||
|
||
.. autosummary:: | ||
:toctree: generated | ||
:nosignatures: | ||
|
||
synchronize | ||
get_rng_state | ||
set_rng_state | ||
manual_seed | ||
seed |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
#include <ATen/ATen.h> | ||
#include <c10/util/CallOnce.h> | ||
#include <torch/csrc/Generator.h> | ||
#include <torch/csrc/python_headers.h> | ||
#include <torch/csrc/utils/python_numbers.h> | ||
|
||
// pthread.h is included for tracking bad forks | ||
#ifndef WIN32 | ||
#include <pthread.h> | ||
#endif | ||
|
||
namespace torch { | ||
namespace mps { | ||
|
||
namespace { | ||
// True for children forked after mps init | ||
static bool in_bad_fork = false; | ||
|
||
// Called in the forked child if mps has already been initialized | ||
static void forked_mps_child() { | ||
in_bad_fork = true; | ||
} | ||
|
||
// Should be called before the first mps call. | ||
static void track_bad_mps_fork() { | ||
#ifndef WIN32 | ||
static c10::once_flag flag; | ||
c10::call_once( | ||
flag, [] { pthread_atfork(nullptr, nullptr, forked_mps_child); }); | ||
#endif | ||
} | ||
} // namespace | ||
|
||
static PyObject* MPSModule_isInBadFork(PyObject* self, PyObject* noargs) { | ||
HANDLE_TH_ERRORS | ||
return PyBool_FromLong(in_bad_fork); | ||
END_HANDLE_TH_ERRORS | ||
} | ||
|
||
static PyObject* MPSModule_getDefaultMPSGenerator( | ||
PyObject* _unused, | ||
PyObject* noargs) { | ||
HANDLE_TH_ERRORS | ||
track_bad_mps_fork(); | ||
return THPGenerator_initDefaultGenerator( | ||
at::detail::getMPSHooks().getDefaultMPSGenerator()); | ||
END_HANDLE_TH_ERRORS | ||
} | ||
|
||
static PyObject* MPSModule_isAvailable(PyObject* _unused, PyObject* noargs) { | ||
HANDLE_TH_ERRORS | ||
track_bad_mps_fork(); | ||
if (at::detail::getMPSHooks().hasMPS()) { | ||
Py_RETURN_TRUE; | ||
} else { | ||
Py_RETURN_FALSE; | ||
} | ||
END_HANDLE_TH_ERRORS | ||
} | ||
|
||
static PyObject* MPSModule_isMacOS13orNewer( | ||
PyObject* _unused, | ||
PyObject* noargs) { | ||
HANDLE_TH_ERRORS | ||
if (at::detail::getMPSHooks().isOnMacOS13orNewer()) { | ||
Py_RETURN_TRUE; | ||
} else { | ||
Py_RETURN_FALSE; | ||
} | ||
END_HANDLE_TH_ERRORS | ||
} | ||
|
||
static PyObject* MPSModule_synchronize(PyObject* _unused, PyObject* noargs) { | ||
HANDLE_TH_ERRORS | ||
at::detail::getMPSHooks().deviceSynchronize(); | ||
Py_RETURN_NONE; | ||
END_HANDLE_TH_ERRORS | ||
} | ||
|
||
// NOLINTNEXTLINE(modernize-avoid-c-arrays, | ||
// cppcoreguidelines-avoid-non-const-global-variables, | ||
// cppcoreguidelines-avoid-c-arrays) | ||
static struct PyMethodDef _MPSModule_methods[] = { | ||
{"_mps_synchronize", MPSModule_synchronize, METH_NOARGS, nullptr}, | ||
{"_mps_is_in_bad_fork", MPSModule_isInBadFork, METH_NOARGS, nullptr}, | ||
{"_is_mps_available", MPSModule_isAvailable, METH_NOARGS, nullptr}, | ||
{"_is_mps_on_macos_13_or_newer", | ||
MPSModule_isMacOS13orNewer, | ||
METH_NOARGS, | ||
nullptr}, | ||
{"_mps_get_default_generator", | ||
MPSModule_getDefaultMPSGenerator, | ||
METH_NOARGS, | ||
nullptr}, | ||
{nullptr}}; | ||
|
||
PyMethodDef* python_functions() { | ||
return _MPSModule_methods; | ||
} | ||
|
||
} // namespace mps | ||
} // namespace torch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
#pragma once | ||
|
||
#include <torch/csrc/python_headers.h> | ||
|
||
namespace torch { | ||
namespace mps { | ||
|
||
PyMethodDef* python_functions(); | ||
|
||
} // namespace mps | ||
} // namespace torch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
r""" | ||
This package enables an interface for accessing MPS backend in python | ||
""" | ||
import torch | ||
from .. import Tensor | ||
|
||
_is_in_bad_fork = getattr(torch._C, "_mps_is_in_bad_fork", lambda: False) | ||
_default_mps_generator: torch._C.Generator = None # type: ignore[assignment] | ||
|
||
# local helper function (not public or exported) | ||
def _get_default_mps_generator() -> torch._C.Generator: | ||
global _default_mps_generator | ||
if _default_mps_generator is None: | ||
_default_mps_generator = torch._C._mps_get_default_generator() | ||
return _default_mps_generator | ||
|
||
def synchronize() -> None: | ||
r"""Waits for all kernels in all streams on a MPS device to complete.""" | ||
return torch._C._mps_synchronize() | ||
|
||
def get_rng_state() -> Tensor: | ||
r"""Returns the random number generator state as a ByteTensor.""" | ||
return _get_default_mps_generator().get_state() | ||
|
||
def set_rng_state(new_state: Tensor) -> None: | ||
r"""Sets the random number generator state. | ||
Args: | ||
new_state (torch.ByteTensor): The desired state | ||
""" | ||
new_state_copy = new_state.clone(memory_format=torch.contiguous_format) | ||
_get_default_mps_generator().set_state(new_state_copy) | ||
|
||
def manual_seed(seed: int) -> None: | ||
r"""Sets the seed for generating random numbers. | ||
Args: | ||
seed (int): The desired seed. | ||
""" | ||
# the torch.mps.manual_seed() can be called from the global | ||
# torch.manual_seed() in torch/random.py. So we need to make | ||
# sure mps is available (otherwise we just return without | ||
# erroring out) | ||
if not torch.has_mps: | ||
return | ||
seed = int(seed) | ||
_get_default_mps_generator().manual_seed(seed) | ||
|
||
def seed() -> None: | ||
r"""Sets the seed for generating random numbers to a random number.""" | ||
_get_default_mps_generator().seed() | ||
|
||
__all__ = [ | ||
'get_rng_state', 'manual_seed', 'seed', 'set_rng_state', 'synchronize'] |
Oops, something went wrong.