Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

STFT layer output shape deviates from STFTTflite layer in batch dimension #130

Closed
PhilippMatthes opened this issue Aug 25, 2021 · 5 comments

Comments

@PhilippMatthes
Copy link

Use Case

I want to convert a STFT layer in my model to a STFTTflite to deploy it to my mobile device. In the documentation I found that another dimension is added to account for complex numbers. But I also encountered a behaviour that is not documented.

Expected Behaviour

input_shape = (2048, 1)  # mono signal

model = keras.models.Sequential()  # TFLite incompatible model
model.add(kapre.STFT(n_fft=1024, hop_length=512, input_shape=input_shape))

tflite_model = keras.models.Sequential()  # TFLite compatible model
tflite_model.add(kapre.STFTTflite(n_fft=1024, hop_length=512, input_shape=input_shape))

model has the output shape (None, 3, 513, 1). Therefore, tflite_model should have the output shape (None, 3, 513, 1, 2).

Observed Behaviour

The output shape of tflite_model is (1, 3, 513, 1, 2) instead of (None, 3, 513, 1, 2).

Problem Solution

  • If this behaviour is unwanted:
    • Change the model output format so that the batch dimension is correctly shaped.
  • Otherwise:
    • Explain in the documentation why the batch dimension is shaped to 1.
    • Explain in the documentation how to include this layer into models which expect the batch dimension to be shaped None.
@kenders2000
Copy link
Contributor

kenders2000 commented Aug 25, 2021

Hi,

Yeah the restriction to a batch size of one in tflite is enforced and is something that I have been trying to address, but I have been getting seg-faults when running tflite inference so have not yet solved it. If you are interested you can look at my fork for details (work in progress):

https://github.com/kenders2000/kapre/tree/feature/tflite-variable-batch-size

The use case I have developed for this is to train a model using the 'vanilla' kapre.STFT layers, then when you want to convert the model to tflite for deployment create a new model with the kapre.STFTTflite layers and load the weights from the vanilla one. On the mobile device you are restricted to performing a single inference at a time, but this is generally not that restrictive.

I recently tried using the tf.signal.stft but it is still not tflite compatible from what I could see.

Hope this helps, I agree there is probably a case for improved documentation to explain this.

Cheers

Paul

@PhilippMatthes
Copy link
Author

PhilippMatthes commented Aug 25, 2021

Thanks for your response @kenders2000!

I have a follow-up question. I was able to create a batch-compatible TFLite classification model as follows:

inputs = keras.layers.Input(input_shape)
x = kapre.STFTTflite(n_fft=1024, hop_length=512, pad_begin=True)(inputs)
x = kapre.MagnitudeTflite()(x)
x = kapre.MagnitudeToDecibel()(x)
x = keras.layers.Flatten()(x)
x = keras.layers.Dense(n_outputs, activation='softmax')(x)
model = keras.models.Model(inputs, x)

With the resulting model.summary():

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_8 (InputLayer)         [(None, 2048, 1)]         0         
_________________________________________________________________
stft_tflite_10 (STFTTflite)  (1, 4, 513, 1, 2)         0         
_________________________________________________________________
magnitude_tflite_3 (Magnitud (1, 4, 513, 1)            0         
_________________________________________________________________
magnitude_to_decibel_3 (Magn (1, 4, 513, 1)            0         
_________________________________________________________________
flatten_1 (Flatten)          (1, 2052)                 0         
_________________________________________________________________
dense_1 (Dense)              (1, 9)                    18477     
=================================================================
Total params: 18,477
Trainable params: 18,477
Non-trainable params: 0

Do you expect this model to run into seg-faults with inference on mobile devices?

Note: this is just a simple example model for simplicity.

@kenders2000
Copy link
Contributor

Hi,

I can see in the model summary that while the input is None theSTFTTflite layers still have a batch size of 1. So while I would expect this model to convert and run fine (when you provide a batch size of one), you would still need resize the input dimension of the resulting tflite file to have a batch size of 1.

E.g. using resize_tensor_input()

Cheers

Paul

@PhilippMatthes
Copy link
Author

I see, thanks!

@keunwoochoi
Copy link
Owner

Thank you so much for everyone!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants