In [1]:
import tensorflow as tf

In [None]:
# 自定义层

class MyLayer(tf.keras.layers.Layer):
    def __init__(self, units, activation=None, **kwargs):
        super(MyLayer, self).__init__(**kwargs)
        self.units = units
        self.activation = tf.keras.activations.get(activation)

    def build(self, input_shape):
        self.kernel = self.add_weight(
            shape=(input_shape[-1], self.units),
            initializer='random_normal',
            trainable=True,
            name='kernel'
        )
        self.bias = self.add_weight(
            shape=(self.units,),
            initializer='zeros',
            trainable=True,
            name='bias'
        )
    
    def call(self, inputs):
        z = tf.matmul(inputs, self.kernel) + self.bias
        if self.activation is not None:
            return self.activation(z)
        return z
    
    def get_config(self):
        config = super(MyLayer, self).get_config()
        config.update({
            'units': self.units,
            'activation': tf.keras.activations.serialize(self.activation)
        })
        return config