# 特征转换

## One-hot encoding

In [1]:
import tensorflow as tf

In [2]:
# One-hot encoding
# num_classes：类别个数，可不设定
tf.keras.utils.to_categorical([0, 1, 2, 3], num_classes=9) 

array([[1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0.]], dtype=float32)

## MNIST 手写阿拉伯数字辨识

In [3]:
mnist = tf.keras.datasets.mnist

# 载入 MNIST 手写阿拉伯数字资料
(x_train, y_train),(x_test, y_test) = mnist.load_data()

# 特征缩放，使用常态化(Normalization)，公式 = (x - min) / (max - min)
x_train_norm, x_test_norm = x_train / 255.0, x_test / 255.0

# One-hot encoding
y_train = tf.keras.utils.to_categorical(y_train)
y_test = tf.keras.utils.to_categorical(y_test)

# 建立模型
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(256, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation='softmax')
])

# 设定优化器(optimizer)、损失函数(loss)、效能衡量指标(metrics)的类别
model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# 模型训练
history = model.fit(x_train_norm, y_train, epochs=5, validation_split=0.2)

# 评分(Score Model)
score=model.evaluate(x_test_norm, y_test, verbose=0)

for i, x in enumerate(score):
    print(f'{model.metrics_names[i]}: {score[i]:.4f}')

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
loss: 0.0763
accuracy: 0.9766


## Normalization

In [4]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers.experimental import preprocessing

# 测试资料
data = np.array([[0.1, 0.2, 0.3], [0.8, 0.9, 1.0], [1.5, 1.6, 1.7],]) 
layer = preprocessing.Normalization()  # 常态化
layer.adapt(data)             # 训练
normalized_data = layer(data) # 转换

# 显示平均数、标准差
print(f"平均数: {normalized_data.numpy().mean():.2f}")
print(f"标准差: {normalized_data.numpy().std():.2f}")

平均数: 0.00
标准差: 1.00


In [5]:
normalized_data

<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[-1.2247449, -1.2247449, -1.2247449],
       [ 0.       ,  0.       ,  0.       ],
       [ 1.2247449,  1.224745 ,  1.224745 ]], dtype=float32)>