Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run BirdNET weights in pytorch #29

Closed
acmiyaguchi opened this issue Feb 6, 2023 · 14 comments
Closed

Run BirdNET weights in pytorch #29

acmiyaguchi opened this issue Feb 6, 2023 · 14 comments
Assignees

Comments

@acmiyaguchi
Copy link
Contributor

We're probably going to settle down on pytorch in this project, just based on prior experience. However, two of the models that we want to use are written in tensorflow. These are BirdNET and Bird MixIT.

For BirdNET, the weights can be found under the vendor directory. Make sure that you checkout all the submodules:

git submodule init
git submodule update

See this directory: https://github.com/kahst/BirdNET-Analyzer/tree/cb3707813f4465922823bcfba31358f6c5d0c370/checkpoints

We might be able to do this by exporting to onxx. I'd recommend using https://github.com/microsoft/MMdnn to transform the layers (I haven't used the tool myself).

I'd start off by taking a look at audio.py and model.py, and try to reproduce embeddings on an audio file using tensorflow directly. Afterwards, modify the workflow for pytorch after converting the weights.

@zemm16
Copy link
Contributor

zemm16 commented Feb 13, 2023

Tried converting the tflite model to onnyx but the type of model was not supported. Getting the following error:

"Unsupported TFLite OP: 89 REDUCE_MIN!"

Not entirely sure what operators are exactly in tensorflow, but apparently onyxx only supports a certain number of them (and its lower than 89).

I initially tried the tflite model as that what is used in the examples from BirdNet-Analyzer, but I am going to try other strategies, I am definitely going to be looking into MMdnn next.

@acmiyaguchi
Copy link
Contributor Author

Can you try converting them into pytorch directly instead of onnx? Also, what was the full command that you ran? IIUC, there are both the tflite weights (optimized for cpu/embedded use), and there are the full model weights for gpu training. I'm not entirely familiar with tensorflow, but I think that might be the saved_model.pb.

@zemm16
Copy link
Contributor

zemm16 commented Feb 13, 2023

Yeah there are both tflite and .pb weights. I am going to try working with the .pb weights today. Onnx does not seem to have the ability to do direct transfer to pytorch. It seems you have to first transfer to onnx then to pytorch. I am going to look into MMdnn today.

@acmiyaguchi
Copy link
Contributor Author

It looks like mmdnn has the ability to convert to any supported format as long as it can be represented in their intermediate format. I found this command on their README that converts a tensorflow checkpoint into a pytorch model:

$ mmconvert -sf tensorflow -in imagenet_resnet_v2_152.ckpt.meta -iw imagenet_resnet_v2_152.ckpt --dstNodeName MMdnn_Output -df pytorch -om tf_resnet_to_pth.pth

@zemm16
Copy link
Contributor

zemm16 commented Feb 20, 2023

Some updates. MMDNN has not been updated to work on tensorflow 2.0 (current version it works with is 1.15) and so there is some incompatibility I think between the saved models in BirdNet Analyzer and MMDNN. In addition, I have been working on just loading the model in and looking at the architecture.
My plan was trying to manipulate the model to be able to run it correctly in MMDNN. However, the loaded model seems empty?
The default examples run on the .tflite saved models, and when running analyze.py to analyze the audio file they work fine.
When I switch to the .pb model to be loaded into the script however, there is no predict method associated with the loaded model.
Going to look into this a little more and see if I can find any answers.

@zemm16
Copy link
Contributor

zemm16 commented Feb 21, 2023

Current issues with mmdnn are the following:

mmconvert \
  -sf tensorflow \
  -iw vendor/BirdNET-Analyzer/checkpoints/V2.2/BirdNET_GLOBAL_3K_V2.2_Model/saved_model.pb \
  --inNodeName INPUT \
  --inputShape 144000 \
  --dstNodeName CLASS_DENSE_1 \
  -df pytorch \
  -om tf2torch_saved_model

gives the error "Error parsing message with type 'tensorflow.GraphDef'".

saved_model.pb from tensorflow 2.x is a completely different format from the frozen graph from tensorflow 1.x

Running onnyx there are issues with certain operators used in tensorflow. This is pointed out on their github page:
https://github.com/onnx/tensorflow-onnx

I am pretty sure (but not entirely convinced?) it is this giving the following errors when trying to convert both the tflite file and the .pb file:

python -m tf2onnx.convert \
  --saved-model ../vendor/BirdNET-Analyzer/checkpoints/V2.2/BirdNET_GLOBAL_3K_V2.2_Model/ \
  --output data/logs/BirdNet0 \
  --opset 17

gives the error:

ValueError: make_sure failure: Current implementation of RFFT or FFT only allows ComplexAbs as consumer not {'Cast'}

whereas

python -m tf2onnx.convert \
  --opset 17 \
  --tflite ../vendor/BirdNET-Analyzer/checkpoints/V2.2/BirdNET_GLOBAL_3K_V2.2_Model_FP32.tflite \
  --output model.onnx

gives the error

ValueError: make_sure failure: Current implementation of RFFT2D only allows ComplexAbs as consumer not {'Squeeze'}

@acmiyaguchi
Copy link
Contributor Author

Just edited the commands for formatting.

@acmiyaguchi
Copy link
Contributor Author

acmiyaguchi commented Feb 21, 2023

I took a gander at this -- it looks like we might want to build some of the tensorflow tools here: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms#inspecting-graphs

In particular, I'm interested in getting a dump of the graph and to take a look at some of the operations. RFFT2D presumably means the real part of the 2-dim fast fourier transform, but I'm not entirely sure what consumers are in the context. https://www.tensorflow.org/api_docs/python/tf/signal/rfft2d

I took a look in this file for the error code: https://github.com/onnx/tensorflow-onnx/blob/ff139a41e0b34b0312f179ae66c2f41034ec2c72/tf2onnx/onnx_opset/signal.py#L116-L121

Not entirely sure what modifications are necessary in order to transform the graph so we can use it in onnx, but having the ability to inspect the graph is the first step. Some of those graph transformations ("freezing the graph") might be useful too. I think the crux of the solution will either be patching onnx or writing a graph transform of some kind so it's compatible with onnx.

@acmiyaguchi
Copy link
Contributor Author

acmiyaguchi commented Feb 26, 2023

@zemm16 managed to get the model running on onnx via #51

See this notebook for details on how you can reproduce the results: https://github.com/dsgt-birdclef/birdclef-2023/blob/main/notebooks/acm-20230225-00-tensorflow-frozen-graph.ipynb

It might be useful to start directly off with the onnx model, which can be found here: gs://birdclef-2023/data/models/birdnet-onnx-v1.onnx

There are a couple of paths from here:

  • https://github.com/ENOT-AutoDL/onnx2torch we might be able to convert directly to pytorch with this tool.
  • We can try the mmdnn route again, this time using the optimized, frozen graph that's in the same directory as the onnx model

We should probably have a notebook that shows off the workflow using pytorch to generate an embedding, and to also label a section of audio.

@zemm16
Copy link
Contributor

zemm16 commented Feb 26, 2023

This is awesome. I will check out tomorrow. I was discussing onnx with some people the other day and was told that converting from onnx to pytorch is much easier than tf to onnx so might not have many issues. Fingers crossed.

@zemm16
Copy link
Contributor

zemm16 commented Feb 27, 2023

I need to review some of the links @acmiyaguchi you provided so I fully understand the model conversion works but it seems onnx2torch is not working because at some point the conversion file can't find nodes correctly in the onnx model. I reproduced the error in emm-20230226-00-onnx2pytorch.ipynb

@acmiyaguchi
Copy link
Contributor Author

Linking #52

@acmiyaguchi
Copy link
Contributor Author

Another thing to consider -- maybe having the onnx model is enough, and we can consider moving birdnet inference up into the data loader layer (pre-processing) instead of treating it as part of the model? Trying to fit a 144k float32 vec into GPU memory is going to be memory intensive, so it might actually be a good idea to keep it on the CPU?

@zemm16
Copy link
Contributor

zemm16 commented Feb 28, 2023

Pursuing building the onnx model into data loader to extract embeddings

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants