Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Transpose op in TFlite #25297

Open
wants to merge 7 commits into
base: 4.x
Choose a base branch
from
Open

Conversation

CNOCycle
Copy link

@CNOCycle CNOCycle commented Mar 29, 2024

Pull Request Readiness Checklist

Merge with extra: opencv/opencv_extra#1168

The purpose of this PR is to introduce support for the Transpose op in TFlite format and to add a shape comparison between the output tensors and the references. In some occasional cases, the shape of the output tensor is [1,4,1,1], while the shape of the reference tensor is [1,4]. Consequently, the norm check incorrectly reports that the test has passed, as the residual is zero.

Below is a Python script for generating testing data. The generated data can be integrated into the repo opencv_extra.

import numpy as np
import tensorflow as tf

PREFIX_TFL = '/path/to/opencv_extra/testdata/dnn/tflite/'

def generator(input_tensor, model, saved_name):

    # convert keras model to .tflite format
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    #converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.optimizations = [None]
    tflite_model = converter.convert()
    with open(f'{PREFIX_TFL}/{saved_name}.tflite', 'wb') as f:
        f.write(tflite_model)

    # save the input tensor to .npy
    if input_tensor.ndim == 4:
        opencv_tensor = np.transpose(input_tensor, (0,3,1,2))
    else:
        opencv_tensor = input_tensor
    opencv_tensor = np.copy(opencv_tensor, order='C').astype(np.float32)
    np.save(f'{PREFIX_TFL}/{saved_name}_inp.npy', opencv_tensor)

    # generate output tenosr and save it to .npy
    mat_out = model(input_tensor).numpy()
    mat_out = np.copy(mat_out, order='C').astype(np.float32)
    if mat_out.ndim == 4:
        mat_out = np.transpose(mat_out, (0,3,1,2))
    interpreter = tf.lite.Interpreter(model_content=tflite_model)
    out_name = interpreter.get_output_details()[0]['name']
    np.save(f'{PREFIX_TFL}/{saved_name}_out_{out_name}.npy', mat_out)

def build_transpose():

    model_name = "keras_permute"
    mat_in = np.array([[[1,2,3], [4,5,6]]], dtype=np.float32)

    model = tf.keras.Sequential()
    model.add(tf.keras.Input(shape=(2,3)))
    model.add(tf.keras.layers.Permute((2,1)))
    model.summary()

    generator(mat_in, model, model_name)

if __name__ == '__main__':
    build_transpose()
  • I agree to contribute to the project under Apache 2 License.
  • To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV
  • The PR is proposed to the proper branch
  • There is a reference to the original bug report and related work
  • There is accuracy test, performance test and test data in opencv_extra repository, if applicable
    Patch to opencv_extra has the same branch name.
  • The feature is well documented and sample code can be built with the project CMake

@dkurt
Copy link
Member

dkurt commented Mar 30, 2024

@CNOCycle, please open a PR to https://github.com/opencv/opencv_extra/ with the same branch name as here.

@CNOCycle
Copy link
Author

CNOCycle commented Apr 1, 2024

Hi, @dkurt,

Thank you for your valuable feedback. After careful consideration, I have decided to split this PR into two separate ones. The first will focus on implementing shape checkers, while the second will address the support for the Transpose op.

Regarding the Transpose op support, I have chosen to postpone it until we resolve the shape issues. There are a couple of reasons for this decision. Firstly, I need additional time to seamlessly integrate my stand-alone script for generating testing data into the existing script in the repository opencv/opencv_extra. Secondly, upon conducting tests in my local environment, I discovered that the layout issue is more intricate than initially anticipated.

Specifically, I added the shape checker as you suggested (ASSERT_EQ(ref.size, outs[i].size);) on the branch 4.x (eba158f) but encountered failures in at least two tests. The details are outlined below:

[ RUN      ] Test_TFLite.face_landmark/0, where GetParam() = OCV/CPU
... ...
Expected equality of these values:
  ref.size
    Which is: 1 x 1 x 1 x 1404
  outs[i].size
    Which is: 1 x 1404 x 1 x 1
[  FAILED  ] Test_TFLite.face_landmark/0, where GetParam() = OCV/CPU (3133 ms)

[ RUN      ] Test_TFLite.selfie_segmentation/0, where GetParam() = OCV/CPU
... ...
Expected equality of these values:
  ref.size
    Which is: 256 x 256
  outs[i].size
    Which is: 1 x 1 x 256 x 256
[  FAILED  ] Test_TFLite.selfie_segmentation/0, where GetParam() = OCV/CPU (5182 ms)

[==========] 8 tests from 1 test case ran. (382801 ms total)
[  PASSED  ] 6 tests.
[  FAILED  ] 2 tests, listed below:
[  FAILED  ] Test_TFLite.face_landmark/0, where GetParam() = OCV/CPU
[  FAILED  ] Test_TFLite.selfie_segmentation/0, where GetParam() = OCV/CPU

 2 FAILED TESTS

These errors are reproducible on RISC-V RVV with Debug mode and x64 with Release mode. The introduced shape checker effectively identifies the inconsistent shape issue. Do you have any insights or suggestions regarding these errors?

@dkurt
Copy link
Member

dkurt commented Apr 1, 2024

@CNOCycle, this is a data layout issue I mentioned before. TFLite/TensorFlow work with NHWC by default, OpenCV with NCHW. So during a layer import you have to change axes order:

std::vector<int> perm = allTensors[op.inputs()->Get(1)];

DataLayout inpLayout = layouts[op.inputs()->Get(0)];
if (inpLayout == DNN_LAYOUT_NHWC && perm.size() == 4) {
    static const int order[] = {0, 2, 3, 1};  // NHWC -> NCHW
    for (int& dim: perm) {
        CV_Assert(dim >= 0 && dim < 4);
        dim = order[dim];
    }
}

If in your test example inpLayout is different, please add a simple Conv2D layer before permutation to TFLite model.

@CNOCycle
Copy link
Author

CNOCycle commented Apr 1, 2024

Thank you for providing a clear explanation of the data layout issue. It's important to note that the failed tests I mentioned earlier are not new tests related to the Transpose op; they are existing tests in the opencv repo.

Upon direct inspection of the .tflite models, I observed that the shapes of the output tensors from face_landmark.tflite and selfie_segmentation.tflite are [1,1,1,1404] and [1, 256, 256, 1], respectively. Based on my understanding, the expected shapes in OpenCV should be [1, 1404, 1, 1] and [1, 1, 256, 256]. Consequently, it appears that there is a missing axis swap in the first model, and the shape of the reference tensor in the second model seems to be incorrect. Fortunately, these errors can be easily rectified by updating the shapes of the reference tensors. Please correct me if I am mistaken.

@dkurt dkurt self-assigned this Apr 5, 2024
@dkurt
Copy link
Member

dkurt commented Apr 5, 2024

@CNOCycle, thanks for the observation about shapes. I verified that test data for both models were saved in native view, without necessary reshaping. We can fix it by updating test data but I prefer to just add a workaround in test engine:

    ASSERT_EQ(outs.size(), outNames.size());
    for (int i = 0; i < outNames.size(); ++i) {
        Mat ref = blobFromNPY(findDataFile(format("dnn/tflite/%s_out_%s.npy", modelName.c_str(), outNames[i].c_str())));
        if (modelName == "face_landmark" || modelName == "selfie_segmentation") {
            ref = ref.reshape(1, 1);
            outs[i] = outs[i].reshape(1, 1);
        }
        normAssert(ref, outs[i], outNames[i].c_str(), l1, lInf);
    }

Note that normAssert will check the shapes.

@asmorkalov
Copy link
Contributor

@CNOCycle I merged another patch for TFLite tests and it generates conflict. Could you rebase your PR and fix the conflict.

@CNOCycle
Copy link
Author

CNOCycle commented Apr 9, 2024

Apologies for the delayed response. I encountered a link error while attempting to build the scalable RVV on Debug mode from the latest 4.x branch. I'm unsure of the origin of this error. Nevertheless, I will verify the correctness of this PR using x64 mode or another mode, and promptly push a new one based on the latest branch.

@asmorkalov
Copy link
Contributor

@dkurt is it ready for merge?

else if (perm[1] == 3 && perm[2] == 2 && perm[3] == 1) {
std::vector<int> orderLP = {0, 2, 1, 3};
layerParams.set("order", DictValue::arrayInt<int*>(orderLP.data(), orderLP.size()));
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, change layout of output:

void TFImporter::parseTranspose(tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams)

}
if (perm[1] == 1 && perm[2] == 2 && perm[3] == 3) {
std::vector<int> orderLP = {0, 1, 2, 3};
layerParams.set("order", DictValue::arrayInt<int*>(orderLP.data(), orderLP.size()));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reduce code duplications

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CNOCycle
Copy link
Author

Hi @dkurt,

Thank you for the valuable feedback. Concerning the code complexity issues, it's essential to note that the transpose operation in TF implementation also encompasses six cases to handle data layout.

if (inpLayout == DNN_LAYOUT_NHWC)
{
if (permData[0] == 0 && permData[1] == 3 && permData[2] == 1 && permData[3] == 2)
{
// in TensorFlow: NHWC->NCHW
// in OpenCV: NCHW->NCHW
data_layouts[name] = DNN_LAYOUT_NCHW;
}
else if (permData[0] == 0 && permData[1] == 1 && permData[2] == 2 && permData[3] == 3)
{
// in TensorFlow: NHWC->NHWC
// in OpenCV: NCHW->NCHW
data_layouts[name] = DNN_LAYOUT_NHWC;
}
else if (permData[0] == 0 && permData[1] == 3 && permData[2] == 2 && permData[3] == 1)
{
// in TensorFlow: NHWC->NCWH
// in OpenCV: NCHW->NCWH
int permData[] = {0, 1, 3, 2};
layerParams.set("order", DictValue::arrayInt<int*>(permData, perm.total()));
data_layouts[name] = DNN_LAYOUT_NCHW; // we keep track NCHW because channels position only matters
type = "Permute";
}
else
CV_Error(Error::StsParseError, "Only NHWC <-> NCHW permutations are allowed.");
}
else if (inpLayout == DNN_LAYOUT_NCHW)
{
if (permData[0] == 0 && permData[1] == 2 && permData[2] == 3 && permData[3] == 1)
{
// in TensorFlow: NCHW->NHWC
// in OpenCV: NCHW->NCHW
data_layouts[name] = DNN_LAYOUT_NHWC;
}
else if (permData[0] == 0 && permData[1] == 1 && permData[2] == 2 && permData[3] == 3)
{
// in TensorFlow: NCHW->NCHW
// in OpenCV: NCHW->NCHW
data_layouts[name] = DNN_LAYOUT_NCHW;
}
else
CV_Error(Error::StsParseError, "Only NHWC <-> NCHW permutations are allowed.");
}

The rationale behind specifying these six cases is elucidated in the comment:

Since applying the NCHW permutation to a NCHW tensor mirrors the NHWC permutation applied to an NHWC tensor, n additional NHWC -> NCHW conversion is requred to match the data layout.

To ensure alignment with the NCHW layout, merely applying a NHWC -> NCHW conversion to permutation vector is insufficient. The sequence of the order vector varies based on the given permutation vector.

Allow me to illustrate what I mean through the following demonstration. Suppose we have a vector A = [N, H, W, C] and a permutation vector P = [0, 1, 2, 3]. After the transpose operation, the output should be [N, H, W, C], while the expected format for OpenCV is [N, C, H, W]. Once the input is represented in a channel-first format, denoted as At = [N, C, H, W], the permutation vector should be adjusted accordingly. Applying NHWC -> NCHW conversion to the permutation vector, it becomes Pt = [0, 3, 1, 2]. However, applying Pt on At results in [N, W, C, H], which is an incorrect outcome.

As evidenced, when providing At = [N, C, H, W] with P = [0, 1, 2, 3], the expected output is [N, C, H, W], resulting in the order vector being set to [0, 1, 2, 3]. The mapping of the order vector for six cases is shown below:

A = [N, H, W, C], At = [N, C, H, W]
case1: P = [0, 1, 2, 3] ->Ap = [N, H, W, C] -> Acv = [N, C, H, W] -> order = [0, 1, 2, 3]
case2: P = [0, 1, 3, 2] ->Ap = [N, H, C, W] -> Acv = [N, W, H, C] -> order = [0, 3, 2, 1]
case3: P = [0, 2, 1, 3] ->Ap = [N, W, H, C] -> Acv = [N, C, W, H] -> order = [0, 1, 3, 2]
case4: P = [0, 2, 3, 1] ->Ap = [N, W, C, H] -> Acv = [N, H, W, C] -> order = [0, 2, 3, 1]
case5: P = [0, 3, 1, 2] ->Ap = [N, C, H, W] -> Acv = [N, W, C, H] -> order = [0, 3, 1, 2]
case6: P = [0, 3, 2, 1] ->Ap = [N, C, W, H] -> Acv = [N, H, C, W] -> order = [0, 2, 1, 3]

@asmorkalov
Copy link
Contributor

modules/dnn/src/tflite/tflite_importer.cpp:734: tab in indent.
+	// For implementation details, please refer to the disscusion:

@CNOCycle
Copy link
Author

Hi @dkurt

Is there any progress on this PR?

@asmorkalov
Copy link
Contributor

Run cd /home/ci/opencv
modules/dnn/src/tflite/tflite_importer.cpp:734: tab in indent.
+	// For implementation details, please refer to the disscusion:
modules/dnn/src/tflite/tflite_importer.cpp:735: tab in indent.
+	// https://github.com/opencv/opencv/pull/25297#issuecomment-2049762298

@dkurt
Copy link
Member

dkurt commented Apr 24, 2024

@CNOCycle, please consider review comments

@CNOCycle
Copy link
Author

@dkurt

I have fixed the tab issue and provided 5 test cases to verify correctness. If still have any concerns about this PR, please let me know. Thanks.

@asmorkalov
Copy link
Contributor

@dkurt could you take a look again?

@CNOCycle
Copy link
Author

CNOCycle commented May 6, 2024

sorry for late reply. I will re-submit a revised one today.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants