Skip to content

Commit

Permalink
Merge pull request #1168 from CNOCycle:tflite/transpose
Browse files Browse the repository at this point in the history
* Refine data transpose when generating testing data for tflite models

* Add test cases for tflite models

* Add more test cases for permutation op in tflite model

* Update testing data for permutation op in tflite model
  • Loading branch information
CNOCycle committed May 15, 2024
1 parent 1458fff commit 723bdf2
Show file tree
Hide file tree
Showing 13 changed files with 66 additions and 1 deletion.
67 changes: 66 additions & 1 deletion testdata/dnn/tflite/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,14 @@ def save_tflite_model(model, inp, name):
out = model(inp)
out = np.array(out)

if len(inp.shape) == 4:
# convert NHWC to NCHW format
if inp.ndim == 4:
inp = inp.transpose(0, 3, 1, 2)
inp = np.copy(inp, order='C').astype(inp.dtype)

if out.ndim == 4:
out = out.transpose(0, 3, 1, 2)
out = np.copy(out, order='C').astype(out.dtype)

np.save(f'{name}_inp.npy', inp)
np.save(f'{name}_out_Identity.npy', out)
Expand Down Expand Up @@ -102,3 +107,63 @@ def split(x):

inp = np.random.standard_normal((1, 2)).astype(np.float32)
save_tflite_model(fully_connected, inp, 'fully_connected')

permutation_3d = tf.keras.models.Sequential([
tf.keras.layers.Permute((2,1))
])

permutation_3d = tf.function(
permutation_3d.call,
input_signature=[tf.TensorSpec((1,2,3), tf.float32)],
)
inp = np.random.standard_normal((1, 2, 3)).astype(np.float32)
save_tflite_model(permutation_3d, inp, 'permutation_3d')

# Temporarily disabled as TFLiteConverter produces a incorrect graph in this case
#permutation_4d_0123 = tf.keras.models.Sequential([
# tf.keras.layers.Permute((1,2,3)),
# tf.keras.layers.Conv2D(3,1)
#])
#
#permutation_4d_0123 = tf.function(
# permutation_4d_0123.call,
# input_signature=[tf.TensorSpec((1,2,3,4), tf.float32)],
#)
#inp = np.random.standard_normal((1, 2, 3, 4)).astype(np.float32)
#save_tflite_model(permutation_4d_0123, inp, 'permutation_4d_0123')

permutation_4d_0132 = tf.keras.models.Sequential([
tf.keras.layers.Permute((1,3,2)),
tf.keras.layers.Conv2D(3,1)
])

permutation_4d_0132 = tf.function(
permutation_4d_0132.call,
input_signature=[tf.TensorSpec((1,2,3,4), tf.float32)],
)
inp = np.random.standard_normal((1, 2, 3, 4)).astype(np.float32)
save_tflite_model(permutation_4d_0132, inp, 'permutation_4d_0132')

permutation_4d_0213 = tf.keras.models.Sequential([
tf.keras.layers.Permute((2,1,3)),
tf.keras.layers.Conv2D(3,1)
])

permutation_4d_0213 = tf.function(
permutation_4d_0213.call,
input_signature=[tf.TensorSpec((1,2,3,4), tf.float32)],
)
inp = np.random.standard_normal((1, 2, 3, 4)).astype(np.float32)
save_tflite_model(permutation_4d_0213, inp, 'permutation_4d_0213')

permutation_4d_0231 = tf.keras.models.Sequential([
tf.keras.layers.Permute((2,3,1)),
tf.keras.layers.Conv2D(3,1)
])

permutation_4d_0231 = tf.function(
permutation_4d_0231.call,
input_signature=[tf.TensorSpec((1,2,3,4), tf.float32)],
)
inp = np.random.standard_normal((1, 2, 3, 4)).astype(np.float32)
save_tflite_model(permutation_4d_0231, inp, 'permutation_4d_0231')
Binary file added testdata/dnn/tflite/permutation_3d.tflite
Binary file not shown.
Binary file added testdata/dnn/tflite/permutation_3d_inp.npy
Binary file not shown.
Binary file added testdata/dnn/tflite/permutation_3d_out_Identity.npy
Binary file not shown.
Binary file added testdata/dnn/tflite/permutation_4d_0132.tflite
Binary file not shown.
Binary file added testdata/dnn/tflite/permutation_4d_0132_inp.npy
Binary file not shown.
Binary file not shown.
Binary file added testdata/dnn/tflite/permutation_4d_0213.tflite
Binary file not shown.
Binary file added testdata/dnn/tflite/permutation_4d_0213_inp.npy
Binary file not shown.
Binary file not shown.
Binary file added testdata/dnn/tflite/permutation_4d_0231.tflite
Binary file not shown.
Binary file added testdata/dnn/tflite/permutation_4d_0231_inp.npy
Binary file not shown.
Binary file not shown.

0 comments on commit 723bdf2

Please sign in to comment.