Skip to content

Conversation

@clementval
Copy link
Contributor

Shared memory for TMA operation needs to be align to 16. Add ability to set an alignment on the cuf.shared_memory operation.

@clementval clementval requested a review from wangzpgi December 2, 2025 21:50
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Dec 2, 2025
@llvmbot
Copy link
Member

llvmbot commented Dec 2, 2025

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

Shared memory for TMA operation needs to be align to 16. Add ability to set an alignment on the cuf.shared_memory operation.


Full diff: https://github.com/llvm/llvm-project/pull/170372.diff

4 Files Affected:

  • (modified) flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td (+3-4)
  • (modified) flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp (+10)
  • (modified) flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp (+1-1)
  • (modified) flang/test/Lower/CUDA/cuda-device-proc.cuf (+7)
diff --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
index 07bb47e26b968..3fda523acb382 100644
--- a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
+++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
@@ -350,15 +350,14 @@ def cuf_SharedMemoryOp
   let arguments = (ins TypeAttr:$in_type, OptionalAttr<StrAttr>:$uniq_name,
       OptionalAttr<StrAttr>:$bindc_name, Variadic<AnyIntegerType>:$typeparams,
       Variadic<AnyIntegerType>:$shape,
-      Optional<AnyIntegerType>:$offset // offset in bytes from the shared memory
-                                       // base address.
-  );
+      // offset in bytes from the shared memory base address.
+      Optional<AnyIntegerType>:$offset, OptionalAttr<I64Attr>:$alignment);
 
   let results = (outs fir_ReferenceType:$ptr);
 
   let assemblyFormat = [{
       (`[` $offset^ `:` type($offset) `]`)? $in_type (`(` $typeparams^ `:` type($typeparams) `)`)?
-        (`,` $shape^ `:` type($shape) )?  attr-dict `->` qualified(type($ptr))
+        (`,` $shape^ `:` type($shape) )?  (`align` $alignment^ )? attr-dict `->` qualified(type($ptr))
   }];
 
   let builders = [OpBuilder<(ins "mlir::Type":$inType,
diff --git a/flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp b/flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp
index 270037f5fcb00..67af481cec31a 100644
--- a/flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp
@@ -17,6 +17,8 @@
 #include "flang/Evaluate/common.h"
 #include "flang/Optimizer/Builder/FIRBuilder.h"
 #include "flang/Optimizer/Builder/MutableBox.h"
+#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "mlir/Dialect/Index/IR/IndexOps.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -1489,6 +1491,13 @@ void CUDAIntrinsicLibrary::genTMABulkG2S(
       builder, loc, dst, src, barrier, fir::getBase(args[3]), {}, {});
 }
 
+static void setAlignment(mlir::Value ptr, unsigned alignment) {
+  if (auto declareOp = mlir::dyn_cast<hlfir::DeclareOp>(ptr.getDefiningOp()))
+    if (auto sharedOp = mlir::dyn_cast<cuf::SharedMemoryOp>(
+            declareOp.getMemref().getDefiningOp()))
+      sharedOp.setAlignment(alignment);
+}
+
 static void genTMABulkLoad(fir::FirOpBuilder &builder, mlir::Location loc,
                            mlir::Value barrier, mlir::Value src,
                            mlir::Value dst, mlir::Value nelem,
@@ -1496,6 +1505,7 @@ static void genTMABulkLoad(fir::FirOpBuilder &builder, mlir::Location loc,
   mlir::Value size = mlir::arith::MulIOp::create(builder, loc, nelem, eleSize);
   auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(builder.getContext());
   barrier = builder.createConvert(loc, llvmPtrTy, barrier);
+  setAlignment(dst, 16);
   dst = builder.createConvert(loc, llvmPtrTy, dst);
   src = builder.createConvert(loc, llvmPtrTy, src);
   mlir::NVVM::InlinePtxOp::create(
diff --git a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
index 687007d957225..671e5f9455c22 100644
--- a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
+++ b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
@@ -333,7 +333,7 @@ void cuf::SharedMemoryOp::build(
       bindcName.empty() ? mlir::StringAttr{} : builder.getStringAttr(bindcName);
   build(builder, result, wrapAllocaResultType(inType),
         mlir::TypeAttr::get(inType), nameAttr, bindcAttr, typeparams, shape,
-        /*offset=*/mlir::Value{});
+        /*offset=*/mlir::Value{}, /*alignment=*/mlir::IntegerAttr{});
   result.addAttributes(attributes);
 }
 
diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index 434322ea22265..7f350944d70f6 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -538,6 +538,7 @@ end subroutine
 ! CHECK-LABEL: func.func @_QPtest_tma_bulk_load_c4
 ! CHECK: %[[BARRIER:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<shared>, uniq_name = "_QFtest_tma_bulk_load_c4Ebarrier1"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
 ! CHECK: %[[ELEM_COUNT:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_tma_bulk_load_c4Eelem_count"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: cuf.shared_memory !fir.array<1024xcomplex<f32>> align 16 {bindc_name = "tmp", uniq_name = "_QFtest_tma_bulk_load_c4Etmp"} -> !fir.ref<!fir.array<1024xcomplex<f32>>>
 ! CHECK: %[[COUNT:.*]] = fir.load %[[ELEM_COUNT]]#0 : !fir.ref<i32>
 ! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 8 : i32
 ! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
@@ -557,6 +558,7 @@ end subroutine
 ! CHECK-LABEL: func.func @_QPtest_tma_bulk_load_c8
 ! CHECK: %[[BARRIER:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<shared>, uniq_name = "_QFtest_tma_bulk_load_c8Ebarrier1"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
 ! CHECK: %[[ELEM_COUNT:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_tma_bulk_load_c8Eelem_count"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: cuf.shared_memory !fir.array<1024xcomplex<f64>> align 16 {bindc_name = "tmp", uniq_name = "_QFtest_tma_bulk_load_c8Etmp"} -> !fir.ref<!fir.array<1024xcomplex<f64>>>
 ! CHECK: %[[COUNT:.*]] = fir.load %[[ELEM_COUNT]]#0 : !fir.ref<i32>
 ! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 16 : i32
 ! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
@@ -576,6 +578,7 @@ end subroutine
 ! CHECK-LABEL: func.func @_QPtest_tma_bulk_load_i4
 ! CHECK: %[[BARRIER:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<shared>, uniq_name = "_QFtest_tma_bulk_load_i4Ebarrier1"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
 ! CHECK: %[[ELEM_COUNT:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_tma_bulk_load_i4Eelem_count"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: cuf.shared_memory !fir.array<1024xi32> align 16 {bindc_name = "tmp", uniq_name = "_QFtest_tma_bulk_load_i4Etmp"} -> !fir.ref<!fir.array<1024xi32>>
 ! CHECK: %[[COUNT:.*]] = fir.load %[[ELEM_COUNT]]#0 : !fir.ref<i32>
 ! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 4 : i32
 ! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
@@ -595,6 +598,7 @@ end subroutine
 ! CHECK-LABEL: func.func @_QPtest_tma_bulk_load_i8
 ! CHECK: %[[BARRIER:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<shared>, uniq_name = "_QFtest_tma_bulk_load_i8Ebarrier1"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
 ! CHECK: %[[ELEM_COUNT:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_tma_bulk_load_i8Eelem_count"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: cuf.shared_memory !fir.array<1024xi64> align 16 {bindc_name = "tmp", uniq_name = "_QFtest_tma_bulk_load_i8Etmp"} -> !fir.ref<!fir.array<1024xi64>>
 ! CHECK: %[[COUNT:.*]] = fir.load %[[ELEM_COUNT]]#0 : !fir.ref<i32>
 ! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 8 : i32
 ! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
@@ -614,6 +618,7 @@ end subroutine
 ! CHECK-LABEL: func.func @_QPtest_tma_bulk_load_r2
 ! CHECK: %[[BARRIER:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<shared>, uniq_name = "_QFtest_tma_bulk_load_r2Ebarrier1"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
 ! CHECK: %[[ELEM_COUNT:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_tma_bulk_load_r2Eelem_count"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: cuf.shared_memory !fir.array<1024xf16> align 16 {bindc_name = "tmp", uniq_name = "_QFtest_tma_bulk_load_r2Etmp"} -> !fir.ref<!fir.array<1024xf16>>
 ! CHECK: %[[COUNT:.*]] = fir.load %[[ELEM_COUNT]]#0 : !fir.ref<i32>
 ! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 2 : i32
 ! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
@@ -633,6 +638,7 @@ end subroutine
 ! CHECK-LABEL: func.func @_QPtest_tma_bulk_load_r4
 ! CHECK: %[[BARRIER:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<shared>, uniq_name = "_QFtest_tma_bulk_load_r4Ebarrier1"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
 ! CHECK: %[[ELEM_COUNT:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_tma_bulk_load_r4Eelem_count"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: cuf.shared_memory !fir.array<1024xf32> align 16 {bindc_name = "tmp", uniq_name = "_QFtest_tma_bulk_load_r4Etmp"} -> !fir.ref<!fir.array<1024xf32>>
 ! CHECK: %[[COUNT:.*]] = fir.load %[[ELEM_COUNT]]#0 : !fir.ref<i32>
 ! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 4 : i32
 ! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
@@ -652,6 +658,7 @@ end subroutine
 ! CHECK-LABEL: func.func @_QPtest_tma_bulk_load_r8
 ! CHECK: %[[BARRIER:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<shared>, uniq_name = "_QFtest_tma_bulk_load_r8Ebarrier1"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
 ! CHECK: %[[ELEM_COUNT:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_tma_bulk_load_r8Eelem_count"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: cuf.shared_memory !fir.array<1024xf64> align 16 {bindc_name = "tmp", uniq_name = "_QFtest_tma_bulk_load_r8Etmp"} -> !fir.ref<!fir.array<1024xf64>>
 ! CHECK: %[[COUNT:.*]] = fir.load %[[ELEM_COUNT]]#0 : !fir.ref<i32>
 ! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 8 : i32
 ! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32

@clementval clementval enabled auto-merge (squash) December 2, 2025 22:00
@clementval clementval merged commit d3256d9 into llvm:main Dec 2, 2025
13 checks passed
@clementval clementval deleted the cuf_shared_op_align branch December 2, 2025 22:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

flang:fir-hlfir flang Flang issues not falling into any other category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants