-
Notifications
You must be signed in to change notification settings - Fork 11.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for numpy arrays to memref conversions.
This offers the ability to pass numpy arrays to the corresponding memref argument. Reviewed By: mehdi_amini, nicolasvasilache Differential Revision: https://reviews.llvm.org/D100077
- Loading branch information
Showing
3 changed files
with
297 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .np_to_memref import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
# This file contains functions to convert between Memrefs and NumPy arrays and vice-versa. | ||
|
||
import numpy as np | ||
import ctypes | ||
|
||
|
||
def make_nd_memref_descriptor(rank, dtype): | ||
class MemRefDescriptor(ctypes.Structure): | ||
""" | ||
Build an empty descriptor for the given rank/dtype, where rank>0. | ||
""" | ||
|
||
_fields_ = [ | ||
("allocated", ctypes.c_longlong), | ||
("aligned", ctypes.POINTER(dtype)), | ||
("offset", ctypes.c_longlong), | ||
("shape", ctypes.c_longlong * rank), | ||
("strides", ctypes.c_longlong * rank), | ||
] | ||
|
||
return MemRefDescriptor | ||
|
||
|
||
def make_zero_d_memref_descriptor(dtype): | ||
class MemRefDescriptor(ctypes.Structure): | ||
""" | ||
Build an empty descriptor for the given dtype, where rank=0. | ||
""" | ||
|
||
_fields_ = [ | ||
("allocated", ctypes.c_longlong), | ||
("aligned", ctypes.POINTER(dtype)), | ||
("offset", ctypes.c_longlong), | ||
] | ||
|
||
return MemRefDescriptor | ||
|
||
|
||
class UnrankedMemRefDescriptor(ctypes.Structure): | ||
""" Creates a ctype struct for memref descriptor""" | ||
|
||
_fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)] | ||
|
||
|
||
def get_ranked_memref_descriptor(nparray): | ||
""" | ||
Return a ranked memref descriptor for the given numpy array. | ||
""" | ||
if nparray.ndim == 0: | ||
x = make_zero_d_memref_descriptor(np.ctypeslib.as_ctypes_type(nparray.dtype))() | ||
x.allocated = nparray.ctypes.data | ||
x.aligned = nparray.ctypes.data_as( | ||
ctypes.POINTER(np.ctypeslib.as_ctypes_type(nparray.dtype)) | ||
) | ||
x.offset = ctypes.c_longlong(0) | ||
return x | ||
|
||
x = make_nd_memref_descriptor( | ||
nparray.ndim, np.ctypeslib.as_ctypes_type(nparray.dtype) | ||
)() | ||
x.allocated = nparray.ctypes.data | ||
x.aligned = nparray.ctypes.data_as( | ||
ctypes.POINTER(np.ctypeslib.as_ctypes_type(nparray.dtype)) | ||
) | ||
x.offset = ctypes.c_longlong(0) | ||
x.shape = nparray.ctypes.shape | ||
|
||
# Numpy uses byte quantities to express strides, MLIR OTOH uses the | ||
# torch abstraction which specifies strides in terms of elements. | ||
strides_ctype_t = ctypes.c_longlong * nparray.ndim | ||
x.strides = strides_ctype_t(*[x // nparray.itemsize for x in nparray.strides]) | ||
return x | ||
|
||
|
||
def get_unranked_memref_descriptor(nparray): | ||
""" | ||
Return a generic/unranked memref descriptor for the given numpy array. | ||
""" | ||
d = UnrankedMemRefDescriptor() | ||
d.rank = nparray.ndim | ||
x = get_ranked_memref_descriptor(nparray) | ||
d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p) | ||
return d | ||
|
||
|
||
def unranked_memref_to_numpy(unranked_memref, np_dtype): | ||
""" | ||
Converts unranked memrefs to numpy arrays. | ||
""" | ||
descriptor = make_nd_memref_descriptor( | ||
unranked_memref[0].rank, np.ctypeslib.as_ctypes_type(np_dtype) | ||
) | ||
val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor)) | ||
np_arr = np.ctypeslib.as_array(val[0].aligned, shape=val[0].shape) | ||
strided_arr = np.lib.stride_tricks.as_strided( | ||
np_arr, | ||
np.ctypeslib.as_array(val[0].shape), | ||
np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize, | ||
) | ||
return strided_arr | ||
|
||
|
||
def ranked_memref_to_numpy(ranked_memref): | ||
""" | ||
Converts ranked memrefs to numpy arrays. | ||
""" | ||
np_arr = np.ctypeslib.as_array( | ||
ranked_memref[0].aligned, shape=ranked_memref[0].shape | ||
) | ||
strided_arr = np.lib.stride_tricks.as_strided( | ||
np_arr, | ||
np.ctypeslib.as_array(ranked_memref[0].shape), | ||
np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize, | ||
) | ||
return strided_arr |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters