diff --git a/mlir/python/mlir/extras/types.py b/mlir/python/mlir/extras/types.py index 34eee1edb57ff..b875d639e9d40 100644 --- a/mlir/python/mlir/extras/types.py +++ b/mlir/python/mlir/extras/types.py @@ -21,6 +21,7 @@ Float8E4M3Type, Float8E5M2Type, Float8E8M0FNUType, + FloatTF32Type, FunctionType, IndexType, IntegerType, @@ -70,6 +71,7 @@ def ui(width): f16 = lambda: F16Type.get() f32 = lambda: F32Type.get() +tf32 = lambda: FloatTF32Type.get() f64 = lambda: F64Type.get() bf16 = lambda: BF16Type.get() diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py index 48ddc8359ca0a..6ce0fc12d8082 100644 --- a/mlir/test/python/ir/builtin_types.py +++ b/mlir/test/python/ir/builtin_types.py @@ -639,6 +639,7 @@ def testTypeIDs(): (BF16Type, BF16Type.get()), (F16Type, F16Type.get()), (F32Type, F32Type.get()), + (FloatTF32Type, FloatTF32Type.get()), (F64Type, F64Type.get()), (NoneType, NoneType.get()), (ComplexType, ComplexType.get(f32)), @@ -668,6 +669,7 @@ def testTypeIDs(): # CHECK: BF16Type(bf16) # CHECK: F16Type(f16) # CHECK: F32Type(f32) + # CHECK: FloatTF32Type(tf32) # CHECK: F64Type(f64) # CHECK: NoneType(none) # CHECK: ComplexType(complex) @@ -734,6 +736,9 @@ def print_downcasted(typ): # CHECK: F32Type # CHECK: F32Type(f32) print_downcasted(F32Type.get()) + # CHECK: FloatTF32Type + # CHECK: FloatTF32Type(tf32) + print_downcasted(FloatTF32Type.get()) # CHECK: F64Type # CHECK: F64Type(f64) print_downcasted(F64Type.get())