|
| 1 | +""" |
| 2 | +This module defines runtime wrappers, which, based on previous analysis |
| 3 | +attempts to process the inputs and outputs, apply mutations, functionalize randomness |
| 4 | +and dispatch subclasses. |
| 5 | +""" |
| 6 | + |
| 7 | +from contextlib import nullcontext |
| 8 | +from typing import Callable, List, Optional, Union |
| 9 | +from unittest.mock import patch |
| 10 | + |
| 11 | +import torch |
| 12 | +from torch._decomp.decompositions_for_rng import PhiloxStateTracker |
| 13 | +from torch._guards import detect_fake_mode |
| 14 | +from torch._prims_common import CUDARngStateHelper |
| 15 | + |
| 16 | +from .functional_utils import gen_alias_from_base |
| 17 | +from .schemas import OutputType, SubclassCreationMeta, TensorAlias, ViewAndMutationMeta |
| 18 | +from .subclass_utils import unwrap_tensor_subclasses, wrap_tensor_subclasses |
| 19 | +from .utils import call_func_at_runtime_with_args, make_boxed_func |
| 20 | + |
| 21 | + |
| 22 | +# The wrapper created by this function handles all of the runtime aliasing and mutation "epilogue" logic |
| 23 | +# that needs to run after the compiled function. |
| 24 | +# |
| 25 | +# This function accepts a trace_joint flag, indicating whether or not we're generating the runtime |
| 26 | +# epilogue for a forward-only inference graph, or for an autograd.Function.apply function. |
| 27 | +# This is because there are some minor differences in how we treat these cases at runtime: |
| 28 | +# - resize_() is currently handled in the inference case, but not fully handled in the autograd case. |
| 29 | +# - the autograd cases inserts TensorAlias wrapper objects for outputs that alias inputs |
| 30 | +def create_runtime_wrapper( |
| 31 | + compiled_fn, |
| 32 | + *, |
| 33 | + runtime_metadata: ViewAndMutationMeta, |
| 34 | + indices_of_inps_to_detach: List[int], |
| 35 | + trace_joint: bool, |
| 36 | + keep_input_mutations: bool, |
| 37 | + disable_amp: bool, |
| 38 | +): |
| 39 | + if not hasattr(compiled_fn, "_boxed_call"): |
| 40 | + compiled_fn = make_boxed_func(compiled_fn) |
| 41 | + |
| 42 | + def runtime_wrapper(*args): |
| 43 | + if trace_joint: |
| 44 | + args_ = list(args) |
| 45 | + # See Note [Detaching inputs that never need gradients] |
| 46 | + for idx in indices_of_inps_to_detach: |
| 47 | + if isinstance(args_[idx], torch.Tensor): |
| 48 | + args_[idx] = args_[idx].detach() |
| 49 | + with torch.autograd._force_original_view_tracking(True): |
| 50 | + all_outs = call_func_at_runtime_with_args( |
| 51 | + compiled_fn, |
| 52 | + args_, |
| 53 | + disable_amp=disable_amp, |
| 54 | + ) |
| 55 | + else: |
| 56 | + # When we have an inference graph, we run with torch.no_grad. |
| 57 | + # It's possible to get an inference graph with inputs that require grad, |
| 58 | + # in which case we want to make sure autograd is disabled |
| 59 | + # (since e.g., inductor will generate aten.addmm.out calls which autograd will complain on) |
| 60 | + with torch.no_grad(): |
| 61 | + all_outs = call_func_at_runtime_with_args( |
| 62 | + compiled_fn, |
| 63 | + args, |
| 64 | + disable_amp=disable_amp, |
| 65 | + ) |
| 66 | + |
| 67 | + num_mutated_runtime_inps = runtime_metadata.num_mutated_inp_runtime_indices |
| 68 | + num_intermediate_bases = runtime_metadata.num_intermediate_bases |
| 69 | + |
| 70 | + if keep_input_mutations and trace_joint: |
| 71 | + num_graph_handled = runtime_metadata.num_mutated_graph_handled_indices |
| 72 | + # autograd.Function requires us to return the mutated inputs as extra outputs to the autograd.Function.forward |
| 73 | + if num_graph_handled > 0: |
| 74 | + all_outs = all_outs[:-num_graph_handled] |
| 75 | + |
| 76 | + assert ( |
| 77 | + len(all_outs) |
| 78 | + == num_mutated_runtime_inps |
| 79 | + + runtime_metadata.num_outputs |
| 80 | + + num_intermediate_bases |
| 81 | + ) |
| 82 | + |
| 83 | + # Step 3: After running the compiled fw, apply updates to mutated inputs |
| 84 | + num_mutations_to_apply = runtime_metadata.num_mutated_inp_runtime_indices |
| 85 | + if num_mutations_to_apply > 0: |
| 86 | + updated_inputs = all_outs[:num_mutations_to_apply] |
| 87 | + fw_outs = all_outs[num_mutations_to_apply:] |
| 88 | + |
| 89 | + for i, inpt_idx in enumerate(runtime_metadata.mutated_inp_runtime_indices): |
| 90 | + meta = runtime_metadata.input_info[inpt_idx] |
| 91 | + if not meta.mutates_data and not meta.mutates_metadata: |
| 92 | + continue |
| 93 | + original_inpt = args[inpt_idx] |
| 94 | + updated_inpt = updated_inputs[i] |
| 95 | + if meta.mutates_storage_metadata: |
| 96 | + # mutates_storage_metadata means our input saw a x.set_(y) call. |
| 97 | + # What if x **also** saw a data and/or a metadata mutation? |
| 98 | + # (1) If the [meta]data mutation occurred after the set_(), |
| 99 | + # then there is no need to copy_() the data. |
| 100 | + # When we perform x.set_(x_updated), we are guaranteed that |
| 101 | + # x_updated already has the final version of the data/metadata |
| 102 | + # (2) If a data mutation occurred before the set_(). |
| 103 | + # This case seems very difficult to support. |
| 104 | + # TODO: discuss on the PR and decide if we want to tr to |
| 105 | + # either support it, or detect and ban it. |
| 106 | + if trace_joint: |
| 107 | + assert isinstance(updated_inpt, TensorAlias) |
| 108 | + updated_inpt = updated_inpt.alias |
| 109 | + original_inpt.set_(updated_inpt) |
| 110 | + continue |
| 111 | + if meta.mutates_metadata and not meta.mutates_data: |
| 112 | + if trace_joint: |
| 113 | + assert isinstance(updated_inpt, TensorAlias) |
| 114 | + updated_inpt = updated_inpt.alias |
| 115 | + # We need to grab the size/stride/storage_offset from the compiled forward, |
| 116 | + # and use that to mutate the metadata of the input |
| 117 | + original_inpt.as_strided_( |
| 118 | + updated_inpt.size(), |
| 119 | + updated_inpt.stride(), |
| 120 | + updated_inpt.storage_offset(), |
| 121 | + ) |
| 122 | + else: |
| 123 | + if meta.mutates_data and meta.mutates_metadata: |
| 124 | + original_inpt.as_strided_( |
| 125 | + updated_inpt.size(), |
| 126 | + updated_inpt.stride(), |
| 127 | + updated_inpt.storage_offset(), |
| 128 | + ) |
| 129 | + else: |
| 130 | + assert meta.mutates_data |
| 131 | + if meta.is_leaf and original_inpt.requires_grad: |
| 132 | + # We can hit this situation in this case: |
| 133 | + # def f(x): |
| 134 | + # x.detach().mul_(2) |
| 135 | + # return x + 1 |
| 136 | + # AOTAutograd will see a mutation in the above case, and try to |
| 137 | + # apply a copy_() here, in the epilogue. |
| 138 | + # But if x required gradients, and is a leaf, then autograd |
| 139 | + # will yell at us for trying to mutate it. |
| 140 | + # However, it's only possible to end up in this scenario (like the above) |
| 141 | + # if all of the mutations to the leaf input were non-autograd-tracking mutations |
| 142 | + # (aka mutations under no_grad(), or on detached views). |
| 143 | + # In that case, we fully want to hide the mutation from autograd, so detaching is ok. |
| 144 | + original_inpt.detach().copy_(updated_inpt) |
| 145 | + else: |
| 146 | + original_inpt.copy_(updated_inpt) |
| 147 | + else: |
| 148 | + fw_outs = all_outs |
| 149 | + |
| 150 | + # Step 4: Manually regenerate any outputs that are aliased to inputs, instead of |
| 151 | + # compiling them. |
| 152 | + if runtime_metadata.num_outputs_aliased > 0: |
| 153 | + # The compiled forward also returned intermediate bases. We don't want to return them to the user. |
| 154 | + if runtime_metadata.num_intermediate_bases > 0: |
| 155 | + fw_outs_no_intermediate_bases = fw_outs[ |
| 156 | + : -runtime_metadata.num_intermediate_bases |
| 157 | + ] |
| 158 | + intermediate_bases = fw_outs[-runtime_metadata.num_intermediate_bases :] |
| 159 | + else: |
| 160 | + fw_outs_no_intermediate_bases = fw_outs |
| 161 | + intermediate_bases = [] |
| 162 | + |
| 163 | + assert len(fw_outs_no_intermediate_bases) == len( |
| 164 | + runtime_metadata.output_info |
| 165 | + ) |
| 166 | + fw_outs_including_aliases = [] |
| 167 | + for i, (o, info) in enumerate( |
| 168 | + zip(fw_outs_no_intermediate_bases, runtime_metadata.output_info) |
| 169 | + ): |
| 170 | + if info.output_type in [ |
| 171 | + OutputType.non_alias, |
| 172 | + OutputType.unsafe_view_alias, |
| 173 | + OutputType.custom_function_view, |
| 174 | + ]: |
| 175 | + fw_outs_including_aliases.append(o) |
| 176 | + continue |
| 177 | + if trace_joint: |
| 178 | + assert isinstance(o, TensorAlias) |
| 179 | + o_ = o.alias |
| 180 | + else: |
| 181 | + o_ = o |
| 182 | + |
| 183 | + o_grad = runtime_metadata.output_info[i].requires_grad |
| 184 | + if info.output_type == OutputType.alias_of_input: |
| 185 | + aliased_base_tensor = args[info.base_idx] # type: ignore[index] |
| 186 | + regenerated_out = gen_alias_from_base( |
| 187 | + aliased_base_tensor, o_, o_grad |
| 188 | + ) |
| 189 | + fw_outs_including_aliases.append(regenerated_out) |
| 190 | + continue |
| 191 | + elif info.output_type == OutputType.is_input: |
| 192 | + aliased_base_tensor = args[info.base_idx] # type: ignore[index] |
| 193 | + regenerated_out = aliased_base_tensor |
| 194 | + fw_outs_including_aliases.append(regenerated_out) |
| 195 | + continue |
| 196 | + elif info.output_type == OutputType.alias_of_intermediate: |
| 197 | + base_tensor_list = intermediate_bases |
| 198 | + elif ( |
| 199 | + info.output_type == OutputType.alias_of_intermediate_save_as_output |
| 200 | + ): |
| 201 | + base_tensor_list = intermediate_bases |
| 202 | + else: |
| 203 | + assert ( |
| 204 | + info.output_type |
| 205 | + == OutputType.alias_of_intermediate_base_is_user_output |
| 206 | + ) |
| 207 | + base_tensor_list = fw_outs_no_intermediate_bases |
| 208 | + aliased_base_tensor = base_tensor_list[info.base_idx] |
| 209 | + # TODO: handle the custom autograd function case here. |
| 210 | + # We need a way to check whether a tensor came from a custom autograd fn from python, |
| 211 | + # AND a way to replay that custom view fn. |
| 212 | + regenerated_out = gen_alias_from_base(aliased_base_tensor, o_, o_grad) |
| 213 | + fw_outs_including_aliases.append(regenerated_out) |
| 214 | + ret_outs = fw_outs_including_aliases |
| 215 | + else: |
| 216 | + ret_outs = fw_outs |
| 217 | + |
| 218 | + if runtime_metadata.dynamic_outputs: |
| 219 | + for t, o in zip(ret_outs, runtime_metadata.output_info): |
| 220 | + if o.dynamic_dims is None: |
| 221 | + continue |
| 222 | + if hasattr(t, "_dynamo_weak_dynamic_indices"): |
| 223 | + t._dynamo_weak_dynamic_indices |= o.dynamic_dims |
| 224 | + else: |
| 225 | + t._dynamo_weak_dynamic_indices = o.dynamic_dims.copy() |
| 226 | + if runtime_metadata.grad_enabled_mutation is not None: |
| 227 | + torch.set_grad_enabled(runtime_metadata.grad_enabled_mutation) |
| 228 | + return ret_outs |
| 229 | + |
| 230 | + return runtime_wrapper |
| 231 | + |
| 232 | + |
| 233 | +# Calling convention: If we are running functionalized RNG, then outs consists |
| 234 | +# of (user_outs, rng_offset) |
| 235 | +def functionalized_rng_runtime_epilogue( |
| 236 | + metadata: ViewAndMutationMeta, outs, return_new_outs=True |
| 237 | +): |
| 238 | + if metadata.is_rng_op_functionalized: |
| 239 | + assert metadata.num_outputs_rng_offset == 1 |
| 240 | + new_rng_offset = outs[-1] |
| 241 | + CUDARngStateHelper.set_new_offset(new_rng_offset) |
| 242 | + if return_new_outs: |
| 243 | + user_outs = outs[:-1] |
| 244 | + return user_outs |
| 245 | + else: |
| 246 | + return None |
| 247 | + return outs |
| 248 | + |
| 249 | + |
| 250 | +def create_functionalized_rng_ops_wrapper(func, args, trace_joint=True): |
| 251 | + # Functionalization of rng ops changes the calling convention of the joint graph. |
| 252 | + # It goes from (primals, tangents) to (seed, offset, primals, tangents) |
| 253 | + # At runtime, we pass on the current seed and offset. This is hidden from |
| 254 | + # the user. |
| 255 | + fake_mode = detect_fake_mode() |
| 256 | + if fake_mode is None: |
| 257 | + fake_mode = nullcontext() |
| 258 | + |
| 259 | + def override_get_rng_state(device: Union[int, str, torch.device] = "cuda"): |
| 260 | + out = PhiloxStateTracker.get_state_as_tensor() |
| 261 | + return out |
| 262 | + |
| 263 | + def override_set_rng_state(x, device: Union[int, str, torch.device] = "cuda"): |
| 264 | + PhiloxStateTracker.set_state_from_tensor(x) |
| 265 | + |
| 266 | + def append_rng_offsets(args): |
| 267 | + if trace_joint: |
| 268 | + # args signature before: Tuple(fwd_outputs), Tuple(bwd_outputs) |
| 269 | + # args signature after: Tuple(fwd_outputs, new_fwd_rng_offset), Tuple(bwd_offset, new_bwd_rng_offset) |
| 270 | + return ( |
| 271 | + (*args[0], PhiloxStateTracker.get_updated_fwd_offset()), |
| 272 | + (*args[1], PhiloxStateTracker.get_updated_bwd_offset()), |
| 273 | + ) |
| 274 | + else: |
| 275 | + # args signature before: Tuple(fwd_outputs) |
| 276 | + # args signature after: Tuple(fwd_outputs, new_fwd_rng_offset) |
| 277 | + return (*args, PhiloxStateTracker.get_updated_fwd_offset()) |
| 278 | + |
| 279 | + def traced_joint( |
| 280 | + primals, tangents, fwd_seed, fwd_base_offset, bwd_seed, bwd_base_offset |
| 281 | + ): |
| 282 | + with patch("torch.cuda.get_rng_state", override_get_rng_state), patch( |
| 283 | + "torch.cuda.set_rng_state", override_set_rng_state |
| 284 | + ): |
| 285 | + return append_rng_offsets(func(primals, tangents)) |
| 286 | + |
| 287 | + def traced_forward(*primals_fwd_seed_fwd_base_offset): |
| 288 | + # The signature is (*primals, seed, offset) |
| 289 | + with patch("torch.cuda.get_rng_state", override_get_rng_state), patch( |
| 290 | + "torch.cuda.set_rng_state", override_set_rng_state |
| 291 | + ): |
| 292 | + return append_rng_offsets(func(*primals_fwd_seed_fwd_base_offset[:-2])) |
| 293 | + |
| 294 | + if trace_joint: |
| 295 | + # Get the current seed and offset to setup tracing. |
| 296 | + fwd_seed, fwd_base_offset = CUDARngStateHelper.get_torch_state_as_tuple( |
| 297 | + fake_mode |
| 298 | + ) |
| 299 | + bwd_seed, bwd_base_offset = CUDARngStateHelper.get_torch_state_as_tuple( |
| 300 | + fake_mode |
| 301 | + ) |
| 302 | + PhiloxStateTracker.record_state(fwd_seed, fwd_base_offset, "forward") |
| 303 | + PhiloxStateTracker.record_state(bwd_seed, bwd_base_offset, "backward") |
| 304 | + return traced_joint, ( |
| 305 | + *args, |
| 306 | + fwd_seed, |
| 307 | + fwd_base_offset, |
| 308 | + bwd_seed, |
| 309 | + bwd_base_offset, |
| 310 | + ) |
| 311 | + else: |
| 312 | + # Get the current seed and offset to setup tracing. |
| 313 | + fwd_seed, fwd_base_offset = CUDARngStateHelper.get_torch_state_as_tuple( |
| 314 | + fake_mode |
| 315 | + ) |
| 316 | + PhiloxStateTracker.record_state(fwd_seed, fwd_base_offset, "forward") |
| 317 | + return traced_forward, (*args, fwd_seed, fwd_base_offset) |
| 318 | + |
| 319 | + |
| 320 | +# This wrapper handles the AOTDispatch runtime logic for tensor subclasses. |
| 321 | +# At runtime, we have a compiled function that knows how to operate on the domain of DenseTensor -> DenseTensor, |
| 322 | +# But the user might have passed us some tensor subclass inputs (or expect some subclass tensor outputs). |
| 323 | +# This function handles the wrapping and unwrapping of tensor subclasses at runtime. |
| 324 | +def aot_dispatch_subclass_wrapper( |
| 325 | + runtime_fn: Callable, |
| 326 | + *, |
| 327 | + subclass_metas: List[Union[int, SubclassCreationMeta]], |
| 328 | + num_fw_outs_saved_for_bw: Optional[int], |
| 329 | +) -> Callable: |
| 330 | + def inner_fn(args): |
| 331 | + unwrapped_args = unwrap_tensor_subclasses(args, is_joint_structure=False) |
| 332 | + # expectation: runtime_fn is a boxed fn |
| 333 | + unwrapped_outs = runtime_fn(unwrapped_args) |
| 334 | + wrapped_outs = wrap_tensor_subclasses( |
| 335 | + unwrapped_outs, |
| 336 | + subclass_metas=subclass_metas, |
| 337 | + num_fw_outs_saved_for_bw=num_fw_outs_saved_for_bw, |
| 338 | + is_runtime=True, |
| 339 | + ) |
| 340 | + return wrapped_outs |
| 341 | + |
| 342 | + # box it |
| 343 | + inner_fn._boxed_call = True # type: ignore[attr-defined] |
| 344 | + return inner_fn |
0 commit comments