Skip to content

Commit

Permalink
Musicgen ONNX export (text-conditional only) (#1779)
Browse files Browse the repository at this point in the history
* WIP but need to work on encodec first

* musicgen onnx export

* better logs

* add tests

* rename audio_encoder_decode.onnx to encodec_decode.onnx

* fix num heads in pkv

* nits

* add build_delay_pattern_mask

* fix wrong hidden_size for cross attention pkv

* fix tests

* update doc
  • Loading branch information
fxmarty committed Apr 10, 2024
1 parent 5ea14c1 commit 2f75b0d
Show file tree
Hide file tree
Showing 12 changed files with 637 additions and 13 deletions.
1 change: 1 addition & 0 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
- MobileNet v2
- MPNet
- MT5
- Musicgen (text-conditional only)
- Nystromformer
- OWL-ViT
- Pegasus
Expand Down
7 changes: 7 additions & 0 deletions optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,13 @@ def __init__(
)

self._normalized_config.DECODER_NORMALIZED_CONFIG_CLASS = self._decoder_onnx_config._normalized_config
self._normalized_config.DECODER_NORMALIZED_CONFIG_CLASS = self._decoder_onnx_config._normalized_config
self._normalized_config.DECODER_NORMALIZED_CONFIG_CLASS.encoder_num_attention_heads = (
self._decoder_onnx_config._normalized_config.num_attention_heads
)
self._normalized_config.DECODER_NORMALIZED_CONFIG_CLASS.decoder_num_attention_heads = (
self._decoder_onnx_config._normalized_config.num_attention_heads
)

if isinstance(self._decoder_onnx_config, OnnxSeq2SeqConfigWithPast):
self._past_key_values_generator = (
Expand Down
1 change: 1 addition & 0 deletions optimum/exporters/onnx/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,6 @@

SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED = [
"bart",
"musicgen",
"whisper",
]
5 changes: 4 additions & 1 deletion optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def _run_validation(

model_kwargs = model_kwargs if model_kwargs is not None else {}

logger.info(f"Validating ONNX model {onnx_model.as_posix()}...")
logger.info(f"\nValidating ONNX model {onnx_model.as_posix()}...")

if atol is None:
atol = config.ATOL_FOR_VALIDATION
Expand Down Expand Up @@ -764,6 +764,9 @@ def export_models(
output_path = output_dir / output_name
output_path.parent.mkdir(parents=True, exist_ok=True)

logger.info(
f"\n***** Exporting submodel {i + 1}/{len(models_and_onnx_configs)}: {submodel.__class__.__name__} *****"
)
outputs.append(
export(
model=submodel,
Expand Down
307 changes: 305 additions & 2 deletions optimum/exporters/onnx/model_configs.py

Large diffs are not rendered by default.

137 changes: 136 additions & 1 deletion optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,6 @@ def patched_forward(*args, **kwargs):
elif self.real_config._behavior == "decoder" and self.real_config.use_past_in_inputs:
# The filtering happens here. The decoder with use_past_in_inputs=True corresponds to the autoregressive one.
filterd_outputs[name] = tuple([v[:2] for v in value])

return filterd_outputs

self.patched_forward = patched_forward
Expand Down Expand Up @@ -796,3 +795,139 @@ def patched_forward(input_ids, attention_mask, pixel_values):
return {"text_embeds": text_embeds, "image_embeds": image_embeds}

self.patched_forward = patched_forward


# Triu with possible dynamic `diagonal` argument. Not possible with torch.triu unfortunately.
def triu_onnx(x, diagonal=0):
l, w = x.shape
arange_rows = torch.arange(l, device=x.device)

arange_cols = torch.arange(w, device=x.device)
mask = arange_cols.expand(l, w)

arange_rows = arange_rows[:, None] + diagonal
mask = mask >= arange_rows
return x.masked_fill(mask == 0, 0)


def patched_build_delay_pattern_mask(self, input_ids: torch.Tensor, pad_token_id: int, max_length: int = None):
# (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
input_ids = input_ids.reshape(-1, self.num_codebooks, input_ids.shape[-1])
bsz, num_codebooks, seq_len = input_ids.shape

max_length = max_length if max_length is not None else self.generation_config.max_length
input_ids_shifted = torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1

channel_codebooks = num_codebooks // 2 if self.config.audio_channels == 2 else num_codebooks
# we only apply the mask if we have a large enough seq len - otherwise we return as is
if max_length < 2 * channel_codebooks - 1:
raise NotImplementedError("Not supported in ONNX export. Please open an issue in Optimum repository.")

# fill the shifted ids with the prompt entries, offset by the codebook idx
for codebook in range(channel_codebooks):
if self.config.audio_channels == 1:
# mono channel - loop over the codebooks one-by-one
input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook]
else:
# left/right channels are interleaved in the generated codebooks, so handle one then the other
input_ids_shifted[:, 2 * codebook, codebook : seq_len + codebook] = input_ids[:, 2 * codebook]
input_ids_shifted[:, 2 * codebook + 1, codebook : seq_len + codebook] = input_ids[:, 2 * codebook + 1]

# construct a pattern mask that indicates the positions of padding tokens for each codebook
# first fill the upper triangular part (the EOS padding)
# NOTE: We could use torch.bool here, but PyTorch the complains with `The exported ONNX model failed ONNX shape inference.`
# Using int8 leads to `Could not find an implementation for Where`
delay_pattern = triu_onnx(
torch.ones((channel_codebooks, max_length), dtype=torch.int32), diagonal=max_length - channel_codebooks + 1
)

# NOTE: We could use torch.bool here, but PyTorch the complains with `The exported ONNX model failed ONNX shape inference.`
# Using int32 leads to `Could not find an implementation for Trilu`, hence int64 here

# then fill the lower triangular part (the BOS padding)
delay_pattern = delay_pattern + torch.tril(torch.ones((channel_codebooks, max_length), dtype=torch.int64))
delay_pattern = delay_pattern.to(torch.bool)

if self.config.audio_channels == 2:
# for left/right channel we need to duplicate every row of the pattern mask in an interleaved fashion
delay_pattern = delay_pattern.repeat_interleave(2, dim=0)

mask = ~delay_pattern.to(input_ids.device)
input_ids = mask * input_ids_shifted + ~mask * pad_token_id

# find the first position to start generating - this is the first place we have the -1 token
# and will always be in the first codebook (since it has no codebook offset)
first_codebook_ids = input_ids[:, 0, :]
start_ids = (first_codebook_ids == -1).nonzero()[:, 1]

# TODO: Is this OK?
first_start_id = start_ids.min()

# (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
pattern_mask = input_ids.reshape(bsz * num_codebooks, -1)
input_ids_edited = input_ids[..., :first_start_id].reshape(bsz * num_codebooks, -1)
return {"input_ids_edited": input_ids_edited, "delay_pattern_mask": pattern_mask}


class MusicgenModelPatcher(Seq2SeqModelPatcher):
def __enter__(self):
self.patch_ops()
if self.real_config.model_part == "build_delay_pattern_mask":
# For build_delay_pattern_mask, we need to override the signature too.
self._model.forward = types.MethodType(patched_build_delay_pattern_mask, self._model)
else:
setattr(self._model, self.orig_forward_name, self.patched_forward)

def __exit__(self, exc_type, exc_value, traceback):
self.restore_ops()
if self.real_config.model_part == "build_delay_pattern_mask":
self._model.forward = self.original_decoder_forward
else:
setattr(self._model, self.orig_forward_name, self.orig_forward)

def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)

if config.model_part == "build_delay_pattern_mask":
self.original_decoder_forward = self.orig_forward
elif config.model_part == "encodec_decode":
# EncodecModel.forward -> EncodecModel.decode
@functools.wraps(self.orig_forward)
def patched_forward(
input_values: Optional["torch.Tensor"] = None,
padding_mask: Optional["torch.Tensor"] = None,
audio_codes: Optional["torch.Tensor"] = None,
bandwidth: Optional[float] = None,
audio_scales: Optional["torch.Tensor"] = None,
return_dict: Optional[bool] = None,
):
chunk_length = self.real_config._config.audio_encoder.chunk_length
if chunk_length is None:
if audio_scales is not None:
audio_scales = audio_scales[0]

if len(audio_codes) != 1:
raise ValueError(f"Expected one frame, got {len(audio_codes)}")
audio_values = self._model._decode_frame(audio_codes[0], audio_scales)
else:
raise ValueError("Not supported, a meaningful error should have been raised ahead.")
decoded_frames = []

for frame, scale in zip(audio_codes, audio_scales):
frames = self._model._decode_frame(frame, scale)
decoded_frames.append(frames)

audio_values = self._model._linear_overlap_add(decoded_frames, self.config.chunk_stride or 1)

# truncate based on padding mask
if padding_mask is not None and padding_mask.shape[-1] < audio_values.shape[-1]:
audio_values = audio_values[..., : padding_mask.shape[-1]]

return {"audio_values": audio_values}

self.patched_forward = patched_forward
7 changes: 6 additions & 1 deletion optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ class TasksManager:
"object-detection": "AutoModelForObjectDetection",
"question-answering": "AutoModelForQuestionAnswering",
"semantic-segmentation": "AutoModelForSemanticSegmentation",
"text-to-audio": "AutoModelForTextToSpectrogram",
"text-to-audio": ("AutoModelForTextToSpectrogram", "AutoModelForTextToWaveform"),
"text-generation": "AutoModelForCausalLM",
"text2text-generation": "AutoModelForSeq2SeqLM",
"text-classification": "AutoModelForSequenceClassification",
Expand Down Expand Up @@ -334,6 +334,7 @@ class TasksManager:

# TODO: some models here support text-generation export but are not supported in ORTModelForCausalLM
# Set of model topologies we support associated to the tasks supported by each topology and the factory
# TODO: remove `-with-past` tasks and rather rely on `variant`.
_SUPPORTED_MODEL_TYPE = {
"audio-spectrogram-transformer": supported_tasks_mapping(
"feature-extraction",
Expand Down Expand Up @@ -813,6 +814,10 @@ class TasksManager:
"text2text-generation-with-past",
onnx="MT5OnnxConfig",
),
"musicgen": supported_tasks_mapping(
"text-to-audio", # "variant" handles the "-with-past". We should generalize that.
onnx="MusicgenOnnxConfig",
),
"m2m-100": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
Expand Down
46 changes: 46 additions & 0 deletions optimum/exporters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,50 @@ def get_stable_diffusion_models_for_export(
return models_for_export


def get_musicgen_models_for_export(model: Union["PreTrainedModel", "TFPreTrainedModel"], config: "ExportConfig"):
models_for_export = {
"text_encoder": model.text_encoder,
"encodec_decode": model.audio_encoder,
# For the decoder, we do not pass model.decoder because we may need to export model.enc_to_dec_proj
DECODER_NAME: model,
DECODER_WITH_PAST_NAME: model,
"build_delay_pattern_mask": model.decoder,
}

text_encoder_config = config.__class__(
model.config, task=config.task, legacy=False, model_part="text_encoder", variant=config.variant
)
models_for_export["text_encoder"] = (models_for_export["text_encoder"], text_encoder_config)

audio_encoder_config = config.__class__(
model.config, task=config.task, legacy=False, model_part="encodec_decode", variant=config.variant
)
models_for_export["encodec_decode"] = (models_for_export["encodec_decode"], audio_encoder_config)

use_past = "with-past" in config.variant
decoder_export_config = config.with_behavior("decoder", use_past=use_past, use_past_in_inputs=False)
decoder_export_config.model_part = "decoder"
models_for_export[DECODER_NAME] = (models_for_export[DECODER_NAME], decoder_export_config)

if "with-past" in config.variant:
decoder_export_config_with_past = config.with_behavior("decoder", use_past=True, use_past_in_inputs=True)
decoder_export_config_with_past.model_part = "decoder"
models_for_export[DECODER_WITH_PAST_NAME] = (
models_for_export[DECODER_WITH_PAST_NAME],
decoder_export_config_with_past,
)

build_delay_pattern_mask_config = config.__class__(
model.config, task=config.task, legacy=False, model_part="build_delay_pattern_mask", variant=config.variant
)
models_for_export["build_delay_pattern_mask"] = (
models_for_export["build_delay_pattern_mask"],
build_delay_pattern_mask_config,
)

return models_for_export


def _get_submodels_for_export_sam(model, variant):
models_for_export = {}

Expand Down Expand Up @@ -513,6 +557,8 @@ def _get_submodels_and_export_configs(
models_and_export_configs = get_sam_models_for_export(model, export_config)
elif model.config.model_type == "speecht5":
models_and_export_configs = get_speecht5_models_for_export(model, export_config, model_kwargs)
elif model.config.model_type == "musicgen":
models_and_export_configs = get_musicgen_models_for_export(model, export_config)
else:
models_and_export_configs = {"model": (model, export_config)}

Expand Down
2 changes: 1 addition & 1 deletion optimum/onnx/transformations_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def _unify_onnx_outputs(model1: ModelProto, model2: ModelProto, strict: bool):
else:
logger.info(
f"The two models proto have different outputs ({len(model1_outputs)} and {len(model2_outputs)} outputs)."
" Constant outputs will be added to unify the two models outputs."
" Constant outputs will be added to unify the two models outputs. This is expected for encoder-decoder models where cached cross-attention key/values are constant outputs, omitted in the model with KV cache."
)

if model2_outputs.issubset(model1_outputs) is False:
Expand Down
3 changes: 3 additions & 0 deletions optimum/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,11 @@
BloomDummyPastKeyValuesGenerator,
DummyAudioInputGenerator,
DummyBboxInputGenerator,
DummyCodegenDecoderTextInputGenerator,
DummyDecoderTextInputGenerator,
DummyEncodecInputGenerator,
DummyInputGenerator,
DummyIntGenerator,
DummyLabelsGenerator,
DummyPastKeyValuesGenerator,
DummyPix2StructInputGenerator,
Expand Down
Loading

0 comments on commit 2f75b0d

Please sign in to comment.