Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add CLAP onnx export for zero-shot-audio-classification #1552

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

xenova
Copy link
Contributor

@xenova xenova commented Nov 26, 2023

What does this PR do?

Adds support for CLAP export for zero-shot-audio-classification. Still a WIP.

TODOs:

  • Fix fused exports (i.e., optimum-cli export onnx -m laion/clap-htsat-unfused out works but optimum-cli export onnx -m laion/clap-htsat-fused out fails
  • Write tests
  • Test w/ Transformers.js
  • Cleanup

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@xenova
Copy link
Contributor Author

xenova commented Dec 1, 2023

I'm facing an issue where, although validation passes (all outputs 1e-5 threshold), the output of the model is different when running with the pipeline function. Here's some code to reproduce:

import torch
from transformers import AutoTokenizer, AutoProcessor, pipeline
from optimum.onnxruntime import ORTModel

from transformers.models.clap.modeling_clap import ClapOutput

class ORTClapModel(ORTModel):
    def forward(self, *args, **kwargs):
        print(args, kwargs)
        feeds = {
            'input_ids': kwargs['input_ids'].numpy(),
            'attention_mask': kwargs['attention_mask'].numpy(),
            'input_features': kwargs['input_features'].numpy(),
        }
        out = self.model.run(None, feeds)

        return ClapOutput(**{
            "logits_per_audio": torch.from_numpy(out[0]).to(self.device),
            "logits_per_text": torch.from_numpy(out[1]).to(self.device),
            "text_embeds": torch.from_numpy(out[2]).to(self.device),
            "audio_embeds": torch.from_numpy(out[3]).to(self.device),
        })

model = ORTClapModel.from_pretrained("Xenova/clap-htsat-unfused", file_name='onnx/model.onnx') # onnx checkpoint
tokenizer = AutoTokenizer.from_pretrained("Xenova/clap-htsat-unfused")
processor = AutoProcessor.from_pretrained("Xenova/clap-htsat-unfused")

audio_url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/dog_barking.wav'

audio_classifier = pipeline(task="zero-shot-audio-classification", model=model, tokenizer=tokenizer, feature_extractor=processor.feature_extractor)
output = audio_classifier(audio_url, candidate_labels=["Sound of a dog", "Sound of vaccum cleaner"])
print(output)
# >>> [{'score': 0.6000569462776184, 'label': 'Sound of vaccum cleaner'}, {'score': 0.399943083524704, 'label': 'Sound of a dog'}]


pt_audio_classifier = pipeline(task="zero-shot-audio-classification", model="laion/clap-htsat-unfused")
output = pt_audio_classifier(audio_url, candidate_labels=["Sound of a dog", "Sound of vaccum cleaner"])
print(output)
# >>> [{'score': 0.9995301961898804, 'label': 'Sound of a dog'}, {'score': 0.0004698506381828338, 'label': 'Sound of vaccum cleaner'}]

Original model: https://huggingface.co/laion/clap-htsat-unfused
Converted model: https://huggingface.co/Xenova/clap-htsat-unfused

Any ideas @fxmarty?

Comment on lines +688 to +692
if self.normalized_config.model_type == 'clap':
# TODO figure out what this value is for?
# https://huggingface.co/laion/clap-htsat-fused uses 4
num_channels = 1
shape = [self.batch_size, num_channels, self.feature_size, self.num_mel_bins]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update: When fusion is enabled, num channels becomes 4 due to stacking of the fbank features.
Also, self.feature_size is incorrect for this. Should be self.nb_max_frames + 1. e.g., [1, 1, 1001, 64] is a valid shape.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if self.normalized_config.model_type in ['clap', 'clap_audio_model']:
    shape = [self.batch_size, 1, 1001, 64]
else:
    shape = [self.batch_size, self.feature_size, self.nb_max_frames]

@xenova
Copy link
Contributor Author

xenova commented Dec 1, 2023

I believe I have located the source, which is the same as this issue, and is caused by this function.

xenova added a commit to xenova/transformers.js that referenced this pull request Dec 4, 2023
xenova added a commit to xenova/transformers.js that referenced this pull request Dec 5, 2023
…ctrogram Transformer (`audio-classification`) (#427)

* Add FFT unit tests

* Refactor maths.js and audio.js

* Refactor audio processors

* Add support for AST models

* Add another audio-classification example

* Add audio processing unit tests

* Implement `log_mel='dB'` in `spectrogram` function

* Add `ClapFeatureExtractor`

* Implement `ClapFeatureExtractor` unit tests

* Add support for `CLAP`

* Add `ZeroShotAudioClassificationPipeline`

* Add listed support for  `zero-shot-audio-classification` pipeline tag

* Cleanup

* `let` -> `const`

* Update `mel_filter_bank` unit test

* Add `'Xenova/tiny-random-ClapModel'`

* Add `ClapAudioModelWithProjection` and `ClapTextModelWithProjection`

* Move audio validation to helper function

* Optimize `mel_filter_bank` computation

-30ms

* Update mel filters unit test

* Cleanup

* Optimizations

* Fix jsdoc

* Optimizations

* Add WIP conversion scripts

Will be updated once huggingface/optimum#1552 is merged
@fxmarty
Copy link
Collaborator

fxmarty commented Dec 5, 2023

good catch @xenova if we can squeeze huggingface/transformers#27790 into the upcoming transformers release it's nice :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants