Skip to content

Commit

Permalink
[BuddyLeNet] Fix lenet error and format files (#291)
Browse files Browse the repository at this point in the history
* fix lenet error

* format files

* update cmake file

* fix lenet cmake files
  • Loading branch information
weilinquan committed May 10, 2024
1 parent e65fddf commit faec06b
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 26 deletions.
2 changes: 2 additions & 0 deletions examples/BuddyLeNet/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ data
__pycache__
*.pth
lenet.mlir
forward.mlir
subgraph0.mlir
33 changes: 23 additions & 10 deletions examples/BuddyLeNet/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,20 +1,33 @@
add_custom_command(
OUTPUT ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/lenet.mlir ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/arg0.data
OUTPUT ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/forward.mlir ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/subgraph0.mlir ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/arg0.data
COMMAND python3 ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/buddy-lenet-import.py
COMMENT "Generating lenet.mlir and parameter files"
COMMENT "Generating forward.mlir, subgraph0.mlir and parameter files"
)

add_custom_command(
OUTPUT lenet.o
COMMAND ${LLVM_MLIR_BINARY_DIR}/mlir-opt ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/lenet.mlir
-pass-pipeline "builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor, tosa-to-arith))" |
OUTPUT forward.o
COMMAND ${LLVM_MLIR_BINARY_DIR}/mlir-opt ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/forward.mlir
-pass-pipeline "builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor, tosa-to-arith), empty-tensor-to-alloc-tensor, convert-elementwise-to-linalg, arith-bufferize, func.func(linalg-bufferize, tensor-bufferize), func-bufferize)" |
${LLVM_MLIR_BINARY_DIR}/mlir-opt
-pass-pipeline "builtin.module(func.func(buffer-deallocation-simplification, convert-linalg-to-loops), eliminate-empty-tensors, func.func(llvm-request-c-wrappers),convert-math-to-llvm, convert-math-to-libm, convert-scf-to-cf, convert-arith-to-llvm, expand-strided-metadata, finalize-memref-to-llvm, convert-func-to-llvm, reconcile-unrealized-casts)" |
${LLVM_MLIR_BINARY_DIR}/mlir-translate -mlir-to-llvmir |
${LLVM_MLIR_BINARY_DIR}/llvm-as |
${LLVM_MLIR_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O0 -o ${BUDDY_BINARY_DIR}/../examples/BuddyLeNet/forward.o
DEPENDS ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/forward.mlir
COMMENT "Building forward.o"
VERBATIM)

add_custom_command(
OUTPUT subgraph0.o
COMMAND ${LLVM_MLIR_BINARY_DIR}/mlir-opt ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/subgraph0.mlir
-pass-pipeline "builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor, tosa-to-arith))" |
${BUDDY_BINARY_DIR}/buddy-opt
-eliminate-empty-tensors
-convert-tensor-to-linalg
-linalg-bufferize
-convert-linalg-to-affine-loops
-lower-affine
-func-bufferize
-func-bufferize-dynamic-offset
-arith-bufferize
-tensor-bufferize
-buffer-deallocation
Expand All @@ -31,12 +44,12 @@ add_custom_command(
-reconcile-unrealized-casts |
${LLVM_MLIR_BINARY_DIR}/mlir-translate -mlir-to-llvmir |
${LLVM_MLIR_BINARY_DIR}/llvm-as |
${LLVM_MLIR_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O0 -o ${BUDDY_BINARY_DIR}/../examples/BuddyLeNet/lenet.o
DEPENDS ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/lenet.mlir
COMMENT "Building lenet.o"
${LLVM_MLIR_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O0 -o ${BUDDY_BINARY_DIR}/../examples/BuddyLeNet/subgraph0.o
DEPENDS ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/subgraph0.mlir
COMMENT "Building subgraph0.o"
VERBATIM)

add_library(LENET STATIC lenet.o)
add_library(LENET STATIC subgraph0.o forward.o)

SET_TARGET_PROPERTIES(LENET PROPERTIES LINKER_LANGUAGE C)

Expand Down
26 changes: 18 additions & 8 deletions examples/BuddyLeNet/buddy-lenet-import.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@
# ===---------------------------------------------------------------------------

import os
from pathlib import Path

import numpy
import numpy as np
import torch
from torch._inductor.decomposition import decompositions as inductor_decomp

from buddy.compiler.frontend import DynamoCompiler
from buddy.compiler.graph import GraphDriver
from buddy.compiler.graph.transform import simply_fuse
from buddy.compiler.ops import tosa
from model import LeNet

Expand Down Expand Up @@ -53,14 +56,21 @@
assert len(graphs) == 1
graph = graphs[0]
params = dynamo_compiler.imported_params[graph]
graph.lower_to_top_level_ir(do_params_pack=True)
pattern_list = [simply_fuse]
graphs[0].fuse_ops(pattern_list)
driver = GraphDriver(graphs[0])
driver.subgraphs[0].lower_to_top_level_ir()
path_prefix = os.path.dirname(os.path.abspath(__file__))
# Write the MLIR module to the file.
with open(os.path.join(path_prefix, "lenet.mlir"), "w") as module_file:
print(graph._imported_module, file=module_file)
with open(os.path.join(path_prefix, "subgraph0.mlir"), "w") as module_file:
print(driver.subgraphs[0]._imported_module, file=module_file)
with open(os.path.join(path_prefix, "forward.mlir"), "w") as module_file:
print(driver.construct_main_graph(True), file=module_file)

# Concatenate all parameters into a single numpy array and write to a file.
all_param = numpy.concatenate(
params = dynamo_compiler.imported_params[graph]
current_path = os.path.dirname(os.path.abspath(__file__))

float32_param = np.concatenate(
[param.detach().numpy().reshape([-1]) for param in params]
)
all_param.tofile(os.path.join(path_prefix, "arg0.data"))

float32_param.tofile(Path(current_path) / "arg0.data")
2 changes: 1 addition & 1 deletion frontend/Python/graph/graph_driver.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# ===- graph_driver.py -------------------------------------------------------------
# ===- graph_driver.py ---------------------------------------------------------
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
25 changes: 19 additions & 6 deletions frontend/Python/ops/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,12 @@ def func_op(node: FuncOp, symbol_table: Dict[Tuple[str, int], ir.Operation]):
mlir_dtype = mlir_element_type_get(arg.dtype)
stride = []
for dim, dim_size in enumerate(shape):
stride.append(functools.reduce(lambda x, y: x * y, shape[dim+1:]+[1]))
memref_attr = ir.Attribute.parse("strided<{}, offset: ?>".format(stride))
stride.append(
functools.reduce(lambda x, y: x * y, shape[dim + 1 :] + [1])
)
memref_attr = ir.Attribute.parse(
"strided<{}, offset: ?>".format(stride)
)
arguments.append(ir.MemRefType.get(shape, mlir_dtype, memref_attr))
results = []
for i, shape in enumerate(node.tensor_meta["shape"]):
Expand All @@ -61,8 +65,12 @@ def call_op(node: CallOp, symbol_table: Dict[Tuple[str, int], ir.Operation]):
stride = []
shape = memref_type.shape
for dim, dim_size in enumerate(shape):
stride.append(functools.reduce(lambda x, y: x * y, shape[dim+1:]+[1]))
memref_attr = ir.Attribute.parse("strided<{}, offset: ?>".format(stride))
stride.append(
functools.reduce(lambda x, y: x * y, shape[dim + 1 :] + [1])
)
memref_attr = ir.Attribute.parse(
"strided<{}, offset: ?>".format(stride)
)
dest = ir.MemRefType.get(shape, memref_type.element_type, memref_attr)
cast_op = memref.CastOp(dest, input_node)
arguments.append(cast_op)
Expand Down Expand Up @@ -125,7 +133,9 @@ def param_extract(
return memref_subview_op
stride = []
for dim, dim_size in enumerate(output_shape):
stride.append(functools.reduce(lambda x, y: x * y, output_shape[dim+1:]+[1]))
stride.append(
functools.reduce(lambda x, y: x * y, output_shape[dim + 1 :] + [1])
)
memref_attr = ir.Attribute.parse(
"strided<{}, offset: {}>".format(stride, offset)
)
Expand All @@ -143,9 +153,12 @@ def param_extract(
None,
)
axis = ir.ArrayAttr.get([axis], None)
expand_shape_op = memref.ExpandShapeOp(memref_type, memref_subview_op.result, axis)
expand_shape_op = memref.ExpandShapeOp(
memref_type, memref_subview_op.result, axis
)
return expand_shape_op


ops_registry = {
"FuncOp": func_op,
"CallOp": call_op,
Expand Down
2 changes: 1 addition & 1 deletion midend/lib/Conversion/FuncBufferize/FuncBufferizePass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
#include "llvm/Support/Debug.h"
#include <cstdint>
#include <memory>
#include <mlir/IR/OperationSupport.h>
#include <utility>
using namespace mlir;
using namespace mlir::func;
Expand All @@ -53,6 +52,7 @@ class FuncBufferizeDynamicOffsetPass
: public PassWrapper<FuncBufferizeDynamicOffsetPass,
OperationPass<ModuleOp>> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FuncBufferizeDynamicOffsetPass)
FuncBufferizeDynamicOffsetPass() = default;
llvm::StringRef getArgument() const final {
return "func-bufferize-dynamic-offset";
Expand Down

0 comments on commit faec06b

Please sign in to comment.