# keras训练和评估

In [2]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

model中训练和评估的函数

fit:训练

evaluate：评估

predict：预测

In [4]:
# 样例如下
inputs = keras.Input(shape=(784,), name="digits")
x = layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = layers.Dense(64, activation="relu", name="dense_2")(x)
outputs = layers.Dense(10, activation="softmax", name="predictions")(x)

model = keras.Model(inputs=inputs, outputs=outputs)

In [5]:
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

In [6]:
# 转换数据，并划分训练、验证、测试集合
# Preprocess the data (these are NumPy arrays)
x_train = x_train.reshape(60000, 784).astype("float32") / 255
x_test = x_test.reshape(10000, 784).astype("float32") / 255

y_train = y_train.astype("float32")
y_test = y_test.astype("float32")

# Reserve 10,000 samples for validation
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]

In [7]:
# 定义训练过程配置
model.compile(optimizer=keras.optimizers.Adam(1e-4),loss=keras.losses.SparseCategoricalCrossentropy(),metrics=[keras.metrics.SparseCategoricalAccuracy()])


In [8]:
print("Fit model on training data")
# 定义epoch和batch-size
history = model.fit(
    x_train,
    y_train,
    batch_size=64,
    epochs=2,
    # We pass some validation for
    # monitoring validation loss and metrics
    # at the end of each epoch
    validation_data=(x_val, y_val),
)

Fit model on training data
Train on 50000 samples, validate on 10000 samples
Epoch 1/2
Epoch 2/2


In [9]:
history.history

{'loss': [1.000363467540741, 0.3663938018035889],
 'sparse_categorical_accuracy': [0.74288, 0.90022],
 'val_loss': [0.400281676030159, 0.292415634059906],
 'val_sparse_categorical_accuracy': [0.8968, 0.9213]}

In [10]:
# Evaluate the model on the test data using `evaluate`
print("Evaluate on test data")
results = model.evaluate(x_test, y_test, batch_size=128)
print("test loss, test acc:", results)

# Generate predictions (probabilities -- the output of the last layer)
# on new data using `predict`
print("Generate predictions for 3 samples")
predictions = model.predict(x_test[:3])
print("predictions shape:", predictions.shape)

Evaluate on test data


test loss, test acc: [0.303735422039032, 0.9134]
Generate predictions for 3 samples
predictions shape: (3, 10)


## 关于compile中许多内置的方法

- optimizer
    - SGD()
    - RMSprop()
    - Adam()
    - etc.
- Losses:
    - MeanSquaredError()
    - KLDivergence()
    - CosineSimilarity()
    - etc.
- Metrics:
    - AUC()
    - Precision()
    - Recall()
    - etc.
    
如果要自定义loss，需要继承tf.keras.losses.Loss类，此时需要重写__init__ and call(self,y_true,y_pred)

示例如下

In [11]:
class CustomMSE(keras.losses.Loss):
    def __init__(self, regularization_factor=0.1, name="custom_mse"):
        super().__init__(name=name)
        self.regularization_factor = regularization_factor

    def call(self, y_true, y_pred):
        mse = tf.math.reduce_mean(tf.square(y_true - y_pred))
        reg = tf.math.reduce_mean(tf.square(0.5 - y_pred))
        return mse + reg * self.regularization_factor

如果要自定义metrics，需要继承tf.keras.metrics.Metric，需要实现4中方法：
- __init__(self):为了你的metric目标创建中间参数
- update_state(self, y_true, y_pred, sample_weight=None)：利用真实值和预测值更新中间值
- result(self):利用中间值得出最终结果
- reset_states(self):重新初始化中间状态

示例如下：

In [12]:
class CategoricalTruePositives(keras.metrics.Metric):
    def __init__(self, name="categorical_true_positives", **kwargs):
        super(CategoricalTruePositives, self).__init__(name=name, **kwargs)
        self.true_positives = self.add_weight(name="ctp", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_pred = tf.reshape(tf.argmax(y_pred, axis=1), shape=(-1, 1))
        values = tf.cast(y_true, "int32") == tf.cast(y_pred, "int32")
        values = tf.cast(values, "float32")
        if sample_weight is not None:
            sample_weight = tf.cast(sample_weight, "float32")
            values = tf.multiply(values, sample_weight)
        self.true_positives.assign_add(tf.reduce_sum(values))

    def result(self):
        return self.true_positives

    def reset_states(self):
        # The state of the metric will be reset at the start of each epoch.
        self.true_positives.assign(0.0)