diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 88c2eb3326d96..18e8270f5aa99 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" @@ -89,7 +90,22 @@ static FailureOr getFatRawBufferTypeLike(MemRefType source, auto stridedLayout = dyn_cast(layout); if (!stridedLayout) return failure(); - mb.setLayout(StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides())); + MemRefLayoutAttrInterface newLayout = + StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides()); + // Special case: if resetting the offset causes the strided layout to become + // the identity layout, then reset to the identity layout. + // TODO: this'll get a lot simpler when we have the contiguous layout. + SmallVector stridesIfIdentity; + if (source.hasStaticShape()) { + stridesIfIdentity = computeSuffixProduct(source.getShape()); + } else if (source.getRank() <= 1) { + stridesIfIdentity = SmallVector(source.getRank(), 1); + } + if (stridesIfIdentity == stridedLayout.getStrides()) { + newLayout = AffineMapAttr::get( + AffineMap::getMultiDimIdentityMap(source.getRank(), ctx)); + } + mb.setLayout(newLayout); } return (MemRefType)(mb); } diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir index 8871b2ce0eadb..cc1162d8b0de8 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir @@ -68,7 +68,7 @@ func.func @fat_raw_buffer_cast_dyn_size_offset(%buf: memref, #gpu_global_addrspace>) -> memref, #amdgpu.address_space> { +func.func @fat_raw_buffer_cast_reset_offset(%buf: memref, #gpu_global_addrspace>) -> memref> { // CHECK: %[[desc:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> // CHECK-DAG: %[[memRefPtr:.*]] = llvm.extractvalue %[[desc]][1] // CHECK-DAG: %[[memRefOff:.*]] = llvm.extractvalue %[[desc]][2] @@ -77,8 +77,8 @@ func.func @fat_raw_buffer_cast_reset_offset(%buf: memref, #gpu_global_addrspace> to memref, #amdgpu.address_space> - return %ret : memref, #amdgpu.address_space> + %ret = amdgpu.fat_raw_buffer_cast %buf resetOffset : memref, #gpu_global_addrspace> to memref> + return %ret : memref> } // CHECK-LABEL: func @fat_raw_buffer_cast_valid_bytes diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir index 5559ac8f1a5c3..fe2b32be04de4 100644 --- a/mlir/test/Dialect/AMDGPU/ops.mlir +++ b/mlir/test/Dialect/AMDGPU/ops.mlir @@ -360,10 +360,54 @@ func.func @fat_raw_buffer_cast_easy(%m: memref<8xi32>) -> memref<8xi32, #amdgpu. // CHECK-SAME: cacheSwizzleStride(%{{[^)]*}}) // CHECK-SAME: boundsCheck(false) // CHECK-SAME: resetOffset -func.func @fat_raw_buffer_cast(%m: memref<8xi32, strided<[1], offset: ?>>, %validBytes: i32, %cacheSwizzle: i14) -> memref<8xi32, strided<[1]>, #amdgpu.address_space> { +func.func @fat_raw_buffer_cast(%m: memref<8xi32, strided<[1], offset: ?>>, %validBytes: i32, %cacheSwizzle: i14) -> memref<8xi32, #amdgpu.address_space> { %ret = amdgpu.fat_raw_buffer_cast %m validBytes(%validBytes) cacheSwizzleStride(%cacheSwizzle) boundsCheck(false) resetOffset - : memref<8xi32, strided<[1], offset: ?>> to memref<8xi32, strided<[1]>, #amdgpu.address_space> - func.return %ret : memref<8xi32, strided<[1]>, #amdgpu.address_space> + : memref<8xi32, strided<[1], offset: ?>> to memref<8xi32, #amdgpu.address_space> + func.return %ret : memref<8xi32, #amdgpu.address_space> +} + +// CHECK-LABEL: func @fat_raw_buffer_cast_dynamic_1d_reset_offset +// CHECK: amdgpu.fat_raw_buffer_cast +func.func @fat_raw_buffer_cast_dynamic_1d_reset_offset(%m: memref>) -> memref> { + %ret = amdgpu.fat_raw_buffer_cast %m resetOffset + : memref> to memref> + func.return %ret : memref> +} + +// CHECK-LABEL: func @fat_raw_buffer_cast_dynamic_0d_reset_offset +// CHECK: %[[ret:.+]] = amdgpu.fat_raw_buffer_cast +// CHECK: return %[[ret]] +func.func @fat_raw_buffer_cast_dynamic_0d_reset_offset(%m: memref>) -> memref> { + %ret = amdgpu.fat_raw_buffer_cast %m resetOffset + : memref> to memref> + func.return %ret : memref> +} + +// CHECK-LABEL: func @fat_raw_buffer_cast_static_shape_2d_reset_offset +// CHECK: %[[ret:.+]] = amdgpu.fat_raw_buffer_cast +// CHECK: return %[[ret]] +func.func @fat_raw_buffer_cast_static_shape_2d_reset_offset(%m: memref<4x4xi32, strided<[4, 1], offset: ?>>) -> memref<4x4xi32, #amdgpu.address_space> { + %ret = amdgpu.fat_raw_buffer_cast %m resetOffset + : memref<4x4xi32, strided<[4, 1], offset: ?>> to memref<4x4xi32, #amdgpu.address_space> + func.return %ret : memref<4x4xi32, #amdgpu.address_space> +} + +// CHECK-LABEL: func @fat_raw_buffer_cast_dynamic_2d_reset_offset +// CHECK: %[[ret:.+]] = amdgpu.fat_raw_buffer_cast +// CHECK: return %[[ret]] +func.func @fat_raw_buffer_cast_dynamic_2d_reset_offset(%m: memref>) -> memref, #amdgpu.address_space> { + %ret = amdgpu.fat_raw_buffer_cast %m resetOffset + : memref> to memref, #amdgpu.address_space> + func.return %ret : memref, #amdgpu.address_space> +} + +// CHECK-LABEL: func @fat_raw_buffer_cast_noncontiguous_2d_reset_offset +// CHECK: %[[ret:.+]] = amdgpu.fat_raw_buffer_cast +// CHECK: return %[[ret]] +func.func @fat_raw_buffer_cast_noncontiguous_2d_reset_offset(%m: memref<4x4xi32, strided<[8, 1], offset: ?>>) -> memref<4x4xi32, strided<[8, 1]>, #amdgpu.address_space> { + %ret = amdgpu.fat_raw_buffer_cast %m resetOffset + : memref<4x4xi32, strided<[8, 1], offset: ?>> to memref<4x4xi32, strided<[8, 1]>, #amdgpu.address_space> + func.return %ret : memref<4x4xi32, strided<[8, 1]>, #amdgpu.address_space> } // CHECK-LABEL: func @raw_buffer_load_f32_from_rank_1