Skip to content

Commit

Permalink
[mlir][Type] Remove the remaining usages of Type::getKind in preparat…
Browse files Browse the repository at this point in the history
…ion for its removal

This revision removes all of the lingering usages of Type::getKind. A consequence of this is that FloatType is now split into 4 derived types that represent each of the possible float types(BFloat16Type, Float16Type, Float32Type, and Float64Type). Other than this split, this revision is NFC.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D85566
  • Loading branch information
River707 committed Aug 13, 2020
1 parent 18b1e67 commit 6527712
Show file tree
Hide file tree
Showing 10 changed files with 294 additions and 364 deletions.
90 changes: 74 additions & 16 deletions mlir/include/mlir/IR/StandardTypes.h
Expand Up @@ -180,25 +180,18 @@ class IntegerType
// FloatType
//===----------------------------------------------------------------------===//

class FloatType : public Type::TypeBase<FloatType, Type, TypeStorage> {
class FloatType : public Type {
public:
using Base::Base;

static FloatType get(StandardTypes::Kind kind, MLIRContext *context);
using Type::Type;

// Convenience factories.
static FloatType getBF16(MLIRContext *ctx) {
return get(StandardTypes::BF16, ctx);
}
static FloatType getF16(MLIRContext *ctx) {
return get(StandardTypes::F16, ctx);
}
static FloatType getF32(MLIRContext *ctx) {
return get(StandardTypes::F32, ctx);
}
static FloatType getF64(MLIRContext *ctx) {
return get(StandardTypes::F64, ctx);
}
static FloatType getBF16(MLIRContext *ctx);
static FloatType getF16(MLIRContext *ctx);
static FloatType getF32(MLIRContext *ctx);
static FloatType getF64(MLIRContext *ctx);

/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Type type);

/// Return the bitwidth of this float type.
unsigned getWidth();
Expand All @@ -207,6 +200,67 @@ class FloatType : public Type::TypeBase<FloatType, Type, TypeStorage> {
const llvm::fltSemantics &getFloatSemantics();
};

//===----------------------------------------------------------------------===//
// BFloat16Type

class BFloat16Type
: public Type::TypeBase<BFloat16Type, FloatType, TypeStorage> {
public:
using Base::Base;

/// Return an instance of the bfloat16 type.
static BFloat16Type get(MLIRContext *context);
};

inline FloatType FloatType::getBF16(MLIRContext *ctx) {
return BFloat16Type::get(ctx);
}

//===----------------------------------------------------------------------===//
// Float16Type

class Float16Type : public Type::TypeBase<Float16Type, FloatType, TypeStorage> {
public:
using Base::Base;

/// Return an instance of the float16 type.
static Float16Type get(MLIRContext *context);
};

inline FloatType FloatType::getF16(MLIRContext *ctx) {
return Float16Type::get(ctx);
}

//===----------------------------------------------------------------------===//
// Float32Type

class Float32Type : public Type::TypeBase<Float32Type, FloatType, TypeStorage> {
public:
using Base::Base;

/// Return an instance of the float32 type.
static Float32Type get(MLIRContext *context);
};

inline FloatType FloatType::getF32(MLIRContext *ctx) {
return Float32Type::get(ctx);
}

//===----------------------------------------------------------------------===//
// Float64Type

class Float64Type : public Type::TypeBase<Float64Type, FloatType, TypeStorage> {
public:
using Base::Base;

/// Return an instance of the float64 type.
static Float64Type get(MLIRContext *context);
};

inline FloatType FloatType::getF64(MLIRContext *ctx) {
return Float64Type::get(ctx);
}

//===----------------------------------------------------------------------===//
// NoneType
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -623,6 +677,10 @@ inline bool BaseMemRefType::classof(Type type) {
return type.isa<MemRefType, UnrankedMemRefType>();
}

inline bool FloatType::classof(Type type) {
return type.isa<BFloat16Type, Float16Type, Float32Type, Float64Type>();
}

inline bool ShapedType::classof(Type type) {
return type.isa<RankedTensorType, VectorType, UnrankedTensorType,
UnrankedMemRefType, MemRefType>();
Expand Down
14 changes: 5 additions & 9 deletions mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
Expand Up @@ -210,19 +210,15 @@ Type LLVMTypeConverter::convertIntegerType(IntegerType type) {
}

Type LLVMTypeConverter::convertFloatType(FloatType type) {
switch (type.getKind()) {
case mlir::StandardTypes::F32:
if (type.isa<Float32Type>())
return LLVM::LLVMType::getFloatTy(&getContext());
case mlir::StandardTypes::F64:
if (type.isa<Float64Type>())
return LLVM::LLVMType::getDoubleTy(&getContext());
case mlir::StandardTypes::F16:
if (type.isa<Float16Type>())
return LLVM::LLVMType::getHalfTy(&getContext());
case mlir::StandardTypes::BF16: {
if (type.isa<BFloat16Type>())
return LLVM::LLVMType::getBFloatTy(&getContext());
}
default:
llvm_unreachable("non-float type in convertFloatType");
}
llvm_unreachable("non-float type in convertFloatType");
}

// Convert a `ComplexType` to an LLVM type. The result is a complex number
Expand Down
72 changes: 25 additions & 47 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
Expand Up @@ -10,6 +10,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace mlir;
using namespace mlir::LLVM;
Expand All @@ -23,46 +24,28 @@ static void printTypeImpl(llvm::raw_ostream &os, LLVMType type,

/// Returns the keyword to use for the given type.
static StringRef getTypeKeyword(LLVMType type) {
switch (type.getKind()) {
case LLVMType::VoidType:
return "void";
case LLVMType::HalfType:
return "half";
case LLVMType::BFloatType:
return "bfloat";
case LLVMType::FloatType:
return "float";
case LLVMType::DoubleType:
return "double";
case LLVMType::FP128Type:
return "fp128";
case LLVMType::X86FP80Type:
return "x86_fp80";
case LLVMType::PPCFP128Type:
return "ppc_fp128";
case LLVMType::X86MMXType:
return "x86_mmx";
case LLVMType::TokenType:
return "token";
case LLVMType::LabelType:
return "label";
case LLVMType::MetadataType:
return "metadata";
case LLVMType::FunctionType:
return "func";
case LLVMType::IntegerType:
return "i";
case LLVMType::PointerType:
return "ptr";
case LLVMType::FixedVectorType:
case LLVMType::ScalableVectorType:
return "vec";
case LLVMType::ArrayType:
return "array";
case LLVMType::StructType:
return "struct";
}
llvm_unreachable("unhandled type kind");
return TypeSwitch<Type, StringRef>(type)
.Case<LLVMVoidType>([&](Type) { return "void"; })
.Case<LLVMHalfType>([&](Type) { return "half"; })
.Case<LLVMBFloatType>([&](Type) { return "bfloat"; })
.Case<LLVMFloatType>([&](Type) { return "float"; })
.Case<LLVMDoubleType>([&](Type) { return "double"; })
.Case<LLVMFP128Type>([&](Type) { return "fp128"; })
.Case<LLVMX86FP80Type>([&](Type) { return "x86_fp80"; })
.Case<LLVMPPCFP128Type>([&](Type) { return "ppc_fp128"; })
.Case<LLVMX86MMXType>([&](Type) { return "x86_mmx"; })
.Case<LLVMTokenType>([&](Type) { return "token"; })
.Case<LLVMLabelType>([&](Type) { return "label"; })
.Case<LLVMMetadataType>([&](Type) { return "metadata"; })
.Case<LLVMFunctionType>([&](Type) { return "func"; })
.Case<LLVMIntegerType>([&](Type) { return "i"; })
.Case<LLVMPointerType>([&](Type) { return "ptr"; })
.Case<LLVMVectorType>([&](Type) { return "vec"; })
.Case<LLVMArrayType>([&](Type) { return "array"; })
.Case<LLVMStructType>([&](Type) { return "struct"; })
.Default([](Type) -> StringRef {
llvm_unreachable("unexpected 'llvm' type kind");
});
}

/// Prints the body of a structure type. Uses `stack` to avoid printing
Expand Down Expand Up @@ -153,14 +136,8 @@ static void printTypeImpl(llvm::raw_ostream &os, LLVMType type,
return;
}

unsigned kind = type.getKind();
os << getTypeKeyword(type);

// Trivial types only consist of their keyword.
if (LLVMType::FIRST_TRIVIAL_TYPE <= kind &&
kind <= LLVMType::LAST_TRIVIAL_TYPE)
return;

if (auto intType = type.dyn_cast<LLVMIntegerType>()) {
os << intType.getBitWidth();
return;
Expand Down Expand Up @@ -190,7 +167,8 @@ static void printTypeImpl(llvm::raw_ostream &os, LLVMType type,
if (auto structType = type.dyn_cast<LLVMStructType>())
return printStructType(os, structType, stack);

printFunctionType(os, type.cast<LLVMFunctionType>(), stack);
if (auto funcType = type.dyn_cast<LLVMFunctionType>())
return printFunctionType(os, funcType, stack);
}

void mlir::LLVM::detail::printType(LLVMType type, DialectAsmPrinter &printer) {
Expand Down

0 comments on commit 6527712

Please sign in to comment.