Skip to content

UNet2DConditionModel to onnx with torch.onnx faild #9890

@jesenzhang

Description

@jesenzhang

Describe the bug

I want to convert a unet to onnx using the way as example, i can get unet result as ret,but when run into torch.onnx.export , an error reported, the code is here.

Reproduction

import os
import shutil
from pathlib import Path
import onnx
import torch
from packaging import version
from torch.onnx import export
from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline
from diffusers import DiffusionPipeline
from transformers import CLIPTextModel,CLIPTextModelWithProjection
from diffusers import AutoencoderKL,UNet2DConditionModel

@torch.no_grad()
def export_unetxl(device='cuda'):
model_opts = {'torch_dtype': torch.float16,'use_safetensors':False}
unet_model_dir='RunDiffusion/Juggernaut-XI-v11'
model= UNet2DConditionModel.from_pretrained(unet_model_dir,subfolder="unet", **model_opts).to(device)
torch_inference=''
# model = optimize_checkpoint(model, torch_inference)
onnx_path = 'juggernaut-v11-unet.onnx'

input_names= [
        "sample",
        "timestep",
        "encoder_hidden_states",
        "text_embeds",
        "time_ids"
    ]
output_names= [
        "latent"
    ]
dynamic_axes={
        "sample": {
            "0": "2B",
            "1": "C",
            "2": "H",
            "3": "W"
        },
        "timestep": { "0": "2B" },
        "encoder_hidden_states": {
            "0": "2B",
            "1": "unet_hidden_sequence" 
        },
       'latent': {0: '2B',"1": "C", 2: 'H', 3: 'W'},
        "text_embeds": {
            "0": "2B",
            "1": "unet_text_embeds_size"
        },
        "time_ids": {
            "0": "2B",
            "1": "unet_time_ids_size"
        }
    }
shape_dict= {
        "sample": [
            2,
            4,
            128,
            128
        ],
        "encoder_hidden_states": [
            2,
            77,
            2048
        ],
        "latent": [
            2,
            4,
            128,
            128
        ],
        "text_embeds": [
            2,
            1280
        ],
        "time_ids": [
            2,
            6
        ]
    }
xB=2
dtype = torch.float16
batch_size=1
unet_dim=4
latent_height=128
latent_width=128
text_maxlen=77
time_dim=6
embedding_dim=2048
opt_level=19
sample_input=(
        torch.randn((xB*batch_size,unet_dim, latent_height, latent_width), dtype=dtype, device=device),
        torch.randn((1,)).to(device=device, dtype=dtype),
        torch.randn((xB*batch_size,text_maxlen,embedding_dim), dtype=dtype, device=device),
        # torch.rand((xB* batch_size, 1280), dtype=dtype),
        # torch.rand((xB* batch_size, time_dim), dtype=dtype)
        {
            'added_cond_kwargs': {
                'text_embeds': torch.randn((xB*batch_size, 1280), dtype=dtype, device=device),
                'time_ids' : torch.randn((xB*batch_size, time_dim), dtype=dtype, device=device)
            }
        }
    )


ret = model(torch.randn((xB*batch_size,unet_dim, latent_height, latent_width), dtype=dtype, device=device),
        torch.randn((1,)).to(device=device, dtype=dtype),
        torch.randn((xB*batch_size,text_maxlen,embedding_dim), dtype=dtype, device=device),
        added_cond_kwargs= {
            'text_embeds': torch.randn((xB*batch_size, 1280), dtype=dtype, device=device),
            'time_ids' : torch.randn((xB*batch_size, time_dim), dtype=dtype, device=device)
        }
        )
# with torch.inference_mode():
torch.onnx.export(model,
    sample_input,
    onnx_path,
    export_params=True,
    verbose = False,
    do_constant_folding=True,
    opset_version=opt_level,
    input_names=input_names,
    output_names=output_names,
    dynamic_axes=dynamic_axes,
)

export_unetxl()

Logs

diffusers.models.unets.unet_2d_condition.UNet2DConditionModel::/torch.nn.modules.normalization.GroupNorm::conv_norm_out # /home/sd/miniconda3/envs/trt/lib/python3.11/site-packages/torch/nn/functional.py:2955:0
  %input : Half(2, 320, 128, 128, strides=[5242880, 16384, 128, 1], requires_grad=0, device=cuda:0) = aten::silu(%input.1095), scope: diffusers.models.unets.unet_2d_condition.UNet2DConditionModel::/torch.nn.modules.activation.SiLU::act # /home/sd/miniconda3/envs/trt/lib/python3.11/site-packages/torch/nn/functional.py:2380:0
  %107415 : Half(2, 4, 128, 128, strides=[65536, 16384, 128, 1], requires_grad=0, device=cuda:0) = aten::_convolution(%input, %conv_out.weight, %conv_out.bias, %107429, %107429, %107429, %108606, %107432, %108608, %108606, %108606, %108614, %108614), scope: diffusers.models.unets.unet_2d_condition.UNet2DConditionModel::/torch.nn.modules.conv.Conv2d::conv_out # /home/sd/miniconda3/envs/trt/lib/python3.11/site-packages/torch/nn/modules/conv.py:549:0
  return (%107415)
, {'sample': {'0': '2B', '1': 'C', '2': 'H', '3': 'W'}, 'timestep': {'0': '2B'}, 'encoder_hidden_states': {'0': '2B', '1': 'unet_hidden_sequence'}, 'latent': {0: '2B', '1': 'C', 2: 'H', 3: 'W'}, 'text_embeds': {'0': '2B', '1': 'unet_text_embeds_size'}, 'time_ids': {'0': '2B', '1': 'unet_time_ids_size'}}, ['sample', 'timestep', 'encoder_hidden_states', 'text_embeds', 'time_ids']

System Info

No LSB modules are available.
Distributor ID: Ubuntu
Description: Ubuntu 20.04.6 LTS
Release: 20.04
Codename: focal

diffusers 0.32.0.dev0
torch 2.5.0
torch2trt_dynamic 0.6.0
torchmetrics 1.5.1
torchvision 0.20.0
transformers 4.42.2
onnx 1.17.0
onnx-graphsurgeon 0.5.2
onnxconverter-common 1.14.0
onnxmltools 1.12.0
onnxruntime 1.20.0
onnxruntime_extensions 0.12.0
onnxruntime-gpu 1.19.2
onnxscript 0.1.0.dev20241107
onnxslim 0.1.35

Who can help?

@yiyixuxu @sayakpaul @DN6

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingstaleIssues that haven't received updates

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions