Skip to content

[mlir][tosa] Add row_gather_block_scaled op#192272

Merged
psunn merged 1 commit into
llvm:mainfrom
psunn:row_gather_block_scaled
Apr 16, 2026
Merged

[mlir][tosa] Add row_gather_block_scaled op#192272
psunn merged 1 commit into
llvm:mainfrom
psunn:row_gather_block_scaled

Conversation

@psunn
Copy link
Copy Markdown
Contributor

@psunn psunn commented Apr 15, 2026

Add tosa.row_gather_block_scaled to the MLIR TOSA dialect, aligned with the current TOSA 1.1 draft spec and the implementation in tosa-tools.

This includes:

  • op definition
  • verifier and shape inference support
  • validation / profile compliance wiring
  • availability and extension handling
  • lit tests for parsing, verification, shape inference, and version / extension gating

Notes

The op supports both spec-defined forms:

  • non-block-scaled: 1 input value tensor, BLOCK_SIZE_1, 1 output
  • block-scaled: data + scale tensor list, non-BLOCK_SIZE_1, 2 outputs

This also tightens existing block-scaled-only ops to reject BLOCK_SIZE_1 now that it is part of the shared enum.

Op-specific level checks for ROW_GATHER_BLOCK_SCALED have been deferred while the TOSA 1.1 draft is still evolving.

@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Apr 15, 2026

@llvm/pr-subscribers-mlir-tosa

Author: Peng Sun (psunn)

Changes

Add tosa.row_gather_block_scaled to the MLIR TOSA dialect, aligned with the current TOSA 1.1 draft spec and the implementation in tosa-tools.

This includes:

  • op definition
  • verifier and shape inference support
  • validation / profile compliance wiring
  • availability and extension handling
  • lit tests for parsing, verification, shape inference, and version / extension gating

Notes

The op supports both spec-defined forms:

  • non-block-scaled: 1 input value tensor, BLOCK_SIZE_1, 1 output
  • block-scaled: data + scale tensor list, non-BLOCK_SIZE_1, 2 outputs

This also tightens existing block-scaled-only ops to reject BLOCK_SIZE_1 now that it is part of the shared enum.


Patch is 36.87 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/192272.diff

14 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc (+70)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+3-3)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+38)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+193)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp (+18-4)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+12)
  • (modified) mlir/test/Dialect/Tosa/availability.mlir (+10-1)
  • (modified) mlir/test/Dialect/Tosa/invalid_extension.mlir (+8)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+16)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+20)
  • (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir (+10-1)
  • (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p1-pro-fp-valid.mlir (+9)
  • (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir (+10-1)
  • (modified) mlir/test/Dialect/Tosa/verifier.mlir (+36)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index d3e2cd129028e..50bb9f69c6242 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -396,6 +396,25 @@ profileComplianceMap = {
         {{i16T, i64T, i16T}, SpecificationVersion::V_1_1_DRAFT},
         {{i32T, i64T, i32T}, SpecificationVersion::V_1_1_DRAFT}},
        anyOf}}},
+    {"tosa.row_gather_block_scaled",
+     {{{Profile::pro_int},
+       {{{i8T, i32T, i32T, i8T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i16T, i32T, i32T, i16T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Profile::pro_fp},
+       {{{i8T, i32T, i32T, i8T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i16T, i32T, i32T, i16T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp16T, i32T, i32T, fp16T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp32T, i32T, i32T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Profile::pro_fp, Profile::pro_int},
+       {{{boolT, i32T, i32T, boolT}, SpecificationVersion::V_1_1_DRAFT},
+        {{i8T, i64T, i32T, i8T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i16T, i64T, i32T, i16T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i32T, i64T, i32T, i32T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i64T, i64T, i32T, i64T}, SpecificationVersion::V_1_1_DRAFT},
+        {{boolT, i64T, i32T, boolT}, SpecificationVersion::V_1_1_DRAFT}},
+       anyOf}}},
     {"tosa.scatter",
      {{{Profile::pro_int},
        {{{i8T, i32T, i8T, i8T}, SpecificationVersion::V_1_0},
@@ -890,6 +909,57 @@ extensionComplianceMap = {
       {{Extension::bf16, Extension::int64},
        {{{bf16T, i64T, bf16T}, SpecificationVersion::V_1_1_DRAFT}},
        allOf}}},
+    {"tosa.row_gather_block_scaled",
+     {{{Extension::fp8e4m3},
+       {{{fp8e4m3T, i32T, i32T, fp8e4m3T}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Extension::fp8e5m2},
+       {{{fp8e5m2T, i32T, i32T, fp8e5m2T}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Extension::bf16},
+       {{{bf16T, i32T, i32T, bf16T}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Extension::int64},
+       {{{i8T, i64T, i32T, i8T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i16T, i64T, i32T, i16T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i32T, i64T, i32T, i32T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i64T, i64T, i32T, i64T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp16T, i64T, i32T, fp16T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp32T, i64T, i32T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+        {{boolT, i64T, i32T, boolT}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Extension::fp8e4m3, Extension::int64},
+       {{{fp8e4m3T, i64T, i32T, fp8e4m3T}, SpecificationVersion::V_1_1_DRAFT}},
+       allOf},
+      {{Extension::fp8e5m2, Extension::int64},
+       {{{fp8e5m2T, i64T, i32T, fp8e5m2T}, SpecificationVersion::V_1_1_DRAFT}},
+       allOf},
+      {{Extension::bf16, Extension::int64},
+       {{{bf16T, i64T, i32T, bf16T}, SpecificationVersion::V_1_1_DRAFT}},
+       allOf},
+      {{Extension::mxfp},
+       {{{fp4e2m1T, fp8ue8m0T, i32T, i32T, fp4e2m1T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp6e2m3T, fp8ue8m0T, i32T, i32T, fp6e2m3T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp6e3m2T, fp8ue8m0T, i32T, i32T, fp6e3m2T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp8e4m3T, fp8ue8m0T, i32T, i32T, fp8e4m3T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp8e5m2T, fp8ue8m0T, i32T, i32T, fp8e5m2T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{mxint8T, fp8ue8m0T, i32T, i32T, mxint8T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT}}},
+      {{Extension::mxfp, Extension::int64},
+       {{{fp4e2m1T, fp8ue8m0T, i64T, i32T, fp4e2m1T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp6e2m3T, fp8ue8m0T, i64T, i32T, fp6e2m3T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp6e3m2T, fp8ue8m0T, i64T, i32T, fp6e3m2T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp8e4m3T, fp8ue8m0T, i64T, i32T, fp8e4m3T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp8e5m2T, fp8ue8m0T, i64T, i32T, fp8e5m2T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{mxint8T, fp8ue8m0T, i64T, i32T, mxint8T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT}},
+       allOf}}},
     {"tosa.scatter",
      {{{Extension::fp8e4m3},
        {{{fp8e4m3T, i32T, fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 1f05aee3e5eec..591073e9985ae 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -484,11 +484,12 @@ def Tosa_RoundingModeAttr
     : Tosa_I32EnumAttr<"RoundingMode", "Supported rounding modes", "rounding_mode",
                     [Tosa_ROUNDING_SINGLE_ROUND, Tosa_ROUNDING_INEXACT_ROUND, Tosa_ROUNDING_DOUBLE_ROUND]>;
 
+def Tosa_BLOCK_SIZE_1 : I32EnumAttrCase<"BLOCK_SIZE_1", 1>;
 def Tosa_BLOCK_SIZE_32 : I32EnumAttrCase<"BLOCK_SIZE_32", 32>;
 
 def Tosa_BlockSizeAttr
-    : Tosa_I32EnumAttr<"BlockSize", "Block size for the block_scaled formats", "block_size",
-                    [Tosa_BLOCK_SIZE_32]> {
+    : Tosa_I32EnumAttr<"BlockSize", "Block size for the block_scaled formats",
+                       "block_size", [Tosa_BLOCK_SIZE_1, Tosa_BLOCK_SIZE_32]> {
   let extraClassDeclaration = [{
     static uint32_t getBlockSizeValue(BlockSize blockSize) {
       return static_cast<uint32_t>(blockSize);
@@ -496,7 +497,6 @@ def Tosa_BlockSizeAttr
   }];
 }
 
-
 //===----------------------------------------------------------------------===//
 // TOSA Interfaces.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 45d1388a28749..ba750ef4438db 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2487,6 +2487,44 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather", [NoMemoryEffect]> {
       "operands attr-dict `:` functional-type(operands, results)";
 }
 
+//===----------------------------------------------------------------------===//
+// Operator: row_gather_block_scaled
+//===----------------------------------------------------------------------===//
+def Tosa_RowGatherBlockScaledOp
+    : Tosa_InferShapedTypeOp<"row_gather_block_scaled", [NoMemoryEffect]> {
+  let summary =
+      "Row gather operation for block-scaled and non-block-scaled data.";
+
+  let description = [{
+    Generate a tensor-list which contains a data tensor and an optional scale
+    tensor based on the input indices and row_count. The number of consecutive
+    rows gathered for each index is specified in row_count.
+
+    This operation follows the TOSA 1.1 draft specification and may evolve as
+    the specification is updated.
+
+    This operation is not pure. Undefined behaviour may occur if the specified
+    indices are out of range.
+  }];
+
+  let arguments = (ins Variadic<Tosa_Tensor3D>:$values,
+      Tosa_IndexTensor2D:$indices, Tosa_ScalarInt32Tensor:$row_count,
+      Tosa_BlockSizeAttr:$block_size);
+
+  let results = (outs Variadic<Tosa_Tensor3D>:$output);
+
+  list<Availability> availability = [Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+                                     Extension<[Tosa_EXT_FP8E4M3,
+                                                Tosa_EXT_FP8E5M2, Tosa_EXT_BF16,
+                                                Tosa_EXT_MXFP, Tosa_EXT_INT64]>,
+  ];
+
+  let hasVerifier = 1;
+
+  let assemblyFormat =
+      "operands attr-dict `:` functional-type(operands, results)";
+}
+
 //===----------------------------------------------------------------------===//
 // Operator: scatter
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 2754d3b21d4a6..360ab25875bd7 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2201,6 +2201,8 @@ LogicalResult MatmulTBlockScaledOp::verify() {
 
   // Verify C is a multiple of block size
   const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
+  if (blockSize == 1)
+    return emitOpError("requires block_size to not be BLOCK_SIZE_1");
   if (ShapedType::isStatic(C) && C % blockSize != 0)
     return emitOpError("expect C to be a multiple of block size, got C=")
            << C << ", block_size=" << blockSize;
@@ -2848,6 +2850,18 @@ static FailureOr<int64_t> getZeroPoint(Value val, bool signExtend) {
   return -1;
 }
 
+static FailureOr<int64_t> getConstantScalarIntValue(Value val) {
+  ElementsAttr attr;
+  if (!matchPattern(val, m_Constant(&attr)))
+    return failure();
+
+  if (!llvm::isa<IntegerType>(attr.getElementType()) ||
+      attr.getNumElements() != 1)
+    return failure();
+
+  return attr.getValues<APInt>()[0].getSExtValue();
+}
+
 template <typename T>
 static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp,
                                      const std::string &operand) {
@@ -3096,6 +3110,48 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult tosa::RowGatherBlockScaledOp::inferReturnTypeComponents(
+    MLIRContext *context, ::std::optional<Location> location,
+    RowGatherBlockScaledOp::Adaptor adaptor,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  const auto values = adaptor.getValues();
+  if (values.empty() || values.size() > 2)
+    return failure();
+
+  SmallVector<int64_t> dataShape(3, ShapedType::kDynamic);
+  const ShapeAdaptor valuesShape(values.front().getType());
+  if (valuesShape.hasRank()) {
+    dataShape[0] = valuesShape.getDimSize(0);
+    dataShape[2] = valuesShape.getDimSize(2);
+  }
+
+  const ShapeAdaptor indicesShape(adaptor.getIndices().getType());
+  if (indicesShape.hasRank()) {
+    if (dataShape[0] == ShapedType::kDynamic)
+      dataShape[0] = indicesShape.getDimSize(0);
+
+    if (auto rowCount = getConstantScalarIntValue(adaptor.getRowCount());
+        succeeded(rowCount) && rowCount.value() > 0) {
+      const int64_t indicesW = indicesShape.getDimSize(1);
+      if (ShapedType::isStatic(indicesW))
+        dataShape[1] = indicesW * rowCount.value();
+    }
+  }
+
+  inferredReturnShapes.push_back(ShapedTypeComponents(dataShape));
+  if (values.size() == 1)
+    return success();
+
+  SmallVector<int64_t> scaleShape = dataShape;
+  const uint32_t blockSize =
+      BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
+  if (ShapedType::isStatic(dataShape[2]))
+    scaleShape[2] = dataShape[2] / blockSize;
+
+  inferredReturnShapes.push_back(ShapedTypeComponents(scaleShape));
+  return success();
+}
+
 LogicalResult tosa::GatherOp::verify() {
   if (verifySameElementTypes(*this, /* inType = */ getValues().getType(),
                              /* outType = */ getOutput().getType())
@@ -3145,6 +3201,137 @@ LogicalResult tosa::GatherOp::verify() {
   return success();
 }
 
+LogicalResult tosa::RowGatherBlockScaledOp::verify() {
+  const OperandRange values = getValues();
+  const ResultRange output = getOutput();
+  if (values.empty() || values.size() > 2)
+    return emitOpError()
+           << "expects values tensor list length to be 1 or 2, got "
+           << values.size();
+  if (output.size() != values.size())
+    return emitOpError()
+           << "expects output tensor list length to match values tensor list "
+              "length, got "
+           << output.size() << " results for " << values.size()
+           << " input tensors";
+
+  const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
+  if (values.size() == 1 && blockSize != 1)
+    return emitOpError()
+           << "requires block_size to be BLOCK_SIZE_1 when values tensor list "
+              "length is 1";
+  if (values.size() == 2 && blockSize == 1)
+    return emitOpError()
+           << "requires block_size to not be BLOCK_SIZE_1 when values tensor "
+              "list length is 2";
+
+  if (failed(verifySameElementTypes(*this, values[0].getType(),
+                                    output[0].getType(), "values[0]",
+                                    "output[0]")))
+    return failure();
+  if (values.size() == 2 && failed(verifySameElementTypes(
+                                *this, values[1].getType(), output[1].getType(),
+                                "values[1]", "output[1]")))
+    return failure();
+
+  if (auto rowCount = getConstantScalarIntValue(getRowCount());
+      succeeded(rowCount) && rowCount.value() <= 0)
+    return emitOpError() << "requires row_count to be > 0, got "
+                         << rowCount.value();
+
+  int64_t n = ShapedType::kDynamic;
+  int64_t k = ShapedType::kDynamic;
+  int64_t c = ShapedType::kDynamic;
+  int64_t w = ShapedType::kDynamic;
+  int64_t multiplesOfC = ShapedType::kDynamic;
+
+  const ShapeAdaptor valuesDataShape(values[0].getType());
+  if (valuesDataShape.hasRank()) {
+    n = valuesDataShape.getDimSize(0);
+    k = valuesDataShape.getDimSize(1);
+    c = valuesDataShape.getDimSize(2);
+  }
+
+  if (ShapedType::isStatic(c) && c % blockSize != 0)
+    return emitOpError() << "expects channels of values[0] (" << c
+                         << ") to be divisible by block_size (" << blockSize
+                         << ")";
+
+  const ShapeAdaptor indicesShape(getIndices().getType());
+  if (indicesShape.hasRank()) {
+    if (failed(tryUpdateDimOrFailure(*this, n, indicesShape.getDimSize(0),
+                                     "indices", "batch")))
+      return failure();
+    w = indicesShape.getDimSize(1);
+  }
+
+  const ShapeAdaptor outputDataShape(output[0].getType());
+  if (outputDataShape.hasRank()) {
+    if (failed(tryUpdateDimOrFailure(*this, n, outputDataShape.getDimSize(0),
+                                     "output[0]", "batch")) ||
+        failed(tryUpdateDimOrFailure(*this, c, outputDataShape.getDimSize(2),
+                                     "output[0]", "channels")))
+      return failure();
+
+    if (auto rowCount = getConstantScalarIntValue(getRowCount());
+        succeeded(rowCount) && rowCount.value() > 0 &&
+        ShapedType::isStatic(w)) {
+      const int64_t expectedOutputRows = w * rowCount.value();
+      if (ShapedType::isStatic(outputDataShape.getDimSize(1)) &&
+          outputDataShape.getDimSize(1) != expectedOutputRows)
+        return emitOpError() << "requires output[0] dimension 1 to have size "
+                             << expectedOutputRows << ", got "
+                             << outputDataShape.getDimSize(1);
+    }
+  }
+
+  if (values.size() == 2) {
+    const ShapeAdaptor valuesScaleShape(values[1].getType());
+    if (valuesScaleShape.hasRank()) {
+      if (failed(tryUpdateDimOrFailure(*this, n, valuesScaleShape.getDimSize(0),
+                                       "values[1]", "batch")) ||
+          failed(tryUpdateDimOrFailure(*this, k, valuesScaleShape.getDimSize(1),
+                                       "values[1]", "rows")))
+        return failure();
+      multiplesOfC = valuesScaleShape.getDimSize(2);
+    }
+
+    const ShapeAdaptor outputScaleShape(output[1].getType());
+    if (outputScaleShape.hasRank()) {
+      if (failed(tryUpdateDimOrFailure(*this, n, outputScaleShape.getDimSize(0),
+                                       "output[1]", "batch")))
+        return failure();
+
+      if (auto rowCount = getConstantScalarIntValue(getRowCount());
+          succeeded(rowCount) && rowCount.value() > 0 &&
+          ShapedType::isStatic(w)) {
+        const int64_t expectedOutputRows = w * rowCount.value();
+        if (ShapedType::isStatic(outputScaleShape.getDimSize(1)) &&
+            outputScaleShape.getDimSize(1) != expectedOutputRows)
+          return emitOpError() << "requires output[1] dimension 1 to have size "
+                               << expectedOutputRows << ", got "
+                               << outputScaleShape.getDimSize(1);
+      }
+
+      if (ShapedType::isDynamic(multiplesOfC))
+        multiplesOfC = outputScaleShape.getDimSize(2);
+      else if (ShapedType::isStatic(outputScaleShape.getDimSize(2)) &&
+               multiplesOfC != outputScaleShape.getDimSize(2))
+        return emitOpError()
+               << "expected channels of output[1] to match size "
+               << multiplesOfC << ", got " << outputScaleShape.getDimSize(2);
+    }
+
+    if (ShapedType::isStatic(c) && ShapedType::isStatic(multiplesOfC) &&
+        multiplesOfC != c / blockSize)
+      return emitOpError()
+             << "expects channels of scale tensors to equal C/block_size (" << c
+             << "/" << blockSize << "), got " << multiplesOfC;
+  }
+
+  return success();
+}
+
 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     ResizeOp::Adaptor adaptor,
@@ -3987,6 +4174,8 @@ LogicalResult Conv2DBlockScaledOp::verify() {
 
   // Verify IC is a multiple of block size
   const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
+  if (blockSize == 1)
+    return emitOpError("requires block_size to not be BLOCK_SIZE_1");
   if (ShapedType::isStatic(IC) && IC % blockSize != 0)
     return emitOpError("expect IC to be a multiple of block size, got IC=")
            << IC << ", block_size=" << blockSize;
@@ -4577,6 +4766,8 @@ LogicalResult CastFromBlockScaledOp::verify() {
   if (inputDataShape.hasRank()) {
     const unsigned int blockSize =
         BlockSizeAttr::getBlockSizeValue(getBlockSize());
+    if (blockSize == 1)
+      return emitOpError("requires block_size to not be BLOCK_SIZE_1");
     const int64_t inputDataLastDim =
         inputDataShape.getDimSize(inputDataShape.getRank() - 1);
     if (inputDataLastDim % blockSize != 0)
@@ -4650,6 +4841,8 @@ LogicalResult CastToBlockScaledOp::verify() {
 
   const unsigned int blockSize =
       BlockSizeAttr::getBlockSizeValue(getBlockSize());
+  if (blockSize == 1)
+    return emitOpError("requires block_size to not be BLOCK_SIZE_1");
   const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
   if (inputDataShape.hasRank()) {
     const int64_t inputDataLastDim =
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 01c85be4f704f..4ea225b860f6c 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -185,6 +185,18 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) {
   return success();
 }
 
+template <>
+LogicalResult
+ProfileInfoDepot::populateProfileInfo(tosa::RowGatherBlockScaledOp op) {
+  for (Value value : op.getValues())
+    addValue(value);
+  addValue(op.getIndices());
+  addValue(op.getRowCount());
+  for (Value result : op.getOutput())
+    addValue(result);
+  return success();
+}
+
 template <>
 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) {
   addValue(op.getValuesIn());
@@ -288,6 +300,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
   POPULATE_PROFILE_INFO_CUSTOM(Tile)
   POPULATE_PROFILE_INFO_CUSTOM(Transpose)
   POPULATE_PROFILE_INFO_CUSTOM(Gather)
+  POPULATE_PROFILE_INFO_CUSTOM(RowGatherBlockScaled)
   POPULATE_PROFILE_INFO_CUSTOM(Scatter)
   POPULATE_PROFILE_INFO_CUSTOM(Resize)
   POPULATE_PROFILE_INFO_CUSTOM(Select)
@@ -598,10 +611,11 @@ SmallVector<OpComplianceInfo<T>> TosaProfileCompliance::findMatchedEntries(
     SmallVector<VersionedTypeInfo> sets = compInfo[i].operandTypeInfoSet;
     for (const auto &set : sets) {
       SmallVector<TypeInfo> expected = set.first;
-      assert(present.size() == expected.size() &&
-             "the entries for profile-based compliance do not match between...
[truncated]

@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Apr 15, 2026

@llvm/pr-subscribers-mlir

Author: Peng Sun (psunn)

Changes

Add tosa.row_gather_block_scaled to the MLIR TOSA dialect, aligned with the current TOSA 1.1 draft spec and the implementation in tosa-tools.

This includes:

  • op definition
  • verifier and shape inference support
  • validation / profile compliance wiring
  • availability and extension handling
  • lit tests for parsing, verification, shape inference, and version / extension gating

Notes

The op supports both spec-defined forms:

  • non-block-scaled: 1 input value tensor, BLOCK_SIZE_1, 1 output
  • block-scaled: data + scale tensor list, non-BLOCK_SIZE_1, 2 outputs

This also tightens existing block-scaled-only ops to reject BLOCK_SIZE_1 now that it is part of the shared enum.


Patch is 36.87 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/192272.diff

14 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc (+70)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+3-3)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+38)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+193)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp (+18-4)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+12)
  • (modified) mlir/test/Dialect/Tosa/availability.mlir (+10-1)
  • (modified) mlir/test/Dialect/Tosa/invalid_extension.mlir (+8)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+16)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+20)
  • (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir (+10-1)
  • (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p1-pro-fp-valid.mlir (+9)
  • (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir (+10-1)
  • (modified) mlir/test/Dialect/Tosa/verifier.mlir (+36)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index d3e2cd129028e..50bb9f69c6242 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -396,6 +396,25 @@ profileComplianceMap = {
         {{i16T, i64T, i16T}, SpecificationVersion::V_1_1_DRAFT},
         {{i32T, i64T, i32T}, SpecificationVersion::V_1_1_DRAFT}},
        anyOf}}},
+    {"tosa.row_gather_block_scaled",
+     {{{Profile::pro_int},
+       {{{i8T, i32T, i32T, i8T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i16T, i32T, i32T, i16T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Profile::pro_fp},
+       {{{i8T, i32T, i32T, i8T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i16T, i32T, i32T, i16T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i32T, i32T, i32T, i32T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp16T, i32T, i32T, fp16T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp32T, i32T, i32T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Profile::pro_fp, Profile::pro_int},
+       {{{boolT, i32T, i32T, boolT}, SpecificationVersion::V_1_1_DRAFT},
+        {{i8T, i64T, i32T, i8T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i16T, i64T, i32T, i16T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i32T, i64T, i32T, i32T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i64T, i64T, i32T, i64T}, SpecificationVersion::V_1_1_DRAFT},
+        {{boolT, i64T, i32T, boolT}, SpecificationVersion::V_1_1_DRAFT}},
+       anyOf}}},
     {"tosa.scatter",
      {{{Profile::pro_int},
        {{{i8T, i32T, i8T, i8T}, SpecificationVersion::V_1_0},
@@ -890,6 +909,57 @@ extensionComplianceMap = {
       {{Extension::bf16, Extension::int64},
        {{{bf16T, i64T, bf16T}, SpecificationVersion::V_1_1_DRAFT}},
        allOf}}},
+    {"tosa.row_gather_block_scaled",
+     {{{Extension::fp8e4m3},
+       {{{fp8e4m3T, i32T, i32T, fp8e4m3T}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Extension::fp8e5m2},
+       {{{fp8e5m2T, i32T, i32T, fp8e5m2T}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Extension::bf16},
+       {{{bf16T, i32T, i32T, bf16T}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Extension::int64},
+       {{{i8T, i64T, i32T, i8T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i16T, i64T, i32T, i16T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i32T, i64T, i32T, i32T}, SpecificationVersion::V_1_1_DRAFT},
+        {{i64T, i64T, i32T, i64T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp16T, i64T, i32T, fp16T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp32T, i64T, i32T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+        {{boolT, i64T, i32T, boolT}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Extension::fp8e4m3, Extension::int64},
+       {{{fp8e4m3T, i64T, i32T, fp8e4m3T}, SpecificationVersion::V_1_1_DRAFT}},
+       allOf},
+      {{Extension::fp8e5m2, Extension::int64},
+       {{{fp8e5m2T, i64T, i32T, fp8e5m2T}, SpecificationVersion::V_1_1_DRAFT}},
+       allOf},
+      {{Extension::bf16, Extension::int64},
+       {{{bf16T, i64T, i32T, bf16T}, SpecificationVersion::V_1_1_DRAFT}},
+       allOf},
+      {{Extension::mxfp},
+       {{{fp4e2m1T, fp8ue8m0T, i32T, i32T, fp4e2m1T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp6e2m3T, fp8ue8m0T, i32T, i32T, fp6e2m3T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp6e3m2T, fp8ue8m0T, i32T, i32T, fp6e3m2T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp8e4m3T, fp8ue8m0T, i32T, i32T, fp8e4m3T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp8e5m2T, fp8ue8m0T, i32T, i32T, fp8e5m2T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{mxint8T, fp8ue8m0T, i32T, i32T, mxint8T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT}}},
+      {{Extension::mxfp, Extension::int64},
+       {{{fp4e2m1T, fp8ue8m0T, i64T, i32T, fp4e2m1T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp6e2m3T, fp8ue8m0T, i64T, i32T, fp6e2m3T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp6e3m2T, fp8ue8m0T, i64T, i32T, fp6e3m2T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp8e4m3T, fp8ue8m0T, i64T, i32T, fp8e4m3T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp8e5m2T, fp8ue8m0T, i64T, i32T, fp8e5m2T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{mxint8T, fp8ue8m0T, i64T, i32T, mxint8T, fp8ue8m0T},
+         SpecificationVersion::V_1_1_DRAFT}},
+       allOf}}},
     {"tosa.scatter",
      {{{Extension::fp8e4m3},
        {{{fp8e4m3T, i32T, fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_0}}},
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 1f05aee3e5eec..591073e9985ae 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -484,11 +484,12 @@ def Tosa_RoundingModeAttr
     : Tosa_I32EnumAttr<"RoundingMode", "Supported rounding modes", "rounding_mode",
                     [Tosa_ROUNDING_SINGLE_ROUND, Tosa_ROUNDING_INEXACT_ROUND, Tosa_ROUNDING_DOUBLE_ROUND]>;
 
+def Tosa_BLOCK_SIZE_1 : I32EnumAttrCase<"BLOCK_SIZE_1", 1>;
 def Tosa_BLOCK_SIZE_32 : I32EnumAttrCase<"BLOCK_SIZE_32", 32>;
 
 def Tosa_BlockSizeAttr
-    : Tosa_I32EnumAttr<"BlockSize", "Block size for the block_scaled formats", "block_size",
-                    [Tosa_BLOCK_SIZE_32]> {
+    : Tosa_I32EnumAttr<"BlockSize", "Block size for the block_scaled formats",
+                       "block_size", [Tosa_BLOCK_SIZE_1, Tosa_BLOCK_SIZE_32]> {
   let extraClassDeclaration = [{
     static uint32_t getBlockSizeValue(BlockSize blockSize) {
       return static_cast<uint32_t>(blockSize);
@@ -496,7 +497,6 @@ def Tosa_BlockSizeAttr
   }];
 }
 
-
 //===----------------------------------------------------------------------===//
 // TOSA Interfaces.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 45d1388a28749..ba750ef4438db 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2487,6 +2487,44 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather", [NoMemoryEffect]> {
       "operands attr-dict `:` functional-type(operands, results)";
 }
 
+//===----------------------------------------------------------------------===//
+// Operator: row_gather_block_scaled
+//===----------------------------------------------------------------------===//
+def Tosa_RowGatherBlockScaledOp
+    : Tosa_InferShapedTypeOp<"row_gather_block_scaled", [NoMemoryEffect]> {
+  let summary =
+      "Row gather operation for block-scaled and non-block-scaled data.";
+
+  let description = [{
+    Generate a tensor-list which contains a data tensor and an optional scale
+    tensor based on the input indices and row_count. The number of consecutive
+    rows gathered for each index is specified in row_count.
+
+    This operation follows the TOSA 1.1 draft specification and may evolve as
+    the specification is updated.
+
+    This operation is not pure. Undefined behaviour may occur if the specified
+    indices are out of range.
+  }];
+
+  let arguments = (ins Variadic<Tosa_Tensor3D>:$values,
+      Tosa_IndexTensor2D:$indices, Tosa_ScalarInt32Tensor:$row_count,
+      Tosa_BlockSizeAttr:$block_size);
+
+  let results = (outs Variadic<Tosa_Tensor3D>:$output);
+
+  list<Availability> availability = [Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+                                     Extension<[Tosa_EXT_FP8E4M3,
+                                                Tosa_EXT_FP8E5M2, Tosa_EXT_BF16,
+                                                Tosa_EXT_MXFP, Tosa_EXT_INT64]>,
+  ];
+
+  let hasVerifier = 1;
+
+  let assemblyFormat =
+      "operands attr-dict `:` functional-type(operands, results)";
+}
+
 //===----------------------------------------------------------------------===//
 // Operator: scatter
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 2754d3b21d4a6..360ab25875bd7 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2201,6 +2201,8 @@ LogicalResult MatmulTBlockScaledOp::verify() {
 
   // Verify C is a multiple of block size
   const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
+  if (blockSize == 1)
+    return emitOpError("requires block_size to not be BLOCK_SIZE_1");
   if (ShapedType::isStatic(C) && C % blockSize != 0)
     return emitOpError("expect C to be a multiple of block size, got C=")
            << C << ", block_size=" << blockSize;
@@ -2848,6 +2850,18 @@ static FailureOr<int64_t> getZeroPoint(Value val, bool signExtend) {
   return -1;
 }
 
+static FailureOr<int64_t> getConstantScalarIntValue(Value val) {
+  ElementsAttr attr;
+  if (!matchPattern(val, m_Constant(&attr)))
+    return failure();
+
+  if (!llvm::isa<IntegerType>(attr.getElementType()) ||
+      attr.getNumElements() != 1)
+    return failure();
+
+  return attr.getValues<APInt>()[0].getSExtValue();
+}
+
 template <typename T>
 static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp,
                                      const std::string &operand) {
@@ -3096,6 +3110,48 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult tosa::RowGatherBlockScaledOp::inferReturnTypeComponents(
+    MLIRContext *context, ::std::optional<Location> location,
+    RowGatherBlockScaledOp::Adaptor adaptor,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  const auto values = adaptor.getValues();
+  if (values.empty() || values.size() > 2)
+    return failure();
+
+  SmallVector<int64_t> dataShape(3, ShapedType::kDynamic);
+  const ShapeAdaptor valuesShape(values.front().getType());
+  if (valuesShape.hasRank()) {
+    dataShape[0] = valuesShape.getDimSize(0);
+    dataShape[2] = valuesShape.getDimSize(2);
+  }
+
+  const ShapeAdaptor indicesShape(adaptor.getIndices().getType());
+  if (indicesShape.hasRank()) {
+    if (dataShape[0] == ShapedType::kDynamic)
+      dataShape[0] = indicesShape.getDimSize(0);
+
+    if (auto rowCount = getConstantScalarIntValue(adaptor.getRowCount());
+        succeeded(rowCount) && rowCount.value() > 0) {
+      const int64_t indicesW = indicesShape.getDimSize(1);
+      if (ShapedType::isStatic(indicesW))
+        dataShape[1] = indicesW * rowCount.value();
+    }
+  }
+
+  inferredReturnShapes.push_back(ShapedTypeComponents(dataShape));
+  if (values.size() == 1)
+    return success();
+
+  SmallVector<int64_t> scaleShape = dataShape;
+  const uint32_t blockSize =
+      BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
+  if (ShapedType::isStatic(dataShape[2]))
+    scaleShape[2] = dataShape[2] / blockSize;
+
+  inferredReturnShapes.push_back(ShapedTypeComponents(scaleShape));
+  return success();
+}
+
 LogicalResult tosa::GatherOp::verify() {
   if (verifySameElementTypes(*this, /* inType = */ getValues().getType(),
                              /* outType = */ getOutput().getType())
@@ -3145,6 +3201,137 @@ LogicalResult tosa::GatherOp::verify() {
   return success();
 }
 
+LogicalResult tosa::RowGatherBlockScaledOp::verify() {
+  const OperandRange values = getValues();
+  const ResultRange output = getOutput();
+  if (values.empty() || values.size() > 2)
+    return emitOpError()
+           << "expects values tensor list length to be 1 or 2, got "
+           << values.size();
+  if (output.size() != values.size())
+    return emitOpError()
+           << "expects output tensor list length to match values tensor list "
+              "length, got "
+           << output.size() << " results for " << values.size()
+           << " input tensors";
+
+  const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
+  if (values.size() == 1 && blockSize != 1)
+    return emitOpError()
+           << "requires block_size to be BLOCK_SIZE_1 when values tensor list "
+              "length is 1";
+  if (values.size() == 2 && blockSize == 1)
+    return emitOpError()
+           << "requires block_size to not be BLOCK_SIZE_1 when values tensor "
+              "list length is 2";
+
+  if (failed(verifySameElementTypes(*this, values[0].getType(),
+                                    output[0].getType(), "values[0]",
+                                    "output[0]")))
+    return failure();
+  if (values.size() == 2 && failed(verifySameElementTypes(
+                                *this, values[1].getType(), output[1].getType(),
+                                "values[1]", "output[1]")))
+    return failure();
+
+  if (auto rowCount = getConstantScalarIntValue(getRowCount());
+      succeeded(rowCount) && rowCount.value() <= 0)
+    return emitOpError() << "requires row_count to be > 0, got "
+                         << rowCount.value();
+
+  int64_t n = ShapedType::kDynamic;
+  int64_t k = ShapedType::kDynamic;
+  int64_t c = ShapedType::kDynamic;
+  int64_t w = ShapedType::kDynamic;
+  int64_t multiplesOfC = ShapedType::kDynamic;
+
+  const ShapeAdaptor valuesDataShape(values[0].getType());
+  if (valuesDataShape.hasRank()) {
+    n = valuesDataShape.getDimSize(0);
+    k = valuesDataShape.getDimSize(1);
+    c = valuesDataShape.getDimSize(2);
+  }
+
+  if (ShapedType::isStatic(c) && c % blockSize != 0)
+    return emitOpError() << "expects channels of values[0] (" << c
+                         << ") to be divisible by block_size (" << blockSize
+                         << ")";
+
+  const ShapeAdaptor indicesShape(getIndices().getType());
+  if (indicesShape.hasRank()) {
+    if (failed(tryUpdateDimOrFailure(*this, n, indicesShape.getDimSize(0),
+                                     "indices", "batch")))
+      return failure();
+    w = indicesShape.getDimSize(1);
+  }
+
+  const ShapeAdaptor outputDataShape(output[0].getType());
+  if (outputDataShape.hasRank()) {
+    if (failed(tryUpdateDimOrFailure(*this, n, outputDataShape.getDimSize(0),
+                                     "output[0]", "batch")) ||
+        failed(tryUpdateDimOrFailure(*this, c, outputDataShape.getDimSize(2),
+                                     "output[0]", "channels")))
+      return failure();
+
+    if (auto rowCount = getConstantScalarIntValue(getRowCount());
+        succeeded(rowCount) && rowCount.value() > 0 &&
+        ShapedType::isStatic(w)) {
+      const int64_t expectedOutputRows = w * rowCount.value();
+      if (ShapedType::isStatic(outputDataShape.getDimSize(1)) &&
+          outputDataShape.getDimSize(1) != expectedOutputRows)
+        return emitOpError() << "requires output[0] dimension 1 to have size "
+                             << expectedOutputRows << ", got "
+                             << outputDataShape.getDimSize(1);
+    }
+  }
+
+  if (values.size() == 2) {
+    const ShapeAdaptor valuesScaleShape(values[1].getType());
+    if (valuesScaleShape.hasRank()) {
+      if (failed(tryUpdateDimOrFailure(*this, n, valuesScaleShape.getDimSize(0),
+                                       "values[1]", "batch")) ||
+          failed(tryUpdateDimOrFailure(*this, k, valuesScaleShape.getDimSize(1),
+                                       "values[1]", "rows")))
+        return failure();
+      multiplesOfC = valuesScaleShape.getDimSize(2);
+    }
+
+    const ShapeAdaptor outputScaleShape(output[1].getType());
+    if (outputScaleShape.hasRank()) {
+      if (failed(tryUpdateDimOrFailure(*this, n, outputScaleShape.getDimSize(0),
+                                       "output[1]", "batch")))
+        return failure();
+
+      if (auto rowCount = getConstantScalarIntValue(getRowCount());
+          succeeded(rowCount) && rowCount.value() > 0 &&
+          ShapedType::isStatic(w)) {
+        const int64_t expectedOutputRows = w * rowCount.value();
+        if (ShapedType::isStatic(outputScaleShape.getDimSize(1)) &&
+            outputScaleShape.getDimSize(1) != expectedOutputRows)
+          return emitOpError() << "requires output[1] dimension 1 to have size "
+                               << expectedOutputRows << ", got "
+                               << outputScaleShape.getDimSize(1);
+      }
+
+      if (ShapedType::isDynamic(multiplesOfC))
+        multiplesOfC = outputScaleShape.getDimSize(2);
+      else if (ShapedType::isStatic(outputScaleShape.getDimSize(2)) &&
+               multiplesOfC != outputScaleShape.getDimSize(2))
+        return emitOpError()
+               << "expected channels of output[1] to match size "
+               << multiplesOfC << ", got " << outputScaleShape.getDimSize(2);
+    }
+
+    if (ShapedType::isStatic(c) && ShapedType::isStatic(multiplesOfC) &&
+        multiplesOfC != c / blockSize)
+      return emitOpError()
+             << "expects channels of scale tensors to equal C/block_size (" << c
+             << "/" << blockSize << "), got " << multiplesOfC;
+  }
+
+  return success();
+}
+
 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     ResizeOp::Adaptor adaptor,
@@ -3987,6 +4174,8 @@ LogicalResult Conv2DBlockScaledOp::verify() {
 
   // Verify IC is a multiple of block size
   const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
+  if (blockSize == 1)
+    return emitOpError("requires block_size to not be BLOCK_SIZE_1");
   if (ShapedType::isStatic(IC) && IC % blockSize != 0)
     return emitOpError("expect IC to be a multiple of block size, got IC=")
            << IC << ", block_size=" << blockSize;
@@ -4577,6 +4766,8 @@ LogicalResult CastFromBlockScaledOp::verify() {
   if (inputDataShape.hasRank()) {
     const unsigned int blockSize =
         BlockSizeAttr::getBlockSizeValue(getBlockSize());
+    if (blockSize == 1)
+      return emitOpError("requires block_size to not be BLOCK_SIZE_1");
     const int64_t inputDataLastDim =
         inputDataShape.getDimSize(inputDataShape.getRank() - 1);
     if (inputDataLastDim % blockSize != 0)
@@ -4650,6 +4841,8 @@ LogicalResult CastToBlockScaledOp::verify() {
 
   const unsigned int blockSize =
       BlockSizeAttr::getBlockSizeValue(getBlockSize());
+  if (blockSize == 1)
+    return emitOpError("requires block_size to not be BLOCK_SIZE_1");
   const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
   if (inputDataShape.hasRank()) {
     const int64_t inputDataLastDim =
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 01c85be4f704f..4ea225b860f6c 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -185,6 +185,18 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) {
   return success();
 }
 
+template <>
+LogicalResult
+ProfileInfoDepot::populateProfileInfo(tosa::RowGatherBlockScaledOp op) {
+  for (Value value : op.getValues())
+    addValue(value);
+  addValue(op.getIndices());
+  addValue(op.getRowCount());
+  for (Value result : op.getOutput())
+    addValue(result);
+  return success();
+}
+
 template <>
 LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) {
   addValue(op.getValuesIn());
@@ -288,6 +300,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
   POPULATE_PROFILE_INFO_CUSTOM(Tile)
   POPULATE_PROFILE_INFO_CUSTOM(Transpose)
   POPULATE_PROFILE_INFO_CUSTOM(Gather)
+  POPULATE_PROFILE_INFO_CUSTOM(RowGatherBlockScaled)
   POPULATE_PROFILE_INFO_CUSTOM(Scatter)
   POPULATE_PROFILE_INFO_CUSTOM(Resize)
   POPULATE_PROFILE_INFO_CUSTOM(Select)
@@ -598,10 +611,11 @@ SmallVector<OpComplianceInfo<T>> TosaProfileCompliance::findMatchedEntries(
     SmallVector<VersionedTypeInfo> sets = compInfo[i].operandTypeInfoSet;
     for (const auto &set : sets) {
       SmallVector<TypeInfo> expected = set.first;
-      assert(present.size() == expected.size() &&
-             "the entries for profile-based compliance do not match between...
[truncated]

@psunn
Copy link
Copy Markdown
Contributor Author

psunn commented Apr 15, 2026

@Tai78641
Copy link
Copy Markdown
Contributor

LGTM except for one nit

@psunn psunn force-pushed the row_gather_block_scaled branch from b5cc923 to 1c6765f Compare April 16, 2026 10:37
@psunn
Copy link
Copy Markdown
Contributor Author

psunn commented Apr 16, 2026

Updated to resolve merge conflicts.

Copy link
Copy Markdown
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @psunn! Had some nitpicks, otherwise LGTM!

Comment thread mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Comment thread mlir/lib/Dialect/Tosa/IR/TosaOps.cpp Outdated
Comment thread mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
Comment thread mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@psunn psunn force-pushed the row_gather_block_scaled branch 3 times, most recently from bad45d3 to 4e88d6b Compare April 16, 2026 14:31
@psunn psunn requested a review from lhutton1 April 16, 2026 14:33
Note: defer op-specific level checks for ROW_GATHER_BLOCK_SCALED while the TOSA 1.1 draft is still evolving.

Signed-off-by: Peng Sun <peng.sun@arm.com>
Copy link
Copy Markdown
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@psunn psunn force-pushed the row_gather_block_scaled branch from 8c57610 to a9ca5fb Compare April 16, 2026 14:35
@psunn
Copy link
Copy Markdown
Contributor Author

psunn commented Apr 16, 2026

updated to rebase on head of current main

@psunn psunn merged commit 2f268ec into llvm:main Apr 16, 2026
10 checks passed
lhutton1 added a commit to lhutton1/llvm-project that referenced this pull request Apr 17, 2026
Fixes a validation test after a merge race condition with
llvm#192122 and
llvm#192272.

Change-Id: I3b87e0ef9c04432b3af1a5aae7630dbae662e802
lhutton1 added a commit that referenced this pull request Apr 17, 2026
Fixes a validation test after a merge race condition with
#192122 and
#192272.
llvm-sync Bot pushed a commit to arm/arm-toolchain that referenced this pull request Apr 17, 2026
cpullvm-upstream-sync Bot pushed a commit to navaneethshan/cpullvm-toolchain-1 that referenced this pull request Apr 17, 2026
alexfh pushed a commit to alexfh/llvm-project that referenced this pull request Apr 18, 2026
Add `tosa.row_gather_block_scaled` to the MLIR TOSA dialect, aligned
with the current TOSA 1.1 draft spec and the implementation in
`tosa-tools`.

  This includes:
  - op definition
  - verifier and shape inference support
  - validation / profile compliance wiring
  - availability and extension handling
- lit tests for parsing, verification, shape inference, and version /
extension gating

  The op supports both spec-defined forms:
  - non-block-scaled: 1 input value tensor, `BLOCK_SIZE_1`, 1 output
- block-scaled: data + scale tensor list, non-`BLOCK_SIZE_1`, 2 outputs

Op-specific level checks for ROW_GATHER_BLOCK_SCALED have been deferred
while the TOSA 1.1 draft is still evolving.

Signed-off-by: Peng Sun <peng.sun@arm.com>
alexfh pushed a commit to alexfh/llvm-project that referenced this pull request Apr 18, 2026
Fixes a validation test after a merge race condition with
llvm#192122 and
llvm#192272.
llvm-upstreamsync Bot pushed a commit to qualcomm/cpullvm-toolchain that referenced this pull request Apr 24, 2026
KHicketts pushed a commit to KHicketts/llvm-project that referenced this pull request Apr 30, 2026
Fixes a validation test after a merge race condition with
llvm#192122 and
llvm#192272.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants