Skip to content

Commit

Permalink
[mlir] Add Float8E5M2FNUZ and Float8E4M3FNUZ types to MLIR
Browse files Browse the repository at this point in the history
Float8E5M2FNUZ and Float8E4M3FNUZ have been added to APFloat in D141863.
This change adds these types as MLIR builtin types alongside Float8E5M2
and Float8E4M3FN (added in D133823 and D138075).

Reviewed By: krzysz00

Differential Revision: https://reviews.llvm.org/D143744
  • Loading branch information
jakeh-gc authored and Chris Jackson committed Feb 13, 2023
1 parent 7c84f6a commit 96267b6
Show file tree
Hide file tree
Showing 19 changed files with 201 additions and 3 deletions.
14 changes: 14 additions & 0 deletions mlir/include/mlir-c/BuiltinTypes.h
Expand Up @@ -81,6 +81,20 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FN(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx);

/// Checks whether the given type is an f8E5M2FNUZ type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type);

/// Creates an f8E5M2FNUZ type in the given context. The type is owned by the
/// context.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx);

/// Checks whether the given type is an f8E4M3FNUZ type.
MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type);

/// Creates an f8E4M3FNUZ type in the given context. The type is owned by the
/// context.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx);

/// Checks whether the given type is a bf16 type.
MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type);

Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/IR/Builders.h
Expand Up @@ -62,6 +62,8 @@ class Builder {
// Types.
FloatType getFloat8E5M2Type();
FloatType getFloat8E4M3FNType();
FloatType getFloat8E5M2FNUZType();
FloatType getFloat8E4M3FNUZType();
FloatType getBF16Type();
FloatType getF16Type();
FloatType getF32Type();
Expand Down
15 changes: 13 additions & 2 deletions mlir/include/mlir/IR/BuiltinTypes.h
Expand Up @@ -47,6 +47,8 @@ class FloatType : public Type {
static FloatType getF128(MLIRContext *ctx);
static FloatType getFloat8E5M2(MLIRContext *ctx);
static FloatType getFloat8E4M3FN(MLIRContext *ctx);
static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx);
static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);

/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Type type);
Expand Down Expand Up @@ -374,8 +376,9 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
}

inline bool FloatType::classof(Type type) {
return type.isa<Float8E5M2Type, Float8E4M3FNType, BFloat16Type, Float16Type,
Float32Type, Float64Type, Float80Type, Float128Type>();
return type.isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
Float8E4M3FNUZType, BFloat16Type, Float16Type, Float32Type,
Float64Type, Float80Type, Float128Type>();
}

inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) {
Expand All @@ -386,6 +389,14 @@ inline FloatType FloatType::getFloat8E4M3FN(MLIRContext *ctx) {
return Float8E4M3FNType::get(ctx);
}

inline FloatType FloatType::getFloat8E5M2FNUZ(MLIRContext *ctx) {
return Float8E5M2FNUZType::get(ctx);
}

inline FloatType FloatType::getFloat8E4M3FNUZ(MLIRContext *ctx) {
return Float8E4M3FNUZType::get(ctx);
}

inline FloatType FloatType::getBF16(MLIRContext *ctx) {
return BFloat16Type::get(ctx);
}
Expand Down
44 changes: 44 additions & 0 deletions mlir/include/mlir/IR/BuiltinTypes.td
Expand Up @@ -118,6 +118,50 @@ def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN"> {
}];
}

//===----------------------------------------------------------------------===//
// Float8E5M2FNUZType

def Builtin_Float8E5M2FNUZ : Builtin_FloatType<"Float8E5M2FNUZ"> {
let summary = "8-bit floating point with 2 bit mantissa";
let description = [{
An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits
mantissa. This is not a standard type as defined by IEEE-754, but it follows
similar conventions, with the exception that there are no infinity values,
no negative zero, and only one NaN representation. This type has the
following characteristics:

* bit encoding: S1E5M2
* exponent bias: 16
* infinities: Not supported
* NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s
* denormals when exponent is 0

Described in: https://arxiv.org/abs/2206.02915
}];
}

//===----------------------------------------------------------------------===//
// Float8E4M3FNUZType

def Builtin_Float8E4M3FNUZ : Builtin_FloatType<"Float8E4M3FNUZ"> {
let summary = "8-bit floating point with 3 bit mantissa";
let description = [{
An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits
mantissa. This is not a standard type as defined by IEEE-754, but it follows
similar conventions, with the exception that there are no infinity values,
no negative zero, and only one NaN representation. This type has the
following characteristics:

* bit encoding: S1E4M3
* exponent bias: 8
* infinities: Not supported
* NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s
* denormals when exponent is 0

Described in: https://arxiv.org/abs/2209.05433
}];
}

//===----------------------------------------------------------------------===//
// BFloat16Type

Expand Down
4 changes: 4 additions & 0 deletions mlir/include/mlir/IR/OpBase.td
Expand Up @@ -488,6 +488,10 @@ def F8E4M3FN : Type<CPred<"$_self.isFloat8E4M3FN()">, "f8E4M3FN type">,
BuildableType<"$_builder.getFloat8E4M3FNType()">;
def F8E5M2 : Type<CPred<"$_self.isFloat8E5M2()">, "f8E5M2 type">,
BuildableType<"$_builder.getFloat8E5M2Type()">;
def F8E4M3FNUZ : Type<CPred<"$_self.isFloat8E4M3FNUZ()">, "f8E4M3FNUZ type">,
BuildableType<"$_builder.getFloat8E4M3FNUZType()">;
def F8E5M2FNUZ : Type<CPred<"$_self.isFloat8E5M2FNUZ()">, "f8E5M2FNUZ type">,
BuildableType<"$_builder.getFloat8E5M2FNUZType()">;

def AnyComplex : Type<CPred<"$_self.isa<::mlir::ComplexType>()">,
"complex-type", "::mlir::ComplexType">;
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/IR/Types.h
Expand Up @@ -122,6 +122,8 @@ class Type {
bool isIndex() const;
bool isFloat8E5M2() const;
bool isFloat8E4M3FN() const;
bool isFloat8E5M2FNUZ() const;
bool isFloat8E4M3FNUZ() const;
bool isBF16() const;
bool isF16() const;
bool isF32() const;
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/AsmParser/TokenKinds.def
Expand Up @@ -95,6 +95,8 @@ TOK_KEYWORD(f64)
TOK_KEYWORD(f80)
TOK_KEYWORD(f8E5M2)
TOK_KEYWORD(f8E4M3FN)
TOK_KEYWORD(f8E5M2FNUZ)
TOK_KEYWORD(f8E4M3FNUZ)
TOK_KEYWORD(f128)
TOK_KEYWORD(false)
TOK_KEYWORD(floordiv)
Expand Down
8 changes: 8 additions & 0 deletions mlir/lib/AsmParser/TypeParser.cpp
Expand Up @@ -33,6 +33,8 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
case Token::inttype:
case Token::kw_f8E5M2:
case Token::kw_f8E4M3FN:
case Token::kw_f8E5M2FNUZ:
case Token::kw_f8E4M3FNUZ:
case Token::kw_bf16:
case Token::kw_f16:
case Token::kw_f32:
Expand Down Expand Up @@ -295,6 +297,12 @@ Type Parser::parseNonFunctionType() {
case Token::kw_f8E4M3FN:
consumeToken(Token::kw_f8E4M3FN);
return builder.getFloat8E4M3FNType();
case Token::kw_f8E5M2FNUZ:
consumeToken(Token::kw_f8E5M2FNUZ);
return builder.getFloat8E5M2FNUZType();
case Token::kw_f8E4M3FNUZ:
consumeToken(Token::kw_f8E4M3FNUZ);
return builder.getFloat8E4M3FNUZType();
case Token::kw_bf16:
consumeToken(Token::kw_bf16);
return builder.getBF16Type();
Expand Down
38 changes: 38 additions & 0 deletions mlir/lib/Bindings/Python/IRTypes.cpp
Expand Up @@ -139,6 +139,42 @@ class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type> {
}
};

/// Floating Point Type subclass - Float8E4M3FNUZ.
class PyFloat8E4M3FNUZType : public PyConcreteType<PyFloat8E4M3FNUZType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
static constexpr const char *pyClassName = "Float8E4M3FNUZType";
using PyConcreteType::PyConcreteType;

static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](DefaultingPyMlirContext context) {
MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get());
return PyFloat8E4M3FNUZType(context->getRef(), t);
},
py::arg("context") = py::none(), "Create a float8_e4m3fnuz type.");
}
};

/// Floating Point Type subclass - Float8E5M2FNUZ.
class PyFloat8E5M2FNUZType : public PyConcreteType<PyFloat8E5M2FNUZType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
static constexpr const char *pyClassName = "Float8E5M2FNUZType";
using PyConcreteType::PyConcreteType;

static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](DefaultingPyMlirContext context) {
MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get());
return PyFloat8E5M2FNUZType(context->getRef(), t);
},
py::arg("context") = py::none(), "Create a float8_e5m2fnuz type.");
}
};

/// Floating Point Type subclass - BF16Type.
class PyBF16Type : public PyConcreteType<PyBF16Type> {
public:
Expand Down Expand Up @@ -700,6 +736,8 @@ void mlir::python::populateIRTypes(py::module &m) {
PyIndexType::bind(m);
PyFloat8E4M3FNType::bind(m);
PyFloat8E5M2Type::bind(m);
PyFloat8E4M3FNUZType::bind(m);
PyFloat8E5M2FNUZType::bind(m);
PyBF16Type::bind(m);
PyF16Type::bind(m);
PyF32Type::bind(m);
Expand Down
16 changes: 16 additions & 0 deletions mlir/lib/CAPI/IR/BuiltinTypes.cpp
Expand Up @@ -84,6 +84,22 @@ MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx) {
return wrap(FloatType::getFloat8E4M3FN(unwrap(ctx)));
}

bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type) {
return unwrap(type).isFloat8E5M2FNUZ();
}

MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx) {
return wrap(FloatType::getFloat8E5M2FNUZ(unwrap(ctx)));
}

bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type) {
return unwrap(type).isFloat8E4M3FNUZ();
}

MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx) {
return wrap(FloatType::getFloat8E4M3FNUZ(unwrap(ctx)));
}

bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); }

MlirType mlirBF16TypeGet(MlirContext ctx) {
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/IR/AsmPrinter.cpp
Expand Up @@ -2410,6 +2410,8 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
.Case<IndexType>([&](Type) { os << "index"; })
.Case<Float8E5M2Type>([&](Type) { os << "f8E5M2"; })
.Case<Float8E4M3FNType>([&](Type) { os << "f8E4M3FN"; })
.Case<Float8E5M2FNUZType>([&](Type) { os << "f8E5M2FNUZ"; })
.Case<Float8E4M3FNUZType>([&](Type) { os << "f8E4M3FNUZ"; })
.Case<BFloat16Type>([&](Type) { os << "bf16"; })
.Case<Float16Type>([&](Type) { os << "f16"; })
.Case<Float32Type>([&](Type) { os << "f32"; })
Expand Down
8 changes: 8 additions & 0 deletions mlir/lib/IR/Builders.cpp
Expand Up @@ -41,6 +41,14 @@ FloatType Builder::getFloat8E4M3FNType() {
return FloatType::getFloat8E4M3FN(context);
}

FloatType Builder::getFloat8E5M2FNUZType() {
return FloatType::getFloat8E5M2FNUZ(context);
}

FloatType Builder::getFloat8E4M3FNUZType() {
return FloatType::getFloat8E4M3FNUZ(context);
}

FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }

FloatType Builder::getF16Type() { return FloatType::getF16(context); }
Expand Down
7 changes: 6 additions & 1 deletion mlir/lib/IR/BuiltinTypes.cpp
Expand Up @@ -88,7 +88,8 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
//===----------------------------------------------------------------------===//

unsigned FloatType::getWidth() {
if (isa<Float8E5M2Type, Float8E4M3FNType>())
if (isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
Float8E4M3FNUZType>())
return 8;
if (isa<Float16Type, BFloat16Type>())
return 16;
Expand All @@ -109,6 +110,10 @@ const llvm::fltSemantics &FloatType::getFloatSemantics() {
return APFloat::Float8E5M2();
if (isa<Float8E4M3FNType>())
return APFloat::Float8E4M3FN();
if (isa<Float8E5M2FNUZType>())
return APFloat::Float8E5M2FNUZ();
if (isa<Float8E4M3FNUZType>())
return APFloat::Float8E4M3FNUZ();
if (isa<BFloat16Type>())
return APFloat::BFloat();
if (isa<Float16Type>())
Expand Down
10 changes: 10 additions & 0 deletions mlir/lib/IR/MLIRContext.cpp
Expand Up @@ -209,6 +209,8 @@ class MLIRContextImpl {
/// Cached Type Instances.
Float8E5M2Type f8E5M2Ty;
Float8E4M3FNType f8E4M3FNTy;
Float8E5M2FNUZType f8E5M2FNUZTy;
Float8E4M3FNUZType f8E4M3FNUZTy;
BFloat16Type bf16Ty;
Float16Type f16Ty;
Float32Type f32Ty;
Expand Down Expand Up @@ -281,6 +283,8 @@ MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)
/// Floating-point Types.
impl->f8E5M2Ty = TypeUniquer::get<Float8E5M2Type>(this);
impl->f8E4M3FNTy = TypeUniquer::get<Float8E4M3FNType>(this);
impl->f8E5M2FNUZTy = TypeUniquer::get<Float8E5M2FNUZType>(this);
impl->f8E4M3FNUZTy = TypeUniquer::get<Float8E4M3FNUZType>(this);
impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
impl->f16Ty = TypeUniquer::get<Float16Type>(this);
impl->f32Ty = TypeUniquer::get<Float32Type>(this);
Expand Down Expand Up @@ -870,6 +874,12 @@ Float8E5M2Type Float8E5M2Type::get(MLIRContext *context) {
Float8E4M3FNType Float8E4M3FNType::get(MLIRContext *context) {
return context->getImpl().f8E4M3FNTy;
}
Float8E5M2FNUZType Float8E5M2FNUZType::get(MLIRContext *context) {
return context->getImpl().f8E5M2FNUZTy;
}
Float8E4M3FNUZType Float8E4M3FNUZType::get(MLIRContext *context) {
return context->getImpl().f8E4M3FNUZTy;
}
BFloat16Type BFloat16Type::get(MLIRContext *context) {
return context->getImpl().bf16Ty;
}
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/IR/Types.cpp
Expand Up @@ -36,6 +36,8 @@ MLIRContext *Type::getContext() const { return getDialect().getContext(); }

bool Type::isFloat8E5M2() const { return isa<Float8E5M2Type>(); }
bool Type::isFloat8E4M3FN() const { return isa<Float8E4M3FNType>(); }
bool Type::isFloat8E5M2FNUZ() const { return isa<Float8E5M2FNUZType>(); }
bool Type::isFloat8E4M3FNUZ() const { return isa<Float8E4M3FNUZType>(); }
bool Type::isBF16() const { return isa<BFloat16Type>(); }
bool Type::isF16() const { return isa<Float16Type>(); }
bool Type::isF32() const { return isa<Float32Type>(); }
Expand Down
16 changes: 16 additions & 0 deletions mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
Expand Up @@ -52,6 +52,8 @@ __all__ = [
"DictAttr",
"Float8E4M3FNType",
"Float8E5M2Type",
"Float8E4M3FNUZType",
"Float8E5M2FNUZType",
"F16Type",
"F32Type",
"F64Type",
Expand Down Expand Up @@ -593,6 +595,20 @@ class Float8E5M2Type(Type):
@staticmethod
def isinstance(arg: Any) -> bool: ...

class Float8E4M3FNUZType(Type):
def __init__(self, cast_from_type: Type) -> None: ...
@staticmethod
def get(*args, **kwargs) -> Float8E4M3FNUZType: ...
@staticmethod
def isinstance(arg: Any) -> bool: ...

class Float8E5M2FNUZType(Type):
def __init__(self, cast_from_type: Type) -> None: ...
@staticmethod
def get(*args, **kwargs) -> Float8E5M2FNUZType: ...
@staticmethod
def isinstance(arg: Any) -> bool: ...

# TODO: Auto-generated. Audit and fix.
class F16Type(Type):
def __init__(self, cast_from_type: Type) -> None: ...
Expand Down
8 changes: 8 additions & 0 deletions mlir/test/IR/attribute.mlir
Expand Up @@ -44,6 +44,14 @@ func.func @float_attrs_pass() {
// CHECK: float_attr = 2.000000e+00 : f8E4M3FN
float_attr = 2. : f8E4M3FN
} : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = 2.000000e+00 : f8E5M2FNUZ
float_attr = 2. : f8E5M2FNUZ
} : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = 2.000000e+00 : f8E4M3FNUZ
float_attr = 2. : f8E4M3FNUZ
} : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = 2.000000e+00 : f16
float_attr = 2. : f16
Expand Down
4 changes: 4 additions & 0 deletions mlir/test/python/ir/builtin_types.py
Expand Up @@ -197,6 +197,10 @@ def testFloatType():
print("float:", Float8E4M3FNType.get())
# CHECK: float: f8E5M2
print("float:", Float8E5M2Type.get())
# CHECK: float: f8E5M2FNUZ
print("float:", Float8E5M2FNUZType.get())
# CHECK: float: f8E4M3FNUZ
print("float:", Float8E4M3FNUZType.get())
# CHECK: float: bf16
print("float:", BF16Type.get())
# CHECK: float: f16
Expand Down
2 changes: 2 additions & 0 deletions mlir/utils/lldb-scripts/mlirDataFormatters.py
Expand Up @@ -52,6 +52,8 @@ def build_ptr_str_from_addr(addrValue: lldb.SBValue, type: lldb.SBType):
"mlir::UnknownLoc": '"loc(unknown)"',
"mlir::Float8E5M2Type": '"f8E5M2"',
"mlir::Float8E4M3FNType": '"f8E4M3FN"',
"mlir::Float8E5M2FNUZType": '"f8E5M2FNUZ"',
"mlir::Float8E4M3FNUZType": '"f8E4M3FNUZ"',
"mlir::BFloat16Type": '"bf16"',
"mlir::Float16Type": '"f16"',
"mlir::Float32Type": '"f32"',
Expand Down

0 comments on commit 96267b6

Please sign in to comment.