Describe the bug
I want to convert a unet to onnx using the way as example, i can get unet result as ret,but when run into torch.onnx.export , an error reported, the code is here.
Reproduction
import os
import shutil
from pathlib import Path
import onnx
import torch
from packaging import version
from torch.onnx import export
from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline
from diffusers import DiffusionPipeline
from transformers import CLIPTextModel,CLIPTextModelWithProjection
from diffusers import AutoencoderKL,UNet2DConditionModel
@torch.no_grad()
def export_unetxl(device='cuda'):
model_opts = {'torch_dtype': torch.float16,'use_safetensors':False}
unet_model_dir='RunDiffusion/Juggernaut-XI-v11'
model= UNet2DConditionModel.from_pretrained(unet_model_dir,subfolder="unet", **model_opts).to(device)
torch_inference=''
# model = optimize_checkpoint(model, torch_inference)
onnx_path = 'juggernaut-v11-unet.onnx'
input_names= [
"sample",
"timestep",
"encoder_hidden_states",
"text_embeds",
"time_ids"
]
output_names= [
"latent"
]
dynamic_axes={
"sample": {
"0": "2B",
"1": "C",
"2": "H",
"3": "W"
},
"timestep": { "0": "2B" },
"encoder_hidden_states": {
"0": "2B",
"1": "unet_hidden_sequence"
},
'latent': {0: '2B',"1": "C", 2: 'H', 3: 'W'},
"text_embeds": {
"0": "2B",
"1": "unet_text_embeds_size"
},
"time_ids": {
"0": "2B",
"1": "unet_time_ids_size"
}
}
shape_dict= {
"sample": [
2,
4,
128,
128
],
"encoder_hidden_states": [
2,
77,
2048
],
"latent": [
2,
4,
128,
128
],
"text_embeds": [
2,
1280
],
"time_ids": [
2,
6
]
}
xB=2
dtype = torch.float16
batch_size=1
unet_dim=4
latent_height=128
latent_width=128
text_maxlen=77
time_dim=6
embedding_dim=2048
opt_level=19
sample_input=(
torch.randn((xB*batch_size,unet_dim, latent_height, latent_width), dtype=dtype, device=device),
torch.randn((1,)).to(device=device, dtype=dtype),
torch.randn((xB*batch_size,text_maxlen,embedding_dim), dtype=dtype, device=device),
# torch.rand((xB* batch_size, 1280), dtype=dtype),
# torch.rand((xB* batch_size, time_dim), dtype=dtype)
{
'added_cond_kwargs': {
'text_embeds': torch.randn((xB*batch_size, 1280), dtype=dtype, device=device),
'time_ids' : torch.randn((xB*batch_size, time_dim), dtype=dtype, device=device)
}
}
)
ret = model(torch.randn((xB*batch_size,unet_dim, latent_height, latent_width), dtype=dtype, device=device),
torch.randn((1,)).to(device=device, dtype=dtype),
torch.randn((xB*batch_size,text_maxlen,embedding_dim), dtype=dtype, device=device),
added_cond_kwargs= {
'text_embeds': torch.randn((xB*batch_size, 1280), dtype=dtype, device=device),
'time_ids' : torch.randn((xB*batch_size, time_dim), dtype=dtype, device=device)
}
)
# with torch.inference_mode():
torch.onnx.export(model,
sample_input,
onnx_path,
export_params=True,
verbose = False,
do_constant_folding=True,
opset_version=opt_level,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
)
export_unetxl()
Logs
diffusers.models.unets.unet_2d_condition.UNet2DConditionModel::/torch.nn.modules.normalization.GroupNorm::conv_norm_out # /home/sd/miniconda3/envs/trt/lib/python3.11/site-packages/torch/nn/functional.py:2955:0
%input : Half(2, 320, 128, 128, strides=[5242880, 16384, 128, 1], requires_grad=0, device=cuda:0) = aten::silu(%input.1095), scope: diffusers.models.unets.unet_2d_condition.UNet2DConditionModel::/torch.nn.modules.activation.SiLU::act # /home/sd/miniconda3/envs/trt/lib/python3.11/site-packages/torch/nn/functional.py:2380:0
%107415 : Half(2, 4, 128, 128, strides=[65536, 16384, 128, 1], requires_grad=0, device=cuda:0) = aten::_convolution(%input, %conv_out.weight, %conv_out.bias, %107429, %107429, %107429, %108606, %107432, %108608, %108606, %108606, %108614, %108614), scope: diffusers.models.unets.unet_2d_condition.UNet2DConditionModel::/torch.nn.modules.conv.Conv2d::conv_out # /home/sd/miniconda3/envs/trt/lib/python3.11/site-packages/torch/nn/modules/conv.py:549:0
return (%107415)
, {'sample': {'0': '2B', '1': 'C', '2': 'H', '3': 'W'}, 'timestep': {'0': '2B'}, 'encoder_hidden_states': {'0': '2B', '1': 'unet_hidden_sequence'}, 'latent': {0: '2B', '1': 'C', 2: 'H', 3: 'W'}, 'text_embeds': {'0': '2B', '1': 'unet_text_embeds_size'}, 'time_ids': {'0': '2B', '1': 'unet_time_ids_size'}}, ['sample', 'timestep', 'encoder_hidden_states', 'text_embeds', 'time_ids']
System Info
No LSB modules are available.
Distributor ID: Ubuntu
Description: Ubuntu 20.04.6 LTS
Release: 20.04
Codename: focal
diffusers 0.32.0.dev0
torch 2.5.0
torch2trt_dynamic 0.6.0
torchmetrics 1.5.1
torchvision 0.20.0
transformers 4.42.2
onnx 1.17.0
onnx-graphsurgeon 0.5.2
onnxconverter-common 1.14.0
onnxmltools 1.12.0
onnxruntime 1.20.0
onnxruntime_extensions 0.12.0
onnxruntime-gpu 1.19.2
onnxscript 0.1.0.dev20241107
onnxslim 0.1.35
Who can help?
@yiyixuxu @sayakpaul @DN6
Describe the bug
I want to convert a unet to onnx using the way as example, i can get unet result as ret,but when run into torch.onnx.export , an error reported, the code is here.
Reproduction
import os
import shutil
from pathlib import Path
import onnx
import torch
from packaging import version
from torch.onnx import export
from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline
from diffusers import DiffusionPipeline
from transformers import CLIPTextModel,CLIPTextModelWithProjection
from diffusers import AutoencoderKL,UNet2DConditionModel
@torch.no_grad()
def export_unetxl(device='cuda'):
model_opts = {'torch_dtype': torch.float16,'use_safetensors':False}
unet_model_dir='RunDiffusion/Juggernaut-XI-v11'
model= UNet2DConditionModel.from_pretrained(unet_model_dir,subfolder="unet", **model_opts).to(device)
torch_inference=''
# model = optimize_checkpoint(model, torch_inference)
onnx_path = 'juggernaut-v11-unet.onnx'
export_unetxl()
Logs
System Info
No LSB modules are available.
Distributor ID: Ubuntu
Description: Ubuntu 20.04.6 LTS
Release: 20.04
Codename: focal
diffusers 0.32.0.dev0
torch 2.5.0
torch2trt_dynamic 0.6.0
torchmetrics 1.5.1
torchvision 0.20.0
transformers 4.42.2
onnx 1.17.0
onnx-graphsurgeon 0.5.2
onnxconverter-common 1.14.0
onnxmltools 1.12.0
onnxruntime 1.20.0
onnxruntime_extensions 0.12.0
onnxruntime-gpu 1.19.2
onnxscript 0.1.0.dev20241107
onnxslim 0.1.35
Who can help?
@yiyixuxu @sayakpaul @DN6