Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ONNX To Pytorch Conversion #168

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 26 additions & 18 deletions nebullvm/operations/conversions/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from nebullvm.operations.base import Operation
from nebullvm.operations.conversions.pytorch import convert_torch_to_onnx
from nebullvm.operations.conversions.tensorflow import convert_tf_to_onnx
from nebullvm.operations.conversions.onnx import convert_onnx_to_torch
from nebullvm.optional_modules.onnx import onnx
from nebullvm.optional_modules.tensorflow import tensorflow as tf
from nebullvm.optional_modules.torch import torch
from nebullvm.tools.base import DeepLearningFramework, ModelParams
from nebullvm.tools.data import DataManager
from onnx import ModelProto


class Converter(Operation, abc.ABC):
Expand All @@ -27,7 +29,7 @@ def __init__(self, model_name: Optional[str] = None):
self.model_name = model_name or "temp"

def set_state(
self, model: Union[torch.nn.Module, tf.Module, str], data: DataManager
self, model: Union[torch.nn.Module, tf.Module, onnx.ModelProto, str], data: DataManager
):
self.model = model
self.data = data
Expand Down Expand Up @@ -104,26 +106,32 @@ def pytorch_conversion(self):


class ONNXConverter(Converter):
DEST_FRAMEWORKS = []
DEST_FRAMEWORKS = [DeepLearningFramework.NUMPY]

def execute(self, save_path, model_params):
onnx_path = save_path / f"{self.model_name}{self.ONNX_EXTENSION}"
try:
model_onnx = onnx.load(str(self.model))
onnx.save(model_onnx, str(onnx_path))
except Exception:
self.logger.error(
"The provided onnx model path is invalid. Please provide"
" a valid path to a model in order to use Nebullvm."
)
self.converted_models = []

self.converted_models = [str(onnx_path)]
def execute(
self,
save_path: Path,
model_params: ModelProto,
):
self.converted_models = [self.model]
for framework in self.DEST_FRAMEWORKS:
if framework is DeepLearningFramework.NUMPY:
self.pytorch_conversion(save_path, model_params)
else:
raise NotImplementedError()

def tensorflow_conversion(self):
# TODO: Implement conversion from ONNX to Tensorflow
raise NotImplementedError()

def pytorch_conversion(self):
# TODO: Implement conversion from ONNX to Pytorch
raise NotImplementedError()
def pytorch_conversion(self, save_path, model_params):
self.model_onnx = self.model
#torch_path = save_path / f"{self.model_name}{self.TORCH_EXTENSION}"
torch_model = convert_onnx_to_torch(
onnx_model=self.model_onnx
)
if self.converted_models is None:
self.converted_models = [torch_model]
else:
self.converted_models.append(torch_model)

27 changes: 27 additions & 0 deletions nebullvm/operations/conversions/onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import logging
from pathlib import Path

from nebullvm.optional_modules.torch import torch
from nebullvm.tools.base import Device
from nebullvm.optional_modules.onnx import ModelProto
logger = logging.getLogger("nebullvm_logger")

from nebullvm.optional_modules.onnx import convert

def convert_onnx_to_torch(
onnx_model: ModelProto
):
"""Function importing a custom ONNX model and converting it in Pytorch

Args:
onnx_model: ONNX model (tested with model=onnx.load("model.onnx")).
"""
try:
torch_model = torch.fx.symbolic_trace(convert(onnx_model))
return torch_model
except Exception as e:
logger.warning("Exception raised during conversion of ONNX to Pytorch."
"ONNX to Torch pipeline will be skipped")
logger.warning(e)
return None

18 changes: 18 additions & 0 deletions nebullvm/operations/inference_learners/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from tempfile import TemporaryDirectory
from typing import Tuple, Union, Optional, List

import numpy as np

from nebullvm.operations.inference_learners.base import (
PytorchBaseInferenceLearner,
LearnerMetadata,
Expand Down Expand Up @@ -116,3 +118,19 @@ def from_torch_model(
input_data=input_data,
device=device,
)


class NumpyPytorchBackendInferenceLearner(
PytorchBackendInferenceLearner, PytorchBaseInferenceLearner
):

"""
Wrapper around PytorchBackendInferenceLearner to allow numpy inputs and outputs
"""


def run(self, *input_tensors: np.ndarray) -> Tuple[np.ndarray, ...]:
input_tensors = [torch.from_numpy(t) for t in input_tensors]
# Call the PytorchBackendInferenceLearner run method
res = super().run(*input_tensors)
return tuple(out.numpy() for out in res)
6 changes: 6 additions & 0 deletions nebullvm/optional_modules/onnx.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
try:
import onnx # noqa F401
from onnx import ModelProto
except ImportError:
onnx = None

Expand All @@ -11,3 +12,8 @@

except ImportError:
convert_float_to_float16_model_path = object

try:
from onnx2torch import convert # noqa F401
except ImportError:
convert = None