Skip to content

Commit

Permalink
mlir/DialectConversion: use std::optional (NFC)
Browse files Browse the repository at this point in the history
This is part of an effort to migrate from llvm::Optional to
std::optional. This patch touches DialectConversion, and modifies
existing conversions and tests appropriately.

See also: https://discourse.llvm.org/t/deprecating-llvm-optional-x-hasvalue-getvalue-getvalueor/63716

Signed-off-by: Ramkumar Ramachandra <r@artagnon.com>

Differential Revision: https://reviews.llvm.org/D140303
  • Loading branch information
artagnon committed Dec 19, 2022
1 parent 2f6439b commit 0de16fa
Show file tree
Hide file tree
Showing 27 changed files with 209 additions and 191 deletions.
6 changes: 3 additions & 3 deletions flang/lib/Optimizer/CodeGen/TypeConverter.h
Expand Up @@ -138,7 +138,7 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter {
addSourceMaterialization(
[&](mlir::OpBuilder &builder, mlir::Type resultType,
mlir::ValueRange inputs,
mlir::Location loc) -> llvm::Optional<mlir::Value> {
mlir::Location loc) -> std::optional<mlir::Value> {
if (inputs.size() != 1)
return std::nullopt;
return inputs[0];
Expand All @@ -148,7 +148,7 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter {
addTargetMaterialization(
[&](mlir::OpBuilder &builder, mlir::Type resultType,
mlir::ValueRange inputs,
mlir::Location loc) -> llvm::Optional<mlir::Value> {
mlir::Location loc) -> std::optional<mlir::Value> {
if (inputs.size() != 1)
return std::nullopt;
return inputs[0];
Expand All @@ -163,7 +163,7 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter {
mlir::Type indexType() { return mlir::IntegerType::get(&getContext(), 64); }

// fir.type<name(p : TY'...){f : TY...}> --> llvm<"%name = { ty... }">
llvm::Optional<mlir::LogicalResult>
std::optional<mlir::LogicalResult>
convertRecordType(fir::RecordType derived,
llvm::SmallVectorImpl<mlir::Type> &results,
llvm::ArrayRef<mlir::Type> callStack) {
Expand Down
12 changes: 7 additions & 5 deletions mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
Expand Up @@ -23,23 +23,25 @@ class SPIRVTypeConverter;
namespace spirv {
/// Mapping from numeric MemRef memory spaces into SPIR-V symbolic ones.
using MemorySpaceToStorageClassMap =
std::function<Optional<spirv::StorageClass>(Attribute)>;
std::function<std::optional<spirv::StorageClass>(Attribute)>;

/// Maps MemRef memory spaces to storage classes for Vulkan-flavored SPIR-V
/// using the default rule. Returns std::nullopt if the memory space is unknown.
Optional<spirv::StorageClass> mapMemorySpaceToVulkanStorageClass(Attribute);
std::optional<spirv::StorageClass>
mapMemorySpaceToVulkanStorageClass(Attribute);
/// Maps storage classes for Vulkan-flavored SPIR-V to MemRef memory spaces
/// using the default rule. Returns std::nullopt if the storage class is
/// unsupported.
Optional<unsigned> mapVulkanStorageClassToMemorySpace(spirv::StorageClass);
std::optional<unsigned> mapVulkanStorageClassToMemorySpace(spirv::StorageClass);

/// Maps MemRef memory spaces to storage classes for OpenCL-flavored SPIR-V
/// using the default rule. Returns std::nullopt if the memory space is unknown.
Optional<spirv::StorageClass> mapMemorySpaceToOpenCLStorageClass(Attribute);
std::optional<spirv::StorageClass>
mapMemorySpaceToOpenCLStorageClass(Attribute);
/// Maps storage classes for OpenCL-flavored SPIR-V to MemRef memory spaces
/// using the default rule. Returns std::nullopt if the storage class is
/// unsupported.
Optional<unsigned> mapOpenCLStorageClassToMemorySpace(spirv::StorageClass);
std::optional<unsigned> mapOpenCLStorageClassToMemorySpace(spirv::StorageClass);

/// Type converter for converting numeric MemRef memory spaces into SPIR-V
/// symbolic ones.
Expand Down
59 changes: 30 additions & 29 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
Expand Up @@ -56,7 +56,7 @@ class SPIRVType : public Type {
/// the given `storage` class. This method does not guarantee the uniqueness
/// of extensions; the same extension may be appended multiple times.
void getExtensions(ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage = std::nullopt);
std::optional<StorageClass> storage = std::nullopt);

/// The capability requirements for each type are following the
/// ((Capability::A OR Extension::B) AND (Capability::C OR Capability::D))
Expand All @@ -68,12 +68,12 @@ class SPIRVType : public Type {
/// uniqueness of capabilities; the same capability may be appended multiple
/// times.
void getCapabilities(CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage = std::nullopt);
std::optional<StorageClass> storage = std::nullopt);

/// Returns the size in bytes for each type. If no size can be calculated,
/// returns `std::nullopt`. Note that if the type has explicit layout, it is
/// also taken into account in calculation.
Optional<int64_t> getSizeInBytes();
std::optional<int64_t> getSizeInBytes();
};

// SPIR-V scalar type: bool type, integer type, floating point type.
Expand All @@ -89,11 +89,11 @@ class ScalarType : public SPIRVType {
static bool isValid(IntegerType);

void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage = std::nullopt);
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage = std::nullopt);
std::optional<StorageClass> storage = std::nullopt);

Optional<int64_t> getSizeInBytes();
std::optional<int64_t> getSizeInBytes();
};

// SPIR-V composite type: VectorType, SPIR-V ArrayType, or SPIR-V StructType.
Expand All @@ -117,11 +117,11 @@ class CompositeType : public SPIRVType {
bool hasCompileTimeKnownNumElements() const;

void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage = std::nullopt);
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage = std::nullopt);
std::optional<StorageClass> storage = std::nullopt);

Optional<int64_t> getSizeInBytes();
std::optional<int64_t> getSizeInBytes();
};

// SPIR-V array type
Expand All @@ -145,13 +145,13 @@ class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
unsigned getArrayStride() const;

void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage = std::nullopt);
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage = std::nullopt);
std::optional<StorageClass> storage = std::nullopt);

/// Returns the array size in bytes. Since array type may have an explicit
/// stride declaration (in bytes), we also include it in the calculation.
Optional<int64_t> getSizeInBytes();
std::optional<int64_t> getSizeInBytes();
};

// SPIR-V image type
Expand Down Expand Up @@ -188,9 +188,9 @@ class ImageType
// TODO: Add support for Access qualifier

void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage = std::nullopt);
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage = std::nullopt);
std::optional<StorageClass> storage = std::nullopt);
};

// SPIR-V pointer type
Expand All @@ -206,9 +206,9 @@ class PointerType : public Type::TypeBase<PointerType, SPIRVType,
StorageClass getStorageClass() const;

void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage = std::nullopt);
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage = std::nullopt);
std::optional<StorageClass> storage = std::nullopt);
};

// SPIR-V run-time array type
Expand All @@ -230,9 +230,9 @@ class RuntimeArrayType
unsigned getArrayStride() const;

void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage = std::nullopt);
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage = std::nullopt);
std::optional<StorageClass> storage = std::nullopt);
};

// SPIR-V sampled image type
Expand All @@ -253,9 +253,10 @@ class SampledImageType
Type getImageType() const;

void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<spirv::StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<spirv::StorageClass> storage = std::nullopt);
std::optional<spirv::StorageClass> storage = std::nullopt);
void
getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<spirv::StorageClass> storage = std::nullopt);
};

/// SPIR-V struct type. Two kinds of struct types are supported:
Expand Down Expand Up @@ -389,9 +390,9 @@ class StructType
ArrayRef<MemberDecorationInfo> memberDecorations = {});

void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage = std::nullopt);
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage = std::nullopt);
std::optional<StorageClass> storage = std::nullopt);
};

llvm::hash_code
Expand All @@ -416,9 +417,9 @@ class CooperativeMatrixNVType
unsigned getColumns() const;

void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage = std::nullopt);
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage = std::nullopt);
std::optional<StorageClass> storage = std::nullopt);
};

// SPIR-V joint matrix type
Expand All @@ -443,9 +444,9 @@ class JointMatrixINTELType
MatrixLayout getMatrixLayout() const;

void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage = std::nullopt);
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage = std::nullopt);
std::optional<StorageClass> storage = std::nullopt);
};

// SPIR-V matrix type
Expand Down Expand Up @@ -480,9 +481,9 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
Type getElementType() const;

void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage = std::nullopt);
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage = std::nullopt);
std::optional<StorageClass> storage = std::nullopt);
};

} // namespace spirv
Expand Down
48 changes: 26 additions & 22 deletions mlir/include/mlir/Transforms/DialectConversion.h
Expand Up @@ -54,7 +54,7 @@ class TypeConverter {
ArrayRef<Type> getConvertedTypes() const { return argTypes; }

/// Get the input mapping for the given argument.
Optional<InputMapping> getInputMapping(unsigned input) const {
std::optional<InputMapping> getInputMapping(unsigned input) const {
return remappedInputs[input];
}

Expand All @@ -81,27 +81,28 @@ class TypeConverter {
unsigned newInputCount = 1);

/// The remapping information for each of the original arguments.
SmallVector<Optional<InputMapping>, 4> remappedInputs;
SmallVector<std::optional<InputMapping>, 4> remappedInputs;

/// The set of new argument types.
SmallVector<Type, 4> argTypes;
};

/// Register a conversion function. A conversion function must be convertible
/// to any of the following forms(where `T` is a class derived from `Type`:
/// * Optional<Type>(T)
/// * std::optional<Type>(T)
/// - This form represents a 1-1 type conversion. It should return nullptr
/// or `std::nullopt` to signify failure. If `std::nullopt` is returned,
/// the converter is allowed to try another conversion function to
/// perform the conversion.
/// * Optional<LogicalResult>(T, SmallVectorImpl<Type> &)
/// * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &)
/// - This form represents a 1-N type conversion. It should return
/// `failure` or `std::nullopt` to signify a failed conversion. If the
/// new set of types is empty, the type is removed and any usages of the
/// existing value are expected to be removed during conversion. If
/// `std::nullopt` is returned, the converter is allowed to try another
/// conversion function to perform the conversion.
/// * Optional<LogicalResult>(T, SmallVectorImpl<Type> &, ArrayRef<Type>)
/// * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &,
/// ArrayRef<Type>)
/// - This form represents a 1-N type conversion supporting recursive
/// types. The first two arguments and the return value are the same as
/// for the regular 1-N form. The third argument is contains is the
Expand All @@ -119,7 +120,7 @@ class TypeConverter {

/// Register a materialization function, which must be convertible to the
/// following form:
/// `Optional<Value>(OpBuilder &, T, ValueRange, Location)`,
/// `std::optional<Value>(OpBuilder &, T, ValueRange, Location)`,
/// where `T` is any subclass of `Type`. This function is responsible for
/// creating an operation, using the OpBuilder and Location provided, that
/// "casts" a range of values into a single value of the given type `T`. It
Expand Down Expand Up @@ -203,7 +204,7 @@ class TypeConverter {
/// This function converts the type signature of the given block, by invoking
/// 'convertSignatureArg' for each argument. This function should return a
/// valid conversion for the signature on success, std::nullopt otherwise.
Optional<SignatureConversion> convertBlockSignature(Block *block);
std::optional<SignatureConversion> convertBlockSignature(Block *block);

/// Materialize a conversion from a set of types into one result type by
/// generating a cast sequence of some kind. See the respective
Expand All @@ -229,12 +230,12 @@ class TypeConverter {
/// The signature of the callback used to convert a type. If the new set of
/// types is empty, the type is removed and any usages of the existing value
/// are expected to be removed during conversion.
using ConversionCallbackFn = std::function<Optional<LogicalResult>(
using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
Type, SmallVectorImpl<Type> &, ArrayRef<Type>)>;

/// The signature of the callback used to materialize a conversion.
using MaterializationCallbackFn =
std::function<Optional<Value>(OpBuilder &, Type, ValueRange, Location)>;
using MaterializationCallbackFn = std::function<std::optional<Value>(
OpBuilder &, Type, ValueRange, Location)>;

/// Attempt to materialize a conversion using one of the provided
/// materialization functions.
Expand All @@ -244,23 +245,24 @@ class TypeConverter {

/// Generate a wrapper for the given callback. This allows for accepting
/// different callback forms, that all compose into a single version.
/// With callback of form: `Optional<Type>(T)`
/// With callback of form: `std::optional<Type>(T)`
template <typename T, typename FnT>
std::enable_if_t<std::is_invocable_v<FnT, T>, ConversionCallbackFn>
wrapCallback(FnT &&callback) {
return wrapCallback<T>(
[callback = std::forward<FnT>(callback)](
T type, SmallVectorImpl<Type> &results, ArrayRef<Type>) {
if (Optional<Type> resultOpt = callback(type)) {
if (std::optional<Type> resultOpt = callback(type)) {
bool wasSuccess = static_cast<bool>(*resultOpt);
if (wasSuccess)
results.push_back(*resultOpt);
return Optional<LogicalResult>(success(wasSuccess));
return std::optional<LogicalResult>(success(wasSuccess));
}
return Optional<LogicalResult>();
return std::optional<LogicalResult>();
});
}
/// With callback of form: `Optional<LogicalResult>(T, SmallVectorImpl<Type>
/// With callback of form: `std::optional<LogicalResult>(T,
/// SmallVectorImpl<Type>
/// &)`
template <typename T, typename FnT>
std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &>,
Expand All @@ -272,7 +274,8 @@ class TypeConverter {
return callback(type, results);
});
}
/// With callback of form: `Optional<LogicalResult>(T, SmallVectorImpl<Type>
/// With callback of form: `std::optional<LogicalResult>(T,
/// SmallVectorImpl<Type>
/// &, ArrayRef<Type>)`.
template <typename T, typename FnT>
std::enable_if_t<
Expand All @@ -281,7 +284,7 @@ class TypeConverter {
wrapCallback(FnT &&callback) {
return [callback = std::forward<FnT>(callback)](
Type type, SmallVectorImpl<Type> &results,
ArrayRef<Type> callStack) -> Optional<LogicalResult> {
ArrayRef<Type> callStack) -> std::optional<LogicalResult> {
T derivedType = type.dyn_cast<T>();
if (!derivedType)
return std::nullopt;
Expand All @@ -303,7 +306,7 @@ class TypeConverter {
MaterializationCallbackFn wrapMaterialization(FnT &&callback) {
return [callback = std::forward<FnT>(callback)](
OpBuilder &builder, Type resultType, ValueRange inputs,
Location loc) -> Optional<Value> {
Location loc) -> std::optional<Value> {
if (T derivedType = resultType.dyn_cast<T>())
return callback(builder, derivedType, inputs, loc);
return std::nullopt;
Expand Down Expand Up @@ -681,7 +684,8 @@ class ConversionTarget {

/// The signature of the callback used to determine if an operation is
/// dynamically legal on the target.
using DynamicLegalityCallbackFn = std::function<Optional<bool>(Operation *)>;
using DynamicLegalityCallbackFn =
std::function<std::optional<bool>(Operation *)>;

ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
virtual ~ConversionTarget() = default;
Expand Down Expand Up @@ -830,7 +834,7 @@ class ConversionTarget {
//===--------------------------------------------------------------------===//

/// Get the legality action for the given operation.
Optional<LegalizationAction> getOpAction(OperationName op) const;
std::optional<LegalizationAction> getOpAction(OperationName op) const;

/// If the given operation instance is legal on this target, a structure
/// containing legality information is returned. If the operation is not
Expand All @@ -841,7 +845,7 @@ class ConversionTarget {
/// Note: Legality is actually a 4-state: Legal(recursive=true),
/// Legal(recursive=false), Illegal or Unknown, where Unknown is treated
/// either as Legal or Illegal depending on context.
Optional<LegalOpDetails> isLegal(Operation *op) const;
std::optional<LegalOpDetails> isLegal(Operation *op) const;

/// Returns true is operation instance is illegal on this target. Returns
/// false if operation is legal, operation legality wasn't registered by user
Expand Down Expand Up @@ -873,7 +877,7 @@ class ConversionTarget {
};

/// Get the legalization information for the given operation.
Optional<LegalizationInfo> getOpInfo(OperationName op) const;
std::optional<LegalizationInfo> getOpInfo(OperationName op) const;

/// A deterministic mapping of operation name and its respective legality
/// information.
Expand Down

0 comments on commit 0de16fa

Please sign in to comment.