Skip to content

Commit

Permalink
Add support for numpy arrays to memref conversions.
Browse files Browse the repository at this point in the history
This offers the ability to pass numpy arrays to the corresponding
memref argument.

Reviewed By: mehdi_amini, nicolasvasilache

Differential Revision: https://reviews.llvm.org/D100077
  • Loading branch information
Prashant Kumar authored and joker-eph committed Apr 15, 2021
1 parent b2b59f6 commit 102fd1c
Show file tree
Hide file tree
Showing 3 changed files with 297 additions and 0 deletions.
1 change: 1 addition & 0 deletions mlir/lib/Bindings/Python/mlir/runtime/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .np_to_memref import *
119 changes: 119 additions & 0 deletions mlir/lib/Bindings/Python/mlir/runtime/np_to_memref.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# This file contains functions to convert between Memrefs and NumPy arrays and vice-versa.

import numpy as np
import ctypes


def make_nd_memref_descriptor(rank, dtype):
class MemRefDescriptor(ctypes.Structure):
"""
Build an empty descriptor for the given rank/dtype, where rank>0.
"""

_fields_ = [
("allocated", ctypes.c_longlong),
("aligned", ctypes.POINTER(dtype)),
("offset", ctypes.c_longlong),
("shape", ctypes.c_longlong * rank),
("strides", ctypes.c_longlong * rank),
]

return MemRefDescriptor


def make_zero_d_memref_descriptor(dtype):
class MemRefDescriptor(ctypes.Structure):
"""
Build an empty descriptor for the given dtype, where rank=0.
"""

_fields_ = [
("allocated", ctypes.c_longlong),
("aligned", ctypes.POINTER(dtype)),
("offset", ctypes.c_longlong),
]

return MemRefDescriptor


class UnrankedMemRefDescriptor(ctypes.Structure):
""" Creates a ctype struct for memref descriptor"""

_fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)]


def get_ranked_memref_descriptor(nparray):
"""
Return a ranked memref descriptor for the given numpy array.
"""
if nparray.ndim == 0:
x = make_zero_d_memref_descriptor(np.ctypeslib.as_ctypes_type(nparray.dtype))()
x.allocated = nparray.ctypes.data
x.aligned = nparray.ctypes.data_as(
ctypes.POINTER(np.ctypeslib.as_ctypes_type(nparray.dtype))
)
x.offset = ctypes.c_longlong(0)
return x

x = make_nd_memref_descriptor(
nparray.ndim, np.ctypeslib.as_ctypes_type(nparray.dtype)
)()
x.allocated = nparray.ctypes.data
x.aligned = nparray.ctypes.data_as(
ctypes.POINTER(np.ctypeslib.as_ctypes_type(nparray.dtype))
)
x.offset = ctypes.c_longlong(0)
x.shape = nparray.ctypes.shape

# Numpy uses byte quantities to express strides, MLIR OTOH uses the
# torch abstraction which specifies strides in terms of elements.
strides_ctype_t = ctypes.c_longlong * nparray.ndim
x.strides = strides_ctype_t(*[x // nparray.itemsize for x in nparray.strides])
return x


def get_unranked_memref_descriptor(nparray):
"""
Return a generic/unranked memref descriptor for the given numpy array.
"""
d = UnrankedMemRefDescriptor()
d.rank = nparray.ndim
x = get_ranked_memref_descriptor(nparray)
d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p)
return d


def unranked_memref_to_numpy(unranked_memref, np_dtype):
"""
Converts unranked memrefs to numpy arrays.
"""
descriptor = make_nd_memref_descriptor(
unranked_memref[0].rank, np.ctypeslib.as_ctypes_type(np_dtype)
)
val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor))
np_arr = np.ctypeslib.as_array(val[0].aligned, shape=val[0].shape)
strided_arr = np.lib.stride_tricks.as_strided(
np_arr,
np.ctypeslib.as_array(val[0].shape),
np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize,
)
return strided_arr


def ranked_memref_to_numpy(ranked_memref):
"""
Converts ranked memrefs to numpy arrays.
"""
np_arr = np.ctypeslib.as_array(
ranked_memref[0].aligned, shape=ranked_memref[0].shape
)
strided_arr = np.lib.stride_tricks.as_strided(
np_arr,
np.ctypeslib.as_array(ranked_memref[0].shape),
np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize,
)
return strided_arr
177 changes: 177 additions & 0 deletions mlir/test/Bindings/Python/execution_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from mlir.ir import *
from mlir.passmanager import *
from mlir.execution_engine import *
from mlir.runtime import *

# Log everything to stderr and flush so that we have a unified stream to match
# errors/info emitted by MLIR to stderr.
Expand Down Expand Up @@ -131,3 +132,179 @@ def callback(a, b):
log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0]*2))

run(testBasicCallback)

# Test callback with an unranked memref
# CHECK-LABEL: TEST: testUnrankedMemRefCallback
def testUnrankedMemRefCallback():
# Define a callback function that takes an unranked memref, converts it to a numpy array and prints it.
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
def callback(a):
arr = unranked_memref_to_numpy(a, np.float32)
log("Inside callback: ")
log(arr)

with Context():
# The module just forwards to a runtime function known as "some_callback_into_python".
module = Module.parse(
r"""
func @callback_memref(%arg0: memref<*xf32>) attributes { llvm.emit_c_interface } {
call @some_callback_into_python(%arg0) : (memref<*xf32>) -> ()
return
}
func private @some_callback_into_python(memref<*xf32>) -> () attributes { llvm.emit_c_interface }
"""
)
execution_engine = ExecutionEngine(lowerToLLVM(module))
execution_engine.register_runtime("some_callback_into_python", callback)
inp_arr = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32)
# CHECK: Inside callback:
# CHECK{LITERAL}: [[1. 2.]
# CHECK{LITERAL}: [3. 4.]]
execution_engine.invoke(
"callback_memref",
ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(inp_arr))),
)
inp_arr_1 = np.array([5, 6, 7], dtype=np.float32)
strided_arr = np.lib.stride_tricks.as_strided(
inp_arr_1, strides=(4, 0), shape=(3, 4)
)
# CHECK: Inside callback:
# CHECK{LITERAL}: [[5. 5. 5. 5.]
# CHECK{LITERAL}: [6. 6. 6. 6.]
# CHECK{LITERAL}: [7. 7. 7. 7.]]
execution_engine.invoke(
"callback_memref",
ctypes.pointer(
ctypes.pointer(get_unranked_memref_descriptor(strided_arr))
),
)

run(testUnrankedMemRefCallback)

# Test callback with a ranked memref.
# CHECK-LABEL: TEST: testRankedMemRefCallback
def testRankedMemRefCallback():
# Define a callback function that takes a ranked memref, converts it to a numpy array and prints it.
@ctypes.CFUNCTYPE(
None,
ctypes.POINTER(
make_nd_memref_descriptor(2, np.ctypeslib.as_ctypes_type(np.float32))
),
)
def callback(a):
arr = ranked_memref_to_numpy(a)
log("Inside Callback: ")
log(arr)

with Context():
# The module just forwards to a runtime function known as "some_callback_into_python".
module = Module.parse(
r"""
func @callback_memref(%arg0: memref<2x2xf32>) attributes { llvm.emit_c_interface } {
call @some_callback_into_python(%arg0) : (memref<2x2xf32>) -> ()
return
}
func private @some_callback_into_python(memref<2x2xf32>) -> () attributes { llvm.emit_c_interface }
"""
)
execution_engine = ExecutionEngine(lowerToLLVM(module))
execution_engine.register_runtime("some_callback_into_python", callback)
inp_arr = np.array([[1.0, 5.0], [6.0, 7.0]], np.float32)
# CHECK: Inside Callback:
# CHECK{LITERAL}: [[1. 5.]
# CHECK{LITERAL}: [6. 7.]]
execution_engine.invoke(
"callback_memref", ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr)))
)

run(testRankedMemRefCallback)

# Test addition of two memref
# CHECK-LABEL: TEST: testMemrefAdd
def testMemrefAdd():
with Context():
module = Module.parse(
"""
module {
func @main(%arg0: memref<1xf32>, %arg1: memref<f32>, %arg2: memref<1xf32>) attributes { llvm.emit_c_interface } {
%0 = constant 0 : index
%1 = memref.load %arg0[%0] : memref<1xf32>
%2 = memref.load %arg1[] : memref<f32>
%3 = addf %1, %2 : f32
memref.store %3, %arg2[%0] : memref<1xf32>
return
}
} """
)
arg1 = np.array([32.5]).astype(np.float32)
arg2 = np.array(6).astype(np.float32)
res = np.array([0]).astype(np.float32)

arg1_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg1)))
arg2_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg2)))
res_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(res)))

execution_engine = ExecutionEngine(lowerToLLVM(module))
execution_engine.invoke(
"main", arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr
)
# CHECK: [32.5] + 6.0 = [38.5]
log("{0} + {1} = {2}".format(arg1, arg2, res))

run(testMemrefAdd)

# Test addition of two 2d_memref
# CHECK-LABEL: TEST: testDynamicMemrefAdd2D
def testDynamicMemrefAdd2D():
with Context():
module = Module.parse(
"""
module {
func @memref_add_2d(%arg0: memref<2x2xf32>, %arg1: memref<?x?xf32>, %arg2: memref<2x2xf32>) attributes {llvm.emit_c_interface} {
%c0 = constant 0 : index
%c2 = constant 2 : index
%c1 = constant 1 : index
br ^bb1(%c0 : index)
^bb1(%0: index): // 2 preds: ^bb0, ^bb5
%1 = cmpi slt, %0, %c2 : index
cond_br %1, ^bb2, ^bb6
^bb2: // pred: ^bb1
%c0_0 = constant 0 : index
%c2_1 = constant 2 : index
%c1_2 = constant 1 : index
br ^bb3(%c0_0 : index)
^bb3(%2: index): // 2 preds: ^bb2, ^bb4
%3 = cmpi slt, %2, %c2_1 : index
cond_br %3, ^bb4, ^bb5
^bb4: // pred: ^bb3
%4 = memref.load %arg0[%0, %2] : memref<2x2xf32>
%5 = memref.load %arg1[%0, %2] : memref<?x?xf32>
%6 = addf %4, %5 : f32
memref.store %6, %arg2[%0, %2] : memref<2x2xf32>
%7 = addi %2, %c1_2 : index
br ^bb3(%7 : index)
^bb5: // pred: ^bb3
%8 = addi %0, %c1 : index
br ^bb1(%8 : index)
^bb6: // pred: ^bb1
return
}
}
"""
)
arg1 = np.random.randn(2,2).astype(np.float32)
arg2 = np.random.randn(2,2).astype(np.float32)
res = np.random.randn(2,2).astype(np.float32)

arg1_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg1)))
arg2_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg2)))
res_memref_ptr = ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(res)))

execution_engine = ExecutionEngine(lowerToLLVM(module))
execution_engine.invoke(
"memref_add_2d", arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr
)
# CHECK: True
log(np.allclose(arg1+arg2, res))

run(testDynamicMemrefAdd2D)

0 comments on commit 102fd1c

Please sign in to comment.