Skip to content

[MLIR][XeGPU] Refactor isEvenlyDistributable() to Layout attribute interface#191945

Merged
Jianhui-Li merged 3 commits into
llvm:mainfrom
Jianhui-Li:users/Jianhui-Li/XeGPU/RefactorIsEvenlyDistributable
Apr 15, 2026
Merged

[MLIR][XeGPU] Refactor isEvenlyDistributable() to Layout attribute interface#191945
Jianhui-Li merged 3 commits into
llvm:mainfrom
Jianhui-Li:users/Jianhui-Li/XeGPU/RefactorIsEvenlyDistributable

Conversation

@Jianhui-Li
Copy link
Copy Markdown
Contributor

This PR refactor isEvenlyDistributable() to layout attribute interface isDistributable(), and used them in all anchor operations to check the shape can be ditributed with the anchor layout.

@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Apr 14, 2026

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Jianhui Li (Jianhui-Li)

Changes

This PR refactor isEvenlyDistributable() to layout attribute interface isDistributable(), and used them in all anchor operations to check the shape can be ditributed with the anchor layout.


Patch is 27.95 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/191945.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td (+53-7)
  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td (-4)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+4-84)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+73-3)
  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+1-1)
  • (modified) mlir/test/Dialect/XeGPU/invalid.mlir (-8)
  • (modified) mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir (+8-8)
  • (modified) mlir/test/Dialect/XeGPU/transform-ops.mlir (+11-11)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index f8a2beabb9b95..a7bea9881602f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -188,6 +188,9 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
     InterfaceMethod<"Check the availability of subgroup level layouts",
                     "bool",
                     "isForSubgroup">,
+    InterfaceMethod<"Check the availability of lane level layouts",
+                    "bool",
+                    "isForLane">,
     InterfaceMethod<"Get the rank of attribute",
                     "int64_t",
                     "getRank">,
@@ -299,26 +302,60 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
                       } else {
                         return failure();
                       }
-                      assert(
-                          !subShape.empty() &&
-                          "sgdata or lanedata cannot be empty for distributed shape computation");
+                      // sgdata or lanedata cannot be empty for distributed shape computation
+                      if (subShape.empty())
+                        return failure();
                       SmallVector<int64_t> distributedShape(shape.size());
                       for (auto [i, dim] : llvm::enumerate(shape)) {
                         int64_t distriUnit = layout[i]*subShape[i];
                         if ((dim % distriUnit) == 0) {
                           // Evenly divisible case, divide the dimension by the layout factor.
                           distributedShape[i] = dim / layout[i];
-                          assert((distributedShape[i] % subShape[i] == 0) &&
-                                "Even distribution: sgdata or lanedata must divide the distributed dimension");
+                          if (distributedShape[i] % subShape[i] != 0)
+                            return failure();
                         } else {
                           // wrap around case, the dimension size must be equal to subShape value
-                          assert(dim == subShape[i] &&
-                                "Wrap-around distribution: sgdata or lanedata must be same as tensor tile shape");
+                          if(dim != subShape[i])
+                            return failure();
                           distributedShape[i] = dim;
                         }
                       }
                       return distributedShape;
                     }]>,
+    InterfaceMethod<[{Checks if the given shape can be distributed by the layout}],
+                    /*retTy=*/"bool",
+                    /*methodName=*/"isDistributable",
+                    /*args=*/(ins "SmallVector<int64_t>":$shape),
+                    /*methodBody=*/[{
+                        DistributeLayoutAttr curLayoutAttr = $_self;
+                        SmallVector<int64_t> curShape = shape;
+                        // Phase 1: Distribute across subgroups (sg_layout + sg_data).
+                        if (curLayoutAttr.isForWorkgroup()) {
+                          auto maybeSgShape = curLayoutAttr.computeDistributedShape(curShape);
+                          if (failed(maybeSgShape))
+                            return false;
+                          curShape = maybeSgShape.value();
+                          curLayoutAttr = curLayoutAttr.dropSgLayoutAndData();
+                          if (!curLayoutAttr)
+                            return true;
+                        }
+                        // Phase 2: Distribute across instruction data (inst_data).
+                        if (curLayoutAttr.isForSubgroup() && !curLayoutAttr.isForLane()) {
+                          SmallVector<int64_t> instData = curLayoutAttr.getEffectiveInstDataAsInt();
+                          for (size_t i = 0; i < curShape.size(); ++i) {
+                            if (curShape[i] % instData[i] != 0)
+                              return false;
+                          }
+                          // inst_data becomes the new shape for next phase
+                          curShape = instData;
+                          curLayoutAttr = curLayoutAttr.dropInstData();
+                          if (!curLayoutAttr)
+                            return true;
+                        }
+                        // Phase 3: Distribute across lanes (lane_layout + lane_data).
+                        auto maybeLaneShape = curLayoutAttr.computeDistributedShape(curShape);
+                        return succeeded(maybeLaneShape);
+                      }]>,
     InterfaceMethod</*desc=*/[{Check if this layout is a slice of another layout.}],
                     /*retTy=*/"bool",
                     /*methodName=*/"isSliceOf",
@@ -487,6 +524,10 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
       return !isForWorkgroup();
     }
 
+    bool isForLane() {
+      return !isForWorkgroup() && (getInstData() == nullptr);
+    }
+
     int64_t getRank() const {
       if (auto attr = getSgLayout())
         return attr.size();
@@ -687,6 +728,11 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
       return parent.isForSubgroup();
     }
 
+    bool isForLane() const {
+      auto parent = dyn_cast<LayoutAttr>(getParent());
+      return parent.isForLane();
+    }
+
     /// Returns the SgLayout of the attribute, computed by applying
     /// the slice dimensions to the underlying LayoutAttr.
     SmallVector<int64_t> getEffectiveSgLayoutAsInt() const {
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
index c173b93face98..84fd8f9e0060c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
@@ -38,10 +38,6 @@ def XeGPU_Dialect : Dialect {
     let useDefaultAttributePrinterParser = true;
 
     let extraClassDeclaration = [{
-      /// Checks if the given shape can be evenly distributed based on the layout
-      /// and data factors provided by the LayoutAttr.
-      static bool isEvenlyDistributable(llvm::ArrayRef<int64_t> shape, xegpu::DistributeLayoutAttr attr);
-
       /// drops/slices the shape in the specified dims, and return the rest. e.g.,
       /// for shape = [32, 64, 8], dims = [0, 2], it will return [64]
       template<typename T, typename U>
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index eaa43c02946d8..80a3fc91f1c4f 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -121,74 +121,6 @@ static SmallVector<SmallVector<int64_t>> genStaticCoordinates(
   return coordinates;
 }
 
-// Checks if the given shape can be evenly distributed based on the layout
-// and data factors provided by the LayoutAttr.
-bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
-                                         xegpu::DistributeLayoutAttr attr) {
-  assert(attr && "Layout attribute is missing.");
-
-  // Checks whether the given shape can be evenly distributed using the
-  // specified layout and data attributes. If successful, it returns the work
-  // size for each compute unit; otherwise, it returns `std::nullopt`. The work
-  // size per compute unit is calculated as follows:
-  //   - If `data` is null: newShape[i] = shape[i] / layout[i]
-  //   - If `data` is not null: newShape[i] = data[i]
-  // When round-robin distribution (`rr`) is enabled, `shape[i]` can be
-  // smaller than `layout[i] * data[i]`, allowing multiple compute units to
-  // share the data.
-  auto tryDistribute = [&](llvm::ArrayRef<int64_t> shape,
-                           SmallVector<int64_t> layout,
-                           SmallVector<int64_t> data,
-                           bool rr = true) -> optional<SmallVector<int64_t>> {
-    llvm::SmallVector<int64_t> newShape(shape);
-    if (layout.size()) {
-      if (layout.size() != shape.size())
-        return std::nullopt;
-      auto ratio = computeShapeRatio(shape, layout);
-      if (ratio.has_value()) {
-        newShape = ratio.value();
-      } else if (!rr || !computeShapeRatio(layout, shape).has_value()) {
-        return std::nullopt;
-      }
-      // Round-robin case: continue with original newShape
-    }
-
-    if (data.size()) {
-      if (data.size() != shape.size())
-        return std::nullopt;
-      auto ratio = computeShapeRatio(newShape, data);
-      if (!ratio.has_value() && rr)
-        ratio = computeShapeRatio(data, newShape);
-      if (!ratio.has_value())
-        return std::nullopt;
-
-      // if data is not null, we always return it for next phase.
-      newShape = data;
-    }
-    return newShape;
-  };
-
-  // check the sgLayout and sgData
-  auto maybeSgShape = tryDistribute(shape, attr.getEffectiveSgLayoutAsInt(),
-                                    attr.getEffectiveSgDataAsInt());
-  if (!maybeSgShape)
-    return false;
-  auto sgShape = maybeSgShape.value();
-
-  // check InstData, it neither have layout nor need round-robin
-  auto maybeInstShape =
-      tryDistribute(sgShape, {}, attr.getEffectiveInstDataAsInt(), false);
-  if (!maybeInstShape)
-    return false;
-  auto instShape = maybeInstShape.value();
-
-  // check LaneLayout and LaneData
-  auto maybeLaneShape =
-      tryDistribute(instShape, attr.getEffectiveLaneLayoutAsInt(),
-                    attr.getEffectiveLaneDataAsInt());
-  return maybeLaneShape.has_value();
-}
-
 //===----------------------------------------------------------------------===//
 // XeGPU_BlockTensorDescAttr
 //===----------------------------------------------------------------------===//
@@ -1431,25 +1363,12 @@ TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
                            << chunkAlignmentFactor;
     }
   }
-
-  auto layoutAttr = llvm::dyn_cast_if_present<LayoutAttr>(layout);
-  if (layoutAttr) {
+  if (auto layoutAttr =
+          mlir::dyn_cast_if_present<DistributeLayoutAttr>(layout)) {
     if (rank != (size_t)layoutAttr.getRank())
       return emitError() << "expected layout rank to match tensor rank";
 
-    auto laneData = layoutAttr.getLaneData();
-    if (scatterAttr && laneData) {
-      // Validate subgroup mapping rules for scattered tensors.
-      // if chunkSize > 1, the last dimension of the tensor should
-      // be distributed in the units divisible by chunkAlignmentFactor.
-      int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
-      if (chunkSize > 1 && laneData[rank - 1] % chunkAlignmentFactor)
-        return emitError()
-               << "expected last dim of lane_data to be a multiple of: "
-               << chunkAlignmentFactor;
-    }
-
-    if (!XeGPUDialect::isEvenlyDistributable(shape, layoutAttr)) {
+    if (!layoutAttr.isDistributable(SmallVector<int64_t>(shape))) {
       std::string shapeStr;
       llvm::raw_string_ostream stream(shapeStr);
       llvm::interleaveComma(shape, stream);
@@ -1457,6 +1376,7 @@ TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
                          << layoutAttr;
     }
   }
+
   return success();
 }
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 5697097a4c999..9107cda30a8fa 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -218,6 +218,11 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
       }
     }
   }
+
+  if (layout && !layout.isDistributable(
+                    SmallVector<int64_t>(dataShape.begin(), dataShape.end())))
+    return emitError() << "Value shape is not distributable with the layout";
+
   if (dataShape.size() == 2) {
     if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
                      [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
@@ -504,6 +509,14 @@ LogicalResult PrefetchNdOp::verify() {
     return emitOpError(
         "Mismatched ranks between offsets and tensor descriptor");
 
+  if (getAnchorLayout()) {
+    auto layout = getAnchorLayout();
+    auto tdescShape = getShapeOf(tdescTy);
+    if (!layout.isDistributable(tdescShape))
+      return emitOpError(
+          "TensorDesc shape is not distributable with the layout");
+  }
+
   return success();
 }
 
@@ -628,6 +641,14 @@ LogicalResult LoadNdOp::verify() {
     return emitOpError(
         "Mismatched ranks between offsets and tensor descriptor");
 
+  if (getAnchorLayout()) {
+    auto layout = getAnchorLayout();
+    auto origTdescShape = getShapeOf(tdescTy);
+    if (!layout.isDistributable(origTdescShape))
+      return emitOpError(
+          "TensorDesc shape is not distributable with the layout");
+  }
+
   return success();
 }
 
@@ -721,6 +742,13 @@ LogicalResult StoreNdOp::verify() {
     return emitOpError(
         "Mismatched ranks between offsets and tensor descriptor");
 
+  if (getAnchorLayout()) {
+    auto layout = getAnchorLayout();
+    if (!layout.isDistributable(tdescShape))
+      return emitOpError(
+          "TensorDesc shape is not distributable with the layout");
+  }
+
   return success();
 }
 
@@ -823,6 +851,19 @@ LogicalResult PrefetchOp::verify() {
   if (getOffsetAlignByteAttr() && !srcTy.isInteger())
     return emitOpError("offset_align_byte only allowed with integer source.");
 
+  if (getAnchorLayout()) {
+    auto layout = getAnchorLayout();
+    // get the offset operand and its shape
+    if (auto offsets = getOffsets()) {
+      auto offsetsTy = offsets.getType();
+      if (!llvm::isa<VectorType>(offsetsTy))
+        return emitOpError("Offsets should be a vector.");
+      auto offsetShape = getShapeOf(offsetsTy);
+      if (!layout.isDistributable(offsetShape))
+        return emitOpError("offset shape is not distributable with the layout");
+    }
+  }
+
   return success();
 }
 
@@ -870,6 +911,13 @@ LogicalResult LoadGatherOp::verify() {
   if (memTy && (getElementType() != memTy.getElementType()))
     return emitError() << "Value should have the same element type as MemRef.";
 
+  if (getAnchorLayout()) {
+    auto layout = getAnchorLayout();
+    auto valShape = getShapeOf(valueTy);
+    if (!layout.isDistributable(valShape))
+      return emitOpError("Value shape is not distributable with the layout");
+  }
+
   auto offsetsTy = getOffsets().getType();
   return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize,
                                           [&]() { return emitOpError(); });
@@ -954,6 +1002,13 @@ LogicalResult StoreScatterOp::verify() {
   if (memTy && (getElementType() != memTy.getElementType()))
     return emitError() << "Value should have the same element type as MemRef.";
 
+  if (getAnchorLayout()) {
+    auto layout = getAnchorLayout();
+    auto valShape = getShapeOf(valueTy);
+    if (!layout.isDistributable(valShape))
+      return emitOpError("Value shape is not distributable with the layout");
+  }
+
   auto offsetsTy = getOffsets().getType();
   return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize,
                                           [&]() { return emitOpError(); });
@@ -1052,6 +1107,21 @@ LogicalResult DpasOp::verify() {
   auto rhsShape = getRhsType().getShape();
   auto resShape = getResultType().getShape();
 
+  if (auto cdLayout = getLayoutCd())
+    if (!cdLayout->isDistributable(
+            SmallVector<int64_t>(resShape.begin(), resShape.end())))
+      return emitOpError("Value shape is not distributable with the layout");
+
+  if (auto aLayout = getLayoutA())
+    if (!aLayout->isDistributable(
+            SmallVector<int64_t>(lhsShape.begin(), lhsShape.end())))
+      return emitOpError("Value shape is not distributable with the layout");
+
+  if (auto bLayout = getLayoutB())
+    if (!bLayout->isDistributable(
+            SmallVector<int64_t>(rhsShape.begin(), rhsShape.end())))
+      return emitOpError("Value shape is not distributable with the layout");
+
   if (getAcc() && getAcc().getType() != getResultType())
     return emitOpError("Expecting the acc type to be the same as result.");
 
@@ -1103,12 +1173,12 @@ LogicalResult ConvertLayoutOp::verify() {
 
   Type srcType = getSource().getType();
   if (llvm::isa<VectorType>(srcType)) {
-    auto shape = llvm::cast<VectorType>(srcType).getShape();
-    if (!XeGPUDialect::isEvenlyDistributable(shape, srcLayout))
+    SmallVector<int64_t> shape(llvm::cast<VectorType>(srcType).getShape());
+    if (!srcLayout.isDistributable(shape))
       return emitOpError(
           "invalid input layout, data cannot be evenly distributed.");
 
-    if (!XeGPUDialect::isEvenlyDistributable(shape, resLayout))
+    if (!resLayout.isDistributable(shape))
       return emitOpError(
           "invalid target layout, data cannot be evenly distributed.");
   }
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index a095c19d66c15..d637b6828deab 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -498,7 +498,7 @@ struct WgToSgVectorBroadcastOp
     VectorType newResultType =
         VectorType::get(sgShape, resultType.getElementType());
 
-    if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
+    if (!layout.isDistributable(SmallVector<int64_t>(wgShape)))
       return failure();
 
     SmallVector<Value> newBroadcastOps;
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 7390b47b3f8d9..82c7879c79d56 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -325,14 +325,6 @@ func.func @create_tdesc_layout_1(%src: ui64) {
   return
 }
 
-// -----
-func.func @create_tdesc_layout_2(%src: ui64) {
-  %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
-  // expected-error@+1 {{expected last dim of lane_data to be a multiple of: 2}}
-  %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x4xf16, #xegpu.scatter_tdesc_attr<chunk_size = 4>, #xegpu.layout<lane_layout = [4, 1], lane_data = [1, 1]>>
-  return
-}
-
 // -----
 func.func @load_gather_simt_1(%src: ui64) {
   %0 = arith.constant dense<1>: vector<4xi1>
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
index 831d1e05967f8..62426d619445b 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
@@ -27,24 +27,24 @@ gpu.module @test {
   // CHECK-SAME: %[[ARG_1:.*]]: memref<128x256xf32>
   func.func @vector_transpose(%src: memref<256x128xf32>, %src1: memref<128x256xf32>) {
     // CHECK: %[[TDESC_LD:.*]] = xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf32> ->
-    // CHECK-SAME: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 32], order = [0, 1]>>
+    // CHECK-SAME: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 16], order = [0, 1]>>
     // CHECK: %[[TDESC_ST:.*]] = xegpu.create_nd_tdesc %[[ARG_1]] : memref<128x256xf32> ->
-    // CHECK-SAME: !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], order = [1, 0]>>
+    // CHECK-SAME: !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64], order = [1, 0]>>
 
-    // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC_LD]][0, 0] <{layout = #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 32], order = [0, 1]>}>
-    // CHECK-SAME: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 32], order = [0, 1]>> -> vector<256x128xf32>
+    // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC_LD]][0, 0] <{layout = #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 16], order = [0, 1]>}>
+    // CHECK-SAME: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 16], order = [0, 1]>> -> vector<256x128xf32>
 
     // CHECK: %[[TRANSPOSED:.*]] = vector.transpose %2, [1, 0]
-    // CHECK-SAME {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], order = [1, 0]>} : vector<256x128xf32> to vector<128x256xf32>
+    // CHECK-SAME {layout_resu...
[truncated]

Comment thread mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp Outdated
if (memTy && (getElementType() != memTy.getElementType()))
return emitError() << "Value should have the same element type as MemRef.";

if (getAnchorLayout()) {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (auto layout = getAnchorLayout()) instead?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. Also, isDistributable takes a SmallVector<int64_t>, why not feed getShapeOf(valueTy) directly into it or change the signature in favor of ArrayRef?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modified

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Changing the signature in favor of ArrayRef" is a good suggestion. Most of these interface function requires revisit to use ArrayRef instead of SmallVector. For this one, I try to change it but it calls other function using smallvector so it is not a big gain.

}

bool isForLane() const {
auto parent = dyn_cast<LayoutAttr>(getParent());
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dont we need a flatten here?
SliceAttr attr = flatten()

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to flatten().

// Phase 2: Distribute across instruction data (inst_data).
if (curLayoutAttr.isForSubgroup() && !curLayoutAttr.isForLane()) {
SmallVector<int64_t> instData = curLayoutAttr.getEffectiveInstDataAsInt();
for (size_t i = 0; i < curShape.size(); ++i) {
Copy link
Copy Markdown
Contributor

@akroviakov akroviakov Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we distribute progressively, as in a while loop of computeDistributedShape that unravels the shape and removes the "outermost" layout? This would contain the distribution logic in one place.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my initial implementation tries to use a while loop, but it doesn't pay off and hard to debug. So I would leave as is.

Copy link
Copy Markdown
Contributor

@akroviakov akroviakov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good addition to the layout API

@Jianhui-Li Jianhui-Li merged commit 4df814a into llvm:main Apr 15, 2026
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants