In [1]:
import tensorflow as tf 
from tensorflow import keras

In [2]:
tf.__version__

'2.3.0'

# 使用add_weight()创建一个自定义的全连接层

In [3]:
class Old_Linear(keras.layers.Layer):
    def __init__(self, units, input_dim):
        super(Old_Linear, self).__init__()
        self._weights = self.add_weight(shape=(input_dim, units), 
                                        initializer='random_normal',
                                       trainable=True)
        self._bias = self.add_weight(shape=(units,),
                                    initializer='zeros',
                                    trainable=True)
    
    def call(self, inputs):
        return tf.matmul(inputs, self._weights) + self._bias

# 使用 build 推迟 weights 的初始化到获得输入数据的 shape 之后

In [4]:
class Linear(keras.layers.Layer):
    def __init__(self, units):
        super(Linear, self).__init__()
        self.units = units
    
    def build(self, input_shape):
        self._weights = self.add_weight(shape=(input_shape[-1], self.units), 
                                        initializer='random_normal',
                                       trainable=True)
        self._bias = self.add_weight(shape=(self.units,),
                                    initializer='zeros',
                                    trainable=True)
    
    def call(self, inputs):
        return tf.matmul(inputs, self._weights) + self._bias

# 层的递归组合

In [5]:
class MLP_Block(keras.layers.Layer):
    def __init__(self):
        super(MLP_Block, self).__init__()
        self.layer1 = Linear(32)
        self.layer2 = Linear(64)
        self.layer3 = Linear(1)
    
    def call(self, inputs):
        x = self.layer1(inputs)
        x = tf.nn.relu(x)
        x = self.layer2(x)
        x = tf.nn.relu(x)
        x = self.layer3(x)
        return x

In [6]:
mlp = MLP_Block()

In [7]:
# mlp(inputs)

In [8]:
class MLP_Block_2(keras.layers.Layer):
    def __init__(self):
        super(MLP_Block_2, self).__init__()
        self.layer1 = tf.keras.layers.Dense(32)
        self.layer2 = tf.keras.layers.Dense(64)
        self.layer3 = tf.keras.layers.Dense(32)
    
    def call(self, inputs):
        x1 = self.layer1(inputs)
        x1 = tf.nn.relu(x1)
        x2 = self.layer2(x1)
        x2 = tf.nn.relu(x2)
        x3 = self.layer3(x2)
        x = tf.concat([x1, x3])
        return x

In [9]:
mlp_2 = MLP_Block_2()