Skip to content

Commit

Permalink
Fix #2814 lambda function serialization and deserialization (#3639)
Browse files Browse the repository at this point in the history
* Remove old-style function attributes.

* Fix lambda function serialization and deserialization.
  • Loading branch information
gw0 authored and fchollet committed Aug 31, 2016
1 parent c939ceb commit 6417d90
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 32 deletions.
21 changes: 5 additions & 16 deletions keras/engine/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

import numpy as np

import sys
import marshal
import types as python_types
import warnings
import copy
Expand All @@ -15,6 +13,7 @@

from .. import backend as K
from ..utils.io_utils import ask_to_proceed_with_overwrite
from ..utils.generic_utils import func_dump, func_load


def to_list(x):
Expand Down Expand Up @@ -1414,13 +1413,8 @@ def compute_mask(self, inputs, mask=None):
raise Exception('Invalid merge mode: {}'.format(self.mode))

def get_config(self):
py3 = sys.version_info[0] == 3

if isinstance(self.mode, python_types.LambdaType):
if py3:
mode = marshal.dumps(self.mode.__code__).decode('raw_unicode_escape')
else:
mode = marshal.dumps(self.mode.func_code).decode('raw_unicode_escape')
mode = func_dump(self.mode)
mode_type = 'lambda'
elif callable(self.mode):
mode = self.mode.__name__
Expand All @@ -1430,10 +1424,7 @@ def get_config(self):
mode_type = 'raw'

if isinstance(self._output_shape, python_types.LambdaType):
if py3:
output_shape = marshal.dumps(self._output_shape.__code__).decode('raw_unicode_escape')
else:
output_shape = marshal.dumps(self._output_shape.func_code).decode('raw_unicode_escape')
output_shape = func_dump(self._output_shape)
output_shape_type = 'lambda'
elif callable(self._output_shape):
output_shape = self._output_shape.__name__
Expand All @@ -1456,17 +1447,15 @@ def from_config(cls, config):
if mode_type == 'function':
mode = globals()[config['mode']]
elif mode_type == 'lambda':
mode = marshal.loads(config['mode'].encode('raw_unicode_escape'))
mode = python_types.FunctionType(mode, globals())
mode = func_load(config['mode'], globs=globals())
else:
mode = config['mode']

output_shape_type = config.pop('output_shape_type')
if output_shape_type == 'function':
output_shape = globals()[config['output_shape']]
elif output_shape_type == 'lambda':
output_shape = marshal.loads(config['output_shape'].encode('raw_unicode_escape'))
output_shape = python_types.FunctionType(output_shape, globals())
output_shape = func_load(config['output_shape'], globs=globals())
else:
output_shape = config['output_shape']

Expand Down
21 changes: 5 additions & 16 deletions keras/layers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@
import copy
import inspect
import types as python_types
import marshal
import sys
import warnings

from .. import backend as K
from .. import activations, initializations, regularizers, constraints
from ..engine import InputSpec, Layer, Merge
from ..regularizers import ActivityRegularizer
from ..utils.generic_utils import func_dump, func_load


class Masking(Layer):
Expand Down Expand Up @@ -554,23 +553,15 @@ def call(self, x, mask=None):
return self.function(x, **arguments)

def get_config(self):
py3 = sys.version_info[0] == 3

if isinstance(self.function, python_types.LambdaType):
if py3:
function = marshal.dumps(self.function.__code__).decode('raw_unicode_escape')
else:
function = marshal.dumps(self.function.func_code).decode('raw_unicode_escape')
function = func_dump(self.function)
function_type = 'lambda'
else:
function = self.function.__name__
function_type = 'function'

if isinstance(self._output_shape, python_types.LambdaType):
if py3:
output_shape = marshal.dumps(self._output_shape.__code__).decode('raw_unicode_escape')
else:
output_shape = marshal.dumps(self._output_shape.func_code).decode('raw_unicode_escape')
output_shape = func_dump(self._output_shape)
output_shape_type = 'lambda'
elif callable(self._output_shape):
output_shape = self._output_shape.__name__
Expand All @@ -593,17 +584,15 @@ def from_config(cls, config):
if function_type == 'function':
function = globals()[config['function']]
elif function_type == 'lambda':
function = marshal.loads(config['function'].encode('raw_unicode_escape'))
function = python_types.FunctionType(function, globals())
function = func_load(config['function'], globs=globals())
else:
raise Exception('Unknown function type: ' + function_type)

output_shape_type = config.pop('output_shape_type')
if output_shape_type == 'function':
output_shape = globals()[config['output_shape']]
elif output_shape_type == 'lambda':
output_shape = marshal.loads(config['output_shape'].encode('raw_unicode_escape'))
output_shape = python_types.FunctionType(output_shape, globals())
output_shape = func_load(config['output_shape'], globs=globals())
else:
output_shape = config['output_shape']

Expand Down
39 changes: 39 additions & 0 deletions keras/utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import time
import sys
import six
import marshal
import types as python_types


def get_from_module(identifier, module_params, module_name,
Expand Down Expand Up @@ -33,6 +35,43 @@ def make_tuple(*args):
return args


def func_dump(func):
'''Serialize user defined function.'''
code = marshal.dumps(func.__code__).decode('raw_unicode_escape')
defaults = func.__defaults__
if func.__closure__:
closure = tuple(c.cell_contents for c in func.__closure__)
else:
closure = None
return code, defaults, closure


def func_load(code, defaults=None, closure=None, globs=None):
'''Deserialize user defined function.'''
if isinstance(code, (tuple, list)): # unpack previous dump
code, defaults, closure = code
code = marshal.loads(code.encode('raw_unicode_escape'))
if closure is not None:
closure = func_reconstruct_closure(closure)
if globs is None:
globs = globals()
return python_types.FunctionType(code, globs, name=code.co_name, argdefs=defaults, closure=closure)


def func_reconstruct_closure(values):
'''Deserialization helper that reconstructs a closure.'''
nums = range(len(values))
src = ["def func(arg):"]
src += [" _%d = arg[%d]" % (n, n) for n in nums]
src += [" return lambda:(%s)" % ','.join(["_%d" % n for n in nums]), ""]
src = '\n'.join(src)
try:
exec(src)
except:
raise SyntaxError(src)
return func(values).__closure__


class Progbar(object):
def __init__(self, target, width=30, verbose=1, interval=0.01):
'''
Expand Down

0 comments on commit 6417d90

Please sign in to comment.