Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,6 @@ class ScalarType : public SPIRVType {
static bool isValid(FloatType);
/// Returns true if the given float type is valid for the SPIR-V dialect.
static bool isValid(IntegerType);

std::optional<int64_t> getSizeInBytes();
};

// SPIR-V composite type: VectorType, SPIR-V ArrayType, SPIR-V
Expand All @@ -112,8 +110,6 @@ class CompositeType : public SPIRVType {
/// Return true if the number of elements is known at compile time and is not
/// implementation dependent.
bool hasCompileTimeKnownNumElements() const;

std::optional<int64_t> getSizeInBytes();
};

// SPIR-V array type
Expand All @@ -137,10 +133,6 @@ class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
/// Returns the array stride in bytes. 0 means no stride decorated on this
/// type.
unsigned getArrayStride() const;

/// Returns the array size in bytes. Since array type may have an explicit
/// stride declaration (in bytes), we also include it in the calculation.
std::optional<int64_t> getSizeInBytes();
};

// SPIR-V image type
Expand Down
78 changes: 30 additions & 48 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "llvm/Support/ErrorHandling.h"

#include <cstdint>
#include <optional>

using namespace mlir;
using namespace mlir::spirv;
Expand Down Expand Up @@ -172,14 +173,6 @@ Type ArrayType::getElementType() const { return getImpl()->elementType; }

unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }

std::optional<int64_t> ArrayType::getSizeInBytes() {
auto elementType = llvm::cast<SPIRVType>(getElementType());
std::optional<int64_t> size = elementType.getSizeInBytes();
if (!size)
return std::nullopt;
return (*size + getArrayStride()) * getNumElements();
}

//===----------------------------------------------------------------------===//
// CompositeType
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -245,28 +238,6 @@ void TypeCapabilityVisitor::addConcrete(VectorType type) {
}
}

std::optional<int64_t> CompositeType::getSizeInBytes() {
if (auto arrayType = llvm::dyn_cast<ArrayType>(*this))
return arrayType.getSizeInBytes();
if (auto structType = llvm::dyn_cast<StructType>(*this))
return structType.getSizeInBytes();
if (auto vectorType = llvm::dyn_cast<VectorType>(*this)) {
std::optional<int64_t> elementSize =
llvm::cast<ScalarType>(vectorType.getElementType()).getSizeInBytes();
if (!elementSize)
return std::nullopt;
return *elementSize * vectorType.getNumElements();
}
if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
std::optional<int64_t> elementSize =
llvm::cast<ScalarType>(tensorArmType.getElementType()).getSizeInBytes();
if (!elementSize)
return std::nullopt;
return *elementSize * tensorArmType.getNumElements();
}
return std::nullopt;
}

//===----------------------------------------------------------------------===//
// CooperativeMatrixType
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -714,19 +685,6 @@ void TypeCapabilityVisitor::addConcrete(ScalarType type) {
#undef WIDTH_CASE
}

std::optional<int64_t> ScalarType::getSizeInBytes() {
auto bitWidth = getIntOrFloatBitWidth();
// According to the SPIR-V spec:
// "There is no physical size or bit pattern defined for values with boolean
// type. If they are stored (in conjunction with OpVariable), they can only
// be used with logical addressing operations, not physical, and only with
// non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
// Private, Function, Input, and Output."
if (bitWidth == 1)
return std::nullopt;
return bitWidth / 8;
}

//===----------------------------------------------------------------------===//
// SPIRVType
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -760,11 +718,35 @@ void SPIRVType::getCapabilities(
}

std::optional<int64_t> SPIRVType::getSizeInBytes() {
if (auto scalarType = llvm::dyn_cast<ScalarType>(*this))
return scalarType.getSizeInBytes();
if (auto compositeType = llvm::dyn_cast<CompositeType>(*this))
return compositeType.getSizeInBytes();
return std::nullopt;
return TypeSwitch<SPIRVType, std::optional<int64_t>>(*this)
.Case<ScalarType>([](ScalarType type) -> std::optional<int64_t> {
// According to the SPIR-V spec:
// "There is no physical size or bit pattern defined for values with
// boolean type. If they are stored (in conjunction with OpVariable),
// they can only be used with logical addressing operations, not
// physical, and only with non-externally visible shader Storage
// Classes: Workgroup, CrossWorkgroup, Private, Function, Input, and
// Output."
int64_t bitWidth = type.getIntOrFloatBitWidth();
if (bitWidth == 1)
return std::nullopt;
return bitWidth / 8;
})
.Case<ArrayType>([](ArrayType type) -> std::optional<int64_t> {
// Since array type may have an explicit stride declaration (in bytes),
// we also include it in the calculation.
auto elementType = cast<SPIRVType>(type.getElementType());
if (std::optional<int64_t> size = elementType.getSizeInBytes())
return (*size + type.getArrayStride()) * type.getNumElements();
return std::nullopt;
})
.Case<VectorType, TensorArmType>([](auto type) -> std::optional<int64_t> {
if (std::optional<int64_t> elementSize =
cast<ScalarType>(type.getElementType()).getSizeInBytes())
return *elementSize * type.getNumElements();
return std::nullopt;
})
.Default(std::optional<int64_t>());
}

//===----------------------------------------------------------------------===//
Expand Down
Loading