diff --git a/keras/engine/topology.py b/keras/engine/topology.py index eb1a82f6ab5..a4b9f6a222f 100644 --- a/keras/engine/topology.py +++ b/keras/engine/topology.py @@ -5,8 +5,6 @@ import numpy as np -import sys -import marshal import types as python_types import warnings import copy @@ -15,6 +13,7 @@ from .. import backend as K from ..utils.io_utils import ask_to_proceed_with_overwrite +from ..utils.generic_utils import func_dump, func_load def to_list(x): @@ -1414,13 +1413,8 @@ def compute_mask(self, inputs, mask=None): raise Exception('Invalid merge mode: {}'.format(self.mode)) def get_config(self): - py3 = sys.version_info[0] == 3 - if isinstance(self.mode, python_types.LambdaType): - if py3: - mode = marshal.dumps(self.mode.__code__).decode('raw_unicode_escape') - else: - mode = marshal.dumps(self.mode.func_code).decode('raw_unicode_escape') + mode = func_dump(self.mode) mode_type = 'lambda' elif callable(self.mode): mode = self.mode.__name__ @@ -1430,10 +1424,7 @@ def get_config(self): mode_type = 'raw' if isinstance(self._output_shape, python_types.LambdaType): - if py3: - output_shape = marshal.dumps(self._output_shape.__code__).decode('raw_unicode_escape') - else: - output_shape = marshal.dumps(self._output_shape.func_code).decode('raw_unicode_escape') + output_shape = func_dump(self._output_shape) output_shape_type = 'lambda' elif callable(self._output_shape): output_shape = self._output_shape.__name__ @@ -1456,8 +1447,7 @@ def from_config(cls, config): if mode_type == 'function': mode = globals()[config['mode']] elif mode_type == 'lambda': - mode = marshal.loads(config['mode'].encode('raw_unicode_escape')) - mode = python_types.FunctionType(mode, globals()) + mode = func_load(config['mode'], globs=globals()) else: mode = config['mode'] @@ -1465,8 +1455,7 @@ def from_config(cls, config): if output_shape_type == 'function': output_shape = globals()[config['output_shape']] elif output_shape_type == 'lambda': - output_shape = marshal.loads(config['output_shape'].encode('raw_unicode_escape')) - output_shape = python_types.FunctionType(output_shape, globals()) + output_shape = func_load(config['output_shape'], globs=globals()) else: output_shape = config['output_shape'] diff --git a/keras/layers/core.py b/keras/layers/core.py index f21f139c43e..0828873707f 100644 --- a/keras/layers/core.py +++ b/keras/layers/core.py @@ -7,14 +7,13 @@ import copy import inspect import types as python_types -import marshal -import sys import warnings from .. import backend as K from .. import activations, initializations, regularizers, constraints from ..engine import InputSpec, Layer, Merge from ..regularizers import ActivityRegularizer +from ..utils.generic_utils import func_dump, func_load class Masking(Layer): @@ -554,23 +553,15 @@ def call(self, x, mask=None): return self.function(x, **arguments) def get_config(self): - py3 = sys.version_info[0] == 3 - if isinstance(self.function, python_types.LambdaType): - if py3: - function = marshal.dumps(self.function.__code__).decode('raw_unicode_escape') - else: - function = marshal.dumps(self.function.func_code).decode('raw_unicode_escape') + function = func_dump(self.function) function_type = 'lambda' else: function = self.function.__name__ function_type = 'function' if isinstance(self._output_shape, python_types.LambdaType): - if py3: - output_shape = marshal.dumps(self._output_shape.__code__).decode('raw_unicode_escape') - else: - output_shape = marshal.dumps(self._output_shape.func_code).decode('raw_unicode_escape') + output_shape = func_dump(self._output_shape) output_shape_type = 'lambda' elif callable(self._output_shape): output_shape = self._output_shape.__name__ @@ -593,8 +584,7 @@ def from_config(cls, config): if function_type == 'function': function = globals()[config['function']] elif function_type == 'lambda': - function = marshal.loads(config['function'].encode('raw_unicode_escape')) - function = python_types.FunctionType(function, globals()) + function = func_load(config['function'], globs=globals()) else: raise Exception('Unknown function type: ' + function_type) @@ -602,8 +592,7 @@ def from_config(cls, config): if output_shape_type == 'function': output_shape = globals()[config['output_shape']] elif output_shape_type == 'lambda': - output_shape = marshal.loads(config['output_shape'].encode('raw_unicode_escape')) - output_shape = python_types.FunctionType(output_shape, globals()) + output_shape = func_load(config['output_shape'], globs=globals()) else: output_shape = config['output_shape'] diff --git a/keras/utils/generic_utils.py b/keras/utils/generic_utils.py index 04092ff9d58..9f06e9b4871 100644 --- a/keras/utils/generic_utils.py +++ b/keras/utils/generic_utils.py @@ -3,6 +3,8 @@ import time import sys import six +import marshal +import types as python_types def get_from_module(identifier, module_params, module_name, @@ -33,6 +35,43 @@ def make_tuple(*args): return args +def func_dump(func): + '''Serialize user defined function.''' + code = marshal.dumps(func.__code__).decode('raw_unicode_escape') + defaults = func.__defaults__ + if func.__closure__: + closure = tuple(c.cell_contents for c in func.__closure__) + else: + closure = None + return code, defaults, closure + + +def func_load(code, defaults=None, closure=None, globs=None): + '''Deserialize user defined function.''' + if isinstance(code, (tuple, list)): # unpack previous dump + code, defaults, closure = code + code = marshal.loads(code.encode('raw_unicode_escape')) + if closure is not None: + closure = func_reconstruct_closure(closure) + if globs is None: + globs = globals() + return python_types.FunctionType(code, globs, name=code.co_name, argdefs=defaults, closure=closure) + + +def func_reconstruct_closure(values): + '''Deserialization helper that reconstructs a closure.''' + nums = range(len(values)) + src = ["def func(arg):"] + src += [" _%d = arg[%d]" % (n, n) for n in nums] + src += [" return lambda:(%s)" % ','.join(["_%d" % n for n in nums]), ""] + src = '\n'.join(src) + try: + exec(src) + except: + raise SyntaxError(src) + return func(values).__closure__ + + class Progbar(object): def __init__(self, target, width=30, verbose=1, interval=0.01): '''