Skip to content

Commit

Permalink
[CUDA/ROCM] Add byte offset to the base when lowering binding (#6193)
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasRaoux committed Jun 11, 2021
1 parent d18d794 commit 845b8bc
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
13 changes: 13 additions & 0 deletions iree/compiler/Conversion/LinalgToLLVMGPU/ConvertToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,19 @@ class ConvertIREEBindingOp : public ConvertToLLVMPattern {
ireeBindingOp.getResult().getType().dyn_cast<MemRefType>();
uint64_t binding = ireeBindingOp.queryBindingOp().binding().getZExtValue();
Value llvmBufferBasePtr = llvmFuncOp.getArgument(argMapping[binding]);
// Add the byte offset.
Value llvmBufferBasei8Ptr = rewriter.create<LLVM::BitcastOp>(
loc,
LLVM::LLVMPointerType::get(rewriter.getIntegerType(8),
llvmBufferBasePtr.getType()
.cast<LLVM::LLVMPointerType>()
.getAddressSpace()),
llvmBufferBasePtr);
llvmBufferBasei8Ptr = rewriter.create<LLVM::GEPOp>(
loc, llvmBufferBasei8Ptr.getType(), llvmBufferBasei8Ptr,
adaptor.byte_offset());
llvmBufferBasePtr = rewriter.create<LLVM::BitcastOp>(
loc, llvmBufferBasePtr.getType(), llvmBufferBasei8Ptr);
if (memrefType.hasStaticShape()) {
auto desc = MemRefDescriptor::fromStaticShape(
rewriter, loc, *getTypeConverter(), memrefType, llvmBufferBasePtr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
// Test that that standard and GPU ops are converted to LLVM and NVVM.
func @abs_ex_dispatch_0() {
%c0 = constant 0 : index
%0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<16xf32>
%c128 = constant 128 : index
%0 = hal.interface.binding.subspan @io::@arg0[%c128] : memref<16xf32>
%1 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<16xi32>
%2 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<16xf32>
%3 = "gpu.block_id"() {dimension = "x"} : () -> index
Expand All @@ -25,6 +26,12 @@ hal.interface @io attributes {sym_visibility = "private"} {
}

// CHECK-LABEL: llvm.func @abs_ex_dispatch_0
// CHECK-SAME: (%{{.*}}: !llvm.ptr<i32>, %{{.*}}: !llvm.ptr<f32>, %{{.*}}: !llvm.ptr<f32>)
// CHECK-SAME: (%[[ARG0:.+]]: !llvm.ptr<i32>, %[[ARG1:.+]]: !llvm.ptr<f32>, %{{.*}}: !llvm.ptr<f32>)
// CHECK: %[[C128:.+]] = llvm.mlir.constant(128 : index) : i64
// CHECK: %[[PTRI8:.+]] = llvm.bitcast %[[ARG1]] : !llvm.ptr<f32> to !llvm.ptr<i8>
// CHECK: %[[OFF:.+]] = llvm.getelementptr %[[PTRI8]][%[[C128]]] : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
// CHECK: %[[PTR:.+]] = llvm.bitcast %[[OFF]] : !llvm.ptr<i8> to !llvm.ptr<f32>
// CHECK: llvm.insertvalue %[[PTR]], %{{.*}}[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: llvm.insertvalue %[[PTR]], %{{.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: nvvm.read.ptx.sreg.tid.x
// CHECK: llvm.fadd

0 comments on commit 845b8bc

Please sign in to comment.