diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td index c97d093c84e51..207b99164d0d6 100644 --- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td +++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td @@ -231,6 +231,14 @@ def BFloat16Type : DialectType<(type)>; def Float16Type : DialectType<(type)>; +// Make it easy to stage the addition of new floating point types so that +// readers can be updated first. This is enabled by default, but can be flipped +// to DialectTypeNoPrint if staging needed. +// Note: this will be removed post next release. +class EnableFloatPrintingJune2026 : DialectType; + +def FloatTF32Type : EnableFloatPrintingJune2026<(type)>; + def Float32Type : DialectType<(type)>; def Float64Type : DialectType<(type)>; @@ -239,6 +247,28 @@ def Float80Type : DialectType<(type)>; def Float128Type : DialectType<(type)>; +def Float8E5M2Type : EnableFloatPrintingJune2026<(type)>; + +def Float8E4M3Type : EnableFloatPrintingJune2026<(type)>; + +def Float8E4M3FNType : EnableFloatPrintingJune2026<(type)>; + +def Float8E5M2FNUZType : EnableFloatPrintingJune2026<(type)>; + +def Float8E4M3FNUZType : EnableFloatPrintingJune2026<(type)>; + +def Float8E4M3B11FNUZType : EnableFloatPrintingJune2026<(type)>; + +def Float8E3M4Type : EnableFloatPrintingJune2026<(type)>; + +def Float4E2M1FNType : EnableFloatPrintingJune2026<(type)>; + +def Float6E2M3FNType : EnableFloatPrintingJune2026<(type)>; + +def Float6E3M2FNType : EnableFloatPrintingJune2026<(type)>; + +def Float8E8M0FNUType : EnableFloatPrintingJune2026<(type)>; + def ComplexType : DialectType<(type Type:$elementType )>; @@ -371,7 +401,19 @@ def BuiltinDialectTypes : DialectTypes<"Builtin"> { UnrankedMemRefTypeWithMemSpace, UnrankedTensorType, VectorType, - VectorTypeWithScalableDims + VectorTypeWithScalableDims, + FloatTF32Type, + Float8E5M2Type, + Float8E4M3Type, + Float8E4M3FNType, + Float8E5M2FNUZType, + Float8E4M3FNUZType, + Float8E4M3B11FNUZType, + Float8E3M4Type, + Float4E2M1FNType, + Float6E2M3FNType, + Float6E3M2FNType, + Float8E8M0FNUType ]; } diff --git a/mlir/include/mlir/IR/BytecodeBase.td b/mlir/include/mlir/IR/BytecodeBase.td index 184c81e6a5f7d..df60800ef639b 100644 --- a/mlir/include/mlir/IR/BytecodeBase.td +++ b/mlir/include/mlir/IR/BytecodeBase.td @@ -153,6 +153,10 @@ class DialectType : DialectAttrOrType, TypeKind { let cParser = "succeeded($_reader.readType<$_resultType>($_var))"; let cBuilder = "getChecked<$_resultType>([&]() { return reader.emitError(); }, context, $_args)"; } +// Variant of the above, where it never prints. Useful for staging. +class DialectTypeNoPrint : DialectType { + let printerPredicate = "false"; +} class DialectAttributes { string dialect = d; diff --git a/mlir/test/Dialect/Builtin/Bytecode/types.mlir b/mlir/test/Dialect/Builtin/Bytecode/types.mlir index bcfbf64c833dd..5e421e2bf75bf 100644 --- a/mlir/test/Dialect/Builtin/Bytecode/types.mlir +++ b/mlir/test/Dialect/Builtin/Bytecode/types.mlir @@ -18,16 +18,40 @@ module @TestComplex attributes { module @TestFloat attributes { // CHECK: bytecode.test = bf16, // CHECK: bytecode.test1 = f16, + // CHECK: bytecode.test10 = f8E4M3FNUZ, + // CHECK: bytecode.test11 = f8E4M3B11FNUZ, + // CHECK: bytecode.test12 = f8E3M4, + // CHECK: bytecode.test13 = f4E2M1FN, + // CHECK: bytecode.test14 = f6E2M3FN, + // CHECK: bytecode.test15 = f6E3M2FN, + // CHECK: bytecode.test16 = f8E8M0FNU, + // CHECK: bytecode.test17 = tf32, // CHECK: bytecode.test2 = f32, // CHECK: bytecode.test3 = f64, // CHECK: bytecode.test4 = f80, - // CHECK: bytecode.test5 = f128 + // CHECK: bytecode.test5 = f128, + // CHECK: bytecode.test6 = f8E5M2, + // CHECK: bytecode.test7 = f8E4M3, + // CHECK: bytecode.test8 = f8E4M3FN, + // CHECK: bytecode.test9 = f8E5M2FNUZ bytecode.test = bf16, bytecode.test1 = f16, bytecode.test2 = f32, bytecode.test3 = f64, bytecode.test4 = f80, - bytecode.test5 = f128 + bytecode.test5 = f128, + bytecode.test6 = f8E5M2, + bytecode.test7 = f8E4M3, + bytecode.test8 = f8E4M3FN, + bytecode.test9 = f8E5M2FNUZ, + bytecode.test10 = f8E4M3FNUZ, + bytecode.test11 = f8E4M3B11FNUZ, + bytecode.test12 = f8E3M4, + bytecode.test13 = f4E2M1FN, + bytecode.test14 = f6E2M3FN, + bytecode.test15 = f6E3M2FN, + bytecode.test16 = f8E8M0FNU, + bytecode.test17 = tf32 } {} //===----------------------------------------------------------------------===//