-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Description
Describe the bug
Note: This might be something for the MVP program #12635 if there's anyone who already has a deep understanding of rotary embeddings and complex numbers. I don't.
The Qwen image pipeline calls
| def apply_rotary_emb_qwen( |
use_real==False.
The function therefore operates on complex numbers.
If compiled, torch.compile warns about this: venv/lib/python3.12/site-packages/torch/_inductor/lowering.py:1890: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
Performance being worse than eager isn't a big deal. This is not a performance critical part of the model.
However, due to a subtle torch.compile bug it leads to random compile failures:
pytorch/pytorch#163876
Can the code path with real numbers be used instead?
Reproduction
I cannot provide reproduction code, because it's random and shows up mostly when a kernel is recompiled, but also not consistently.
Multiple users are affected though. It can be worked around by putting a compile.disable decorator around the function, but I don't like this solution because then you cannot compile with fullgraph=True anymore.
Logs
packed_predicted_flow = model.transformer(
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/nn/
modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/home/linux/KI/OneTrainer/src/diffusers/src/diffusers/models/transformers/transformer_qwenimage.py", line 629, in forward
encoder_hidden_states, hidden_states = block(
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1771, in _wrapped_call_impl
return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 749, in compile_wrapper
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1871, in _call_user_compiler
raise BackendCompilerFailed(
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1846, in _call_user_compiler
compiled_fn = compiler_fn(gm, example_inputs)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 150, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/__init__.py", line 2380, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 2418, in compile_fx
return aot_autograd(
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 109, in __call__
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1199, in aot_module_simplified
compiled_fn = AOTAutogradCache.load(
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/autograd_cache.py", line 1140, in load
compiled_fn = dispatch_and_compile()
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1184, in dispatch_and_compile
compiled_fn, _ = create_aot_dispatcher_function(
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 576, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 836, in _create_aot_dispatcher_function
compiled_fn, fw_metadata = compiler_fn(
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 1604, in aot_dispatch_autograd
compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 483, in __call__
return self.compiler_fn(gm, example_inputs)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 2250, in fw_compiler_base
return inner_compile(
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 745, in compile_fx_inner
return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_dynamo/repro/after_aot.py", line 124, in debug_wrapper
inner_compiled_fn = compiler_fn(gm, example_inputs)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 896, in _compile_fx_inner
mb_compiled_graph = fx_codegen_and_compile(
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1578, in fx_codegen_and_compile
return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1456, in codegen_and_compile
compiled_module = graph.compile_to_module()
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/graph.py", line 2293, in compile_to_module
return self._compile_to_module()
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/graph.py", line 2299, in _compile_to_module
self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/graph.py", line 2238, in codegen
self.scheduler.codegen()
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/scheduler.py", line 4598, in codegen
else self._codegen(self.nodes)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/scheduler.py", line 4750, in _codegen
self.get_backend(device).codegen_node(node)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/codegen/cuda_combined_scheduling.py", line 107, in codegen_node
return self._triton_scheduling.codegen_node(node)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/codegen/simd.py", line 1363, in codegen_node
coalesce_analysis = analyze_memory_coalescing(node)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/tiling_utils.py", line 650, in analyze_memory_coalescing
norm_read_writes = extract_normalized_read_writes(fused_node)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/tiling_utils.py", line 482, in extract_normalized_read_writes
if any(
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/tiling_utils.py", line 483, in <genexpr>
(isinstance(var, sympy.Expr) and not var.is_constant())
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/expr.py", line 724, in is_constant
b = expr._random(None, -1, 0, 1, 0)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/expr.py", line 562, in _random
nmag = abs(self.evalf(2, subs=reps))
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/evalf.py", line 1654, in evalf
result = evalf(self, prec + 4, options)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/evalf.py", line 1489, in evalf
r = rf(x, prec, options)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/evalf.py", line 602, in evalf_add
terms = [evalf(arg, prec + 10, options) for arg in v.args]
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/evalf.py", line 602, in <listcomp>
terms = [evalf(arg, prec + 10, options) for arg in v.args]
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/evalf.py", line 1489, in evalf
r = rf(x, prec, options)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/evalf.py", line 650, in evalf_mul
result = evalf(arg, prec, options)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/evalf.py", line 1493, in evalf
x = x.subs(evalf_subs(prec, options['subs']))
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/basic.py", line 1171, in subs
rv = rv._subs(old, new, **kwargs)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/cache.py", line 72, in wrapper
retval = cfunc(*args, **kwargs)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/basic.py", line 1285, in _subs
rv = fallback(self, old, new)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/basic.py", line 1262, in fallback
rv = self.func(*args)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/cache.py", line 72, in wrapper
retval = cfunc(*args, **kwargs)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/function.py", line 450, in __new__
return cls._new_(*args, **options) # type: ignore
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/function.py", line 472, in _new_
result = super().__new__(cls, *args, **options)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/cache.py", line 72, in wrapper
retval = cfunc(*args, **kwargs)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/function.py", line 309, in __new__
evaluated = cls.eval(*args)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/utils/_sympy/functions.py", line 488, in eval
assert p >= 0, p
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AssertionError: -1470286036225387/1000000000000000System Info
various with torch 2.8