### Export trained .pth file to other formats

In [6]:
import torch
import os
import sys

# Add the parent directory of 'models' to sys.path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

# from Models import attention_unet as AttnUNet, vanilla_unet as UNet
from networks.attention_unet import AttnUNet
from networks.unet import UNet

device = torch.device("cpu" if torch.cuda.is_available() else "cpu")

# 加载模型（假设是UNet）
model = AttnUNet(in_channels=3, out_channels=1, channels=[64, 128, 256, 512]).to(device)
state_dict = torch.load("../runs/attention_unet_0725_2105/attention_unet_best.pth", map_location=device)
# model = UNet(in_channels=3, out_channels=1, channels=[64, 128]).to(device)
# state_dict = torch.load("../runs/unet_20250725/unet_best.pth", map_location=device)
model.load_state_dict(state_dict=state_dict, strict=False)
model.eval()  # 切换到推理模式

# 创建示例输入（与模型训练时的输入尺寸一致）
dummy_input = torch.randn(1, 3, 256, 256)  # [batch, channels, height, width]

In [7]:
# 导出ONNX文件
torch.onnx.export(
    model,
    dummy_input,
    "model.onnx",  # 输出文件名
    input_names=["input"],  # 输入节点名称
    output_names=["output"],  # 输出节点名称
    dynamic_axes={
        "input": {0: "batch_size"},  # 动态批次维度
        "output": {0: "batch_size"}
    },
    opset_version=11  # ONNX算子集版本
)

In [7]:
import onnxruntime as ort

# 加载ONNX模型
sess = ort.InferenceSession("model.onnx")
output = sess.run(["output"], {"input": dummy_input.numpy()})
print("ONNX输出形状:", output[0].shape)  # 应与PyTorch输出一致

ModuleNotFoundError: No module named 'onnxruntime'