diff --git a/CMakeLists.txt b/CMakeLists.txt index a9e75b487c20..0704e05d4653 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -32,6 +32,13 @@ set(MLIR_MAIN_SRC_DIR ${LLVM_MAIN_SRC_DIR}/../mlir) set(MLIR_INCLUDE_DIR ${LLVM_MAIN_SRC_DIR}/../mlir/include) set(MLIR_TABLEGEN_OUTPUT_DIR ${CMAKE_BINARY_DIR}/tools/mlir/include) +if (SANDBOX_ENABLE_ALP) + # ALP includes + add_compile_definitions("SANDBOX_ENABLE_ALP") + include_directories(experimental/alp/include) + include_directories(${CMAKE_BINARY_DIR}/tools/sandbox/experimental/alp/include) +endif() + list(APPEND CMAKE_MODULE_PATH ${MLIR_MAIN_SRC_DIR}/cmake/modules) list(APPEND CMAKE_MODULE_PATH ${LLVM_MAIN_SRC_DIR}/cmake) set(MLIR_TABLEGEN_EXE mlir-tblgen) @@ -80,3 +87,8 @@ add_subdirectory(lib) add_subdirectory(python) add_subdirectory(test) add_subdirectory(tools) +if (SANDBOX_ENABLE_ALP) +# ALP cmake files +add_subdirectory(experimental) +endif() + diff --git a/configure.py b/configure.py index e6a63b1ebac3..db018fb2762d 100755 --- a/configure.py +++ b/configure.py @@ -23,6 +23,11 @@ def parse_arguments(): help="Build with ENABLE_LLD=ON (optional)", dest="enable_lld", default = False) + parser.add_argument("--enable-alp", + help="Build with SANDBOX_ENABLE_ALP=ON (optional)", + dest="enable_alp", + action="store_false", + default = False) parser.add_argument("--no-ccache", help="Disables ccache (if available)", dest="enable_ccache", @@ -102,6 +107,10 @@ def main(args): else: print("WARNING: LLD (ld.lld) not found on path. Configure may fail.") + # Optionally enable Alp + if args.enable_alp: + llvm_configure_args.append("-DSANBOX_ENABLE_ALP") + # Detect ccache. if args.enable_ccache: ccache_path = shutil.which("ccache") diff --git a/experimental/CMakeLists.txt b/experimental/CMakeLists.txt new file mode 100644 index 000000000000..246f5f08c24d --- /dev/null +++ b/experimental/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(alp) diff --git a/experimental/alp/CMakeLists.txt b/experimental/alp/CMakeLists.txt new file mode 100644 index 000000000000..5fb4975f1e74 --- /dev/null +++ b/experimental/alp/CMakeLists.txt @@ -0,0 +1,14 @@ +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) + +set(LLVM_LINK_COMPONENTS + Core + Support + nativecodegen + native + OrcJIT + ) + +include_directories(include/) +add_subdirectory(include) +add_subdirectory(lib) diff --git a/experimental/alp/alp/compile_op.py b/experimental/alp/alp/compile_op.py new file mode 100644 index 000000000000..acbd7d1d50ee --- /dev/null +++ b/experimental/alp/alp/compile_op.py @@ -0,0 +1,219 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import os +import tempfile +import shutil +from pathlib import Path +from .utils import print_command, run_and_save, run_command, add_extension +from .library.blas import gemm + +def build_main_obj(benchmark_prog, m, n, k, op, reps, mktmp_fn): + benchmark_prog = benchmark_prog.replace("_M_", str(m)) + benchmark_prog = benchmark_prog.replace("_K_", str(k)) + benchmark_prog = benchmark_prog.replace("_N_", str(n)) + benchmark_prog = benchmark_prog.replace("__OP__", op) + benchmark_prog = benchmark_prog.replace("_REPS_", str(reps)) + + main_mlir = mktmp_fn("test.mlir") + main_mlir_lowered = mktmp_fn("test.llvm.mlir") + main_llvm = mktmp_fn("test.ll") + main_obj = mktmp_fn("test.o") + + f = open(main_mlir, "w") + f.write(benchmark_prog) + f.close() + + # main program + cmd = ["mlir-opt"] + cmd.append(main_mlir) + cmd.append("--linalg-bufferize") + cmd.append("--std-bufferize") + cmd.append("--tensor-constant-bufferize") + cmd.append("--tensor-bufferize") + cmd.append("--func-bufferize") + cmd.append("-convert-linalg-to-affine-loops") + cmd.append("-lower-affine") + cmd.append("-convert-scf-to-std") + cmd.append("-convert-memref-to-llvm") + cmd.append("-convert-std-to-llvm") + cmd.append("-reconcile-unrealized-casts") + cmd.append(f"> {main_mlir_lowered}") + run_command(cmd) + print_command(cmd) + + cmd = ["mlir-translate"] + cmd.append("--mlir-to-llvmir") + cmd.append(f"{main_mlir_lowered}") + cmd.append(f"> {main_llvm}") + run_command(cmd) + + cmd = ["llc"] + cmd.append(f"{main_llvm}") + cmd.append("-O3") + cmd.append("-filetype=obj") + cmd.append(f"-o {main_obj}") + run_command(cmd) + + +def apply(transform_list, op_mlir_file, verbosity_level): + cmd = ["$IREE_LLVM_SANDBOX_BUILD_DIR/bin/mlir-proto-opt "] + for t in transform_list: + if not t: + continue + if type(t) is tuple: + (l, ext) = t + if l >= verbosity_level: + run_and_save(cmd, op_mlir_file, add_extension(op_mlir_file, ext)) + else: + cmd.append(t) + + output = add_extension(op_mlir_file, "llvm") + run_and_save(cmd, op_mlir_file, output) + return output + +def SaveIR(x, ext): + return (x, ext) + +def build_operator_obj(op_prog, m, n, k, op, option_list, mktmp_fn, verbosity_level=0): + op_prog = op_prog.replace("_M_", str(m)) + op_prog = op_prog.replace("_K_", str(k)) + op_prog = op_prog.replace("_N_", str(n)) + + op_mlir = mktmp_fn(f"{op}.mlir") + f = open(f"{op_mlir}", "w") + f.write(op_prog) + f.close() + + # Transformation options + tile_sizes = option_list["tile_sizes"] + reorder_tile_sizes = option_list["reorder_tile_sizes"] + register_tile_sizes = option_list["register_tile_sizes"] + reorder_register_tile_sizes = option_list["reorder_register_tile_sizes"] + hoist_packing = option_list['hoist_packing'] + split_vector_transfer = option_list['split_vector_transfers_to'] + extract_micro_kernel = option_list['extract_micro_kernel'] + modulo_scheduling = option_list['modulo_scheduling'] + + Canonicalize = " --canonicalize --cse" + CodegenDriver = "--linalg-tensor-codegen-driver=\"anchor-func=gemm anchor-op=linalg.generic" + + # Transformations + OuterTiling = CodegenDriver + f" tile-sizes={tile_sizes} tile-interchange={reorder_tile_sizes}\"" + Canonicalize + + InnerTiling = CodegenDriver + f" tile-sizes={register_tile_sizes} tile-interchange={reorder_register_tile_sizes}" + \ + f" pad pack-paddings=1,1,0 hoist-paddings={hoist_packing} \"" + Canonicalize + + DecomposeToLowerDimensionalNamedOp = CodegenDriver + " decompose-to-lower-dim\"" + Canonicalize + + Vectorize = CodegenDriver + " vectorize vectorize-padding\"" + Canonicalize + + Bufferize = "--linalg-bufferization-driver" + Canonicalize + + LowerVector = "--linalg-vector-lowering=\"max-transfer-rank=1 " +\ + f" split-transfers={split_vector_transfer}" +\ + " lower-vector-transpose-to=eltwise" +\ + " lower-vector-multi-reduction-to=innerparallel" +\ + " lower-vector-contraction-to=outerproduct" +\ + " unroll-vector-transfers=true" + + LowerVectorStage = lambda stage : LowerVector+f" lower-vector-stage={stage}\"" + Canonicalize + + ExtractKernel = "--alp-extract-kernel" + Canonicalize if extract_micro_kernel else "" + ModuloScheduling = "--alp-modulo-scheduling" if modulo_scheduling else "" # TODO: Order is not preserved if I canonicalize + + LowerToLLVM = "--convert-vector-to-scf " +\ + "--convert-linalg-to-loops " +\ + "--canonicalize " +\ + "--lower-affine " +\ + "--convert-scf-to-std " +\ + "--convert-linalg-to-llvm " +\ + "--convert-vector-to-llvm " +\ + "--convert-math-to-llvm " +\ + "--convert-memref-to-llvm " +\ + "--convert-std-to-llvm " +\ + "--canonicalize " +\ + "--cse " +\ + "--reconcile-unrealized-casts " + + + TransformList = [OuterTiling, + InnerTiling, + SaveIR(4, "tile"), + DecomposeToLowerDimensionalNamedOp, + Vectorize, + SaveIR(4, "vectorize"), + Bufferize, + SaveIR(4, "bufferize"), + LowerVectorStage(0), + SaveIR(4, "lower_vector"), + ExtractKernel, + ModuloScheduling, + SaveIR(4, "micro_kernel"), + LowerVectorStage(1), + LowerVectorStage(2), + LowerVectorStage(3), + LowerVectorStage(4), + LowerVectorStage(5), + LowerVectorStage(6), + SaveIR(4, "micro_kernel_final"), + LowerToLLVM] + + op_llvm_mlir = apply(TransformList, op_mlir, verbosity_level) + + out = run_command(["$IREE_LLVM_SANDBOX_BUILD_DIR/bin/mlir-translate --mlir-to-llvmir " + op_llvm_mlir]) + op_llvm = mktmp_fn(f"{op}.ll") + f = open(f"{op_llvm}", "w") + f.write(out) + f.close() + + op_obj = mktmp_fn(f"{op}.o") + op_asm = mktmp_fn(f"{op}.s") + + cmd = ["llc"] + cmd.append(op_llvm) + cmd.append("-O3") + cmd.append("-filetype=obj") + cmd.append(f"-o {op_obj}") + run_command(cmd) + + cmd = ["llc"] + cmd.append(f"{op_llvm}") + cmd.append("-O3") + cmd.append("-filetype=asm") + cmd.append(f"-o {op_asm}") + run_command(cmd) + +def link_main(op, mktmp_fn): + out_bin = "exec_matmul" + main_obj = mktmp_fn("test.o") + op_obj = mktmp_fn(f"{op}.o") + + cmd = ["clang++"] + cmd.append(f"{main_obj}") + cmd.append(f"{op_obj}") + cmd.append(f"-o {out_bin}") + cmd.append("-lmlir_c_runner_utils") + print_command(cmd) + run_command(cmd) + +def build_mlir(op, m, n, k, options): + verbose = ("verbosity_level" in options) and options["verbosity_level"] > 0 + reps= 1 + if options["reps"]: + reps = options["reps"] + + if verbose: + Path("./tmp").mkdir(exist_ok=True) + tmp_dir_name = "./tmp" + verbosity_level=options["verbosity_level"] + else: + tmp_dir = tempfile.TemporaryDirectory() + tmp_dir_name = tmp_dir.name + verbosity_level=0 + + (benchmark, op_mlir)= gemm(False) + mktmp = lambda x : os.path.join(tmp_dir_name, x) + build_main_obj(benchmark, m, n, k, op, reps, mktmp) + build_operator_obj(op_mlir, m, n, k, op, options, mktmp, verbosity_level) + link_main(op, mktmp) diff --git a/experimental/alp/alp/library/blas.py b/experimental/alp/alp/library/blas.py new file mode 100644 index 000000000000..5a1db1fdd55d --- /dev/null +++ b/experimental/alp/alp/library/blas.py @@ -0,0 +1,135 @@ +typest = """ +!memref_type_A = type tensor<_K_x_M_xf32> +!memref_type_B = type tensor<_K_x_N_xf32> +!memref_type_C = type tensor<_M_x_N_xf32> +""" + +types = """ +!memref_type_A = type tensor<_M_x_K_xf32> +!memref_type_B = type tensor<_K_x_N_xf32> +!memref_type_C = type tensor<_M_x_N_xf32> +""" + +init_tensors = """ + %A0 = linalg.init_tensor [_M_,_K_] : !memref_type_A + %B0 = linalg.init_tensor [_K_,_N_] : !memref_type_B + %C = linalg.init_tensor [_M_, _N_] : !memref_type_C +""" + +init_tensors_t = """ + %A0 = linalg.init_tensor [_K_,_M_] : !memref_type_A + %B0 = linalg.init_tensor [_K_,_N_] : !memref_type_B + %C = linalg.init_tensor [_M_, _N_] : !memref_type_C +""" + + +gemm_benchmark = f""" +func @main() -> i32 {{ + call @print_pid() : () -> () + __INIT_TENSORS__ + + %elem = arith.constant 1.0 : f32 + %A = linalg.fill(%elem, %A0) : f32, !memref_type_A -> !memref_type_A + %B = linalg.fill(%elem, %B0) : f32, !memref_type_B -> !memref_type_B + + %out = call @gemm(%A, %B, %C) : (!memref_type_A, !memref_type_B, !memref_type_C) -> !memref_type_C + %reps = arith.constant _REPS_ : index + %t_start = call @rtclock() : () -> f64 + affine.for %arg0 = 0 to %reps {{ + call @gemm(%A, %B, %C) : (!memref_type_A, !memref_type_B, !memref_type_C) -> !memref_type_C + }} + %t_end = call @rtclock() : () -> f64 + %repsi = arith.index_cast %reps : index to i64 + %repsf = arith.sitofp %repsi: i64 to f64 + %t_tot = arith.subf %t_end, %t_start : f64 + %t = arith.divf %t_tot, %repsf : f64 + + call @print_time(%t) : (f64) -> () + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + %M = tensor.dim %C, %c0 : !memref_type_C + %N = tensor.dim %C, %c1 : !memref_type_C + %K = tensor.dim %A, %c0 : !memref_type_A + + %Mi32 = arith.index_cast %M: index to i64 + %Ni32 = arith.index_cast %N: index to i64 + %Ki32 = arith.index_cast %K: index to i64 + + %c2 = arith.constant 2 : i64 + %f1 = arith.muli %Mi32, %Ni32 : i64 + %f2 = arith.muli %f1, %Ki32 : i64 + %f3 = arith.muli %c2, %f2 : i64 + + // 2*M*N*K. + %num_flops_f = arith.sitofp %f3: i64 to f64 + %flops = arith.divf %num_flops_f, %t : f64 + call @print_flops(%flops) : (f64) -> () + + %i0 = arith.constant 0 : i32 + return %i0 : i32 +}} + + +func private @print_flops(f64) +func private @print_time(f64) +func private @printNewline() +func private @print_pid() +func private @rtclock() -> f64 +func private @print_memref_f32(memref<*xf32>) +func private @gemm(%A : !memref_type_A, %B : !memref_type_B, %C : !memref_type_C) -> !memref_type_C +""" + + + + +GEMM = """ +func @gemm(%A : !memref_type_A {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %B : !memref_type_B {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %C : !memref_type_C {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> !memref_type_C { + %0 = linalg.generic + {indexing_maps = [affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)>], + iterator_types = ["parallel", "parallel", "reduction"]} + + ins(%A, %B: !memref_type_A, !memref_type_B) + outs(%C: !memref_type_C) { + ^bb0(%a: f32, %b: f32, %c: f32) : + %d = arith.mulf %a, %b: f32 + %e = arith.addf %c, %d: f32 + linalg.yield %e : f32 + } -> !memref_type_C + return %0 : !memref_type_C + } +""" + +GEMM_T = """ +func @gemm(%A : !memref_type_A {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %B : !memref_type_B {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %C : !memref_type_C {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> !memref_type_C { + %0 = linalg.generic + {indexing_maps = [affine_map<(m, n, k) -> (k, m)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)>], + iterator_types = ["parallel", "parallel", "reduction"]} + + ins(%A, %B: !memref_type_A, !memref_type_B) + outs(%C: !memref_type_C) { + ^bb0(%a: f32, %b: f32, %c: f32) : + %d = arith.mulf %a, %b: f32 + %e = arith.addf %c, %d: f32 + linalg.yield %e : f32 + } -> !memref_type_C + return %0 : !memref_type_C + } +""" + +def gemm(trA): + if trA: + bench = gemm_benchmark.replace("__INIT_TENSORS__", str(init_tensors_t)) + return (typest + bench, typest + GEMM_T) + else: + bench = gemm_benchmark.replace("__INIT_TENSORS__", str(init_tensors)) + return (types + bench, types+ GEMM) diff --git a/experimental/alp/alp/mlirc.py b/experimental/alp/alp/mlirc.py new file mode 100644 index 000000000000..f46fdaf5c9b7 --- /dev/null +++ b/experimental/alp/alp/mlirc.py @@ -0,0 +1,60 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import sys +import argparse +from .utils import parse, run_command, print_command +from .compile_op import build_mlir + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("mlirc") + + # GEMM size + parser.add_argument("--M", type=int) + parser.add_argument("--N", type=int) + parser.add_argument("--K", type=int) + + # Outer tiling + parser.add_argument("--tile-sizes", nargs='+', type=int) + parser.add_argument("--reorder-tile-sizes", nargs='+', type=int) + + # Inner tiling + parser.add_argument("--register-tile-sizes", nargs='+', type=int) + parser.add_argument("--reorder-register-tile-sizes", nargs='+', type=int) + parser.add_argument("--hoist-packing", nargs='+', type=int) + + # Vector lowering + parser.add_argument("--unroll-vector-transfers", action="store_true") + parser.add_argument("--split-vector-transfers-to") + + # micro-kernel transforms + parser.add_argument("--extract-micro-kernel", action="store_true") + parser.add_argument("--modulo-scheduling", action="store_true") + + # Verbosity + parser.add_argument("--verbose", action="store_true") + parser.add_argument("--verbosity-level", type=int, default=0) + parser.add_argument("--reps", type=int, default=1) + + args = parser.parse_args() + + stringify = lambda l : ','.join([str(e) for e in l]) + options = { "tile_sizes" : stringify(args.tile_sizes), + "register_tile_sizes" : stringify(args.register_tile_sizes), + "split_vector_transfers_to" : args.split_vector_transfers_to, + "unroll_vector_transfers" : args.unroll_vector_transfers, + "reorder_tile_sizes": stringify(args.reorder_tile_sizes), + "reorder_register_tile_sizes": stringify(args.reorder_register_tile_sizes), + "hoist_packing": stringify(args.hoist_packing), + "extract_micro_kernel": args.extract_micro_kernel, + "modulo_scheduling": args.modulo_scheduling, + "verbosity_level" : 0, + "reps": args.reps + } + + if (args.verbose): + options["verbosity_level"]=1 + if (args.verbosity_level > 0): + options["verbosity_level"]=args.verbosity_level + build_mlir("gemm", args.M, args.N, args.K, options) diff --git a/experimental/alp/alp/tuner.py b/experimental/alp/alp/tuner.py new file mode 100644 index 000000000000..827c6483e202 --- /dev/null +++ b/experimental/alp/alp/tuner.py @@ -0,0 +1,149 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#!/usr/bin/env python +import opentuner +from opentuner import ConfigurationManipulator +from opentuner.search.manipulator import IntegerParameter, PowerOfTwoParameter, EnumParameter, BooleanParameter +from opentuner import MeasurementInterface +from opentuner import Result +import sys + +from .utils import assemble_options_from_list, parse + +max_flops = 0 +class MLIRFlagsTuner(MeasurementInterface): + + def manipulator(self): + """ + Define the search space by creating a + ConfigurationManipulator + """ + manipulator = ConfigurationManipulator() + + manipulator.add_parameter( + PowerOfTwoParameter('mr', 4, 4)) + + manipulator.add_parameter( + PowerOfTwoParameter('nr', 16, 16)) + + manipulator.add_parameter( + PowerOfTwoParameter('kr', 16, 64)) + + manipulator.add_parameter( + PowerOfTwoParameter('kc', 64, 128)) + + manipulator.add_parameter( + PowerOfTwoParameter('mc', 256, 2048)) + + manipulator.add_parameter( + PowerOfTwoParameter('nc', 64, 2048)) + + manipulator.add_parameter( + IntegerParameter('ha', 4 , 4)) + + manipulator.add_parameter( + IntegerParameter('hb', 3 , 3)) + + return manipulator + + def run(self, desired_result, input, limit): + global max_flops + + + """ + Compile and run a given configuration then + return performance + """ + + cfg = desired_result.configuration.data + + mr = cfg['mr'] + nr = cfg['nr'] + kr = cfg['kr'] + kc = cfg['kc'] + mc = cfg['mc'] + nc = cfg['nc'] + ha = cfg['ha'] + hb = cfg['hb'] + # reordering = cfg['reorder'] + + M = self.args.M + N = self.args.N + K = self.args.K + + # mr = min(mr,mc) + # nr = min(nr,nc) + # kr = min(kr, kc) + # kr = kc + + cfg['mr'] = mr + cfg['nr'] = nr + cfg['kr'] = kr + reordering = "Afirst" + + if reordering == "Afirst": + reorder_inner = "0 1 2" + reorder_outer = "0 2 1" + else: + reorder_inner = "1 0 2" + reorder_outer = "1 2 0" + + hoisting_params = f"{ha} {hb} 0" + cmd = ['python3 -m alp.mlirc'] + cmd.append(f'--M {M}') + cmd.append(f'--N {N}') + cmd.append(f'--K {K}') + + cmd.append(f"--tile-sizes {mc} {nc} {kc}") + cmd.append(f"--register-tile-sizes {mr} {nr} {kr}") + cmd.append(f"--reorder-tile-sizes {reorder_outer}") + cmd.append(f"--reorder-register-tile-sizes {reorder_inner}") + + #if cfg['unrollVectorTransfers']: + cmd.append(f"--unroll-vector-transfers") + cmd.append(f"--split-vector-transfers-to none") # {cfg['splitVectorTransfersTo']}") + cmd.append(f"--hoist-packing {hoisting_params}") + + compile_result = self.call_program(' '.join(cmd)) + + + if compile_result['returncode'] != 0: + return Result(time=sys.maxsize) + + assert compile_result['returncode'] == 0 + + run_cmd = './exec_matmul' + run_result = self.call_program(run_cmd, limit=0.7) + + if run_result['returncode'] != 0: + return Result(time=sys.maxsize) + + assert run_result['returncode'] == 0 + + secs, flops = parse(run_result['stderr']) + + if(flops>max_flops): + s = ' '.join([str(elem) for elem in cmd]) + max_flops=flops + + + return Result(time=1/flops) + + def save_final_config(self, configuration): + """called at the end of tuning""" + print("Optimal block size written to mmm_final_config.json:", configuration.data) + M = self.args.M + N = self.args.N + K = self.args.K + self.manipulator().save_to_file(configuration.data, + f'mmm_final_config_{M}_{N}_{K}.json') + + +if __name__ == '__main__': + argparser = opentuner.default_argparser() + argparser.add_argument("--M", type=int) + argparser.add_argument("--N", type=int) + argparser.add_argument("--K", type=int) + MLIRFlagsTuner.main(argparser.parse_args()) diff --git a/experimental/alp/alp/utils.py b/experimental/alp/alp/utils.py new file mode 100644 index 000000000000..9926a2ad914c --- /dev/null +++ b/experimental/alp/alp/utils.py @@ -0,0 +1,102 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import subprocess +import os +import numpy as np +from subprocess import PIPE, Popen + +def run_command(cmd): + # print(cmd) + output = subprocess.check_output(' '.join(cmd), shell=True) + return output.decode('ascii') + +def print_command(cmd): + print(' '.join(cmd)) + +def run_and_save(cmd, original_ir, new_ir): + out = run_command(cmd + [original_ir]) + f = open(f"{new_ir}", "w") + f.write(out) + f.close() + +def add_extension(fname, ext): + orig_ext = os.path.splitext(fname)[1] + newfilename = os.path.splitext(fname)[0] + "." + ext + orig_ext + return newfilename + +def parse(out): + secs = 0 + flops = 0 + lines = out.split('\n') + for l in lines: + if not l: + continue + [a,b]= l.split() + if b == "secs": + secs = float(a) + if b == "GFLOPS": + flops = float(a) + return (secs, flops) + +def analytical_model(hw, Sdata): + # Analyitical model for GEMM + # https://www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf + + # Vector unit properties + Nvec = hw["Nvec"] + Lvfma = hw["Lvfma"] + Nvfma = hw["Nvfma"] + + # Determine mr/nr + K = Nvec*Nvfma*Lvfma + mr = np.ceil((np.sqrt(K)/Nvec))*Nvec + nr = np.ceil(K/mr) + + # L1 properties + SL1 = hw["SL"][0]*1024 + WL1 = hw["WL"][0] + + # L2 properties + SL2 = hw["SL"][1] *1024 + WL2 = hw["WL"][1] + + if "CL" in hw: + CL1 = hw["CL"][0] + CL2 = hw["CL"][1] + NL1 = SL1/(WL1*CL1) + NL2 = SL2/(WL2*CL2) + elif "NL" in hw: + NL1 = hw["NL"][0] + NL2 = hw["NL"][1] + CL1 = SL1/(WL1*NL1) + CL2 = SL2/(WL2*NL2) + + # if L3 properties are specified, then determine nc + if hw["num_caches"] == 3: + SL3 = hw["SL"][2] * 1024 + WL3 = hw["WL"][2] + + if "CL" in hw: + CL3 = hw["CL"][2] + NL3 = SL3/(WL3*CL3) + elif "NL" in hw: + NL3 = hw["NL"][2] + CL3 = SL3/(WL3*NL3) + + # Determine kc + CAr = np.floor((WL1-1)/(1+nr/mr)) + kc = (CAr*NL1*CL1)/(mr*Sdata) + + # Determine mc + CBr2 = np.ceil(nr*kc*Sdata/(NL2*CL2)) + mc = ( (WL2-1-CBr2)*NL2*CL2/(kc*Sdata)) + + # Determine nc + if hw["num_caches"] == 3: + CAc3 = np.ceil(mc*kc*Sdata/(NL3*CL3)) + nc = ((WL3-CAc3-1)*NL3*CL3)/(kc*Sdata) + else: + nc = -1 + + return (mc, nc, kc, mr, nr) diff --git a/experimental/alp/include/CMakeLists.txt b/experimental/alp/include/CMakeLists.txt new file mode 100644 index 000000000000..246f5f08c24d --- /dev/null +++ b/experimental/alp/include/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(alp) diff --git a/experimental/alp/include/alp/CMakeLists.txt b/experimental/alp/include/alp/CMakeLists.txt new file mode 100644 index 000000000000..e31af3266116 --- /dev/null +++ b/experimental/alp/include/alp/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Transforms) diff --git a/experimental/alp/include/alp/Transforms/CMakeLists.txt b/experimental/alp/include/alp/Transforms/CMakeLists.txt new file mode 100644 index 000000000000..65311815ee0b --- /dev/null +++ b/experimental/alp/include/alp/Transforms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name ALP) +add_public_tablegen_target(ALPPassIncGen) diff --git a/experimental/alp/include/alp/Transforms/PassDetail.h b/experimental/alp/include/alp/Transforms/PassDetail.h new file mode 100644 index 000000000000..587cabcd8c54 --- /dev/null +++ b/experimental/alp/include/alp/Transforms/PassDetail.h @@ -0,0 +1,46 @@ +//===- PassDetail.h - Pass class details ------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef IREE_LLVM_SANDBOX_PASSDETAIL_H_ +#define IREE_LLVM_SANDBOX_PASSDETAIL_H_ + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +// Forward declaration from Dialect.h +template +void registerDialect(DialectRegistry ®istry); + +namespace linalg { +class LinalgDialect; +} // end namespace linalg + +namespace scf { +class SCFDialect; +} // end namespace scf + +namespace memref { +class MemRefDialect; +} // end namespace memref + +namespace tensor { +class TensorDialect; +} // end namespace tensor + +namespace vector { +class VectorDialect; +} // end namespace vector + +#define GEN_PASS_CLASSES +#include "alp/Transforms/Passes.h.inc" + +} // end namespace mlir + +#endif // IREE_LLVM_SANDBOX_PASSDETAIL_H_ diff --git a/experimental/alp/include/alp/Transforms/Passes.h b/experimental/alp/include/alp/Transforms/Passes.h new file mode 100644 index 000000000000..ea46e027b78e --- /dev/null +++ b/experimental/alp/include/alp/Transforms/Passes.h @@ -0,0 +1,36 @@ +//===- Passes.h - Alp pass entry points ----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines prototypes that expose pass constructors. +// +//===----------------------------------------------------------------------===// + +#ifndef ALP_LLVM_SANDBOX_PASSES_H +#define ALP_LLVM_SANDBOX_PASSES_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +/// Create a pass extract the kernel function out +std::unique_ptr> createExtractKernelPass(); + +/// Create a pass to modulo-schedule the kernel +std::unique_ptr createModuloSchedulingPass(); + +//===----------------------------------------------------------------------===// +// Registration +//===----------------------------------------------------------------------===// + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "alp/Transforms/Passes.h.inc" + +} // namespace mlir + +#endif // ALP_LLVM_SANDBOX_PASSES_H diff --git a/experimental/alp/include/alp/Transforms/Passes.td b/experimental/alp/include/alp/Transforms/Passes.td new file mode 100644 index 000000000000..1079966729d3 --- /dev/null +++ b/experimental/alp/include/alp/Transforms/Passes.td @@ -0,0 +1,35 @@ +//===-- Passes.td - Alp pass definition file -----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef ALP_LLVM_SANDBOX_PASSES +#define ALP_LLVM_SANDBOX_PASSES + +include "mlir/Pass/PassBase.td" + +def ExtractKernelPass: Pass<"alp-extract-kernel", "ModuleOp"> { + let summary = "Pass to extract the kernel in a separate function."; + let constructor = "mlir::createExtractKernelPass()"; + let options = [ + Option<"anchorFuncOpName", "anchor-func", "std::string", /*default=*/"", + "Which func op is the anchor to latch on.">, + //Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"", + // "Which linalg op within the func is the anchor to latch on.">, + ]; +} + +def ModuloSchedulingPass: FunctionPass<"alp-modulo-scheduling"> { + let summary = "Pass to modulo-schedule a loop."; + let constructor = "mlir::createModuloSchedulingPass()"; + let options = [ + + Option<"unrolling", "unrolling", "int", /*default=*/"", + "Unrolling level before scheduling the loop.">, + ]; +} + +#endif // ALP_LLVM_SANDBOX_PASSES diff --git a/experimental/alp/lib/CMakeLists.txt b/experimental/alp/lib/CMakeLists.txt new file mode 100644 index 000000000000..e31af3266116 --- /dev/null +++ b/experimental/alp/lib/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Transforms) diff --git a/experimental/alp/lib/Transforms/CMakeLists.txt b/experimental/alp/lib/Transforms/CMakeLists.txt new file mode 100644 index 000000000000..0a9d4c533e25 --- /dev/null +++ b/experimental/alp/lib/Transforms/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_library(ExperimentalAlpTransforms + extract_kernel_pass.cpp + modulo_scheduling_pass.cpp + + LINK_LIBS PRIVATE + MLIRLinalg + MLIRLinalgTransforms + + DEPENDS + ALPPassIncGen + MLIRLinalg # TODO: Why needed here? + MLIRLinalgTransforms # TODO: Why needed here? +) diff --git a/experimental/alp/lib/Transforms/extract_kernel_pass.cpp b/experimental/alp/lib/Transforms/extract_kernel_pass.cpp new file mode 100644 index 000000000000..c9754681f081 --- /dev/null +++ b/experimental/alp/lib/Transforms/extract_kernel_pass.cpp @@ -0,0 +1,186 @@ +//===-- extract_kernel_pass.cpp - Extract Kernel Pass ------*- c++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "alp/Transforms/PassDetail.h" +#include "alp/Transforms/Passes.h" + +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" + +#include "alp/Transforms/Passes.h" + +#include + +#define DEBUG_TYPE "extract-kernel" +#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") + +using namespace mlir; + +void extract_function(StringRef func_name, Block *block, ModuleOp parentModule, + RewriterBase &rewriter, Location loc) { + SmallVector Input, Output; + + // Create the function (callee site) with an empty block + rewriter.setInsertionPointToStart(parentModule.getBody()); + auto func_op = rewriter.create( + parentModule.getLoc(), func_name, + FunctionType::get(parentModule.getContext(), Input, Output)); + auto entry_block = func_op.addEntryBlock(); + + // Build the dominance tree of the parent op of the block + Region *region = block->getParent(); + Operation *parent_op = region->getParentOp(); + auto dom_info = mlir::DominanceInfo(parent_op); + + // std::set consts; + llvm::SmallVector vals; + llvm::SmallVector consts; + bool add_yield = false; + + // Walk the block and find out all the variables that were defined outside + // this block and are used inside the block (i.e., all the variables x that + // properly dominate the block). The only things we will redefine inside the + // entry block are constants. For all other variables, we will add them as + // inputs to the function + block->walk([&](Operation *inst) { + for (Value val : inst->getOperands()) { + if (dom_info.properlyDominates(val, parent_op)) { + arith::ConstantOp const_op = val.getDefiningOp(); + if (const_op) { + // It's useless to add many times the same index + if (std::find(consts.begin(), consts.end(), const_op) == + consts.end()) { + consts.push_back(const_op); + rewriter.setInsertionPointToStart(entry_block); + Operation *new_const = rewriter.clone(*const_op); + rewriter.replaceOpWithinBlock(const_op, new_const->getResult(0), + block); + } + } else { + if (std::find(vals.begin(), vals.end(), val) == vals.end()) { + func_op.insertArgument(vals.size(), val.getType(), {}); + vals.push_back(val); + } + } + } + } + + // Remove Yield operations and signal to add it from the caller site + // TODO: this is wrong if yield has results/operands connected to it. + // We should clone the yield in the caller block + if (dyn_cast(inst) && inst->getBlock() == block) { + add_yield = true; + rewriter.eraseOp(inst); + } + }); + + llvm::SmallVector newtypes; + // We are not done yet. We need to merge the block into the entry block. To do + // this: 1 If an operation in the block is using a value coming from the block + // argument, add the value as function argument and replace the value with + // it + // 2 If an operation in the block is using a value generated outside the + // block, simply replace its value with a funciton argument + + // Step 1: get all the block arguments, add them as function arguments and + // replece their use inside the block + for (auto block_arg : block->getArguments()) { + func_op.insertArgument(vals.size(), block_arg.getType(), {}); + auto arg = func_op.getArgument(vals.size()); + block_arg.replaceAllUsesWith(arg); + newtypes.push_back(block_arg.getType()); + } + + // Step 2: replace all the values that are pointing outside the block and + // replace them with function arguments + auto args = func_op.getArguments(); + for (unsigned i = 0; i < vals.size(); i++) { + auto val = vals[i]; + auto arg = args[i]; + val.replaceUsesWithIf(arg, [&](OpOperand &op) { + return dom_info.dominates(block, op.getOwner()->getBlock()); + }); + } + + // Save some information about the original block. Once the block is merged + // inside the entry block these information won't be available anymore + bool has_no_successor = block->hasNoSuccessors(); + Block *succ = (has_no_successor ? nullptr : block->getSuccessor(0)); + + // Remove all arguments from the block signature + for (unsigned i = 0; i < block->getNumArguments(); i++) { + block->eraseArgument(i); + } + + // Merge block into entry_block (this destroys block) + rewriter.mergeBlocks(block, entry_block); + + // Add a returnOp into the block to properly terminate it + rewriter.setInsertionPointToEnd(entry_block); + rewriter.create(loc); + + // We are done with the callee. Now we have to work on the caller. The overall + // idea is to insert a new_block right before the successor of the old block. + // If the old block has no successors, then add it at the end of the region + Block *new_block = nullptr; + if (has_no_successor) { + new_block = rewriter.createBlock(region, region->end(), newtypes); + } else { + new_block = rewriter.createBlock(succ, newtypes); + } + + // Remember to add the block arguments as inputs to the function + for (unsigned i = 0; i < new_block->getNumArguments(); i++) { + vals.push_back(new_block->getArgument(i)); + } + + // Create the call + rewriter.create(loc, func_op, vals); + + if (add_yield) { + rewriter.create(loc); + } +} + +// Walk the for loops and find the one that as operands. In GEMM is the +// micro-kernel. +// TODO: we should have the linalg::split to signal the microkernel of the +// operation and use it to run the function extractor if needed +struct ExtractKernelPass : public ExtractKernelPassBase { + + ExtractKernelPass() = default; + ExtractKernelPass(const ExtractKernelPass &pass) {} + void getDependentDialects(DialectRegistry ®istry) const override {} + void runOnOperation() override { + // Get the current FuncOp operation being operated on. + auto module = getOperation(); + scf::ForOp loop; + + for (FuncOp func : module.getOps()) { + // Walk the operations within the function. + func.walk([&](scf::ForOp forop) { + if (forop.getNumIterOperands()) { + loop = forop; + } + }); + } + + IRRewriter rewriter(module.getContext()); + extract_function("kernel", loop->getBlock(), module, rewriter, + module.getLoc()); + } +}; + +std::unique_ptr> mlir::createExtractKernelPass() { + return std::make_unique(); +} \ No newline at end of file diff --git a/experimental/alp/lib/Transforms/modulo_scheduling_pass.cpp b/experimental/alp/lib/Transforms/modulo_scheduling_pass.cpp new file mode 100644 index 000000000000..de7c8ea04d79 --- /dev/null +++ b/experimental/alp/lib/Transforms/modulo_scheduling_pass.cpp @@ -0,0 +1,128 @@ +//===-- modulo_scheduling_pass.cpp - Implement modulo scheduling ------*- c++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "alp/Transforms/PassDetail.h" +#include "alp/Transforms/Passes.h" + +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" +#include + +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/Transforms.h" +#include "mlir/Dialect/SCF/Utils.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/LoopUtils.h" + +using namespace mlir; + +namespace { +struct ModuloSchedulingPass + : public ModuloSchedulingPassBase { + // ModuloScheduling(int unrollFactor):unrollFactor_(unrollFactor){} + ModuloSchedulingPass() = default; + ModuloSchedulingPass(const ModuloSchedulingPass &pass) {} + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnFunction() override { + // Get the current FuncOp operation being operated on. + auto f = getFunction(); + + scf::ForOp loop; + + // Unroll the kernel + f.walk([&](scf::ForOp forop) { + if (forop.getNumIterOperands()) { + loop = forop; + } + }); + + if (loop) { + // Unroll + auto annotateFn = [this](unsigned i, Operation *op, OpBuilder b) { + op->setAttr("unrolled_iteration", b.getUI32IntegerAttr(i)); + }; + + (void)loopUnrollByFactor(loop, 2, annotateFn); + + // Pipeline the kernel + RewritePatternSet patterns(&getContext()); + mlir::scf::PipeliningOption options; + + // Order/stage the instruction within the for loop. We are looking for a + // pattern like %x0 = load -> (stage0, pos3) %y0 = load -> (stage0, pos4) + // %z0 = outerprod(%x0, %y0) -> (stage1, pos2) + // %x1 = load -> (stage1, pos0) + // %y1 = load -> (stage1, pos1) + // %z1 = outerprod(%x1, %y1) -> (stage1, pos5) + std::unordered_map stage_map; + std::unordered_map clock_map; + + int anchor0 = -1; + int anchor1 = -1; + int stage = 0; + int clock = 0; + // Take care of the stages + for (Operation &operation : loop.getBody()->getOperations()) { + Operation *op = &operation; + if (dyn_cast(op)) { + continue; + } + if (anchor0 == -1 && dyn_cast(op)) { + anchor0 = clock; + stage = 1; + } else if (anchor1 == -1 && dyn_cast(op)) { + anchor1 = clock; + } + clock_map[op] = clock++; + stage_map[op] = stage; + } + + // Take care of the clocks + int diff = anchor1 - anchor0; + for (Operation &operation : loop.getBody()->getOperations()) { + Operation *op = &operation; + int clock = clock_map[op]; + // swap the loads + if (dyn_cast(op)) { + continue; + } + if (clock < anchor0) { + clock_map[op] = diff + clock; + } else if (clock == anchor0) { + clock_map[op] = diff - 1; + } else if (clock > anchor0 && clock < anchor1) { + clock_map[op] = clock - anchor0 - 1; + } + } + + options.getScheduleFn = + [&](scf::ForOp forOp, + std::vector> &schedule) { + schedule.resize(forOp.getBody()->getOperations().size() - 1); + for (auto p : clock_map) { + Operation *op = p.first; + int clock = p.second; + int stage = stage_map[op]; + schedule[clock] = {op, stage}; + } + }; + + scf::populateSCFLoopPipeliningPatterns(patterns, options); + (void)applyOpPatternsAndFold(loop, std::move(patterns)); + } + } + int unrollFactor_; +}; +} // namespace + +std::unique_ptr mlir::createModuloSchedulingPass() { + return std::make_unique(); +} diff --git a/tools/mlir-proto-opt/CMakeLists.txt b/tools/mlir-proto-opt/CMakeLists.txt index 518b85df1d4d..9f1225af002a 100644 --- a/tools/mlir-proto-opt/CMakeLists.txt +++ b/tools/mlir-proto-opt/CMakeLists.txt @@ -26,4 +26,11 @@ PRIVATE MLIRVectorExtTestPasses MLIRVectorExtTransform ) +if (SANDBOX_ENABLE_ALP) + target_link_libraries(mlir-proto-opt + PRIVATE + ExperimentalAlpTransforms + ) +endif() + mlir_check_all_link_libraries(mlir-proto-opt) diff --git a/tools/mlir-proto-opt/mlir-proto-opt.cpp b/tools/mlir-proto-opt/mlir-proto-opt.cpp index d67a01705198..cf7af9894685 100644 --- a/tools/mlir-proto-opt/mlir-proto-opt.cpp +++ b/tools/mlir-proto-opt/mlir-proto-opt.cpp @@ -54,17 +54,28 @@ static void registerIreeDialects(DialectRegistry ®istry) {} #endif void registerTestPasses() { registerTestVectorMaskingUtils(); } +#ifdef SANDBOX_ENABLE_ALP +#include "alp/Transforms/Passes.h" +#endif + +static void registerExperimentalPasses() { +#ifdef SANDBOX_ENABLE_ALP + registerALPPasses(); +#endif +} int main(int argc, char **argv) { llvm::InitLLVM y(argc, argv); registerAllPasses(); ireeLlvmSandboxRegisterPasses(); + registerExperimentalPasses(); linalg_ext::registerLinalgExtPasses(); registerTestPasses(); DialectRegistry registry; registerAllDialects(registry); registerIreeDialects(registry); + registry.insert(); linalg_ext::registerTilingInterfaceExternalModels(registry);