In [1]:
import tensorflow as tf

In [2]:
class Linear_0(tf.keras.layers.Layer):  # 继承自:tf.keras.layers.Layer
    """自定义全连接层方法一"""

    def __init__(self, units, input_dim, name='dcdmm'):
        super(Linear_0, self).__init__(name=name)
        w_init = tf.random_normal_initializer()
        self.w = tf.Variable(initial_value=w_init(shape=(input_dim, units),
                                                  dtype=tf.float32),
                             trainable=True, name='x')  # ★★★★★不要忘记对变量进行命名
        b_init = tf.zeros_initializer()
        self.b = tf.Variable(initial_value=b_init(shape=(units,),
                                                  dtype=tf.float32),
                             trainable=True, name='b')

    def call(self, inputs):
        return tf.matmul(inputs, self.w) + self.b


class Linear_1(tf.keras.layers.Layer):
    """自定义全连接层方式二"""

    def __init__(self, units, input_dim, name='dcdmm'):
        super(Linear_1, self).__init__(name=name)
        # Adds a new variable to the layer.
        self.w = self.add_weight(shape=(input_dim, units),
                                 initializer=tf.random_normal_initializer,
                                 trainable=True, name='w')
        self.b = self.add_weight(shape=(units,),
                                 initializer=tf.zeros_initializer,
                                 trainable=True, name='b')

    def call(self, inputs):
        return tf.matmul(inputs, self.w) + self.b


class Linear_2(tf.keras.layers.Layer):
    """自定义全连接层方式三"""

    def __init__(self, units, name="dcdmm"):
        """__init__ , where you can do all input-independent initialization"""
        super(Linear_2, self).__init__(name=name)
        self.units = units

    def build(self, input_shape):
        """build, where you know the shapes of the input tensors and can do the rest of the initialization"""
        print("input_shape:", input_shape)
        self.w = self.add_weight(shape=(input_shape[-1], self.units),
                                 initializer='random_normal',
                                 trainable=True, name='w')
        self.b = self.add_weight(shape=(self.units,),
                                 initializer='random_normal',
                                 trainable=True, name='b')

    def call(self, inputs):
        """call, where you do the forward computation"""
        return tf.matmul(inputs, self.w) + self.b

In [3]:
x = tf.ones((2, 2))

linear_layer_0 = Linear_0(4, 2)
y_0 = linear_layer_0(x)
print(y_0, end='\n\n')

linear_layer_1 = Linear_1(4, 2)
y_1 = linear_layer_1(x)
print(y_1, end='\n\n')

linear_layer_2 = Linear_2(4)
y_2 = linear_layer_2(x)  # call()函数第一次执行时会被调用一次
print(y_2, end='\n\n')

y_2_1 = linear_layer_2(x)
print(y_2_1, end='\n\n')

linear_layer_2_1 = Linear_2(4)
linear_layer_2_1.build(input_shape=(None, 2))  # 显式调用build函数

tf.Tensor(
[[-0.01174007  0.0963226   0.06558874  0.09499393]
 [-0.01174007  0.0963226   0.06558874  0.09499393]], shape=(2, 4), dtype=float32)

tf.Tensor(
[[ 0.07614176 -0.02075787 -0.14038478  0.07890347]
 [ 0.07614176 -0.02075787 -0.14038478  0.07890347]], shape=(2, 4), dtype=float32)

input_shape: (2, 2)
tf.Tensor(
[[ 0.07164451 -0.09050271 -0.03292338 -0.04957008]
 [ 0.07164451 -0.09050271 -0.03292338 -0.04957008]], shape=(2, 4), dtype=float32)

tf.Tensor(
[[ 0.07164451 -0.09050271 -0.03292338 -0.04957008]
 [ 0.07164451 -0.09050271 -0.03292338 -0.04957008]], shape=(2, 4), dtype=float32)

input_shape: (None, 2)


In [4]:
# The concatenation of the lists trainable_weights and non_trainable_weights (in this order).
print('weight:', linear_layer_2.weights)
# List of variables that should not be included in backprop.
print('non-trainable weight:', linear_layer_2.non_trainable_weights)
# List of variables to be included in backprop.
print('trainable weight:', linear_layer_2.trainable_weights)
# The name of the layer (string).
print('name:', linear_layer_2.name)

weight: [<tf.Variable 'dcdmm/w:0' shape=(2, 4) dtype=float32, numpy=
array([[ 0.06234207,  0.01973487, -0.01235966, -0.0502708 ],
       [-0.02465934, -0.09822895,  0.02504545, -0.02374667]],
      dtype=float32)>, <tf.Variable 'dcdmm/b:0' shape=(4,) dtype=float32, numpy=array([ 0.03396178, -0.01200864, -0.04560917,  0.02444738], dtype=float32)>]
non-trainable weight: []
trainable weight: [<tf.Variable 'dcdmm/w:0' shape=(2, 4) dtype=float32, numpy=
array([[ 0.06234207,  0.01973487, -0.01235966, -0.0502708 ],
       [-0.02465934, -0.09822895,  0.02504545, -0.02374667]],
      dtype=float32)>, <tf.Variable 'dcdmm/b:0' shape=(4,) dtype=float32, numpy=array([ 0.03396178, -0.01200864, -0.04560917,  0.02444738], dtype=float32)>]
name: dcdmm
