Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip] densenet example #305

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/BuddyDensenet/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
arg0.data
arg1.data
forward.mlir
subgraph0.mlir
42 changes: 42 additions & 0 deletions examples/BuddyDensenet/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
add_custom_command(
OUTPUT ${BUDDY_EXAMPLES_DIR}/BuddyDensenet/forward.mlir ${BUDDY_EXAMPLES_DIR}/BuddyDensenet/subgraph0.mlir ${BUDDY_EXAMPLES_DIR}/BuddyDensenet/arg0.data ${BUDDY_EXAMPLES_DIR}/BuddyDensenet/arg1.data
COMMAND python3 ${BUDDY_EXAMPLES_DIR}/BuddyDensenet/import-densenet.py
COMMENT "Generating forward.mlir, subgraph0.mlir and parameter files"
)


add_custom_command(
OUTPUT forward.o
COMMAND ${BUDDY_BINARY_DIR}/buddy-opt ${BUDDY_EXAMPLES_DIR}/BuddyDensenet/forward.mlir
-pass-pipeline "builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor, tosa-to-arith), empty-tensor-to-alloc-tensor, convert-elementwise-to-linalg, arith-bufferize, func.func(linalg-bufferize, tensor-bufferize), func-bufferize)" |
${LLVM_MLIR_BINARY_DIR}/mlir-opt
-pass-pipeline "builtin.module(func.func(buffer-deallocation-simplification, convert-linalg-to-loops), eliminate-empty-tensors, func.func(llvm-request-c-wrappers),convert-math-to-llvm, convert-math-to-libm, convert-scf-to-cf, convert-arith-to-llvm, expand-strided-metadata, finalize-memref-to-llvm, convert-func-to-llvm, reconcile-unrealized-casts)" |
${LLVM_MLIR_BINARY_DIR}/mlir-translate -mlir-to-llvmir |
${LLVM_MLIR_BINARY_DIR}/llvm-as |
${LLVM_MLIR_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O0 -o ${BUDDY_BINARY_DIR}/../examples/BuddyDensenet/forward.o
DEPENDS ${BUDDY_EXAMPLES_DIR}/BuddyDensenet/forward.mlir
COMMENT "Building forward.o"
VERBATIM)

add_custom_command(
OUTPUT subgraph0.o
COMMAND ${BUDDY_BINARY_DIR}/buddy-opt ${BUDDY_EXAMPLES_DIR}/BuddyDensenet/subgraph0.mlir
-pass-pipeline "builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor, tosa-to-arith), empty-tensor-to-alloc-tensor, convert-elementwise-to-linalg, func-bufferize-dynamic-offset, arith-bufferize, func.func(tensor-bufferize,linalg-bufferize))" |
${LLVM_MLIR_BINARY_DIR}/mlir-opt
-pass-pipeline "builtin.module(func.func(buffer-deallocation-simplification, convert-linalg-to-loops, lower-affine), eliminate-empty-tensors, func.func(llvm-request-c-wrappers),convert-math-to-llvm, convert-math-to-libm, convert-scf-to-cf, convert-arith-to-llvm, expand-strided-metadata, finalize-memref-to-llvm, convert-func-to-llvm, reconcile-unrealized-casts)" |
${LLVM_MLIR_BINARY_DIR}/mlir-translate -mlir-to-llvmir |
${LLVM_MLIR_BINARY_DIR}/llvm-as |
${LLVM_MLIR_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O0 -o ${BUDDY_BINARY_DIR}/../examples/BuddyDensenet/subgraph0.o
DEPENDS ${BUDDY_EXAMPLES_DIR}/BuddyDensenet/subgraph0.mlir
COMMENT "Building subgraph0.o"
VERBATIM)

add_library(DENSENET STATIC forward.o subgraph0.o)

SET_TARGET_PROPERTIES(DENSENET PROPERTIES LINKER_LANGUAGE C)

add_executable(buddy-densenet-run densenet-main.cpp)
target_link_directories(buddy-densenet-run PRIVATE ${LLVM_MLIR_LIBRARY_DIR})

set(BUDDY_DENSENET_LIBS DENSENET mlir_c_runner_utils)
target_link_libraries(buddy-densenet-run ${BUDDY_DENSENET_LIBS} ${OpenCV_LIBS})
23 changes: 23 additions & 0 deletions examples/BuddyDensenet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Buddy Compiler DenseNet Image Classification Example

## Introduction
This example shows how to use Buddy Compiler to compile a DenseNet model to MLIR code then run it. The [model](DenseNet121) is trained to classify the image type.


## How to run
1. Ensure that LLVM, OpenCV, Buddy Compiler and the Buddy Compiler python packages are installed properly. You can refer to [here](https://github.com/buddy-compiler/buddy-mlir) for more information and do a double check.

2. Set the `PYTHONPATH` environment variable.
```bash
$ export PYTHONPATH=/path-to-buddy-mlir/llvm/build/tools/mlir/python_packages/mlir_core:/path-to-buddy-mlir/build/python_packages:${PYTHONPATH}
```

3. Build and run the BERT example
```bash
$ cmake -G Ninja .. -DBUDDY_DENSENET_EXAMPLES=ON
$ ninja buddy-densenet-run
$ cd bin
$ ./buddy-densenet-run
```

4. Enjoy it!
133 changes: 133 additions & 0 deletions examples/BuddyDensenet/densenet-main.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
//===- bert-main.cpp ------------------------------------------------------===//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//===----------------------------------------------------------------------===//

#include <buddy/Core/Container.h>
#include <buddy/LLM/TextContainer.h>
#include <filesystem>
#include <limits>
#include <opencv2/imgproc.hpp>
#include <opencv2/opencv.hpp>
#include <string>
#include <utility>
#include <vector>

using namespace buddy;

// Declare DenseNet forward function.
extern "C" void _mlir_ciface_forward(MemRef<float, 2> *result,
MemRef<float, 1> *arg0,
MemRef<long long, 1> *arg1,
MemRef<float, 4> *arg2);

void loadParameters(const std::string &floatParamPath,
const std::string &int64ParamPath,
MemRef<float, 1> &floatParam,
MemRef<long long, 1> &int64Param) {
std::ifstream floatParamFile(floatParamPath, std::ios::in | std::ios::binary);
if (!floatParamFile.is_open()) {
std::string errMsg = "Failed to open float param file: " +
std::filesystem::canonical(floatParamPath).string();
throw std::runtime_error(errMsg);
}
floatParamFile.read(reinterpret_cast<char *>(floatParam.getData()),
floatParam.getSize() * sizeof(float));
if (floatParamFile.fail()) {
throw std::runtime_error("Failed to read float param file");
}
floatParamFile.close();

std::ifstream int64ParamFile(int64ParamPath, std::ios::in | std::ios::binary);
if (!int64ParamFile.is_open()) {
std::string errMsg = "Failed to open int64 param file: " +
std::filesystem::canonical(int64ParamPath).string();
throw std::runtime_error(errMsg);
}
int64ParamFile.read(reinterpret_cast<char *>(int64Param.getData()),
int64Param.getSize() * sizeof(long long));
if (int64ParamFile.fail()) {
throw std::runtime_error("Failed to read int64 param file");
}
int64ParamFile.close();
}

int main(int argc, char **argv) {
/// Print the title of this example.
const std::string title = "DenseNet Inference Powered by Buddy Compiler";
std::cout << "\033[33;1m" << title << "\033[0m" << std::endl;

/// Load weights to MemRef container.
MemRef<float, 1> arg0({8062504});
MemRef<long long, 1> arg1({121});
loadParameters("../../examples/BuddyBert/arg0.data",
"../../examples/BuddyBert/arg1.data", arg0, arg1);

if (argc != 2) {
std::cout << "Need Img Path" << std::endl;
}
/// Get user image.
std::cout << "Read Img:" << argv[1] << std::endl;
cv::Mat image = cv::imread(argv[1], cv::IMREAD_COLOR);

cv::Mat resize_image;
cv::resize(image, resize_image, cv::Size(224, 224));

MemRef<float, 4> input({1, 3, 224, 224});
float *dst = input.getData();
unsigned char *src = resize_image.ptr<unsigned char>();
// from BGR to channal RGB
for (int i = 0; i < 224; ++i) {
for (int j = 0; j < 224; ++j) {
float r = src[(i * 224 + j) * 3 + 2], g = src[(i * 224 + j) * 3 + 1],
b = src[(i * 224 + j) * 3];

dst[i * 224 + j] = (r / 255 - 0.485) / 0.229;
dst[224 * 224 + i * 224 + j] = (g / 255 - 0.456) / 0.224;
dst[224 * 224 * 2 + i * 224 + j] = (b / 255 - 0.406) / 0.225;
}
}

/// Initialize data containers.
MemRef<float, 2> result({1, 1000});

const auto inferenceStart = std::chrono::high_resolution_clock::now();

/// Execute forward inference of the model.
_mlir_ciface_forward(&result, &arg0, &arg1, &input);

const auto inferenceEnd = std::chrono::high_resolution_clock::now();
const std::chrono::duration<double, std::milli> inferenceTime =
inferenceEnd - inferenceStart;
/// Find the selected emotion.
int predict_label = -1;
float max_logits = std::numeric_limits<float>::min();
for (int i = 0; i < 1000; i++) {
if (max_logits < result.getData()[i]) {
max_logits = result.getData()[i];
predict_label = i;
}
}

std::cout << "\033[33;1m[Result] \033[0m";
std::cout << "The label of your image is ";
std::cout << "\033[32;1m" << predict_label << "\033[0m";
std::cout << "." << std::endl;

/// Print the performance.
std::cout << "\033[33;1m[Time] \033[0m";
std::cout << inferenceTime.count() << " ms" << std::endl;

return 0;
}
Binary file added examples/BuddyDensenet/example.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
72 changes: 72 additions & 0 deletions examples/BuddyDensenet/import-densenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# ===- import-densenet.py ----------------------------------------------------------
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ===---------------------------------------------------------------------------
#
# This is the test of DenseNet model.
#
# ===---------------------------------------------------------------------------

import os
from pathlib import Path

import numpy as np
import torch
from buddy.compiler.frontend import DynamoCompiler
from buddy.compiler.graph import GraphDriver
from buddy.compiler.graph.transform import (
simply_fuse,
useless_placeholder_eliminate,
)
from buddy.compiler.ops import tosa
from torch._inductor.decomposition import decompositions as inductor_decomp
from torchvision.models import densenet121, DenseNet121_Weights

weights = DenseNet121_Weights.DEFAULT
model = densenet121(weights=weights)
model.eval()
dynamo_compiler = DynamoCompiler(
primary_registry=tosa.ops_registry,
aot_autograd_decomposition=inductor_decomp,
)

inputs = torch.randn((3, 224, 224), dtype=torch.float32).unsqueeze(0)

with torch.no_grad():
graphs = dynamo_compiler.importer(model, inputs)

assert len(graphs) == 1
graph = graphs[0]
params = dynamo_compiler.imported_params[graph]
pattern_list = [simply_fuse, useless_placeholder_eliminate]
graph.fuse_ops(pattern_list)
driver = GraphDriver(graphs[0])
driver.subgraphs[0].lower_to_top_level_ir()
path_prefix = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(path_prefix, "subgraph0.mlir"), "w") as module_file:
print(driver.subgraphs[0]._imported_module, file=module_file)
with open(os.path.join(path_prefix, "forward.mlir"), "w") as module_file:
print(driver.construct_main_graph(True), file=module_file)

params = dynamo_compiler.imported_params[graph]
current_path = os.path.dirname(os.path.abspath(__file__))

float32_param = np.concatenate(
[param.detach().numpy().reshape([-1]) for param in params[:-1]]
)

float32_param.tofile(Path(current_path) / "arg0.data")

int64_param = params[-1].detach().numpy().reshape([-1])
int64_param.tofile(Path(current_path) / "arg1.data")
4 changes: 4 additions & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ if (BUDDY_LENET_EXAMPLES)
add_subdirectory(BuddyLeNet)
endif()

if (BUDDY_DENSENET_EXAMPLES)
add_subdirectory(BuddyDensenet)
endif()

if(BUDDY_DSL_EXAMPLES)
add_subdirectory(ToyDSL)
endif()
Expand Down
1 change: 1 addition & 0 deletions examples/lit.cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
'BuddyBert',
'BuddyLlama',
'BuddyBert',
'BuddyDensenet',
'BuddyResNet18',
'ConvOpt',
'DAPDialect',
Expand Down
3 changes: 3 additions & 0 deletions frontend/Python/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ def __init__(
"where.self": WhereOp,
"sqrt.default": SqrtOp,
"reciprocal.default": ReciprocalOp,
"clamp_min.default": ClampMinOp,
"clamp_max.default": ClampMaxOp,
"avg_pool2d.default": AvgPool2dOp,
}

@property
Expand Down
Loading