diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py index d65ba51afdb90..8cca1e7ad4a9e 100644 --- a/mlir/python/mlir/runtime/np_to_memref.py +++ b/mlir/python/mlir/runtime/np_to_memref.py @@ -37,25 +37,12 @@ class BF16(ctypes.Structure): _fields_ = [("bf16", ctypes.c_int16)] - class F8E5M2(ctypes.Structure): """A ctype representation for MLIR's Float8E5M2.""" _fields_ = [("f8E5M2", ctypes.c_int8)] -class F8E3M4(ctypes.Structure): - """A ctype representation for MLIR's Float8E3M4.""" - - _fields_ = [("f8E3M4", ctypes.c_int8)] - - -class F8E4M3(ctypes.Structure): - """A ctype representation for MLIR's Float8E4M3.""" - - _fields_ = [("f8E4M3", ctypes.c_int8)] - - # https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype def as_ctype(dtp): """Converts dtype to ctype.""" @@ -69,10 +56,6 @@ def as_ctype(dtp): return BF16 if ml_dtypes is not None and dtp == ml_dtypes.float8_e5m2: return F8E5M2 - if ml_dtypes is not None and dtp == ml_dtypes.float8_e3m4: - return F8E3M4 - if ml_dtypes is not None and dtp == ml_dtypes.float8_e4m3: - return F8E4M3 return np.ctypeslib.as_ctypes_type(dtp) @@ -85,17 +68,15 @@ def to_numpy(array): if array.dtype == F16: return array.view("float16") assert not ( - array.dtype in (BF16, F8E5M2, F8E3M4, F8E4M3) and ml_dtypes is None - ), f"{array.dtype=} requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n" + array.dtype == BF16 and ml_dtypes is None + ), f"bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n" if array.dtype == BF16: return array.view("bfloat16") + assert not ( + array.dtype == F8E5M2 and ml_dtypes is None + ), f"float8_e5m2 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n" if array.dtype == F8E5M2: return array.view("float8_e5m2") - if array.dtype == F8E3M4: - return array.view("float8_e3m4") - if array.dtype == F8E4M3: - return array.view("float8_e4m3") - return array diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py index 858ee089042ad..b11340f2c19ce 100644 --- a/mlir/test/python/execution_engine.py +++ b/mlir/test/python/execution_engine.py @@ -8,7 +8,7 @@ from mlir.runtime import * try: - from ml_dtypes import bfloat16, float8_e5m2, float8_e3m4, float8_e4m3 + from ml_dtypes import bfloat16, float8_e5m2 HAS_ML_DTYPES = True except ModuleNotFoundError: @@ -623,90 +623,6 @@ def testF8E5M2Memref(): log("TEST: testF8E5M2Memref") -# Test f8E3M4 memrefs -# CHECK-LABEL: TEST: testF8E3M4Memref -def testF8E3M4Memref(): - with Context(): - module = Module.parse( - """ - module { - func.func @main(%arg0: memref<1xf8E3M4>, - %arg1: memref<1xf8E3M4>) attributes { llvm.emit_c_interface } { - %0 = arith.constant 0 : index - %1 = memref.load %arg0[%0] : memref<1xf8E3M4> - memref.store %1, %arg1[%0] : memref<1xf8E3M4> - return - } - } """ - ) - - arg1 = np.array([0.5]).astype(float8_e3m4) - arg2 = np.array([0.0]).astype(float8_e3m4) - - arg1_memref_ptr = ctypes.pointer( - ctypes.pointer(get_ranked_memref_descriptor(arg1)) - ) - arg2_memref_ptr = ctypes.pointer( - ctypes.pointer(get_ranked_memref_descriptor(arg2)) - ) - - execution_engine = ExecutionEngine(lowerToLLVM(module)) - execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr) - - # test to-numpy utility - x = ranked_memref_to_numpy(arg2_memref_ptr[0]) - assert len(x) == 1 - assert x[0] == 0.5 - - -if HAS_ML_DTYPES: - run(testF8E3M4Memref) -else: - log("TEST: testF8E3M4Memref") - - -# Test f8E4M3 memrefs -# CHECK-LABEL: TEST: testF8E4M3Memref -def testF8E4M3Memref(): - with Context(): - module = Module.parse( - """ - module { - func.func @main(%arg0: memref<1xf8E4M3>, - %arg1: memref<1xf8E4M3>) attributes { llvm.emit_c_interface } { - %0 = arith.constant 0 : index - %1 = memref.load %arg0[%0] : memref<1xf8E4M3> - memref.store %1, %arg1[%0] : memref<1xf8E4M3> - return - } - } """ - ) - - arg1 = np.array([0.5]).astype(float8_e4m3) - arg2 = np.array([0.0]).astype(float8_e4m3) - - arg1_memref_ptr = ctypes.pointer( - ctypes.pointer(get_ranked_memref_descriptor(arg1)) - ) - arg2_memref_ptr = ctypes.pointer( - ctypes.pointer(get_ranked_memref_descriptor(arg2)) - ) - - execution_engine = ExecutionEngine(lowerToLLVM(module)) - execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr) - - # test to-numpy utility - x = ranked_memref_to_numpy(arg2_memref_ptr[0]) - assert len(x) == 1 - assert x[0] == 0.5 - - -if HAS_ML_DTYPES: - run(testF8E4M3Memref) -else: - log("TEST: testF8E4M3Memref") - - # Test addition of two 2d_memref # CHECK-LABEL: TEST: testDynamicMemrefAdd2D def testDynamicMemrefAdd2D():