# 神經層

## 建立模型結構

In [22]:
import tensorflow as tf
from tensorflow.keras import layers

# 建立模型
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu', name="layer1"),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax', name="layer2")
])

# 設定優化器(optimizer)、損失函數(loss)、效能衡量指標(metrics)的類別
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# 顯示模型彙總資訊
model.summary()

Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten_3 (Flatten)          (None, 784)               0         
_________________________________________________________________
layer1 (Dense)               (None, 128)               100480    
_________________________________________________________________
dropout_4 (Dropout)          (None, 128)               0         
_________________________________________________________________
layer2 (Dense)               (None, 10)                1290      
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________


## 第一層 Dense 參數個數計算

In [23]:
# 設定模型的 input/output
feature_extractor = tf.keras.Model(
    inputs=model.inputs,
    outputs=model.get_layer(name="layer1").output,
)

# 呼叫 feature_extractor 取得 output
x = tf.ones((1, 28, 28))
features = feature_extractor(x)
features.shape

TensorShape([1, 128])

In [24]:
# 第一層 Dense 參數個數計算
parameter_count = (28 * 28) * features.shape[1] + features.shape[1]
print(f'參數(parameter)個數：{parameter_count}')

參數(parameter)個數：100480


## 第二層 Dense 參數個數計算

In [26]:
# 設定模型的 input/output
feature_extractor = tf.keras.Model(
    inputs=model.inputs,
    outputs=model.get_layer(name="layer2").output,
)

# 呼叫 feature_extractor 取得 output
x = tf.ones((1, 28, 28))
features = feature_extractor(x)

parameter_count = 128 * features.shape[1] + features.shape[1]
print(f'參數(parameter)個數：{parameter_count}')

參數(parameter)個數：1290
