Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 0 additions & 22 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,6 @@ class ScalarType : public SPIRVType {
/// Returns true if the given float type is valid for the SPIR-V dialect.
static bool isValid(IntegerType);

void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage = std::nullopt);

Expand Down Expand Up @@ -118,8 +116,6 @@ class CompositeType : public SPIRVType {
/// implementation dependent.
bool hasCompileTimeKnownNumElements() const;

void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage = std::nullopt);

Expand Down Expand Up @@ -148,8 +144,6 @@ class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
/// type.
unsigned getArrayStride() const;

void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage = std::nullopt);

Expand Down Expand Up @@ -193,8 +187,6 @@ class ImageType
ImageFormat getImageFormat() const;
// TODO: Add support for Access qualifier

void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage = std::nullopt);
};
Expand All @@ -213,8 +205,6 @@ class PointerType : public Type::TypeBase<PointerType, SPIRVType,

StorageClass getStorageClass() const;

void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage = std::nullopt);
};
Expand All @@ -239,8 +229,6 @@ class RuntimeArrayType
/// type.
unsigned getArrayStride() const;

void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage = std::nullopt);
};
Expand All @@ -265,8 +253,6 @@ class SampledImageType

Type getImageType() const;

void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<spirv::StorageClass> storage = std::nullopt);
void
getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<spirv::StorageClass> storage = std::nullopt);
Expand Down Expand Up @@ -420,8 +406,6 @@ class StructType
ArrayRef<MemberDecorationInfo> memberDecorations = {},
ArrayRef<StructDecorationInfo> structDecorations = {});

void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage = std::nullopt);
};
Expand Down Expand Up @@ -456,8 +440,6 @@ class CooperativeMatrixType
/// Returns the use parameter of the cooperative matrix.
CooperativeMatrixUseKHR getUse() const;

void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage = std::nullopt);

Expand Down Expand Up @@ -512,8 +494,6 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
/// Returns the elements' type (i.e, single element type).
Type getElementType() const;

void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage = std::nullopt);
};
Expand Down Expand Up @@ -552,8 +532,6 @@ class TensorArmType
bool hasRank() const { return !getShape().empty(); }
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }

void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage = std::nullopt);
};
Expand Down
176 changes: 76 additions & 100 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,67 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/ErrorHandling.h"

#include <cstdint>

using namespace mlir;
using namespace mlir::spirv;

namespace {
// Helper function to collect extensions implied by a type by visiting all its
// subtypes. Maintains a set of `seen` types to avoid recursion in structs.
//
// Serves as the source-of-truth for type extension information. All extension
// logic should be added to this class, while the
// `SPIRVType::getExtensions` function should not handle extension-related logic
// directly and only invoke `TypeExtensionVisitor::add(Type *)`.
class TypeExtensionVisitor {
public:
TypeExtensionVisitor(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage)
: extensions(extensions), storage(storage) {}

// Main visitor entry point. Adds all extensions to the vector. Saves `type`
// as seen and dispatches to the right concrete `.add` function.
void add(SPIRVType type) {
if (auto [_it, inserted] = seen.insert({type, storage}); !inserted)
return;

TypeSwitch<SPIRVType>(type)
.Case<ScalarType, PointerType, CooperativeMatrixType, TensorArmType>(
[this](auto concreteType) { addConcrete(concreteType); })
.Case<VectorType, ArrayType, RuntimeArrayType, MatrixType, ImageType>(
[this](auto concreteType) { add(concreteType.getElementType()); })
.Case<StructType>([this](StructType concreteType) {
for (Type elementType : concreteType.getElementTypes())
add(elementType);
})
.Case<SampledImageType>([this](SampledImageType concreteType) {
add(concreteType.getImageType());
})
.Default([](SPIRVType) { llvm_unreachable("Unhandled type"); });
}

void add(Type type) { add(cast<SPIRVType>(type)); }

private:
// Types that add unique extensions.
void addConcrete(ScalarType type);
void addConcrete(PointerType type);
void addConcrete(CooperativeMatrixType type);
void addConcrete(TensorArmType type);

SPIRVType::ExtensionArrayRefVector &extensions;
std::optional<StorageClass> storage;
llvm::SmallDenseSet<std::pair<Type, std::optional<StorageClass>>> seen;
};

} // namespace

//===----------------------------------------------------------------------===//
// ArrayType
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -65,11 +118,6 @@ Type ArrayType::getElementType() const { return getImpl()->elementType; }

unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }

void ArrayType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
}

void ArrayType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage) {
Expand Down Expand Up @@ -140,27 +188,6 @@ bool CompositeType::hasCompileTimeKnownNumElements() const {
return !llvm::isa<CooperativeMatrixType, RuntimeArrayType>(*this);
}

void CompositeType::getExtensions(
SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
TypeSwitch<Type>(*this)
.Case<ArrayType, CooperativeMatrixType, MatrixType, RuntimeArrayType,
StructType>(
[&](auto type) { type.getExtensions(extensions, storage); })
.Case<VectorType>([&](VectorType type) {
return llvm::cast<ScalarType>(type.getElementType())
.getExtensions(extensions, storage);
})
.Case<TensorArmType>([&](TensorArmType type) {
static constexpr Extension ext{Extension::SPV_ARM_tensors};
extensions.push_back(ext);
return llvm::cast<ScalarType>(type.getElementType())
.getExtensions(extensions, storage);
})

.Default([](Type) { llvm_unreachable("invalid composite type"); });
}

void CompositeType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage) {
Expand Down Expand Up @@ -284,12 +311,10 @@ CooperativeMatrixUseKHR CooperativeMatrixType::getUse() const {
return getImpl()->use;
}

void CooperativeMatrixType::getExtensions(
SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
static constexpr Extension exts[] = {Extension::SPV_KHR_cooperative_matrix};
extensions.push_back(exts);
void TypeExtensionVisitor::addConcrete(CooperativeMatrixType type) {
add(type.getElementType());
static constexpr auto ext = Extension::SPV_KHR_cooperative_matrix;
extensions.push_back(ext);
}

void CooperativeMatrixType::getCapabilities(
Expand Down Expand Up @@ -403,11 +428,6 @@ ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {

ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }

void ImageType::getExtensions(SPIRVType::ExtensionArrayRefVector &,
std::optional<StorageClass>) {
// Image types do not require extra extensions thus far.
}

void ImageType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass>) {
Expand Down Expand Up @@ -454,14 +474,15 @@ StorageClass PointerType::getStorageClass() const {
return getImpl()->storageClass;
}

void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
void TypeExtensionVisitor::addConcrete(PointerType type) {
// Use this pointer type's storage class because this pointer indicates we are
// using the pointee type in that specific storage class.
llvm::cast<SPIRVType>(getPointeeType())
.getExtensions(extensions, getStorageClass());
std::optional<StorageClass> oldStorageClass = storage;
storage = type.getStorageClass();
add(type.getPointeeType());
storage = oldStorageClass;

if (auto scExts = spirv::getExtensions(getStorageClass()))
if (auto scExts = spirv::getExtensions(type.getStorageClass()))
extensions.push_back(*scExts);
}

Expand Down Expand Up @@ -513,12 +534,6 @@ Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }

unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }

void RuntimeArrayType::getExtensions(
SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
}

void RuntimeArrayType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage) {
Expand Down Expand Up @@ -553,10 +568,9 @@ bool ScalarType::isValid(IntegerType type) {
return llvm::is_contained({1u, 8u, 16u, 32u, 64u}, type.getWidth());
}

void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
if (isa<BFloat16Type>(*this)) {
static const Extension ext = Extension::SPV_KHR_bfloat16;
void TypeExtensionVisitor::addConcrete(ScalarType type) {
if (isa<BFloat16Type>(type)) {
static constexpr auto ext = Extension::SPV_KHR_bfloat16;
extensions.push_back(ext);
}

Expand All @@ -570,18 +584,16 @@ void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
case StorageClass::PushConstant:
case StorageClass::StorageBuffer:
case StorageClass::Uniform:
if (getIntOrFloatBitWidth() == 8) {
static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
ArrayRef<Extension> ref(exts, std::size(exts));
extensions.push_back(ref);
if (type.getIntOrFloatBitWidth() == 8) {
static constexpr auto ext = Extension::SPV_KHR_8bit_storage;
extensions.push_back(ext);
}
[[fallthrough]];
case StorageClass::Input:
case StorageClass::Output:
if (getIntOrFloatBitWidth() == 16) {
static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
ArrayRef<Extension> ref(exts, std::size(exts));
extensions.push_back(ref);
if (type.getIntOrFloatBitWidth() == 16) {
static constexpr auto ext = Extension::SPV_KHR_16bit_storage;
extensions.push_back(ext);
}
break;
default:
Expand Down Expand Up @@ -722,23 +734,7 @@ bool SPIRVType::isScalarOrVector() {

void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
scalarType.getExtensions(extensions, storage);
} else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
compositeType.getExtensions(extensions, storage);
} else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
imageType.getExtensions(extensions, storage);
} else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
sampledImageType.getExtensions(extensions, storage);
} else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
matrixType.getExtensions(extensions, storage);
} else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
ptrType.getExtensions(extensions, storage);
} else if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
tensorArmType.getExtensions(extensions, storage);
} else {
llvm_unreachable("invalid SPIR-V Type to getExtensions");
}
TypeExtensionVisitor{extensions, storage}.add(*this);
}

void SPIRVType::getCapabilities(
Expand Down Expand Up @@ -818,12 +814,6 @@ SampledImageType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
return success();
}

void SampledImageType::getExtensions(
SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
llvm::cast<ImageType>(getImageType()).getExtensions(extensions, storage);
}

void SampledImageType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage) {
Expand Down Expand Up @@ -1182,12 +1172,6 @@ StructType::trySetBody(ArrayRef<Type> memberTypes,
structDecorations);
}

void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
for (Type elementType : getElementTypes())
llvm::cast<SPIRVType>(elementType).getExtensions(extensions, storage);
}

void StructType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage) {
Expand Down Expand Up @@ -1287,11 +1271,6 @@ unsigned MatrixType::getNumElements() const {
return (getImpl()->columnCount) * getNumRows();
}

void MatrixType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
llvm::cast<SPIRVType>(getColumnType()).getExtensions(extensions, storage);
}

void MatrixType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage) {
Expand Down Expand Up @@ -1347,12 +1326,9 @@ TensorArmType TensorArmType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type TensorArmType::getElementType() const { return getImpl()->elementType; }
ArrayRef<int64_t> TensorArmType::getShape() const { return getImpl()->shape; }

void TensorArmType::getExtensions(
SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {

llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
static constexpr Extension ext{Extension::SPV_ARM_tensors};
void TypeExtensionVisitor::addConcrete(TensorArmType type) {
add(type.getElementType());
static constexpr auto ext = Extension::SPV_ARM_tensors;
extensions.push_back(ext);
}

Expand Down
Loading