Skip to content

Commit

Permalink
Fix XLA symbolic shapes binding (pytorch#88928)
Browse files Browse the repository at this point in the history
Obsoletes pytorch#88772

Mostly revolves around NOT assuming that the inside is a SymNode,
but instead duck-typed to be a SymNode.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: pytorch#88928
Approved by: https://github.com/SherlockNoMad
  • Loading branch information
ezyang authored and pytorchmergebot committed Nov 13, 2022
1 parent 2aca97c commit 46796fe
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 60 deletions.
3 changes: 0 additions & 3 deletions c10/core/SymNodeImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,6 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
virtual SymNode clone() {
TORCH_CHECK(false, "NYI");
};
virtual SymNode sym_int() {
TORCH_CHECK(false, "NYI");
}
virtual SymNode sym_float() {
TORCH_CHECK(false, "NYI");
}
Expand Down
6 changes: 3 additions & 3 deletions test/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch.utils._pytree import tree_map
from torch.fx.experimental import symbolic_shapes
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import ShapeEnv, sym_float, guard_int, SymNode, sym_sqrt, sym_int
from torch.fx.experimental.symbolic_shapes import ShapeEnv, sym_float, guard_int, SymNode, sym_sqrt, sym_int, to_node
from torch.utils._python_dispatch import TorchDispatchMode
from torch import SymInt

Expand Down Expand Up @@ -478,9 +478,9 @@ def _do_test(self, fn, inp1, inp2, shape_env, is_unary_fn):

def get_sym_inp(inp):
if isinstance(inp, int):
return torch.SymInt(seed_node.to_node(inp))
return torch.SymInt(to_node(seed_node, inp))
else:
return torch.SymFloat(seed_node.to_node(inp))
return torch.SymFloat(to_node(seed_node, inp))

def maybe_xfail(inp1, inp2):
key = (fn, type(inp1).__name__, type(inp2).__name__)
Expand Down
2 changes: 0 additions & 2 deletions torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,6 @@ class SymInt:
"""

def __init__(self, node):
from torch.fx.experimental.symbolic_shapes import SymNode
assert isinstance(node, SymNode)
# This field MUST be named node; C++ binding code assumes that this
# class has a field named node that stores SymNode
self.node = node
Expand Down
77 changes: 52 additions & 25 deletions torch/csrc/jit/python/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1148,38 +1148,65 @@ void initJITBindings(PyObject* module) {
// NB: This isn't actually used for regular PyTorch symbolic tracing;
// XLA is what needs this
#define SYMNODE_UNARY(n) .def(#n, [](c10::SymNode a) { return a->n(); })
#define SYMNODE_UNARY2(n2, n) .def(#n2, [](c10::SymNode a) { return a->n(); })
#define SYMNODE_BINARY(n) \
.def(#n, [](c10::SymNode a, c10::SymNode b) { return a->n(b); })
auto symnode_class =
py::class_<c10::SymNodeImpl, c10::SymNode>(m, "_SymNode")
// clang-format off
// These DO NOT install magic methods; the SymInt/SymFloat wrapper in
// Python is responsible for this
SYMNODE_UNARY(clone)
// Named these for consistency with inner python class, but maybe
// should change the python side
SYMNODE_UNARY2(__bool__, bool_) SYMNODE_UNARY2(__int__, int_)
SYMNODE_UNARY2(__sym_int__, sym_int) SYMNODE_UNARY2(
__sym_float__, sym_float) SYMNODE_BINARY(add) SYMNODE_BINARY(sub)
SYMNODE_BINARY(mul) SYMNODE_BINARY(truediv) SYMNODE_BINARY(pow)
SYMNODE_BINARY(floordiv) SYMNODE_BINARY(mod) SYMNODE_BINARY(
eq) SYMNODE_BINARY(gt) SYMNODE_BINARY(lt)
SYMNODE_BINARY(le) SYMNODE_BINARY(ge) SYMNODE_BINARY(min)
SYMNODE_BINARY(max) SYMNODE_UNARY(ceil)
SYMNODE_UNARY(floor) SYMNODE_UNARY(neg)
// Intentionally don't set file line, as the
// Python backtrace matters more here
.def(
"guard_int",
[](c10::SymNode a) {
return a->guard_int(nullptr, 0);
})
.def(
"__str__",
[](c10::SymNode a) { return a->str(); })
.def("__repr__", [](c10::SymNode a) {
return a->str();
});
SYMNODE_UNARY(is_int)
SYMNODE_UNARY(is_float)
SYMNODE_UNARY(bool_)
SYMNODE_UNARY(int_)
SYMNODE_UNARY(sym_float)
SYMNODE_BINARY(add)
SYMNODE_BINARY(sub)
SYMNODE_BINARY(mul)
SYMNODE_BINARY(truediv)
SYMNODE_BINARY(pow)
SYMNODE_BINARY(floordiv)
SYMNODE_BINARY(mod)
SYMNODE_BINARY(eq)
SYMNODE_BINARY(gt)
SYMNODE_BINARY(lt)
SYMNODE_BINARY(le)
SYMNODE_BINARY(ge)
SYMNODE_BINARY(min)
SYMNODE_BINARY(max)
SYMNODE_UNARY(ceil)
SYMNODE_UNARY(floor)
SYMNODE_UNARY(neg)
// Intentionally don't set file line, as the
// Python backtrace matters more here
.def(
"guard_int",
[](c10::SymNode a) {
return a->guard_int(nullptr, 0);
})
.def(
"guard_float",
[](c10::SymNode a) {
return a->guard_float(nullptr, 0);
})
.def(
"wrap_int",
[](c10::SymNode a, int64_t b) {
return a->wrap_int(b);
})
.def(
"wrap_float",
[](c10::SymNode a, double b) {
return a->wrap_float(b);
})
.def(
"__str__",
[](c10::SymNode a) { return a->str(); })
.def("__repr__", [](c10::SymNode a) {
return a->str();
});
// clang-format on

// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<CompleteArgumentSpec>(m, "CompleteArgumentSpec")
Expand Down
14 changes: 11 additions & 3 deletions torch/csrc/utils/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,19 @@ py::handle type_caster<c10::SymInt>::cast(
return_value_policy /* policy */,
handle /* parent */) {
if (si.is_symbolic()) {
// TODO: generalize this to work with C++ backed class
auto* py_node =
dynamic_cast<torch::impl::PythonSymNodeImpl*>(si.toSymNodeImpl().get());
TORCH_INTERNAL_ASSERT(py_node);
return torch::get_symint_class()(py_node->getPyObj()).release();
if (py_node) {
// Return the Python directly (unwrap)
return torch::get_symint_class()(py_node->getPyObj()).release();
} else {
// Wrap the C++ into Python
auto inner = py::cast(si.toSymNodeImpl());
if (!inner) {
throw python_error();
}
return torch::get_symint_class()(inner).release();
}
} else {
return py::cast(si.as_int_unchecked()).release();
}
Expand Down
4 changes: 0 additions & 4 deletions torch/csrc/utils/python_symnode.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,6 @@ class PythonSymNodeImpl : public c10::SymNodeImpl {
return dispatch_common_(__FUNCTION__);
}

c10::SymNode sym_int() override {
return dispatch_common_(__FUNCTION__);
}

c10::SymNode sym_float() override {
return dispatch_common_(__FUNCTION__);
}
Expand Down
39 changes: 19 additions & 20 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,18 @@ def sym_int(a):
return sym_floor(a) if a > 0 else sym_ceil(a)
return int(a)

def to_node(self, num):
if isinstance(num, (SymInt, SymFloat)):
return num.node
elif isinstance(num, int):
return self.wrap_int(num)
elif isinstance(num, float):
return self.wrap_float(num)
else:
# NotImplemented is important so that Python tries the
# other magic method
return NotImplemented

# TODO: An incomplete list
# 1. Set variables to be equal when we do equality
# 2. Specialize on 0/1 when we do subtraction
Expand All @@ -148,18 +160,6 @@ def expr(self):
def _update_expr(self):
self._expr = self.shape_env.replace(self._expr)

def to_node(self, num):
if isinstance(num, (SymInt, SymFloat)):
return num.node
elif isinstance(num, int):
return self.wrap_int(num)
elif isinstance(num, float):
return self.wrap_float(num)
else:
# NotImplemented is important so that Python tries the
# other magic method
return NotImplemented

def is_int(self):
return self.pytype is int

Expand Down Expand Up @@ -297,16 +297,15 @@ def _nyi():
always_bool_magic_methods = {"eq", "gt", "lt", "le", "ge"}

def wrap_node(x):
if not isinstance(x, SymNode):
return x
if x.constant is not None:
# TODO: let C++ also take advantage of this
if isinstance(x, SymNode) and x.constant is not None:
return x.constant
if x.pytype is int:
if x.is_int():
return SymInt(x)
elif x.pytype is float:
elif x.is_float():
return SymFloat(x)
else:
raise AssertionError(f"unrecognized return type {x.pytype}")
raise AssertionError(f"unrecognized return type {x}")

def _make_node_magic(method, func):
func = lru_cache(256)(func)
Expand Down Expand Up @@ -378,13 +377,13 @@ def unary_magic_impl(self):
return wrap_node(getattr(self.node, method)())

def binary_magic_impl(self, other):
other_node = self.node.to_node(other)
other_node = to_node(self.node, other)
if other_node is NotImplemented:
return NotImplemented
return wrap_node(getattr(self.node, method)(other_node))

def rbinary_magic_impl(self, other):
other_node = self.node.to_node(other)
other_node = to_node(self.node, other)
if other_node is NotImplemented:
return NotImplemented
return wrap_node(getattr(other_node, method)(self.node))
Expand Down

0 comments on commit 46796fe

Please sign in to comment.