In [4]:
import onnx
import onnx_graphsurgeon as gs
import numpy as np

def insert_nhwc_to_nchw_transpose(
    model_path="bestfp16.onnx",
    output_path="bestfp16_nhwc.onnx",
    input_name="images"
):
    model = onnx.load(model_path)
    graph = gs.import_onnx(model)

    # Find input
    input_var = next((inp for inp in graph.inputs if inp.name == input_name), None)
    assert input_var is not None, f"❌ Input '{input_name}' not found."

    # ⚠️ Use dynamic shape: [N, H, W, 3]
    input_var.shape = ["batch", "height", "width", 3]
    input_var.dtype = np.float32

    # Output of transpose: [N, 3, H, W]
    transposed = gs.Variable(name="pixel_values_nchw", dtype=np.float32, shape=["batch", 3, "height", "width"])

    # Insert NHWC → NCHW Transpose
    transpose_node = gs.Node(
        op="Transpose",
        name="Transpose_NHWC_to_NCHW",
        inputs=[input_var],
        outputs=[transposed],
        attrs={"perm": [0, 3, 1, 2]},  # NHWC to NCHW
    )

    # Replace input usage
    for node in graph.nodes:
        node.inputs = [transposed if inp is input_var else inp for inp in node.inputs]

    graph.nodes.insert(0, transpose_node)

    graph.cleanup().toposort()
    onnx.save(gs.export_onnx(graph), output_path)

    print(f"✅ Saved NHWC-compatible model to: {output_path}")

if __name__ == "__main__":
    insert_nhwc_to_nchw_transpose()


✅ Saved NHWC-compatible model to: bestfp16_nhwc.onnx


In [5]:
import onnx
import onnx_graphsurgeon as gs
import numpy as np
model_path = "bestfp16_nhwc.onnx"
model = onnx.load(model_path)
graph = gs.import_onnx(model)
exported_model = gs.export_onnx(graph)

# Copy IR and opset version from original
exported_model.ir_version = 10
#del exported_model.opset_import[:]
#exported_model.opset_import.extend(model.opset_import)

try:
    onnx.checker.check_model(exported_model)
    print("✅ Model passed ONNX validation.")
except onnx.checker.ValidationError as e:
    print("❌ Model failed ONNX validation:")
    print(e)


print("IR version:", exported_model.ir_version)
for opset in exported_model.opset_import:
    print("Opset version for domain '{}': {}".format(opset.domain, opset.version))

onnx.save(exported_model, "bestfp16_nhwc.onnx")


✅ Model passed ONNX validation.
IR version: 10
Opset version for domain '': 20
