-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Description
Describe the issue
I have exported a custom UNET model with BatchNorm2D layers from pytorch. This has resulted in an onnx file with the BatchNorm layers in train mode (it seems impossible to create an onnx from this model in eval mode, there are lots of issues with the exporter documented), In onnxruntime all is fine and the output matches that of the original model. To use externally (TensorRT) I need batchnorm layers to be in eval mode. It was suggested simply changing this in onnx and removing the additional 2 outputs. This results in a valid onnx model, but it gives a different output to the original pytorch and onnx models. I've checked all the values of the two models in netron and they are identical apart from training_mode.
To reproduce
The following script reproduces the error:
import torch
from PIL import Image
from torchvision import transforms as T
import torchvision
import onnx
def get_transform(image):
transform = T.ToTensor()
image = transform(image)
return image
def disable_running_stats(model):
"""Disable track_running_stats to restore original training behavior."""
for module in model.modules():
if isinstance(module, torch.nn.BatchNorm2d): # Adjust for InstanceNorm1d/3d if needed
module.track_running_stats = False
# Sets BatchNorm layer to eval (not train) and deletes extra 2 outputs
def remove_bn_extra_outputs(model_path, output_path):
# Load the ONNX model
onnx_model = onnx.load(model_path)
graph = onnx_model.graph
for node in graph.node:
if node.op_type == "BatchNormalization":
# Find the 'training_mode' attribute and set it to 0 (False)
for attribute in node.attribute:
if attribute.name == 'training_mode' and attribute.i == 1:
attribute.i = 0
node.output.remove(node.output[1])
node.output.remove(node.output[1])
# Check the model for validity and save
onnx.checker.check_model(onnx_model)
onnx.save(onnx_model, output_path)
print(f"Modified model saved to {output_path}")
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Load image data
pic = Image.open('snapshot10X-crop.jpg')
pic=(get_transform(pic).unsqueeze(0)).to(device)
# Load PyTorch model
model_gc = torch.load("gc_traced.pt",weights_only=False)
model_gc.to(device)
# do inference in PyTorch (This is correct)
out = model_gc(pic)
# Calculae statistics of PyTorch output
max_value = torch.max(out[0,0,:,:])
min_value = torch.min(out[0,0,:,:])
mean_value = torch.mean(out[0,0,:,:])
print("Out0 Max,Min,Mean:")
print(max_value)
print(min_value)
print(mean_value)
max_value = torch.max(out[0,1,:,:])
min_value = torch.min(out[0,1,:,:])
mean_value = torch.mean(out[0,1,:,:])
print("Out1 Max,Min,Mean:")
print(max_value)
print(min_value)
print(mean_value)
#Save results as pngs (these look correct)
torchvision.io.write_png((out[0,0:1,:,:]*255).cpu().type(torch.uint8),"result0.png")
torchvision.io.write_png((out[0,1:2,:,:]*255).cpu().type(torch.uint8),"result1.png")
# Convert to onnx
from torch.onnx import export
#export(model_gc, pic, "mymodel.onnx", opset_version=18, input_names=['image'] ,output_names=['output'], dynamo=True) # New converter - doesn't work for this model
export(model_gc, pic, "mymodel.onnx", opset_version=14, input_names=['image'] ,output_names=['output']) # Legacy converter (works but can only output onnx in "Train" mode!)
remove_bn_extra_outputs("mymodel.onnx","mymodel_fixed.onnx")
#
# Now load, and do inference with original ONNX model (has BatchNorm nodes in training mode)
#
import onnxruntime
session = onnxruntime.InferenceSession("mymodel.onnx")
# Check if the model has been loaded successfully
if session is None:
raise ValueError("Failed to load the model")
# Load image data
pic = Image.open('snapshot10X-crop.jpg')
pic=(get_transform(pic).unsqueeze(0)).to(device)
result = session.run(["output"], {"image": pic.cpu().numpy()})
import numpy as np
out = result[0]
max_value = np.max(out[0,0,:,:])
min_value = np.min(out[0,0,:,:])
mean_value = np.mean(out[0,0,:,:])
print("Original ONNX Out0 Max,Min,Mean:")
print(max_value)
print(min_value)
print(mean_value)
max_value = np.max(out[0,1,:,:])
min_value = np.min(out[0,1,:,:])
mean_value = np.mean(out[0,1,:,:])
print("Original ONNX Out1 Max,Min,Mean:")
print(max_value)
print(min_value)
print(mean_value)
torchvision.io.write_png(torch.from_numpy((out[0,0:1,:,:]*255).astype(np.uint8)),"ONNXresult0.png")
torchvision.io.write_png(torch.from_numpy((out[0,1:2,:,:]*255).astype(np.uint8)),"ONNXresult1.png")
#
# Now load, and do inference with fixed ONNX model (has BatchNorm nodes in eval mode)
#
session2 = onnxruntime.InferenceSession("mymodel_fixed.onnx")
# Check if the model has been loaded successfully
if session2 is None:
raise ValueError("Failed to load the model")
# Load image data
pic = Image.open('snapshot10X-crop.jpg')
pic=(get_transform(pic).unsqueeze(0)).to(device)
result = session2.run(["output"], {"image": pic.cpu().numpy()})
import numpy as np
out = result[0]
max_value = np.max(out[0,0,:,:])
min_value = np.min(out[0,0,:,:])
mean_value = np.mean(out[0,0,:,:])
print("Fixed ONNX Out0 Max,Min,Mean:")
print(max_value)
print(min_value)
print(mean_value)
max_value = np.max(out[0,1,:,:])
min_value = np.min(out[0,1,:,:])
mean_value = np.mean(out[0,1,:,:])
print("Fixed ONNX Out1 Max,Min,Mean:")
print(max_value)
print(min_value)
print(mean_value)
The output is:
Out0 Max,Min,Mean:
tensor(1., device='cuda:0', grad_fn=<MaxBackward1>)
tensor(0.0086, device='cuda:0', grad_fn=<MinBackward1>)
tensor(0.8373, device='cuda:0', grad_fn=<MeanBackward0>)
Out1 Max,Min,Mean:
tensor(0.9914, device='cuda:0', grad_fn=<MaxBackward1>)
tensor(0., device='cuda:0', grad_fn=<MinBackward1>)
tensor(0.1627, device='cuda:0', grad_fn=<MeanBackward0>)
C:\Users\Derek Magee\Downloads\GC_Expt\drm_eval3.py:86: DeprecationWarning: You are using the legacy TorchScript-based ONNX export. Starting in PyTorch 2.9, the new torch.export-based ONNX exporter will be the default. To switch now, set dynamo=True in torch.onnx.export. This new exporter supports features like exporting LLMs with DynamicCache. We encourage you to try it and share feedback to help improve the experience. Learn more about the new export logic: https://pytorch.org/docs/stable/onnx_dynamo.html. For exporting control flow: https://pytorch.org/tutorials/beginner/onnx/export_control_flow_model_to_onnx_tutorial.html.
export(model_gc, pic, "mymodel.onnx", opset_version=14, input_names=['image'] ,output_names=['output']) # Legacy converter (works but model flawed - fine with torchscript model though!)
c:\Python311\Lib\site-packages\torch\onnx\utils.py:807: UserWarning: no signature found for builtin <built-in method __call__ of PyCapsule object at 0x00000168E79D9A10>, skipping _decide_input_format
warnings.warn(f"{e}, skipping _decide_input_format")
c:\Python311\Lib\site-packages\torch\onnx\symbolic_helper.py:1460: UserWarning: ONNX export mode is set to TrainingMode.EVAL, but operator 'batch_norm' is set to train=True. Exporting with train=True.
warnings.warn(
Modified model saved to mymodel_fixed.onnx
Original ONNX Out0 Max,Min,Mean:
1.0
0.008597248
0.8372893
Original ONNX Out1 Max,Min,Mean:
0.9914028
0.0
0.16271073
Fixed ONNX Out0 Max,Min,Mean:
1.0
0.11693505
0.9321432
Fixed ONNX Out1 Max,Min,Mean:
0.8830649
1.151566e-19
0.06785674
As you can see, the initial onnx model (in train mode) matches the original pytorch model. The "Fixed" model does not. The only difference between the models is the setting of training_mode to 0 for BatchNormalisation layers, and removing the tow extra outputs (this is done in remove_bn_extra_outputs).
Model available at (Note: this is a torchscript model for ease of use):
https://drive.google.com/file/d/1R_fIsJs7C8uBL-RqW6aBF5Qao0a1LUFK/view?usp=sharing
Data image at:
https://drive.google.com/file/d/1aw75jZFNH98NJbTs2v9huVECGLEyMdMc/view?usp=sharing
I'm using:
PyTorch: 2.8.0+cu128
Python 3.11
Windows 10
Urgency
No response
Platform
Windows
OS Version
10
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
1.20.1
ONNX Runtime API
Python
Architecture
X64
Execution Provider
CUDA
Execution Provider Library Version
12.8