Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions mlir/python/mlir/runtime/np_to_memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ class BF16(ctypes.Structure):

_fields_ = [("bf16", ctypes.c_int16)]

class F8E5M2(ctypes.Structure):
"""A ctype representation for MLIR's Float8E5M2."""

_fields_ = [("f8E5M2", ctypes.c_int8)]


# https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype
def as_ctype(dtp):
Expand All @@ -49,6 +54,8 @@ def as_ctype(dtp):
return F16
if ml_dtypes is not None and dtp == ml_dtypes.bfloat16:
return BF16
if ml_dtypes is not None and dtp == ml_dtypes.float8_e5m2:
return F8E5M2
return np.ctypeslib.as_ctypes_type(dtp)


Expand All @@ -65,6 +72,11 @@ def to_numpy(array):
), f"bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
if array.dtype == BF16:
return array.view("bfloat16")
assert not (
array.dtype == F8E5M2 and ml_dtypes is None
), f"float8_e5m2 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
if array.dtype == F8E5M2:
return array.view("float8_e5m2")
return array


Expand Down
41 changes: 40 additions & 1 deletion mlir/test/python/execution_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from mlir.passmanager import *
from mlir.execution_engine import *
from mlir.runtime import *
from ml_dtypes import bfloat16
from ml_dtypes import bfloat16, float8_e5m2


# Log everything to stderr and flush so that we have a unified stream to match
Expand Down Expand Up @@ -561,6 +561,45 @@ def testBF16Memref():
run(testBF16Memref)


# Test f8E5M2 memrefs
# CHECK-LABEL: TEST: testF8E5M2Memref
def testF8E5M2Memref():
with Context():
module = Module.parse(
"""
module {
func.func @main(%arg0: memref<1xf8E5M2>,
%arg1: memref<1xf8E5M2>) attributes { llvm.emit_c_interface } {
%0 = arith.constant 0 : index
%1 = memref.load %arg0[%0] : memref<1xf8E5M2>
memref.store %1, %arg1[%0] : memref<1xf8E5M2>
return
}
} """
)

arg1 = np.array([0.5]).astype(float8_e5m2)
arg2 = np.array([0.0]).astype(float8_e5m2)

arg1_memref_ptr = ctypes.pointer(
ctypes.pointer(get_ranked_memref_descriptor(arg1))
)
arg2_memref_ptr = ctypes.pointer(
ctypes.pointer(get_ranked_memref_descriptor(arg2))
)

execution_engine = ExecutionEngine(lowerToLLVM(module))
execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)

# test to-numpy utility
# CHECK: [0.5]
npout = ranked_memref_to_numpy(arg2_memref_ptr[0])
log(npout)


run(testF8E5M2Memref)


# Test addition of two 2d_memref
# CHECK-LABEL: TEST: testDynamicMemrefAdd2D
def testDynamicMemrefAdd2D():
Expand Down
Loading