diff --git a/mlir/lib/Dialect/GPU/Transforms/GlobalIdRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/GlobalIdRewriter.cpp index 6519b65cec465..e55a695be13c8 100644 --- a/mlir/lib/Dialect/GPU/Transforms/GlobalIdRewriter.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/GlobalIdRewriter.cpp @@ -11,9 +11,9 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/GPU/Transforms/Passes.h" -#include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/IR/PatternMatch.h" using namespace mlir; @@ -26,13 +26,15 @@ struct GpuGlobalIdRewriter : public OpRewritePattern { PatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto dim = op.getDimension(); - auto blockId = gpu::BlockIdOp::create(rewriter, loc, dim); - auto blockDim = gpu::BlockDimOp::create(rewriter, loc, dim); + Value blockId = gpu::BlockIdOp::create(rewriter, loc, dim); + Value blockDim = gpu::BlockDimOp::create(rewriter, loc, dim); + auto indexType = rewriter.getIndexType(); // Compute blockId.x * blockDim.x - auto tmp = index::MulOp::create(rewriter, op.getLoc(), blockId, blockDim); - auto threadId = gpu::ThreadIdOp::create(rewriter, loc, dim); + Value tmp = + arith::MulIOp::create(rewriter, loc, indexType, blockId, blockDim); + Value threadId = gpu::ThreadIdOp::create(rewriter, loc, dim); // Compute threadId.x + blockId.x * blockDim.x - rewriter.replaceOpWithNewOp(op, threadId, tmp); + rewriter.replaceOpWithNewOp(op, indexType, threadId, tmp); return success(); } }; diff --git a/mlir/test/Conversion/GPUCommon/lower-global-id.mlir b/mlir/test/Conversion/GPUCommon/lower-global-id.mlir new file mode 100644 index 0000000000000..b0274e0f9f290 --- /dev/null +++ b/mlir/test/Conversion/GPUCommon/lower-global-id.mlir @@ -0,0 +1,33 @@ +// RUN: mlir-opt %s -split-input-file -convert-gpu-to-rocdl | FileCheck %s --check-prefixes=ROCDL +// RUN: mlir-opt %s -split-input-file -convert-gpu-to-nvvm | FileCheck %s --check-prefixes=NVVM + +gpu.module @kernel { + gpu.func @gpu_global_id() -> (index) { + %global_id_x = gpu.global_id x + gpu.return %global_id_x : index + } +} + +// ROCDL-LABEL: llvm.func @gpu_global_id() -> i64 { +// ROCDL: %[[WORKGROUP_0:.*]] = rocdl.workgroup.id.x : i32 +// ROCDL: %[[SEXT_0:.*]] = llvm.sext %[[WORKGROUP_0]] : i32 to i64 +// ROCDL: %[[WORKGROUP_1:.*]] = rocdl.workgroup.dim.x : i32 +// ROCDL: %[[SEXT_1:.*]] = llvm.sext %[[WORKGROUP_1]] : i32 to i64 +// ROCDL: %[[MUL_0:.*]] = llvm.mul %[[SEXT_0]], %[[SEXT_1]] : i64 +// ROCDL: %[[WORKITEM_0:.*]] = rocdl.workitem.id.x : i32 +// ROCDL: %[[SEXT_2:.*]] = llvm.sext %[[WORKITEM_0]] : i32 to i64 +// ROCDL: %[[ADD_0:.*]] = llvm.add %[[SEXT_2]], %[[MUL_0]] : i64 +// ROCDL: llvm.return %[[ADD_0]] : i64 +// ROCDL: } + +// NVVM-LABEL: llvm.func @gpu_global_id() -> i64 { +// NVVM: %[[READ_0:.*]] = nvvm.read.ptx.sreg.ctaid.x : i32 +// NVVM: %[[SEXT_0:.*]] = llvm.sext %[[READ_0]] : i32 to i64 +// NVVM: %[[READ_1:.*]] = nvvm.read.ptx.sreg.ntid.x : i32 +// NVVM: %[[SEXT_1:.*]] = llvm.sext %[[READ_1]] : i32 to i64 +// NVVM: %[[MUL_0:.*]] = llvm.mul %[[SEXT_0]], %[[SEXT_1]] : i64 +// NVVM: %[[READ_2:.*]] = nvvm.read.ptx.sreg.tid.x : i32 +// NVVM: %[[SEXT_2:.*]] = llvm.sext %[[READ_2]] : i32 to i64 +// NVVM: %[[ADD_0:.*]] = llvm.add %[[SEXT_2]], %[[MUL_0]] : i64 +// NVVM: llvm.return %[[ADD_0]] : i64 +// NVVM: } diff --git a/mlir/test/Dialect/GPU/globalId-rewrite.mlir b/mlir/test/Dialect/GPU/globalId-rewrite.mlir index 9e02d69daa436..d7d080a0093aa 100644 --- a/mlir/test/Dialect/GPU/globalId-rewrite.mlir +++ b/mlir/test/Dialect/GPU/globalId-rewrite.mlir @@ -8,27 +8,27 @@ module { threads(%tx, %ty, %tz) in (%block_x = %sz, %block_y = %sz, %block_z = %sz) { // CHECK: %[[BIDY:.*]] = gpu.block_id x // CHECK-NEXT: %[[BDIMY:.*]] = gpu.block_dim x - // CHECK-NEXT: %[[TMPY:.*]] = index.mul %[[BIDY]], %[[BDIMY]] + // CHECK-NEXT: %[[TMPY:.*]] = arith.muli %[[BIDY]], %[[BDIMY]] // CHECK-NEXT: %[[TIDX:.*]] = gpu.thread_id x - // CHECK-NEXT: %[[GIDX:.*]] = index.add %[[TIDX]], %[[TMPY]] + // CHECK-NEXT: %[[GIDX:.*]] = arith.addi %[[TIDX]], %[[TMPY]] %idx = gpu.global_id x // CHECK: memref.store %[[GIDX]], %[[MEM]][] : memref memref.store %idx, %mem[] : memref // CHECK: %[[BIDY:.*]] = gpu.block_id y // CHECK-NEXT: %[[BDIMY:.*]] = gpu.block_dim y - // CHECK-NEXT: %[[TMPY:.*]] = index.mul %[[BIDY]], %[[BDIMY]] + // CHECK-NEXT: %[[TMPY:.*]] = arith.muli %[[BIDY]], %[[BDIMY]] // CHECK-NEXT: %[[TIDY:.*]] = gpu.thread_id y - // CHECK-NEXT: %[[GIDY:.*]] = index.add %[[TIDY]], %[[TMPY]] + // CHECK-NEXT: %[[GIDY:.*]] = arith.addi %[[TIDY]], %[[TMPY]] %idy = gpu.global_id y // CHECK: memref.store %[[GIDY]], %[[MEM]][] : memref memref.store %idy, %mem[] : memref // CHECK: %[[BIDZ:.*]] = gpu.block_id z // CHECK-NEXT: %[[BDIMZ:.*]] = gpu.block_dim z - // CHECK-NEXT: %[[TMPZ:.*]] = index.mul %[[BIDZ]], %[[BDIMZ]] + // CHECK-NEXT: %[[TMPZ:.*]] = arith.muli %[[BIDZ]], %[[BDIMZ]] // CHECK-NEXT: %[[TIDZ:.*]] = gpu.thread_id z - // CHECK-NEXT: %[[GIDZ:.*]] = index.add %[[TIDZ]], %[[TMPZ]] + // CHECK-NEXT: %[[GIDZ:.*]] = arith.addi %[[TIDZ]], %[[TMPZ]] %idz = gpu.global_id z // CHECK: memref.store %[[GIDZ]], %[[MEM]][] : memref memref.store %idz, %mem[] : memref