* https://www.tensorflow.org/guide/keras/custom_layers_and_models#you_can_optionally_enable_serialization_on_your_layers

In [1]:
import tensorflow as tf
import numpy as np
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Layer, Input
from functools import wraps # This convenience func preserves name and docstring

print(tf.__version__)

2.0.0


## A simple dense layer

In [2]:
class MyDenseLayer(Layer):
    def __init__(self, num_outputs, **kwargs):
        super(MyDenseLayer, self).__init__(**kwargs)
        self.num_outputs = num_outputs
    
    def build(self, input_shape):
        self.kernel = self.add_variable("kernel", shape=[int(input_shape[-1]), self.num_outputs])
        super(MyDenseLayer, self).build(input_shape)  # Be sure to call this at the end
        
    def call(self, inputs):
        return tf.matmul(inputs, self.kernel)

## Add `get_config` method to class

In [3]:
def add_method(cls):
    def decorator(func):
        @wraps(func) 
        def wrapper(self, *args, **kwargs): 
#             return func(*args, **kwargs)
            return func(self, *args, **kwargs)
        setattr(cls, func.__name__, wrapper)
        # Note we are not binding func, but wrapper which accepts self but does exactly the same as func
        return func # returning func means func can still be used normally
    return decorator

In [4]:
@add_method(MyDenseLayer)
def get_config(self):
    config = super(MyDenseLayer, self).get_config()
    config.update({'num_outputs': self.num_outputs})
    return config

In [5]:
a = MyDenseLayer(10)
a.get_config()

{'name': 'my_dense_layer',
 'trainable': True,
 'dtype': 'float32',
 'num_outputs': 10}

If you need more flexibility when deserializing the layer from its config, you can also override the `from_config` class method. This is the base implementation of `from_config`
```python
def from_config(cls, config):
    return cls(**config)
```

In [6]:
b = MyDenseLayer.from_config(a.get_config())
b.get_config()

{'name': 'my_dense_layer',
 'trainable': True,
 'dtype': 'float32',
 'num_outputs': 10}

## Composing layers

In [7]:
class DenseBlock(Layer):
    def __init__(self, **kwargs):
        super(DenseBlock, self).__init__(**kwargs)
        self.dense_1 = MyDenseLayer(32)
        self.dense_2 = MyDenseLayer(64)
        self.dense_3 = MyDenseLayer(3)
    def call(self, inputs):
        x = self.dense_1(inputs)
        x = tf.nn.relu(x)
        x = self.dense_2(x)
        x = tf.nn.relu(x)
        return self.dense_3(x)

In [8]:
x = DenseBlock()
x.get_config()

{'name': 'dense_block', 'trainable': True, 'dtype': 'float32'}

In [9]:
y = DenseBlock.from_config(x.get_config())
y.get_config()

{'name': 'dense_block', 'trainable': True, 'dtype': 'float32'}

## Save model

In [10]:
inputs = Input(shape=(10,))
x = MyDenseLayer(20)(inputs)

model = Model(inputs, x)
model.summary()

W1123 15:22:45.628622 25044 deprecation.py:323] From <ipython-input-2-37c6602cccc5>:7: Layer.add_variable (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
Please use `layer.add_weight` method instead.


Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 10)]              0         
_________________________________________________________________
my_dense_layer_7 (MyDenseLay (None, 20)                200       
Total params: 200
Trainable params: 200
Non-trainable params: 0
_________________________________________________________________


In [11]:
model.save('model.h5')

In [12]:
x2 = DenseBlock()(inputs)

model2 = Model(inputs, x2)
model2.summary()

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 10)]              0         
_________________________________________________________________
dense_block_1 (DenseBlock)   (None, 3)                 2560      
Total params: 2,560
Trainable params: 2,560
Non-trainable params: 0
_________________________________________________________________


In [13]:
model2.save('model2.h5')

## Load model

In [14]:
model = load_model('model.h5', custom_objects={'MyDenseLayer': MyDenseLayer})
model.summary()

W1123 15:22:46.453179 25044 hdf5_format.py:177] No training configuration found in save file: the model was *not* compiled. Compile it manually.


Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 10)]              0         
_________________________________________________________________
my_dense_layer_7 (MyDenseLay (None, 20)                200       
Total params: 200
Trainable params: 200
Non-trainable params: 0
_________________________________________________________________


In [15]:
model.get_config()

{'name': 'model',
 'layers': [{'class_name': 'InputLayer',
   'config': {'batch_input_shape': (None, 10),
    'dtype': 'float32',
    'sparse': False,
    'name': 'input_1'},
   'name': 'input_1',
   'inbound_nodes': []},
  {'class_name': 'MyDenseLayer',
   'config': {'name': 'my_dense_layer_7',
    'trainable': True,
    'dtype': 'float32',
    'num_outputs': 20},
   'name': 'my_dense_layer_7',
   'inbound_nodes': [[['input_1', 0, 0, {}]]]}],
 'input_layers': [['input_1', 0, 0]],
 'output_layers': [['my_dense_layer_7', 0, 0]]}

In [16]:
model2 = load_model('model2.h5', custom_objects={'DenseBlock': DenseBlock})
model2.summary()

W1123 15:22:46.756617 25044 hdf5_format.py:177] No training configuration found in save file: the model was *not* compiled. Compile it manually.


Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 10)]              0         
_________________________________________________________________
dense_block_1 (DenseBlock)   (None, 3)                 2560      
Total params: 2,560
Trainable params: 2,560
Non-trainable params: 0
_________________________________________________________________


In [17]:
model2.get_config()

{'name': 'model_1',
 'layers': [{'class_name': 'InputLayer',
   'config': {'batch_input_shape': (None, 10),
    'dtype': 'float32',
    'sparse': False,
    'name': 'input_1'},
   'name': 'input_1',
   'inbound_nodes': []},
  {'class_name': 'DenseBlock',
   'config': {'name': 'dense_block_1', 'trainable': True, 'dtype': 'float32'},
   'name': 'dense_block_1',
   'inbound_nodes': [[['input_1', 0, 0, {}]]]}],
 'input_layers': [['input_1', 0, 0]],
 'output_layers': [['dense_block_1', 0, 0]]}