From f9bbbbcdc36826944822e9c0264c85b68cc965bf Mon Sep 17 00:00:00 2001 From: randy Date: Wed, 3 Sep 2025 17:19:58 -0400 Subject: [PATCH 1/3] Fixing bug when replacing text-audio token placeholders with audio embeddings --- src/transformers/models/voxtral/modeling_voxtral.py | 5 ++++- src/transformers/models/voxtral/modular_voxtral.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/voxtral/modeling_voxtral.py b/src/transformers/models/voxtral/modeling_voxtral.py index 3dbbbd9fee0b..0658cf63b6bb 100644 --- a/src/transformers/models/voxtral/modeling_voxtral.py +++ b/src/transformers/models/voxtral/modeling_voxtral.py @@ -507,9 +507,12 @@ def forward( if input_features is not None: audio_embeds = self.get_audio_embeds(input_features) + # Enable gradient tracking when inputs_embeds is a leaf tensor + if inputs_embeds.is_leaf and inputs_embeds.requires_grad: + inputs_embeds = inputs_embeds.clone() # replace text-audio token placeholders with audio embeddings audio_token_mask = input_ids == self.config.audio_token_id - inputs_embeds[audio_token_mask] = audio_embeds + inputs_embeds[audio_token_mask] = audio_embeds.to(inputs_embeds.device) outputs: BaseModelOutputWithPast = self.language_model( attention_mask=attention_mask, diff --git a/src/transformers/models/voxtral/modular_voxtral.py b/src/transformers/models/voxtral/modular_voxtral.py index 0c325c0c605e..787b71fd10fa 100644 --- a/src/transformers/models/voxtral/modular_voxtral.py +++ b/src/transformers/models/voxtral/modular_voxtral.py @@ -242,9 +242,12 @@ def forward( if input_features is not None: audio_embeds = self.get_audio_embeds(input_features) + # Enable gradient tracking when inputs_embeds is a leaf tensor + if inputs_embeds.is_leaf and inputs_embeds.requires_grad: + inputs_embeds = inputs_embeds.clone() # replace text-audio token placeholders with audio embeddings audio_token_mask = input_ids == self.config.audio_token_id - inputs_embeds[audio_token_mask] = audio_embeds + inputs_embeds[audio_token_mask] = audio_embeds.to(inputs_embeds.device) outputs: BaseModelOutputWithPast = self.language_model( attention_mask=attention_mask, From 47fa46e3e2634aed4e5561864900061a73ac5aeb Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Thu, 4 Sep 2025 16:36:28 +0200 Subject: [PATCH 2/3] apply changes --- src/transformers/models/voxtral/modeling_voxtral.py | 9 ++++----- src/transformers/models/voxtral/modular_voxtral.py | 9 ++++----- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/voxtral/modeling_voxtral.py b/src/transformers/models/voxtral/modeling_voxtral.py index 0658cf63b6bb..8762573562ac 100644 --- a/src/transformers/models/voxtral/modeling_voxtral.py +++ b/src/transformers/models/voxtral/modeling_voxtral.py @@ -507,12 +507,11 @@ def forward( if input_features is not None: audio_embeds = self.get_audio_embeds(input_features) - # Enable gradient tracking when inputs_embeds is a leaf tensor - if inputs_embeds.is_leaf and inputs_embeds.requires_grad: - inputs_embeds = inputs_embeds.clone() # replace text-audio token placeholders with audio embeddings - audio_token_mask = input_ids == self.config.audio_token_id - inputs_embeds[audio_token_mask] = audio_embeds.to(inputs_embeds.device) + audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1) + inputs_embeds = inputs_embeds.masked_scatter( + audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device) + ) outputs: BaseModelOutputWithPast = self.language_model( attention_mask=attention_mask, diff --git a/src/transformers/models/voxtral/modular_voxtral.py b/src/transformers/models/voxtral/modular_voxtral.py index 787b71fd10fa..89ac896f133f 100644 --- a/src/transformers/models/voxtral/modular_voxtral.py +++ b/src/transformers/models/voxtral/modular_voxtral.py @@ -242,12 +242,11 @@ def forward( if input_features is not None: audio_embeds = self.get_audio_embeds(input_features) - # Enable gradient tracking when inputs_embeds is a leaf tensor - if inputs_embeds.is_leaf and inputs_embeds.requires_grad: - inputs_embeds = inputs_embeds.clone() # replace text-audio token placeholders with audio embeddings - audio_token_mask = input_ids == self.config.audio_token_id - inputs_embeds[audio_token_mask] = audio_embeds.to(inputs_embeds.device) + audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1) + inputs_embeds = inputs_embeds.masked_scatter( + audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device) + ) outputs: BaseModelOutputWithPast = self.language_model( attention_mask=attention_mask, From eb17a17c615cecefc533e8593222c4ca9574e84e Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Thu, 4 Sep 2025 17:02:01 +0200 Subject: [PATCH 3/3] audio token replacement does not make sense when input ids not provided --- src/transformers/models/voxtral/modeling_voxtral.py | 2 +- src/transformers/models/voxtral/modular_voxtral.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/voxtral/modeling_voxtral.py b/src/transformers/models/voxtral/modeling_voxtral.py index 8762573562ac..671d91066cae 100644 --- a/src/transformers/models/voxtral/modeling_voxtral.py +++ b/src/transformers/models/voxtral/modeling_voxtral.py @@ -504,7 +504,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) - if input_features is not None: + if input_features is not None and input_ids is not None: audio_embeds = self.get_audio_embeds(input_features) # replace text-audio token placeholders with audio embeddings diff --git a/src/transformers/models/voxtral/modular_voxtral.py b/src/transformers/models/voxtral/modular_voxtral.py index 89ac896f133f..a0080f58eb0d 100644 --- a/src/transformers/models/voxtral/modular_voxtral.py +++ b/src/transformers/models/voxtral/modular_voxtral.py @@ -239,7 +239,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) - if input_features is not None: + if input_features is not None and input_ids is not None: audio_embeds = self.get_audio_embeds(input_features) # replace text-audio token placeholders with audio embeddings