In [1]:
import os
import numpy as np
from PIL import Image
import cv2
import matplotlib.pyplot as plt
import torch
import random
import pandas as pd
from collections import OrderedDict

In [2]:
from networks import get_model

In [3]:
def normalization(rgb_img,mean_list=[0.485, 0.456, 0.406],std_list=[0.229, 0.224, 0.225]):
    MEAN = 255 * np.array(mean_list)
    STD = 255 * np.array(std_list)
    rgb_img = rgb_img.transpose(-1, 0, 1)
    norm_img = (rgb_img - MEAN[:, None, None]) / STD[:, None, None]
    
    return norm_img

In [19]:
ckpt_path = "Safas_wildcelebA_align.pth"
output_path = "Safas_wildcelebA_align.onnx"

In [20]:
model_type = 'ResNet18_lgt'
max_iter=-1
total_cls_num = 2
normfc = False
usebias = True
feat_loss = 'supcon'

model = get_model(model_type, \
                  max_iter, total_cls_num, pretrained=False, \
                  normed_fc=normfc, use_bias=usebias, \
                  simsiam=True if feat_loss == 'simsiam' else False)

In [21]:
ckpt = torch.load(ckpt_path)
state_dict = ckpt['state_dict']

new_state_dict = OrderedDict()
for n, v in state_dict.items():
    name = n.replace("module.","") # dataparallel
    new_state_dict[name] = v
model.load_state_dict(new_state_dict)
_ = model.eval()

In [22]:
img = np.random.randint(0, 255, size=(256, 256, 3), dtype=np.uint8)
img = normalization(img)
img = torch.tensor(np.expand_dims(img,0).astype(np.float32))

In [23]:
# export
opset=12
torch.onnx.export(model, img, output_path, keep_initializers_as_inputs=False, verbose=False, opset_version=opset)

In [24]:
# simplify

import onnxsim
import onnx

model = onnx.load(output_path)
graph = model.graph
graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None'

model, check = onnxsim.simplify(model, input_shapes={"input.1":(1, 3, 256, 256)})
assert check, "Simplified ONNX model could not be validated"
onnx.save(model, output_path)