diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py index 8cca1e7ad4a9e..d65ba51afdb90 100644 --- a/mlir/python/mlir/runtime/np_to_memref.py +++ b/mlir/python/mlir/runtime/np_to_memref.py @@ -37,12 +37,25 @@ class BF16(ctypes.Structure): _fields_ = [("bf16", ctypes.c_int16)] + class F8E5M2(ctypes.Structure): """A ctype representation for MLIR's Float8E5M2.""" _fields_ = [("f8E5M2", ctypes.c_int8)] +class F8E3M4(ctypes.Structure): + """A ctype representation for MLIR's Float8E3M4.""" + + _fields_ = [("f8E3M4", ctypes.c_int8)] + + +class F8E4M3(ctypes.Structure): + """A ctype representation for MLIR's Float8E4M3.""" + + _fields_ = [("f8E4M3", ctypes.c_int8)] + + # https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype def as_ctype(dtp): """Converts dtype to ctype.""" @@ -56,6 +69,10 @@ def as_ctype(dtp): return BF16 if ml_dtypes is not None and dtp == ml_dtypes.float8_e5m2: return F8E5M2 + if ml_dtypes is not None and dtp == ml_dtypes.float8_e3m4: + return F8E3M4 + if ml_dtypes is not None and dtp == ml_dtypes.float8_e4m3: + return F8E4M3 return np.ctypeslib.as_ctypes_type(dtp) @@ -68,15 +85,17 @@ def to_numpy(array): if array.dtype == F16: return array.view("float16") assert not ( - array.dtype == BF16 and ml_dtypes is None - ), f"bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n" + array.dtype in (BF16, F8E5M2, F8E3M4, F8E4M3) and ml_dtypes is None + ), f"{array.dtype=} requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n" if array.dtype == BF16: return array.view("bfloat16") - assert not ( - array.dtype == F8E5M2 and ml_dtypes is None - ), f"float8_e5m2 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n" if array.dtype == F8E5M2: return array.view("float8_e5m2") + if array.dtype == F8E3M4: + return array.view("float8_e3m4") + if array.dtype == F8E4M3: + return array.view("float8_e4m3") + return array diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py index b11340f2c19ce..858ee089042ad 100644 --- a/mlir/test/python/execution_engine.py +++ b/mlir/test/python/execution_engine.py @@ -8,7 +8,7 @@ from mlir.runtime import * try: - from ml_dtypes import bfloat16, float8_e5m2 + from ml_dtypes import bfloat16, float8_e5m2, float8_e3m4, float8_e4m3 HAS_ML_DTYPES = True except ModuleNotFoundError: @@ -623,6 +623,90 @@ def testF8E5M2Memref(): log("TEST: testF8E5M2Memref") +# Test f8E3M4 memrefs +# CHECK-LABEL: TEST: testF8E3M4Memref +def testF8E3M4Memref(): + with Context(): + module = Module.parse( + """ + module { + func.func @main(%arg0: memref<1xf8E3M4>, + %arg1: memref<1xf8E3M4>) attributes { llvm.emit_c_interface } { + %0 = arith.constant 0 : index + %1 = memref.load %arg0[%0] : memref<1xf8E3M4> + memref.store %1, %arg1[%0] : memref<1xf8E3M4> + return + } + } """ + ) + + arg1 = np.array([0.5]).astype(float8_e3m4) + arg2 = np.array([0.0]).astype(float8_e3m4) + + arg1_memref_ptr = ctypes.pointer( + ctypes.pointer(get_ranked_memref_descriptor(arg1)) + ) + arg2_memref_ptr = ctypes.pointer( + ctypes.pointer(get_ranked_memref_descriptor(arg2)) + ) + + execution_engine = ExecutionEngine(lowerToLLVM(module)) + execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr) + + # test to-numpy utility + x = ranked_memref_to_numpy(arg2_memref_ptr[0]) + assert len(x) == 1 + assert x[0] == 0.5 + + +if HAS_ML_DTYPES: + run(testF8E3M4Memref) +else: + log("TEST: testF8E3M4Memref") + + +# Test f8E4M3 memrefs +# CHECK-LABEL: TEST: testF8E4M3Memref +def testF8E4M3Memref(): + with Context(): + module = Module.parse( + """ + module { + func.func @main(%arg0: memref<1xf8E4M3>, + %arg1: memref<1xf8E4M3>) attributes { llvm.emit_c_interface } { + %0 = arith.constant 0 : index + %1 = memref.load %arg0[%0] : memref<1xf8E4M3> + memref.store %1, %arg1[%0] : memref<1xf8E4M3> + return + } + } """ + ) + + arg1 = np.array([0.5]).astype(float8_e4m3) + arg2 = np.array([0.0]).astype(float8_e4m3) + + arg1_memref_ptr = ctypes.pointer( + ctypes.pointer(get_ranked_memref_descriptor(arg1)) + ) + arg2_memref_ptr = ctypes.pointer( + ctypes.pointer(get_ranked_memref_descriptor(arg2)) + ) + + execution_engine = ExecutionEngine(lowerToLLVM(module)) + execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr) + + # test to-numpy utility + x = ranked_memref_to_numpy(arg2_memref_ptr[0]) + assert len(x) == 1 + assert x[0] == 0.5 + + +if HAS_ML_DTYPES: + run(testF8E4M3Memref) +else: + log("TEST: testF8E4M3Memref") + + # Test addition of two 2d_memref # CHECK-LABEL: TEST: testDynamicMemrefAdd2D def testDynamicMemrefAdd2D():