Skip to content

Commit

Permalink
Merge pull request #90 from dlshriver/update-ops
Browse files Browse the repository at this point in the history
Add Slice Operation and better Dropout support
  • Loading branch information
dlshriver committed Jun 9, 2022
2 parents e852463 + 5ab1ac5 commit f29424b
Show file tree
Hide file tree
Showing 35 changed files with 503 additions and 30 deletions.
36 changes: 35 additions & 1 deletion dnnv/nn/converters/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ def _to_onnx_proto(
tensor_proto = onnx.numpy_helper.from_array(value, name=opname)
self.initializer.append(tensor_proto)
return tensor_proto
if isinstance(value, bool):
tensor_proto = onnx.numpy_helper.from_array(np.asarray(value), name=opname)
self.initializer.append(tensor_proto)
return tensor_proto
if isinstance(value, (int, float)):
tensor_proto = onnx.numpy_helper.from_array(
np.array(value, dtype=f"{type(value).__name__}32"), name=opname
Expand Down Expand Up @@ -289,10 +293,13 @@ def visit_Dropout(self, operation: operations.Dropout) -> onnx.NodeProto:

x = self._to_onnx_proto(operation.x, f"{opname}.x")
ratio = self._to_onnx_proto(operation.ratio, f"{opname}.ratio")
training_mode = self._to_onnx_proto(
operation.training_mode, f"{opname}.training_mode"
)

node = onnx.helper.make_node(
op_type,
inputs=[x.name, ratio.name],
inputs=[x.name, ratio.name, training_mode.name],
outputs=[opname],
name=opname,
)
Expand Down Expand Up @@ -552,6 +559,33 @@ def visit_Sigmoid(self, operation: operations.Sigmoid) -> onnx.NodeProto:

return node

def visit_Slice(self, operation: operations.Slice) -> onnx.NodeProto:
op_type = str(operation)
idx = self.op_counts[op_type] = self.op_counts[op_type] + 1
opname = f"{op_type}_{idx}"

x = self._to_onnx_proto(operation.x, f"{opname}.x")
starts = self._to_onnx_proto(operation.starts, f"{opname}.starts")
ends = self._to_onnx_proto(operation.ends, f"{opname}.ends")

inputs = [x.name, starts.name, ends.name]
if operation.steps is not None:
axes = self._to_onnx_proto(operation.axes, f"{opname}.axes")
steps = self._to_onnx_proto(operation.steps, f"{opname}.steps")
inputs.extend([axes.name, steps.name])
elif operation.axes is not None:
axes = self._to_onnx_proto(operation.axes, f"{opname}.axes")
inputs.append(axes.name)

node = onnx.helper.make_node(
op_type,
inputs=inputs,
outputs=[opname],
name=opname,
)

return node

def visit_Sub(self, operation: operations.Sub) -> onnx.NodeProto:
op_type = str(operation)
idx = self.op_counts[op_type] = self.op_counts[op_type] + 1
Expand Down
37 changes: 37 additions & 0 deletions dnnv/nn/converters/tensorflow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import tensorflow as tf

from .. import operations
from ..graph import OperationGraph
from ..operations import Operation
from ..utils import ONNX_TO_TENSORFLOW_DTYPE
Expand Down Expand Up @@ -391,6 +392,7 @@ def visit_Dropout(self, operation):
@self._cached
def dropout_func(*inputs):
x = _concretize([x_], inputs)
assert not operation.training_mode
return x

return dropout_func
Expand Down Expand Up @@ -808,6 +810,41 @@ def sign_func(*inputs):

return sign_func

def visit_Slice(self, operation: operations.Slice):
x_ = operation.x
if isinstance(x_, Operation):
x_ = self.visit(x_)
starts_ = operation.starts
if isinstance(starts_, Operation):
starts_ = self.visit(starts_)
ends_ = operation.ends
if isinstance(ends_, Operation):
ends_ = self.visit(ends_)
axes_ = operation.axes
if isinstance(axes_, Operation):
axes_ = self.visit(axes_)
steps_ = operation.steps
if isinstance(steps_, Operation):
steps_ = self.visit(steps_)

@self._cached
def slice_func(*inputs):
x, starts, ends, axes, steps = _concretize(
[x_, starts_, ends_, axes_, steps_], inputs
)
n = x.ndim
slices = [slice(None) for _ in range(n)]
if axes is None:
axes = range(n)
if steps is None:
steps = [1 for _ in range(n)]
for i, axis in enumerate(axes):
slices[axis] = slice(starts[i], ends[i], steps[i])
result = x[tuple(slices)]
return result

return slice_func

def visit_Softmax(self, operation):
x_ = operation.x
if isinstance(x_, Operation):
Expand Down
30 changes: 26 additions & 4 deletions dnnv/nn/operations/nn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import Optional

import numpy as np
Expand Down Expand Up @@ -206,23 +207,44 @@ def from_onnx(cls, onnx_node, *inputs):


class Dropout(Operation):
def __init__(self, x, *, ratio=0.5, include_mask=False, name: Optional[str] = None):
def __init__(
self,
x,
*,
ratio=0.5,
training_mode=False,
include_mask=False,
name: Optional[str] = None,
):
super().__init__(name=name)
self.x = x
self.ratio = ratio
self.training_mode = training_mode
self.include_mask = include_mask

@classmethod
def from_onnx(cls, onnx_node, *inputs):
attributes = {a.name: as_numpy(a) for a in onnx_node.attribute}
ratio = attributes.get("ratio", 0.5)
if len(onnx_node.output) == 2:
training_mode = bool(attributes.get("training_mode", False))
if training_mode:
logger = logging.getLogger(__name__)
logger.warning("Dropout operations in training mode have limited support.")
if training_mode and len(onnx_node.output) == 2:
raise NotImplementedError(
"Using the mask of a Dropout operation is not yet supported."
" If you need this functionality, please open a GitHub issue."
)
return cls(*inputs, ratio=ratio, include_mask=True, name=onnx_node.name)
return cls(*inputs, ratio=ratio, name=onnx_node.name)
return cls(
*inputs,
ratio=ratio,
training_mode=training_mode,
include_mask=True,
name=onnx_node.name,
)
return cls(
*inputs, ratio=ratio, training_mode=training_mode, name=onnx_node.name
)


class GlobalAveragePool(Operation):
Expand Down
20 changes: 20 additions & 0 deletions dnnv/nn/operations/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,25 @@ def from_onnx(cls, onnx_node, *inputs):
return cls(*inputs, name=onnx_node.name)


class Slice(Operation):
def __init__(
self, x, starts, ends, *, axes=None, steps=None, name: Optional[str] = None
):
super().__init__(name=name)
self.x = x
self.starts = starts
self.ends = ends
self.axes = axes
self.steps = steps

@classmethod
def from_onnx(cls, onnx_node, *inputs):
attributes = {a.name: as_numpy(a) for a in onnx_node.attribute}
axes = attributes.get("axes")
steps = attributes.get("steps")
return cls(*inputs, axes=axes, steps=steps, name=onnx_node.name)


class Tile(Operation):
def __init__(self, x, repeats, *, name: Optional[str] = None):
super().__init__(name=name)
Expand Down Expand Up @@ -232,6 +251,7 @@ def from_onnx(cls, onnx_node, *inputs):
"Reshape",
"Resize",
"Shape",
"Slice",
"Tile",
"Transpose",
"Unsqueeze",
Expand Down
2 changes: 2 additions & 0 deletions dnnv/nn/transformers/simplifiers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .convert_reshape_to_flatten import ConvertReshapeToFlatten
from .convert_sub_to_add import ConvertSubToAdd
from .drop_identities import (
DropDropout,
DropIdentity,
DropUnnecessaryConcat,
DropUnnecessaryFlatten,
Expand All @@ -35,6 +36,7 @@
ConvertMul,
ConvertReshapeToFlatten,
ConvertSubToAdd,
DropDropout,
DropIdentity,
DropUnnecessaryConcat,
DropUnnecessaryFlatten,
Expand Down
8 changes: 8 additions & 0 deletions dnnv/nn/transformers/simplifiers/drop_identities.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@
from .base import Simplifier


class DropDropout(Simplifier):
def visit_Dropout(self, operation: operations.Dropout):
if operation.training_mode:
return operation
return operation.x


class DropIdentity(Simplifier):
def visit_Identity(self, operation: operations.Identity):
return operation.x
Expand Down Expand Up @@ -36,6 +43,7 @@ def visit_Relu(self, operation: operations.Relu):


__all__ = [
"DropDropout",
"DropIdentity",
"DropUnnecessaryConcat",
"DropUnnecessaryFlatten",
Expand Down
22 changes: 22 additions & 0 deletions dnnv/nn/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,28 @@ def visit_Sigmoid(self, operation: operations.Sigmoid) -> None:
self.print_op_id(operation)
print(f"Sigmoid({self.get_op_id(operation.x)})")

def visit_Slice(self, operation: operations.Slice) -> None:
self.generic_visit(operation)
self.print_op_id(operation)
axes = (
f", axes={self.get_op_id(operation.axes)}"
if operation.axes is not None
else ""
)
steps = (
f", steps={self.get_op_id(operation.steps)}"
if operation.steps is not None
else ""
)
print(
"Slice("
f"{self.get_op_id(operation.x)}, "
f"{self.get_op_id(operation.starts)}, "
f"{self.get_op_id(operation.ends)}"
f"{axes}{steps}"
")"
)

def visit_Sign(self, operation: operations.Sign) -> None:
self.generic_visit(operation)
self.print_op_id(operation)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import onnxruntime
import onnxruntime.backend

from dnnv.nn.converters.onnx import *
from dnnv.nn.operations import *
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import onnxruntime
import onnxruntime.backend
import pytest

from dnnv.nn.converters.onnx import *
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import onnxruntime
import onnxruntime.backend

from dnnv.nn.converters.onnx import *
from dnnv.nn.operations import *
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
import onnx
import onnxruntime
import onnxruntime.backend

from dnnv.nn.converters.onnx import *
from dnnv.nn.operations import *
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import onnxruntime
import onnxruntime.backend

from dnnv.nn.converters.onnx import *
from dnnv.nn.operations import *
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import onnxruntime
import onnxruntime.backend

from dnnv.nn.converters.onnx import *
from dnnv.nn.operations import *
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import onnxruntime
import onnxruntime.backend
import pytest

from dnnv.nn.converters.onnx import *
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import onnxruntime
import onnxruntime.backend

from dnnv.nn.converters.onnx import *
from dnnv.nn.operations import *
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy as np
import onnxruntime
import pytest
import onnxruntime.backend

from dnnv.nn.converters.onnx import *
from dnnv.nn.operations import *
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import onnxruntime
import onnxruntime.backend
import pytest

from dnnv.nn.converters.onnx import *
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import onnxruntime
import onnxruntime.backend

from dnnv.nn.converters.onnx import *
from dnnv.nn.operations import *
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import onnxruntime
import onnxruntime.backend

from dnnv.nn.converters.onnx import *
from dnnv.nn.operations import *
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import onnxruntime
import onnxruntime.backend
import pytest

from dnnv.nn.converters.onnx import *
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import onnxruntime
import onnxruntime.backend

from dnnv.nn.converters.onnx import *
from dnnv.nn.operations import *
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import onnxruntime
import onnxruntime.backend

from dnnv.nn.converters.onnx import *
from dnnv.nn.operations import *
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import onnxruntime
import onnxruntime.backend

from dnnv.nn.converters.onnx import *
from dnnv.nn.operations import *
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import onnxruntime
import onnxruntime.backend
import pytest

from dnnv.nn.converters.onnx import *
Expand Down

0 comments on commit f29424b

Please sign in to comment.