Skip to content

Conversation

@lhutton1
Copy link
Contributor

Adds initial support for the ext-shape extension, including the operations:

  • ADD_SHAPE
  • SUB_SHAPE
  • MUL_SHAPE
  • DIV_FLOOR_SHAPE
  • DIV_CEIL_SHAPE

to align with the spec change: arm/tosa-specification@efc88a1.

This includes the operator definition, same rank checks and level checks during validation. It does not currently include support for folding or shape inference. This will be added in a later commit.

Based on work originally implemented by @Tai78641.

Adds initial support for the ext-shape extension, including
the operations:
- ADD_SHAPE
- SUB_SHAPE
- MUL_SHAPE
- DIV_FLOOR_SHAPE
- DIV_CEIL_SHAPE
to align with the spec change:
arm/tosa-specification@efc88a1.

This includes the operator definition, same rank checks
and level checks during validation. It does not currently
include support for folding or shape inference. This will
be added in a later commit.

Change-Id: I544af295552b9a9fecaba50b6131d7876113e47c
@llvmbot
Copy link
Member

llvmbot commented Nov 24, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tosa

Author: Luke Hutton (lhutton1)

Changes

Adds initial support for the ext-shape extension, including the operations:

  • ADD_SHAPE
  • SUB_SHAPE
  • MUL_SHAPE
  • DIV_FLOOR_SHAPE
  • DIV_CEIL_SHAPE

to align with the spec change: arm/tosa-specification@efc88a1.

This includes the operator definition, same rank checks and level checks during validation. It does not currently include support for folding or shape inference. This will be added in a later commit.

Based on work originally implemented by @Tai78641.


Full diff: https://github.com/llvm/llvm-project/pull/169321.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+4-2)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h (+1)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td (+116-7)
  • (modified) mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp (+1)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp (+5)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+15-3)
  • (modified) mlir/test/Dialect/Tosa/invalid_extension.mlir (+10)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+21-1)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+45)
  • (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir (+10-1)
  • (modified) mlir/test/Dialect/Tosa/verifier.mlir (+16)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index cc23955f31f23..419340256fa59 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -241,6 +241,7 @@ class Tosa_I32EnumAttr<string name, string description, string mnemonic,
 // INEXACTROUND : Adds inexact rounding support to the RESCALE operator.
 // DYNAMIC      : Removes all Compile Time Constant state for CTC inputs.
 // MXFP         : Microscaling formats.
+// SHAPE        : Shape calcuation operators.
 //===----------------------------------------------------------------------===//
 
 def Tosa_NONE : I32EnumAttrCase<"none", 0>;
@@ -274,6 +275,7 @@ def Tosa_EXT_INEXACTROUND : I32EnumAttrCase<"inexactround", 10>;
 def Tosa_EXT_DYNAMIC      : I32EnumAttrCase<"dynamic", 11>;
 def Tosa_EXT_MXFP         : I32EnumAttrCase<"mxfp", 12>;
 def Tosa_EXT_INT64        : I32EnumAttrCase<"int64", 13>;
+def Tosa_EXT_SHAPE        : I32EnumAttrCase<"shape", 14>;
 
 
 def Tosa_ExtensionAttr
@@ -281,7 +283,7 @@ def Tosa_ExtensionAttr
       Tosa_EXT_NONE, Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16,
       Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE,
       Tosa_EXT_CONTROLFLOW, Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND,
-      Tosa_EXT_DYNAMIC, Tosa_EXT_MXFP, Tosa_EXT_INT64
+      Tosa_EXT_DYNAMIC, Tosa_EXT_MXFP, Tosa_EXT_INT64, Tosa_EXT_SHAPE,
     ]> {
   let extraClassDeclaration = [{
     static llvm::SmallVector<Extension, 13> getAllValues() {
@@ -290,7 +292,7 @@ def Tosa_ExtensionAttr
         Extension::fp8e4m3, Extension::fp8e5m2, Extension::fft,
         Extension::variable, Extension::controlflow, Extension::doubleround,
         Extension::inexactround, Extension::dynamic, Extension::mxfp,
-        Extension::int64
+        Extension::int64, Extension::shape
       };
     }
   }];
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index ea58f49b64c44..bee253689bab7 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -154,6 +154,7 @@ class TosaProfileCompliance {
     case Extension::controlflow:
     case Extension::dynamic:
     case Extension::int64:
+    case Extension::shape:
       return {Profile::pro_fp, Profile::pro_int};
     case Extension::none:
       return {};
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
index 90cda42d95624..7b1c7e208ebe3 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
@@ -30,15 +30,8 @@ def TosaShapeOperator : NativeOpTrait<"TosaShapeOperator"> {
 
 class Tosa_ShapeOp<string mnemonic, list<Trait> traits = []>
     : Tosa_Op<mnemonic, !listconcat(traits, [TosaShapeOperator, Pure])> {
-  list<Availability> availability = [
-    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
-    Extension<[]>,
-  ];
-
   let assemblyFormat =
       "operands attr-dict `:` functional-type(operands, results)";
-
-  let hasFolder = 1;
 }
 
 // op trait: shape operator has same ranks for operands and results
@@ -53,6 +46,29 @@ class Tosa_ElementwiseShapeOp<string mnemonic, list<Trait> traits = []>
 }
 
 
+//===----------------------------------------------------------------------===//
+// Operator: AddShape
+//===----------------------------------------------------------------------===//
+def Tosa_AddShapeOp : Tosa_ElementwiseShapeOp<"add_shape", [Pure]> {
+  let summary = "Elementwise addition of shapes.";
+
+  let description = [{
+      Elementwise addition of input1 and input2. Size of shapes must match.
+  }];
+
+  let arguments = (ins
+    Tosa_Shape:$input1,
+    Tosa_Shape:$input2
+  );
+
+  let results = (outs Tosa_Shape:$output);
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_SHAPE]>,
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // Operator: ConstShape
 //===----------------------------------------------------------------------===//
@@ -80,6 +96,99 @@ def Tosa_ConstShapeOp : Tosa_ShapeOp<"const_shape", [ConstantLike, Pure]> {
   ];
 
   let hasVerifier = 1;
+  let hasFolder = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: DivCeilShape
+//===----------------------------------------------------------------------===//
+def Tosa_DivCeilShapeOp : Tosa_ElementwiseShapeOp<"div_ceil_shape", [Pure]> {
+  let summary = "Elementwise ceiling divide of shapes.";
+
+  let description = [{
+      Elementwise divide of input1 by input2. The result of the divide is rounded up.
+  }];
+
+  let arguments = (ins
+    Tosa_Shape:$input1,
+    Tosa_Shape:$input2
+  );
+
+  let results = (outs Tosa_Shape:$output);
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_SHAPE]>
+  ];
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: DivFloorShape
+//===----------------------------------------------------------------------===//
+def Tosa_DivFloorShapeOp : Tosa_ElementwiseShapeOp<"div_floor_shape", [Pure]> {
+  let summary = "Elementwise floor divide of shapes.";
+
+  let description = [{
+      Elementwise integer divide of input1 by input2. The result of the divide is rounded down.
+  }];
+
+  let arguments = (ins
+    Tosa_Shape:$input1,
+    Tosa_Shape:$input2
+  );
+
+  let results = (outs Tosa_Shape:$output);
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_SHAPE]>
+  ];
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: MulShape
+//===----------------------------------------------------------------------===//
+def Tosa_MulShapeOp : Tosa_ElementwiseShapeOp<"mul_shape", [Pure]> {
+  let summary = "Elementwise multiplication of shapes.";
+
+  let description = [{
+      Elementwise multiplication of input1 and input2.
+  }];
+
+  let arguments = (ins
+    Tosa_Shape:$input1,
+    Tosa_Shape:$input2
+  );
+
+  let results = (outs Tosa_Shape:$output);
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_SHAPE]>
+  ];
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: SubShape
+//===----------------------------------------------------------------------===//
+def Tosa_SubShapeOp : Tosa_ElementwiseShapeOp<"sub_shape", [Pure]> {
+  let summary = "Elementwise subtraction of shapes.";
+
+  let description = [{
+      Elementwise subtraction of input1 and input2. Size of shapes must match.
+  }];
+
+  let arguments = (ins
+    Tosa_Shape:$input1,
+    Tosa_Shape:$input2
+  );
+
+  let results = (outs Tosa_Shape:$output);
+
+  list<Availability> availability = [
+    Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+    Extension<[Tosa_EXT_SHAPE]>,
+  ];
 }
 
 #endif // TOSA_SHAPE_OPS
diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
index eb47e85cf9b0b..01f78f86d427b 100644
--- a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
@@ -43,6 +43,7 @@ TosaSpecificationVersion getMinVersion(const Extension &extension) {
     return TosaSpecificationVersion(1, 0);
   case Extension::mxfp:
   case Extension::int64:
+  case Extension::shape:
     return TosaSpecificationVersion(1, 1);
   case Extension::none:
     return TosaSpecificationVersion(0, 0);
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index ddd9c70402fdc..c9150d5b34d00 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -317,7 +317,12 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
   // Type Invariant Extension, a capability extension that is independent
   // of the data type, meaning any compatible type can be used. No type
   // constraint for those operations.
+  POPULATE_PROFILE_INFO_SKIP(AddShape)
   POPULATE_PROFILE_INFO_SKIP(ConstShape)
+  POPULATE_PROFILE_INFO_SKIP(DivCeilShape)
+  POPULATE_PROFILE_INFO_SKIP(DivFloorShape)
+  POPULATE_PROFILE_INFO_SKIP(MulShape)
+  POPULATE_PROFILE_INFO_SKIP(SubShape)
   POPULATE_PROFILE_INFO_SKIP(Yield)
   POPULATE_PROFILE_INFO_SKIP(If)
   POPULATE_PROFILE_INFO_SKIP(While)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index b54ed5585d72d..421ef237e628f 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -218,6 +218,12 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
       if (type.getRank() > highest_rank)
         return op->emitOpError() << "failed level check: " << operandOrResult
                                  << " rank(shape) <= MAX_RANK";
+    } else if (tosa::shapeType shapeType =
+                   dyn_cast<tosa::shapeType>(typeToCheck)) {
+      if (shapeType.getRank() > highest_rank)
+        return op->emitOpError()
+               << "failed shape type level check: " << typeToCheck
+               << " exceeds MAX_RANK";
     }
     return success();
   }
@@ -638,15 +644,21 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
   CHECK_RANKS_AND_SIZES(CastFromBlockScaled);
   CHECK_RANKS_AND_SIZES(CastToBlockScaled);
   CHECK_RANKS_AND_SIZES(Rescale);
+  // Data Nodes
+  CHECK_RANKS_AND_SIZES(Const);
+  CHECK_RANKS_AND_SIZES(Identity);
   // Control Flow Operators
   CHECK_RANKS_AND_SIZES(If);
   // Variable Operators
   CHECK_RANKS_AND_SIZES(Variable);
   CHECK_RANKS_AND_SIZES(VariableWrite);
   CHECK_RANKS_AND_SIZES(VariableRead);
-  // Data Nodes
-  CHECK_RANKS_AND_SIZES(Const);
-  CHECK_RANKS_AND_SIZES(Identity);
+  // Shape Operators
+  CHECK_RANKS_AND_SIZES(AddShape);
+  CHECK_RANKS_AND_SIZES(DivCeilShape);
+  CHECK_RANKS_AND_SIZES(DivFloorShape);
+  CHECK_RANKS_AND_SIZES(MulShape);
+  CHECK_RANKS_AND_SIZES(SubShape);
 
   // For the following operators, check whether the size of each tensor
   // operand is valid in a given Level.
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 68a95787b81c7..a06406fcdab1f 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -584,3 +584,13 @@ func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<
   %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>)
   return %0#0, %0#1 : tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>
 }
+
+// -----
+
+func.func @test_mul_shape() -> !tosa.shape<4> {
+  %a = tosa.const_shape {values = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %b = tosa.const_shape {values = dense<[5, 6, 7, 8]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // expected-error@+1 {{'tosa.mul_shape' op illegal: requires [shape] but not enabled in target}}
+  %c = tosa.mul_shape %a, %b : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<4>
+  return %c : !tosa.shape<4>
+}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index a7087647e542b..213c4ae054c51 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -390,7 +390,7 @@ func.func @test_pad_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1
 
 func.func @test_reshape_rank_invalid(%arg0: tensor<13x21x3xf32>) -> tensor<1x1x1x1x1x1x819xf32> {
   %1 = tosa.const_shape {values = dense<[1, 1, 1, 1, 1, 1, 819]> : tensor<7xindex>} : () -> !tosa.shape<7>
-  // expected-error@+1 {{'tosa.reshape' op failed level check: result rank(shape) <= MAX_RANK}}
+  // expected-error@+1 {{'tosa.reshape' op failed shape type level check: '!tosa.shape<7>' exceeds MAX_RANK}}
   %0 = "tosa.reshape"(%arg0, %1) : (tensor<13x21x3xf32>, !tosa.shape<7>) -> tensor<1x1x1x1x1x1x819xf32>
   return %0 : tensor<1x1x1x1x1x1x819xf32>
 }
@@ -1662,3 +1662,23 @@ func.func @test_cast_to_block_scaled_invalid_rank(%arg0: tensor<1x2x3x4x5x6x7x32
   %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<1x2x3x4x5x6x7x32xf32>) -> (tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, tensor<1x2x3x4x5x6x7x1xf8E8M0FNU>)
   return %0#0, %0#1 : tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, tensor<1x2x3x4x5x6x7x1xf8E8M0FNU>
 }
+
+// -----
+
+func.func @test_add_shape_invalid_rank() -> !tosa.shape<13> {
+  %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]> : tensor<13xindex>} : () -> !tosa.shape<13>
+  %b = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]> : tensor<13xindex>} : () -> !tosa.shape<13>
+  // expected-error@+1 {{'tosa.add_shape' op failed shape type level check: '!tosa.shape<13>' exceeds MAX_RANK}}
+  %c = tosa.add_shape %a, %b : (!tosa.shape<13>, !tosa.shape<13>) -> !tosa.shape<13>
+  return %c : !tosa.shape<13>
+}
+
+// -----
+
+func.func @test_div_floor_shape_invalid_rank() -> !tosa.shape<7> {
+  %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7]> : tensor<7xindex>} : () -> !tosa.shape<7>
+  %b = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6, 7]> : tensor<7xindex>} : () -> !tosa.shape<7>
+  // expected-error@+1 {{'tosa.div_floor_shape' op failed shape type level check: '!tosa.shape<7>' exceeds MAX_RANK}}
+  %c = tosa.div_floor_shape %a, %b : (!tosa.shape<7>, !tosa.shape<7>) -> !tosa.shape<7>
+  return %c : !tosa.shape<7>
+}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index a4591f7ffd393..2c4ec857ad20e 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -1374,3 +1374,48 @@ func.func @test_const_mxint8(%arg0 : index) -> tensor<2x!tosa.mxint8> {
     %0 = "tosa.const"() {values = dense<"0x007F"> : tensor<2x!tosa.mxint8>} : () -> tensor<2x!tosa.mxint8>
     return %0 : tensor<2x!tosa.mxint8>
 }
+
+// -----
+// CHECK-LABEL: test_add_shape
+func.func @test_add_shape() -> !tosa.shape<4> {
+  %a = tosa.const_shape {values = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %b = tosa.const_shape {values = dense<[5, 6, 7, 8]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %c = tosa.add_shape %a, %b : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<4>
+  return %c : !tosa.shape<4>
+}
+
+// -----
+// CHECK-LABEL: test_sub_shape
+func.func @test_sub_shape() -> !tosa.shape<3> {
+  %a = tosa.const_shape {values = dense<[10, 5, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %b = tosa.const_shape {values = dense<[2, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %c = tosa.sub_shape %a, %b : (!tosa.shape<3>, !tosa.shape<3>) -> !tosa.shape<3>
+  return %c : !tosa.shape<3>
+}
+
+// -----
+// CHECK-LABEL: test_mul_shape
+func.func @test_mul_shape() -> !tosa.shape<4> {
+  %a = tosa.const_shape {values = dense<[2, 3, 4, 5]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %b = tosa.const_shape {values = dense<[7, 0, 2, 6]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %c = tosa.mul_shape %a, %b : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<4>
+  return %c : !tosa.shape<4>
+}
+
+// -----
+// CHECK-LABEL: test_div_ceil_shape
+func.func @test_div_ceil_shape() -> !tosa.shape<4> {
+  %a = tosa.const_shape {values = dense<[5, 7, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %b = tosa.const_shape {values = dense<[2, 3, 4, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %c = tosa.div_ceil_shape %a, %b : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<4>
+  return %c : !tosa.shape<4>
+}
+
+// -----
+// CHECK-LABEL: test_div_floor_shape
+func.func @test_div_floor_shape() -> !tosa.shape<4> {
+  %a = tosa.const_shape {values = dense<[5, 7, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %b = tosa.const_shape {values = dense<[2, 3, 4, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %c = tosa.div_floor_shape %a, %b : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<4>
+  return %c : !tosa.shape<4>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
index c285ae3cf44ee..66a94559348a8 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,mxfp,int64" -tosa-validate="strict-op-spec-alignment" | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,mxfp,int64,shape" -tosa-validate="strict-op-spec-alignment" | FileCheck %s
 
 // -----
 
@@ -156,3 +156,12 @@ func.func @test_scatter_const_indices_int64(%arg0: tensor<2x52x3xf32>, %arg2: te
   %0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi64>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32>
   return %0 : tensor<2x52x3xf32>
 }
+
+// -----
+// CHECK-LABEL: test_add_shape
+func.func @test_add_shape() -> !tosa.shape<4> {
+  %a = tosa.const_shape {values = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %b = tosa.const_shape {values = dense<[5, 6, 7, 8]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %c = tosa.add_shape %a, %b : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<4>
+  return %c : !tosa.shape<4>
+}
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 6cf76cdc7ad8e..a70709b4ecc6a 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -1222,3 +1222,19 @@ func.func @test_cast_to_block_scaled_data_scale_channel_mismatch(%arg0: tensor<4
   %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>)
   return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>
 }
+
+// -----
+
+func.func @test_elementwise_shape_op_same_inputs_rank(%arg0: !tosa.shape<4>, %arg1: !tosa.shape<3>) -> !tosa.shape<4> {
+  // expected-error@+1 {{'tosa.add_shape' op operands don't have matching ranks}}
+  %0 = tosa.add_shape %arg0, %arg1 : (!tosa.shape<4>, !tosa.shape<3>) -> !tosa.shape<4>
+  return %0 : !tosa.shape<4>
+}
+
+// -----
+
+func.func @test_elementwise_shape_op_same_input_output_rank(%arg0: !tosa.shape<4>, %arg1: !tosa.shape<4>) -> !tosa.shape<3> {
+  // expected-error@+1 {{'tosa.div_floor_shape' op result shape has different rank than operands}}
+  %0 = tosa.div_floor_shape %arg0, %arg1 : (!tosa.shape<4>, !tosa.shape<4>) -> !tosa.shape<3>
+  return %0 : !tosa.shape<3>
+}

@lhutton1
Copy link
Contributor Author

cc @Tai78641 @udaya-ranga @psunn

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants