Skip to content

Commit

Permalink
[mlir][sparse][python] migrate more code from boilerplate into proper…
Browse files Browse the repository at this point in the history
… numpy land

The boilerplate was setting up some arrays for testing. To fully illustrate
python - MLIR potential, however, this data should also come from numpy land.

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D108336
  • Loading branch information
aartbik committed Aug 20, 2021
1 parent 02d1175 commit 24ea94a
Showing 1 changed file with 31 additions and 24 deletions.
55 changes: 31 additions & 24 deletions mlir/test/python/dialects/sparse_tensor/test_SpMM.py
Expand Up @@ -55,24 +55,19 @@ def spMxM(*args):
def boilerplate(attr: st.EncodingAttr):
"""Returns boilerplate main method.
This method sets up a boilerplate main method that calls the generated
sparse kernel. For convenience, this part is purely done as string input.
This method sets up a boilerplate main method that takes three tensors
(a, b, c), converts the first tensor a into s sparse tensor, and then
calls the sparse kernel for matrix multiplication. For convenience,
this part is purely done as string input.
"""
return f"""
func @main(%c: tensor<3x2xf64>) -> tensor<3x2xf64>
func @main(%ad: tensor<3x4xf64>, %b: tensor<4x2xf64>, %c: tensor<3x2xf64>) -> tensor<3x2xf64>
attributes {{ llvm.emit_c_interface }} {{
%0 = constant dense<[ [ 1.1, 0.0, 0.0, 1.4 ],
[ 0.0, 0.0, 0.0, 0.0 ],
[ 0.0, 0.0, 3.3, 0.0 ]]> : tensor<3x4xf64>
%a = sparse_tensor.convert %0 : tensor<3x4xf64> to tensor<3x4xf64, {attr}>
%b = constant dense<[ [ 1.0, 2.0 ],
[ 4.0, 3.0 ],
[ 5.0, 6.0 ],
[ 8.0, 7.0 ]]> : tensor<4x2xf64>
%1 = call @spMxM(%a, %b, %c) : (tensor<3x4xf64, {attr}>,
%a = sparse_tensor.convert %ad : tensor<3x4xf64> to tensor<3x4xf64, {attr}>
%0 = call @spMxM(%a, %b, %c) : (tensor<3x4xf64, {attr}>,
tensor<4x2xf64>,
tensor<3x2xf64>) -> tensor<3x2xf64>
return %1 : tensor<3x2xf64>
return %0 : tensor<3x2xf64>
}}
"""

Expand All @@ -83,25 +78,34 @@ def build_compile_and_run_SpMM(attr: st.EncodingAttr, support_lib: str,
module = build_SpMM(attr)
func = str(module.operation.regions[0].blocks[0].operations[0].operation)
module = ir.Module.parse(func + boilerplate(attr))

# Compile.
compiler(module)
engine = execution_engine.ExecutionEngine(
module, opt_level=0, shared_libs=[support_lib])
# Set up numpy input, invoke the kernel, and get numpy output.

# Set up numpy input and buffer for output.
a = np.array(
[[1.1, 0.0, 0.0, 1.4], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 3.3, 0.0]],
np.float64)
b = np.array([[1.0, 2.0], [4.0, 3.0], [5.0, 6.0], [8.0, 7.0]], np.float64)
c = np.zeros((3, 2), np.float64)
out = np.zeros((3, 2), np.float64)

mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
mem_out = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(out)))

# Invoke the kernel and get numpy output.
# Built-in bufferization uses in-out buffers.
# TODO: replace with inplace comprehensive bufferization.
Cin = np.zeros((3, 2), np.double)
Cout = np.zeros((3, 2), np.double)
Cin_memref_ptr = ctypes.pointer(
ctypes.pointer(rt.get_ranked_memref_descriptor(Cin)))
Cout_memref_ptr = ctypes.pointer(
ctypes.pointer(rt.get_ranked_memref_descriptor(Cout)))
engine.invoke('main', Cout_memref_ptr, Cin_memref_ptr)
Cresult = rt.ranked_memref_to_numpy(Cout_memref_ptr[0])
engine.invoke('main', mem_out, mem_a, mem_b, mem_c)

# Sanity check on computed result.
expected = [[12.3, 12.0], [0.0, 0.0], [16.5, 19.8]]
if np.allclose(Cresult, expected):
expected = np.matmul(a, b);
c = rt.ranked_memref_to_numpy(mem_out[0])
if np.allclose(c, expected):
pass
else:
quit(f'FAILURE')
Expand Down Expand Up @@ -132,7 +136,10 @@ def __call__(self, module: ir.Module):
# CHECK: Passed 72 tests
@run
def testSpMM():
# Obtain path to runtime support library.
support_lib = os.getenv('SUPPORT_LIB')
assert os.path.exists(support_lib), f'{support_lib} does not exist'

with ir.Context() as ctx, ir.Location.unknown():
count = 0
# Fixed compiler optimization strategy.
Expand Down

0 comments on commit 24ea94a

Please sign in to comment.