diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 95fe1e0535843..73afdb29b6149 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -309,12 +309,15 @@ class NVVM_SingleResultIntrinsicOp traits = [], str class NVVM_PureSpecialRegisterOp traits = []> : NVVM_IntrOp { let arguments = (ins); + let results = (outs I32:$res); let assemblyFormat = "attr-dict `:` type($res)"; } -class NVVM_SpecialRegisterOp traits = []> : +class NVVM_SpecialRegisterOp traits = []> : NVVM_IntrOp { let arguments = (ins); + let results = (outs resultType:$res); let assemblyFormat = "attr-dict `:` type($res)"; } @@ -421,8 +424,8 @@ def NVVM_AggrSmemSize : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.aggr.smem.s //===----------------------------------------------------------------------===// // Clock registers def NVVM_ClockOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.clock">; -def NVVM_Clock64Op : NVVM_SpecialRegisterOp<"read.ptx.sreg.clock64">; -def NVVM_GlobalTimerOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.globaltimer">; +def NVVM_Clock64Op : NVVM_SpecialRegisterOp<"read.ptx.sreg.clock64", I64>; +def NVVM_GlobalTimerOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.globaltimer", I64>; def NVVM_GlobalTimerLoOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.globaltimer.lo">; //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index e849b59b846f7..0e3357992be18 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -2115,3 +2115,19 @@ module attributes { dlti.dl_spec = #dlti.dl_spec< %0 = llvm.ptrtoaddr %arg0 : !llvm.ptr to i64 } } + +// ----- + +func.func @nvvm_read_sreg_tid_x_wrong_type() { + // expected-error@+1 {{'nvvm.read.ptx.sreg.tid.x' op result #0 must be 32-bit signless integer, but got 'i64'}} + %0 = nvvm.read.ptx.sreg.tid.x : i64 + return +} + +// ----- + +func.func @nvvm_read_sreg_clock64_wrong_type() { + // expected-error@+1 {{'nvvm.read.ptx.sreg.clock64' op result #0 must be 64-bit signless integer, but got 'i32'}} + %0 = nvvm.read.ptx.sreg.clock64 : i32 + return +} diff --git a/mlir/test/python/dialects/nvvm.py b/mlir/test/python/dialects/nvvm.py index 24abf617548b8..f5e057812642e 100644 --- a/mlir/test/python/dialects/nvvm.py +++ b/mlir/test/python/dialects/nvvm.py @@ -377,3 +377,16 @@ def reductions(mask, vi32, vf32): # CHECK: %[[REDUX_35:.*]] = nvvm.redux.sync fmax %[[ARG2]], %[[ARG1]] : f32 -> f32 # CHECK: return # CHECK: } + + +# CHECK-LABEL: TEST: testSpecialRegisterInferredResults +@constructAndPrintInModule +def testSpecialRegisterInferredResults(): + # CHECK: %{{.*}} = nvvm.read.ptx.sreg.tid.x : i32 + nvvm.ThreadIdXOp() + # CHECK: %{{.*}} = nvvm.read.ptx.sreg.clock : i32 + nvvm.ClockOp() + # CHECK: %{{.*}} = nvvm.read.ptx.sreg.clock64 : i64 + nvvm.Clock64Op() + # CHECK: %{{.*}} = nvvm.read.ptx.sreg.globaltimer : i64 + nvvm.GlobalTimerOp()