Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model export does not work for N2V2 #128

Closed
tibuch opened this issue Oct 11, 2022 · 1 comment
Closed

Model export does not work for N2V2 #128

tibuch opened this issue Oct 11, 2022 · 1 comment
Labels
bug Something isn't working

Comments

@tibuch
Copy link
Collaborator

tibuch commented Oct 11, 2022

The blurpool implementation does not support model export:

1/1 [==============================] - 0s 173ms/step

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In [21], line 1
----> 1 model.export_TF(name='Noise2Void - 2D SEM Example', 
      2                 description='This is the 2D Noise2Void example trained on SEM data in python.', 
      3                 authors=["Tim-Oliver Buchholz", "Alexander Krull", "Florian Jug"],
      4                 test_img=X_val[0,...,0], axes='YX',
      5                 patch_shape=patch_shape)

File /tungstenfs/scratch/gmicro_share/_software/CondaEnvs/Linux/n2v-test-pr/lib/python3.10/site-packages/csbdeep/models/base_model.py:32, in suppress_without_basedir.<locals>._suppress_without_basedir.<locals>.wrapper(*args, **kwargs)
     30     warn is False or warnings.warn("Suppressing call of '%s' (due to basedir=None)." % f.__name__)
     31 else:
---> 32     return f(*args, **kwargs)

File /tungstenfs/scratch/gmicro_share/_software/CondaEnvs/Linux/n2v-test-pr/lib/python3.10/site-packages/n2v/models/n2v_standard.py:473, in N2V.export_TF(self, name, description, authors, test_img, axes, patch_shape, fname)
    464 # CSBDeep Export
    465 meta = {
    466     'type': self.__class__.__name__,
    467     'version': package_version,
   (...)
    471     'tile_overlap': self._axes_tile_overlap(self.config.axes),
    472 }
--> 473 export_SavedModel(self.keras_model, str(fname), meta=meta)
    474 # CSBDeep Export Done
    475 
    476 # Replace : with -
    477 name = name.replace(':', ' -')

File /tungstenfs/scratch/gmicro_share/_software/CondaEnvs/Linux/n2v-test-pr/lib/python3.10/site-packages/csbdeep/utils/tf.py:230, in export_SavedModel(model, outpath, meta, format)
    228 with tempfile.TemporaryDirectory() as tmpdir:
    229     tmpsubdir = os.path.join(tmpdir,'model')
--> 230     export_to_dir(tmpsubdir)
    231     shutil.make_archive(outpath, format, tmpsubdir)

File /tungstenfs/scratch/gmicro_share/_software/CondaEnvs/Linux/n2v-test-pr/lib/python3.10/site-packages/csbdeep/utils/tf.py:207, in export_SavedModel.<locals>.export_to_dir(dirname)
    204 weights = model.get_weights()
    205 with tf.Graph().as_default():
    206     # clone model in new graph and set weights
--> 207     _model = clone_model(model)
    208     _model.set_weights(weights)
    209     _export(_model)

File /tungstenfs/scratch/gmicro_share/_software/CondaEnvs/Linux/n2v-test-pr/lib/python3.10/site-packages/keras/models/cloning.py:505, in clone_model(model, input_tensors, clone_function)
    501     return _clone_sequential_model(
    502         model, input_tensors=input_tensors, layer_fn=clone_function
    503     )
    504 else:
--> 505     return _clone_functional_model(
    506         model, input_tensors=input_tensors, layer_fn=clone_function
    507     )

File /tungstenfs/scratch/gmicro_share/_software/CondaEnvs/Linux/n2v-test-pr/lib/python3.10/site-packages/keras/models/cloning.py:208, in _clone_functional_model(model, input_tensors, layer_fn)
    202 if not callable(layer_fn):
    203     raise ValueError(
    204         "Expected `layer_fn` argument to be a callable. "
    205         f"Received: layer_fn={layer_fn}"
    206     )
--> 208 model_configs, created_layers = _clone_layers_and_model_config(
    209     model, new_input_layers, layer_fn
    210 )
    211 # Reconstruct model from the config, using the cloned layers.
    212 (
    213     input_tensors,
    214     output_tensors,
   (...)
    217     model_configs, created_layers=created_layers
    218 )

File /tungstenfs/scratch/gmicro_share/_software/CondaEnvs/Linux/n2v-test-pr/lib/python3.10/site-packages/keras/models/cloning.py:273, in _clone_layers_and_model_config(model, input_layers, layer_fn)
    270         created_layers[layer.name] = layer_fn(layer)
    271     return {}
--> 273 config = functional.get_network_config(
    274     model, serialize_layer_fn=_copy_layer
    275 )
    276 return config, created_layers

File /tungstenfs/scratch/gmicro_share/_software/CondaEnvs/Linux/n2v-test-pr/lib/python3.10/site-packages/keras/engine/functional.py:1563, in get_network_config(network, serialize_layer_fn, config)
   1558         node_data = node.serialize(
   1559             _make_node_key, node_conversion_map
   1560         )
   1561         filtered_inbound_nodes.append(node_data)
-> 1563 layer_config = serialize_layer_fn(layer)
   1564 layer_config["name"] = layer.name
   1565 layer_config["inbound_nodes"] = filtered_inbound_nodes

File /tungstenfs/scratch/gmicro_share/_software/CondaEnvs/Linux/n2v-test-pr/lib/python3.10/site-packages/keras/models/cloning.py:270, in _clone_layers_and_model_config.<locals>._copy_layer(layer)
    268     created_layers[layer.name] = InputLayer(**layer.get_config())
    269 else:
--> 270     created_layers[layer.name] = layer_fn(layer)
    271 return {}

File /tungstenfs/scratch/gmicro_share/_software/CondaEnvs/Linux/n2v-test-pr/lib/python3.10/site-packages/keras/models/cloning.py:50, in _clone_layer(layer)
     49 def _clone_layer(layer):
---> 50     return layer.__class__.from_config(layer.get_config())

File /tungstenfs/scratch/gmicro_share/_software/CondaEnvs/Linux/n2v-test-pr/lib/python3.10/site-packages/keras/engine/base_layer.py:786, in Layer.get_config(self)
    783 # Check that either the only argument in the `__init__` is  `self`,
    784 # or that `get_config` has been overridden:
    785 if extra_args and hasattr(self.get_config, "_is_default"):
--> 786     raise NotImplementedError(
    787         textwrap.dedent(
    788             f"""
    789   Layer {self.__class__.__name__} has arguments {extra_args}
    790   in `__init__` and therefore must override `get_config()`.
    791 
    792   Example:
    793 
    794   class CustomLayer(keras.layers.Layer):
    795       def __init__(self, arg1, arg2):
    796           super().__init__()
    797           self.arg1 = arg1
    798           self.arg2 = arg2
    799 
    800       def get_config(self):
    801           config = super().get_config()
    802           config.update({{
    803               "arg1": self.arg1,
    804               "arg2": self.arg2,
    805           }})
    806           return config"""
    807         )
    808     )
    810 return config

NotImplementedError: 
Layer MaxBlurPool2D has arguments ['pool']
in `__init__` and therefore must override `get_config()`.

Example:

class CustomLayer(keras.layers.Layer):
    def __init__(self, arg1, arg2):
        super().__init__()
        self.arg1 = arg1
        self.arg2 = arg2

    def get_config(self):
        config = super().get_config()
        config.update({
            "arg1": self.arg1,
            "arg2": self.arg2,
        })
        return config
@tibuch tibuch added the bug Something isn't working label Oct 11, 2022
@jdeschamps
Copy link
Member

Fixed here: #130

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants