diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py index de5b8d6f70d8b..5b3c3c4ae322d 100644 --- a/mlir/python/mlir/runtime/np_to_memref.py +++ b/mlir/python/mlir/runtime/np_to_memref.py @@ -18,15 +18,33 @@ class C64(ctypes.Structure): _fields_ = [("real", ctypes.c_float), ("imag", ctypes.c_float)] +class F16(ctypes.Structure): + """A ctype representation for MLIR's Float16.""" + _fields_ = [("f16", ctypes.c_int16)] + + def as_ctype(dtp): """Converts dtype to ctype.""" if dtp is np.dtype(np.complex128): return C128 if dtp is np.dtype(np.complex64): return C64 + if dtp is np.dtype(np.float16): + return F16 return np.ctypeslib.as_ctypes_type(dtp) +def to_numpy(array): + """Converts ctypes array back to numpy dtype array.""" + if array.dtype == C128: + return array.view("complex128") + if array.dtype == C64: + return array.view("complex64") + if array.dtype == F16: + return array.view("float16") + return array + + def make_nd_memref_descriptor(rank, dtype): class MemRefDescriptor(ctypes.Structure): @@ -105,11 +123,7 @@ def unranked_memref_to_numpy(unranked_memref, np_dtype): np.ctypeslib.as_array(val[0].shape), np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize, ) - if strided_arr.dtype == C128: - return strided_arr.view("complex128") - if strided_arr.dtype == C64: - return strided_arr.view("complex64") - return strided_arr + return to_numpy(strided_arr) def ranked_memref_to_numpy(ranked_memref): @@ -121,8 +135,4 @@ def ranked_memref_to_numpy(ranked_memref): np.ctypeslib.as_array(ranked_memref[0].shape), np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize, ) - if strided_arr.dtype == C128: - return strided_arr.view("complex128") - if strided_arr.dtype == C64: - return strided_arr.view("complex64") - return strided_arr + return to_numpy(strided_arr) diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py index 53cbac35482e7..6eed53c6efd15 100644 --- a/mlir/test/python/execution_engine.py +++ b/mlir/test/python/execution_engine.py @@ -266,6 +266,50 @@ def testMemrefAdd(): run(testMemrefAdd) +# Test addition of two f16 memrefs +# CHECK-LABEL: TEST: testF16MemrefAdd +def testF16MemrefAdd(): + with Context(): + module = Module.parse(""" + module { + func.func @main(%arg0: memref<1xf16>, + %arg1: memref<1xf16>, + %arg2: memref<1xf16>) attributes { llvm.emit_c_interface } { + %0 = arith.constant 0 : index + %1 = memref.load %arg0[%0] : memref<1xf16> + %2 = memref.load %arg1[%0] : memref<1xf16> + %3 = arith.addf %1, %2 : f16 + memref.store %3, %arg2[%0] : memref<1xf16> + return + } + } """) + + arg1 = np.array([11.]).astype(np.float16) + arg2 = np.array([22.]).astype(np.float16) + arg3 = np.array([0.]).astype(np.float16) + + arg1_memref_ptr = ctypes.pointer( + ctypes.pointer(get_ranked_memref_descriptor(arg1))) + arg2_memref_ptr = ctypes.pointer( + ctypes.pointer(get_ranked_memref_descriptor(arg2))) + arg3_memref_ptr = ctypes.pointer( + ctypes.pointer(get_ranked_memref_descriptor(arg3))) + + execution_engine = ExecutionEngine(lowerToLLVM(module)) + execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr, + arg3_memref_ptr) + # CHECK: [11.] + [22.] = [33.] + log("{0} + {1} = {2}".format(arg1, arg2, arg3)) + + # test to-numpy utility + # CHECK: [33.] + npout = ranked_memref_to_numpy(arg3_memref_ptr[0]) + log(npout) + + +run(testF16MemrefAdd) + + # Test addition of two complex memrefs # CHECK-LABEL: TEST: testComplexMemrefAdd def testComplexMemrefAdd(): @@ -442,15 +486,15 @@ def testSharedLibLoad(): ctypes.pointer(get_ranked_memref_descriptor(arg0))) if sys.platform == 'win32': - shared_libs = [ - "../../../../bin/mlir_runner_utils.dll", - "../../../../bin/mlir_c_runner_utils.dll" - ] + shared_libs = [ + "../../../../bin/mlir_runner_utils.dll", + "../../../../bin/mlir_c_runner_utils.dll" + ] else: - shared_libs = [ - "../../../../lib/libmlir_runner_utils.so", - "../../../../lib/libmlir_c_runner_utils.so" - ] + shared_libs = [ + "../../../../lib/libmlir_runner_utils.so", + "../../../../lib/libmlir_c_runner_utils.so" + ] execution_engine = ExecutionEngine( lowerToLLVM(module), @@ -484,15 +528,15 @@ def testNanoTime(): }""") if sys.platform == 'win32': - shared_libs = [ - "../../../../bin/mlir_runner_utils.dll", - "../../../../bin/mlir_c_runner_utils.dll" - ] + shared_libs = [ + "../../../../bin/mlir_runner_utils.dll", + "../../../../bin/mlir_c_runner_utils.dll" + ] else: - shared_libs = [ - "../../../../lib/libmlir_runner_utils.so", - "../../../../lib/libmlir_c_runner_utils.so" - ] + shared_libs = [ + "../../../../lib/libmlir_runner_utils.so", + "../../../../lib/libmlir_c_runner_utils.so" + ] execution_engine = ExecutionEngine( lowerToLLVM(module),