In [98]:
import tensorflow as tf

In [102]:
class Linear_0(tf.keras.layers.Layer):  # 继承自:tf.keras.layers.Layer
    """自定义全连接层方法一"""

    def __init__(self, units, input_dim, name='dcdmm'):
        super(Linear_0, self).__init__(name=name)
        w_init = tf.random_normal_initializer()
        self.w = tf.Variable(initial_value=w_init(shape=(input_dim, units),
                                                  dtype=tf.float32),
                             trainable=True)
        b_init = tf.zeros_initializer()
        self.b = tf.Variable(initial_value=b_init(shape=(units,),
                                                  dtype=tf.float32),
                             trainable=True)

    def call(self, inputs):
        return tf.matmul(inputs, self.w) + self.b


class Linear_1(tf.keras.layers.Layer):
    """自定义全连接层方式二"""

    def __init__(self, units, input_dim, name='dcdmm'):
        super(Linear_1, self).__init__(name=name)
        # Adds a new variable to the layer.
        self.w = self.add_weight(shape=(input_dim, units),
                                 initializer=tf.random_normal_initializer,
                                 trainable=True)
        self.b = self.add_weight(shape=(units,),
                                 initializer=tf.zeros_initializer,
                                 trainable=True)

    def call(self, inputs):
        return tf.matmul(inputs, self.w) + self.b


class Linear_2(tf.keras.layers.Layer):
    """自定义全连接层方式三"""

    def __init__(self, units, name="dcdmm"):
        """__init__ , where you can do all input-independent initialization"""
        super(Linear_2, self).__init__(name=name)
        self.units = units

    def build(self, input_shape):
        """build, where you know the shapes of the input tensors and can do the rest of the initialization"""
        print("input_shape:", input_shape)
        self.w = self.add_weight(shape=(input_shape[-1], self.units),
                                 initializer='random_normal',
                                 trainable=True)
        self.b = self.add_weight(shape=(self.units,),
                                 initializer='random_normal',
                                 trainable=True)

    def call(self, inputs):
        """call, where you do the forward computation"""
        return tf.matmul(inputs, self.w) + self.b

In [103]:
x = tf.ones((2, 2))

linear_layer_0 = Linear_0(4, 2)
y_0 = linear_layer_0(x)
print(y_0, end='\n\n')

linear_layer_1 = Linear_1(4, 2)
y_1 = linear_layer_1(x)
print(y_1, end='\n\n')

linear_layer_2 = Linear_2(4)
y_2 = linear_layer_2(x)  # 进入build函数
print(y_2, end='\n\n')

linear_layer_2_1 = Linear_2(4)
linear_layer_2_1.build(input_shape=(None, 2))  # 进入build函数

tf.Tensor(
[[ 0.05823753  0.02533753 -0.06461746 -0.00608682]
 [ 0.05823753  0.02533753 -0.06461746 -0.00608682]], shape=(2, 4), dtype=float32)

tf.Tensor(
[[ 0.00143165 -0.06142656  0.02683384 -0.09348118]
 [ 0.00143165 -0.06142656  0.02683384 -0.09348118]], shape=(2, 4), dtype=float32)

input_shape: (2, 2)
23
tf.Tensor(
[[0.02280284 0.13814521 0.04022688 0.04280937]
 [0.02280284 0.13814521 0.04022688 0.04280937]], shape=(2, 4), dtype=float32)

input_shape: (None, 2)


In [101]:
# The concatenation of the lists trainable_weights and non_trainable_weights (in this order).
print('weight:', linear_layer_2.weights)
# List of variables that should not be included in backprop.
print('non-trainable weight:', linear_layer_2.non_trainable_weights)
# List of variables to be included in backprop.
print('trainable weight:', linear_layer_2.trainable_weights)
# The name of the layer (string).
print('name:', linear_layer_2.name)

weight: [<tf.Variable 'dcdmm/Variable:0' shape=(2, 4) dtype=float32, numpy=
array([[ 0.07897102, -0.01195054,  0.00275214, -0.0507881 ],
       [-0.07438576, -0.00409146,  0.03593483, -0.02142463]],
      dtype=float32)>, <tf.Variable 'dcdmm/Variable:0' shape=(4,) dtype=float32, numpy=array([0.00342475, 0.0139913 , 0.0307123 , 0.07513855], dtype=float32)>]
non-trainable weight: []
trainable weight: [<tf.Variable 'dcdmm/Variable:0' shape=(2, 4) dtype=float32, numpy=
array([[ 0.07897102, -0.01195054,  0.00275214, -0.0507881 ],
       [-0.07438576, -0.00409146,  0.03593483, -0.02142463]],
      dtype=float32)>, <tf.Variable 'dcdmm/Variable:0' shape=(4,) dtype=float32, numpy=array([0.00342475, 0.0139913 , 0.0307123 , 0.07513855], dtype=float32)>]
name: dcdmm
