Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a class for ONNX models #980

Merged
merged 1 commit into from
Sep 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions backend/src/nodes/model_save_nodes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# These sad files have to be all on their own :(
from __future__ import annotations
from typing import Any
from sanic.log import logger

from .node_base import NodeBase
Expand All @@ -11,6 +10,7 @@
from .properties.outputs import *

from .utils.ncnn_model import NcnnModel
from .utils.onnx_model import OnnxModel


@NodeFactory.register("chainner:onnx:save_model")
Expand All @@ -34,13 +34,12 @@ def __init__(self):
self.side_effects = True

def run(
self, onnx_model: Tuple[Any, bytes], directory: str, model_name: str
self, model: OnnxModel, directory: str, model_name: str
) -> None:
full_path = f"{os.path.join(directory, model_name)}.onnx"
logger.info(f"Writing file to path: {full_path}")
with open(full_path, "wb") as f:
_, onnx_model_bytes = onnx_model
f.write(onnx_model_bytes)
f.write(model.bytes)


@NodeFactory.register("chainner:ncnn:save_model")
Expand Down
74 changes: 18 additions & 56 deletions backend/src/nodes/onnx_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from .properties.outputs import *
from .utils.ncnn_model import NcnnModel
from .utils.onnx_auto_split import onnx_auto_split_process
from .utils.onnx_model import OnnxModel
from .utils.onnx_session import get_onnx_session
from .utils.onnx_to_ncnn import Onnx2NcnnConverter
from .utils.utils import get_h_w_c, np2nptensor, nptensor2np, convenient_upscale
from .utils.exec_options import get_execution_options
Expand All @@ -28,44 +30,6 @@
from .model_save_nodes import OnnxSaveModelNode


def create_inference_session(model_as_string: bytes) -> ort.InferenceSession:
exec_options = get_execution_options()
if exec_options.onnx_execution_provider == "TensorrtExecutionProvider":
providers = [
(
"TensorrtExecutionProvider",
{
"device_id": exec_options.onnx_gpu_index,
},
),
(
"CUDAExecutionProvider",
{
"device_id": exec_options.onnx_gpu_index,
},
),
"CPUExecutionProvider",
]
elif exec_options.onnx_execution_provider == "CUDAExecutionProvider":
providers = [
(
"CUDAExecutionProvider",
{
"device_id": exec_options.onnx_gpu_index,
},
),
"CPUExecutionProvider",
]
else:
providers = [exec_options.onnx_execution_provider, "CPUExecutionProvider"]

session = ort.InferenceSession(
model_as_string,
providers=providers,
)
return session


@NodeFactory.register("chainner:onnx:load_model")
class OnnxLoadModelNode(NodeBase):
def __init__(self):
Expand All @@ -87,7 +51,7 @@ def __init__(self):

self.model = None # Defined in run

def run(self, path: str) -> Tuple[Tuple[ort.InferenceSession, bytes], str, str]:
def run(self, path: str) -> Tuple[OnnxModel, str, str]:
"""Read a pth file from the specified path and return it as a state dict
and loaded model after finding arch config"""

Expand All @@ -100,10 +64,8 @@ def run(self, path: str) -> Tuple[Tuple[ort.InferenceSession, bytes], str, str]:

model_as_string = model.SerializeToString() # type: ignore

session = create_inference_session(model_as_string)

dirname, basename = os.path.split(os.path.splitext(path)[0])
return (session, model_as_string), dirname, basename
return OnnxModel(model_as_string), dirname, basename


@NodeFactory.register("chainner:onnx:upscale_image")
Expand Down Expand Up @@ -155,12 +117,12 @@ def upscale(

def run(
self,
onnx_model: Tuple[ort.InferenceSession, bytes],
model: OnnxModel,
img: np.ndarray,
tile_mode: Union[int, None],
) -> np.ndarray:
"""Upscales an image with a pretrained model"""
session, _ = onnx_model
session = get_onnx_session(model, get_execution_options())
shape = session.get_inputs()[0].shape
if isinstance(shape[1], int) and shape[1] <= 4:
in_nc = shape[1]
Expand Down Expand Up @@ -238,22 +200,22 @@ def perform_interp(

return interp_weights_list

def check_will_upscale(self, interp: Tuple[ort.InferenceSession, bytes]):
def check_will_upscale(self, model: OnnxModel):
fake_img = np.ones((3, 3, 3), dtype=np.float32, order="F")
result = OnnxImageUpscaleNode().run(interp, fake_img, None)
result = OnnxImageUpscaleNode().run(model, fake_img, None)

mean_color = np.mean(result)
del result
return mean_color > 0.5

def run(
self,
a: Tuple[ort.InferenceSession, bytes],
b: Tuple[ort.InferenceSession, bytes],
a: OnnxModel,
b: OnnxModel,
amount: int,
) -> Tuple[Tuple[ort.InferenceSession, bytes], int, int]:
model_a = a[1]
model_b = b[1]
) -> Tuple[OnnxModel, int, int]:
model_a = a.bytes
model_b = b.bytes
if amount == 0:
return a, 100, 0
elif amount == 100:
Expand Down Expand Up @@ -286,13 +248,13 @@ def run(
model_proto_interp.graph.initializer.extend(interp_weights_list) # type: ignore
model_interp: bytes = model_proto_interp.SerializeToString() # type: ignore

session = create_inference_session(model_interp)
if not self.check_will_upscale((session, model_interp)):
model = OnnxModel(model_interp)
if not self.check_will_upscale(model):
raise ValueError(
"These models are not compatible and not able to be interpolated together"
)

return (session, model_interp), 100 - amount, amount
return model, 100 - amount, amount


@NodeFactory.register("chainner:onnx:convert_to_ncnn")
Expand Down Expand Up @@ -324,10 +286,10 @@ def __init__(self):
except:
pass

def run(self, onnx_model: bytes, is_fp16: int) -> Tuple[NcnnModel, str]:
def run(self, model: OnnxModel, is_fp16: int) -> Tuple[NcnnModel, str]:
fp16 = bool(is_fp16)

model_proto = onnx.load_model_from_string(onnx_model)
model_proto = onnx.load_model_from_string(model.bytes)
passes = onnxoptimizer.get_fuse_and_elimination_passes()
opt_model = onnxoptimizer.optimize(model_proto, passes) # type: ignore

Expand Down
15 changes: 4 additions & 11 deletions backend/src/nodes/pytorch_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .utils.pytorch_auto_split import auto_split_process
from .utils.utils import get_h_w_c, np2tensor, tensor2np, convenient_upscale
from .utils.exec_options import get_execution_options, ExecutionOptions
from .utils.onnx_model import OnnxModel
from .utils.torch_types import PyTorchModel
from .utils.pytorch_model_loading import load_state_dict

Expand Down Expand Up @@ -436,7 +437,7 @@ def __init__(self):
except:
pass

def run(self, model: torch.nn.Module) -> List[Any]:
def run(self, model: torch.nn.Module) -> OnnxModel:
exec_options = to_pytorch_execution_options(get_execution_options())

model = model.eval()
Expand Down Expand Up @@ -464,15 +465,7 @@ def run(self, model: torch.nn.Module) -> List[Any]:
f.seek(0)
onnx_model_bytes = f.read()

try:
# pylint: disable=import-outside-toplevel
from .onnx_nodes import create_inference_session

session = create_inference_session(onnx_model_bytes)
except:
session = None

return [(session, onnx_model_bytes)]
return OnnxModel(onnx_model_bytes)


@NodeFactory.register("chainner:pytorch:model_dim")
Expand Down Expand Up @@ -522,6 +515,6 @@ def run(self, model: torch.nn.Module, is_fp16: int) -> Any:
manager to use this node."
)
onnx_model = ConvertTorchToONNXNode().run(model)
ncnn_model, fp_mode = ConvertOnnxToNcnnNode().run(onnx_model[0][1], is_fp16)
ncnn_model, fp_mode = ConvertOnnxToNcnnNode().run(onnx_model, is_fp16)

return ncnn_model, fp_mode
5 changes: 5 additions & 0 deletions backend/src/nodes/utils/onnx_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# This class defines an interface.
# It is important that is does not contain types that depend on ONNX.
class OnnxModel:
def __init__(self, model_as_bytes: bytes):
self.bytes = model_as_bytes
57 changes: 57 additions & 0 deletions backend/src/nodes/utils/onnx_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from __future__ import annotations
import onnxruntime as ort
from weakref import WeakKeyDictionary

from .exec_options import ExecutionOptions
from .onnx_model import OnnxModel


def create_inference_session(
model: OnnxModel, exec_options: ExecutionOptions
) -> ort.InferenceSession:
if exec_options.onnx_execution_provider == "TensorrtExecutionProvider":
providers = [
(
"TensorrtExecutionProvider",
{
"device_id": exec_options.onnx_gpu_index,
},
),
(
"CUDAExecutionProvider",
{
"device_id": exec_options.onnx_gpu_index,
},
),
"CPUExecutionProvider",
]
elif exec_options.onnx_execution_provider == "CUDAExecutionProvider":
providers = [
(
"CUDAExecutionProvider",
{
"device_id": exec_options.onnx_gpu_index,
},
),
"CPUExecutionProvider",
]
else:
providers = [exec_options.onnx_execution_provider, "CPUExecutionProvider"]

session = ort.InferenceSession(model.bytes, providers=providers)
return session


__session_cache: WeakKeyDictionary[
OnnxModel, ort.InferenceSession
] = WeakKeyDictionary()


def get_onnx_session(
model: OnnxModel, exec_options: ExecutionOptions
) -> ort.InferenceSession:
cached = __session_cache.get(model)
if cached is None:
cached = create_inference_session(model, exec_options)
__session_cache[model] = cached
return cached