Skip to content

Commit ddd4563

Browse files
jon-chuangdmenig
authored andcommitted
[AOT Refactor] runtime wrappers (pytorch#114557)
--- Part _ of pytorch#114548 Pull Request resolved: pytorch#114557 Approved by: https://github.com/bdhirsh ghstack dependencies: pytorch#114550, pytorch#114551, pytorch#114552, pytorch#114553, pytorch#114554, pytorch#114555, pytorch#114556
1 parent 64431f5 commit ddd4563

File tree

2 files changed

+351
-284
lines changed

2 files changed

+351
-284
lines changed
Lines changed: 344 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,344 @@
1+
"""
2+
This module defines runtime wrappers, which, based on previous analysis
3+
attempts to process the inputs and outputs, apply mutations, functionalize randomness
4+
and dispatch subclasses.
5+
"""
6+
7+
from contextlib import nullcontext
8+
from typing import Callable, List, Optional, Union
9+
from unittest.mock import patch
10+
11+
import torch
12+
from torch._decomp.decompositions_for_rng import PhiloxStateTracker
13+
from torch._guards import detect_fake_mode
14+
from torch._prims_common import CUDARngStateHelper
15+
16+
from .functional_utils import gen_alias_from_base
17+
from .schemas import OutputType, SubclassCreationMeta, TensorAlias, ViewAndMutationMeta
18+
from .subclass_utils import unwrap_tensor_subclasses, wrap_tensor_subclasses
19+
from .utils import call_func_at_runtime_with_args, make_boxed_func
20+
21+
22+
# The wrapper created by this function handles all of the runtime aliasing and mutation "epilogue" logic
23+
# that needs to run after the compiled function.
24+
#
25+
# This function accepts a trace_joint flag, indicating whether or not we're generating the runtime
26+
# epilogue for a forward-only inference graph, or for an autograd.Function.apply function.
27+
# This is because there are some minor differences in how we treat these cases at runtime:
28+
# - resize_() is currently handled in the inference case, but not fully handled in the autograd case.
29+
# - the autograd cases inserts TensorAlias wrapper objects for outputs that alias inputs
30+
def create_runtime_wrapper(
31+
compiled_fn,
32+
*,
33+
runtime_metadata: ViewAndMutationMeta,
34+
indices_of_inps_to_detach: List[int],
35+
trace_joint: bool,
36+
keep_input_mutations: bool,
37+
disable_amp: bool,
38+
):
39+
if not hasattr(compiled_fn, "_boxed_call"):
40+
compiled_fn = make_boxed_func(compiled_fn)
41+
42+
def runtime_wrapper(*args):
43+
if trace_joint:
44+
args_ = list(args)
45+
# See Note [Detaching inputs that never need gradients]
46+
for idx in indices_of_inps_to_detach:
47+
if isinstance(args_[idx], torch.Tensor):
48+
args_[idx] = args_[idx].detach()
49+
with torch.autograd._force_original_view_tracking(True):
50+
all_outs = call_func_at_runtime_with_args(
51+
compiled_fn,
52+
args_,
53+
disable_amp=disable_amp,
54+
)
55+
else:
56+
# When we have an inference graph, we run with torch.no_grad.
57+
# It's possible to get an inference graph with inputs that require grad,
58+
# in which case we want to make sure autograd is disabled
59+
# (since e.g., inductor will generate aten.addmm.out calls which autograd will complain on)
60+
with torch.no_grad():
61+
all_outs = call_func_at_runtime_with_args(
62+
compiled_fn,
63+
args,
64+
disable_amp=disable_amp,
65+
)
66+
67+
num_mutated_runtime_inps = runtime_metadata.num_mutated_inp_runtime_indices
68+
num_intermediate_bases = runtime_metadata.num_intermediate_bases
69+
70+
if keep_input_mutations and trace_joint:
71+
num_graph_handled = runtime_metadata.num_mutated_graph_handled_indices
72+
# autograd.Function requires us to return the mutated inputs as extra outputs to the autograd.Function.forward
73+
if num_graph_handled > 0:
74+
all_outs = all_outs[:-num_graph_handled]
75+
76+
assert (
77+
len(all_outs)
78+
== num_mutated_runtime_inps
79+
+ runtime_metadata.num_outputs
80+
+ num_intermediate_bases
81+
)
82+
83+
# Step 3: After running the compiled fw, apply updates to mutated inputs
84+
num_mutations_to_apply = runtime_metadata.num_mutated_inp_runtime_indices
85+
if num_mutations_to_apply > 0:
86+
updated_inputs = all_outs[:num_mutations_to_apply]
87+
fw_outs = all_outs[num_mutations_to_apply:]
88+
89+
for i, inpt_idx in enumerate(runtime_metadata.mutated_inp_runtime_indices):
90+
meta = runtime_metadata.input_info[inpt_idx]
91+
if not meta.mutates_data and not meta.mutates_metadata:
92+
continue
93+
original_inpt = args[inpt_idx]
94+
updated_inpt = updated_inputs[i]
95+
if meta.mutates_storage_metadata:
96+
# mutates_storage_metadata means our input saw a x.set_(y) call.
97+
# What if x **also** saw a data and/or a metadata mutation?
98+
# (1) If the [meta]data mutation occurred after the set_(),
99+
# then there is no need to copy_() the data.
100+
# When we perform x.set_(x_updated), we are guaranteed that
101+
# x_updated already has the final version of the data/metadata
102+
# (2) If a data mutation occurred before the set_().
103+
# This case seems very difficult to support.
104+
# TODO: discuss on the PR and decide if we want to tr to
105+
# either support it, or detect and ban it.
106+
if trace_joint:
107+
assert isinstance(updated_inpt, TensorAlias)
108+
updated_inpt = updated_inpt.alias
109+
original_inpt.set_(updated_inpt)
110+
continue
111+
if meta.mutates_metadata and not meta.mutates_data:
112+
if trace_joint:
113+
assert isinstance(updated_inpt, TensorAlias)
114+
updated_inpt = updated_inpt.alias
115+
# We need to grab the size/stride/storage_offset from the compiled forward,
116+
# and use that to mutate the metadata of the input
117+
original_inpt.as_strided_(
118+
updated_inpt.size(),
119+
updated_inpt.stride(),
120+
updated_inpt.storage_offset(),
121+
)
122+
else:
123+
if meta.mutates_data and meta.mutates_metadata:
124+
original_inpt.as_strided_(
125+
updated_inpt.size(),
126+
updated_inpt.stride(),
127+
updated_inpt.storage_offset(),
128+
)
129+
else:
130+
assert meta.mutates_data
131+
if meta.is_leaf and original_inpt.requires_grad:
132+
# We can hit this situation in this case:
133+
# def f(x):
134+
# x.detach().mul_(2)
135+
# return x + 1
136+
# AOTAutograd will see a mutation in the above case, and try to
137+
# apply a copy_() here, in the epilogue.
138+
# But if x required gradients, and is a leaf, then autograd
139+
# will yell at us for trying to mutate it.
140+
# However, it's only possible to end up in this scenario (like the above)
141+
# if all of the mutations to the leaf input were non-autograd-tracking mutations
142+
# (aka mutations under no_grad(), or on detached views).
143+
# In that case, we fully want to hide the mutation from autograd, so detaching is ok.
144+
original_inpt.detach().copy_(updated_inpt)
145+
else:
146+
original_inpt.copy_(updated_inpt)
147+
else:
148+
fw_outs = all_outs
149+
150+
# Step 4: Manually regenerate any outputs that are aliased to inputs, instead of
151+
# compiling them.
152+
if runtime_metadata.num_outputs_aliased > 0:
153+
# The compiled forward also returned intermediate bases. We don't want to return them to the user.
154+
if runtime_metadata.num_intermediate_bases > 0:
155+
fw_outs_no_intermediate_bases = fw_outs[
156+
: -runtime_metadata.num_intermediate_bases
157+
]
158+
intermediate_bases = fw_outs[-runtime_metadata.num_intermediate_bases :]
159+
else:
160+
fw_outs_no_intermediate_bases = fw_outs
161+
intermediate_bases = []
162+
163+
assert len(fw_outs_no_intermediate_bases) == len(
164+
runtime_metadata.output_info
165+
)
166+
fw_outs_including_aliases = []
167+
for i, (o, info) in enumerate(
168+
zip(fw_outs_no_intermediate_bases, runtime_metadata.output_info)
169+
):
170+
if info.output_type in [
171+
OutputType.non_alias,
172+
OutputType.unsafe_view_alias,
173+
OutputType.custom_function_view,
174+
]:
175+
fw_outs_including_aliases.append(o)
176+
continue
177+
if trace_joint:
178+
assert isinstance(o, TensorAlias)
179+
o_ = o.alias
180+
else:
181+
o_ = o
182+
183+
o_grad = runtime_metadata.output_info[i].requires_grad
184+
if info.output_type == OutputType.alias_of_input:
185+
aliased_base_tensor = args[info.base_idx] # type: ignore[index]
186+
regenerated_out = gen_alias_from_base(
187+
aliased_base_tensor, o_, o_grad
188+
)
189+
fw_outs_including_aliases.append(regenerated_out)
190+
continue
191+
elif info.output_type == OutputType.is_input:
192+
aliased_base_tensor = args[info.base_idx] # type: ignore[index]
193+
regenerated_out = aliased_base_tensor
194+
fw_outs_including_aliases.append(regenerated_out)
195+
continue
196+
elif info.output_type == OutputType.alias_of_intermediate:
197+
base_tensor_list = intermediate_bases
198+
elif (
199+
info.output_type == OutputType.alias_of_intermediate_save_as_output
200+
):
201+
base_tensor_list = intermediate_bases
202+
else:
203+
assert (
204+
info.output_type
205+
== OutputType.alias_of_intermediate_base_is_user_output
206+
)
207+
base_tensor_list = fw_outs_no_intermediate_bases
208+
aliased_base_tensor = base_tensor_list[info.base_idx]
209+
# TODO: handle the custom autograd function case here.
210+
# We need a way to check whether a tensor came from a custom autograd fn from python,
211+
# AND a way to replay that custom view fn.
212+
regenerated_out = gen_alias_from_base(aliased_base_tensor, o_, o_grad)
213+
fw_outs_including_aliases.append(regenerated_out)
214+
ret_outs = fw_outs_including_aliases
215+
else:
216+
ret_outs = fw_outs
217+
218+
if runtime_metadata.dynamic_outputs:
219+
for t, o in zip(ret_outs, runtime_metadata.output_info):
220+
if o.dynamic_dims is None:
221+
continue
222+
if hasattr(t, "_dynamo_weak_dynamic_indices"):
223+
t._dynamo_weak_dynamic_indices |= o.dynamic_dims
224+
else:
225+
t._dynamo_weak_dynamic_indices = o.dynamic_dims.copy()
226+
if runtime_metadata.grad_enabled_mutation is not None:
227+
torch.set_grad_enabled(runtime_metadata.grad_enabled_mutation)
228+
return ret_outs
229+
230+
return runtime_wrapper
231+
232+
233+
# Calling convention: If we are running functionalized RNG, then outs consists
234+
# of (user_outs, rng_offset)
235+
def functionalized_rng_runtime_epilogue(
236+
metadata: ViewAndMutationMeta, outs, return_new_outs=True
237+
):
238+
if metadata.is_rng_op_functionalized:
239+
assert metadata.num_outputs_rng_offset == 1
240+
new_rng_offset = outs[-1]
241+
CUDARngStateHelper.set_new_offset(new_rng_offset)
242+
if return_new_outs:
243+
user_outs = outs[:-1]
244+
return user_outs
245+
else:
246+
return None
247+
return outs
248+
249+
250+
def create_functionalized_rng_ops_wrapper(func, args, trace_joint=True):
251+
# Functionalization of rng ops changes the calling convention of the joint graph.
252+
# It goes from (primals, tangents) to (seed, offset, primals, tangents)
253+
# At runtime, we pass on the current seed and offset. This is hidden from
254+
# the user.
255+
fake_mode = detect_fake_mode()
256+
if fake_mode is None:
257+
fake_mode = nullcontext()
258+
259+
def override_get_rng_state(device: Union[int, str, torch.device] = "cuda"):
260+
out = PhiloxStateTracker.get_state_as_tensor()
261+
return out
262+
263+
def override_set_rng_state(x, device: Union[int, str, torch.device] = "cuda"):
264+
PhiloxStateTracker.set_state_from_tensor(x)
265+
266+
def append_rng_offsets(args):
267+
if trace_joint:
268+
# args signature before: Tuple(fwd_outputs), Tuple(bwd_outputs)
269+
# args signature after: Tuple(fwd_outputs, new_fwd_rng_offset), Tuple(bwd_offset, new_bwd_rng_offset)
270+
return (
271+
(*args[0], PhiloxStateTracker.get_updated_fwd_offset()),
272+
(*args[1], PhiloxStateTracker.get_updated_bwd_offset()),
273+
)
274+
else:
275+
# args signature before: Tuple(fwd_outputs)
276+
# args signature after: Tuple(fwd_outputs, new_fwd_rng_offset)
277+
return (*args, PhiloxStateTracker.get_updated_fwd_offset())
278+
279+
def traced_joint(
280+
primals, tangents, fwd_seed, fwd_base_offset, bwd_seed, bwd_base_offset
281+
):
282+
with patch("torch.cuda.get_rng_state", override_get_rng_state), patch(
283+
"torch.cuda.set_rng_state", override_set_rng_state
284+
):
285+
return append_rng_offsets(func(primals, tangents))
286+
287+
def traced_forward(*primals_fwd_seed_fwd_base_offset):
288+
# The signature is (*primals, seed, offset)
289+
with patch("torch.cuda.get_rng_state", override_get_rng_state), patch(
290+
"torch.cuda.set_rng_state", override_set_rng_state
291+
):
292+
return append_rng_offsets(func(*primals_fwd_seed_fwd_base_offset[:-2]))
293+
294+
if trace_joint:
295+
# Get the current seed and offset to setup tracing.
296+
fwd_seed, fwd_base_offset = CUDARngStateHelper.get_torch_state_as_tuple(
297+
fake_mode
298+
)
299+
bwd_seed, bwd_base_offset = CUDARngStateHelper.get_torch_state_as_tuple(
300+
fake_mode
301+
)
302+
PhiloxStateTracker.record_state(fwd_seed, fwd_base_offset, "forward")
303+
PhiloxStateTracker.record_state(bwd_seed, bwd_base_offset, "backward")
304+
return traced_joint, (
305+
*args,
306+
fwd_seed,
307+
fwd_base_offset,
308+
bwd_seed,
309+
bwd_base_offset,
310+
)
311+
else:
312+
# Get the current seed and offset to setup tracing.
313+
fwd_seed, fwd_base_offset = CUDARngStateHelper.get_torch_state_as_tuple(
314+
fake_mode
315+
)
316+
PhiloxStateTracker.record_state(fwd_seed, fwd_base_offset, "forward")
317+
return traced_forward, (*args, fwd_seed, fwd_base_offset)
318+
319+
320+
# This wrapper handles the AOTDispatch runtime logic for tensor subclasses.
321+
# At runtime, we have a compiled function that knows how to operate on the domain of DenseTensor -> DenseTensor,
322+
# But the user might have passed us some tensor subclass inputs (or expect some subclass tensor outputs).
323+
# This function handles the wrapping and unwrapping of tensor subclasses at runtime.
324+
def aot_dispatch_subclass_wrapper(
325+
runtime_fn: Callable,
326+
*,
327+
subclass_metas: List[Union[int, SubclassCreationMeta]],
328+
num_fw_outs_saved_for_bw: Optional[int],
329+
) -> Callable:
330+
def inner_fn(args):
331+
unwrapped_args = unwrap_tensor_subclasses(args, is_joint_structure=False)
332+
# expectation: runtime_fn is a boxed fn
333+
unwrapped_outs = runtime_fn(unwrapped_args)
334+
wrapped_outs = wrap_tensor_subclasses(
335+
unwrapped_outs,
336+
subclass_metas=subclass_metas,
337+
num_fw_outs_saved_for_bw=num_fw_outs_saved_for_bw,
338+
is_runtime=True,
339+
)
340+
return wrapped_outs
341+
342+
# box it
343+
inner_fn._boxed_call = True # type: ignore[attr-defined]
344+
return inner_fn

0 commit comments

Comments
 (0)