Skip to content

[BUG] If-Else Statements is not evaluated at runtime #3213

@LRlr239

Description

@LRlr239

Which component has the problem?

CuTe DSL

Bug Report

Describe the bug
this if-else statement does not work with nvidia-cutlass-dsl==4.5.0:

@cute.jit
def create_tensor_from_ptr(ptr: cute.Pointer):
    ...
    if ptr.memspace == cute_rt.AddressSpace.generic:
        ...
    elif ptr.memspace == cute_rt.AddressSpace.gmem:

Steps/Code to reproduce bug
Follow this guide http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports to craft a minimal bug report. This helps us reproduce the issue you're having and resolve the issue more quickly.

pip install nvidia-cutlass-dsl==4.5.0
import torch
import numpy as np
import cutlass
import cutlass.cute as cute
from cutlass.torch import dtype as cutlass_torch_dtype
import cutlass.cute.runtime as cute_rt

@cute.kernel
def kernel_demo(a: cute.Tensor):
    tidx, _, _ = cute.arch.thread_idx()
    if tidx == 0:
        cute.print_tensor(a)



@cute.jit
def create_tensor_from_ptr(ptr: cute.Pointer):
    layout = cute.make_layout((8, 5), stride=(5, 1))
    tensor = cute.make_tensor(ptr, layout)

    if ptr.memspace == cute_rt.AddressSpace.generic:
        cute.printf("generic tensor")
        tensor.fill(1)
        cute.print_tensor(tensor)
    elif ptr.memspace == cute_rt.AddressSpace.gmem:
        cute.printf("gmem tensor")
        kernel_demo(tensor).launch(grid=(1, 1, 1), block=(1, 1, 1))

target_device = "cuda"

a = torch.randn(8, 5, dtype=cutlass_torch_dtype(cute.Float32), device=target_device)
ptr_a = cute_rt.make_ptr(cutlass.Float32, a.data_ptr(), mem_space=cute_rt.AddressSpace.generic if target_device == "cpu" else cute_rt.AddressSpace.gmem)

print(cutlass_torch_dtype(cute.Float32))
print(ptr_a.dtype, ptr_a.memspace, ptr_a.size_in_bytes())

if target_device == "cuda":
    cutlass.cuda.initialize_cuda_context()

create_tensor_from_ptr(ptr_a)

it cause exception:

---------------------------------------------------------------------------
MLIRError                                 Traceback (most recent call last)
File /usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/dsl.py:1088, in BaseDSL.build_module(self, module, function_name)
   1087 try:
-> 1088     module.operation.verify()
   1089 except Exception as e:

MLIRError: Verification failed:
error: "cute.print_tensor(tensor)"("/tmp/ipykernel_3888/1609407769.py":17:8): 'cute.print_view' op in host context can only print view in host memory space, got '!cute.memref<f32, gmem, "(8,5):(5,1)">'
 note: "cute.print_tensor(tensor)"("/tmp/ipykernel_3888/1609407769.py":17:8): see current operation: "cute.print_view"(%3) <{is_signed = false, verbose = false}> : (!cute.memref<f32, gmem, "(8,5):(5,1)">) -> ()

During handling of the above exception, another exception occurred:

DSLRuntimeError                           Traceback (most recent call last)
Cell In[10], line 13
     10 if target_device == "cuda":
     11     cutlass.cuda.initialize_cuda_context()
---> 13 create_tensor_from_ptr(ptr_a)

File /usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/dsl.py:457, in BaseDSL.jit_runner.<locals>.jit_runner_decorator.<locals>.jit_wrapper(*args, **kwargs)
    453     return getattr(func._dsl_object, executor_name)(
    454         func, *args, **kwargs, _name_prefix=custom_name
    455     )
    456 else:
--> 457     return getattr(func._dsl_object, executor_name)(
    458         func, *args, **kwargs
    459     )

File /usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/dsl.py:1571, in BaseDSL._func(self, funcBody, *args, **kwargs)
   1569 # Generate MLIR Context and start generating IR
   1570 log().debug(f"Generating MLIR for function '{function_name}'")
-> 1571 result = self.generate_mlir(
   1572     funcBody,
   1573     canonicalized_kwargs,
   1574     function_name,
   1575     gpu_module_attrs,
   1576     canonicalized_args,
   1577     args_spec,
   1578     pipeline,
   1579     no_cache,
   1580     no_jit_engine,
   1581     compile_only,
   1582     location=self.decorator_location,
   1583 )
   1584 return result

File /usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/dsl.py:1360, in BaseDSL.generate_mlir(self, funcBody, kwargs, function_name, gpu_module_attrs, args, args_spec, pipeline, no_cache, no_jit_engine, compile_only, location)
   1357 original_function_name = funcBody.__name__
   1359 # Generate original ir module and its hash value.
-> 1360 module, module_hash, result = self.generate_original_ir(
   1361     ir,
   1362     func,
   1363     funcBody,
   1364     kwargs,
   1365     function_name,
   1366     func_types,
   1367     gpu_module_attrs,
   1368     args,
   1369     args_spec,
   1370     location=location,
   1371 )
   1373 # dryrun is used to only generate IR
   1374 if self.envar.dryrun:

File /usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/dsl.py:1165, in BaseDSL.generate_original_ir(self, ir, func, funcBody, kwargs, function_name, func_types, gpu_module_attrs, args, args_spec, location)
   1162     module, result = build_ir_module()
   1163 module_hash = self.get_module_hash(module, function_name)
-> 1165 module = self.build_module(module, function_name)
   1167 return module, module_hash, result

File /usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/dsl.py:1090, in BaseDSL.build_module(self, module, function_name)
   1088     module.operation.verify()
   1089 except Exception as e:
-> 1090     raise DSLRuntimeError("🧊🧊🧊 ICE IR Verification Failed 🧊🧊🧊", cause=e)
   1092 return module

DSLRuntimeError: DSLRuntimeError: 🧊🧊🧊 ICE IR Verification Failed 🧊🧊🧊
  Caused exception: Verification failed:
error: "cute.print_tensor(tensor)"("/tmp/ipykernel_3888/1609407769.py":17:8): 'cute.print_view' op in host context can only print view in host memory space, got '!cute.memref<f32, gmem, "(8,5):(5,1)">'
 note: "cute.print_tensor(tensor)"("/tmp/ipykernel_3888/1609407769.py":17:8): see current operation: "cute.print_view"(%3) <{is_signed = false, verbose = false}> : (!cute.memref<f32, gmem, "(8,5):(5,1)">) -> ()

Expected behavior
A clear and concise description of what you expected to happen.

Environment details (please complete the following information):

  • Environment location: [Bare-metal, Docker, Cloud(specify cloud provider)]

Additional context
Add any other context about the problem here.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions