In [2]:
from onnxmltools.utils.float16_converter import convert_float_to_float16
import onnx



In [3]:
# Load and convert the model
model = onnx.load("models/nemo-parakeet_tdt_ctc_110m.onnx")
model_fp16 = convert_float_to_float16(model)
onnx.save_model(model_fp16, "models/nemo-parakeet_tdt_ctc_110m_fp16.onnx")



In [4]:
# Load and convert the model
model = onnx.load("models/glados.onnx")
model_fp16 = convert_float_to_float16(model)
onnx.save_model(model_fp16, "models/glados_fp16.onnx")



In [6]:
import shutil
import os

# First do the model conversion
def convert_model_with_config(original_model_path, fp16_model_path):
    # Convert model as before
    import onnx
    from onnxconverter_common import float16
    
    model = onnx.load(original_model_path)
    excluded_ops = {
        'RandomNormalLike',
        'Range',
        'Constant'
    }
    
    model_fp16 = float16.convert_float_to_float16(
        model,
        keep_io_types=True,
        op_block_list=excluded_ops
    )
    
    # Save the FP16 model
    onnx.save_model(model_fp16, fp16_model_path)
    
    # Copy the config file
    original_config = original_model_path + '.json'
    new_config = fp16_model_path + '.json'
    
    if os.path.exists(original_config):
        shutil.copy2(original_config, new_config)
        print(f"Config file copied to {new_config}")
    else:
        print(f"Warning: Original config file {original_config} not found")

# Use the function
convert_model_with_config(
    "models/glados.onnx",
    "models/glados_fp16.onnx"
)

Config file copied to models/glados_fp16.onnx.json


In [3]:
# Load and convert the model
model = onnx.load("models/phomenizer_en.onnx")
model_fp16 = convert_float_to_float16(model)
onnx.save_model(model_fp16, "models/phomenizer_en_fp16.onnx")





In [15]:
import onnx
import warnings
from onnxconverter_common import float16
from onnxruntime.quantization import quantize_dynamic, QuantType
from onnxruntime.quantization.quant_utils import QuantizationMode

def get_audio_input_name(model_path):
    """
    Get the name of the audio input node from the ONNX model.
    
    Args:
        model_path: Path to the ONNX model
    Returns:
        str: Name of the audio input node
    """
    model = onnx.load(model_path)
    # Usually the first input is the audio input
    audio_input = model.graph.input[0].name
    return audio_input

def quantize_model(input_path, output_dir):
    """
    Quantize an ONNX model to different precision formats while keeping audio input as FP32.
    
    Args:
        input_path: Path to the input ONNX model
        output_dir: Directory to save quantized models
    """
    # Load the original model
    model = onnx.load(input_path)
    model_name = input_path.split('/')[-1].replace('.onnx', '')
    
    # Get the audio input name
    audio_input_name = get_audio_input_name(input_path)
    
    # FP16 conversion with preserved FP32 input
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=UserWarning, module="onnxconverter_common.float16")
        # Specify which nodes should remain in FP32
        model_fp16 = float16.convert_float_to_float16(
            model,
            keep_io_types=True,  # Keep input/output types as is
            op_block_list=[audio_input_name]  # Keep audio input in FP32
        )
        fp16_path = f"{output_dir}/{model_name}_fp16.onnx"
        onnx.save_model(model_fp16, fp16_path)
        print(f"Saved FP16 model to {fp16_path}")

    # INT8 quantization
    int8_path = f"{output_dir}/{model_name}_int8.onnx"
    quantize_dynamic(
        input_path,
        int8_path,
        weight_type=QuantType.QInt8,
        op_types_to_quantize=['Conv', 'MatMul'],  # Quantize only specific operations
        nodes_to_exclude=[audio_input_name]  # Exclude audio input from quantization
    )
    print(f"Saved INT8 model to {int8_path}")
    
    # INT4 quantization (experimental)
    try:
        int4_path = f"{output_dir}/{model_name}_int4.onnx"
        quantize_dynamic(
            input_path,
            int4_path,
            weight_type=QuantType.QInt4,
            op_types_to_quantize=['Conv', 'MatMul'],
            nodes_to_exclude=[audio_input_name]
        )
        print(f"Saved INT4 model to {int4_path}")
    except Exception as e:
        print(f"INT4 quantization failed: {str(e)}")
        print("Note: INT4 quantization is experimental and may not be supported for this model")

def verify_model_inputs(model_path):
    """
    Verify the input types of the quantized model.
    
    Args:
        model_path: Path to the ONNX model
    """
    model = onnx.load(model_path)
    print(f"\nVerifying model: {model_path}")
    for input_info in model.graph.input:
        print(f"Input '{input_info.name}' type: {input_info.type.tensor_type.elem_type}")
        # ONNX element type 1 corresponds to FLOAT

if __name__ == "__main__":
    input_model = "models/phomenizer_en.onnx"
    output_dir = "models/quantized"
    
    quantize_model(input_model, output_dir)
    
    # Verify the quantized models
    for model_type in ['fp16', 'int8', 'int4']:
        try:
            verify_model_inputs(f"{output_dir}/{input_model.split('/')[-1].replace('.onnx', '')}_{model_type}.onnx")
        except Exception as e:
            print(f"Could not verify {model_type} model: {str(e)}")



Saved FP16 model to models/quantized/phomenizer_en_fp16.onnx




Saved INT8 model to models/quantized/phomenizer_en_int8.onnx
Saved INT4 model to models/quantized/phomenizer_en_int4.onnx

Verifying model: models/quantized/phomenizer_en_fp16.onnx
Input 'modelInput' type: 7

Verifying model: models/quantized/phomenizer_en_int8.onnx
Input 'modelInput' type: 7

Verifying model: models/quantized/phomenizer_en_int4.onnx
Input 'modelInput' type: 7


In [16]:

input_model = "models/glados.onnx"
output_dir = "models/quantized"

quantize_model(input_model, output_dir)

Saved FP16 model to models/quantized/glados_fp16.onnx




Saved INT8 model to models/quantized/glados_int8.onnx




Saved INT4 model to models/quantized/glados_int4.onnx


In [17]:



input_model = "models/nemo-parakeet_tdt_ctc_110m.onnx"
output_dir = "models/quantized"

quantize_model(input_model, output_dir)

Saved FP16 model to models/quantized/nemo-parakeet_tdt_ctc_110m_fp16.onnx




Saved INT8 model to models/quantized/nemo-parakeet_tdt_ctc_110m_int8.onnx




Saved INT4 model to models/quantized/nemo-parakeet_tdt_ctc_110m_int4.onnx
