Before running conversion:
1. Run [1.0_Model_download.ipynb](1.0_Model_download.ipynb) to download models locally.

In [None]:
import os
import rootutils
import torch

# adding root to python path
rootutils.setup_root(
    os.path.abspath(''), indicator=['.git', 'pyproject.toml'], pythonpath=True
)

from src.models.components.cnn_cam_multihead import CNNCAMMultihead
from src.models.components.vit_rollout_multihead import VitRolloutMultihead
from src.models.components.nn_utils import weight_load
from src.models.components.base_model import BaseModel

#### CNN export

In [None]:
cnn_model = CNNCAMMultihead(
    backbone='torchvision.models/efficientnet_v2_s',
    return_node='features.6.0.block.0',
    multi_head=True,
)
weights = weight_load(
    ckpt_path='../trained_models/models--DeepVisionXplain--efficientnet_v2_s_downscaled_pcb/',
    weights_only=True,
)
cnn_model.load_state_dict(weights)
cnn_model.eval()
x = torch.randn((1, 3, 224, 224))
torch.onnx.export(
    cnn_model,
    x,
    'efficientnet_v2_s_downscaled_pcb.onnx',
    export_params=True,
    opset_version=20,
    do_constant_folding=False,
    input_names=['input'],
    output_names=['map', 'output'],
    dynamic_axes={
        'input': {0: 'batch_size'},
        'map': {0: 'batch_size'},
        'output': {0: 'batch_size'},
    },
)

#### ViT export

In [None]:
vit_model = VitRolloutMultihead(
    backbone='timm/vit_tiny_patch16_224.augreg_in21k_ft_in1k', multi_head=True
)
weights = weight_load(
    ckpt_path='../trained_models/models--DeepVisionXplain--vit_tiny_patch16_224.augreg_in21k_ft_in1k_pcb/',
    weights_only=True,
)
vit_model.load_state_dict(weights)
vit_model.eval()
x = torch.randn((1, 3, 224, 224))
torch.onnx.export(
    vit_model,
    x,
    'vit_tiny_patch16_224.augreg_in21k_ft_in1k_pcb.onnx',
    export_params=True,
    opset_version=20,
    do_constant_folding=False,
    input_names=['input'],
    output_names=['map', 'output'],
    dynamic_axes={
        'input': {0: 'batch_size'},
        'map': {0: 'batch_size'},
        'output': {0: 'batch_size'},
    },
)

#### Segmentation model export

In [None]:
segmentation_model = BaseModel(
    model_name = 'segmentation_models_pytorch/UnetPlusPlus',
    encoder_name = 'mobilenet_v2',
    )
weights = weight_load(
    ckpt_path='../trained_models/unet++.ckpt',
    weights_only=True,
)
segmentation_model.load_state_dict(weights)
segmentation_model.eval()


x = torch.randn((1, 3, 768, 640))
torch.onnx.export(
    segmentation_model,
    x,
    'unet++.onnx',
    export_params=True,
    opset_version=20,
    do_constant_folding=False,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={
        'input': {0: 'batch_size'},
        'output': {0: 'batch_size'},
    },
)