Navigation Menu

Skip to content

Commit

Permalink
Add missing typing.Optional type annotations to function parameters.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 376873484
  • Loading branch information
rchen152 authored and mn-robot committed Jun 1, 2021
1 parent f41c85c commit ec4d407
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 96 deletions.
40 changes: 21 additions & 19 deletions morph_net/network_regularizers/activation_regularizer.py
Expand Up @@ -5,6 +5,8 @@
# [internal] enable type annotations
from __future__ import print_function

from typing import Optional

from morph_net.framework import batch_norm_source_op_handler
from morph_net.framework import conv2d_transpose_source_op_handler
from morph_net.framework import conv_source_op_handler
Expand All @@ -22,15 +24,15 @@
class GammaActivationRegularizer(generic_regularizers.NetworkRegularizer):
"""A NetworkRegularizer that targets activation count using Gamma L1."""

def __init__(
self,
output_boundary: List[tf.Operation],
gamma_threshold,
regularizer_decorator: Type[generic_regularizers.OpRegularizer] = None,
decorator_parameters=None,
input_boundary: List[tf.Operation] = None,
force_group=None,
regularizer_blacklist=None):
def __init__(self,
output_boundary: List[tf.Operation],
gamma_threshold,
regularizer_decorator: Optional[Type[
generic_regularizers.OpRegularizer]] = None,
decorator_parameters=None,
input_boundary: Optional[List[tf.Operation]] = None,
force_group=None,
regularizer_blacklist=None):
"""Creates a GammaActivationRegularizer object.
Args:
Expand Down Expand Up @@ -95,16 +97,16 @@ def cost_name(self):
class GroupLassoActivationRegularizer(generic_regularizers.NetworkRegularizer):
"""A NetworkRegularizer that targets activation count using L1 group lasso."""

def __init__(
self,
output_boundary: List[tf.Operation],
threshold,
l1_fraction=0,
regularizer_decorator: Type[generic_regularizers.OpRegularizer] = None,
decorator_parameters=None,
input_boundary: List[tf.Operation] = None,
force_group=None,
regularizer_blacklist=None):
def __init__(self,
output_boundary: List[tf.Operation],
threshold,
l1_fraction=0,
regularizer_decorator: Optional[Type[
generic_regularizers.OpRegularizer]] = None,
decorator_parameters=None,
input_boundary: Optional[List[tf.Operation]] = None,
force_group=None,
regularizer_blacklist=None):
"""Creates a GroupLassoActivationRegularizer object.
Args:
Expand Down
40 changes: 20 additions & 20 deletions morph_net/network_regularizers/flop_regularizer.py
Expand Up @@ -4,7 +4,7 @@
from __future__ import division
# [internal] enable type annotations
from __future__ import print_function
from typing import Type, List
from typing import List, Optional, Type

from morph_net.framework import batch_norm_source_op_handler
from morph_net.framework import conv2d_transpose_source_op_handler as conv2d_transpose_handler
Expand Down Expand Up @@ -40,15 +40,15 @@ def cost_name(self):
class GammaFlopsRegularizer(generic_regularizers.NetworkRegularizer):
"""A NetworkRegularizer that targets FLOPs using Gamma L1 as OpRegularizer."""

def __init__(
self,
output_boundary: List[tf.Operation],
gamma_threshold,
regularizer_decorator: Type[generic_regularizers.OpRegularizer] = None,
decorator_parameters=None,
input_boundary: List[tf.Operation] = None,
force_group=None,
regularizer_blacklist=None):
def __init__(self,
output_boundary: List[tf.Operation],
gamma_threshold,
regularizer_decorator: Optional[Type[
generic_regularizers.OpRegularizer]] = None,
decorator_parameters=None,
input_boundary: Optional[List[tf.Operation]] = None,
force_group=None,
regularizer_blacklist=None):
"""Creates a GammaFlopsRegularizer object.
Args:
Expand Down Expand Up @@ -113,16 +113,16 @@ def cost_name(self):
class GroupLassoFlopsRegularizer(generic_regularizers.NetworkRegularizer):
"""A NetworkRegularizer that targets FLOPs using L1 group lasso."""

def __init__(
self,
output_boundary: List[tf.Operation],
threshold,
l1_fraction=0,
regularizer_decorator: Type[generic_regularizers.OpRegularizer] = None,
decorator_parameters=None,
input_boundary: List[tf.Operation] = None,
force_group=None,
regularizer_blacklist=None):
def __init__(self,
output_boundary: List[tf.Operation],
threshold,
l1_fraction=0,
regularizer_decorator: Optional[Type[
generic_regularizers.OpRegularizer]] = None,
decorator_parameters=None,
input_boundary: Optional[List[tf.Operation]] = None,
force_group=None,
regularizer_blacklist=None):
"""Creates a GroupLassoFlopsRegularizer object.
Args:
Expand Down
50 changes: 25 additions & 25 deletions morph_net/network_regularizers/latency_regularizer.py
@@ -1,6 +1,6 @@
"""A NetworkRegularizer that targets inference latency."""

from typing import Type, List
from typing import List, Optional, Type

from morph_net.framework import batch_norm_source_op_handler
from morph_net.framework import conv2d_transpose_source_op_handler as conv2d_transpose_handler
Expand Down Expand Up @@ -52,19 +52,19 @@ class LogisticSigmoidLatencyRegularizer(
regularized. See op_regularizer_manager for more detail.
"""

def __init__(
self,
output_boundary: List[tf.Operation],
hardware,
batch_size=1,
regularize_on_mask=True,
alive_threshold=0.1,
mask_as_alive_vector=True,
regularizer_decorator: Type[generic_regularizers.OpRegularizer] = None,
decorator_parameters=None,
input_boundary: List[tf.Operation] = None,
force_group=None,
regularizer_blacklist=None):
def __init__(self,
output_boundary: List[tf.Operation],
hardware,
batch_size=1,
regularize_on_mask=True,
alive_threshold=0.1,
mask_as_alive_vector=True,
regularizer_decorator: Optional[Type[
generic_regularizers.OpRegularizer]] = None,
decorator_parameters=None,
input_boundary: Optional[List[tf.Operation]] = None,
force_group=None,
regularizer_blacklist=None):

self._hardware = hardware
self._batch_size = batch_size
Expand Down Expand Up @@ -97,17 +97,17 @@ def cost_name(self):
class GammaLatencyRegularizer(generic_regularizers.NetworkRegularizer):
"""A NetworkRegularizer that targets latency using Gamma L1."""

def __init__(
self,
output_boundary: List[tf.Operation],
gamma_threshold,
hardware,
batch_size=1,
regularizer_decorator: Type[generic_regularizers.OpRegularizer] = None,
decorator_parameters=None,
input_boundary: List[tf.Operation] = None,
force_group=None,
regularizer_blacklist=None) -> None:
def __init__(self,
output_boundary: List[tf.Operation],
gamma_threshold,
hardware,
batch_size=1,
regularizer_decorator: Optional[Type[
generic_regularizers.OpRegularizer]] = None,
decorator_parameters=None,
input_boundary: Optional[List[tf.Operation]] = None,
force_group=None,
regularizer_blacklist=None) -> None:
"""Creates a GammaLatencyRegularizer object.
Latency cost and regularization loss is calculated for a specified hardware
Expand Down
24 changes: 12 additions & 12 deletions morph_net/network_regularizers/logistic_sigmoid_regularizer.py
Expand Up @@ -6,7 +6,7 @@
from __future__ import print_function

import abc
from typing import Type, List
from typing import List, Optional, Type

from morph_net.framework import generic_regularizers
from morph_net.framework import logistic_sigmoid_source_op_handler as ls_handler
Expand All @@ -23,17 +23,17 @@
class LogisticSigmoidRegularizer(generic_regularizers.NetworkRegularizer):
"""Base class for NetworkRegularizers that use probabilistic sampling."""

def __init__(
self,
output_boundary: List[tf.Operation],
regularize_on_mask=True,
alive_threshold=0.1,
mask_as_alive_vector=True,
regularizer_decorator: Type[generic_regularizers.OpRegularizer] = None,
decorator_parameters=None,
input_boundary: List[tf.Operation] = None,
force_group=None,
regularizer_blacklist=None):
def __init__(self,
output_boundary: List[tf.Operation],
regularize_on_mask=True,
alive_threshold=0.1,
mask_as_alive_vector=True,
regularizer_decorator: Optional[Type[
generic_regularizers.OpRegularizer]] = None,
decorator_parameters=None,
input_boundary: Optional[List[tf.Operation]] = None,
force_group=None,
regularizer_blacklist=None):
"""Creates a LogisticSigmoidFlopsRegularizer object.
Args:
Expand Down
40 changes: 20 additions & 20 deletions morph_net/network_regularizers/model_size_regularizer.py
Expand Up @@ -4,7 +4,7 @@
from __future__ import division
# [internal] enable type annotations
from __future__ import print_function
from typing import Text, Type, List
from typing import List, Optional, Text, Type

from morph_net.framework import batch_norm_source_op_handler
from morph_net.framework import conv2d_transpose_source_op_handler as conv2d_transpose_handler
Expand Down Expand Up @@ -40,15 +40,15 @@ def cost_name(self):
class GammaModelSizeRegularizer(generic_regularizers.NetworkRegularizer):
"""A NetworkRegularizer that targets model size using Gamma L1."""

def __init__(
self,
output_boundary: List[tf.Operation],
gamma_threshold,
regularizer_decorator: Type[generic_regularizers.OpRegularizer] = None,
decorator_parameters=None,
input_boundary: List[tf.Operation] = None,
force_group=None,
regularizer_blacklist=None):
def __init__(self,
output_boundary: List[tf.Operation],
gamma_threshold,
regularizer_decorator: Optional[Type[
generic_regularizers.OpRegularizer]] = None,
decorator_parameters=None,
input_boundary: Optional[List[tf.Operation]] = None,
force_group=None,
regularizer_blacklist=None):
"""Creates a GammaModelSizeRegularizer object.
Args:
Expand Down Expand Up @@ -112,16 +112,16 @@ def cost_name(self):
class GroupLassoModelSizeRegularizer(generic_regularizers.NetworkRegularizer):
"""A NetworkRegularizer that targets model size using L1 group lasso."""

def __init__(
self,
output_boundary: List[tf.Operation],
threshold,
l1_fraction=0.0,
regularizer_decorator: Type[generic_regularizers.OpRegularizer] = None,
decorator_parameters=None,
input_boundary: List[tf.Operation] = None,
force_group: List[Text] = None,
regularizer_blacklist: List[Text] = None):
def __init__(self,
output_boundary: List[tf.Operation],
threshold,
l1_fraction=0.0,
regularizer_decorator: Optional[Type[
generic_regularizers.OpRegularizer]] = None,
decorator_parameters=None,
input_boundary: Optional[List[tf.Operation]] = None,
force_group: Optional[List[Text]] = None,
regularizer_blacklist: Optional[List[Text]] = None):
"""Creates a GroupLassoModelSizeRegularizer object.
Args:
Expand Down

0 comments on commit ec4d407

Please sign in to comment.