Skip to content

Commit

Permalink
Enhanced generators with grad-mode decorators (pytorch#49017)
Browse files Browse the repository at this point in the history
Summary:
This PR addresses the feature request outlined in pytorch#48713 for two-way communication with enhanced generators from [pep-342](https://www.python.org/dev/peps/pep-0342/).

Briefly, the logic of the patch resembles `yield from` [pep-380](https://www.python.org/dev/peps/pep-0380/), which cannot be used, since the generator **must be interacted with from within the grad-mode context**, while yields from the decorator **must take place outside of the context**. Hence any interaction with the wrapped generator, be it via [.send](https://docs.python.org/3/reference/expressions.html?highlight=throw#generator.send), [.throw](https://docs.python.org/3/reference/expressions.html?highlight=throw#generator.throw), and even [.close](https://docs.python.org/3/reference/expressions.html?highlight=throw#generator.close) must be wrapped by a `with` clause. The patch is compatible with `for i in gen: pass` and `next(gen)` use cases and allows two-way communication with the generator via `.send <-> yield` points.

### Logic
At lines [L37-L38](https://github.com/ivannz/pytorch/blob/2d40296c0c6617b3980c86762be466c995aa7f8e/torch/autograd/grad_mode.py#L37-L38) we (the decorator) **start the wrapped generator** (coroutine) by issuing `None` into it (equivalently, we can use `next(get)` here). Then we **dispatch responses of the generator** to our ultimate caller and **relay the latter's requests** into the generator in the loop on lines [L39-L52](https://github.com/ivannz/pytorch/blob/2d40296c0c6617b3980c86762be466c995aa7f8e/torch/autograd/grad_mode.py#L39-L52).

We yield the most recent response on [L40-L41](https://github.com/ivannz/pytorch/blob/2d40296c0c6617b3980c86762be466c995aa7f8e/torch/autograd/grad_mode.py#L40-L41), at which point we become **paused**, waiting for the next ultimate caller's interaction with us. If the caller **sends us a request**, then we become unpaused and move to [L51-L52](https://github.com/ivannz/pytorch/blob/2d40296c0c6617b3980c86762be466c995aa7f8e/torch/autograd/grad_mode.py#L51-L52) and **forward it into the generator**, at which point we pause, waiting for its response. The response might be a value, an exception or a `StopIteration`. In the case of an exception from the generator, we let it **bubble up** from the immediately surrounding [except clause](https://docs.python.org/3/reference/compound_stmts.html#the-try-statement)  to the ultimate caller through the [outer try-except](https://github.com/ivannz/pytorch/blob/2dc287bba87fa6f05c49446c0239ffdcdb1e896e/torch/autograd/grad_mode.py#L36-L54). In the case of a `StopIteration`, we **take it's payload and propagate it** to the caller via [return](https://github.com/ivannz/pytorch/blob/2d40296c0c6617b3980c86762be466c995aa7f8e/torch/autograd/grad_mode.py#L54). In the case of a value, the flow and the loop continues.

The caller **throwing an exception at us** is handled much like a proper request, except for the exception playing the role of the request. In this case we **forward it into the generator** on lines [L47-L49](https://github.com/ivannz/pytorch/blob/2d40296c0c6617b3980c86762be466c995aa7f8e/torch/autograd/grad_mode.py#L47-L49) and await its response. We explicitly **advance** the traceback one frame up, in order to indicate the **source of the exception within the generator**.

Finally the `GeneratorExit` is handled on lines [L42-L45](https://github.com/ivannz/pytorch/blob/2d40296c0c6617b3980c86762be466c995aa7f8e/torch/autograd/grad_mode.py#L42-L45) and closes the generator.

Updates: clarified exception propagation

Pull Request resolved: pytorch#49017

Reviewed By: izdeby

Differential Revision: D25567796

Pulled By: albanD

fbshipit-source-id: 801577cccfcb2b5e13a08e77faf407881343b7b0
  • Loading branch information
ivannz authored and hwangdeyu committed Dec 23, 2020
1 parent 076d62f commit 197266d
Show file tree
Hide file tree
Showing 2 changed files with 222 additions and 7 deletions.
181 changes: 181 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1161,6 +1161,187 @@ def no_grad_context_manager_recursive(depth):
enable_grad_context_manager_recursive(10)
self.assertFalse(torch.is_grad_enabled())

def test_set_grad_coroutines(self):
@torch.no_grad()
def coro_no_grad(n=10):
self.assertFalse(torch.is_grad_enabled())
for i in range(n):
self.assertFalse(torch.is_grad_enabled())
r = yield i
self.assertFalse(torch.is_grad_enabled())
self.assertEqual(i, r)
self.assertFalse(torch.is_grad_enabled())

@torch.enable_grad()
def coro_enable_grad(n=10):
self.assertTrue(torch.is_grad_enabled())
for i in range(n):
self.assertTrue(torch.is_grad_enabled())
r = yield i
self.assertTrue(torch.is_grad_enabled())
self.assertEqual(i, r)
self.assertTrue(torch.is_grad_enabled())

with torch.enable_grad():
self.assertTrue(torch.is_grad_enabled())
coro, r = coro_no_grad(), None
try:
while True:
self.assertTrue(torch.is_grad_enabled())
r = coro.send(r)
self.assertTrue(torch.is_grad_enabled())

except StopIteration:
pass

with torch.no_grad():
self.assertFalse(torch.is_grad_enabled())
coro, r = coro_enable_grad(), None
try:
while True:
self.assertFalse(torch.is_grad_enabled())
r = coro.send(r)
self.assertFalse(torch.is_grad_enabled())

except StopIteration:
pass

def test_set_grad_coroutines_benign_exceptions(self):
class RecoverableException(Exception):
pass

@torch.no_grad()
def coro_no_grad(n=10):
has_raised = False
for i in range(n):
try:
self.assertFalse(torch.is_grad_enabled())
yield (-i if has_raised else i)

except RecoverableException:
self.assertFalse(torch.is_grad_enabled())
has_raised = True

@torch.enable_grad()
def coro_enable_grad(n=10):
has_raised = False
for i in range(n):
try:
self.assertTrue(torch.is_grad_enabled())
yield (-i if has_raised else i)

except RecoverableException:
self.assertTrue(torch.is_grad_enabled())
has_raised = True

with torch.enable_grad():
coro = coro_no_grad()
assert 0 == next(coro)
try:
while True:
r = coro.throw(RecoverableException)
self.assertLess(r, 0)

except StopIteration:
pass

with torch.no_grad():
coro = coro_enable_grad()
assert 0 == next(coro)
try:
while True:
r = coro.throw(RecoverableException)
self.assertLess(r, 0)

except StopIteration:
pass

def test_set_grad_coroutines_critical_exceptions(self):
class UnrecoverableException(Exception):
pass

class SecondaryException(Exception):
pass

@torch.no_grad()
def coro_no_grad(n=10):
has_raised = False
for i in range(n):
try:
self.assertFalse(torch.is_grad_enabled())
yield (-i if has_raised else i)

except UnrecoverableException:
self.assertFalse(torch.is_grad_enabled())
raise SecondaryException

@torch.enable_grad()
def coro_enable_grad(n=10):
has_raised = False
for i in range(n):
try:
self.assertTrue(torch.is_grad_enabled())
yield (-i if has_raised else i)

except UnrecoverableException:
self.assertTrue(torch.is_grad_enabled())
raise SecondaryException

with torch.enable_grad():
coro = coro_no_grad()
assert 0 == next(coro)
with self.assertRaises(SecondaryException):
coro.throw(UnrecoverableException)

with torch.no_grad():
coro = coro_enable_grad()
assert 0 == next(coro)
with self.assertRaises(SecondaryException):
coro.throw(UnrecoverableException)

def test_set_grad_coroutines_exit(self):
@torch.no_grad()
def coro_no_grad(state):
for i in range(10):
try:
self.assertFalse(torch.is_grad_enabled())
yield i

except GeneratorExit:
self.assertFalse(torch.is_grad_enabled())
state.add('GeneratorExit')
raise

@torch.enable_grad()
def coro_enable_grad(state):
for i in range(10):
try:
self.assertTrue(torch.is_grad_enabled())
yield i

except GeneratorExit:
self.assertTrue(torch.is_grad_enabled())
state.add('GeneratorExit')
raise

state = set()
with torch.enable_grad():
coro = coro_no_grad(state)
for i in range(5):
next(coro)

coro.close()
self.assertTrue('GeneratorExit' in state)

state = set()
with torch.no_grad():
coro = coro_enable_grad(state)
for i in range(5):
next(coro)

coro.close()
self.assertTrue('GeneratorExit' in state)

def test_no_grad_python_function(self):
"""Python Functions should respect grad mode."""
x = torch.ones(5, 5, requires_grad=True)
Expand Down
48 changes: 41 additions & 7 deletions torch/autograd/grad_mode.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
import torch
import functools
import inspect
Expand Down Expand Up @@ -31,13 +32,46 @@ def _wrap_generator(self, func):
@functools.wraps(func)
def generator_context(*args, **kwargs):
gen = func(*args, **kwargs)
while True:
try:
with self.__class__():
x = next(gen)
yield x
except StopIteration:
break

# Generators are suspended and unsuspended at `yield`, hence we
# make sure the grad mode is properly set every time the execution
# flow returns into the wrapped generator and restored when it
# returns through our `yield` to our caller (see PR #49017).
cls = type(self)
try:
# Issuing `None` to a generator fires it up
with cls():
response = gen.send(None)

while True:
try:
# Forward the response to our caller and get its next request
request = yield response

except GeneratorExit:
# Inform the still active generator about its imminent closure
with cls():
gen.close()
raise

except BaseException:
# Propagate the exception thrown at us by the caller
with cls():
response = gen.throw(*sys.exc_info())

else:
# Pass the last request to the generator and get its response
with cls():
response = gen.send(request)

# We let the exceptions raised above by the generator's `.throw` or
# `.send` methods bubble up to our caller, except for StopIteration
except StopIteration as e:
# The generator informed us that it is done: take whatever its
# returned value (if any) was and indicate that we're done too
# by returning it (see docs for python's return-statement).
return e.value

return generator_context

def __enter__(self) -> None:
Expand Down

0 comments on commit 197266d

Please sign in to comment.