Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch Upstream] Triton regression on Inductor test case:NameError("Cannot access global variable STRING_CONSTANT_C from within @jit'ed function) #1126

Closed
etaf opened this issue May 14, 2024 · 3 comments
Assignees
Labels

Comments

@etaf
Copy link

etaf commented May 14, 2024

When upgrade triton commit from dd07225 to e47fd95,
The Inductor test case used to pass now failed:

python test/inductor/test_triton_kernels.py -k test_triton_kernel_constants

======================================================================
ERROR: test_triton_kernel_constants (__main__.KernelTests)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/xinanlin/xinanlin/pytorch/torch/testing/_internal/common_utils.py", line 2756, in wrapper
    method(*args, **kwargs)
  File "/home/xinanlin/xinanlin/pytorch/test/inductor/test_triton_kernels.py", line 637, in test_triton_kernel_constants
    torch_result = call_triton(t)
  File "/home/xinanlin/xinanlin/pytorch/test/inductor/test_triton_kernels.py", line 623, in call_triton
    mulC_kernel[grid](
  File "/home/xinanlin/xinanlin/miniconda3/lib/python3.10/site-packages/triton/runtime/jit.py", line 345, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/home/xinanlin/xinanlin/miniconda3/lib/python3.10/site-packages/triton/runtime/jit.py", line 662, in run
    kernel = self.compile(
  File "/home/xinanlin/xinanlin/miniconda3/lib/python3.10/site-packages/triton/compiler/compiler.py", line 276, in compile
    module = src.make_ir(options, codegen_fns, context)
  File "/home/xinanlin/xinanlin/miniconda3/lib/python3.10/site-packages/triton/compiler/compiler.py", line 113, in make_ir
    return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns)
triton.compiler.errors.CompilationError: at 13:30:
    in_ptr0,
    out_ptr,
    n_elements,
    BLOCK_SIZE: "tl.constexpr",
    CONSTANT_NAME: "tl.constexpr",
):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(in_ptr0 + offsets, mask=mask)
    if CONSTANT_NAME.value == STRING_CONSTANT_C:
                              ^
NameError("Cannot access global variable STRING_CONSTANT_C from within @jit'ed function. Triton kernels can only access global variables that are annotated as
 constexpr (`x: triton.language.constexpr = 42` or `x = triton.language.constexpr(42)`).  Alternatively, set the envvar TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1,
but we do not promise to support this forever.")

Reproduce:

checkout the PR: pytorch/pytorch#124147
build stock Pytorch with USE_XPU=1
python test/inductor/test_triton_kernels.py -k test_triton_kernel_constants

@etaf
Copy link
Author

etaf commented May 14, 2024

@riverliuintel @vlad-penkin This issue block pytorch upstream, please priotize, thanks.

@vlad-penkin vlad-penkin added bug Something isn't working upstream: pytorch labels May 14, 2024
@alexbaden alexbaden self-assigned this May 14, 2024
@alexbaden
Copy link
Contributor

Will be fixed by pytorch/pytorch#126195

@alexbaden
Copy link
Contributor

Merged: pytorch/pytorch@9554300

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants