# DDColor ONNX Export - Dynamic Batch Size

Export DDColor model to ONNX with dynamic batch size support (fixed 512x512 spatial dimensions).

## 1. Install Dependencies

In [None]:
# Uncomment if packages are not installed
# !pip install onnxsim onnxruntime

## 2. Import Libraries

In [1]:
import torch
import onnx
import onnxsim
from onnx import load_model, save_model, shape_inference
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference

import sys

sys.path.append('/workspace/DDColor')
from basicsr.archs.ddcolor_arch import DDColor

print(f"PyTorch version: {torch.__version__}")
print(f"ONNX version: {onnx.__version__}")

  from .autonotebook import tqdm as notebook_tqdm


PyTorch version: 2.6.0a0+df5bbc09d1.nv24.12
ONNX version: 1.16.1


## 3. Configuration

In [2]:
# Model configuration
MODEL_PATH = "./DDColor/pretrain/ddcolor_paper_tiny.pth"  # Path to your trained DDColor model
MODEL_SIZE = "tiny"  # Options: "tiny" or "large"
DECODER_TYPE = "MultiScaleColorDecoder"  # Options: "MultiScaleColorDecoder" or "SingleColorDecoder"

# Export settings
EXPORT_PATH = "./exported/model.onnx"  # Output ONNX file
OPSET_VERSION = 12  # ONNX opset version

# Fixed spatial dimensions (DO NOT CHANGE)
INPUT_SIZE = [512, 512]

print(f"Configuration:")
print(f"  Model: {MODEL_PATH}")
print(f"  Size: {MODEL_SIZE}")
print(f"  Decoder: {DECODER_TYPE}")
print(f"  Output: {EXPORT_PATH}")
print(f"  Spatial dims: {INPUT_SIZE[0]}x{INPUT_SIZE[1]} (fixed)")
print(f"  Batch size: dynamic")

Configuration:
  Model: ./DDColor/pretrain/ddcolor_paper_tiny.pth
  Size: tiny
  Decoder: MultiScaleColorDecoder
  Output: ./exported/model.onnx
  Spatial dims: 512x512 (fixed)
  Batch size: dynamic


## 4. Create and Export Model

In [3]:
def create_and_export_model():
    device = torch.device("cpu")
    encoder_name = "convnext-t" if MODEL_SIZE == "tiny" else "convnext-l"
    
    # Create model with fixed 512x512 input size
    model = DDColor(
        encoder_name=encoder_name,
        decoder_name=DECODER_TYPE,
        input_size=INPUT_SIZE,
        num_output_channels=2,
        last_norm="Spectral",
        do_normalize=False,
        num_queries=100,
        num_scales=3,
        dec_layers=9,
    ).to(device)
    
    print(f"✅ Model created: {encoder_name} encoder, {DECODER_TYPE} decoder")
    
    # Load pretrained weights
    try:
        ckpt = torch.load(MODEL_PATH, map_location=device)
        model.load_state_dict(ckpt["params"], strict=False)
        print(f"✅ Weights loaded from {MODEL_PATH}")
    except FileNotFoundError:
        print(f"⚠️  No weights found at {MODEL_PATH}, using random initialization")
    
    model.eval()
    
    # Create dummy input: [batch=1, channels=3, height=512, width=512]
    dummy_input = torch.randn((1, 3, INPUT_SIZE[0], INPUT_SIZE[1]), dtype=torch.float32)
    
    # ONLY batch dimension is dynamic, spatial dimensions are fixed at 512x512
    dynamic_axes = {
        "input": {0: "batch"},   # Only batch is dynamic
        "output": {0: "batch"},  # Only batch is dynamic
    }
    
    print(f"\n🔄 Exporting to ONNX with dynamic batch size...")
    
    torch.onnx.export(
        model,
        dummy_input,
        EXPORT_PATH,
        export_params=True,
        opset_version=OPSET_VERSION,
        input_names=["input"],
        output_names=["output"],
        dynamic_axes=dynamic_axes,
        do_constant_folding=True,
    )
    
    print(f"✅ ONNX export complete: {EXPORT_PATH}")

# Run the export
create_and_export_model()

✅ Model created: convnext-t encoder, MultiScaleColorDecoder decoder
✅ Weights loaded from ./DDColor/pretrain/ddcolor_paper_tiny.pth

🔄 Exporting to ONNX with dynamic batch size...
✅ ONNX export complete: ./exported/model.onnx


## 5. Optimize and Verify ONNX Model

In [4]:
def optimize_and_verify_onnx():
    print("🔧 Running ONNX optimization...\n")
    
    # Step 1: Shape inference
    print("1. Shape inference...")
    model = load_model(EXPORT_PATH)
    model = shape_inference.infer_shapes(model)
    save_model(model, EXPORT_PATH)
    
    # Step 2: Symbolic shape inference (better dynamic handling)
    print("2. Symbolic shape inference...")
    model = SymbolicShapeInference.infer_shapes(
        load_model(EXPORT_PATH),
        auto_merge=True,
        guess_output_rank=True,
    )
    save_model(model, EXPORT_PATH)
    
    # Step 3: Simplify
    print("3. Simplifying model...")
    model_simplified, check = onnxsim.simplify(model)
    assert check, "ONNX simplification failed!"
    onnx.save(model_simplified, EXPORT_PATH)
    
    # Step 4: Verify
    print("4. Verifying model...")
    onnx_model = onnx.load(EXPORT_PATH)
    onnx.checker.check_model(onnx_model)
    
    print(f"\n✅ Model optimized and verified: {EXPORT_PATH}")
    
    # Display model info
    print("\n📊 Model Input/Output Info:")
    for input_tensor in onnx_model.graph.input:
        print(f"  Input: {input_tensor.name}")
        shape = [dim.dim_param if dim.dim_param else dim.dim_value 
                for dim in input_tensor.type.tensor_type.shape.dim]
        print(f"    Shape: {shape}")
        print(f"    Expected: ['batch', 3, 512, 512]")
    
    for output_tensor in onnx_model.graph.output:
        print(f"\n  Output: {output_tensor.name}")
        shape = [dim.dim_param if dim.dim_param else dim.dim_value 
                for dim in output_tensor.type.tensor_type.shape.dim]
        print(f"    Shape: {shape}")
        print(f"    Expected: ['batch', 2, 512, 512]")
    
    # File size
    import os
    file_size_mb = os.path.getsize(EXPORT_PATH) / (1024 * 1024)
    print(f"\n📁 File size: {file_size_mb:.2f} MB")

# Optimize and verify
optimize_and_verify_onnx()

🔧 Running ONNX optimization...

1. Shape inference...
2. Symbolic shape inference...
3. Simplifying model...
4. Verifying model...

✅ Model optimized and verified: ./exported/model.onnx

📊 Model Input/Output Info:
  Input: input
    Shape: ['batch', 3, 512, 512]
    Expected: ['batch', 3, 512, 512]

  Output: output
    Shape: ['batch', 2, 512, 512]
    Expected: ['batch', 2, 512, 512]

📁 File size: 210.31 MB


## 6. Test ONNX Model

In [5]:
import onnxruntime as ort
import numpy as np

def test_onnx_model():
    """Test the ONNX model with different batch sizes."""

    providers = [
        # The TensorrtExecutionProvider is the fastest.
        ('TensorrtExecutionProvider', { 
            'device_id': 0,
            'trt_max_workspace_size': 4 * 1024 * 1024 * 1024,
            'trt_fp16_enable': True,
            'trt_engine_cache_enable': True,
            'trt_engine_cache_path': './trt_engine_cache',
            'trt_engine_cache_prefix': 'model',
            'trt_dump_subgraphs': False,
            'trt_timing_cache_enable': True,
            'trt_timing_cache_path': './trt_engine_cache',
            #'trt_builder_optimization_level': 3,
        })]
    # Create ONNX session
    session = ort.InferenceSession(EXPORT_PATH, providers=providers)
    
    print("Testing ONNX model with different batch sizes:\n")
    
    # Test different batch sizes
    for batch_size in [1, 4, 8, 16]:
        # Create test input
        test_input = np.random.randn(batch_size, 3, 512, 512).astype(np.float32)
        
        # Run inference
        outputs = session.run(None, {"input": test_input})
        
        print(f"Batch size {batch_size:2d}: Input {test_input.shape} → Output {outputs[0].shape} ✅")
    
    print("\n✅ Dynamic batching works correctly!")

# Test the model
try:
    test_onnx_model()
except Exception as e:
    print(f"Could not test model: {e}")
    print("This is normal if onnxruntime is not installed")

Testing ONNX model with different batch sizes:

Batch size  1: Input (1, 3, 512, 512) → Output (1, 2, 512, 512) ✅
Batch size  4: Input (4, 3, 512, 512) → Output (4, 2, 512, 512) ✅
Batch size  8: Input (8, 3, 512, 512) → Output (8, 2, 512, 512) ✅
Batch size 16: Input (16, 3, 512, 512) → Output (16, 2, 512, 512) ✅

✅ Dynamic batching works correctly!


## Summary

Your DDColor model has been exported to ONNX format!

**Key features:**
- ✅ Dynamic batch size support
- ✅ Fixed 512x512 spatial dimensions
- ✅ Optimized and simplified
- ✅ Ready for TensorRT conversion

**Next step:** Use the TensorRT notebook to convert this ONNX model to TensorRT engine for faster GPU inference.