# 神经层

## 建立模型结构

In [1]:
import tensorflow as tf
from tensorflow.keras import layers

# 建立模型
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu', name="layer1"),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax', name="layer2")
])

# 设定优化器(optimizer)、损失函数(loss)、效能衡量指标(metrics)的类别
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# 显示模型汇总资讯
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten (Flatten)            (None, 784)               0         
_________________________________________________________________
layer1 (Dense)               (None, 128)               100480    
_________________________________________________________________
dropout (Dropout)            (None, 128)               0         
_________________________________________________________________
layer2 (Dense)               (None, 10)                1290      
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________


## 第一层 Dense 参数个数计算

In [2]:
# 设定模型的 input/output
feature_extractor = tf.keras.Model(
    inputs=model.inputs,
    outputs=model.get_layer(name="layer1").output,
)

# 呼叫 feature_extractor 取得 output
x = tf.ones((1, 28, 28))
features = feature_extractor(x)
features.shape

TensorShape([1, 128])

In [3]:
# 第一层 Dense 参数个数计算
parameter_count = (28 * 28) * features.shape[1] + features.shape[1]
print(f'参数(parameter)个数：{parameter_count}')

参数(parameter)个数：100480


## 第二层 Dense 参数个数计算

In [4]:
# 设定模型的 input/output
feature_extractor = tf.keras.Model(
    inputs=model.inputs,
    outputs=model.get_layer(name="layer2").output,
)

# 呼叫 feature_extractor 取得 output
x = tf.ones((1, 28, 28))
features = feature_extractor(x)

parameter_count = 128 * features.shape[1] + features.shape[1]
print(f'参数(parameter)个数：{parameter_count}')

参数(parameter)个数：1290
