In [1]:
import os
import numpy as np
import cv2
import torch
from collections import OrderedDict

In [2]:
from src.model import build_model
from src.get_config import get_config

In [5]:
def normalization(rgb_img,mean_list=[0.485, 0.456, 0.406],std_list=[0.229, 0.224, 0.225]):
    MEAN = 255 * np.array(mean_list)
    STD = 255 * np.array(std_list)
    rgb_img = rgb_img.transpose(-1, 0, 1)
    norm_img = (rgb_img - MEAN[:, None, None]) / STD[:, None, None]
    
    return norm_img

def preprocessing(aimg,mean_list=[0.485, 0.456, 0.406],std_list=[0.229, 0.224, 0.225]):
    input_img = normalization(aimg,mean_list,std_list) # aimg is RGB
    input_img = torch.tensor(np.expand_dims(input_img,0).astype(np.float32))
    
    return input_img

In [17]:
cfg_path = "configs/efficientNet_B0_celebA.py"
model_path = "./save_model_224/best_epoch28.pth"

cfg_path = "configs/efficientNet_B0_celebA_add.py"
model_path = "./save_model_add224/best_epoch30.pth"

cfg_path = "configs/efficientNet_B0_celebA_crop.py"
model_path = "./save_model_crop224/best_epoch17.pth"

cfg = get_config(cfg_path)

In [18]:
model = build_model(cfg.network,cfg.num_classes,'',False)

load_weight = torch.load(model_path)
new_state_dict = OrderedDict()
for n, v in load_weight.items():
    name = n.replace("module.","") 
    new_state_dict[name] = v
    
model.load_state_dict(new_state_dict)
_ = model.eval()

In [19]:
img = np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8)
img = preprocessing(img)

In [20]:
# export
output_path = "./save_model_crop224/efficientNetB0_celebA_crop224.onnx"
opset=12
torch.onnx.export(model, img, output_path, keep_initializers_as_inputs=False, verbose=False, opset_version=opset)

In [21]:
# simplify

import onnxsim
import onnx

model = onnx.load(output_path)
graph = model.graph
graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None'

model, check = onnxsim.simplify(model, input_shapes={"input.1":(1, 3, 224, 224)})
assert check, "Simplified ONNX model could not be validated"
onnx.save(model, output_path)