In [None]:
import os
from pathlib import Path

import onnx
import torch
from diffusers.models.attention_processor import Attention
from diffusers import DiffusionPipeline, UNet2DConditionModel
from optimum.onnx.utils import _get_onnx_external_data_tensors, check_model_uses_external_data
from torch.onnx import export as onnx_export

In [3]:
# backbone = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", 
#                 subfolder='unet',
#                 torch_dtype=torch.float16, 
#                 use_safetensors=True, 
#                 variant="fp16"
#                 ).to("cuda")

In [3]:
config, unused_kwargs, commit_hash = UNet2DConditionModel.load_config(
            "config.json",
            return_unused_kwargs=True,
            return_commit_hash=True,
        )
backbone = UNet2DConditionModel.from_config(config, 
                                            torch_dtype=torch.float16, 
                                            **unused_kwargs
                                            ).to("cuda").half()

In [3]:
backbone

UNet2DConditionModel(
  (conv_in): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (time_proj): Timesteps()
  (time_embedding): TimestepEmbedding(
    (linear_1): Linear(in_features=320, out_features=1280, bias=True)
    (act): SiLU()
    (linear_2): Linear(in_features=1280, out_features=1280, bias=True)
  )
  (add_time_proj): Timesteps()
  (add_embedding): TimestepEmbedding(
    (linear_1): Linear(in_features=2816, out_features=1280, bias=True)
    (act): SiLU()
    (linear_2): Linear(in_features=1280, out_features=1280, bias=True)
  )
  (down_blocks): ModuleList(
    (0): DownBlock2D(
      (resnets): ModuleList(
        (0): ResnetBlock2D(
          (norm1): GroupNorm(32, 320, eps=1e-05, affine=True)
          (conv1): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=1280, out_features=320, bias=True)
          (norm2): GroupNorm(32, 320, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, i

In [3]:
AXES_NAME = {
    "sdxl-1.0": {
        "sample": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"},
        "timestep": {0: "steps"},
        "encoder_hidden_states": {0: "batch_size", 1: "sequence_length"},
        "text_embeds": {0: "batch_size"},
        "time_ids": {0: "batch_size"},
        "latent": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"},
    },
}

def generate_dummy_inputs(sd_version, device):
    dummy_input = {}
    if sd_version == "sdxl-1.0" or sd_version == "sdxl-turbo":
        dummy_input["sample"] = torch.ones(2, 4, 128, 128).to(device).half()
        dummy_input["timestep"] = torch.ones(1).to(device).half()
        dummy_input["encoder_hidden_states"] = torch.ones(2, 77, 2048).to(device).half()
        dummy_input["added_cond_kwargs"] = {}
        dummy_input["added_cond_kwargs"]["text_embeds"] = torch.ones(2, 1280).to(device).half()
        dummy_input["added_cond_kwargs"]["time_ids"] = torch.ones(2, 6).to(device).half()
    elif sd_version == "sd3-medium":
        dummy_input["hidden_states"] = torch.ones(2, 16, 128, 128).to(device).half()
        dummy_input["timestep"] = torch.ones(2).to(device).half()
        dummy_input["encoder_hidden_states"] = torch.ones(2, 333, 4096).to(device).half()
        dummy_input["pooled_projections"] = torch.ones(2, 2048).to(device).half()
    elif sd_version == "sd1.5":
        dummy_input["sample"] = torch.ones(2, 4, 64, 64).to(device).half()
        dummy_input["timestep"] = torch.ones(1).to(device).half()
        dummy_input["encoder_hidden_states"] = torch.ones(2, 16, 768).to(device).half()
    else:
        raise NotImplementedError(f"Unsupported sd_version: {sd_version}")

    return dummy_input

In [4]:
def modelopt_export_sd(backbone, onnx_dir, model_name):
    os.makedirs(f"{onnx_dir}", exist_ok=True)
    dummy_inputs = generate_dummy_inputs(model_name, device=backbone.device)

    output = Path(f"{onnx_dir}/backbone.onnx")
    if model_name == "sdxl-1.0" or model_name == "sdxl-turbo":
        input_names = ["sample", "timestep", "encoder_hidden_states", "text_embeds", "time_ids"]
        output_names = ["latent"]
    elif model_name == "sd1.5":
        input_names = ["sample", "timestep", "encoder_hidden_states"]
        output_names = ["latent"]
    elif model_name == "sd3-medium":
        input_names = ["hidden_states", "encoder_hidden_states", "pooled_projections", "timestep"]
        output_names = ["sample"]
    else:
        raise NotImplementedError(f"Unsupported sd_version: {model_name}")

    dynamic_axes = AXES_NAME[model_name]
    do_constant_folding = True
    opset_version = 17

    # Copied from Huggingface's Optimum
    with torch.inference_mode(), torch.autocast("cuda"):
        onnx_export(
            backbone,
            (dummy_inputs,),
            f=output.as_posix(),
            input_names=input_names,
            output_names=output_names,
            dynamic_axes=dynamic_axes,
            do_constant_folding=do_constant_folding,
            opset_version=opset_version,
            export_params=True,
        )

    onnx_model = onnx.load(str(output), load_external_data=False)
    model_uses_external_data = check_model_uses_external_data(onnx_model)

    if model_uses_external_data:
        print('model_uses_external_data : True')
        tensors_paths = _get_onnx_external_data_tensors(onnx_model)
        onnx_model = onnx.load(str(output), load_external_data=True)
        onnx.save(
            onnx_model,
            str(output),
            save_as_external_data=True,
            all_tensors_to_one_file=True,
            location=output.name + "_data",
            size_threshold=1024,
        )
        for tensor in tensors_paths:
            os.remove(output.parent / tensor)

In [5]:
# print(backbone)
onnx_dir = 'onnx_unet'
model_name = 'sdxl-1.0'
modelopt_export_sd(backbone, onnx_dir, model_name)

  if dim % default_overall_up_factor != 0:
  assert hidden_states.shape[1] == self.channels
  assert hidden_states.shape[1] == self.channels
  assert hidden_states.shape[1] == self.channels
  if hidden_states.shape[0] >= 64:
  if not return_dict:


In [9]:
import sys
sys.path.insert(0, '../trt_demo')

from models import Optimizer
def optimize(onnx_graph, return_onnx=True, **kwargs):
    opt = Optimizer(onnx_graph, verbose=True)
    opt.info(': original')
    opt.cleanup()
    opt.info(': cleanup')
    opt.fold_constants()
    opt.info(': fold constants')
    opt.infer_shapes()
    opt.info(': shape inference')
    if kwargs.get('fuse_mha_qkv_int8', False):
        opt.fuse_mha_qkv_int8_sq()
        opt.info(': fuse QKV nodes')
    if kwargs.get('add_groupnorm', False):
        opt.add_groupnorm()
        opt.info(': add groupnorm')
    onnx_opt_graph = opt.cleanup(return_onnx=return_onnx)
    opt.info(': finished')
    return onnx_opt_graph

In [None]:
output = Path(f"{onnx_dir}/backbone.onnx")
onnx_model = onnx.load(str(output))
onnx_opt_graph = optimize(onnx_model, fuse_mha_qkv_int8=True, add_groupnorm=True)
onnx.save(onnx_opt_graph, f"{onnx_dir}/backbone_opt.onnx")