Skip to content

Commit

Permalink
[c10d] Implement __instancecheck__ for c10d::ReduceOp (pytorch#88275
Browse files Browse the repository at this point in the history
)

Summary:
- Customize the metaclass of `torch.distributed.distributed_c10d.ReduceOp` for the sake of custom `__instancecheck__`
- Add `copy.copy`, `copy.deepcopy`, and `pickle` support with tests

Rel:
- pytorch#81272
- pytorch#84243
- pytorch#87191
- pytorch#87303
- pytorch#87555

Ref:
- pybind/pybind11#2696

Pull Request resolved: pytorch#88275
Approved by: https://github.com/wanchaol
  • Loading branch information
crcrpar authored and kulinseth committed Dec 9, 2022
1 parent cf0fb0d commit 80fa262
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 40 deletions.
32 changes: 31 additions & 1 deletion test/distributed/test_c10d_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import copy
import os
import pickle
import sys
import tempfile
import threading
Expand Down Expand Up @@ -1657,15 +1658,44 @@ def comm_fn(tensor, group=None):

class ReduceOpTest(TestCase):

# Ref: https://github.com/pytorch/pytorch/issues/87191
def test_op_isinstance_of_reduceop(self):
for reduce_op in (
c10d.ReduceOp.SUM, c10d.ReduceOp.AVG, c10d.ReduceOp.PRODUCT, c10d.ReduceOp.MIN, c10d.ReduceOp.MAX,
c10d.ReduceOp.BAND, c10d.ReduceOp.BOR, c10d.ReduceOp.BXOR,
):
self.assertTrue(isinstance(reduce_op, c10d.ReduceOp))
for scale in ([torch.tensor(1.0)], 2.0):
for scale in (torch.tensor(1.0), 2.0):
self.assertTrue(isinstance(dist._make_nccl_premul_sum(scale), c10d.ReduceOp))

# Ref: https://github.com/pytorch/pytorch/pull/87303#discussion_r1002879700
def test_reduceop_copyable(self):
for reduce_op in (
c10d.ReduceOp.SUM, c10d.ReduceOp.AVG, c10d.ReduceOp.PRODUCT, c10d.ReduceOp.MIN, c10d.ReduceOp.MAX,
c10d.ReduceOp.BAND, c10d.ReduceOp.BOR, c10d.ReduceOp.BXOR,
):
self.assertEqual(copy.copy(reduce_op), reduce_op)
self.assertEqual(copy.deepcopy(reduce_op), reduce_op)
self.assertEqual(copy.copy(c10d.ReduceOp(reduce_op)), reduce_op)
self.assertEqual(copy.deepcopy(c10d.ReduceOp(reduce_op)), reduce_op)

for scale in (torch.tensor(1.0), 2.0):
reduce_op = dist._make_nccl_premul_sum(scale)
self.assertEqual(copy.copy(reduce_op), reduce_op)
self.assertEqual(copy.deepcopy(reduce_op), reduce_op)

def test_reduceop_pickle(self):
for reduce_op in (
c10d.ReduceOp.SUM, c10d.ReduceOp.AVG, c10d.ReduceOp.PRODUCT, c10d.ReduceOp.MIN, c10d.ReduceOp.MAX,
c10d.ReduceOp.BAND, c10d.ReduceOp.BOR, c10d.ReduceOp.BXOR,
):
pickle.loads(pickle.dumps(reduce_op))
orig = c10d.ReduceOp(reduce_op)
self.assertEqual(pickle.loads(pickle.dumps(orig)), orig)
for scale in (torch.tensor(1.0), 2.0):
reduce_op = dist._make_nccl_premul_sum(scale)
self.assertEqual(pickle.loads(pickle.dumps(reduce_op)), reduce_op)


if __name__ == "__main__":
assert (
Expand Down
18 changes: 8 additions & 10 deletions test/distributed/test_c10d_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,16 +348,14 @@ def allreduce(tensors, op):
# Premul Sum
if torch.cuda.nccl.version() >= (2, 11, 1):
for dtype in torch.half, torch.float, torch.double:
for factor in (3.0,
(torch.tensor([5.0], device=local_device_id, dtype=dtype),)):
for factor in (3.0, torch.tensor([5.0], device=local_device_id, dtype=dtype)):
tensors = [torch.tensor([self.rank + 1]).cuda(local_device_id).to(dtype=dtype)]

allreduce(tensors, c10d._make_nccl_premul_sum(factor))

f = factor if isinstance(factor, float) else factor[0]
# TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
self.assertEqualIgnoreType(
f * torch.tensor([float(self.world_size * (self.world_size + 1) / 2)], device=local_device_id),
factor * torch.tensor([float(self.world_size * (self.world_size + 1) / 2)], device=local_device_id),
tensors[0],
)

Expand Down Expand Up @@ -435,9 +433,9 @@ def reduce(xs, rootRank, rootTensor, op=None):

# Premul sum
if torch.cuda.nccl.version() >= (2, 11, 1):
for factor in (3.0, (torch.tensor([5.0], device=local_device_id),)):
if isinstance(factor, tuple):
factor_ref = factor[0].cpu().item()
for factor in (3.0, torch.tensor([5.0], device=local_device_id)):
if isinstance(factor, torch.Tensor):
factor_ref = factor.cpu().item()
else:
factor_ref = factor
float_tensors = [
Expand Down Expand Up @@ -933,9 +931,9 @@ def perm(n, k):
self.assertEqualIgnoreType(expected, output_tensor)

if torch.cuda.nccl.version() >= (2, 11, 1):
for factor in (3.0, (torch.tensor([5.0], device=self.rank),),):
if isinstance(factor, tuple):
factor_ref = factor[0].cpu().item()
for factor in (3.0, torch.tensor([5.0], device=self.rank)):
if isinstance(factor, torch.Tensor):
factor_ref = factor.cpu().item()
else:
factor_ref = factor
output = [t.float() for t in output]
Expand Down
2 changes: 2 additions & 0 deletions torch/_C/_distributed_c10d.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ class DebugLevel(Enum):

class ReduceOp:

def __init__(self, op: "RedOpType"): ...

SUM = ...
PRODUCT = ...
MIN = ...
Expand Down
7 changes: 3 additions & 4 deletions torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,10 @@ ncclRedOpRAII unpackPreMulSum(
const auto* preMulSupplement =
reinterpret_cast<NCCLPreMulSumSupplement*>(reduceOp.supplement_.get());
ncclRedOp_t preMulSum;
bool has_tensor = !preMulSupplement->tensor_factors.empty();
bool has_tensor = preMulSupplement->tensor_factor.defined();
auto residence = has_tensor ? ncclScalarDevice : ncclScalarHostImmediate;
T* ptr_factor = has_tensor
? preMulSupplement->tensor_factors[dev_in_group].data_ptr<T>()
: nullptr;
T* ptr_factor =
has_tensor ? preMulSupplement->tensor_factor.data_ptr<T>() : nullptr;
T scalar_factor = T(preMulSupplement->double_factor);
ncclRedOpCreatePreMulSum(
&preMulSum,
Expand Down
9 changes: 6 additions & 3 deletions torch/csrc/distributed/c10d/Types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <ATen/core/ivalue.h>
#include <ATen/core/Tensor.h>

#include <c10/macros/Macros.h>
#include <c10/util/intrusive_ptr.h>

namespace c10d {
Expand All @@ -21,9 +22,11 @@ struct TORCH_API _SupplementBase : torch::CustomClassHolder {
// The point of use in ProcessGroupNCCL knows how to unpack it.
struct NCCLPreMulSumSupplement : _SupplementBase {
double double_factor{0.0};
std::vector<at::Tensor> tensor_factors;
at::Tensor tensor_factor;
NCCLPreMulSumSupplement(double f) : double_factor{f} {}
NCCLPreMulSumSupplement(std::vector<at::Tensor> f) : tensor_factors{std::move(f)} {}
NCCLPreMulSumSupplement(at::Tensor t) : tensor_factor{std::move(t)} {
TORCH_CHECK_EQ(tensor_factor.numel(), 1);
}
};

// Other ReduceOps that need different supplementary data can also
Expand Down Expand Up @@ -60,7 +63,7 @@ struct TORCH_API ReduceOp : torch::CustomClassHolder {
}
}

// The heap resource supplement_, if it exists, is managed by a shared_ptr,
// The heap resource supplement_, if it exists, is managed by a c10::intrusive_ptr,
// so constructors and operator= can be simple
ReduceOp(const ReduceOp& other) :
op_(other.op_), supplement_(other.supplement_) {}
Expand Down
117 changes: 106 additions & 11 deletions torch/csrc/distributed/c10d/init.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <torch/csrc/python_headers.h>

#include <c10/util/intrusive_ptr.h>
#include <c10/util/string_view.h>
#include <torch/csrc/distributed/c10d/FileStore.hpp>
#include <torch/csrc/distributed/c10d/TCPStore.hpp>
#include <torch/csrc/distributed/c10d/Utils.hpp>
Expand Down Expand Up @@ -235,6 +236,61 @@ void _register_builtin_comm_hook(
reducer.register_builtin_comm_hook(comm_hook_type);
}

// Customize the metaclass of ::c10d::ReduceOp for the backward compatibility.
// https://github.com/pytorch/pytorch/pull/84243 changed ::c10d::ReduceOp to
// struct from enum, sacrificing some of the Python built-in function supports
// such as `isinstance` (see https://github.com/pytorch/pytorch/issues/87191)
// and `copy` (see
// https://github.com/pytorch/pytorch/pull/87303#discussion_r1002879700). Below,
// we define a custom `isinstance` in CPython/pybind11
// (`reduceopmeta___instancecheck__`) and modify the default metaclass of
// pybind11 (`GetReduceOpMetaclass`) so that
// `isinstance(torch.distributed.ReduceOp.SUM, torch.distributed.ReduceOp)`
// returns :obj:`True` as if `ReduceOp` is enum.
// Ref:
// - https://docs.python.org/3/extending/newtypes_tutorial.html
// - https://docs.python.org/3/c-api/typeobj.html?highlight=tp_methods
// - https://github.com/pybind/pybind11/issues/2696
static PyObject* reduceopmeta___instancecheck__(
PyObject* self,
PyObject* args) {
if (Py_TYPE(self) == Py_TYPE(args)) {
Py_RETURN_TRUE;
}
if (c10::string_view(args->ob_type->tp_name).find("RedOpType") !=
c10::string_view::npos) {
Py_RETURN_TRUE;
}
Py_RETURN_FALSE;
}
static PyMethodDef reduceopmeta_methods[] = {
{"__instancecheck__",
(PyCFunction)reduceopmeta___instancecheck__,
METH_O,
"Custom `__instancecheck__` for ReduceOp"},
{NULL, NULL}};
PyTypeObject* GetReduceOpMetaclass() {
static auto* metaclass = [] {
PyTypeObject* base_metaclass =
pybind11::detail::get_internals().default_metaclass;
PyType_Slot slots[] = {
{Py_tp_base, base_metaclass},
{Py_tp_methods, reduceopmeta_methods},
{0},
};
PyType_Spec spec = {};
spec.name = "torch._C._distributed_c10d._ReduceOpMeta";
spec.basicsize = base_metaclass->tp_basicsize;
spec.flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
spec.slots = slots;
PyTypeObject* metaclass = (PyTypeObject*)PyType_FromSpec(&spec);
if (!metaclass)
throw py::error_already_set();
return metaclass;
}();
return metaclass;
}

PyObject* c10d_init(PyObject* _unused, PyObject* noargs) {
C10_LOG_API_USAGE_ONCE("c10d.python.import");

Expand Down Expand Up @@ -520,7 +576,8 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO
// making `PREMUL_SUM` callable, i.e., allowing for
// `ReduceOp.PREMUL_SUM(scale)` might be better as per @wanchaol.
// https://pybind11.readthedocs.io/en/stable/classes.html#enumerations-and-internal-types
py::class_<::c10d::ReduceOp> reduce_op(module, "ReduceOp", R"(
py::class_<::c10d::ReduceOp> reduce_op(
module, "ReduceOp", py::metaclass((PyObject*)GetReduceOpMetaclass()), R"(
An enum-like class for available reduction operations: ``SUM``, ``PRODUCT``,
``MIN``, ``MAX``, ``BAND``, ``BOR``, ``BXOR``, and ``PREMUL_SUM``.
Expand Down Expand Up @@ -562,14 +619,51 @@ This class does not support ``__members__`` property.)");
[](const ::c10d::ReduceOp& self, const ::c10d::ReduceOp& other) {
return self == other.op_;
})
.def("__hash__", [](const ::c10d::ReduceOp& self) {
return static_cast<uint8_t>(self.op_);
});

// note(crcrpar): Deliberately skip
// [`export_values`](https://pybind11.readthedocs.io/en/stable/classes.html#enumerations-and-internal-types)
// here and manually set values in Python side. See note "ReduceOp static
// class attributes to support `isinstance`"
.def(
"__hash__",
[](const ::c10d::ReduceOp& self) {
return static_cast<uint8_t>(self.op_);
})
.def(
"__copy__",
[](const ::c10d::ReduceOp& self) { return ::c10d::ReduceOp(self); })
.def(
"__deepcopy__",
[](const ::c10d::ReduceOp& self, const py::dict& memo) {
return ::c10d::ReduceOp(self);
})
.def(py::pickle(
[](const ::c10d::ReduceOp& r) {
// __getstate__
if (r.op_ != ::c10d::ReduceOp::RedOpType::PREMUL_SUM) {
return py::make_tuple(r.op_, py::none());
}
TORCH_CHECK(r.supplement_.defined(), "Invalid PREMUL_SUM ReduceOp");
const auto* preMulSupplement =
reinterpret_cast<::c10d::NCCLPreMulSumSupplement*>(
r.supplement_.get());
if (!preMulSupplement->tensor_factor.defined()) {
return py::make_tuple(r.op_, preMulSupplement->double_factor);
} else {
return py::make_tuple(r.op_, preMulSupplement->tensor_factor);
}
},
[](const py::tuple t) {
// __setstate__
TORCH_CHECK(t.size() == 2, "Invalid state");
const auto op =
static_cast<::c10d::ReduceOp::RedOpType>(t[0].cast<uint8_t>());
if (op != ::c10d::ReduceOp::RedOpType::PREMUL_SUM) {
return ::c10d::ReduceOp(op);
}
const auto preMulSupplement_factor = t[1];
if (py::isinstance<py::float_>(preMulSupplement_factor)) {
return ::c10d::makeNCCLPreMulSum(t[1].cast<double>());
} else {
return ::c10d::makeNCCLPreMulSum(t[1].cast<at::Tensor>());
}
}));

py::enum_<::c10d::ReduceOp::RedOpType>(reduce_op, "RedOpType")
.value("SUM", ::c10d::ReduceOp::RedOpType::SUM)
.value("AVG", ::c10d::ReduceOp::RedOpType::AVG)
Expand All @@ -579,7 +673,8 @@ This class does not support ``__members__`` property.)");
.value("BAND", ::c10d::ReduceOp::RedOpType::BAND)
.value("BOR", ::c10d::ReduceOp::RedOpType::BOR)
.value("BXOR", ::c10d::ReduceOp::RedOpType::BXOR)
.value("PREMUL_SUM", ::c10d::ReduceOp::RedOpType::PREMUL_SUM);
.value("PREMUL_SUM", ::c10d::ReduceOp::RedOpType::PREMUL_SUM)
.export_values();

// note(crcrpar): This could be removed because users will not pass
// `RedOpType` to reduce collective ops Ref: [Implicit
Expand All @@ -597,7 +692,7 @@ This class does not support ``__members__`` property.)");
py::call_guard<py::gil_scoped_release>())
.def(
"_make_nccl_premul_sum",
&::c10d::makeNCCLPreMulSum<std::vector<at::Tensor>>,
&::c10d::makeNCCLPreMulSum<at::Tensor>,
py::arg("factor").noconvert(),
py::return_value_policy::copy, // seems safest
py::call_guard<py::gil_scoped_release>());
Expand Down
11 changes: 0 additions & 11 deletions torch/distributed/distributed_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,17 +237,6 @@ def register_backend(cls, name, func, extended_api=False):
dist_backend = Backend


# NOTE(crcrpar): [ReduceOp static class attributes to support `isinstance`]
# A ReduceOp instance of `PREMUL_SUM` is supposed to be created via `_make_nccl_premul_sum`
# while the other `op`s (meaning RedOpType members) can be directly passed to c10d reduce collectives.
# I changed `ReduceOp` to struct from enum class and introduced RedOpType enum class for PREMUL_SUM,
# which broke an implicit contract of ReduceOp being enum-like with which users apply isinstance to
# `op`, for example, `isinstance(ReduceOp.SUM, ReduceOp)`: https://github.com/pytorch/pytorch/issues/87191
DENY_LIST = ("PREMUL_SUM", )
for _red_op_name, _red_op_value in ReduceOp.RedOpType.__members__.items():
setattr(ReduceOp, _red_op_name, _red_op_value if _red_op_name in DENY_LIST else ReduceOp(_red_op_value))


class _reduce_op(object):
r"""
Deprecated enum-like class for reduction operations: ``SUM``, ``PRODUCT``,
Expand Down

0 comments on commit 80fa262

Please sign in to comment.