Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SpeechT5 ONNX support #1404

Merged
merged 18 commits into from
Oct 18, 2023
Merged

SpeechT5 ONNX support #1404

merged 18 commits into from
Oct 18, 2023

Conversation

fxmarty
Copy link
Collaborator

@fxmarty fxmarty commented Sep 21, 2023

This PR adds the support of SpeechT5 ONNX export.

@fxmarty
Copy link
Collaborator Author

fxmarty commented Sep 21, 2023

Hi @xenova, a long awaited one =) This PR is still missing tests, documentation and KV cache support but it is in a good state already. I'll finish it next week. For now I only implemented the text-to-speech task following transformers generate_speech.

Working version: optimum-cli export onnx --model microsoft/speecht5_tts speecht5_onnx --model-kwargs '{"vocoder": "microsoft/speecht5_hifigan"}'

Also left to do align the -with-past and the variant args

@xenova
Copy link
Contributor

xenova commented Sep 21, 2023

Wow this is amazing - thanks so much @fxmarty! I've uploaded my model files here. I'll test it in transformers.js, and I'll update those files when the other options are available. I don't suppose you have any python code which I can use for testing,

something similar to this?
from transformers import AutoTokenizer, pipeline
from optimum.onnxruntime import ORTModelForSeq2SeqLM

session = ORTModelForSeq2SeqLM.from_pretrained('Xenova/ipt-350m', subfolder='onnx')
tokenizer = AutoTokenizer.from_pretrained('Xenova/ipt-350m')

generator_ort = pipeline(
    task="text-generation",
    model=session,
    tokenizer=tokenizer,
)

generator_ort('La nostra azienda')
# [{'generated_text': "La nostra azienda è specializzata nella vendita di prodotti per l'igiene orale e per la salute."}]

Or will the ORTModelForTextToWaveform and/or ORTModelForTextToSpectrogram be coming later in this PR?

The speecht5 docs have a nice example here too.


Also, I had to downgrade to onnxruntime==1.15.1. 1.16.0 gives this error:

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/content/transformers.js/scripts/convert.py", line 18, in <module>
    from onnxruntime.quantization import (
  File "/usr/local/lib/python3.10/dist-packages/onnxruntime/quantization/__init__.py", line 1, in <module>
    from .calibrate import (  # noqa: F401
  File "/usr/local/lib/python3.10/dist-packages/onnxruntime/quantization/calibrate.py", line 21, in <module>
    from .quant_utils import apply_plot, load_model_with_shape_infer, smooth_distribution
  File "/usr/local/lib/python3.10/dist-packages/onnxruntime/quantization/quant_utils.py", line 115, in <module>
    onnx_proto.TensorProto.FLOAT8E4M3FN: float8e4m3fn,
AttributeError: FLOAT8E4M3FN

I assume this is because I had onnx<1.14 installed, but just posting here in case.

@fxmarty
Copy link
Collaborator Author

fxmarty commented Sep 26, 2023

I'll wrap up this PR and add a python example :)

@fxmarty
Copy link
Collaborator Author

fxmarty commented Sep 26, 2023

@xenova something like this (not optimized at all). Does that work for you?

import onnxruntime as ort
import numpy as np
import soundfile as sf
from transformers import SpeechT5Processor

encoder_path = "/path/to/encoder_model.onnx"
decoder_path = "/path/to/decoder_model_merged.onnx"
postnet_and_vocoder_path = "/path/to/decoder_postnet_and_vocoder.onnx"

encoder = ort.InferenceSession(encoder_path, providers=["CPUExecutionProvider"])
decoder = ort.InferenceSession(decoder_path, providers=["CPUExecutionProvider"])
postnet_and_vocoder = ort.InferenceSession(postnet_and_vocoder_path, providers=["CPUExecutionProvider"])

def add_fake_pkv(inputs):
    shape = (1, 12, 0, 64)
    for i in range(6):
        inputs[f"past_key_values.{i}.encoder.key"] = np.zeros(shape).astype(np.float32)
        inputs[f"past_key_values.{i}.encoder.value"] = np.zeros(shape).astype(np.float32)
        inputs[f"past_key_values.{i}.decoder.key"] = np.zeros(shape).astype(np.float32)
        inputs[f"past_key_values.{i}.decoder.value"] = np.zeros(shape).astype(np.float32)
    return inputs

def add_real_pkv(inputs, previous_outputs, cross_attention_pkv):
    for i in range(6):
        inputs[f"past_key_values.{i}.encoder.key"] = cross_attention_pkv[f"present.{i}.encoder.key"]
        inputs[f"past_key_values.{i}.encoder.value"] = cross_attention_pkv[f"present.{i}.encoder.value"]
        inputs[f"past_key_values.{i}.decoder.key"] = previous_outputs[f"present.{i}.decoder.key"]
        inputs[f"past_key_values.{i}.decoder.value"] = previous_outputs[f"present.{i}.decoder.value"]
    return inputs

processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")

inputs = processor(text="Hello, my dog is cute", return_tensors="np")

inp = {
    "input_ids": inputs["input_ids"]
}

outputs = encoder.run(None, inp)
outputs = {output_key.name: outputs[idx] for idx, output_key in enumerate(encoder.get_outputs())}

encoder_last_hidden_state = outputs["encoder_outputs"]
encoder_attention_mask = outputs["encoder_attention_mask"]

minlenratio = 0.0
maxlenratio = 20.0
reduction_factor = 2
threshold = 0.5
num_mel_bins = 80

maxlen = int(encoder_last_hidden_state.shape[1] * maxlenratio / reduction_factor)
minlen = int(encoder_last_hidden_state.shape[1] * minlenratio / reduction_factor)

spectrogram = []
cross_attentions = []
past_key_values = None
idx = 0
cross_attention_pkv = None
use_cache_branch = False

speaker_embeddings = speaker_embeddings = np.zeros((1, 512)).astype(np.float32)

while True:
    idx += 1

    decoder_inputs = {}
    decoder_inputs["use_cache_branch"] = np.array([use_cache_branch])
    decoder_inputs["encoder_attention_mask"] = encoder_attention_mask
    decoder_inputs["speaker_embeddings"] = speaker_embeddings

    if not use_cache_branch:
        decoder_inputs = add_fake_pkv(decoder_inputs)
        decoder_inputs["output_sequence"] = np.zeros((1, 1, num_mel_bins)).astype(np.float32)
        use_cache_branch = True
        decoder_inputs["encoder_hidden_states"] = encoder_last_hidden_state
    else:
        decoder_inputs = add_real_pkv(decoder_inputs, decoder_outputs, cross_attention_pkv)
        decoder_inputs["output_sequence"] = decoder_outputs["output_sequence_out"]
        decoder_inputs["encoder_hidden_states"] = np.zeros((1, 0, 768)).astype(np.float32)  # useless when cross-attention KV has already been computed

    decoder_outputs = decoder.run(None, decoder_inputs)
    decoder_outputs = {output_key.name: decoder_outputs[idx] for idx, output_key in enumerate(decoder.get_outputs())}

    if idx == 1:  # i.e. use_cache_branch = False
        cross_attention_pkv = {key: val for key, val in decoder_outputs.items() if ("encoder" in key and "present" in key)}

    prob = decoder_outputs["prob"]
    spectrum = decoder_outputs["spectrum"]

    spectrogram.append(spectrum)
    
    print("prob", prob)

    # Finished when stop token or maximum length is reached.
    if idx >= minlen and (int(sum(prob >= threshold)) > 0 or idx >= maxlen):
        print("len spectrogram", len(spectrogram))
        spectrogram = np.concatenate(spectrogram)
        vocoder_output = postnet_and_vocoder.run(None, {"spectrogram": spectrogram})
        break

sf.write("speech.wav", vocoder_output[0], samplerate=16000)

@fxmarty fxmarty marked this pull request as ready for review October 5, 2023 14:09
Copy link
Contributor

@xenova xenova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works with transformers.js! 🚀 xenova/transformers.js#345

@fxmarty
Copy link
Collaborator Author

fxmarty commented Oct 6, 2023

@echarlaix probably you would prefer to merge first the PR for the decoders? I expect some conflicts between those two.

@fxmarty
Copy link
Collaborator Author

fxmarty commented Oct 16, 2023

@echarlaix WDYT?

@echarlaix
Copy link
Collaborator

@echarlaix probably you would prefer to merge first the PR for the decoders? I expect some conflicts between those two.

Yes that would be great, thanks for letting me know. To me we can merge the decoder PR cc @michaelbenayoun

Copy link
Collaborator

@echarlaix echarlaix left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super cool thanks @fxmarty

optimum/exporters/onnx/model_configs.py Outdated Show resolved Hide resolved
# Attempt to merge only if the decoder was exported without/with past
if self.use_past is True and len(models_and_onnx_configs) == 3:
# Attempt to merge only if the decoder was exported without/with past, and ignore seq2seq models exported with text-generation task
if len(onnx_files_subpaths) >= 3 and self.use_past is True or self.variant == "with-past":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to check self.variant ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure. I'll need to double check.

@fxmarty fxmarty merged commit 554a83a into huggingface:main Oct 18, 2023
65 of 68 checks passed
Copy link
Contributor

@baskrahmer baskrahmer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accidentally clicked review 😝 meant to just submit a comment

)
model_type = config.model_type.replace("_", "-")
if model_type not in TasksManager._SUPPORTED_MODEL_TYPE:
custom_architecture = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fxmarty this line currently does nothing since it is set to False again in line 381. Do you want to have a look?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! I'll fix

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants