Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ONNX parser for single-layer LSTM hidden and cell states #23475

Merged
merged 6 commits into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions modules/dnn/src/layers/recurrent_layers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,14 @@ class LSTMLayerImpl CV_FINAL : public LSTMLayer
CV_CheckEQ(Wh.rows, Wx.rows, "");
CV_CheckEQ(Wh.rows, (1 + static_cast<int>(bidirectional))*4*Wh.cols, "");
CV_CheckEQ(Wh.rows, (int)bias.total(), "");
CV_CheckEQ(hInternal.cols, Wh.cols, "");
CV_CheckEQ(hInternal.cols, cInternal.cols, "");
CV_CheckEQ(hInternal.rows, cInternal.rows, "");
// Only perform these checks if hInternal and cInternal are not empty matrices
// e.g. inputs are not given by a user
if (!hInternal.empty() && !cInternal.empty())
rogday marked this conversation as resolved.
Show resolved Hide resolved
{
CV_CheckEQ(hInternal.cols, Wh.cols, "");
CV_CheckEQ(hInternal.cols, cInternal.cols, "");
CV_CheckEQ(hInternal.rows, cInternal.rows, "");
}
CV_Assert(Wh.type() == Wx.type() && Wx.type() == bias.type());

// Peephole weights.
Expand Down Expand Up @@ -266,7 +271,7 @@ class LSTMLayerImpl CV_FINAL : public LSTMLayer
std::vector<MatShape> &internals) const CV_OVERRIDE
{
CV_Assert((!usePeephole && blobs.size() == 5) || (usePeephole && blobs.size() == 8));
CV_Assert(inputs.size() == 1);
CV_Assert((inputs.size() == 1 || inputs.size() == 3));
const MatShape& inp0 = inputs[0];

const Mat &Wh = blobs[0], &Wx = blobs[1];
Expand Down Expand Up @@ -326,7 +331,7 @@ class LSTMLayerImpl CV_FINAL : public LSTMLayer
inputs_arr.getMatVector(input);

CV_Assert((!usePeephole && blobs.size() == 5) || (usePeephole && blobs.size() == 8));
CV_Assert(input.size() == 1);
CV_Assert((input.size() == 1 || input.size() == 3));
const Mat& inp0 = input[0];

Mat &Wh = blobs[0], &Wx = blobs[1];
Expand Down Expand Up @@ -383,8 +388,18 @@ class LSTMLayerImpl CV_FINAL : public LSTMLayer
Mat Wh = blobs[0];
Mat Wx = blobs[1];
Mat bias = blobs[2];
Mat h_0 = blobs[3];
Mat c_0 = blobs[4];

Mat h_0, c_0;
// input hx and cx are not prodived as input, replace with zeros
if (input.size() == 3){
h_0 = input[1].reshape(1, input[1].size[0] * input[1].size[1]);
c_0 = input[2].reshape(1, input[2].size[0] * input[2].size[1]);
} else {
h_0 = blobs[3];
c_0 = blobs[4];
}


Mat pI, pF, pO;

Wh = Wh.rowRange(i * Wh.rows / numDirs, (i + 1) * Wh.rows / numDirs);
Expand Down
32 changes: 24 additions & 8 deletions modules/dnn/src/onnx/onnx_importer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1517,10 +1517,16 @@ void transformBlobs(std::vector<Mat>& blobs)

const int numHidden = Wh.size[2];

Mat h0 = blobs[3];
h0 = h0.reshape(1, h0.size[0] * h0.size[1]);
Mat c0 = blobs[4];
c0 = c0.reshape(1, c0.size[0] * c0.size[1]);
Mat h0, c0;
// check weather input is dynamic or not: hx, cx are given by user.
// Resahpe if only they are given
bool dyn_input = (blobs[3].empty() && blobs[4].empty()) ? true : false;
if (!dyn_input){
h0 = blobs[3];
h0 = h0.reshape(1, h0.size[0] * h0.size[1]);
c0 = blobs[4];
c0 = c0.reshape(1, c0.size[0] * c0.size[1]);
}

b = b.reshape(1, b.size[0]);
Mat bx = b.colRange(0, b.cols / 2);
Expand All @@ -1547,8 +1553,11 @@ void transformBlobs(std::vector<Mat>& blobs)
blobs[0] = Wh;
blobs[1] = Wx;
blobs[2] = b.reshape(1, 1);
blobs[3] = h0;
blobs[4] = c0;
// assing reshpaed state of they are given
asmorkalov marked this conversation as resolved.
Show resolved Hide resolved
if(!dyn_input){
blobs[3] = h0;
blobs[4] = c0;
}

if (blobs.size() == 5) {
// so that future patch removing copies can leave all indexing as is
Expand Down Expand Up @@ -1579,8 +1588,15 @@ void ONNXImporter::lstm_extractConsts(LayerParams& layerParams, const opencv_onn
Mat blob;
if (idx < lstm_proto.input_size() && !lstm_proto.input(idx).empty())
{
blob = getBlob(lstm_proto, idx);
CV_Assert(shape(blob) == blobShape);
if ((idx == 5 || idx == 6) && (constBlobs.find(lstm_proto.input(idx)) == constBlobs.end()))
{
blob = Mat();
}
else
{
blob = getBlob(lstm_proto, idx);
CV_Assert(shape(blob) == blobShape);
}
}
else
{
Expand Down