Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Whisper: Add multilingual support, Updated with latest ORT #374

Merged
merged 12 commits into from
Jun 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions examples/whisper/README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Whisper optimization using ORT toolchain
This folder contains a sample use case of Olive to optimize a [Whisper](https://huggingface.co/openai/whisper-base) model using ONNXRuntime tools.
This folder contains a sample use case of Olive to optimize a [Whisper](https://huggingface.co/openai/whisper-tiny) model using ONNXRuntime tools.

Performs optimization pipeline:
- CPU, FP32: *PyTorch Model -> Onnx Model -> Transformers Optimized Onnx Model -> Insert Beam Search Op -> Insert Pre/Post Processing Ops*
- CPU, INT8: *PyTorch Model -> Onnx Model -> Dynamic Quantized Onnx Model -> Insert Beam Search Op -> Insert Pre/Post Processing Ops*
- CPU, INT8: *PyTorch Model -> Onnx Model -> Transformers Optimized Onnx Model -> Dynamic Quantized Onnx Model -> Insert Beam Search Op -> Insert Pre/Post Processing Ops*
- GPU, FP32: *PyTorch Model -> Onnx Model -> Transformers Optimized Onnx Model -> Insert Beam Search Op -> Insert Pre/Post Processing Ops*
- GPU, FP16: *PyTorch Model -> Onnx Model -> Transformers Optimized Onnx Model -> Mixed Precision Model -> Insert Beam Search Op -> Insert Pre/Post Processing Ops*
- GPU, INT8: *PyTorch Model -> Onnx Model -> Dynamic Quantized Onnx Model -> Insert Beam Search Op -> Insert Pre/Post Processing Ops*
Expand All @@ -25,7 +25,7 @@ python -m pip install -r requirements.txt

### Prepare workflow config json
```
python prepare_whisper_configs.py [--no_audio_decoder]
python prepare_whisper_configs.py [--no_audio_decoder] [--multiligual]
```

`--no_audio_decoder` is optional. If not provided, will use audio decoder in the preprocessing ops.
Expand All @@ -36,6 +36,35 @@ python prepare_whisper_configs.py [--no_audio_decoder]
python -m pip install librosa
```

`--multiligual` is optional. If provided, the model produced will support multiple languages that are controlled using `decoder_input_ids` input.

**Note:** Only supported in ONNXRuntime 1.16.0+ which is not released yet. Must be built from or after commit https://github.com/microsoft/onnxruntime/commit/4b69226fca914753844a3291818ce23ac2f00d8c.

**Example of decoder_input_ids:**
```python
import numpy as np
from transformers import AutoConfig, AutoProcessor


model = "openai/whisper-tiny"
config = AutoConfig.from_pretrained(model)
processor = AutoProcessor.from_pretrained(model)

# English transcription
forced_decoder_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe")
# forced_decoder_ids is of the format [(1, 50259), (2, 50359), (3, 50363)] and needs to be
# of the format [50258, 50259, 50359, 50363] where 50258 is the start token id
forced_decoder_ids = [config.decoder_start_token_id] + list(map(lambda token: token[1], forced_decoder_ids))

# If you don't want to provide specific decoder input ids or you want
# Whisper to predict the output language and task, you can set
# forced_decoder_ids = [config.decoder_start_token_id]
# [50258]

# decoder input ids
decoder_input_ids = np.array([forced_decoder_ids], dtype=np.int32)
```

## Run the config to optimize the model
First, install required packages according to passes.
```bash
Expand Down
20 changes: 9 additions & 11 deletions examples/whisper/code/user_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,18 @@


def get_encoder_decoder_init():
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
return WhisperEncoderDecoderInit(
model,
model,
None,
model.config,
decoder_start_token_id=None,
)


def get_decoder():
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
return WhisperDecoder(model, None, model.config)
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
return WhisperDecoder(model, model.config)


def get_encdec_io_config():
Expand All @@ -47,8 +46,8 @@ def get_encdec_io_config():
input_names = ["encoder_input_ids"]

# ONNX exporter might mark dimension like 'Transposepresent_value_self_1_dim_2' in shape inference.
# We use a workaround here: first use dim_param "1" for sequence_length, and later change to dim_value.
sequence_length = "1"
# We use a workaround here: first use dim_param str(model.config.encoder_attention_heads) for num_heads,
# and later change to dim_value.
num_heads = str(model.config.encoder_attention_heads)
hidden_size = str(model.config.d_model)
head_size = str(model.config.d_model // model.config.encoder_attention_heads)
Expand All @@ -61,15 +60,15 @@ def get_encdec_io_config():
},
"logits": {
0: "batch_size",
1: sequence_length,
1: "decode_sequence_length",
},
}

if use_decoder_input_ids:
input_names.append("decoder_input_ids")
dynamic_axes["decoder_input_ids"] = {
0: "batch_size",
1: sequence_length,
1: "decode_sequence_length",
}

for name in present_names:
Expand All @@ -85,15 +84,15 @@ def get_encdec_io_config():
dynamic_axes[name] = {
0: "batch_size",
1: num_heads,
2: sequence_length,
2: "decode_sequence_length",
3: head_size,
}

return {
"input_names": input_names,
"dynamic_axes": dynamic_axes,
"output_names": output_names,
"string_to_int_dim_params": [sequence_length, num_heads, hidden_size, head_size],
"string_to_int_dim_params": [num_heads, hidden_size, head_size],
}


Expand All @@ -114,7 +113,6 @@ def get_dec_io_config():

dynamic_axes = {
"input_ids": {0: "batch_size"},
"encoder_hidden_states": {0: "batch_size", 1: "encode_sequence_length / 2"},
"logits": {0: "batch_size", 1: "sequence_length"},
}

Expand Down
6 changes: 6 additions & 0 deletions examples/whisper/code/whisper_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,13 @@ def __init__(self, data_dir: str, use_audio_decoder: bool = True):
"num_return_sequences": np.asarray([1], dtype=np.int32),
"length_penalty": np.asarray([1.0], dtype=np.float32),
"repetition_penalty": np.asarray([1.0], dtype=np.float32),
# attention_mask only used when version < 1.16.0
"attention_mask": np.zeros((1, self.N_MELS, self.N_FRAMES)).astype(np.int32),
# decoder_input_ids only used when version >= 1.16.0 and multilingual is True
# auto detect language and task
"decoder_input_ids": np.asarray([[50258]], dtype=np.int32),
# English, transcription
# "decoder_input_ids": np.asarray([[50258, 50259, 50359, 50363]], dtype=np.int32),
}
self.data.append(inputs)

Expand Down
9 changes: 3 additions & 6 deletions examples/whisper/code/whisper_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,18 @@


class WhisperDecoderInit(torch.nn.Module):
"""A Whisper decoder with LM head to create initial past key values.
"""A Whisper decoder to create initial past key values.
This model is only called once during starting decoding.
"""

def __init__(
self,
decoder: torch.nn.Module,
lm_head: torch.nn.Module,
config: WhisperConfig,
decoder_start_token_id: int = None,
):
super().__init__()
self.decoder = decoder
self.lm_head = lm_head
self.config = config
self.decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
Expand Down Expand Up @@ -52,12 +50,11 @@ def forward(


class WhisperDecoder(torch.nn.Module):
"""A Whisper decoder with LM head and past key values"""
"""A Whisper decoder and past key values"""

def __init__(self, decoder, lm_head, config):
def __init__(self, decoder, config):
super().__init__()
self.decoder = decoder
self.lm_head = lm_head
self.config = config

def forward(self, decoder_input_ids, *past):
Expand Down
5 changes: 2 additions & 3 deletions examples/whisper/code/whisper_encoder_decoder_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@ def __init__(
self,
encoder: torch.nn.Module,
decoder: torch.nn.Module,
lm_head: torch.nn.Module,
config: WhisperConfig,
decoder_start_token_id: Optional[int] = None,
):
super().__init__()
self.config = config
self.whisper_encoder = WhisperEncoder(encoder, config)
self.whisper_decoder_init = WhisperDecoderInit(decoder, lm_head, config, decoder_start_token_id)
self.whisper_decoder_init = WhisperDecoderInit(decoder, config, decoder_start_token_id)

def forward(
self,
Expand Down Expand Up @@ -64,7 +63,7 @@ def create_dummy(
decoder_input_ids = None
if use_decoder_input_ids:
dtype = torch.int32 if use_int32_inputs else torch.int64
decoder_input_ids = torch.ones((batch_size, 1), dtype=dtype, device=device) * config.decoder_start_token_id
decoder_input_ids = torch.ones((batch_size, 2), dtype=dtype, device=device) * config.decoder_start_token_id

return WhisperEncoderDecoderInitInputs(encoder_inputs.input_ids, decoder_input_ids)

Expand Down
25 changes: 24 additions & 1 deletion examples/whisper/prepare_whisper_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,19 @@
from pathlib import Path
from urllib import request

from onnxruntime import __version__ as OrtVersion
from packaging import version
from transformers import WhisperConfig

SUPPORTED_WORKFLOWS = {
("cpu", "fp32"): ["conversion", "transformers_optimization", "insert_beam_search", "prepost"],
("cpu", "int8"): ["conversion", "onnx_dynamic_quantization", "insert_beam_search", "prepost"],
("cpu", "int8"): [
"conversion",
"transformers_optimization",
"onnx_dynamic_quantization",
"insert_beam_search",
"prepost",
],
("gpu", "fp32"): ["conversion", "transformers_optimization", "insert_beam_search", "prepost"],
("gpu", "fp16"): ["conversion", "transformers_optimization", "mixed_precision", "insert_beam_search", "prepost"],
("gpu", "int8"): ["conversion", "onnx_dynamic_quantization", "insert_beam_search", "prepost"],
Expand All @@ -26,12 +34,24 @@ def get_args(raw_args):
action="store_true",
help="Don't use audio decoder in the model. Default: False",
)
parser.add_argument(
"--multilingual",
action="store_true",
help="Support using model for multiple languages. Only supported in ORT >= 1.16.0. Default: False",
)
return parser.parse_args(raw_args)


def main(raw_args=None):
args = get_args(raw_args)

# version check
version_1_16 = version.parse(OrtVersion) >= version.parse("1.16.0")

# multi-lingual support check
if args.multilingual and not version_1_16:
raise ValueError("Multi-lingual support is only supported in ORT >= 1.16.0")

# load template
template_json = json.load(open("whisper_template.json", "r"))

Expand All @@ -46,6 +66,9 @@ def main(raw_args=None):
template_json["passes"]["transformers_optimization"]["config"]["num_heads"] = whisper_config.encoder_attention_heads
template_json["passes"]["transformers_optimization"]["config"]["hidden_size"] = whisper_config.d_model

# update multi-lingual support
template_json["passes"]["insert_beam_search"]["config"]["use_forced_decoder_ids"] = args.multilingual

# download audio test data
test_audio_path = download_audio_test_data()
template_json["passes"]["prepost"]["config"]["tool_command_args"]["testdata_filepath"] = str(test_audio_path)
Expand Down
2 changes: 1 addition & 1 deletion examples/whisper/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
onnx==1.13.1
onnxruntime==1.15.0
onnxruntime>=1.15.0
onnxruntime-extensions>=0.8.0
torch>=1.13.1
transformers>=4.23.1
51 changes: 26 additions & 25 deletions examples/whisper/test_transcription.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,20 @@
# --------------------------------------------------------------------------
import argparse
import json
import shutil
import sys
import tempfile
from pathlib import Path

import numpy as np
import onnxruntime as ort
from onnxruntime_extensions import PyOrtFunction

from olive.evaluator.olive_evaluator import OnnxEvaluator
from olive.model import ONNXModel

sys.path.append(str(Path(__file__).parent / "code"))

from whisper_dataset import WhisperDataset # noqa: E402

# hard-coded audio hyperparameters
# copied from https://github.com/openai/whisper/blob/main/whisper/audio.py#L12
SAMPLE_RATE = 16000
Expand Down Expand Up @@ -53,33 +59,28 @@ def main(raw_args=None):

# load output model onnx
olive_model = ONNXModel(**output_model_json["config"])
model = PyOrtFunction.from_model(olive_model.model_path)

# load audio data
if not args.audio_path:
args.audio_path = Path(config["passes"]["prepost"]["config"]["tool_command_args"]["testdata_filepath"])
use_audio_decoder = config["passes"]["prepost"]["config"]["tool_command_args"]["use_audio_decoder"]
if use_audio_decoder:
with open(args.audio_path, "rb") as _f:
audio_blob = np.asarray(list(_f.read()), dtype=np.uint8)
else:
import librosa

audio_blob, _ = librosa.load(args.audio_path)

audio_blob = np.expand_dims(audio_blob, axis=0)

output_text = model(
audio_blob,
np.asarray([200], dtype=np.int32),
np.asarray([0], dtype=np.int32),
np.asarray([2], dtype=np.int32),
np.asarray([1], dtype=np.int32),
np.asarray([1.0], dtype=np.float32),
np.asarray([1.0], dtype=np.float32),
np.zeros((1, N_MELS, N_FRAMES)).astype(np.int32),
)
return output_text[0]

# temporary directory for storing audio file
temp_dir = tempfile.TemporaryDirectory()
temp_dir_path = Path(temp_dir.name)
temp_audio_path = temp_dir_path / Path(args.audio_path).name
shutil.copy(args.audio_path, temp_audio_path)

# dataset
dataset = WhisperDataset(temp_dir_path)

# create inference session
session = olive_model.prepare_session(None, "cpu")

# get output
input, _ = dataset[0]
input = OnnxEvaluator.format_input(input, olive_model.get_io_config())
output = session.run(None, input)
return output[0][0]


if __name__ == "__main__":
Expand Down
9 changes: 6 additions & 3 deletions examples/whisper/whisper_template.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"script_dir": "code",
"hf_config": {
"model_class" : "WhisperForConditionalGeneration",
"model_name" : "openai/whisper-tiny.en",
"model_name" : "openai/whisper-tiny",
"components" : [
{
"name": "encoder_decoder_init",
Expand Down Expand Up @@ -80,14 +80,17 @@
}
},
"insert_beam_search" : {
"type" : "InsertBeamSearch"
"type" : "InsertBeamSearch",
"config": {
"use_forced_decoder_ids": "<place_holder>"
}
},
"prepost": {
"type": "AppendPrePostProcessingOps",
"config": {
"tool_command": "whisper",
"tool_command_args": {
"model_name": "openai/whisper-tiny.en",
"model_name": "openai/whisper-tiny",
"testdata_filepath": "<place_holder>",
"use_audio_decoder" : "<place_holder>"
}
Expand Down
Loading