# Modify HuggingFace Transformers Whisper to match expectations for Hailo -- ENCODER ONLY 

Following the patch fill

Done 

* conv1d --> conv2d + forward pass adapted
* SDPA_AVAILABLE  --> attn_implementation='eager'
* input length 30sec --> 10sec and positional embeddings reconstruction

Not done yet:

* "*1.0" on attention values: see patch v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * 1.0
  (would require complete overwrite of WhiseprAttention: https://github.com/huggingface/transformers/blob/53838edde77cb10f3a360150aa85a457637e9ac3/src/transformers/models/whisper/modeling_whisper.py#L288
  and then multiply with "1.0" here: https://github.com/huggingface/transformers/blob/53838edde77cb10f3a360150aa85a457637e9ac3/src/transformers/models/whisper/modeling_whisper.py#L340



In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from transformers import WhisperForConditionalGeneration
import onnx
import types
import os
from onnxsim import simplify
import math

In [2]:
# reference model: from hailo's export script for comparison
hailo_reference_onnx = "hailo_reference_models/tiny/tiny-whisper-encoder-10s.onnx"

In [3]:
base_model_name="openai/whisper-tiny"

output_dir="hailo_compatible_models/hf_whisper_tiny"

In [4]:
# go from 30sec --> 10 sec
SCALING_FACTOR = 3

# Whisper Encoder Architecture modifications

In [None]:
from transformers.modeling_outputs import BaseModelOutput
def conv2_forward(
        self,
        input_features,
        attention_mask=None,
        head_mask=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""Copy from https://github.com/huggingface/transformers/blob/53838edde77cb10f3a360150aa85a457637e9ac3/src/transformers/models/whisper/modeling_whisper.py#L632C5-L730C10
        
        Modifications for setting conv2d"""
        print(">> updated forward fn")
        # for orig seq lengh
        # p = self.config.max_source_positions
        # for modified length
        p = self.config.max_source_positions // SCALING_FACTOR
        c1 = self.conv1.stride[0]
        # this is orig:
        # c2 = self.conv2.stride[0]
        c2 = self.conv2.stride[1] # (time dimension in 2D stride)
        
        print(f"config.max_source_positions: {p}")
        print(f"self.conv1.stride[0]: {c1}")
        print(f"self.conv2.stride[0]: {c2}")

        expected_seq_length = p * c1 * c2
        print(f"--> Expected seqlen: {expected_seq_length}")
        if input_features.shape[-1] != expected_seq_length:
            raise ValueError(
                f"Whisper expects the mel input features to be of length {expected_seq_length}, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}."
            )

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        #### START HAILO PATCH PART ######

        # orig in HF Transformers
        # inputs_embeds = nn.functional.gelu(self.conv1(input_features))
        # inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
        # inputs_embeds = inputs_embeds.permute(0, 2, 1)
        
        # make compatible with 3D (default whisper) and 4D (what hailo expects)
        if len(input_features.shape) == 3:
            # handle 3D inputs
            print("GETTING 3D input...")
            x = input_features.unsqueeze(2)
        else:
            # handle 4D inputs - this is what Hailo wants
            print("GETTING 4D input...")
            x = input_features

        inputs_embeds = F.gelu(self.conv1(x))
        print(f"--> After conv1: {inputs_embeds.shape}")
        inputs_embeds = F.gelu(self.conv2(inputs_embeds))
        print(f"--> After conv2: {inputs_embeds.shape}")
        inputs_embeds = inputs_embeds.flatten(2).permute(0, 2, 1)  # Hailo patch
        print(f"--> After flatten+permute: {inputs_embeds.shape}")            
        #### END HAILO PATCH PART ######

        all_positions = torch.arange(self.embed_positions.num_embeddings, device=inputs_embeds.device)

        hidden_states = inputs_embeds + self.embed_positions(all_positions)
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)

        encoder_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None

        # check if head_mask has a correct number of layers specified if desired
        if head_mask is not None:
            assert head_mask.size()[0] == (len(self.layers)), (
                f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
            )

        for idx, encoder_layer in enumerate(self.layers):
            if output_hidden_states:
                encoder_states = encoder_states + (hidden_states,)
            # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
            to_drop = False
            if self.training:
                dropout_probability = torch.rand([])
                if dropout_probability < self.layerdrop:  # skip the layer
                    to_drop = True

            if to_drop:
                layer_outputs = (None, None)
            else:
                layer_outputs = encoder_layer(
                    hidden_states,
                    None,
                    layer_head_mask=(head_mask[idx] if head_mask is not None else None),
                    output_attentions=output_attentions,
                )

                hidden_states = layer_outputs[0]

            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[1],)

        hidden_states = self.layer_norm(hidden_states)
        if output_hidden_states:
            encoder_states = encoder_states + (hidden_states,)

        if not return_dict:
            return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
        return BaseModelOutput(
            last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
        )    


In [28]:
def set_conv2d(model):
    encoder = model.model.encoder
    #### Conv1D --> Conv2D conversion ####
    # Weight transformations (lines 10-13)
    # Conv1d → Conv2d conversion (lines 51-52, 69-70)    
    print("1️⃣ Applying weight unsqueezing transformation...")
    conv1_weight = encoder.conv1.weight.data.clone()
    conv2_weight = encoder.conv2.weight.data.clone()
    conv1_bias = encoder.conv1.bias.data.clone() if encoder.conv1.bias is not None else None
    conv2_bias = encoder.conv2.bias.data.clone() if encoder.conv2.bias is not None else None

    conv1_weight_transformed = conv1_weight.unsqueeze(2)  # add height so we get: [384, 80, 1, 3]
    conv2_weight_transformed = conv2_weight.unsqueeze(2)  # dito

    print(f"   Conv1 weight: {conv1_weight.shape} → {conv1_weight_transformed.shape}")
    print(f"   Conv2 weight: {conv2_weight.shape} → {conv2_weight_transformed.shape}")

    # line 69-74
    new_conv1 = nn.Conv2d(
        in_channels=80,  # n_mels
        out_channels=384,  # n_state
        kernel_size=(1, 3),
        padding=(0, 1)
    )

    new_conv2 = nn.Conv2d(
        in_channels=384,  # n_state
        out_channels=384,  # n_state
        kernel_size=(1, 3),
        stride=(1, 2),
        padding=(0, 1)
    )

    new_conv1.weight.data = conv1_weight_transformed
    new_conv2.weight.data = conv2_weight_transformed

    if conv1_bias is not None:
        new_conv1.bias.data = conv1_bias
    if conv2_bias is not None:
        new_conv2.bias.data = conv2_bias

    encoder.conv1 = new_conv1
    encoder.conv2 = new_conv2

    print(" >> Conv layers converted")
    #### Conv1D --> Conv2D conversion ####

    return model

In [29]:
def create_sinusoidal_positions(n_positions, d_model, max_timescale=10000):
    # This is the EXACT implementation from OpenAI whisper/model.py
    # in openai whisper: sinusoids (https://github.com/openai/whisper/blob/c0d2f624c09dc18e709e37c2ad90c039a4eb72a2/whisper/model.py#L62)
    assert d_model % 2 == 0
    log_timescale_increment = np.log(max_timescale) / (d_model // 2 - 1)
    inv_timescales = torch.exp(-log_timescale_increment * torch.arange(d_model // 2))
    scaled_time = torch.arange(n_positions)[:, np.newaxis] * inv_timescales[np.newaxis, :]
    return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)

In [30]:
def apply_positional_scaling(model):
    """Apply Hailo's n_audio_ctx // 3 scaling to positional embeddings"""
    encoder = model.model.encoder

    print("Applying positional embedding scaling...")

    # Hailo scaling factor from patch line 171
    original_length = 1500  # 30-second audio (max_source_positions)
    target_length = original_length // SCALING_FACTOR  # 500 for 10-second audio, TODO make configurable
    assert target_length == 500

    print(f"   Scaling positional embeddings: {original_length} → {target_length}")

    # Get original embeddings
    original_embeddings = encoder.embed_positions.weight.data
    hidden_size = original_embeddings.shape[1]  # Should be 384
    assert hidden_size == 384

    # Create new embedding layer with reduced size
    new_embed_positions = nn.Embedding(target_length, hidden_size)

    # # Simple approach: use first 500 embeddings from the original 1500
    # new_embed_positions.weight.data = original_embeddings[:target_length].clone()

    # alternatively: regenerate sinusoidal on shortened embeddings
    sinusoidal_embeddings = create_sinusoidal_positions(target_length, hidden_size)
    new_embed_positions.weight = nn.Parameter(sinusoidal_embeddings)

    # Replace the embedding layer
    encoder.embed_positions = new_embed_positions

    print(f"   Original embedding shape: {original_embeddings.shape}")
    print(f"   New embedding shape: {new_embed_positions.weight.shape}")

    return model

# Load whisper model and apply architecture changes

In [31]:
model = WhisperForConditionalGeneration.from_pretrained(base_model_name, attn_implementation='eager')
print(f"Attention implementation: {model.config._attn_implementation}")

Attention implementation: eager


In [32]:
model = set_conv2d(model)

1️⃣ Applying weight unsqueezing transformation...
   Conv1 weight: torch.Size([384, 80, 3]) → torch.Size([384, 80, 1, 3])
   Conv2 weight: torch.Size([384, 384, 3]) → torch.Size([384, 384, 1, 3])
 >> Conv layers converted


In [33]:
model = apply_positional_scaling(model)

Applying positional embedding scaling...
   Scaling positional embeddings: 1500 → 500
   Original embedding shape: torch.Size([1500, 384])
   New embedding shape: torch.Size([500, 384])


In [34]:
encoder = model.model.encoder
encoder.forward = types.MethodType(conv2_forward, encoder)

# Tests

## Basic Inference Test


In [38]:
# test 3D compatibility

# we need to change input length
LENGTH = 3000 // SCALING_FACTOR
test_input = torch.randn(1, 80, LENGTH)  # [batch, n_mels, time_steps]

print(f"Input shape: {test_input.shape}")

with torch.no_grad():
    encoder_output = model.model.encoder(test_input)

print(f"Input shape: {test_input.shape}")
print(f"Output shape: {encoder_output.last_hidden_state.shape}")
print(f"Expected output shape: [1, 1500, 384]")  # 3000 → 1500 after conv2 stride=2
print(f"Stats: mean={encoder_output.last_hidden_state.mean():.6f}, std={encoder_output.last_hidden_state.std():.6f}")

Input shape: torch.Size([1, 80, 1000])
>> updated forward fn
config.max_source_positions: 500
self.conv1.stride[0]: 1
self.conv2.stride[0]: 2
--> Expected seqlen: 1000
GETTING 3D input...
--> After conv1: torch.Size([1, 384, 1, 1000])
--> After conv2: torch.Size([1, 384, 1, 500])
--> After flatten+permute: torch.Size([1, 500, 384])
Input shape: torch.Size([1, 80, 1000])
Output shape: torch.Size([1, 500, 384])
Expected output shape: [1, 1500, 384]
Stats: mean=0.012366, std=1.476562


In [37]:
# test 4D compatibility

# we need to change input length
LENGTH = 3000 // SCALING_FACTOR
test_input = torch.randn(1, 80, 1, LENGTH)  # [batch, n_mels, time_steps]

print(f"Input shape: {test_input.shape}")

with torch.no_grad():
    encoder_output = model.model.encoder(test_input)

print(f"Input shape: {test_input.shape}")
print(f"Output shape: {encoder_output.last_hidden_state.shape}")
print(f"Expected output shape: [1, 1500, 384]")  # 3000 → 1500 after conv2 stride=2
print(f"Stats: mean={encoder_output.last_hidden_state.mean():.6f}, std={encoder_output.last_hidden_state.std():.6f}")

Input shape: torch.Size([1, 80, 1, 1000])
>> updated forward fn
config.max_source_positions: 500
self.conv1.stride[0]: 1
self.conv2.stride[0]: 2
--> Expected seqlen: 1000
GETTING 4D input...
--> After conv1: torch.Size([1, 384, 1, 1000])
--> After conv2: torch.Size([1, 384, 1, 500])
--> After flatten+permute: torch.Size([1, 500, 384])
Input shape: torch.Size([1, 80, 1, 1000])
Output shape: torch.Size([1, 500, 384])
Expected output shape: [1, 1500, 384]
Stats: mean=0.010678, std=1.483264


### Compare to original model

In [39]:
orig_model = WhisperForConditionalGeneration.from_pretrained(base_model_name, attn_implementation='eager')
print(f"Attention implementation: {orig_model.config._attn_implementation}")

Attention implementation: eager


In [46]:
# Orig whisper model only accepts 3D inputs
test_input = torch.randn(1, 80, 3000)  # [batch, n_mels, time_steps]


print(f"Input shape: {test_input.shape}")

with torch.no_grad():
    encoder_output = orig_model.model.encoder(test_input)

print(f"Input shape: {test_input.shape}")
print(f"Output shape: {encoder_output.last_hidden_state.shape}")
print(f"Expected output shape: [1, 1500, 384]")  # 3000 → 1500 after conv2 stride=2
print(f"Stats: mean={encoder_output.last_hidden_state.mean():.6f}, std={encoder_output.last_hidden_state.std():.6f}")

Input shape: torch.Size([1, 80, 3000])
Input shape: torch.Size([1, 80, 3000])
Output shape: torch.Size([1, 1500, 384])
Expected output shape: [1, 1500, 384]
Stats: mean=0.020402, std=1.463553


## Compare ONNX Graphs via Netron

In [55]:
orig_model_onnx_encoder_path = "/tmp/original_encoder.onnx"
patched_model_onnx_encoder_path = "/tmp/patched_encoder.onnx"

# can only do 3D
test_input = torch.randn(1, 80, 3000)
torch.onnx.export(orig_model.model.encoder, test_input, orig_model_onnx_encoder_path, opset_version=17)

# 4D like hailo
test_input = torch.randn(1, 80, 1, 1000)
torch.onnx.export(model.model.encoder, test_input, patched_model_onnx_encoder_path, opset_version=17)

  torch.onnx.export(orig_model.model.encoder, test_input, orig_model_onnx_encoder_path, opset_version=17)


>> updated forward fn
config.max_source_positions: 500
self.conv1.stride[0]: 1
self.conv2.stride[0]: 2
--> Expected seqlen: 1000
GETTING 4D input...
--> After conv1: torch.Size([1, 384, 1, 1000])
--> After conv2: torch.Size([1, 384, 1, 500])
--> After flatten+permute: torch.Size([1, 500, 384])


  torch.onnx.export(model.model.encoder, test_input, patched_model_onnx_encoder_path, opset_version=17)
  if input_features.shape[-1] != expected_seq_length:


In [49]:
# !pip install netron

In [111]:
! netron {orig_model_onnx_encoder_path}

Serving '/tmp/original_encoder.onnx' at http://localhost:8081
^C
Stopping http://localhost:8081


In [50]:
! netron {patched_model_onnx_encoder_path}

Serving '/tmp/patched_encoder.onnx' at http://localhost:8081
^C
Stopping http://localhost:8081


In [56]:
# compare to model provided by 
! netron {hailo_reference_onnx}

Serving 'converted_models/whisper_onnx_hailo_converted/tiny/tiny-whisper-encoder-10s.onnx' at http://localhost:8081
^C
Stopping http://localhost:8081


# Export the way Hailo does it

as from 
hailo-whisper/export/export_whisper_model.py

including onnx simplify

In [None]:
ONNX_ENCODER_MODEL_FILENAME = "whisper_tiny_encoder_10s_hailo"

# export as 4D model
def export_to_onnx_hailo_style(model, output_dir):

    # start export
    os.makedirs(output_dir, exist_ok=False)
    encoder_path_base = f"{output_dir}/{ONNX_ENCODER_MODEL_FILENAME}_base.onnx"
    encoder_path_final = f"{output_dir}/{ONNX_ENCODER_MODEL_FILENAME}_final.onnx"

    # 4D
    test_input = torch.randn(1, 80, 1, 1000)  # 10s Hailo format

    # ensure inference works
    with torch.no_grad():
        encoder_output = model.model.encoder(test_input)
    print(f"Cncoder_output: {encoder_output.last_hidden_state.shape}")
    print(f"Stats: mean={encoder_output.last_hidden_state.mean():.6f}, std={encoder_output.last_hidden_state.std():.6f}")

    # Export using EXACT Hailo reference settings
    torch.onnx.export(
        model.model.encoder,
        test_input,
        encoder_path_base,
        input_names=['x.1'],           # Match Hailo reference input name
        output_names=['output_525'],   # For now valid placeholder name (renaming later to what is expected by Hailo)
        opset_version=17               # Keep opset 17 as this was used in Hailo exporter
    )

    # Apply ONNX simplification
    print("Applying ONNX simplification...")
    model_onnx = onnx.load(encoder_path_base)
    model_simp, simplify_successful = simplify(model_onnx)
    if not simplify_successful:
        raise RuntimeError("ONNX simplification failed")

    # Rename output to match Hailo reference exactly
    old_name = model_simp.graph.output[0].name
    model_simp.graph.output[0].name = "525" # somehow this is expected by Hailo

    # Update any internal references to the old output name
    for node in model_simp.graph.node:
        for i, output in enumerate(node.output):
            if output == old_name:
                print(f"   Renaming node output {old_name} → 525")
                node.output[i] = "525"

    # safe final model
    onnx.save(model_simp, encoder_path_final)
    print("Encoder exported:")
    print(f" * base onnx: {encoder_path_base}")
    print(f" * simplified onnx: {encoder_path_final}")

    return encoder_path_base, encoder_path_final


In [63]:
encoder_path_base, encoder_path_final = export_to_onnx_hailo_style(model, "hailo_compatible_models/notebook_export")

>> updated forward fn
config.max_source_positions: 500
self.conv1.stride[0]: 1
self.conv2.stride[0]: 2
--> Expected seqlen: 1000
GETTING 4D input...
--> After conv1: torch.Size([1, 384, 1, 1000])
--> After conv2: torch.Size([1, 384, 1, 500])
--> After flatten+permute: torch.Size([1, 500, 384])
Cncoder_output: torch.Size([1, 500, 384])
Stats: mean=0.010642, std=1.462520
>> updated forward fn
config.max_source_positions: 500
self.conv1.stride[0]: 1
self.conv2.stride[0]: 2
--> Expected seqlen: 1000
GETTING 4D input...
--> After conv1: torch.Size([1, 384, 1, 1000])
--> After conv2: torch.Size([1, 384, 1, 500])
--> After flatten+permute: torch.Size([1, 500, 384])
Applying ONNX simplification...


  torch.onnx.export(
  if input_features.shape[-1] != expected_seq_length:


   Renaming node output output_525 → 525
Encoder exported:
 * base onnx: hailo_compatible_models/notebook_export/whisper_tiny_encoder_10s_hailo_base
 * simplified onnx: hailo_compatible_models/notebook_export/whisper_tiny_encoder_10s_hailo_final


## Check with Netron again

In [67]:
! netron {encoder_path_base}

Serving 'hailo_compatible_models/notebook_export/whisper_tiny_encoder_10s_hailo_base' at http://localhost:8081
^C
Stopping http://localhost:8081


In [68]:
! netron {encoder_path_final}

Serving 'hailo_compatible_models/notebook_export/whisper_tiny_encoder_10s_hailo_final' at http://localhost:8081
^C
Stopping http://localhost:8081


In [65]:
! netron {hailo_reference_onnx}

Serving 'converted_models/whisper_onnx_hailo_converted/tiny/tiny-whisper-encoder-10s.onnx' at http://localhost:8081
^C
Stopping http://localhost:8081
