Skip to content

Commit

Permalink
Merge pull request opencv#25297 from CNOCycle:tflite/transpose
Browse files Browse the repository at this point in the history
Support Transpose op in TFlite opencv#25297

**Merge with extra**: opencv/opencv_extra#1168

The purpose of this PR is to introduce support for the Transpose op in TFlite format and to add a shape comparison between the output tensors and the references. In some occasional cases, the shape of the output tensor is `[1,4,1,1]`, while the shape of the reference tensor is `[1,4]`. Consequently, the norm check incorrectly reports that the test has passed, as the residual is zero.

Below is a Python script for generating testing data. The generated data can be integrated into the repo `opencv_extra`.

```python
import numpy as np
import tensorflow as tf

PREFIX_TFL = '/path/to/opencv_extra/testdata/dnn/tflite/'

def generator(input_tensor, model, saved_name):

    # convert keras model to .tflite format
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    #converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.optimizations = [None]
    tflite_model = converter.convert()
    with open(f'{PREFIX_TFL}/{saved_name}.tflite', 'wb') as f:
        f.write(tflite_model)

    # save the input tensor to .npy
    if input_tensor.ndim == 4:
        opencv_tensor = np.transpose(input_tensor, (0,3,1,2))
    else:
        opencv_tensor = input_tensor
    opencv_tensor = np.copy(opencv_tensor, order='C').astype(np.float32)
    np.save(f'{PREFIX_TFL}/{saved_name}_inp.npy', opencv_tensor)

    # generate output tenosr and save it to .npy
    mat_out = model(input_tensor).numpy()
    mat_out = np.copy(mat_out, order='C').astype(np.float32)
    if mat_out.ndim == 4:
        mat_out = np.transpose(mat_out, (0,3,1,2))
    interpreter = tf.lite.Interpreter(model_content=tflite_model)
    out_name = interpreter.get_output_details()[0]['name']
    np.save(f'{PREFIX_TFL}/{saved_name}_out_{out_name}.npy', mat_out)

def build_transpose():

    model_name = "keras_permute"
    mat_in = np.array([[[1,2,3], [4,5,6]]], dtype=np.float32)

    model = tf.keras.Sequential()
    model.add(tf.keras.Input(shape=(2,3)))
    model.add(tf.keras.layers.Permute((2,1)))
    model.summary()

    generator(mat_in, model, model_name)

if __name__ == '__main__':
    build_transpose()
```

### Pull Request Readiness Checklist

- [x] I agree to contribute to the project under Apache 2 License.
- [X] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV
- [X] The PR is proposed to the proper branch
- [ ] There is a reference to the original bug report and related work
- [ ] There is accuracy test, performance test and test data in opencv_extra repository, if applicable
      Patch to opencv_extra has the same branch name.
- [X] The feature is well documented and sample code can be built with the project CMake
  • Loading branch information
CNOCycle authored and klatism committed May 17, 2024
1 parent ebb62cc commit 597234a
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 0 deletions.
45 changes: 45 additions & 0 deletions modules/dnn/src/tflite/tflite_importer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class TFLiteImporter {
void parseFullyConnected(const Operator& op, const std::string& opcode, LayerParams& layerParams);
void parseSoftmax(const Operator& op, const std::string& opcode, LayerParams& layerParams);
void parseCast(const Operator& op, const std::string& opcode, LayerParams& layerParams);
void parseTranspose(const Operator& op, const std::string& opcode, LayerParams& layerParams);

void parseFusedActivation(const Operator& op, ActivationFunctionType activ);
void parseActivation(const Operator& op, const std::string& opcode, LayerParams& layerParams, bool isFused);
Expand Down Expand Up @@ -284,6 +285,7 @@ TFLiteImporter::DispatchMap TFLiteImporter::buildDispatchMap()
dispatch["SOFTMAX"] = &TFLiteImporter::parseSoftmax;
dispatch["CAST"] = &TFLiteImporter::parseCast;
dispatch["TFLite_Detection_PostProcess"] = &TFLiteImporter::parseDetectionPostProcess;
dispatch["TRANSPOSE"] = &TFLiteImporter::parseTranspose;
return dispatch;
}

Expand Down Expand Up @@ -719,6 +721,49 @@ void TFLiteImporter::parseResize(const Operator& op, const std::string& opcode,
addLayer(layerParams, op);
}

void TFLiteImporter::parseTranspose(const Operator& op, const std::string& opcode, LayerParams& layerParams)
{
layerParams.type = "Permute";
std::vector<int> perm = allTensors[op.inputs()->Get(1)];

DataLayout inpLayout = layouts[op.inputs()->Get(0)];
if (inpLayout == DNN_LAYOUT_NHWC && perm.size() == 4) {

// OpenCV operates under the assumption that NCHW format, whereas TFLite defaults to NHWC.
// Therfore, to align these layouts, the axes of the permutation vector should be adjusted accordingly.
// For implementation details, please refer to the disscusion:
// https://github.com/opencv/opencv/pull/25297#issuecomment-2049762298

if (perm[0] != 0) {
CV_Error(Error::StsParseError, "The first axis should not be permuted.");
}
if (perm[1] == 1 && perm[2] == 2 && perm[3] == 3) {
std::vector<int> orderLP = {0, 1, 2, 3};
layerParams.set("order", DictValue::arrayInt<int*>(orderLP.data(), orderLP.size()));
layouts[op.outputs()->Get(0)] = DNN_LAYOUT_NCHW;
}
else if (perm[1] == 1 && perm[2] == 3 && perm[3] == 2) {
std::vector<int> orderLP = {0, 3, 2, 1};
layerParams.set("order", DictValue::arrayInt<int*>(orderLP.data(), orderLP.size()));
}
else if (perm[1] == 2 && perm[2] == 1 && perm[3] == 3) {
std::vector<int> orderLP = {0, 1, 3, 2};
layerParams.set("order", DictValue::arrayInt<int*>(orderLP.data(), orderLP.size()));
layouts[op.outputs()->Get(0)] = DNN_LAYOUT_NCHW;
}
else if (perm[1] == 2 && perm[2] == 3 && perm[3] == 1) {
std::vector<int> orderLP = {0, 2, 3, 1};
layerParams.set("order", DictValue::arrayInt<int*>(orderLP.data(), orderLP.size()));
}

}
else {
layerParams.set("order", DictValue::arrayInt<int*>(perm.data(), perm.size()));
}

addLayer(layerParams, op);
}

int TFLiteImporter::addPermuteLayer(const std::vector<int>& order, const std::string& permName,
const std::pair<int, int>& inpId, int dtype)
{
Expand Down
9 changes: 9 additions & 0 deletions modules/dnn/test/test_tflite_importer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,15 @@ TEST_P(Test_TFLite, fully_connected) {
testLayer("fully_connected");
}

TEST_P(Test_TFLite, permute) {
testLayer("permutation_3d");
// Temporarily disabled as TFLiteConverter produces a incorrect graph in this case
//testLayer("permutation_4d_0123");
testLayer("permutation_4d_0132");
testLayer("permutation_4d_0213");
testLayer("permutation_4d_0231");
}

INSTANTIATE_TEST_CASE_P(/**/, Test_TFLite, dnnBackendsAndTargets());

}} // namespace
Expand Down

0 comments on commit 597234a

Please sign in to comment.