## Import

In [1]:
import sys
import time
import warnings
from pathlib import Path

import cv2
import numpy as np
import torch
import torch.nn as nn
from IPython.display import Markdown, display
from torchvision.models.segmentation import lraspp_mobilenet_v3_large, LRASPP_MobileNet_V3_Large_Weights
from openvino.runtime import Core

  warn(f"Failed to load image Python extension: {e}")


## Set paths

In [2]:
DIRECTORY_NAME = "models"
BASE_MODEL_NAME = DIRECTORY_NAME + "/gen"
weights_path = Path(BASE_MODEL_NAME + ".pth")

# Paths where ONNX and OpenVINO IR models will be stored.
onnx_path = weights_path.with_suffix('.onnx')
if not onnx_path.parent.exists():
    onnx_path.parent.mkdir()
ir_path = onnx_path.with_suffix(".xml")
print(weights_path)
print(onnx_path)

models\gen.pth
models\gen.onnx


## Import the network of generator

In [3]:
from networks.network_generator import SPADEGenerator

## Define options

In [4]:
import argparse

def get_opt():
    parser = argparse.ArgumentParser()

    parser.add_argument("--gpu_ids", default="0")
    parser.add_argument('-j', '--workers', type=int, default=0)
    parser.add_argument('-b', '--batch-size', type=int, default=1)
    parser.add_argument('--fp16', action='store_true', help='use amp')
    # Cuda availability
    parser.add_argument('--cuda', default=True, help='cuda or cpu')

    parser.add_argument('--test_name', type=str, default='test', help='test name')
    parser.add_argument("--dataroot", default="./data/zalando-hd-resized")
    parser.add_argument("--datamode", default="test")
    parser.add_argument("--data_list", default="test_pairs.txt")
    parser.add_argument("--output_dir", type=str, default="./Output")
    parser.add_argument("--datasetting", default="unpaired")
    parser.add_argument("--fine_width", type=int, default=768)
    parser.add_argument("--fine_height", type=int, default=1024)

    parser.add_argument('--tensorboard_dir', type=str, default='./data/zalando-hd-resize/tensorboard',
                        help='save tensorboard infos')
    parser.add_argument('--checkpoint_dir', type=str, default='checkpoints', help='save checkpoint infos')
    parser.add_argument('--tocg_checkpoint', type=str, default='models/mtviton.pth',
                        help='tocg checkpoint')
    parser.add_argument('--gen_checkpoint', type=str, default='models/gen.pth', help='G checkpoint')

    parser.add_argument("--tensorboard_count", type=int, default=100)
    parser.add_argument("--shuffle", action='store_true', help='shuffle input data')
    parser.add_argument("--semantic_nc", type=int, default=13)
    parser.add_argument("--output_nc", type=int, default=13)
    parser.add_argument('--gen_semantic_nc', type=int, default=7, help='# of input label classes without unknown class')

    # network
    parser.add_argument("--warp_feature", choices=['encoder', 'T1'], default="T1")
    parser.add_argument("--out_layer", choices=['relu', 'conv'], default="relu")

    # training
    parser.add_argument("--clothmask_composition", type=str, choices=['no_composition', 'detach', 'warp_grad'],
                        default='warp_grad')

    # Hyper-parameters
    parser.add_argument('--upsample', type=str, default='bilinear', choices=['nearest', 'bilinear'])
    parser.add_argument('--occlusion', action='store_true', help="Occlusion handling")

    # generator
    parser.add_argument('--norm_G', type=str, default='spectralaliasinstance',
                        help='instance normalization or batch normalization')
    parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
    parser.add_argument('--init_type', type=str, default='xavier',
                        help='network initialization [normal|xavier|kaiming|orthogonal]')
    parser.add_argument('--init_variance', type=float, default=0.02, help='variance of the initialization distribution')
    parser.add_argument('--num_upsampling_layers', choices=('normal', 'more', 'most'), default='most',
                        # normal: 256, more: 512
                        help="If 'more', adds upsampling layer between the two middle resnet blocks. If 'most', also add one more upsampling + resnet layer at the end of the generator")
    parser.add_argument('-f', type=str, default="读取额外的参数")
    opt = parser.parse_args()
    return opt


opt = get_opt()

## Initialize pytorch model

In [5]:
import os
from collections import OrderedDict

def load_checkpoint_G(model, checkpoint_path, opt):
    if not os.path.exists(checkpoint_path):
        print("Invalid path!")
        return
    state_dict = torch.load(checkpoint_path, map_location='cpu')
    new_state_dict = OrderedDict(
        [(k.replace('ace', 'alias').replace('.Spade', ''), v) for (k, v) in state_dict.items()])
    new_state_dict._metadata = OrderedDict(
        [(k.replace('ace', 'alias').replace('.Spade', ''), v) for (k, v) in state_dict._metadata.items()])
    model.load_state_dict(new_state_dict, strict=True)
    # if opt.cuda:
    #     model.cuda()

# generator
opt.semantic_nc = 7
generator = SPADEGenerator(opt, 3 + 3 + 3)
# generator.print_network()

# Load Checkpoint
load_checkpoint_G(generator, opt.gen_checkpoint, opt)

generator.eval()

print("Loaded Pytorch model 'gen.pth' ")

Loaded Pytorch model 'gen.pth' 


## Convert Pytorch model to ONNX model

In [6]:
# convert gen
dummy_input0 = torch.randn(1, 9, 1024, 768)
dummy_input1 = torch.randn(1, 7, 1024, 768)
dummy_input = torch.randn(1024, 9, 3, 3)
with warnings.catch_warnings():
    warnings.filterwarnings("ignore")
    if not onnx_path.exists():
        torch.onnx.export(
            generator,
            (dummy_input0,
             dummy_input1),
            onnx_path,
        )
        print(f"ONNX model exported to {onnx_path}.")
    else:
        print(f"ONNX model {onnx_path} already exists.")

RuntimeError: Unsupported: ONNX export of instance_norm for unknown channel size.