# Modify HuggingFace Transformers Whisper Decoder for Hailo Compatibility

Following the Hailo patch modifications:

Key decoder modifications:
1. Token embedding reshape operations (unsqueeze, transpose, flatten)
2. Split final matmul into 4 chunks to avoid Hailo size limits
3. Use eager attention (no SDPA)
4. Fixed decoder sequence length


--> overall: this decoder is not workable quite yet
But due to inefficiencies of running a decoder on Hailo NPU, I decided to run the hybrid approach where the decoder is the default one

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from transformers import WhisperForConditionalGeneration
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache
from transformers.models.whisper.modeling_whisper import create_causal_mask
from transformers.utils import logging
import onnx
import types
import os
from onnxsim import simplify

logger = logging.get_logger(__name__)

In [3]:
base_model_name = "openai/whisper-tiny"
output_dir = "hailo_compatible_models/hf_whisper_tiny"

# Configuration based on Hailo reference
SCALING_FACTOR = 3  # 30s -> 10s
INPUT_LENGTH_SECONDS = 10
DECODER_SEQUENCE_LENGTH = 32  # Max tokens for tiny model
ENCODER_SEQ_LEN = 500  # 1500 / 3 for 10s input
HIDDEN_STATES_CHANNELS = 384  # for tiny model

## Decoder Architecture Modifications

In [4]:
def split_matmul_method(self, x):
    """Split final matmul into 4 chunks to fit Hailo constraints
    
    From patch lines 108-125:
    Splits the large vocab matmul (51865 vocab size) into 4 smaller operations
    """
    vocab_size = self.embed_tokens.weight.shape[0]
    chunk_size = vocab_size // 4
    logit_chunks = []
    
    W = self.embed_tokens.weight.to(x.dtype)
    
    for i in range(4):
        start = i * chunk_size
        end = (i + 1) * chunk_size if i < 3 else vocab_size  # handle remainder
        W_chunk = W[start:end]  # shape: (chunk_size, hidden_size)
        logits_chunk = torch.matmul(x, W_chunk.T)  # shape: (batch, seq_len, chunk_size)
        logit_chunks.append(logits_chunk)
    
    logits = torch.cat(logit_chunks, dim=-1)  # shape: (batch, seq_len, vocab_size)
    
    return logits

In [5]:
def new_decoder_forward(
        self,
        input_ids=None,
        attention_mask=None,
        encoder_hidden_states=None,
        head_mask=None,
        cross_attn_head_mask=None,
        past_key_values=None,
        inputs_embeds=None,
        position_ids=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        cache_position=None,
    ):
        r"""
        Copied from original HF implementation, with minor modifications for Hailo compatibility:
        https://github.com/huggingface/transformers/blob/53838edde77cb10f3a360150aa85a457637e9ac3/src/transformers/models/whisper/modeling_whisper.py#L765

        ONLY modification: Added 3 lines of reshape operations after dropout for Hailo ONNX compatibility
        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # retrieve input_ids and inputs_embeds
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        # NOTE: Cache initialization not needed for ONNX export
        # Reason: ONNX export is stateless - we process all 32 tokens at once, not autoregressive generation
        # The past_key_values cache is only useful for token-by-token generation where we reuse previous
        # key/value computations. For fixed-shape ONNX inference, all tokens are provided upfront.
        # Additionally, this code block is from a newer transformers version and causes compatibility issues.
        #
        # if use_cache and past_key_values is None:
        #     if self.config.is_encoder_decoder:
        #         past_key_values = EncoderDecoderCache(
        #             DynamicCache(config=self.config), DynamicCache(config=self.config)
        #         )
        #     else:
        #         past_key_values = DynamicCache(config=self.config)

        past_key_values_length = 0
        if cache_position is not None:
            past_key_values_length = cache_position[0]
        elif past_key_values is not None:
            past_key_values_length = past_key_values.get_seq_length()

        if cache_position is None:
            cache_position = torch.arange(
                past_key_values_length, past_key_values_length + input_shape[1], device=inputs_embeds.device
            )

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0).repeat(input_shape[0], 1)

        # embed positions
        if input_ids is not None:
            positions = self.embed_positions(
                input_ids, past_key_values_length=past_key_values_length, position_ids=position_ids
            )
        else:
            positions = self.embed_positions(
                inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids
            )

        hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)

        #### START HAILO PATCH (patch lines 144-146) ####
        # These reshape operations modify the ONNX graph structure for Hailo compatibility
        # The shape returns to [batch, seq_len, hidden_dim] but with different internal layout
        hidden_states = hidden_states.unsqueeze(1)  # [B, seq, hidden] -> [B, 1, seq, hidden]
        hidden_states = hidden_states.transpose(1, -1)  # [B, 1, seq, hidden] -> [B, hidden, seq, 1]
        hidden_states = hidden_states.flatten(2).permute(0, 2, 1)  # -> [B, seq, hidden]
        #### END HAILO PATCH ####

        causal_mask = create_causal_mask(
            config=self.config,
            input_embeds=inputs_embeds,
            attention_mask=attention_mask,
            cache_position=cache_position,
            past_key_values=past_key_values,
            position_ids=position_ids,
        )

        if self.gradient_checkpointing and self.training:
            if use_cache:
                logger.warning_once(
                    "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..."
                )
                use_cache = False
        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None

        # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
        for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
            if attn_mask is not None:
                assert attn_mask.size()[0] == (len(self.layers)), (
                    f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
                    f" {head_mask.size()[0]}."
                )
        for idx, decoder_layer in enumerate(self.layers):
            # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
            if output_hidden_states:
                all_hidden_states += (hidden_states,)
            if self.training:
                dropout_probability = torch.rand([])
                if dropout_probability < self.layerdrop:
                    continue

            # Extract past_key_value for this layer (if cache exists)
            past_key_value = past_key_values[idx] if past_key_values is not None else None

            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=causal_mask,
                encoder_hidden_states=encoder_hidden_states,
                layer_head_mask=(head_mask[idx] if head_mask is not None else None),
                cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
                past_key_value=past_key_value,  # Singular, not plural
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
            )
            hidden_states = layer_outputs[0]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

                if encoder_hidden_states is not None:
                    all_cross_attentions += (layer_outputs[2],)

        hidden_states = self.layer_norm(hidden_states)
        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = past_key_values if use_cache else None
        if not return_dict:
            return tuple(
                v
                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
                if v is not None
            )
        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
            cross_attentions=all_cross_attentions,
        )


def simple_decoder_forward_for_onnx(
        self,
        input_ids=None,
        encoder_hidden_states=None,
    ):
    r"""
    Simplified decoder forward specifically for ONNX export.
    Based on OpenAI Whisper structure (patch lines 126-164) adapted for HuggingFace.
    
    This avoids complex transformers helper functions (like create_causal_mask) 
    that don't export to ONNX well.
    
    Uses the same Hailo patch (reshape operations) as new_decoder_forward.
    """
    # Get embeddings
    inputs_embeds = self.embed_tokens(input_ids)  # [1, 32, 384]
    
    # Get positional embeddings (no offset for ONNX - process all tokens at once)
    positions = self.embed_positions(input_ids, past_key_values_length=0)  # [1, 32, 384]
    
    # Combine embeddings
    hidden_states = inputs_embeds + positions
    hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
    
    #### START HAILO PATCH (patch lines 144-146) ####
    # These reshape operations modify the ONNX graph structure for Hailo compatibility
    hidden_states = hidden_states.unsqueeze(1)  # [B, seq, hidden] -> [B, 1, seq, hidden]
    hidden_states = hidden_states.transpose(1, -1)  # [B, 1, seq, hidden] -> [B, hidden, seq, 1]
    hidden_states = hidden_states.flatten(2).permute(0, 2, 1)  # -> [B, seq, hidden]
    #### END HAILO PATCH ####
    
    # Create simple causal mask (like OpenAI Whisper's self.mask)
    # This is a static mask that's ONNX-compatible
    seq_len = input_ids.shape[1]
    causal_mask = torch.full((seq_len, seq_len), float("-inf"), device=hidden_states.device, dtype=hidden_states.dtype)
    causal_mask = torch.triu(causal_mask, diagonal=1)  # Upper triangular with -inf
    causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)  # [1, 1, seq, seq] for batch and heads
    
    # Process through decoder layers
    for decoder_layer in self.layers:
        layer_outputs = decoder_layer(
            hidden_states,
            attention_mask=causal_mask,
            encoder_hidden_states=encoder_hidden_states,
            past_key_value=None,  # No cache for ONNX
            output_attentions=False,
            use_cache=False,
        )
        hidden_states = layer_outputs[0]
    
    # Final layer norm
    hidden_states = self.layer_norm(hidden_states)  # [1, 32, 384]
    
    # Return in same format as new_decoder_forward for compatibility
    return BaseModelOutputWithPastAndCrossAttentions(
        last_hidden_state=hidden_states,
        past_key_values=None,
        hidden_states=None,
        attentions=None,
        cross_attentions=None,
    )

In [6]:
class WhisperDecoderWrapper(nn.Module):
    """Wrapper for ONNX export that combines decoder + split matmul
    
    This is what gets exported to ONNX for Hailo.
    Uses our modified decoder (with Hailo patch) and applies split matmul for final logits.
    """
    def __init__(self, model):
        super().__init__()
        self.decoder = model.model.decoder  # Already has new_decoder_forward with Hailo patch
        self.embed_tokens = model.model.decoder.embed_tokens
        
    def forward(self, decoder_input_ids, encoder_hidden_states):
        # Decoder forward (uses our new_decoder_forward with reshape operations)
        # Pass minimal arguments to avoid ONNX export issues
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            encoder_hidden_states=encoder_hidden_states,
            use_cache=False,  # Critical: disable cache for ONNX
            output_attentions=False,
            output_hidden_states=False,
            return_dict=True
        )
        
        hidden_states = decoder_outputs.last_hidden_state  # [1, 32, 384]
        
        # Split matmul for logits (Hailo requirement)
        logits = split_matmul_method(self.decoder, hidden_states)  # [1, 32, 51865]
        
        return logits

## Load Model and Apply Modifications

In [7]:
# Load base model with eager attention (no SDPA)
model = WhisperForConditionalGeneration.from_pretrained(
    base_model_name,
    attn_implementation='eager'
)
print(f"Attention implementation: {model.config._attn_implementation}")
print(f"Vocab size: {model.config.vocab_size}")
print(f"Max target positions: {model.config.max_target_positions}")

Attention implementation: eager
Vocab size: 51865
Max target positions: 448


In [8]:
# Apply decoder modifications
print("Applying decoder modifications...")
model.model.decoder.forward = types.MethodType(new_decoder_forward, model.model.decoder)
print("✓ Decoder modifications applied (new_decoder_forward with Hailo patch)")

# Also attach the simplified version for ONNX export
model.model.decoder.forward_for_onnx = types.MethodType(simple_decoder_forward_for_onnx, model.model.decoder)
print("✓ Simplified ONNX-compatible forward attached (simple_decoder_forward_for_onnx)")

Applying decoder modifications...
✓ Decoder modifications applied (new_decoder_forward with Hailo patch)
✓ Simplified ONNX-compatible forward attached (simple_decoder_forward_for_onnx)


## Test Decoder Inference

Test the decoder wrapper to ensure it works correctly before ONNX export.
This uses the same wrapper that will be exported to ONNX.

In [11]:
# Create dummy inputs matching Hailo specs
batch_size = 1

# Encoder outputs (from 10s audio)
encoder_hidden_states = torch.randn(
    batch_size,
    ENCODER_SEQ_LEN,
    HIDDEN_STATES_CHANNELS,
    dtype=torch.float32
)

# Decoder input IDs (start token + zeros)
decoder_input_ids = torch.cat([
    torch.tensor([[50258]], dtype=torch.int64),  # Start token
    torch.zeros((1, DECODER_SEQUENCE_LENGTH - 1), dtype=torch.int64)
], dim=1)

print(f"Encoder hidden states: {encoder_hidden_states.shape}")
print(f"Decoder input IDs: {decoder_input_ids.shape}")

Encoder hidden states: torch.Size([1, 500, 384])
Decoder input IDs: torch.Size([1, 32])


In [12]:
# Test inference using decoder wrapper (same as ONNX export)
print("Testing decoder wrapper (this is what gets exported to ONNX)...")

# Create decoder wrapper
decoder_wrapper = WhisperDecoderWrapper(model)
decoder_wrapper.eval()

with torch.no_grad():
    logits = decoder_wrapper(
        decoder_input_ids=decoder_input_ids,
        encoder_hidden_states=encoder_hidden_states
    )

print(f"\nOutput logits shape: {logits.shape}")
print(f"Expected shape: [1, {DECODER_SEQUENCE_LENGTH}, {model.config.vocab_size}]")
print(f"Stats: mean={logits.mean():.6f}, std={logits.std():.6f}")

Testing decoder wrapper (this is what gets exported to ONNX)...

Output logits shape: torch.Size([1, 32, 51865])
Expected shape: [1, 32, 51865]
Stats: mean=20.190424, std=5.041745


## Compare Full vs Simplified Decoder

Before ONNX export, verify that the simplified decoder produces similar results.

In [13]:
# Compare outputs from both decoder approaches
print("Comparing full decoder vs simplified decoder...")

with torch.no_grad():
    # Full decoder (with all transformers features)
    output_full = model.model.decoder(
        input_ids=decoder_input_ids,
        encoder_hidden_states=encoder_hidden_states
    )
    hidden_full = output_full.last_hidden_state
    
    # Simplified decoder (ONNX-compatible)
    output_simple = model.model.decoder.forward_for_onnx(
        input_ids=decoder_input_ids,
        encoder_hidden_states=encoder_hidden_states
    )
    hidden_simple = output_simple.last_hidden_state
    
    # Compute logits for both
    logits_full = split_matmul_method(model.model.decoder, hidden_full)
    logits_simple = split_matmul_method(model.model.decoder, hidden_simple)
    
    # Compare
    diff = (logits_full - logits_simple).abs()
    print(f"\nFull decoder logits shape: {logits_full.shape}")
    print(f"Simple decoder logits shape: {logits_simple.shape}")
    print(f"\nDifference statistics:")
    print(f"  Max diff: {diff.max().item():.6f}")
    print(f"  Mean diff: {diff.mean().item():.6f}")
    print(f"  Median diff: {diff.median().item():.6f}")
    
    if diff.max() < 1e-3:
        print("\n✓ Results are very similar! Safe to use simplified version for ONNX.")
    elif diff.max() < 0.1:
        print("\n⚠ Results are close but have some differences. Should be acceptable for ONNX export.")
    else:
        print("\n✗ Results differ significantly! Need to investigate.")

Comparing full decoder vs simplified decoder...

Full decoder logits shape: torch.Size([1, 32, 51865])
Simple decoder logits shape: torch.Size([1, 32, 51865])

Difference statistics:
  Max diff: 0.000000
  Mean diff: 0.000000
  Median diff: 0.000000

✓ Results are very similar! Safe to use simplified version for ONNX.


## Export to ONNX (Hailo Style)

Using the simplified decoder for ONNX compatibility.

In [14]:
# Create ONNX-compatible wrapper that uses simplified decoder
class WhisperDecoderWrapperForONNX(nn.Module):
    """ONNX-compatible wrapper using simplified decoder forward"""
    def __init__(self, model):
        super().__init__()
        self.decoder = model.model.decoder
        self.embed_tokens = model.model.decoder.embed_tokens
        
    def forward(self, decoder_input_ids, encoder_hidden_states):
        # Use simplified forward for ONNX compatibility
        decoder_outputs = self.decoder.forward_for_onnx(
            input_ids=decoder_input_ids,
            encoder_hidden_states=encoder_hidden_states
        )
        
        hidden_states = decoder_outputs.last_hidden_state  # [1, 32, 384]
        
        # Split matmul for logits
        logits = split_matmul_method(self.decoder, hidden_states)  # [1, 32, 51865]
        
        return logits

# Create wrapper for ONNX export
decoder_wrapper_onnx = WhisperDecoderWrapperForONNX(model)
decoder_wrapper_onnx.eval()
print("✓ ONNX-compatible wrapper created")

✓ ONNX-compatible wrapper created


In [None]:
# Export settings
os.makedirs(output_dir, exist_ok=True)

decoder_name = f"whisper-tiny-decoder-{INPUT_LENGTH_SECONDS}s-seq-{DECODER_SEQUENCE_LENGTH}"
decoder_path_base = f"{output_dir}/{decoder_name}_base.onnx"
decoder_path_final = f"{output_dir}/{decoder_name}_final.onnx"

print(f"Exporting decoder to: {decoder_path_base}")
print("This may take 2-5 minutes...")

with torch.no_grad():
    # Export to ONNX
    # torch.onnx.export(
    #     decoder_wrapper_onnx,
    #     (decoder_input_ids, encoder_hidden_states),
    #     decoder_path_base,
    #     opset_version=13,  # Match Hailo reference
    #     input_names=["decoder_input_ids", "encoder_hidden_states"],
    #     output_names=["logits"],
    #     dynamic_axes=None,  # Fixed shapes for Hailo
    #     do_constant_folding=True,
    #     export_params=True,
    #     verbose=True,  # Set to True if you want to see detailed progress
    # )
    torch.onnx.export(
        decoder_wrapper_onnx,
        (decoder_input_ids, encoder_hidden_states),
        decoder_path_base,
        opset_version=13,  # Match Hailo reference
        input_names=["decoder_input_ids", "encoder_hidden_states"],
        output_names=["logits"],
        verbose=True,  # Set to True if you want to see detailed progress
    )

print(f"✓ Base ONNX model exported to: {decoder_path_base}")

Exporting decoder to: hailo_compatible_models/hf_whisper_tiny/whisper-tiny-decoder-10s-seq-32_base.onnx
This may take 2-5 minutes...
Torch IR graph at exception: graph(%input.1 : Long(1, 32, strides=[32, 1], requires_grad=0, device=cpu),
      %1 : Float(1, 500, 384, strides=[192000, 384, 1], requires_grad=0, device=cpu),
      %decoder.embed_tokens.weight : Float(51865, 384, strides=[384, 1], requires_grad=1, device=cpu),
      %decoder.embed_positions.weight : Float(448, 384, strides=[384, 1], requires_grad=1, device=cpu),
      %decoder.layers.0.self_attn.k_proj.weight : Float(384, 384, strides=[384, 1], requires_grad=1, device=cpu),
      %decoder.layers.0.self_attn.v_proj.weight : Float(384, 384, strides=[384, 1], requires_grad=1, device=cpu),
      %decoder.layers.0.self_attn.v_proj.bias : Float(384, strides=[1], requires_grad=1, device=cpu),
      %decoder.layers.0.self_attn.q_proj.weight : Float(384, 384, strides=[384, 1], requires_grad=1, device=cpu),
      %decoder.layers.0.s

  torch.onnx.export(


t(384, 384, strides=[384, 1], requires_grad=1, device=cpu),
      %decoder.layers.0.self_attn.out_proj.bias : Float(384, strides=[1], requires_grad=1, device=cpu),
      %decoder.layers.0.self_attn_layer_norm.weight : Float(384, strides=[1], requires_grad=1, device=cpu),
      %decoder.layers.0.self_attn_layer_norm.bias : Float(384, strides=[1], requires_grad=1, device=cpu),
      %decoder.layers.0.encoder_attn.k_proj.weight : Float(384, 384, strides=[384, 1], requires_grad=1, device=cpu),
      %decoder.layers.0.encoder_attn.v_proj.weight : Float(384, 384, strides=[384, 1], requires_grad=1, device=cpu),
      %decoder.layers.0.encoder_attn.v_proj.bias : Float(384, strides=[1], requires_grad=1, device=cpu),
      %decoder.layers.0.encoder_attn.q_proj.weight : Float(384, 384, strides=[384, 1], requires_grad=1, device=cpu),
      %decoder.layers.0.encoder_attn.q_proj.bias : Float(384, strides=[1], requires_grad=1, device=cpu),
      %decoder.layers.0.encoder_attn.out_proj.weight : Float(

KeyboardInterrupt: 

In [None]:
# simplify the ONNX model

input_shapes = {
    "decoder_input_ids": [1, DECODER_SEQUENCE_LENGTH],
    "encoder_hidden_states": [1, ENCODER_SEQ_LEN, HIDDEN_STATES_CHANNELS]
}

model_onnx = onnx.load(decoder_path_base)
model_simp, check = simplify(model_onnx, overwrite_input_shapes=input_shapes)
onnx.save(model_simp, decoder_path_final)

# Check if the simplification was successful
if check:
    logger.info("ONNX model was successfully simplified!")
else:
    logger.info("ONNX model simplification failed!")

In [None]:
# View base model
! netron {decoder_path_base}

In [None]:
# View simplified model
! netron {decoder_path_final}

## Summary

Decoder modifications applied:
1. ✓ Token embedding reshape operations (unsqueeze, transpose, flatten)
2. ✓ Split final matmul into 4 chunks for Hailo compatibility
3. ✓ Eager attention implementation (no SDPA)
4. ✓ Fixed sequence length (32 tokens for tiny model)

The decoder is now ready for Hailo NPU deployment!