New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support ORT whisper #420
Add support ORT whisper #420
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this great feature @mht-sharma 🚀 !!!
The API looks very good to me and most of my comments are nits. Did you manage to resolve the issue with the generations not matching the transformers
model?
SPEECH_SEQ2SEQ_ONNX_MODEL_DOCSTRING = r""" | ||
Arguments: | ||
input_features (`torch.FloatTensor`): | ||
Float values mel features extracted from the raw speech waveform. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Float values mel features extracted from the raw speech waveform. | |
Mel features extracted from the raw speech waveform. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -503,14 +582,14 @@ class ORTEncoder: | |||
The ONNX Runtime inference session associated to the encoder. | |||
""" | |||
|
|||
def __init__(self, session: onnxruntime.InferenceSession, device: torch.device): | |||
def __init__(self, session: onnxruntime.InferenceSession, device: torch.device, main_input_name: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a breaking change to the existing API (I'm not 100% sure)?
If yes, one option would be to use:
def __init__(self, session: onnxruntime.InferenceSession, device: torch.device, main_input_name: str): | |
def __init__(self, session: onnxruntime.InferenceSession, device: torch.device, main_input_name: str = "input_ids"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should not break anything. But added the default value for input name
QQ - Do we run the CI on gpu machine? The following test case tests the pipeline on GPU L1148. However, as per my understanding the pipeline is not running on GPU and it should fail or am I missing something. |
Hey @mht-sharma, the CI for GPU is scheduled to run nightly. |
return ORTEncoderForSpeechSeq2Seq(session=encoder_session, device=device, main_input_name=self.main_input_name) | ||
|
||
def get_encoder_onnx_config(encoder_config: PretrainedConfig) -> OnnxConfig: | ||
return SpeechSeq2SeqEncoderOnnxConfig(encoder_config, task="default") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, the class ORTModelForSpeechSeq2Seq
returns generic encoder / decoder ONNX configs. Same design is for the ORTModelForSeq2SeqLM
class.
AutoModelForSpeechSeq2Seq
contains multiple speech model classes within it. There may be some models which are not supported using the generic ONNX configs (Need to confirm)? Similar to AutoClasses we may need a structure to switch to appropriate Onnx Config based on the model types. Probably something we can work after refactor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the whole ONNX part is also changing, I think it's fine to allow time for a refactor to support audio if needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for making this beast of a model work with optimum
@mht-sharma !
Out of curiosity, is there any significant slow down when running the ORT model in CPU/GPU?
If not, I think this PR is in great shape - the main question is whether we should merge + refactor for the new ONNX exporter in #403 or wait until that PR is merged (at the expense of delaying this feature)?
WDYT @michaelbenayoun @echarlaix ?
return ORTEncoderForSpeechSeq2Seq(session=encoder_session, device=device, main_input_name=self.main_input_name) | ||
|
||
def get_encoder_onnx_config(encoder_config: PretrainedConfig) -> OnnxConfig: | ||
return SpeechSeq2SeqEncoderOnnxConfig(encoder_config, task="default") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the whole ONNX part is also changing, I think it's fine to allow time for a refactor to support audio if needed
@lewtun I think it makes more sense to merge the |
2d0a7ca
to
325668f
Compare
The PR is now updated with the following changes.
Gently pinging team members for review. @lewtun @michaelbenayoun |
# bind logits | ||
output_shape, output_buffer = self.prepare_output_buffer( | ||
batch_size=input_features.size(0), | ||
sequence_length=input_features.size(2) // 2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This way to get the output shape seems to be too specific for the Whisper
model and may not fit well for other SpeechSeq2seq
models? Is there a way to avoid giving output shapes? @JingyaHuang
Or if this is the only way should we rename the class to ORTModelForWhisperConditionalGeneration
? This may lead to have different classes for each model type in future. Probably something to discuss.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a way to avoid giving the output shape -> bind the output with OrtValue which will be the case for custom tasks #447 , but the flaw is that then you need to transfer ownership across frameworks which is something that we try to avoid. IMO, if you can infer the output shape, you shall bind it directly with a torch tensor. cc. @philschmid
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should then be handled case by case in the ORTModelForSpeechSeq2Seq
which can return appropriate ORTModelEncoder
based on the model_type.
Currently AutoModelForSpeechSeq2Seq
contains three model types: whisper
, speech_to_text
and speech_encoder_decoder
. So a simple if/else
can be cleaner approach. WDYT @JingyaHuang @lewtun
class ORTModelForSpeechSeq2Seq:
...
...
def _initialize_encoder(
self,
session: onnxruntime.InferenceSession,
config: transformers.PretrainedConfig,
device: torch.device,
use_io_binding: bool = True,
) -> ORTEncoderForSpeechSeq2Seq:
if config.model_type == "whisper":
return ORTEncoderForWhisper(...)
else:
return ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Second this. We could have a general model type -> ORTEncoder class
mapping, and ORTModelForSpeechSeq2Seq
would use this:
class ORTModelForSpeechSeq2Seq:
...
...
def _initialize_encoder(
self,
session: onnxruntime.InferenceSession,
config: transformers.PretrainedConfig,
device: torch.device,
use_io_binding: bool = True,
) -> ORTEncoderForSpeechSeq2Seq:
return _MODEL_TYPE_TO_ORTENCODER.get(model_type, "default")(...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated
@@ -120,6 +120,7 @@ class OnnxConfig(ExportConfig, ABC): | |||
"seq2seq-lm": OrderedDict({"logits": {0: "batch_size", 1: "decoder_sequence_length"}}), | |||
"sequence-classification": OrderedDict({"logits": {0: "batch_size"}}), | |||
"token-classification": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), | |||
"speech2seq-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it the "official" name?
We could take:
automatic-speech-recognition
to match the pipelinesspeech2text
@lewtun wdty?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the idea was to partially align with the underlying autoclass, but I agree automatic-speech-recognition
would be more intuitive.
In general (not for this PR), I think we should take the opportunity to align more closely with the Hub tasks, e.g. seq2seq-lm
could also be text2text-generation
right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright then I guess we can keep speech2seq-lm
for now since the other names are aligned to the AutoClass, and maybe change that (if needed) for all the tasks in another PR.
3d97bba
to
af7ddd3
Compare
@@ -120,6 +120,7 @@ class OnnxConfig(ExportConfig, ABC): | |||
"seq2seq-lm": OrderedDict({"logits": {0: "batch_size", 1: "decoder_sequence_length"}}), | |||
"sequence-classification": OrderedDict({"logits": {0: "batch_size"}}), | |||
"token-classification": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), | |||
"speech2seq-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright then I guess we can keep speech2seq-lm
for now since the other names are aligned to the AutoClass, and maybe change that (if needed) for all the tasks in another PR.
@@ -206,6 +207,17 @@ def is_torch_support_available(self) -> bool: | |||
return TORCH_VERSION >= self.MIN_TORCH_VERSION | |||
return False | |||
|
|||
@property | |||
def torch_to_onnx_input_map(self) -> Mapping[str, str]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would make it clear that it is needed when the dummy input names and the exported input names do not match.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated the doctoring
optimum/exporters/onnx/base.py
Outdated
@@ -229,6 +243,7 @@ def ordered_inputs(self, model: "PreTrainedModel") -> Mapping[str, Mapping[int, | |||
# TODO: figure out a smart way of re-ordering potential nested structures. | |||
# to_insert = sorted(to_insert, key=lambda t: t[0]) | |||
for name, dynamic_axes in to_insert: | |||
name = torch_to_onnx_input_map[name] if name in torch_to_onnx_input_map else name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
name = torch_to_onnx_input_map[name] if name in torch_to_onnx_input_map else name | |
name = self.torch_to_onnx_input_map.get(name, name) |
# bind logits | ||
output_shape, output_buffer = self.prepare_output_buffer( | ||
batch_size=input_features.size(0), | ||
sequence_length=input_features.size(2) // 2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Second this. We could have a general model type -> ORTEncoder class
mapping, and ORTModelForSpeechSeq2Seq
would use this:
class ORTModelForSpeechSeq2Seq:
...
...
def _initialize_encoder(
self,
session: onnxruntime.InferenceSession,
config: transformers.PretrainedConfig,
device: torch.device,
use_io_binding: bool = True,
) -> ORTEncoderForSpeechSeq2Seq:
return _MODEL_TYPE_TO_ORTENCODER.get(model_type, "default")(...)
* added whisper to exporters * Removed reduntant code * Added io binding for ORTModelForSpeechSeq2Seq
63a255f
to
14358a0
Compare
Please change this line to (add .from_pretrained)
|
Hello, I have seen this discussion which shows how can we use pipeline for audios more than 30s and how to change task to transcribe; the chunk works for onnx too, but I couldn't change the task to transcribe in pipeline configuration for onnx. Thanks! |
What does this PR do?
This PR enables the export of Whisper model to ONNX.
To enable this new modality, I had refactored the existing
ORTModelForConditionalGeneration
to add support for the multimodal models.The PR has a dependency on transformers PR 19525 which integrates the onnx config for the whisper model and adds support for the audio preprocessor.
Usage
Using Transformers
AutoModelForSpeechSeq2Seq
Using Optimum
ORTModelForSpeechSeq2Seq
Before submitting