Skip to content

Commit

Permalink
Follow-up for pytorch#37091.
Browse files Browse the repository at this point in the history
  • Loading branch information
hameerabbasi committed Aug 10, 2020
1 parent 05f0053 commit 4dcbc73
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 33 deletions.
17 changes: 14 additions & 3 deletions test/test_overrides.py
Expand Up @@ -285,7 +285,14 @@ def sub_diagonal_foo(a, b, c=None):

# The dispatch table for SubDiagonalTensor's __torch_function__ implementation.
HANDLED_FUNCTIONS_TENSOR_LIKE = {}
HANDLED_FUNCTIONS_WRAPPERS = {}


# Note: _triggered wrapper
# Dict that wraps the implementations from get_testing_overrides into another
# function with a _triggered slot/flag. The triggered flag is set when the
# implementation is called.
WRAPPED_TRIGGERED_IMPLS = {}


def triggered_wrapper(f):
@functools.wraps(f)
Expand Down Expand Up @@ -324,7 +331,8 @@ def generate_tensor_like_torch_implementations():
# decorate the overrides with implements_tensor_like if it's not a
# torch.Tensor method
wrapped = triggered_wrapper(override)
HANDLED_FUNCTIONS_WRAPPERS[func] = wrapped
# See note: "_triggered wrapper"
WRAPPED_TRIGGERED_IMPLS[func] = wrapped
if is_tensor_method_or_property(func):
implements_sub(func)(wrapped)
else:
Expand Down Expand Up @@ -549,6 +557,7 @@ def instance_gen():
t = t[:-1]
if t == 'Tensor':
if arg['name'] == 'self' and is_tensor_method_or_property(func):
# See "Note: properties and __get__"
func = func.__get__(instance_gen())
continue
func_args.append(instance_gen())
Expand Down Expand Up @@ -590,8 +599,9 @@ def test(self):
# ret is None for certain protocols, e.g., `__weakref__` and `__setitem__`
# This is currently the best check but doesn't work for, for example,
# Tensor.__add__ because it redirects to Tensor.add.
# See note "_triggered wrapper"
if ret is None:
self.assertTrue(HANDLED_FUNCTIONS_WRAPPERS[func]._triggered)
self.assertTrue(WRAPPED_TRIGGERED_IMPLS[func]._triggered)
return

self.assertEqual(ret, -1)
Expand All @@ -601,6 +611,7 @@ def test(self):
for func, override in get_testing_overrides().items():
test_method = test_generator(func, override)
if func.__name__ == "__get__":
# Note: properties and __get__
# __get__ is part of the descriptor protocol.
# https://docs.python.org/3/howto/descriptor.html
# This is used for properties of the form
Expand Down
18 changes: 6 additions & 12 deletions tools/autograd/templates/python_variable_methods.cpp
Expand Up @@ -325,12 +325,9 @@ static bool dispatch_to_Bool(const Tensor & self) {
static PyObject * THPVariable_float_scalar(PyObject* self, PyObject* args) {
HANDLE_TH_ERRORS
if (check_has_torch_function(self)) {
try {
return handle_torch_function(self, "__bool__");
}
catch(const python_error&) {
return nullptr;
}
HANDLE_TH_ERRORS
return handle_torch_function(self, "__bool__");
END_HANDLE_TH_ERRORS
}
jit::tracer::warn("Converting a tensor to a Python float", jit::tracer::WARN_PYTHON_DATAFLOW);
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
Expand Down Expand Up @@ -1016,12 +1013,9 @@ static PyObject * THPVariable_type(PyObject* self, PyObject* args, PyObject* kwa

static PyObject * THPVariable_bool_scalar(PyObject* self, PyObject* args) {
if (check_has_torch_function(self)) {
try {
return handle_torch_function(self, "__bool__");
}
catch(const python_error&) {
return nullptr;
}
HANDLE_TH_ERRORS
return handle_torch_function(self, "__bool__");
END_HANDLE_TH_ERRORS
}
jit::tracer::warn("Converting a tensor to a Python boolean", jit::tracer::WARN_PYTHON_DATAFLOW);
return THPVariable_is_nonzero(self, args);
Expand Down
27 changes: 9 additions & 18 deletions torch/csrc/autograd/python_variable.cpp
Expand Up @@ -335,12 +335,9 @@ int THPVariable_set_grad(THPVariable *self, PyObject *py_grad, void *unused)
PyObject *THPVariable_get_volatile(THPVariable *self, void *unused)
{
if (check_has_torch_function((PyObject *)self)) {
try {
return handle_torch_function_getter(self, "volatile");
}
catch (const python_error&) {
return nullptr;
}
HANDLE_TH_ERRORS
return handle_torch_function_getter(self, "volatile");
END_HANDLE_TH_ERRORS
}
const char* msg = "volatile was removed (Variable.volatile is always False)";
PyErr_WarnEx(PyExc_UserWarning, msg, 1);
Expand All @@ -350,12 +347,9 @@ PyObject *THPVariable_get_volatile(THPVariable *self, void *unused)
int THPVariable_set_volatile(THPVariable *self, PyObject *obj, void *unused)
{
if (check_has_torch_function((PyObject *)self)) {
try {
return handle_torch_function_setter(self, "volatile", obj);
}
catch (const python_error&) {
return -1;
}
HANDLE_TH_ERRORS
return handle_torch_function_setter(self, "volatile", obj);
END_HANDLE_TH_ERRORS_RET(-1)
}
return PyErr_WarnEx(PyExc_UserWarning, VOLATILE_WARNING, 1);
}
Expand Down Expand Up @@ -469,12 +463,9 @@ int THPVariable_set_requires_grad(THPVariable *self, PyObject *obj, void *unused
PyObject *THPVariable_get_name(THPVariable* self, void *unused)
{
if (check_has_torch_function((PyObject *)self)) {
try {
return handle_torch_function_getter(self, "name");
}
catch (const python_error&) {
return nullptr;
}
HANDLE_TH_ERRORS
return handle_torch_function_getter(self, "name");
END_HANDLE_TH_ERRORS
}
if (self->cdata.name() == "")
Py_RETURN_NONE;
Expand Down

0 comments on commit 4dcbc73

Please sign in to comment.