Skip to content

Commit

Permalink
Rewrite dynamo cond() handling to not recursively call export (pytorc…
Browse files Browse the repository at this point in the history
…h#90286)

The original implementation of cond() operator support in dynamo operated by recursively calling export() on the inner subgraph.  This is problematic for a number of reasons:

* My original motivating reason: the original implementation had to play tricks to feed real tensors to the recursive export call, which means that it doesn't work well with tracing with dynamic shapes (where we MUST stay in fake tensors to accurately track dynamic shapes across the cond invocation)
* If there are pending side effects, the recursive export() call won't see those side effects (as they are only tracked by Dynamo, not actually applied to the Python environment.) You can see an example where dynamo cond tracing does the wrong thing at pytorch#90208
* If there were side effects inside the true/false branch, these side effects were silently lost (as the export only returns the graph of tensor operations, and not any of the residual Python bytecodes necessary to reapply any side effects.) This could have substantive effects on the export of subsequent parts of the model, as those parts of the models could rely on the side effects.
* It was not possible to track NN module accesses inside the true/false branches, necessitating a hack where the NN module was explicitly passed in as an input to cond pytorch#87020 (comment) which doesn't really make any sense from a backend compilation perspective
* Guards induced from the inside of the true/false branch were not properly propagated to the top level guards; they were just silently dropped (in fact, the original implementation checked that the true/false branch produce the same guards which... is not useful? Like, I don't think that actually is even necessary for correctness)

This PR replaces the old implementation with a new implementation based on graphstate checkpointing. The basic idea is to process a cond(), we checkpoint the state of our interpreter, run the true branch, rollback to our checkpoint, run the false branch, rollback to our checkpoint and then merge the changes from both of the checkpoints. I require the true/false branches to have exactly the same side effects, but union their guards.

Some of the details:

* Dynamo is too aggressive with tracking side effects when processing closures, c.f. https://github.com/pytorch/torchdynamo/pull/233/files#r1040480078 The basic problem is whenever I define a closure, this immediately counts as a side effect, even if I didn't actually mutate anything. This triggered on the nested cond export example. To prevent this from happening, I optimistically avoid tracking side effects, but if a STORE_DEREF happens, I restart analysis with the relevant Source.name() added to `mutated_closure_cell_contents` so we start tracking on closure allocation. This is enough to fix the relevant test.
* For the most part, I assert that the graph states must be equivalent after applying the true/false branches. During debugging, I found it useful to be able to compare two graph states and give a better description about what the divergence was. You can test this using the `diff()` method I've added to a few structures.
* The implementation now supports NestedUserFunctionVariable, which is nice as it allows the true/false branches to be defined closer to the cond implementation.
* I fixed the naming of the true/false subgraphs; previously they were named `name_0`, `name_1`, now they are named `cond_true_0` and `cond_false_0`
* I added `name_to_input` to the saved graph state. I don't actually know if this is necessary, but it seemed like a good idea.
* I have to play some tricks to get the speculating execution of the true/false branch to record into a subgraph. After a careful read of OutputGraph, I found that what would work is overriding graph with a fresh Graph that we want to write things into, and manually setting up the inputs/outputs. It's a little delicate as you have to make sure you reset the Graph to its original before you restore a checkpoint, as checkpoints don't actually save graph for efficiency, and just undo changes on the graph. This capability may usefully get refactored to OutputGraph but I didn't do it in this PR for simplicity.

There are some further problems with the cond() implementation that I leave for future work. Most of these were preexisting with the original implementation.

* Not a problem per se, but if an NN module is used by both the true/false branch, it will show up in the final graph twice (since it has to be a submodule of the GraphModule that makes use of it.) I hope the export pipeline can deal with this.
* List of tensor output for cond is not supported.
* The true/false return values may not have consistent sizes/dims/etc, and we don't check them for consistency.
* If we modify fake tensors in the true/false branches, we aren't rolling them back, c.f. https://github.com/pytorch/torchdynamo/issues/1840

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Pull Request resolved: pytorch#90286
Approved by: https://github.com/voznesenskym
  • Loading branch information
ezyang authored and kulinseth committed Dec 9, 2022
1 parent 064724e commit cb15e39
Show file tree
Hide file tree
Showing 7 changed files with 246 additions and 107 deletions.
14 changes: 7 additions & 7 deletions test/dynamo/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -1437,19 +1437,19 @@ def nop(x):
def test_export_with_module_layer(self):
from functorch.experimental.control_flow import cond

def true_fn(layer, val):
return layer(val) * torch.tensor(2)

def false_fn(layer, val):
return layer(val) * torch.tensor(-1)

class Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 3)

def forward(self, pred, x):
return cond(pred, true_fn, false_fn, [self.linear, x])
def true_fn(val):
return self.linear(val) * torch.tensor(2)

def false_fn(val):
return self.linear(val) * torch.tensor(-1)

return cond(pred, true_fn, false_fn, [x])

mod = Module()
x = torch.randn([3, 3])
Expand Down
5 changes: 4 additions & 1 deletion torch/_dynamo/convert_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import types
import weakref
from traceback import FrameSummary
from typing import cast, Dict, List, Optional
from typing import cast, Dict, List, Optional, Set

import torch
from torch.fx.graph_module import _forward_from_src as original_forward_from_src
Expand Down Expand Up @@ -364,6 +364,8 @@ def _compile(
frame: Optional[types.FrameType] = None,
) -> Optional[GuardedCode]:
output: Optional[OutputGraph] = None
# This is shared across restarts
mutated_closure_cell_contents: Set[str] = set()

# from .utils import print_once; print_once(code.co_filename)
def transform(instructions, code_options):
Expand All @@ -378,6 +380,7 @@ def transform(instructions, code_options):
compiler_fn,
one_graph,
export,
mutated_closure_cell_contents,
)
tracer.run()
output = tracer.output
Expand Down
17 changes: 17 additions & 0 deletions torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,21 @@ class OutputGraphState(NamedTuple):
nn_modules: Optional[Dict[str, torch.nn.Module]]
side_effects: SideEffects
timestamp: int
name_to_input: OrderedDict[str, Optional[fx.Proxy]]

def diff(self, other: "OutputGraphState", *, prefix: str = "") -> Optional[str]:
for k in self._fields:
if k == "side_effects":
r = self.side_effects.diff(other.side_effects)
if r is not None:
return r
continue

sv = getattr(self, k)
ov = getattr(other, k)
if sv != ov:
return f"{prefix}{k} mismatch: {sv} != {ov}"
return None


@functools.lru_cache(None)
Expand Down Expand Up @@ -227,6 +242,7 @@ def copy_graphstate(self) -> OutputGraphState:
dict(self.nn_modules),
self.side_effects.clone(),
self.timestamp,
self.name_to_input.copy(),
)
self.timestamp += 1
return state
Expand All @@ -239,6 +255,7 @@ def restore_graphstate(self, state: OutputGraphState):
self.nn_modules,
self.side_effects,
self.timestamp,
self.name_to_input,
) = state
# FX deepcopy doesn't work for a partially created graph, so just remove new nodes
for node in reversed(list(self.graph.nodes)):
Expand Down
28 changes: 27 additions & 1 deletion torch/_dynamo/side_effects.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import collections
import dataclasses
import inspect
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

import torch.nn

Expand Down Expand Up @@ -75,6 +75,32 @@ def __init__(self, id_to_variable=None, store_attr_mutations=None, keepalive=Non
self.store_attr_mutations = store_attr_mutations or collections.OrderedDict()
self.keepalive = keepalive or []

def __eq__(self, other: object) -> bool:
assert isinstance(other, SideEffects)
# NB: do NOT test keepalive
return (
self.id_to_variable == other.id_to_variable
and self.store_attr_mutations == other.store_attr_mutations
)

def diff(self, other: "SideEffects") -> Optional[str]:
if self.id_to_variable != other.id_to_variable:
sk_itv = self.id_to_variable.keys()
ok_itv = other.id_to_variable.keys()
if sk_itv != ok_itv:
return f"id_to_variable keys: {sk_itv} != {ok_itv}"
# Feel free to augment this with more fancy diffing logic
# if needed for debugging
return "id_to_variable: unknown diff"
elif self.store_attr_mutations != other.store_attr_mutations:
sk_sam = self.store_attr_mutations.keys()
ok_sam = other.store_attr_mutations.keys()
if sk_sam != ok_sam:
return f"store_attr_mutations keys: {sk_sam} != {ok_sam}"
return "store_attr_mutations: unknown diff"
else:
return None

def clone(self):
"""Create a shallow copy"""
return self.__class__(
Expand Down
33 changes: 31 additions & 2 deletions torch/_dynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import typing
import weakref
from collections.abc import Sized
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple
from unittest.mock import patch

import torch
Expand Down Expand Up @@ -116,6 +116,16 @@ class InstructionTranslatorGraphState(NamedTuple):
next_instruction: Optional[Instruction]
lineno: int

def diff(self, other: "InstructionTranslatorGraphState") -> Optional[str]:
for k in self._fields:
if k == "output":
return self.output.diff(other.output, prefix=f"{k}.")
sv = getattr(self, k)
ov = getattr(other, k)
if sv != ov:
return f"{k} mismatch: {sv} != {ov}"
return None


def stack_op(fn: typing.Callable[..., object]):
nargs = len(inspect.signature(fn).parameters)
Expand Down Expand Up @@ -365,6 +375,7 @@ class InstructionTranslatorBase(object):
next_instruction: Optional[Instruction]
block_stack: List[BlockStackEntry]
lineno: int
mutated_closure_cell_contents: Set[str]

checkpoint: Optional[Tuple[Instruction, InstructionTranslatorGraphState]]
random_calls: List[
Expand Down Expand Up @@ -1589,6 +1600,7 @@ def __init__(
compiler_fn,
one_graph,
export,
mutated_closure_cell_contents: Set[str],
):
super(InstructionTranslator, self).__init__(
output=OutputGraph(f_globals, code_options, compiler_fn, self),
Expand All @@ -1605,6 +1617,7 @@ def __init__(
)
self.one_graph: bool = one_graph
self.export = export
self.mutated_closure_cell_contents = mutated_closure_cell_contents
if self.export:
assert (
self.one_graph
Expand Down Expand Up @@ -1853,14 +1866,30 @@ def STORE_DEREF(self, inst):
else:
self.output.side_effects.store_cell(cell, val)
else:
maybe_cell = self.symbolic_locals.get(inst.argval)
if isinstance(
self.symbolic_locals.get(inst.argval),
maybe_cell,
variables.NewCellVariable,
):
self.output.side_effects.store_cell(
self.symbolic_locals[inst.argval], self.pop()
)
else:
if (
maybe_cell is not None
and maybe_cell.source.name()
not in self.parent.mutated_closure_cell_contents
):
# Why is the source name here unique?
# mutated_closure_cell_contents is a per-frame
# concept, and sources identify, e.g., particular
# locals from the frame. If you had two locals,
# they'll get different source names, and therefore
# differ here.
self.parent.mutated_closure_cell_contents.add(
maybe_cell.source.name()
)
raise exc.RestartAnalysis()
unimplemented("write to __closure__ while inlining")

def LOAD_DEREF(self, inst):
Expand Down
26 changes: 22 additions & 4 deletions torch/_dynamo/variables/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ def bind_args(self, parent, args, kwargs):
options = VariableTracker.propagate([self])
wrap = functools.partial(wrap_bound_arg, options=options)

tx = parent.output.root_tx

fn: types.FunctionType = self.fn
fake_func = types.FunctionType(
fn.__code__,
Expand Down Expand Up @@ -146,7 +148,7 @@ def bind_args(self, parent, args, kwargs):
if name == "__class__":
result[name] = variables.UserDefinedClassVariable(cell.cell_contents)
else:
var = parent.output.root_tx.match_nested_cell(name, cell)
var = tx.match_nested_cell(name, cell)
if var is not None:
# optimization for cleaner codegen
result[name] = var
Expand All @@ -163,15 +165,31 @@ def bind_args(self, parent, args, kwargs):
closure_cell_contents = AttrSource(
closure_cell, "cell_contents"
)
contents_var = VariableBuilder(parent, closure_cell_contents)(
cell.cell_contents
)

if (
closure_cell_contents.name()
not in tx.mutated_closure_cell_contents
):
# Optimistically don't allocate the cell, to
# reduce the number of side effects. This is
# important for cond, as without it, any accesses
# to closures create side effects and cond doesn't
# support side effects. If we're wrong and this
# closure cell gets written to, we will restart
# the analysis with this cell's name in the
# mutated list here
result[name] = contents_var
continue

# cells are written to with "cell_contents",
# so the source should just be the closure_cell, not its contents
out = side_effects.track_cell_existing(closure_cell, cell)
side_effects.store_cell(
out,
VariableBuilder(parent, closure_cell_contents)(
cell.cell_contents
),
contents_var,
)

result[name] = out
Expand Down
Loading

0 comments on commit cb15e39

Please sign in to comment.