From 597234ab48a1a1f86a0d92a804c28717f23a0fa0 Mon Sep 17 00:00:00 2001 From: CNOCycle <24318472+CNOCycle@users.noreply.github.com> Date: Thu, 16 May 2024 01:07:25 +0800 Subject: [PATCH] Merge pull request #25297 from CNOCycle:tflite/transpose Support Transpose op in TFlite #25297 **Merge with extra**: https://github.com/opencv/opencv_extra/pull/1168 The purpose of this PR is to introduce support for the Transpose op in TFlite format and to add a shape comparison between the output tensors and the references. In some occasional cases, the shape of the output tensor is `[1,4,1,1]`, while the shape of the reference tensor is `[1,4]`. Consequently, the norm check incorrectly reports that the test has passed, as the residual is zero. Below is a Python script for generating testing data. The generated data can be integrated into the repo `opencv_extra`. ```python import numpy as np import tensorflow as tf PREFIX_TFL = '/path/to/opencv_extra/testdata/dnn/tflite/' def generator(input_tensor, model, saved_name): # convert keras model to .tflite format converter = tf.lite.TFLiteConverter.from_keras_model(model) #converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.optimizations = [None] tflite_model = converter.convert() with open(f'{PREFIX_TFL}/{saved_name}.tflite', 'wb') as f: f.write(tflite_model) # save the input tensor to .npy if input_tensor.ndim == 4: opencv_tensor = np.transpose(input_tensor, (0,3,1,2)) else: opencv_tensor = input_tensor opencv_tensor = np.copy(opencv_tensor, order='C').astype(np.float32) np.save(f'{PREFIX_TFL}/{saved_name}_inp.npy', opencv_tensor) # generate output tenosr and save it to .npy mat_out = model(input_tensor).numpy() mat_out = np.copy(mat_out, order='C').astype(np.float32) if mat_out.ndim == 4: mat_out = np.transpose(mat_out, (0,3,1,2)) interpreter = tf.lite.Interpreter(model_content=tflite_model) out_name = interpreter.get_output_details()[0]['name'] np.save(f'{PREFIX_TFL}/{saved_name}_out_{out_name}.npy', mat_out) def build_transpose(): model_name = "keras_permute" mat_in = np.array([[[1,2,3], [4,5,6]]], dtype=np.float32) model = tf.keras.Sequential() model.add(tf.keras.Input(shape=(2,3))) model.add(tf.keras.layers.Permute((2,1))) model.summary() generator(mat_in, model, model_name) if __name__ == '__main__': build_transpose() ``` ### Pull Request Readiness Checklist - [x] I agree to contribute to the project under Apache 2 License. - [X] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [X] The PR is proposed to the proper branch - [ ] There is a reference to the original bug report and related work - [ ] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [X] The feature is well documented and sample code can be built with the project CMake --- modules/dnn/src/tflite/tflite_importer.cpp | 45 ++++++++++++++++++++++ modules/dnn/test/test_tflite_importer.cpp | 9 +++++ 2 files changed, 54 insertions(+) diff --git a/modules/dnn/src/tflite/tflite_importer.cpp b/modules/dnn/src/tflite/tflite_importer.cpp index 8850cd9ad219..1c048ad9d026 100644 --- a/modules/dnn/src/tflite/tflite_importer.cpp +++ b/modules/dnn/src/tflite/tflite_importer.cpp @@ -70,6 +70,7 @@ class TFLiteImporter { void parseFullyConnected(const Operator& op, const std::string& opcode, LayerParams& layerParams); void parseSoftmax(const Operator& op, const std::string& opcode, LayerParams& layerParams); void parseCast(const Operator& op, const std::string& opcode, LayerParams& layerParams); + void parseTranspose(const Operator& op, const std::string& opcode, LayerParams& layerParams); void parseFusedActivation(const Operator& op, ActivationFunctionType activ); void parseActivation(const Operator& op, const std::string& opcode, LayerParams& layerParams, bool isFused); @@ -284,6 +285,7 @@ TFLiteImporter::DispatchMap TFLiteImporter::buildDispatchMap() dispatch["SOFTMAX"] = &TFLiteImporter::parseSoftmax; dispatch["CAST"] = &TFLiteImporter::parseCast; dispatch["TFLite_Detection_PostProcess"] = &TFLiteImporter::parseDetectionPostProcess; + dispatch["TRANSPOSE"] = &TFLiteImporter::parseTranspose; return dispatch; } @@ -719,6 +721,49 @@ void TFLiteImporter::parseResize(const Operator& op, const std::string& opcode, addLayer(layerParams, op); } +void TFLiteImporter::parseTranspose(const Operator& op, const std::string& opcode, LayerParams& layerParams) +{ + layerParams.type = "Permute"; + std::vector perm = allTensors[op.inputs()->Get(1)]; + + DataLayout inpLayout = layouts[op.inputs()->Get(0)]; + if (inpLayout == DNN_LAYOUT_NHWC && perm.size() == 4) { + + // OpenCV operates under the assumption that NCHW format, whereas TFLite defaults to NHWC. + // Therfore, to align these layouts, the axes of the permutation vector should be adjusted accordingly. + // For implementation details, please refer to the disscusion: + // https://github.com/opencv/opencv/pull/25297#issuecomment-2049762298 + + if (perm[0] != 0) { + CV_Error(Error::StsParseError, "The first axis should not be permuted."); + } + if (perm[1] == 1 && perm[2] == 2 && perm[3] == 3) { + std::vector orderLP = {0, 1, 2, 3}; + layerParams.set("order", DictValue::arrayInt(orderLP.data(), orderLP.size())); + layouts[op.outputs()->Get(0)] = DNN_LAYOUT_NCHW; + } + else if (perm[1] == 1 && perm[2] == 3 && perm[3] == 2) { + std::vector orderLP = {0, 3, 2, 1}; + layerParams.set("order", DictValue::arrayInt(orderLP.data(), orderLP.size())); + } + else if (perm[1] == 2 && perm[2] == 1 && perm[3] == 3) { + std::vector orderLP = {0, 1, 3, 2}; + layerParams.set("order", DictValue::arrayInt(orderLP.data(), orderLP.size())); + layouts[op.outputs()->Get(0)] = DNN_LAYOUT_NCHW; + } + else if (perm[1] == 2 && perm[2] == 3 && perm[3] == 1) { + std::vector orderLP = {0, 2, 3, 1}; + layerParams.set("order", DictValue::arrayInt(orderLP.data(), orderLP.size())); + } + + } + else { + layerParams.set("order", DictValue::arrayInt(perm.data(), perm.size())); + } + + addLayer(layerParams, op); +} + int TFLiteImporter::addPermuteLayer(const std::vector& order, const std::string& permName, const std::pair& inpId, int dtype) { diff --git a/modules/dnn/test/test_tflite_importer.cpp b/modules/dnn/test/test_tflite_importer.cpp index 7ad62bf3081a..7621b44ff53d 100644 --- a/modules/dnn/test/test_tflite_importer.cpp +++ b/modules/dnn/test/test_tflite_importer.cpp @@ -251,6 +251,15 @@ TEST_P(Test_TFLite, fully_connected) { testLayer("fully_connected"); } +TEST_P(Test_TFLite, permute) { + testLayer("permutation_3d"); + // Temporarily disabled as TFLiteConverter produces a incorrect graph in this case + //testLayer("permutation_4d_0123"); + testLayer("permutation_4d_0132"); + testLayer("permutation_4d_0213"); + testLayer("permutation_4d_0231"); +} + INSTANTIATE_TEST_CASE_P(/**/, Test_TFLite, dnnBackendsAndTargets()); }} // namespace