From 2c2d1e5ad9898ec270e07271113544b4c1e461cb Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Tue, 26 Apr 2022 18:00:38 +0800 Subject: [PATCH] feat(codebase/cls): support vision_transformer (#403) * feat(codebase/cls): support vision_transformer * style(onnx2ncnn): format cpp code, upgrade mmcls version * fix(CI): upgrade mmcv to 1.4.2 * fix(onnx2ncnn): offset out of range during fuse conv reshape * docs(vision_transformer.py): update VisionTransformer desc * docs(onnx2ncnn.cpp): add more comment * feat(onnx2ncnn.cpp): revert fuse weight * docs(onnx2ncnn.cpp): add more comment * test(vision_transformer): add test case * refactor(vision_transformer.py): use symbol rewrite layer_norm * refactor(vision_transformer): fix review * fix(attention): add missing files --- .github/workflows/build.yml | 6 +- csrc/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp | 396 ++++++++++++++++-- csrc/backend_ops/ncnn/ops/shape/shape.cpp | 14 +- mmdeploy/codebase/mmcls/models/__init__.py | 1 + .../mmcls/models/backbones/__init__.py | 6 +- .../models/backbones/vision_transformer.py | 68 +++ .../codebase/mmcls/models/utils/__init__.py | 4 + .../codebase/mmcls/models/utils/attention.py | 95 +++++ mmdeploy/pytorch/ops/__init__.py | 4 +- mmdeploy/pytorch/ops/gelu.py | 11 + mmdeploy/pytorch/ops/layer_norm.py | 32 ++ requirements/optional.txt | 2 +- .../test_mmcls/test_mmcls_models.py | 75 +++- 13 files changed, 668 insertions(+), 46 deletions(-) create mode 100644 mmdeploy/codebase/mmcls/models/backbones/vision_transformer.py create mode 100644 mmdeploy/codebase/mmcls/models/utils/__init__.py create mode 100644 mmdeploy/codebase/mmcls/models/utils/attention.py create mode 100644 mmdeploy/pytorch/ops/gelu.py create mode 100644 mmdeploy/pytorch/ops/layer_norm.py diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0f54134ad..8d05b7045 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -23,7 +23,7 @@ jobs: matrix: python-version: [3.7] torch: [1.8.0, 1.9.0] - mmcv: [1.4.0] + mmcv: [1.4.2] include: - torch: 1.8.0 torch_version: torch1.8 @@ -65,7 +65,7 @@ jobs: matrix: python-version: [3.7] torch: [1.9.0+cu102] - mmcv: [1.4.0] + mmcv: [1.4.2] include: - torch: 1.9.0+cu102 torch_version: torch1.9 @@ -108,7 +108,7 @@ jobs: matrix: python-version: [3.7] torch: [1.8.0+cu111] - mmcv: [1.4.0] + mmcv: [1.4.2] include: - torch: 1.8.0+cu111 torch_version: torch1.8 diff --git a/csrc/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp b/csrc/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp index a8a8b1c8d..872091f99 100644 --- a/csrc/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp +++ b/csrc/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #include "onnx.pb.h" @@ -73,6 +74,17 @@ static std::vector get_node_attr_ai(const onnx::NodeProto& node, const char return v; } +static void set_node_attr_ai(onnx::NodeProto& node, const char* key, + const std::vector& value) { + onnx::AttributeProto* attr_group = node.add_attribute(); + attr_group->set_name(key); + for (auto v : value) { + attr_group->add_ints(v); + } + + return; +} + static std::vector get_node_attr_af(const onnx::NodeProto& node, const char* key) { std::vector v; @@ -137,8 +149,9 @@ static onnx::TensorProto get_node_attr_tensor(const onnx::NodeProto& node, const return onnx::TensorProto(); } -static float get_node_attr_from_input_f(const onnx::TensorProto& tp) { - float v = 0.f; +template +static T get_node_attr_from_input(const onnx::TensorProto& tp) { + T v = 0.f; // float if (tp.data_type() == 1) { @@ -183,7 +196,7 @@ static float get_node_attr_from_input_f(const onnx::TensorProto& tp) { } else { // fprintf(stderr, "tp.name: %s\n", tp.name().c_str()); fprintf(stderr, "Unknown data type %d\n", tp.data_type()); - fprintf(stderr, "get_node_attr_from_input_f\n"); + fprintf(stderr, "get_node_attr_from_input\n"); abort(); } @@ -680,7 +693,7 @@ static void fuse_hardswish(onnx::GraphProto* mutable_graph, const onnx::TensorProto& add_three = weights[node->input(1)]; if (add_three.dims_size() != 0 || get_tensor_proto_data_size(add_three) != 1) continue; - float constant_add_three = get_node_attr_from_input_f(add_three); + float constant_add_three = get_node_attr_from_input(add_three); if (constant_add_three != 3.f) continue; onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); @@ -708,8 +721,8 @@ static void fuse_hardswish(onnx::GraphProto* mutable_graph, const onnx::TensorProto& min_tp = weights[node2->input(1)]; const onnx::TensorProto& max_tp = weights[node2->input(2)]; - relu6_min = get_node_attr_from_input_f(min_tp); - relu6_max = get_node_attr_from_input_f(max_tp); + relu6_min = get_node_attr_from_input(min_tp); + relu6_max = get_node_attr_from_input(max_tp); } if (relu6_min != 0.f || relu6_max != 6.f) continue; @@ -722,7 +735,7 @@ static void fuse_hardswish(onnx::GraphProto* mutable_graph, const onnx::TensorProto& div_six = weights[node4->input(1)]; if (div_six.dims_size() != 0 || get_tensor_proto_data_size(div_six) != 1) continue; - float constant_div_six = get_node_attr_from_input_f(div_six); + float constant_div_six = get_node_attr_from_input(div_six); if (node4->op_type() == "Div" && constant_div_six != 6.f) continue; if (node4->op_type() == "Mul" && constant_div_six != 1 / 6.f) continue; @@ -831,7 +844,7 @@ static void fuse_hardsigmoid(onnx::GraphProto* mutable_graph, const onnx::TensorProto& add_three = weights[node->input(1)]; if (add_three.dims_size() != 0 || get_tensor_proto_data_size(add_three) != 1) continue; - float constant_add_three = get_node_attr_from_input_f(add_three); + float constant_add_three = get_node_attr_from_input(add_three); if (constant_add_three != 3.f) continue; onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); @@ -857,8 +870,8 @@ static void fuse_hardsigmoid(onnx::GraphProto* mutable_graph, const onnx::TensorProto& min_tp = weights[node2->input(1)]; const onnx::TensorProto& max_tp = weights[node2->input(2)]; - relu6_min = get_node_attr_from_input_f(min_tp); - relu6_max = get_node_attr_from_input_f(max_tp); + relu6_min = get_node_attr_from_input(min_tp); + relu6_max = get_node_attr_from_input(max_tp); } if (relu6_min != 0.f || relu6_max != 6.f) continue; @@ -867,7 +880,7 @@ static void fuse_hardsigmoid(onnx::GraphProto* mutable_graph, const onnx::TensorProto& div_six = weights[node3->input(1)]; if (div_six.dims_size() != 0 || get_tensor_proto_data_size(div_six) != 1) continue; - float constant_div_six = get_node_attr_from_input_f(div_six); + float constant_div_six = get_node_attr_from_input(div_six); if (node3->op_type() == "Div" && constant_div_six != 6.f) continue; if (node3->op_type() == "Mul" && constant_div_six != 1 / 6.f) continue; @@ -1090,7 +1103,7 @@ static void fuse_normalize(onnx::GraphProto* mutable_graph, } else { const onnx::TensorProto& min_tp = weights[node2->input(1)]; - clip_min = get_node_attr_from_input_f(min_tp); + clip_min = get_node_attr_from_input(min_tp); } // reduce @@ -1343,7 +1356,7 @@ static void fuse_layernorm(onnx::GraphProto* mutable_graph, const onnx::TensorProto& pow_two = weights[node3->input(1)]; if (pow_two.dims_size() != 0 || get_tensor_proto_data_size(pow_two) != 1) continue; - float constant_pow_two = get_node_attr_from_input_f(pow_two); + float constant_pow_two = get_node_attr_from_input(pow_two); if (constant_pow_two != 2.f) continue; std::vector axes4 = get_node_attr_ai(*node4, "axes"); @@ -1360,7 +1373,7 @@ static void fuse_layernorm(onnx::GraphProto* mutable_graph, const onnx::TensorProto& add_eps = weights[node5->input(1)]; if (add_eps.dims_size() != 0 || get_tensor_proto_data_size(add_eps) != 1) continue; - float eps = get_node_attr_from_input_f(add_eps); + float eps = get_node_attr_from_input(add_eps); int affine = 0; while (i + 8 < node_count) { @@ -2546,6 +2559,320 @@ static void fuse_multiheadattention(onnx::GraphProto* mutable_graph, } } +/** + * @brief find graph node by output name + * + * @param graph + * @param name + * @return onnx::NodeProto* + */ +static onnx::NodeProto* find_node_by_output_name(onnx::GraphProto* mutable_graph, + const std::string& name) { + const int input_count = mutable_graph->node_size(); + for (int i = 0; i < input_count; ++i) { + onnx::NodeProto* node = mutable_graph->mutable_node(i); + + for (int j = 0; j < node->output_size(); ++j) { + auto output = node->output(j); + if (output == name) { + return node; + } + } + } + + return nullptr; +} + +/** + * @brief query output shape of target node + * + * @param mutable_graph + * @param target + * @param weights + * @param context + * @return std::tuple> + */ +static std::tuple> query_shape( + onnx::GraphProto* mutable_graph, onnx::NodeProto* target, + const std::map& weights, + std::map>& context) { + // emplace all input nodes + const int input_count = mutable_graph->input_size(); + for (int i = 0; i < input_count; i++) { + auto inp = mutable_graph->input(i); + onnx::TypeProto inp_type = inp.type(); + onnx::TensorShapeProto shape_proto = inp_type.tensor_type().shape(); + + auto dim_size = shape_proto.dim_size(); + std::vector shape(dim_size); + for (int index = 0; index < dim_size; ++index) { + shape[index] = shape_proto.dim(index).dim_value(); + } + + context.emplace(inp.name(), shape); + } + + // BFS the tree, `target` as root, onnx::graph inputs and weights as leaf nodes + std::vector serial = {target}; + { + std::set mark_as_appended = {}; + while (true) { + int start = 0, end = serial.size(); + for (int i = start; i < end; ++i) { + auto node_ptr = serial[i]; + auto len = node_ptr->input_size(); + + for (int j = 0; j < len; ++j) { + std::string name = node_ptr->input(j); + if (context.find(name) != context.end()) { + // if input founded, skip + continue; + } + + if (weights.find(name) != weights.end()) { + // if founded in weights, extract shape to context + auto weight = weights.at(name); + std::vector shape; + for (auto index = 0; index < weight.dims_size(); ++index) { + shape.emplace_back(weight.dims(index)); + } + context.emplace(name, shape); + continue; + } + + if (mark_as_appended.find(name) != mark_as_appended.end()) { + // if mark as appended, skip + continue; + } + // else append it to serialization list + auto depend_ptr = find_node_by_output_name(mutable_graph, name); + if (depend_ptr == nullptr) { + fprintf(stderr, "cannot find %s from graph !\n", name.c_str()); + return std::make_tuple(false, std::vector{}); + } + mark_as_appended.insert(name); + serial.emplace_back(depend_ptr); + } + } + + if (serial.size() <= end) { + // if not new node added, quit + break; + } + + // update start and end position, continue BFS the tree + start = end; + end = serial.size(); + } + } + + // for each node in serialization list, calculate the output shape + { + std::reverse(serial.begin(), serial.end()); + for (auto node : serial) { + if (node->op_type() == "Conv") { + auto inp = context[node->input(0)]; + auto weight = context[node->input(1)]; + assert(inp.size() == 4 and weight.size() == 4); + + int group = get_node_attr_i(*node, "group", 1); + assert(group == 1); + + // treat multiple spatial attr as single one +#define EXTRACT_REPEATED_PARAM(NAME, ATTR, DEFAULT) \ + int ATTR = DEFAULT; \ + { \ + std::vector _vec = get_node_attr_ai(*node, NAME); \ + if (not _vec.empty()) { \ + ATTR = _vec[0]; \ + } \ + } + + EXTRACT_REPEATED_PARAM("dilations", dilation, 1); + EXTRACT_REPEATED_PARAM("pads", pad, 0); + EXTRACT_REPEATED_PARAM("strides", stride, 1); + +#undef EXTRACT_REPEATED_PARAM + + int on = inp[0]; + int oc = weight[0]; + int oh = (inp[2] + 2 * pad - weight[2]) / stride + 1; + int ow = (inp[3] + 2 * pad - weight[3]) / stride + 1; + context.emplace(node->output(0), std::vector{on, oc, oh, ow}); + + } else if (node->op_type() == "Shape") { + auto inp = context[node->input(0)]; + context.emplace(node->output(0), std::vector{1, inp[1], inp[2], inp[3]}); + + } else if (node->op_type() == "Slice") { + assert(node->input_size() >= 4); + + auto inp = context[node->input(0)]; + int start = get_node_attr_from_input(weights.at(node->input(1))); + int end = get_node_attr_from_input(weights.at(node->input(2))); + int axes = get_node_attr_from_input(weights.at(node->input(3))); + + if (axes != 0) { + fprintf(stderr, "Not support axes=%d !\n", axes); + return std::make_tuple(false, std::vector{}); + } + + assert(inp.size() >= end - start); + context.emplace(node->output(0), std::vector{inp.begin() + start, inp.begin() + end}); + + } else if (node->op_type() == "Concat") { + assert(node->input_size() >= 2); + + auto axis = get_node_attr_i(*node, "axis", 0); + if (axis != 0) { + fprintf(stderr, "Not support axes=%d !\n", axis); + return std::make_tuple(false, std::vector{}); + } + + std::vector inp = context[node->input(0)]; + std::vector w_data = get_node_attr_from_input_ai(weights.at(node->input(1))); + + // concat data on axis 0 + inp.insert(inp.end(), w_data.begin(), w_data.end()); + context.emplace(node->output(0), inp); + + } else { + fprintf(stderr, "Unsupported type %s in query_shape !\n", node->op_type().c_str()); + return std::make_tuple(false, std::vector{}); + } + } + } + + assert(context.find(target->output(0)) != context.end()); + auto target_shape = context[target->output(0)]; + return std::make_tuple(true, target_shape); +} + +/** + * @brief fuse subgraph + * + * conv - - - - - - - - - - - -> reshape + * \ / + * shape - slice - concat + * + * to + * + * conv --> reshape + * + * @param mutable_graph + * @param weights + * @param node_reference + * @param blob_names + * @param reduced_node_count + */ +static void fuse_conv_reshape(onnx::GraphProto* mutable_graph, + std::map& weights, + std::map& node_reference, + std::set& blob_names, int& reduced_node_count) { + std::map> shape_context; + const int node_count = mutable_graph->node_size(); + + for (int i = 0; i < node_count; i++) { + onnx::NodeProto* conv = mutable_graph->mutable_node(i); + + if (conv->op_type() != "Conv") { + continue; + } + + if (i + 4 >= node_count) { + continue; + } + + onnx::NodeProto *shape = nullptr, *slice = nullptr, *concat = nullptr, *reshape = nullptr; + + // match [Shape ... Slice, Concat ... Reshape] from near sequence, skip useless Constant + std::vector> candidates = { + {"Shape", &shape}, {"Slice", &slice}, {"Concat", &concat}, {"Reshape", &reshape}}; + + int MAX = std::min(10, node_count - i - 1); + int pos_candidate = 0; + + for (int j = 0; j < MAX; ++j) { + auto node_ptr = mutable_graph->mutable_node(j + i + 1); + if (node_ptr->op_type() == "Constant") { + continue; + } + if (node_ptr->op_type() == std::get<0>(candidates[pos_candidate])) { + *(std::get<1>(candidates[pos_candidate])) = node_ptr; + pos_candidate++; + } + } + + if (pos_candidate != candidates.size()) { + // not match the sequence + continue; + } + + if (node_reference[conv->output(0)] != 2 || node_reference[shape->output(0)] != 1 || + node_reference[slice->output(0)] != 1 || node_reference[concat->output(0)] != 1 || + node_reference[reshape->output(0)] != 1) { + continue; + } + + // check the connections + if (shape->input(0) != conv->output(0) || reshape->input(0) != conv->output(0)) { + continue; + } + if (slice->input(0) != shape->output(0)) { + continue; + } + if (concat->input(0) != slice->output(0)) { + continue; + } + if (reshape->input(0) != conv->output(0) || reshape->input(1) != concat->output(0)) { + continue; + } + + // add reshape attr + auto result = query_shape(mutable_graph, concat, weights, shape_context); + if (!std::get<0>(result)) { + continue; + } + set_node_attr_ai(*reshape, "shape", std::get<1>(result)); + + // reconstruct graph + { + // remove reference + node_reference[reshape->input(1)] -= 1; + node_reference[concat->input(0)] -= 1; + node_reference[slice->input(0)] -= 1; + node_reference[shape->input(0)] -= 1; + + // remove tensor/blob on edge + blob_names.erase(slice->input(0)); + blob_names.erase(slice->input(1)); + blob_names.erase(slice->input(2)); + blob_names.erase(slice->input(3)); + weights.erase(slice->input(1)); + weights.erase(slice->input(2)); + weights.erase(slice->input(3)); + + blob_names.erase(concat->input(0)); + blob_names.erase(concat->input(1)); + weights.erase(concat->input(1)); + + blob_names.erase(reshape->input(0)); + + // update edge + shape->clear_input(); + reshape->clear_input(); + reshape->add_input(conv->output(0)); + + shape->set_op_type("noop_reducedncnn"); + slice->set_op_type("noop_reducedncnn"); + concat->set_op_type("noop_reducedncnn"); + + reduced_node_count += 3; + } + i += 3; + } +} + static void fuse_binaryop_with_scalar(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, @@ -2563,7 +2890,7 @@ static void fuse_binaryop_with_scalar(onnx::GraphProto* mutable_graph, const onnx::TensorProto& scalar_b = weights[node->input(1)]; if (scalar_b.dims_size() != 0 || get_tensor_proto_data_size(scalar_b) != 1) continue; - float b = get_node_attr_from_input_f(scalar_b); + float b = get_node_attr_from_input(scalar_b); node_reference[node->input(1)] -= 1; @@ -2763,6 +3090,7 @@ int main(int argc, char** argv) { // op chain fusion int reduced_node_count = 0; + fuse_conv_reshape(mutable_graph, weights, node_reference, blob_names, reduced_node_count); fuse_weight_reshape(mutable_graph, weights, node_reference, blob_names, reduced_node_count); fuse_weight_transpose(mutable_graph, weights, node_reference, blob_names, reduced_node_count); fuse_shufflechannel(mutable_graph, weights, node_reference, blob_names, reduced_node_count); @@ -3200,6 +3528,8 @@ int main(int argc, char** argv) { fprintf(pp, "%-16s", "UnaryOp"); } else if (op == "Gather") { fprintf(pp, "%-16s", "Gather"); + } else if (op == "Gelu") { + fprintf(pp, "%-16s", "GELU"); } else if (op == "Gemm") { float alpha = get_node_attr_f(node, "alpha", 1.f); float beta = get_node_attr_f(node, "beta", 1.f); @@ -3542,10 +3872,10 @@ int main(int argc, char** argv) { max = get_node_attr_f(node, "max", FLT_MAX); } else { min = weights.find(node.input(1)) != weights.end() - ? get_node_attr_from_input_f(weights[node.input(1)]) + ? get_node_attr_from_input(weights[node.input(1)]) : -FLT_MAX; max = weights.find(node.input(2)) != weights.end() - ? get_node_attr_from_input_f(weights[node.input(2)]) + ? get_node_attr_from_input(weights[node.input(2)]) : FLT_MAX; } @@ -3835,6 +4165,8 @@ int main(int argc, char** argv) { fprintf(stderr, "Unsupported Gather axis: %d\n", axis + 1); } fprintf(pp, " 0=%d", axis); + } else if (op == "Gelu") { + fprintf(pp, " 0=0"); } else if (op == "Gemm") { float alpha = get_node_attr_f(node, "alpha", 1.f); float beta = get_node_attr_f(node, "beta", 1.f); @@ -4405,7 +4737,7 @@ int main(int argc, char** argv) { for (int j = 0; j < embed_dim; j++) { for (int k = 0; k < embed_dim; k++) { - float vb = wptr[k * embed_dim * 3 + j]; + float vb = wptr[j * embed_dim * 3 + k]; fwrite(&vb, sizeof(float), 1, bp); } } @@ -4424,7 +4756,7 @@ int main(int argc, char** argv) { for (int j = 0; j < embed_dim; j++) { for (int k = 0; k < embed_dim; k++) { - float vb = wptr[k * embed_dim * 3 + j + embed_dim]; + float vb = wptr[j * embed_dim * 3 + k + embed_dim]; fwrite(&vb, sizeof(float), 1, bp); } } @@ -4443,7 +4775,7 @@ int main(int argc, char** argv) { for (int j = 0; j < embed_dim; j++) { for (int k = 0; k < embed_dim; k++) { - float vb = wptr[k * embed_dim * 3 + j + embed_dim * 2]; + float vb = wptr[j * embed_dim * 3 + k + embed_dim * 2]; fwrite(&vb, sizeof(float), 1, bp); } } @@ -4459,7 +4791,7 @@ int main(int argc, char** argv) { for (int j = 0; j < embed_dim; j++) { for (int k = 0; k < embed_dim; k++) { - float vb = wptr[k * embed_dim + j]; + float vb = wptr[j * embed_dim + k]; fwrite(&vb, sizeof(float), 1, bp); } } @@ -4489,7 +4821,7 @@ int main(int argc, char** argv) { for (int j = 0; j < embed_dim; j++) { for (int k = 0; k < embed_dim; k++) { - float vb = wptr[k * embed_dim + j]; + float vb = wptr[j * embed_dim + k]; fwrite(&vb, sizeof(float), 1, bp); } } @@ -4504,7 +4836,7 @@ int main(int argc, char** argv) { for (int j = 0; j < embed_dim; j++) { for (int k = 0; k < embed_dim; k++) { - float vb = wptr[k * embed_dim + j]; + float vb = wptr[j * embed_dim + k]; fwrite(&vb, sizeof(float), 1, bp); } } @@ -4519,7 +4851,7 @@ int main(int argc, char** argv) { for (int j = 0; j < embed_dim; j++) { for (int k = 0; k < embed_dim; k++) { - float vb = wptr[k * embed_dim + j]; + float vb = wptr[j * embed_dim + k]; fwrite(&vb, sizeof(float), 1, bp); } } @@ -4534,7 +4866,7 @@ int main(int argc, char** argv) { for (int j = 0; j < embed_dim; j++) { for (int k = 0; k < embed_dim; k++) { - float vb = wptr[k * embed_dim + j]; + float vb = wptr[j * embed_dim + k]; fwrite(&vb, sizeof(float), 1, bp); } } @@ -4552,17 +4884,17 @@ int main(int argc, char** argv) { // fprintf(stderr, "node.input_size(): %d\n", node.input_size()); if (node.input_size() >= 3) { // fprintf(stderr, "ok12!\n"); - max_dets = (int)(get_node_attr_from_input_f(weights[node.input(2)]) + 0.5); + max_dets = (int)(get_node_attr_from_input(weights[node.input(2)]) + 0.5); } if (node.input_size() >= 4) { // fprintf(stderr, "iou_thre: %f\n", - // get_node_attr_from_input_f(weights[node.input(3)])); - iou_thre = get_node_attr_from_input_f(weights[node.input(3)]); + // get_node_attr_from_input(weights[node.input(3)])); + iou_thre = get_node_attr_from_input(weights[node.input(3)]); } if (node.input_size() >= 5) { // fprintf(stderr, "score_thre: %f\n", - // get_node_attr_from_input_f(weights[node.input(4)])); - score_thre = get_node_attr_from_input_f(weights[node.input(4)]); + // get_node_attr_from_input(weights[node.input(4)])); + score_thre = get_node_attr_from_input(weights[node.input(4)]); } fprintf(pp, " 0=%d", max_dets); fprintf(pp, " 1=%f", iou_thre); @@ -4736,8 +5068,10 @@ int main(int argc, char** argv) { if (node.input_size() == 1) { shape = get_node_attr_ai(node, "shape"); - } else { + } else if (weights.find(node.input(1)) != weights.end()) { shape = get_node_attr_from_input_ai(weights[node.input(1)]); + } else { + fprintf(stderr, "Unsupported reshape weight ! \n"); } if (shape.size() == 1) { diff --git a/csrc/backend_ops/ncnn/ops/shape/shape.cpp b/csrc/backend_ops/ncnn/ops/shape/shape.cpp index dbafff7e7..f538eabba 100755 --- a/csrc/backend_ops/ncnn/ops/shape/shape.cpp +++ b/csrc/backend_ops/ncnn/ops/shape/shape.cpp @@ -21,27 +21,27 @@ int Shape::forward(const Mat &bottom_blob, Mat &top_blob, const Option &opt) con return -100; } float *outptr = top_blob; + if (dims == 1) { outptr[0] = 1.0f; outptr[1] = w; - return 0; - } - if (dims == 2) { + } else if (dims == 2) { int h = bottom_blob.h; outptr[0] = 1.0f; outptr[1] = h; outptr[2] = w; - return 0; - } - if (dims == 3) { + } else if (dims == 3) { int h = bottom_blob.h; int channels = bottom_blob.c; outptr[0] = 1.0f; outptr[1] = channels; outptr[2] = h; outptr[3] = w; - return 0; + } else { + fprintf(stdout, "Unsupported dims=%d\n", dims); } + + return 0; } } // namespace mmdeploy diff --git a/mmdeploy/codebase/mmcls/models/__init__.py b/mmdeploy/codebase/mmcls/models/__init__.py index bf8d72f26..2419512c3 100644 --- a/mmdeploy/codebase/mmcls/models/__init__.py +++ b/mmdeploy/codebase/mmcls/models/__init__.py @@ -2,3 +2,4 @@ from .backbones import * # noqa: F401,F403 from .classifiers import * # noqa: F401,F403 from .heads import * # noqa: F401,F403 +from .utils import * # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmcls/models/backbones/__init__.py b/mmdeploy/codebase/mmcls/models/backbones/__init__.py index cda932d65..7e6c954ea 100644 --- a/mmdeploy/codebase/mmcls/models/backbones/__init__.py +++ b/mmdeploy/codebase/mmcls/models/backbones/__init__.py @@ -1,4 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. from .shufflenet_v2 import shufflenetv2_backbone__forward__ncnn +from .vision_transformer import visiontransformer__forward__ncnn -__all__ = ['shufflenetv2_backbone__forward__ncnn'] +__all__ = [ + 'shufflenetv2_backbone__forward__ncnn', + 'visiontransformer__forward__ncnn', +] diff --git a/mmdeploy/codebase/mmcls/models/backbones/vision_transformer.py b/mmdeploy/codebase/mmcls/models/backbones/vision_transformer.py new file mode 100644 index 000000000..21d99aa27 --- /dev/null +++ b/mmdeploy/codebase/mmcls/models/backbones/vision_transformer.py @@ -0,0 +1,68 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmcls.models.utils import resize_pos_embed + +from mmdeploy.core import FUNCTION_REWRITER +from mmdeploy.utils import Backend + + +@FUNCTION_REWRITER.register_rewriter( + func_name= # noqa: E251 + 'mmcls.models.backbones.vision_transformer.VisionTransformer.forward', + backend=Backend.NCNN.value) +def visiontransformer__forward__ncnn(ctx, self, x): + """Rewrite `forward` of VisionTransformer for ncnn backend. + + The chunk in original VisionTransformer.forward will convert + `self.cls_token` to `where` operator in ONNX, which will raise + error in ncnn. + + Args: + ctx (ContextCaller): The context with additional information. + self (VisionTransformer): The instance of the class InvertedResidual. + x (Tensor): Input features of shape (N, Cin, H, W). + Returns: + out (Tensor): A feature map output from InvertedResidual. The tensor + shape (N, Cout, H, W). + """ + B = x.shape[0] + x, patch_resolution = self.patch_embed(x) + + # cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((self.cls_token, x), dim=1) + x = x + resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens) + x = self.drop_after_pos(x) + + if not self.with_cls_token: + # Remove class token for transformer encoder input + x = x[:, 1:] + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + + if i == len(self.layers) - 1 and self.final_norm: + x = self.norm1(x) + + if i in self.out_indices: + B, _, C = x.shape + if self.with_cls_token: + patch_token = x[:, 1:].reshape(B, *patch_resolution, C) + patch_token = patch_token.permute(0, 3, 1, 2) + cls_token = x[:, 0] + else: + patch_token = x.reshape(B, *patch_resolution, C) + patch_token = patch_token.permute(0, 3, 1, 2) + cls_token = None + if self.output_cls_token: + out = [patch_token, cls_token] + else: + out = patch_token + outs.append(out) + + return tuple(outs) diff --git a/mmdeploy/codebase/mmcls/models/utils/__init__.py b/mmdeploy/codebase/mmcls/models/utils/__init__.py new file mode 100644 index 000000000..a3b76e8d7 --- /dev/null +++ b/mmdeploy/codebase/mmcls/models/utils/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .attention import multiheadattention__forward__ncnn + +__all__ = ['multiheadattention__forward__ncnn'] diff --git a/mmdeploy/codebase/mmcls/models/utils/attention.py b/mmdeploy/codebase/mmcls/models/utils/attention.py new file mode 100644 index 000000000..882f47343 --- /dev/null +++ b/mmdeploy/codebase/mmcls/models/utils/attention.py @@ -0,0 +1,95 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import torch +from torch import Tensor + +from mmdeploy.core import FUNCTION_REWRITER +from mmdeploy.utils import Backend + + +class MultiHeadAttentionop(torch.autograd.Function): + """Create onnx::MultiHeadAttention op.""" + + @staticmethod + def forward(ctx, q: Tensor, k: Tensor, v: Tensor, q_weight: Tensor, + q_bias: Tensor, k_weight: Tensor, k_bias: Tensor, + v_weight: Tensor, v_bias: Tensor, o_weight: Tensor, + o_bias: Tensor, embed_dims: int, num_heads: int) -> Tensor: + return torch.rand_like(q) + + @staticmethod + def symbolic(g, q: torch._C.Value, k: torch._C.Value, v: torch._C.Value, + q_weight: torch._C.Value, q_bias: torch._C.Value, + k_weight: torch._C.Value, k_bias: torch._C.Value, + v_weight: torch._C.Value, v_bias: torch._C.Value, + o_weight: torch._C.Value, o_bias: torch._C.Value, + embed_dims: int, num_heads: int): + + q_weight.setDebugName('q_weight') + q_bias.setDebugName('q_bias') + + k_weight.setDebugName('k_weight') + k_bias.setDebugName('k_bias') + + v_weight.setDebugName('v_weight') + v_bias.setDebugName('v_bias') + + o_weight.setDebugName('o_weight') + o_bias.setDebugName('o_bias') + + return g.op( + 'mmdeploy::MultiHeadAttention', + q, + k, + v, + q_weight, + q_bias, + k_weight, + k_bias, + v_weight, + v_bias, + o_weight, + o_bias, + embed_dim_i=embed_dims, + num_heads_i=num_heads) + + +@FUNCTION_REWRITER.register_rewriter( + func_name='mmcls.models.utils.attention.MultiheadAttention.forward', + backend=Backend.NCNN.value) +def multiheadattention__forward__ncnn(ctx, self, qkv_input): + """Rewrite `forward` of MultiheadAttention used in vision_transformer for + ncnn backend. + + Args: + ctx (ContextCaller): The context with additional information. + self (MultiheadAttention): The instance of the class + MultiheadAttention. + x (Tensor): Input features of shape (N, Cin, H, W). + Returns: + out (Tensor): A feature map output from MultiHeadAttention. The tensor + shape (N, Cout, H, W). + """ + + # split qkv weight and bias + qkv_weight = self.qkv.weight.data.reshape(3, self.input_dims, + self.embed_dims) + + q_weight = qkv_weight[0] + k_weight = qkv_weight[1] + v_weight = qkv_weight[2] + + qkv_bias = self.qkv.bias.data.reshape(3, self.embed_dims) + q_bias = qkv_bias[0] + k_bias = qkv_bias[1] + v_bias = qkv_bias[2] + + # out weight and bias + o_weight = self.proj.weight.data + o_bias = self.proj.bias.data + + out = MultiHeadAttentionop.apply(qkv_input, qkv_input, qkv_input, q_weight, + q_bias, k_weight, k_bias, v_weight, + v_bias, o_weight, o_bias, self.embed_dims, + self.num_heads) + return out diff --git a/mmdeploy/pytorch/ops/__init__.py b/mmdeploy/pytorch/ops/__init__.py index 0608aadf7..1b6a5e601 100644 --- a/mmdeploy/pytorch/ops/__init__.py +++ b/mmdeploy/pytorch/ops/__init__.py @@ -2,9 +2,11 @@ from .adaptive_avg_pool import (adaptive_avg_pool1d__default, adaptive_avg_pool2d__default, adaptive_avg_pool3d__default) +from .gelu import gelu__ncnn from .grid_sampler import grid_sampler__default from .hardsigmoid import hardsigmoid__default from .instance_norm import instance_norm__tensorrt +from .layer_norm import layer_norm__ncnn from .lstm import generic_rnn__ncnn from .squeeze import squeeze__default @@ -12,5 +14,5 @@ 'adaptive_avg_pool1d__default', 'adaptive_avg_pool2d__default', 'adaptive_avg_pool3d__default', 'grid_sampler__default', 'hardsigmoid__default', 'instance_norm__tensorrt', 'generic_rnn__ncnn', - 'squeeze__default' + 'squeeze__default', 'gelu__ncnn', 'layer_norm__ncnn' ] diff --git a/mmdeploy/pytorch/ops/gelu.py b/mmdeploy/pytorch/ops/gelu.py new file mode 100644 index 000000000..039e5a114 --- /dev/null +++ b/mmdeploy/pytorch/ops/gelu.py @@ -0,0 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from mmdeploy.core import SYMBOLIC_REWRITER +from mmdeploy.utils import Backend + + +@SYMBOLIC_REWRITER.register_symbolic( + 'gelu', is_pytorch=True, arg_descriptors=['v'], backend=Backend.NCNN.value) +def gelu__ncnn(ctx, g, self): + """Support export GELU with ncnn backend.""" + return g.op('mmdeploy::Gelu', self) diff --git a/mmdeploy/pytorch/ops/layer_norm.py b/mmdeploy/pytorch/ops/layer_norm.py new file mode 100644 index 000000000..b05406858 --- /dev/null +++ b/mmdeploy/pytorch/ops/layer_norm.py @@ -0,0 +1,32 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from: +# https://github.com/pytorch/pytorch/blob/9ade03959392e5a90b74261012de1d806cab2253/torch/onnx/symbolic_opset9.py + +from torch.onnx.symbolic_helper import parse_args + +from mmdeploy.core import SYMBOLIC_REWRITER +from mmdeploy.utils import Backend + + +@parse_args('v', 'is', 'v', 'v', 'f', 'i') +def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable): + """Symbolic function for `layer_norm`. + + PyTorch does not support export layer_norm to ONNX by default. We add the + support here. `layer_norm` will be exported as ONNX node + 'mmdeploy::layer_norm' + """ + weight.setDebugName('layernorm_weight') + bias.setDebugName('layernorm_bias') + return g.op( + 'mmdeploy::LayerNorm', input, weight, bias, affine_i=1, epsilon_f=eps) + + +@SYMBOLIC_REWRITER.register_symbolic( + 'layer_norm', is_pytorch=True, backend=Backend.NCNN.value) +def layer_norm__ncnn(ctx, *args): + """Register default symbolic function for `layer_norm`. + + Add support to layer_norm to ONNX. + """ + return layer_norm(*args) diff --git a/requirements/optional.txt b/requirements/optional.txt index 68d2edba8..0e4fbac63 100644 --- a/requirements/optional.txt +++ b/requirements/optional.txt @@ -1,4 +1,4 @@ -mmcls>=0.15.0,<=0.19.0 +mmcls>=0.21.0,<=0.22.1 mmdet>=2.19.0,<=2.20.0 mmedit mmocr>=0.3.0,<=0.4.1 diff --git a/tests/test_codebase/test_mmcls/test_mmcls_models.py b/tests/test_codebase/test_mmcls/test_mmcls_models.py index a3572fa38..245e0f3ed 100644 --- a/tests/test_codebase/test_mmcls/test_mmcls_models.py +++ b/tests/test_codebase/test_mmcls/test_mmcls_models.py @@ -14,7 +14,7 @@ input = torch.rand(1) -def get_invertedresudual_model(): +def get_invertedresidual_model(): from mmcls.models.backbones.shufflenet_v2 import InvertedResidual model = InvertedResidual(16, 16) @@ -22,6 +22,43 @@ def get_invertedresudual_model(): return model +def get_vit_model(): + from mmcls.models.classifiers.image import ImageClassifier + model = ImageClassifier( + backbone={ + 'type': + 'VisionTransformer', + 'arch': + 'b', + 'img_size': + 384, + 'patch_size': + 32, + 'drop_rate': + 0.1, + 'init_cfg': [{ + 'type': 'Kaiming', + 'layer': 'Conv2d', + 'mode': 'fan_in', + 'nonlinearity': 'linear' + }] + }, + head={ + 'type': 'VisionTransformerClsHead', + 'num_classes': 1000, + 'in_channels': 768, + 'loss': { + 'type': 'CrossEntropyLoss', + 'loss_weight': 1.0 + }, + 'topk': (1, 5) + }, + ) + model.requires_grad_(False) + + return model + + def test_baseclassifier_forward(): from mmcls.models.classifiers import BaseClassifier @@ -78,7 +115,7 @@ def test_multilabel_cls_head(): def test_shufflenetv2_backbone__forward(backend_type: Backend): check_backend(backend_type, True) - model = get_invertedresudual_model() + model = get_invertedresidual_model() model.cpu().eval() if backend_type.value == 'tensorrt': deploy_cfg = mmcv.Config( @@ -121,3 +158,37 @@ def test_shufflenetv2_backbone__forward(backend_type: Backend): rewrite_output = rewrite_output.cpu().numpy() assert np.allclose( model_output, rewrite_output, rtol=1e-03, atol=1e-05) + + +@pytest.mark.parametrize('backend_type', [Backend.NCNN]) +def test_vision_transformer_backbone__forward(backend_type: Backend): + + check_backend(backend_type, True) + model = get_vit_model() + model.eval() + + deploy_cfg = mmcv.Config( + dict( + backend_config=dict(type=backend_type.value), + onnx_config=dict(input_shape=None, output_names=['output']), + codebase_config=dict(type='mmcls', task='Classification'))) + + imgs = torch.rand((1, 3, 384, 384)) + model_outputs = model.forward(imgs, return_loss=False) + wrapped_model = WrapModel(model, 'forward') + rewrite_inputs = {'img': imgs} + rewrite_outputs, is_backend_output = get_rewrite_outputs( + wrapped_model=wrapped_model, + model_inputs=rewrite_inputs, + deploy_cfg=deploy_cfg) + + if isinstance(rewrite_outputs, dict): + rewrite_outputs = rewrite_outputs['output'] + for model_output, rewrite_output in zip(model_outputs, rewrite_outputs): + if isinstance(rewrite_output, torch.Tensor): + rewrite_output = rewrite_output.cpu().numpy() + assert np.allclose( + model_output.reshape(-1), + rewrite_output.reshape(-1), + rtol=1e-03, + atol=1e-05)