Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ONNX To Pytorch Conversion #168

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 26 additions & 17 deletions nebullvm/operations/conversions/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from nebullvm.operations.base import Operation
from nebullvm.operations.conversions.pytorch import convert_torch_to_onnx
from nebullvm.operations.conversions.tensorflow import convert_tf_to_onnx
from nebullvm.operations.conversions.onnx import convert_onnx_to_torch
from nebullvm.optional_modules.onnx import onnx
from nebullvm.optional_modules.tensorflow import tensorflow as tf
from nebullvm.optional_modules.torch import torch
from nebullvm.tools.base import DeepLearningFramework, ModelParams
from nebullvm.tools.data import DataManager
from onnx import ModelProto


class Converter(Operation, abc.ABC):
Expand Down Expand Up @@ -104,26 +106,33 @@ def pytorch_conversion(self):


class ONNXConverter(Converter):
DEST_FRAMEWORKS = []
DEST_FRAMEWORKS = [DeepLearningFramework.NUMPY]

def execute(self, save_path, model_params):
onnx_path = save_path / f"{self.model_name}{self.ONNX_EXTENSION}"
try:
model_onnx = onnx.load(str(self.model))
onnx.save(model_onnx, str(onnx_path))
except Exception:
self.logger.error(
"The provided onnx model path is invalid. Please provide"
" a valid path to a model in order to use Nebullvm."
)
self.converted_models = []

self.converted_models = [str(onnx_path)]
def execute(
self,
save_path: Path,
model_params: ModelProto,
):
self.converted_models = [self.model]
for framework in self.DEST_FRAMEWORKS:
if framework is DeepLearningFramework.NUMPY:
self.pytorch_conversion(save_path, model_params)
else:
raise NotImplementedError()

def tensorflow_conversion(self):
# TODO: Implement conversion from ONNX to Tensorflow
raise NotImplementedError()

def pytorch_conversion(self):
# TODO: Implement conversion from ONNX to Pytorch
raise NotImplementedError()
def pytorch_conversion(self, save_path, model_params):
self.model_onnx = model_params
diegofiori marked this conversation as resolved.
Show resolved Hide resolved
torch_path = save_path / f"{self.model_name}{self.TORCH_EXTENSION}"
torch_model_path = convert_onnx_to_torch(
onnx_model=self.model_onnx,
output_file_path=torch_path,
)
if self.converted_models is None:
self.converted_models = [torch_model_path]
else:
self.converted_models.append(torch_model_path)

31 changes: 31 additions & 0 deletions nebullvm/operations/conversions/onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import logging
from pathlib import Path

from nebullvm.optional_modules.torch import torch
from nebullvm.tools.base import Device
from nebullvm.optional_modules.onnx import ModelProto
logger = logging.getLogger("nebullvm_logger")

from nebullvm.optional_modules.onnx import convert

def convert_onnx_to_torch(
onnx_model: ModelProto,
output_file_path: Path,
):
"""Function importing a custom ONNX model and converting it in Pytorch

Args:
onnx_model: ONNX model (tested with model=onnx.load("model.onnx")).
output_file_path (str or Path): Path where storing the output
Pytorch file.
"""
try:
torch_model = convert(onnx_model)
torch.save(torch_model, output_file_path)
return output_file_path
except Exception as e:
logger.warning("Exception raised during conversion of ONNX to Pytorch."
"ONNX to Torch pipeline will be skipped")
logger.warning(e)
return None

6 changes: 6 additions & 0 deletions nebullvm/optional_modules/onnx.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
try:
import onnx # noqa F401
from onnx import ModelProto
except ImportError:
onnx = None

Expand All @@ -11,3 +12,8 @@

except ImportError:
convert_float_to_float16_model_path = object

try:
from onnx2torch import convert # noqa F401
except ImportError:
convert = None