Skip to content

Commit

Permalink
Integrating alp with iree-llvm-sandbox
Browse files Browse the repository at this point in the history
  • Loading branch information
giuseros authored and Giuseppe Rossini committed Dec 3, 2021
1 parent a7632f8 commit 495f7b2
Show file tree
Hide file tree
Showing 20 changed files with 1,157 additions and 0 deletions.
14 changes: 14 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ set(MLIR_MAIN_SRC_DIR ${LLVM_MAIN_SRC_DIR}/../mlir)
set(MLIR_INCLUDE_DIR ${LLVM_MAIN_SRC_DIR}/../mlir/include)
set(MLIR_TABLEGEN_OUTPUT_DIR ${CMAKE_BINARY_DIR}/tools/mlir/include)

# Disable experimental alp by default
set(SANDBOX_ENABLE_ALP OFF)
if (SANDBOX_ENABLE_ALP)
# ALP includes
add_compile_definitions("SANDBOX_ENABLE_ALP")
include_directories(experimental/alp/include)
include_directories(${CMAKE_BINARY_DIR}/tools/sandbox/experimental/alp/include)
endif()

list(APPEND CMAKE_MODULE_PATH ${MLIR_MAIN_SRC_DIR}/cmake/modules)
list(APPEND CMAKE_MODULE_PATH ${LLVM_MAIN_SRC_DIR}/cmake)
set(MLIR_TABLEGEN_EXE mlir-tblgen)
Expand Down Expand Up @@ -80,3 +89,8 @@ add_subdirectory(lib)
add_subdirectory(python)
add_subdirectory(test)
add_subdirectory(tools)
if (SANDBOX_ENABLE_ALP)
# ALP cmake files
add_subdirectory(experimental)
endif()

1 change: 1 addition & 0 deletions experimental/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(alp)
14 changes: 14 additions & 0 deletions experimental/alp/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)

set(LLVM_LINK_COMPONENTS
Core
Support
nativecodegen
native
OrcJIT
)

include_directories(include/)
add_subdirectory(include)
add_subdirectory(lib)
216 changes: 216 additions & 0 deletions experimental/alp/alp/compile_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
import os
import tempfile
import shutil
from pathlib import Path
from .utils import print_command, run_and_save, run_command, add_extension
from .library.blas import gemm

def build_main_obj(benchmark_prog, m, n, k, op, reps, mktmp_fn):
benchmark_prog = benchmark_prog.replace("_M_", str(m))
benchmark_prog = benchmark_prog.replace("_K_", str(k))
benchmark_prog = benchmark_prog.replace("_N_", str(n))
benchmark_prog = benchmark_prog.replace("__OP__", op)
benchmark_prog = benchmark_prog.replace("_REPS_", str(reps))

main_mlir = mktmp_fn("test.mlir")
main_mlir_lowered = mktmp_fn("test.llvm.mlir")
main_llvm = mktmp_fn("test.ll")
main_obj = mktmp_fn("test.o")

f = open(main_mlir, "w")
f.write(benchmark_prog)
f.close()

# main program
cmd = ["mlir-opt"]
cmd.append(main_mlir)
cmd.append("--linalg-bufferize")
cmd.append("--std-bufferize")
cmd.append("--tensor-constant-bufferize")
cmd.append("--tensor-bufferize")
cmd.append("--func-bufferize")
cmd.append("-convert-linalg-to-affine-loops")
cmd.append("-lower-affine")
cmd.append("-convert-scf-to-std")
cmd.append("-convert-memref-to-llvm")
cmd.append("-convert-std-to-llvm")
cmd.append("-reconcile-unrealized-casts")
cmd.append(f"> {main_mlir_lowered}")
run_command(cmd)
print_command(cmd)

cmd = ["mlir-translate"]
cmd.append("--mlir-to-llvmir")
cmd.append(f"{main_mlir_lowered}")
cmd.append(f"> {main_llvm}")
run_command(cmd)

cmd = ["llc"]
cmd.append(f"{main_llvm}")
cmd.append("-O3")
cmd.append("-filetype=obj")
cmd.append(f"-o {main_obj}")
run_command(cmd)


def apply(transform_list, op_mlir_file, verbosity_level):
cmd = ["$IREE_LLVM_SANDBOX_BUILD_DIR/bin/mlir-proto-opt "]
for t in transform_list:
if not t:
continue
if type(t) is tuple:
(l, ext) = t
if l >= verbosity_level:
run_and_save(cmd, op_mlir_file, add_extension(op_mlir_file, ext))
else:
cmd.append(t)

output = add_extension(op_mlir_file, "llvm")
run_and_save(cmd, op_mlir_file, output)
return output

def SaveIR(x, ext):
return (x, ext)

def build_operator_obj(op_prog, m, n, k, op, option_list, mktmp_fn, verbosity_level=0):
op_prog = op_prog.replace("_M_", str(m))
op_prog = op_prog.replace("_K_", str(k))
op_prog = op_prog.replace("_N_", str(n))

op_mlir = mktmp_fn(f"{op}.mlir")
f = open(f"{op_mlir}", "w")
f.write(op_prog)
f.close()

# Transformation options
tile_sizes = option_list["tile_sizes"]
reorder_tile_sizes = option_list["reorder_tile_sizes"]
register_tile_sizes = option_list["register_tile_sizes"]
reorder_register_tile_sizes = option_list["reorder_register_tile_sizes"]
hoist_packing = option_list['hoist_packing']
split_vector_transfer = option_list['split_vector_transfers_to']
extract_micro_kernel = option_list['extract_micro_kernel']
modulo_scheduling = option_list['modulo_scheduling']

Canonicalize = " --canonicalize --cse"
CodegenDriver = "--linalg-tensor-codegen-driver=\"anchor-func=gemm anchor-op=linalg.generic"

# Transformations
OuterTiling = CodegenDriver + f" tile-sizes={tile_sizes} tile-interchange={reorder_tile_sizes}\"" + Canonicalize

InnerTiling = CodegenDriver + f" tile-sizes={register_tile_sizes} tile-interchange={reorder_register_tile_sizes}" + \
f" pad pack-paddings=1,1,0 hoist-paddings={hoist_packing} \"" + Canonicalize

DecomposeToLowerDimensionalNamedOp = CodegenDriver + " decompose-to-lower-dim\"" + Canonicalize

Vectorize = CodegenDriver + " vectorize vectorize-padding\"" + Canonicalize

Bufferize = "--linalg-bufferization-driver" + Canonicalize

LowerVector = "--linalg-vector-lowering=\"max-transfer-rank=1 " +\
f" split-transfers={split_vector_transfer}" +\
" lower-vector-transpose-to=eltwise" +\
" lower-vector-multi-reduction-to=innerparallel" +\
" lower-vector-contraction-to=outerproduct" +\
" unroll-vector-transfers=true"

LowerVectorStage = lambda stage : LowerVector+f" lower-vector-stage={stage}\"" + Canonicalize

ExtractKernel = "--alp-extract-kernel" + Canonicalize if extract_micro_kernel else ""
ModuloScheduling = "--alp-modulo-scheduling" if modulo_scheduling else "" # TODO: Order is not preserved if I canonicalize

LowerToLLVM = "--convert-vector-to-scf " +\
"--convert-linalg-to-loops " +\
"--canonicalize " +\
"--lower-affine " +\
"--convert-scf-to-std " +\
"--convert-linalg-to-llvm " +\
"--convert-vector-to-llvm " +\
"--convert-math-to-llvm " +\
"--convert-memref-to-llvm " +\
"--convert-std-to-llvm " +\
"--canonicalize " +\
"--cse " +\
"--reconcile-unrealized-casts "


TransformList = [OuterTiling,
InnerTiling,
SaveIR(4, "tile"),
DecomposeToLowerDimensionalNamedOp,
Vectorize,
SaveIR(4, "vectorize"),
Bufferize,
SaveIR(4, "bufferize"),
LowerVectorStage(0),
SaveIR(4, "lower_vector"),
ExtractKernel,
ModuloScheduling,
SaveIR(4, "micro_kernel"),
LowerVectorStage(1),
LowerVectorStage(2),
LowerVectorStage(3),
LowerVectorStage(4),
LowerVectorStage(5),
LowerVectorStage(6),
SaveIR(4, "micro_kernel_final"),
LowerToLLVM]

op_llvm_mlir = apply(TransformList, op_mlir, verbosity_level)

out = run_command(["$IREE_LLVM_SANDBOX_BUILD_DIR/bin/mlir-translate --mlir-to-llvmir " + op_llvm_mlir])
op_llvm = mktmp_fn(f"{op}.ll")
f = open(f"{op_llvm}", "w")
f.write(out)
f.close()

op_obj = mktmp_fn(f"{op}.o")
op_asm = mktmp_fn(f"{op}.s")

cmd = ["llc"]
cmd.append(op_llvm)
cmd.append("-O3")
cmd.append("-filetype=obj")
cmd.append(f"-o {op_obj}")
run_command(cmd)

cmd = ["llc"]
cmd.append(f"{op_llvm}")
cmd.append("-O3")
cmd.append("-filetype=asm")
cmd.append(f"-o {op_asm}")
run_command(cmd)

def link_main(op, mktmp_fn):
out_bin = "exec_matmul"
main_obj = mktmp_fn("test.o")
op_obj = mktmp_fn(f"{op}.o")

cmd = ["clang++"]
cmd.append(f"{main_obj}")
cmd.append(f"{op_obj}")
cmd.append(f"-o {out_bin}")
cmd.append("-lmlir_c_runner_utils")
print_command(cmd)
run_command(cmd)

def build_mlir(op, m, n, k, options):
verbose = ("verbosity_level" in options) and options["verbosity_level"] > 0
reps= 1
if options["reps"]:
reps = options["reps"]

if verbose:
Path("./tmp").mkdir(exist_ok=True)
tmp_dir_name = "./tmp"
verbosity_level=options["verbosity_level"]
else:
tmp_dir = tempfile.TemporaryDirectory()
tmp_dir_name = tmp_dir.name
verbosity_level=0

(benchmark, op_mlir)= gemm(False)
mktmp = lambda x : os.path.join(tmp_dir_name, x)
build_main_obj(benchmark, m, n, k, op, reps, mktmp)
build_operator_obj(op_mlir, m, n, k, op, options, mktmp, verbosity_level)
link_main(op, mktmp)
135 changes: 135 additions & 0 deletions experimental/alp/alp/library/blas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
typest = """
!memref_type_A = type tensor<_K_x_M_xf32>
!memref_type_B = type tensor<_K_x_N_xf32>
!memref_type_C = type tensor<_M_x_N_xf32>
"""

types = """
!memref_type_A = type tensor<_M_x_K_xf32>
!memref_type_B = type tensor<_K_x_N_xf32>
!memref_type_C = type tensor<_M_x_N_xf32>
"""

init_tensors = """
%A0 = linalg.init_tensor [_M_,_K_] : !memref_type_A
%B0 = linalg.init_tensor [_K_,_N_] : !memref_type_B
%C = linalg.init_tensor [_M_, _N_] : !memref_type_C
"""

init_tensors_t = """
%A0 = linalg.init_tensor [_K_,_M_] : !memref_type_A
%B0 = linalg.init_tensor [_K_,_N_] : !memref_type_B
%C = linalg.init_tensor [_M_, _N_] : !memref_type_C
"""


gemm_benchmark = f"""
func @main() -> i32 {{
call @print_pid() : () -> ()
__INIT_TENSORS__
%elem = arith.constant 1.0 : f32
%A = linalg.fill(%elem, %A0) : f32, !memref_type_A -> !memref_type_A
%B = linalg.fill(%elem, %B0) : f32, !memref_type_B -> !memref_type_B
%out = call @gemm(%A, %B, %C) : (!memref_type_A, !memref_type_B, !memref_type_C) -> !memref_type_C
%reps = arith.constant _REPS_ : index
%t_start = call @rtclock() : () -> f64
affine.for %arg0 = 0 to %reps {{
call @gemm(%A, %B, %C) : (!memref_type_A, !memref_type_B, !memref_type_C) -> !memref_type_C
}}
%t_end = call @rtclock() : () -> f64
%repsi = arith.index_cast %reps : index to i64
%repsf = arith.sitofp %repsi: i64 to f64
%t_tot = arith.subf %t_end, %t_start : f64
%t = arith.divf %t_tot, %repsf : f64
call @print_time(%t) : (f64) -> ()
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%M = tensor.dim %C, %c0 : !memref_type_C
%N = tensor.dim %C, %c1 : !memref_type_C
%K = tensor.dim %A, %c0 : !memref_type_A
%Mi32 = arith.index_cast %M: index to i64
%Ni32 = arith.index_cast %N: index to i64
%Ki32 = arith.index_cast %K: index to i64
%c2 = arith.constant 2 : i64
%f1 = arith.muli %Mi32, %Ni32 : i64
%f2 = arith.muli %f1, %Ki32 : i64
%f3 = arith.muli %c2, %f2 : i64
// 2*M*N*K.
%num_flops_f = arith.sitofp %f3: i64 to f64
%flops = arith.divf %num_flops_f, %t : f64
call @print_flops(%flops) : (f64) -> ()
%i0 = arith.constant 0 : i32
return %i0 : i32
}}
func private @print_flops(f64)
func private @print_time(f64)
func private @printNewline()
func private @print_pid()
func private @rtclock() -> f64
func private @print_memref_f32(memref<*xf32>)
func private @gemm(%A : !memref_type_A, %B : !memref_type_B, %C : !memref_type_C) -> !memref_type_C
"""




GEMM = """
func @gemm(%A : !memref_type_A {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false},
%B : !memref_type_B {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false},
%C : !memref_type_C {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> !memref_type_C {
%0 = linalg.generic
{indexing_maps = [affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (k, n)>,
affine_map<(m, n, k) -> (m, n)>],
iterator_types = ["parallel", "parallel", "reduction"]}
ins(%A, %B: !memref_type_A, !memref_type_B)
outs(%C: !memref_type_C) {
^bb0(%a: f32, %b: f32, %c: f32) :
%d = arith.mulf %a, %b: f32
%e = arith.addf %c, %d: f32
linalg.yield %e : f32
} -> !memref_type_C
return %0 : !memref_type_C
}
"""

GEMM_T = """
func @gemm(%A : !memref_type_A {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false},
%B : !memref_type_B {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false},
%C : !memref_type_C {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> !memref_type_C {
%0 = linalg.generic
{indexing_maps = [affine_map<(m, n, k) -> (k, m)>,
affine_map<(m, n, k) -> (k, n)>,
affine_map<(m, n, k) -> (m, n)>],
iterator_types = ["parallel", "parallel", "reduction"]}
ins(%A, %B: !memref_type_A, !memref_type_B)
outs(%C: !memref_type_C) {
^bb0(%a: f32, %b: f32, %c: f32) :
%d = arith.mulf %a, %b: f32
%e = arith.addf %c, %d: f32
linalg.yield %e : f32
} -> !memref_type_C
return %0 : !memref_type_C
}
"""

def gemm(trA):
if trA:
bench = gemm_benchmark.replace("__INIT_TENSORS__", str(init_tensors_t))
return (typest + bench, typest + GEMM_T)
else:
bench = gemm_benchmark.replace("__INIT_TENSORS__", str(init_tensors))
return (types + bench, types+ GEMM)
Loading

0 comments on commit 495f7b2

Please sign in to comment.