Skip to content

Conversation

kuhar
Copy link
Member

@kuhar kuhar commented Sep 22, 2025

Similar to ::getExtensions and ::getCapabilities, introduce a single entry point for type size calculation.

Also fix potential infinite recursion with StructTypes (even non-recursive structs), although I don't know to write a test for this without using C++. This is mostly an NFC modulo this potential bug fix.

Similar to `::getExtensions` and `::getCapabilities`, introduce a single
entry point for type size calculation.

Also fix potential infinite recursion with `StructType`s (even
non-recursive structs), although I don't know to write a test for this
without using C++. This is mostly an NFC modulo this potential bug fix.
@llvmbot
Copy link
Member

llvmbot commented Sep 22, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-spirv

Author: Jakub Kuderski (kuhar)

Changes

Similar to ::getExtensions and ::getCapabilities, introduce a single entry point for type size calculation.

Also fix potential infinite recursion with StructTypes (even non-recursive structs), although I don't know to write a test for this without using C++. This is mostly an NFC modulo this potential bug fix.


Full diff: https://github.com/llvm/llvm-project/pull/160162.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h (-8)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp (+30-48)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 475e3f495e065..e46b576810316 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -88,8 +88,6 @@ class ScalarType : public SPIRVType {
   static bool isValid(FloatType);
   /// Returns true if the given float type is valid for the SPIR-V dialect.
   static bool isValid(IntegerType);
-
-  std::optional<int64_t> getSizeInBytes();
 };
 
 // SPIR-V composite type: VectorType, SPIR-V ArrayType, SPIR-V
@@ -112,8 +110,6 @@ class CompositeType : public SPIRVType {
   /// Return true if the number of elements is known at compile time and is not
   /// implementation dependent.
   bool hasCompileTimeKnownNumElements() const;
-
-  std::optional<int64_t> getSizeInBytes();
 };
 
 // SPIR-V array type
@@ -137,10 +133,6 @@ class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
   /// Returns the array stride in bytes. 0 means no stride decorated on this
   /// type.
   unsigned getArrayStride() const;
-
-  /// Returns the array size in bytes. Since array type may have an explicit
-  /// stride declaration (in bytes), we also include it in the calculation.
-  std::optional<int64_t> getSizeInBytes();
 };
 
 // SPIR-V image type
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 7c2f43bea9ddb..5ed7652987859 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -20,6 +20,7 @@
 #include "llvm/Support/ErrorHandling.h"
 
 #include <cstdint>
+#include <optional>
 
 using namespace mlir;
 using namespace mlir::spirv;
@@ -172,14 +173,6 @@ Type ArrayType::getElementType() const { return getImpl()->elementType; }
 
 unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
 
-std::optional<int64_t> ArrayType::getSizeInBytes() {
-  auto elementType = llvm::cast<SPIRVType>(getElementType());
-  std::optional<int64_t> size = elementType.getSizeInBytes();
-  if (!size)
-    return std::nullopt;
-  return (*size + getArrayStride()) * getNumElements();
-}
-
 //===----------------------------------------------------------------------===//
 // CompositeType
 //===----------------------------------------------------------------------===//
@@ -245,28 +238,6 @@ void TypeCapabilityVisitor::addConcrete(VectorType type) {
   }
 }
 
-std::optional<int64_t> CompositeType::getSizeInBytes() {
-  if (auto arrayType = llvm::dyn_cast<ArrayType>(*this))
-    return arrayType.getSizeInBytes();
-  if (auto structType = llvm::dyn_cast<StructType>(*this))
-    return structType.getSizeInBytes();
-  if (auto vectorType = llvm::dyn_cast<VectorType>(*this)) {
-    std::optional<int64_t> elementSize =
-        llvm::cast<ScalarType>(vectorType.getElementType()).getSizeInBytes();
-    if (!elementSize)
-      return std::nullopt;
-    return *elementSize * vectorType.getNumElements();
-  }
-  if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
-    std::optional<int64_t> elementSize =
-        llvm::cast<ScalarType>(tensorArmType.getElementType()).getSizeInBytes();
-    if (!elementSize)
-      return std::nullopt;
-    return *elementSize * tensorArmType.getNumElements();
-  }
-  return std::nullopt;
-}
-
 //===----------------------------------------------------------------------===//
 // CooperativeMatrixType
 //===----------------------------------------------------------------------===//
@@ -714,19 +685,6 @@ void TypeCapabilityVisitor::addConcrete(ScalarType type) {
 #undef WIDTH_CASE
 }
 
-std::optional<int64_t> ScalarType::getSizeInBytes() {
-  auto bitWidth = getIntOrFloatBitWidth();
-  // According to the SPIR-V spec:
-  // "There is no physical size or bit pattern defined for values with boolean
-  // type. If they are stored (in conjunction with OpVariable), they can only
-  // be used with logical addressing operations, not physical, and only with
-  // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
-  // Private, Function, Input, and Output."
-  if (bitWidth == 1)
-    return std::nullopt;
-  return bitWidth / 8;
-}
-
 //===----------------------------------------------------------------------===//
 // SPIRVType
 //===----------------------------------------------------------------------===//
@@ -760,11 +718,35 @@ void SPIRVType::getCapabilities(
 }
 
 std::optional<int64_t> SPIRVType::getSizeInBytes() {
-  if (auto scalarType = llvm::dyn_cast<ScalarType>(*this))
-    return scalarType.getSizeInBytes();
-  if (auto compositeType = llvm::dyn_cast<CompositeType>(*this))
-    return compositeType.getSizeInBytes();
-  return std::nullopt;
+  return TypeSwitch<SPIRVType, std::optional<int64_t>>(*this)
+      .Case<ScalarType>([](ScalarType type) -> std::optional<int64_t> {
+        // According to the SPIR-V spec:
+        // "There is no physical size or bit pattern defined for values with
+        // boolean type. If they are stored (in conjunction with OpVariable),
+        // they can only be used with logical addressing operations, not
+        // physical, and only with non-externally visible shader Storage
+        // Classes: Workgroup, CrossWorkgroup, Private, Function, Input, and
+        // Output."
+        int64_t bitWidth = type.getIntOrFloatBitWidth();
+        if (bitWidth == 1)
+          return std::nullopt;
+        return bitWidth / 8;
+      })
+      .Case<ArrayType>([](ArrayType type) -> std::optional<int64_t> {
+        // Since array type may have an explicit stride declaration (in bytes),
+        // we also include it in the calculation.
+        auto elementType = cast<SPIRVType>(type.getElementType());
+        if (std::optional<int64_t> size = elementType.getSizeInBytes())
+          return (*size + type.getArrayStride()) * type.getNumElements();
+        return std::nullopt;
+      })
+      .Case<VectorType, TensorArmType>([](auto type) -> std::optional<int64_t> {
+        if(std::optional<int64_t> elementSize =
+                cast<ScalarType>(type.getElementType()).getSizeInBytes())
+          return *elementSize * type.getNumElements();
+        return std::nullopt;
+      })
+      .Default(std::optional<int64_t>());
 }
 
 //===----------------------------------------------------------------------===//

Copy link

github-actions bot commented Sep 22, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@kuhar kuhar merged commit c526c70 into llvm:main Sep 22, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants