In [1]:
import numpy as np
import tensorflow as tf

In [6]:
class MyDense(tf.keras.layers.Layer):
    
    def __init__(self, inp_dim, outp_dim):
        # 自定义网络层
        super(MyDense, self).__init__()
        self.kernel = self.add_weight('w', [inp_dim, outp_dim], trainable=True)
        
    def call(self, inputs, training=None):
        # 自定义前向计算逻辑
        out = inputs @ self.kernel
        out = tf.nn.relu(out)
        return out

In [7]:
net = MyDense(4, 3)
net.variables, net.trainable_variables

([<tf.Variable 'w:0' shape=(4, 3) dtype=float32, numpy=
  array([[-0.6147055 ,  0.8135537 , -0.47653013],
         [-0.658579  , -0.22676665,  0.22495866],
         [-0.5980568 ,  0.63575566,  0.14694881],
         [ 0.3571409 , -0.7591137 , -0.19019789]], dtype=float32)>],
 [<tf.Variable 'w:0' shape=(4, 3) dtype=float32, numpy=
  array([[-0.6147055 ,  0.8135537 , -0.47653013],
         [-0.658579  , -0.22676665,  0.22495866],
         [-0.5980568 ,  0.63575566,  0.14694881],
         [ 0.3571409 , -0.7591137 , -0.19019789]], dtype=float32)>])

In [9]:
network = tf.keras.Sequential([
    MyDense(784, 256),
    MyDense(256, 128),
    MyDense(128, 64),
    MyDense(64, 32),
    MyDense(32, 10)
])
network.build(input_shape=(None, 28*28))
network.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
my_dense_3 (MyDense)         multiple                  200704    
_________________________________________________________________
my_dense_4 (MyDense)         multiple                  32768     
_________________________________________________________________
my_dense_5 (MyDense)         multiple                  8192      
_________________________________________________________________
my_dense_6 (MyDense)         multiple                  2048      
_________________________________________________________________
my_dense_7 (MyDense)         multiple                  320       
Total params: 244,032
Trainable params: 244,032
Non-trainable params: 0
_________________________________________________________________


In [10]:
class MyModel(tf.keras.Model):
    def __init__(self):
        # 自定义网络
        super(MyModel, self).__init__()
        self.fc1 = MyDense(28 * 28, 256)
        self.fc2 = MyDense(256, 128)
        self.fc3 = MyDense(128, 64)
        self.fc4 = MyDense(64, 32)
        self.fc5 = MyDense(32, 10)
    
    def call(self, inputs, training=None):
        # 自定义前向运行逻辑
        x = self.fc1(inputs)
        x = self.fc2(x)
        x = self.fc3(x)
        x = self.fc4(x)
        x = self.fc5(x)
        return x

In [11]:
network = MyModel()
network.build(input_shape=(None, 28 * 28))
network.summary()

Model: "my_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
my_dense_8 (MyDense)         multiple                  200704    
_________________________________________________________________
my_dense_9 (MyDense)         multiple                  32768     
_________________________________________________________________
my_dense_10 (MyDense)        multiple                  8192      
_________________________________________________________________
my_dense_11 (MyDense)        multiple                  2048      
_________________________________________________________________
my_dense_12 (MyDense)        multiple                  320       
Total params: 244,032
Trainable params: 244,032
Non-trainable params: 0
_________________________________________________________________
