Skip to content

Commit

Permalink
Add python 2 and Python 3 support.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 273872341
  • Loading branch information
Daniel Kappler authored and Copybara-Service committed Oct 10, 2019
1 parent 4084f80 commit 58aa81a
Show file tree
Hide file tree
Showing 51 changed files with 247 additions and 144 deletions.
12 changes: 6 additions & 6 deletions export_generators/abstract_export_generator.py
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python2, python3
"""Utilities for exporting savedmodels."""

from __future__ import absolute_import
Expand All @@ -23,12 +24,13 @@
import abc
import functools
import os
import gin
from typing import Any, Dict, List, Optional, Text

import gin
import six
from tensor2robot.models import abstract_model
from tensor2robot.utils import tensorspec_utils
import tensorflow as tf
from typing import Optional, Dict, Text, Any, List

from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_log_pb2
Expand All @@ -37,17 +39,15 @@


@gin.configurable
class AbstractExportGenerator(object):
class AbstractExportGenerator(six.with_metaclass(abc.ABCMeta, object)):
"""Class to manage assets related to exporting a model.
Args:
Attributes:
export_raw_receivers: Whether to export receiver_fns which do not have
preprocessing enabled. This is useful for serving using Servo, in
conjunction with client-preprocessing.
"""

__metaclass__ = abc.ABCMeta

def __init__(self, export_raw_receivers = False):
self._export_raw_receivers = export_raw_receivers
self._feature_spec = None
Expand Down
2 changes: 2 additions & 0 deletions export_generators/abstract_export_generator_test.py
Expand Up @@ -13,12 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python2, python3
"""Tests for tensor2robot.export_generator.sabstract_export_generator."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from six.moves import zip
from tensor2robot.export_generators import abstract_export_generator
from tensor2robot.preprocessors import noop_preprocessor
from tensor2robot.utils import mocks
Expand Down
2 changes: 1 addition & 1 deletion export_generators/default_export_generator.py
Expand Up @@ -36,7 +36,7 @@
class DefaultExportGenerator(abstract_export_generator.AbstractExportGenerator):
"""Class to manage assets related to exporting a model.
Args:
Attributes:
export_raw_receivers: Whether to export receiver_fns which do not have
preprocessing enabled. This is useful for serving using Servo, in
conjunction with client-preprocessing.
Expand Down
2 changes: 1 addition & 1 deletion hooks/async_export_hook_builder.py
Expand Up @@ -90,7 +90,7 @@ def _export_fn(export_dir, global_step):
class AsyncExportHookBuilder(hook_builder.HookBuilder):
"""Creates hooks for exporting for cpu and tpu for serving.
Arguments:
Attributes:
export_dir: Directory to output the latest models.
save_secs: Interval to save models, and copy the latest model from
`export_dir` to `lagged_export_dir`.
Expand Down
2 changes: 2 additions & 0 deletions hooks/checkpoint_hooks_test.py
Expand Up @@ -13,13 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python2, python3
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
from absl import flags
from absl import logging
from six.moves import range
from tensor2robot.hooks import checkpoint_hooks
import tensorflow as tf # tf

Expand Down
7 changes: 4 additions & 3 deletions hooks/hook_builder.py
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python2, python3
"""Interface to manage building hooks."""

from __future__ import absolute_import
Expand All @@ -21,14 +22,14 @@
from __future__ import print_function

import abc
from typing import List

import six
from tensor2robot.models import model_interface
import tensorflow as tf # tf
from typing import List


class HookBuilder(object):
__metaclass__ = abc.ABCMeta
class HookBuilder(six.with_metaclass(abc.ABCMeta, object)):

@abc.abstractmethod
def create_hooks(
Expand Down
27 changes: 19 additions & 8 deletions input_generators/abstract_input_generator.py
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python2, python3
"""The abstract base class for input generators."""

from __future__ import absolute_import
Expand All @@ -23,15 +24,18 @@
import abc
import functools
import inspect
from typing import Callable, Optional, Text, Tuple, Union

import gin
import six
from tensor2robot.models import abstract_model
from tensor2robot.utils import tensorspec_utils
import tensorflow as tf
from typing import Callable, Optional, Text, Tuple, Union



@gin.configurable
class AbstractInputGenerator(object):
class AbstractInputGenerator(six.with_metaclass(abc.ABCMeta, object)):
"""The abstract input generator responsible for creating the input pipeline.
The main functionality for exporting models both for serialized tf.Example
Expand All @@ -40,8 +44,6 @@ class AbstractInputGenerator(object):
respective subclasses.
"""

__metaclass__ = abc.ABCMeta

def __init__(self, batch_size = 32):
"""Create an instance.
Expand Down Expand Up @@ -114,16 +116,25 @@ def set_preprocess_fn(self, preprocess_fn): # pytype: disable=invalid-annotatio
preprocess_fn: The function called during the input dataset generation to
preprocess the data.
"""

if isinstance(preprocess_fn, functools.partial): # pytype: disable=wrong-arg-types
# Note, we do not combine both conditions into one since
# inspect.getargspec does not work for functools.partial objects.
if 'mode' not in preprocess_fn.keywords:
raise ValueError('The preprocess_fn mode has to be set if a partial'
'function has been passed.')
elif 'mode' in inspect.getargspec(preprocess_fn).args:
raise ValueError('The passed preprocess_fn has an open argument `mode`'
'which should be patched by a closure or with '
'functools.partial.')
else:
if six.PY3:
argspec = inspect.getfullargspec(preprocess_fn)
# first 4 element of fullspec corresponds to spec:
# https://docs.python.org/3.4/library/inspect.html
argspec = inspect.ArgSpec(*argspec[:4])
else:
argspec = inspect.getargspec(preprocess_fn) # pylint: disable=deprecated-method
if 'mode' in argspec.args:
raise ValueError('The passed preprocess_fn has an open argument `mode`'
'which should be patched by a closure or with '
'functools.partial.')

self._preprocess_fn = preprocess_fn

Expand Down
12 changes: 7 additions & 5 deletions input_generators/default_input_generator.py
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python2, python3
"""Default input generators wrapping tfdata, metadata and replay buffer."""

from __future__ import absolute_import
Expand All @@ -23,16 +24,16 @@
import abc
import json
import os
from typing import Dict, Optional, Text

from absl import logging
import gin
import six
from tensor2robot.input_generators import abstract_input_generator
from tensor2robot.utils import tensorspec_utils
from tensor2robot.utils import tfdata

import tensorflow as tf

from typing import Dict, Optional, Text


_TF_CONFIG_ENV = 'TF_CONFIG'
_MULTI_EVAL_NAME = 'multi_eval_name'
Expand Down Expand Up @@ -152,9 +153,10 @@ def __init__(self,
raise ValueError('multi_eval_name not set in TF_CONFIG env variable')


class GeneratorInputGenerator(abstract_input_generator.AbstractInputGenerator):
class GeneratorInputGenerator(
six.with_metaclass(abc.ABCMeta,
abstract_input_generator.AbstractInputGenerator)):
"""Class to use for constructing input generators from Python generator objects."""
__metaclass__ = abc.ABCMeta

def __init__(self, sequence_length=None, **kwargs):
self._sequence_length = sequence_length
Expand Down
2 changes: 2 additions & 0 deletions layers/film_resnet_model.py
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python2, python3
"""Contains definitions for Residual Networks with/without FiLM conditioning.
Residual networks ('v1' ResNets) were originally proposed in:
Expand All @@ -36,6 +37,7 @@
from __future__ import division
from __future__ import print_function

from six.moves import range
import tensorflow as tf

_BATCH_NORM_DECAY = 0.997
Expand Down
8 changes: 6 additions & 2 deletions layers/resnet.py
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python2, python3
"""ResNet tower.
"""

Expand All @@ -21,10 +22,13 @@

from __future__ import print_function

from typing import List, Optional

import gin
from six.moves import range
from tensor2robot.layers import film_resnet_model as resnet_lib
import tensorflow as tf
from typing import Optional, List

slim = tf.contrib.slim


Expand Down Expand Up @@ -58,7 +62,7 @@ def _get_block_sizes(resnet_size):
except KeyError:
err = ('Could not find layers for selected Resnet size.\n'
'Size received: {}; sizes allowed: {}.'.format(
resnet_size, choices.keys()))
resnet_size, list(choices.keys())))
raise ValueError(err)


Expand Down
2 changes: 2 additions & 0 deletions layers/resnet_test.py
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python2, python3
"""Tests for tensor2robot.layers.resnet."""

from __future__ import absolute_import
Expand All @@ -21,6 +22,7 @@

import functools
from absl.testing import parameterized
from six.moves import range
from tensor2robot.layers import resnet
import tensorflow as tf

Expand Down
6 changes: 4 additions & 2 deletions layers/snail.py
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python2, python3
"""Implementation of building blocks from https://arxiv.org/abs/1707.03141.
Implementation here is designed to match pseudocode in the paper.
Expand All @@ -23,11 +24,12 @@

from __future__ import print_function

from typing import Text

import numpy as np
from six.moves import range
import tensorflow as tf

from typing import Text

layers = tf.contrib.layers


Expand Down
2 changes: 2 additions & 0 deletions layers/snail_test.py
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python2, python3
"""Tests for SNAIL."""

from __future__ import absolute_import
Expand All @@ -21,6 +22,7 @@
from __future__ import print_function

import numpy as np
from six.moves import range
from tensor2robot.layers import snail
import tensorflow as tf

Expand Down
2 changes: 2 additions & 0 deletions layers/spatial_softmax.py
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python2, python3
"""TensorFlow impl of Spatial Softmax layers. (spatial soft arg-max).
TODO(T2R_CONTRIBUTORS) - consider replacing with contrib version.
Expand All @@ -24,6 +25,7 @@

import gin
import numpy as np
from six.moves import range
import tensorflow as tf
import tensorflow_probability as tfp

Expand Down
5 changes: 5 additions & 0 deletions layers/vision_layers.py
Expand Up @@ -13,11 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python2, python3
"""Implements image-to-pose regression model from WTL paper.
Colloquially referred to as 'Berkeley-Net'.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gin
from six.moves import range
from tensor2robot.layers import spatial_softmax
import tensorflow as tf

Expand Down
9 changes: 6 additions & 3 deletions meta_learning/maml_inner_loop.py
Expand Up @@ -13,17 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python2, python3
"""Custom getter utilities to leverage existing models for MAML."""

from __future__ import absolute_import
from __future__ import division

from __future__ import print_function

from typing import List, Mapping, Optional, Text, Tuple

import gin
from six.moves import zip
import tensorflow as tf

from typing import List, Mapping, Text, Tuple, Optional


@gin.configurable
class MAMLInnerLoopGradientDescent(object):
Expand Down Expand Up @@ -167,7 +170,7 @@ def _compute_and_apply_gradients(self, loss):
# The new cache will contain the updated variables.
self._custom_getter_variable_cache = {}

variable_list = variable_cache_old.keys()
variable_list = list(variable_cache_old.keys())
gradients = tf.gradients(
[loss], [variable_cache_old[name] for name in variable_list])
for name, gradient in zip(variable_list, gradients):
Expand Down

0 comments on commit 58aa81a

Please sign in to comment.