Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Format python with black and isort #4297

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 66 additions & 43 deletions onnx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,80 +1,85 @@
# SPDX-License-Identifier: Apache-2.0

import os
from typing import IO, Any, Optional, TypeVar, Union, cast

from .onnx_cpp2py_export import ONNX_ML
from onnx.external_data_helper import load_external_data_for_model, write_external_data_tensors, convert_model_to_external_data
from .onnx_pb import * # noqa
from .onnx_operators_pb import * # noqa
from .onnx_data_pb import * # noqa
from .version import version as __version__ # noqa
import google.protobuf.message

# Import common subpackages so they're available when you 'import onnx'
import onnx.checker # noqa
import onnx.compose # noqa
import onnx.defs # noqa
import onnx.helper # noqa
import onnx.utils # noqa
import onnx.compose # noqa
from onnx.external_data_helper import (
convert_model_to_external_data,
load_external_data_for_model,
write_external_data_tensors,
)

import google.protobuf.message

from typing import Union, IO, Optional, cast, TypeVar, Any
from .onnx_cpp2py_export import ONNX_ML
from .onnx_data_pb import * # noqa
from .onnx_operators_pb import * # noqa
from .onnx_pb import * # noqa
from .version import version as __version__ # noqa


# f should be either readable or a file path
def _load_bytes(f: Union[IO[bytes], str]) -> bytes:
if hasattr(f, 'read') and callable(cast(IO[bytes], f).read):
if hasattr(f, "read") and callable(cast(IO[bytes], f).read):
s = cast(IO[bytes], f).read()
else:
with open(cast(str, f), 'rb') as readable:
with open(cast(str, f), "rb") as readable:
s = readable.read()
return s


# content should be bytes,
# f should be either writable or a file path
def _save_bytes(content: bytes, f: Union[IO[bytes], str]) -> None:
if hasattr(f, 'write') and callable(cast(IO[bytes], f).write):
if hasattr(f, "write") and callable(cast(IO[bytes], f).write):
cast(IO[bytes], f).write(content)
else:
with open(cast(str, f), 'wb') as writable:
with open(cast(str, f), "wb") as writable:
writable.write(content)


# f should be either a readable file or a file path
def _get_file_path(f: Union[IO[bytes], str]) -> Optional[str]:
if isinstance(f, str):
return os.path.abspath(f)
if hasattr(f, 'name'):
if hasattr(f, "name"):
return os.path.abspath(f.name)
return None


def _serialize(proto: Union[bytes, google.protobuf.message.Message]) -> bytes:
'''
"""
Serialize a in-memory proto to bytes

Arguments:
proto: a in-memory proto, such as a ModelProto, TensorProto, etc

Returns:
Serialized proto in bytes
'''
"""
if isinstance(proto, bytes):
return proto
elif hasattr(proto, 'SerializeToString') and callable(proto.SerializeToString):
elif hasattr(proto, "SerializeToString") and callable(proto.SerializeToString):
result = proto.SerializeToString()
return result
else:
raise TypeError('No SerializeToString method is detected. '
'neither proto is a str.\ntype is {}'.format(type(proto)))
raise TypeError(
"No SerializeToString method is detected. "
"neither proto is a str.\ntype is {}".format(type(proto))
)


_Proto = TypeVar('_Proto', bound=google.protobuf.message.Message)
_Proto = TypeVar("_Proto", bound=google.protobuf.message.Message)


def _deserialize(s: bytes, proto: _Proto) -> _Proto:
'''
"""
Parse bytes into a in-memory proto

Arguments:
Expand All @@ -83,24 +88,31 @@ def _deserialize(s: bytes, proto: _Proto) -> _Proto:

Returns:
The proto instance filled in by s
'''
"""
if not isinstance(s, bytes):
raise ValueError(f'Parameter s must be bytes, but got type: {type(s)}')
raise ValueError(f"Parameter s must be bytes, but got type: {type(s)}")

if not (hasattr(proto, 'ParseFromString') and callable(proto.ParseFromString)):
raise ValueError('No ParseFromString method is detected. '
'\ntype is {}'.format(type(proto)))
if not (hasattr(proto, "ParseFromString") and callable(proto.ParseFromString)):
raise ValueError(
"No ParseFromString method is detected. " "\ntype is {}".format(type(proto))
)

decoded = cast(Optional[int], proto.ParseFromString(s))
if decoded is not None and decoded != len(s):
raise google.protobuf.message.DecodeError(
"Protobuf decoding consumed too few bytes: {} out of {}".format(
decoded, len(s)))
decoded, len(s)
)
)
return proto


def load_model(f: Union[IO[bytes], str], format: Optional[Any] = None, load_external_data: bool = True) -> ModelProto:
'''
def load_model(
f: Union[IO[bytes], str],
format: Optional[Any] = None,
load_external_data: bool = True,
) -> ModelProto:
"""
Loads a serialized ModelProto into memory
load_external_data is true if the external data under the same directory of the model and load the external data
If not, users need to call load_external_data_for_model with directory to load
Expand All @@ -111,7 +123,7 @@ def load_model(f: Union[IO[bytes], str], format: Optional[Any] = None, load_exte

Returns:
Loaded in-memory ModelProto
'''
"""
s = _load_bytes(f)
model = load_model_from_string(s, format=format)

Expand All @@ -125,7 +137,7 @@ def load_model(f: Union[IO[bytes], str], format: Optional[Any] = None, load_exte


def load_tensor(f: Union[IO[bytes], str], format: Optional[Any] = None) -> TensorProto:
'''
"""
Loads a serialized TensorProto into memory

Arguments:
Expand All @@ -134,13 +146,13 @@ def load_tensor(f: Union[IO[bytes], str], format: Optional[Any] = None) -> Tenso

Returns:
Loaded in-memory TensorProto
'''
"""
s = _load_bytes(f)
return load_tensor_from_string(s, format=format)


def load_model_from_string(s: bytes, format: Optional[Any] = None) -> ModelProto:
'''
"""
Loads a binary string (bytes) that contains serialized ModelProto

Arguments:
Expand All @@ -149,12 +161,12 @@ def load_model_from_string(s: bytes, format: Optional[Any] = None) -> ModelProto

Returns:
Loaded in-memory ModelProto
'''
"""
return _deserialize(s, ModelProto())


def load_tensor_from_string(s: bytes, format: Optional[Any] = None) -> TensorProto:
'''
"""
Loads a binary string (bytes) that contains serialized TensorProto

Arguments:
Expand All @@ -163,12 +175,21 @@ def load_tensor_from_string(s: bytes, format: Optional[Any] = None) -> TensorPro

Returns:
Loaded in-memory TensorProto
'''
"""
return _deserialize(s, TensorProto())


def save_model(proto: Union[ModelProto, bytes], f: Union[IO[bytes], str], format: Optional[Any] = None, save_as_external_data: bool = False, all_tensors_to_one_file: bool = True, location: Optional[str] = None, size_threshold: int = 1024, convert_attribute: bool = False) -> None:
'''
def save_model(
proto: Union[ModelProto, bytes],
f: Union[IO[bytes], str],
format: Optional[Any] = None,
save_as_external_data: bool = False,
all_tensors_to_one_file: bool = True,
location: Optional[str] = None,
size_threshold: int = 1024,
convert_attribute: bool = False,
) -> None:
"""
Saves the ModelProto to the specified path and optionally, serialize tensors with raw data as external data before saving.

Arguments:
Expand All @@ -182,12 +203,14 @@ def save_model(proto: Union[ModelProto, bytes], f: Union[IO[bytes], str], format
to external data. To convert every tensor with raw data to external data set size_threshold=0.
convert_attribute: If true, convert all tensors to external data
If false, convert only non-attribute tensors to external data
'''
"""
if isinstance(proto, bytes):
proto = _deserialize(proto, ModelProto())

if save_as_external_data:
convert_model_to_external_data(proto, all_tensors_to_one_file, location, size_threshold, convert_attribute)
convert_model_to_external_data(
proto, all_tensors_to_one_file, location, size_threshold, convert_attribute
)

model_filepath = _get_file_path(f)
if model_filepath:
Expand All @@ -199,14 +222,14 @@ def save_model(proto: Union[ModelProto, bytes], f: Union[IO[bytes], str], format


def save_tensor(proto: TensorProto, f: Union[IO[bytes], str]) -> None:
'''
"""
Saves the TensorProto to the specified path.

Arguments:
proto: should be a in-memory TensorProto
f: can be a file-like object (has "write" function) or a string containing a file name
format: for future use
'''
"""
s = _serialize(proto)
_save_bytes(s, f)

Expand Down
67 changes: 32 additions & 35 deletions onnx/backend/base.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,49 @@
# SPDX-License-Identifier: Apache-2.0

from collections import namedtuple
from typing import Sequence, Any, Type, Tuple, NewType, Optional, Dict
from typing import Any, Dict, NewType, Optional, Sequence, Tuple, Type

import numpy # type: ignore

import onnx.checker
import onnx.onnx_cpp2py_export.checker as c_checker
from onnx import ModelProto, NodeProto, IR_VERSION
from onnx import IR_VERSION, ModelProto, NodeProto


class DeviceType:
_Type = NewType('_Type', int)
_Type = NewType("_Type", int)
CPU: _Type = _Type(0)
CUDA: _Type = _Type(1)


class Device:
'''
"""
Describes device type and device id
syntax: device_type:device_id(optional)
example: 'CPU', 'CUDA', 'CUDA:1'
'''
"""

def __init__(self, device: str) -> None:
options = device.split(':')
options = device.split(":")
self.type = getattr(DeviceType, options[0])
self.device_id = 0
if len(options) > 1:
self.device_id = int(options[1])


def namedtupledict(typename: str, field_names: Sequence[str], *args: Any, **kwargs: Any) -> Type[Tuple[Any, ...]]:
def namedtupledict(
typename: str, field_names: Sequence[str], *args: Any, **kwargs: Any
) -> Type[Tuple[Any, ...]]:
field_names_map = {n: i for i, n in enumerate(field_names)}
# Some output names are invalid python identifier, e.g. "0"
kwargs.setdefault('rename', True)
kwargs.setdefault("rename", True)
data = namedtuple(typename, field_names, *args, **kwargs) # type: ignore

def getitem(self: Any, key: Any) -> Any:
if isinstance(key, str):
key = field_names_map[key]
return super(type(self), self).__getitem__(key) # type: ignore

setattr(data, "__getitem__", getitem)
return data

Expand All @@ -52,55 +55,49 @@ def run(self, inputs: Any, **kwargs: Any) -> Tuple[Any, ...]:

class Backend:
@classmethod
def is_compatible(cls,
model: ModelProto,
device: str = 'CPU',
**kwargs: Any
) -> bool:
def is_compatible(
cls, model: ModelProto, device: str = "CPU", **kwargs: Any
) -> bool:
# Return whether the model is compatible with the backend.
return True

@classmethod
def prepare(cls,
model: ModelProto,
device: str = 'CPU',
**kwargs: Any
) -> Optional[BackendRep]:
def prepare(
cls, model: ModelProto, device: str = "CPU", **kwargs: Any
) -> Optional[BackendRep]:
# TODO Remove Optional from return type
onnx.checker.check_model(model)
return None

@classmethod
def run_model(cls,
model: ModelProto,
inputs: Any,
device: str = 'CPU',
**kwargs: Any
) -> Tuple[Any, ...]:
def run_model(
cls, model: ModelProto, inputs: Any, device: str = "CPU", **kwargs: Any
) -> Tuple[Any, ...]:
backend = cls.prepare(model, device, **kwargs)
assert backend is not None
return backend.run(inputs)

@classmethod
def run_node(cls,
node: NodeProto,
inputs: Any,
device: str = 'CPU',
outputs_info: Optional[Sequence[Tuple[numpy.dtype, Tuple[int, ...]]]] = None,
**kwargs: Dict[str, Any]
) -> Optional[Tuple[Any, ...]]:
'''Simple run one operator and return the results.
def run_node(
cls,
node: NodeProto,
inputs: Any,
device: str = "CPU",
outputs_info: Optional[Sequence[Tuple[numpy.dtype, Tuple[int, ...]]]] = None,
**kwargs: Dict[str, Any],
) -> Optional[Tuple[Any, ...]]:
"""Simple run one operator and return the results.
Args:
outputs_info: a list of tuples, which contains the element type and
shape of each output. First element of the tuple is the dtype, and
the second element is the shape. More use case can be found in
https://github.com/onnx/onnx/blob/main/onnx/backend/test/runner/__init__.py
'''
"""
# TODO Remove Optional from return type
if 'opset_version' in kwargs:
if "opset_version" in kwargs:
special_context = c_checker.CheckerContext()
special_context.ir_version = IR_VERSION
special_context.opset_imports = {'': kwargs['opset_version']} # type: ignore
special_context.opset_imports = {"": kwargs["opset_version"]} # type: ignore
onnx.checker.check_node(node, special_context)
else:
onnx.checker.check_node(node)
Expand Down
Loading