Skip to content

Commit

Permalink
support v_Slice.end=INT64_MAX; support single-head attention subgraph…
Browse files Browse the repository at this point in the history
…; add test for attention subgraph fusion
  • Loading branch information
fengyuentau committed Nov 25, 2023
1 parent 70466d9 commit b3068a3
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 10 deletions.
22 changes: 14 additions & 8 deletions modules/dnn/src/layers/attention_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,17 @@ class AttentionLayerImpl CV_FINAL : public AttentionLayer {
CV_CheckTrue(params.has("qkv_hidden_sizes"), "DNN/Attention: qkv_hidden_sizes is required but missing");
auto param_qkv_hidden_sizes = params.get("qkv_hidden_sizes");
CV_CheckEQ(param_qkv_hidden_sizes.size(), 3, "DNN/Attention: qkv_hidden_sizes must and only have three elements");

qkv_hidden_sizes.clear();
qkv_hidden_sizes.resize(3);
qkv_hidden_sizes[0] = static_cast<size_t>(param_qkv_hidden_sizes.get<int>(0));
qkv_hidden_sizes[1] = static_cast<size_t>(param_qkv_hidden_sizes.get<int>(1));
qkv_hidden_sizes[2] = static_cast<size_t>(param_qkv_hidden_sizes.get<int>(2));

hidden_size = qkv_hidden_sizes[0] + qkv_hidden_sizes[1] + qkv_hidden_sizes[2];
/* v_hidden_size needs to be initialized in finalize in case v_slice_end=INT_MAX */

qkv_head_sizes.clear();
qkv_head_sizes.resize(3);
std::transform(qkv_hidden_sizes.begin(), qkv_hidden_sizes.end(), qkv_head_sizes.begin(),
[this] (const size_t w) { return static_cast<size_t>(w / num_heads); });
qkv_head_sizes[0] = static_cast<size_t>(qkv_hidden_sizes[0] / num_heads);
qkv_head_sizes[1] = static_cast<size_t>(qkv_hidden_sizes[1] / num_heads);

scale = 1.f / params.get<float>("scale", sqrt(qkv_head_sizes[0]));

Expand All @@ -64,14 +65,12 @@ class AttentionLayerImpl CV_FINAL : public AttentionLayer {
const auto &input_shape = inputs[0];
const auto &weight_shape = inputs[1];
const auto &bias_shape = inputs[2];
size_t dim_bias = static_cast<size_t>(std::accumulate(bias_shape.begin(), bias_shape.end(), 1, std::multiplies<int>()));

CV_CheckEQ(input_shape.size(), static_cast<size_t>(3), "DNN/Attention: invalid input dimension");
CV_CheckEQ(weight_shape.size(), static_cast<size_t>(2), "DNN/Attention: invalid weight dimension");

CV_CheckEQ(input_shape[2], weight_shape[0], "DNN/Attention: invalid input shape");
CV_CheckEQ(static_cast<size_t>(weight_shape[1]), hidden_size, "DNN/Attention: invalid weight shape");
CV_CheckEQ(dim_bias, hidden_size, "DNN/Attention: invalid bias shape");
CV_CheckEQ(weight_shape[1], bias_shape[0], "DNN/Attention: invalid weight or bias shape");

outputs.assign(1, inputs[0]);
return false;
Expand All @@ -86,6 +85,13 @@ class AttentionLayerImpl CV_FINAL : public AttentionLayer {
batch_size = static_cast<size_t>(input_shape[0]);
seq_len = static_cast<size_t>(input_shape[1]);
input_hidden_size = static_cast<size_t>(input_shape[2]);

const auto weight_shape = shape(inputs[1]);
hidden_size = weight_shape[1];
qkv_hidden_sizes[2] = hidden_size - qkv_hidden_sizes[0] - qkv_hidden_sizes[1];
qkv_head_sizes[2] = static_cast<size_t>(qkv_hidden_sizes[2] / num_heads);

// std::cout << "finalize: qkv_hidden_sizes=" << qkv_hidden_sizes << ", qkv_head_sizes=" << qkv_head_sizes << ", hidden_size=" << hidden_size << std::endl;
}

void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE {
Expand Down
126 changes: 124 additions & 2 deletions modules/dnn/src/onnx/onnx_graph_simplifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include <opencv2/core/utils/logger.hpp>
#include <queue>
#include <limits>

namespace cv { namespace dnn {
CV__DNN_INLINE_NS_BEGIN
Expand Down Expand Up @@ -293,8 +294,12 @@ class AttentionSubGraph : public Subgraph {
auto fill_qkv_hidden_sizes = [&] (const int slice_node_id) {
int slice_start = extractConstant(net, matchedNodesIds[slice_node_id], 1).at<int>(0);
int slice_end = extractConstant(net, matchedNodesIds[slice_node_id], 2).at<int>(0);
int64_t hidden_size = static_cast<int64_t>(slice_end - slice_start);
qkv_hidden_sizes.push_back(hidden_size);
if (slice_end == std::numeric_limits<int>::max()) {
qkv_hidden_sizes.push_back(0); // workaround for Slice with end=INT_MAX
} else {
int64_t hidden_size = static_cast<int64_t>(slice_end - slice_start);
qkv_hidden_sizes.push_back(hidden_size);
}
};
fill_qkv_hidden_sizes(slice_q);
fill_qkv_hidden_sizes(slice_k);
Expand Down Expand Up @@ -351,6 +356,122 @@ class AttentionSubGraph : public Subgraph {
std::string bias_name;
};

/* Attention subgraph with single head.
No Reshape operator is appended after each Slice operator.
*/
class AttentionSingleHeadSubGraph : public Subgraph {
public:
AttentionSingleHeadSubGraph() {
int input = addNodeToMatch("");
int transpose = addNodeToMatch("Transpose", input); // tranpose does not make any differences to the accuracy here in this subgraph
att_matmul = addNodeToMatch("MatMul", transpose, addNodeToMatch(""));
att_add = addNodeToMatch("Add", addNodeToMatch(""), att_matmul);

// v_path
slice_v = addNodeToMatch("Slice", att_add, addNodeToMatch(""), addNodeToMatch(""), addNodeToMatch(""));
int transpose_v = addNodeToMatch("Transpose", slice_v);

// q_path
slice_q = addNodeToMatch("Slice", att_add, addNodeToMatch(""), addNodeToMatch(""), addNodeToMatch(""));
int transpose_q = addNodeToMatch("Transpose", slice_q);
div_q = addNodeToMatch("Div", transpose_q, addNodeToMatch(""));

// k_path
slice_k = addNodeToMatch("Slice", att_add, addNodeToMatch(""), addNodeToMatch(""), addNodeToMatch(""));
int transpose_k = addNodeToMatch("Transpose", slice_k);

// qk
int matmul_qk = addNodeToMatch("MatMul", div_q, transpose_k);
int softmax_qk = addNodeToMatch("Softmax", matmul_qk);

// qkv
int matmul_qkv = addNodeToMatch("MatMul", softmax_qk, transpose_v);
int transpose_qkv = addNodeToMatch("Transpose", matmul_qkv);
addNodeToMatch("Reshape", transpose_qkv, addNodeToMatch(""));

setFusedNode("Attention", input);
}

static std::string getInputName(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id) {
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
if (initializer_id != -1) {
return onnx_net->getNameOfInitializer(initializer_id);
} else {
const auto node = net->getNode(node_id);
return node->getInputName(input_id);
}
}

virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds) CV_OVERRIDE {
if (Subgraph::match(net, nodeId, matchedNodesIds)) {
// get attrs - qkv_hidden_sizes
auto fill_qkv_hidden_sizes = [&] (const int slice_node_id) {
int slice_start = extractConstant(net, matchedNodesIds[slice_node_id], 1).at<int>(0);
int slice_end = extractConstant(net, matchedNodesIds[slice_node_id], 2).at<int>(0);
if (slice_end == std::numeric_limits<int>::max()) {
qkv_hidden_sizes.push_back(0); // workaround for Slice with end=INT_MAX
} else {
int64_t hidden_size = static_cast<int64_t>(slice_end - slice_start);
qkv_hidden_sizes.push_back(hidden_size);
}
};
fill_qkv_hidden_sizes(slice_q);
fill_qkv_hidden_sizes(slice_k);
fill_qkv_hidden_sizes(slice_v);
CV_CheckEQ(qkv_hidden_sizes.size(), static_cast<size_t>(3), "ONNXSimplifier/Attention: invalid qkv hidden sizes");
CV_CheckEQ(int(qkv_hidden_sizes[0]), int(qkv_hidden_sizes[1]), "ONNXSimplifier/Attention: invalid qkv hidden sizes, q_hidden_size == v_hidden_size is required");
// get attrs - num_heads, scale
num_heads = 1;
scale = extractConstant(net, matchedNodesIds[div_q], 1).at<float>(0);
// std::cout << "AttentionSingleHeadSubGraph: num_heads=" << num_heads << ", qkv_hidden_sizes=" << qkv_hidden_sizes << ", scale=" << scale << std::endl;

// get names
weight_name = getInputName(net, matchedNodesIds[att_matmul], 1);
// std::cout << "AttentionSingleHeadSubGraph: weight_name=" << weight_name << std::endl;
bias_name = getInputName(net, matchedNodesIds[att_add], 0);
// std::cout << "AttentionSingleHeadSubGraph: bias_name=" << bias_name << std::endl;
return true;
}
return false;
}

virtual void finalize(const Ptr<ImportGraphWrapper>& net,
const Ptr<ImportNodeWrapper>& fusedNode,
std::vector<Ptr<ImportNodeWrapper> >&) CV_OVERRIDE {
// add attrs
opencv_onnx::NodeProto* node = fusedNode.dynamicCast<ONNXNodeWrapper>()->node;
opencv_onnx::AttributeProto* attr_num_heads = node->add_attribute();
attr_num_heads->set_name("num_heads");
attr_num_heads->set_i(num_heads);
opencv_onnx::AttributeProto* attr_qkv_hidden_sizes = node->add_attribute();
attr_qkv_hidden_sizes->set_name("qkv_hidden_sizes");
attr_qkv_hidden_sizes->add_ints(qkv_hidden_sizes[0]);
attr_qkv_hidden_sizes->add_ints(qkv_hidden_sizes[1]);
attr_qkv_hidden_sizes->add_ints(qkv_hidden_sizes[2]);
opencv_onnx::AttributeProto* attr_scale = node->add_attribute();
attr_scale->set_name("scale");
attr_scale->set_f(scale);

// add inputs
node->add_input(weight_name);
node->add_input(bias_name);
}

protected:
int att_matmul, att_add;
int slice_q, slice_k, slice_v;
int div_q;

std::vector<int64_t> qkv_hidden_sizes; // order: [qk_hidden_size, qk_hidden_size, v_hidden_size]
int64_t num_heads;
float scale;

std::string weight_name;
std::string bias_name;
};

/* Fusion for Gelu.
Graph before fusion:
Expand Down Expand Up @@ -1376,6 +1497,7 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net)
subgraphs.push_back(makePtr<NormalizeSubgraph4>());
subgraphs.push_back(makePtr<NormalizeSubgraph5>());
subgraphs.push_back(makePtr<AttentionSubGraph>());
subgraphs.push_back(makePtr<AttentionSingleHeadSubGraph>());

simplifySubgraphs(Ptr<ImportGraphWrapper>(new ONNXGraphWrapper(net)), subgraphs);
}
Expand Down
8 changes: 8 additions & 0 deletions modules/dnn/test/test_graph_simplifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,12 @@ TEST_F(Test_Graph_Simplifier, MishSubgraph) {
test("mish", "Mish");
}

TEST_F(Test_Graph_Simplifier, AttentionSubgraph) {
/* Test for 2 subgraphs
- AttentionSubgraph
- AttentionSingleHeadSubgraph
*/
test("attention", "Attention");
test("attention_single_head", "Attention");
}
}}
3 changes: 3 additions & 0 deletions modules/dnn/test/test_onnx_importer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2789,6 +2789,9 @@ TEST_P(Test_ONNX_layers, Expand_shape_model4) {
TEST_P(Test_ONNX_layers, Attention) {
testONNXModels("attention");
}
TEST_P(Test_ONNX_layers, AttentionSingleHead) {
testONNXModels("attention_single_head");
}

INSTANTIATE_TEST_CASE_P(/**/, Test_ONNX_nets, dnnBackendsAndTargets());

Expand Down

0 comments on commit b3068a3

Please sign in to comment.